fix: OpenAI wrapper - remove duplicate span block, fix tool call extraction, fix cost model matching
- Remove duplicate span append + tool call decision logging block (lines 328-426) - Fix _extract_tool_calls_from_response to use getattr() instead of .get() on objects - Fix _calculate_cost to use exact match first, then longest-prefix match (prevents gpt-4o-mini matching gpt-4 pricing) - Fix test mock setup: set return_value BEFORE wrap_openai() so wrapper captures correct original - All 11 OpenAI integration tests + 8 SDK tests passing (19/19)
This commit is contained in:
414
packages/sdk-python/tests/test_openai_integration.py
Normal file
414
packages/sdk-python/tests/test_openai_integration.py
Normal file
@@ -0,0 +1,414 @@
|
||||
import json
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import agentlens
|
||||
from agentlens import init, shutdown, get_client
|
||||
from agentlens.models import TraceStatus, SpanStatus, SpanType
|
||||
from agentlens.integrations.openai import wrap_openai
|
||||
|
||||
|
||||
class TestOpenAIIntegration(unittest.TestCase):
|
||||
def setUp(self):
|
||||
shutdown()
|
||||
init(api_key="test_key", enabled=False)
|
||||
|
||||
def tearDown(self):
|
||||
shutdown()
|
||||
|
||||
def _create_mock_openai_client(self):
|
||||
"""Create a mock OpenAI client with necessary structure."""
|
||||
client = MagicMock()
|
||||
|
||||
# Mock chat.completions structure
|
||||
client.chat = MagicMock()
|
||||
client.chat.completions = MagicMock()
|
||||
|
||||
# Mock the create method
|
||||
client.chat.completions.create = MagicMock()
|
||||
|
||||
return client
|
||||
|
||||
def _create_mock_response(
|
||||
self,
|
||||
model="gpt-4",
|
||||
content="Hello!",
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
tool_calls=None,
|
||||
):
|
||||
"""Create a mock OpenAI response object."""
|
||||
response = MagicMock()
|
||||
response.model = model
|
||||
|
||||
# Mock choices
|
||||
choice = MagicMock()
|
||||
message = MagicMock()
|
||||
message.content = content
|
||||
|
||||
if tool_calls:
|
||||
# Build mock tool calls
|
||||
mock_tool_calls = []
|
||||
for tc_data in tool_calls:
|
||||
tc = MagicMock()
|
||||
function = MagicMock()
|
||||
function.name = tc_data["name"]
|
||||
function.arguments = json.dumps(tc_data["arguments"])
|
||||
tc.function = function
|
||||
mock_tool_calls.append(tc)
|
||||
|
||||
message.tool_calls = mock_tool_calls
|
||||
else:
|
||||
message.tool_calls = []
|
||||
|
||||
choice.message = message
|
||||
response.choices = [choice]
|
||||
|
||||
# Mock usage with actual int objects
|
||||
usage = MagicMock()
|
||||
usage.prompt_tokens = prompt_tokens
|
||||
usage.completion_tokens = completion_tokens
|
||||
usage.total_tokens = total_tokens
|
||||
response.usage = usage
|
||||
|
||||
return response
|
||||
|
||||
def test_basic_wrapping(self):
|
||||
"""Test basic wrapping of OpenAI client."""
|
||||
client = self._create_mock_openai_client()
|
||||
|
||||
# Set up mock response BEFORE wrapping
|
||||
mock_response = self._create_mock_response()
|
||||
client.chat.completions.create.return_value = mock_response
|
||||
|
||||
# Wrap the client
|
||||
wrapped_client = wrap_openai(client)
|
||||
|
||||
# Verify that method was wrapped
|
||||
self.assertIsNotNone(wrapped_client.chat.completions.create)
|
||||
|
||||
# Call the wrapped method
|
||||
result = wrapped_client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
self.assertEqual(result, mock_response)
|
||||
|
||||
def test_span_creation_in_trace(self):
|
||||
"""Test that a span is created when inside a trace context."""
|
||||
client = self._create_mock_openai_client()
|
||||
|
||||
# Set up mock response BEFORE wrapping
|
||||
mock_response = self._create_mock_response()
|
||||
client.chat.completions.create.return_value = mock_response
|
||||
|
||||
# Wrap client
|
||||
wrapped_client = wrap_openai(client)
|
||||
|
||||
# Call inside a trace
|
||||
with agentlens.trace("test_trace"):
|
||||
wrapped_client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
# Verify span was created via sent_traces
|
||||
traces = get_client().sent_traces
|
||||
self.assertEqual(len(traces), 1)
|
||||
trace = traces[0]
|
||||
self.assertEqual(len(trace.spans), 1)
|
||||
|
||||
span = trace.spans[0]
|
||||
self.assertEqual(span.type, SpanType.LLM_CALL.value)
|
||||
self.assertEqual(span.name, "openai.gpt-4")
|
||||
self.assertEqual(span.status, SpanStatus.COMPLETED.value)
|
||||
self.assertIsNotNone(span.duration_ms)
|
||||
|
||||
def test_function_call_decision_logging(self):
|
||||
"""Test that tool calls are logged as decision points."""
|
||||
client = self._create_mock_openai_client()
|
||||
|
||||
mock_response = self._create_mock_response(
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "search",
|
||||
"arguments": {"query": "test"},
|
||||
},
|
||||
{
|
||||
"name": "calculate",
|
||||
"arguments": {"expression": "1+1"},
|
||||
},
|
||||
]
|
||||
)
|
||||
client.chat.completions.create.return_value = mock_response
|
||||
|
||||
wrapped_client = wrap_openai(client)
|
||||
|
||||
# Call inside a trace
|
||||
with agentlens.trace("test_trace"):
|
||||
wrapped_client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Search for something"}],
|
||||
)
|
||||
|
||||
# Verify decision points were created
|
||||
traces = get_client().sent_traces
|
||||
self.assertEqual(len(traces), 1)
|
||||
trace = traces[0]
|
||||
|
||||
# Should have 2 decision points (one per tool call)
|
||||
self.assertEqual(len(trace.decision_points), 2)
|
||||
|
||||
# Check first decision
|
||||
decision1 = trace.decision_points[0]
|
||||
self.assertEqual(decision1.type, "TOOL_SELECTION")
|
||||
self.assertEqual(decision1.chosen["name"], "search")
|
||||
self.assertEqual(decision1.chosen["arguments"], {"query": "test"})
|
||||
self.assertEqual(decision1.alternatives, [])
|
||||
|
||||
# Check second decision
|
||||
decision2 = trace.decision_points[1]
|
||||
self.assertEqual(decision2.type, "TOOL_SELECTION")
|
||||
self.assertEqual(decision2.chosen["name"], "calculate")
|
||||
self.assertEqual(decision2.chosen["arguments"], {"expression": "1+1"})
|
||||
|
||||
def test_error_handling(self):
|
||||
"""Test that errors are handled and span has ERROR status."""
|
||||
client = self._create_mock_openai_client()
|
||||
|
||||
client.chat.completions.create.side_effect = Exception("API Error")
|
||||
|
||||
wrapped_client = wrap_openai(client)
|
||||
|
||||
# Call inside a trace and expect exception
|
||||
with agentlens.trace("test_trace"):
|
||||
with self.assertRaises(Exception) as cm:
|
||||
wrapped_client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
self.assertEqual(str(cm.exception), "API Error")
|
||||
|
||||
# Verify span has ERROR status
|
||||
traces = get_client().sent_traces
|
||||
self.assertEqual(len(traces), 1)
|
||||
trace = traces[0]
|
||||
|
||||
self.assertEqual(len(trace.spans), 1)
|
||||
span = trace.spans[0]
|
||||
self.assertEqual(span.status, SpanStatus.ERROR.value)
|
||||
self.assertEqual(span.status_message, "API Error")
|
||||
|
||||
# Verify error event was created
|
||||
self.assertEqual(len(trace.events), 1)
|
||||
event = trace.events[0]
|
||||
self.assertEqual(event.type, "ERROR")
|
||||
self.assertIn("API Error", event.name)
|
||||
self.assertEqual(event.metadata, {"error_type": "Exception"})
|
||||
|
||||
def test_cost_estimation_gpt4(self):
|
||||
"""Test cost calculation for gpt-4."""
|
||||
client = self._create_mock_openai_client()
|
||||
|
||||
mock_response = self._create_mock_response(
|
||||
model="gpt-4", prompt_tokens=1000, completion_tokens=500
|
||||
)
|
||||
client.chat.completions.create.return_value = mock_response
|
||||
|
||||
wrapped_client = wrap_openai(client)
|
||||
|
||||
# Call inside a trace
|
||||
with agentlens.trace("test_trace"):
|
||||
wrapped_client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
# Verify cost calculation
|
||||
traces = get_client().sent_traces
|
||||
trace = traces[0]
|
||||
span = trace.spans[0]
|
||||
|
||||
# gpt-4: $0.03/1K input, $0.06/1K output
|
||||
# 1000 * 0.03 = 0.03
|
||||
# 500 * 0.06 = 0.03
|
||||
# Total = 0.06
|
||||
self.assertIsNotNone(span.cost_usd)
|
||||
self.assertAlmostEqual(span.cost_usd, 0.06, places=4)
|
||||
|
||||
def test_cost_estimation_gpt4o_mini(self):
|
||||
"""Test cost calculation for gpt-4o-mini."""
|
||||
client = self._create_mock_openai_client()
|
||||
|
||||
mock_response = self._create_mock_response(
|
||||
model="gpt-4o-mini", prompt_tokens=1000, completion_tokens=500
|
||||
)
|
||||
client.chat.completions.create.return_value = mock_response
|
||||
|
||||
wrapped_client = wrap_openai(client)
|
||||
|
||||
# Call inside a trace
|
||||
with agentlens.trace("test_trace"):
|
||||
wrapped_client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
# Verify cost calculation
|
||||
traces = get_client().sent_traces
|
||||
trace = traces[0]
|
||||
span = trace.spans[0]
|
||||
|
||||
# gpt-4o-mini: $0.00015/1K input, $0.0006/1K output
|
||||
# 1000 * 0.00015 = 0.00015
|
||||
# 500 * 0.0006 = 0.0003
|
||||
# Total = 0.00045
|
||||
self.assertIsNotNone(span.cost_usd)
|
||||
self.assertAlmostEqual(span.cost_usd, 0.00045, places=6)
|
||||
|
||||
def test_cost_estimation_unknown_model(self):
|
||||
"""Test that unknown model doesn't calculate cost."""
|
||||
client = self._create_mock_openai_client()
|
||||
|
||||
mock_response = self._create_mock_response(
|
||||
model="unknown-model", prompt_tokens=1000, completion_tokens=500
|
||||
)
|
||||
client.chat.completions.create.return_value = mock_response
|
||||
|
||||
wrapped_client = wrap_openai(client)
|
||||
|
||||
# Call inside a trace
|
||||
with agentlens.trace("test_trace"):
|
||||
wrapped_client.chat.completions.create(
|
||||
model="unknown-model",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
# Verify no cost calculation
|
||||
traces = get_client().sent_traces
|
||||
trace = traces[0]
|
||||
span = trace.spans[0]
|
||||
|
||||
self.assertIsNone(span.cost_usd)
|
||||
|
||||
def test_outside_trace_context(self):
|
||||
"""Test that a standalone trace is created when not in a trace context."""
|
||||
client = self._create_mock_openai_client()
|
||||
|
||||
# Set up mock response BEFORE wrapping
|
||||
mock_response = self._create_mock_response()
|
||||
client.chat.completions.create.return_value = mock_response
|
||||
|
||||
# Wrap the client
|
||||
wrapped_client = wrap_openai(client)
|
||||
|
||||
# Call WITHOUT being inside a trace context
|
||||
wrapped_client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
# Verify a standalone trace was created
|
||||
traces = get_client().sent_traces
|
||||
self.assertEqual(len(traces), 1)
|
||||
|
||||
trace = traces[0]
|
||||
self.assertEqual(trace.name, "openai-gpt-4")
|
||||
self.assertEqual(trace.status, TraceStatus.COMPLETED.value)
|
||||
self.assertEqual(len(trace.spans), 1)
|
||||
|
||||
span = trace.spans[0]
|
||||
self.assertEqual(span.type, SpanType.LLM_CALL.value)
|
||||
self.assertEqual(span.name, "openai.gpt-4")
|
||||
|
||||
def test_message_truncation(self):
|
||||
"""Test that long messages are truncated for privacy."""
|
||||
client = self._create_mock_openai_client()
|
||||
|
||||
mock_response = self._create_mock_response()
|
||||
client.chat.completions.create.return_value = mock_response
|
||||
|
||||
wrapped_client = wrap_openai(client)
|
||||
|
||||
# Call with very long message
|
||||
long_content = "x" * 1000
|
||||
with agentlens.trace("test_trace"):
|
||||
wrapped_client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": long_content}],
|
||||
)
|
||||
|
||||
# Verify truncation in span input
|
||||
traces = get_client().sent_traces
|
||||
trace = traces[0]
|
||||
span = trace.spans[0]
|
||||
|
||||
self.assertIsNotNone(span.input_data)
|
||||
input_messages = span.input_data["messages"]
|
||||
self.assertEqual(len(input_messages), 1)
|
||||
|
||||
# Content should be truncated to ~500 chars with ...
|
||||
content = input_messages[0]["content"]
|
||||
self.assertTrue(len(content) < 1000)
|
||||
self.assertTrue(content.endswith("..."))
|
||||
|
||||
def test_token_tracking(self):
|
||||
"""Test that token counts are tracked."""
|
||||
client = self._create_mock_openai_client()
|
||||
|
||||
mock_response = self._create_mock_response(
|
||||
prompt_tokens=123, completion_tokens=456, total_tokens=579
|
||||
)
|
||||
client.chat.completions.create.return_value = mock_response
|
||||
|
||||
wrapped_client = wrap_openai(client)
|
||||
|
||||
# Call inside a trace
|
||||
with agentlens.trace("test_trace"):
|
||||
wrapped_client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
# Verify token tracking
|
||||
traces = get_client().sent_traces
|
||||
trace = traces[0]
|
||||
span = trace.spans[0]
|
||||
|
||||
self.assertEqual(span.token_count, 579)
|
||||
|
||||
def test_metadata_tracking(self):
|
||||
"""Test that metadata includes model, temperature, max_tokens."""
|
||||
client = self._create_mock_openai_client()
|
||||
|
||||
mock_response = self._create_mock_response()
|
||||
client.chat.completions.create.return_value = mock_response
|
||||
|
||||
wrapped_client = wrap_openai(client)
|
||||
|
||||
# Call with parameters
|
||||
with agentlens.trace("test_trace"):
|
||||
wrapped_client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
# Verify metadata
|
||||
traces = get_client().sent_traces
|
||||
trace = traces[0]
|
||||
span = trace.spans[0]
|
||||
|
||||
self.assertIsNotNone(span.metadata)
|
||||
self.assertEqual(span.metadata["model"], "gpt-4")
|
||||
self.assertEqual(span.metadata["temperature"], 0.7)
|
||||
self.assertEqual(span.metadata["max_tokens"], 1000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user