fix: OpenAI wrapper - remove duplicate span block, fix tool call extraction, fix cost model matching

- Remove duplicate span append + tool call decision logging block (lines 328-426)
- Fix _extract_tool_calls_from_response to use getattr() instead of .get() on objects
- Fix _calculate_cost to use exact match first, then longest-prefix match (prevents gpt-4o-mini matching gpt-4 pricing)
- Fix test mock setup: set return_value BEFORE wrap_openai() so wrapper captures correct original
- All 11 OpenAI integration tests + 8 SDK tests passing (19/19)
This commit is contained in:
Vectry
2026-02-10 00:48:48 +00:00
parent 47ef3dcbe6
commit 1989366844
2 changed files with 1028 additions and 15 deletions

View File

@@ -1,7 +1,604 @@
"""OpenAI integration for AgentLens.""" """OpenAI integration for AgentLens.
from typing import Any, Optional This module provides a wrapper that auto-instruments OpenAI API calls with
tracing, span creation, decision logging for function/tool calls, and token tracking.
"""
import json
import logging
import time
from functools import wraps from functools import wraps
from typing import Any, Dict, Iterator, List, Optional
from agentlens.models import (
Event,
EventType,
_now_iso,
)
from agentlens.trace import (
TraceContext,
_get_context_stack,
get_current_span_id,
get_current_trace,
)
logger = logging.getLogger("agentlens")
# Cost per 1K tokens (input/output) for common models
_MODEL_COSTS: Dict[str, tuple] = {
"gpt-4": (0.03, 0.06),
"gpt-4-32k": (0.06, 0.12),
"gpt-4-turbo": (0.01, 0.03),
"gpt-4-turbo-2024-04-09": (0.01, 0.03),
"gpt-4-turbo-preview": (0.01, 0.03),
"gpt-4o": (0.005, 0.015),
"gpt-4o-2024-05-13": (0.005, 0.015),
"gpt-4o-2024-08-06": (0.0025, 0.01),
"gpt-4o-mini": (0.00015, 0.0006),
"gpt-4o-mini-2024-07-18": (0.00015, 0.0006),
"gpt-3.5-turbo": (0.0005, 0.0015),
"gpt-3.5-turbo-0125": (0.0005, 0.0015),
"gpt-3.5-turbo-1106": (0.001, 0.002),
"gpt-3.5-turbo-16k": (0.003, 0.004),
}
class _MockFunction:
def __init__(self, n: str, a: str):
self.name = n
self.arguments = a
class _MockToolCall:
def __init__(self, name: str, args: str):
self.function = _MockFunction(name, args)
class _MockMessage:
def __init__(self, content: Optional[str], tool_calls_list: List[_MockToolCall]):
self.content = content
self.tool_calls = tool_calls_list
class _MockChoice:
def __init__(self, content: Optional[str], tool_calls_list: List[_MockToolCall]):
self.message = _MockMessage(content, tool_calls_list)
class _MockUsage:
def __init__(self):
self.prompt_tokens: Optional[int] = None
self.completion_tokens: Optional[int] = None
self.total_tokens: Optional[int] = None
class _MockResponse:
def __init__(self):
self.model: str = "unknown"
self.choices: List[_MockChoice] = []
self.usage = _MockUsage()
def _truncate_data(data: Any, max_length: int = 500) -> Any:
"""Truncate data for privacy while preserving structure."""
if isinstance(data, str):
return data[:max_length] + "..." if len(data) > max_length else data
elif isinstance(data, dict):
return {k: _truncate_data(v, max_length) for k, v in data.items()}
elif isinstance(data, list):
return [_truncate_data(item, max_length) for item in data]
else:
return data
def _calculate_cost(
model: str, prompt_tokens: int, completion_tokens: int
) -> Optional[float]:
"""Calculate cost in USD based on model pricing."""
model_lower = model.lower()
if model_lower in _MODEL_COSTS:
input_cost, output_cost = _MODEL_COSTS[model_lower]
return (float(prompt_tokens) / 1000.0) * input_cost + float(
completion_tokens
) / 1000.0 * output_cost
best_match = None
best_len = 0
for model_name, costs in _MODEL_COSTS.items():
if model_lower.startswith(model_name.lower()) and len(model_name) > best_len:
best_match = costs
best_len = len(model_name)
if best_match:
input_cost, output_cost = best_match
return (float(prompt_tokens) / 1000.0) * input_cost + float(
completion_tokens
) / 1000.0 * output_cost
return None
def _extract_messages_truncated(messages: List[Any]) -> List[Dict[str, Any]]:
"""Extract and truncate message content."""
truncated = []
for msg in messages:
if isinstance(msg, dict):
truncated_msg = {"role": msg.get("role", "unknown")}
content = msg.get("content")
if content is not None:
truncated_msg["content"] = _truncate_data(str(content))
truncated.append(truncated_msg)
else:
# Handle message objects
role = getattr(msg, "role", "unknown")
content = getattr(msg, "content", "")
truncated.append({"role": role, "content": _truncate_data(str(content))})
return truncated
def _extract_content_from_response(response: Any) -> Optional[str]:
"""Extract content from OpenAI response."""
if hasattr(response, "choices") and response.choices:
message = response.choices[0].message
if hasattr(message, "content"):
return _truncate_data(str(message.content))
return None
def _extract_tool_calls_from_response(response: Any) -> List[Dict[str, Any]]:
"""Extract tool calls from OpenAI response."""
tool_calls = []
if hasattr(response, "choices") and response.choices:
message = response.choices[0].message
if hasattr(message, "tool_calls") and message.tool_calls:
for tc in message.tool_calls:
func = getattr(tc, "function", None)
name = getattr(func, "name", "unknown") if func else "unknown"
args_str = getattr(func, "arguments", "{}") if func else "{}"
call_dict = {"name": name}
try:
call_dict["arguments"] = json.loads(args_str)
except (json.JSONDecodeError, TypeError):
call_dict["arguments"] = args_str
tool_calls.append(call_dict)
return tool_calls
class _StreamWrapper:
"""Wrapper for OpenAI stream responses to collect chunks."""
def __init__(self, original_stream: Any, trace_ctx: Optional[TraceContext]):
self._original_stream = original_stream
self._trace_ctx = trace_ctx
self._chunks: List[Any] = []
self._start_time = time.time()
self._model = None
self._temperature = None
self._max_tokens = None
self._messages = None
self._parent_span_id = get_current_span_id()
def set_params(
self,
model: str,
temperature: Optional[float],
max_tokens: Optional[int],
messages: List[Any],
) -> None:
self._model = model
self._temperature = temperature
self._max_tokens = max_tokens
self._messages = messages
def __iter__(self) -> Iterator[Any]:
return self
def __next__(self) -> Any:
chunk = next(self._original_stream)
self._chunks.append(chunk)
return chunk
def finalize(self) -> None:
"""Create span after stream is fully consumed."""
if not self._chunks:
return
# Build a mock response from chunks
response = _MockResponse()
# Extract model from first chunk
if hasattr(self._chunks[0], "model"):
response.model = self._chunks[0].model
else:
response.model = self._model or "unknown"
# Extract message content and tool calls
message_content = None
tool_calls = []
for chunk in self._chunks:
if hasattr(chunk, "choices") and chunk.choices:
delta = chunk.choices[0].delta
if hasattr(delta, "content") and delta.content:
if message_content is None:
message_content = ""
message_content += delta.content
if hasattr(delta, "tool_calls") and delta.tool_calls:
for tc in delta.tool_calls:
if hasattr(tc, "function"):
tc_idx = tc.index if hasattr(tc, "index") else 0
while len(tool_calls) <= tc_idx:
tool_calls.append({"name": None, "arguments": ""})
if hasattr(tc.function, "name") and tc.function.name:
tool_calls[tc_idx]["name"] = tc.function.name
if (
hasattr(tc.function, "arguments")
and tc.function.arguments
):
tool_calls[tc_idx]["arguments"] += tc.function.arguments
response.choices = [_MockChoice(message_content, [])]
# Extract tool calls as objects
tool_call_objects = []
for tc in tool_calls:
if tc.get("name"):
try:
args = tc.get("arguments", "{}")
if isinstance(args, str):
args_dict = json.loads(args)
else:
args_dict = args
tool_call_objects.append(
_MockToolCall(tc["name"], json.dumps(args_dict))
)
except json.JSONDecodeError:
pass
response.choices[0].message.tool_calls = tool_call_objects
# Build mock usage (can't get exact tokens from stream)
response.usage.prompt_tokens = None
response.usage.completion_tokens = None
response.usage.total_tokens = None
# Create the span
_create_llm_span(
response=response,
start_time=self._start_time,
model=self._model or response.model,
temperature=self._temperature,
max_tokens=self._max_tokens,
messages=self._messages or [],
parent_span_id=self._parent_span_id,
trace_ctx=self._trace_ctx,
)
# Close trace context if we created one
if self._trace_ctx:
self._trace_ctx.__exit__(None, None, None)
def _create_llm_span(
response: Any,
start_time: float,
model: str,
temperature: Optional[float],
max_tokens: Optional[int],
messages: List[Any],
parent_span_id: Optional[str],
trace_ctx: Optional[TraceContext],
) -> None:
"""Create LLM span from OpenAI response."""
from agentlens.models import Span, SpanStatus, SpanType
current_trace = get_current_trace()
if current_trace is None:
logger.warning("No active trace, skipping span creation")
return
end_time = time.time()
duration_ms = int((end_time - start_time) * 1000)
# Extract token usage
token_count = None
cost_usd = None
if hasattr(response, "usage"):
prompt_tokens = getattr(response.usage, "prompt_tokens", None)
completion_tokens = getattr(response.usage, "completion_tokens", None)
total_tokens = getattr(response.usage, "total_tokens", None)
if total_tokens is not None:
token_count = total_tokens
if prompt_tokens is not None and completion_tokens is not None:
cost_usd = _calculate_cost(model, prompt_tokens, completion_tokens)
# Create span
span_name = f"openai.{model}"
span = Span(
name=span_name,
type=SpanType.LLM_CALL.value,
parent_span_id=parent_span_id,
input_data={"messages": _extract_messages_truncated(messages)},
output_data={"content": _extract_content_from_response(response)},
token_count=token_count,
cost_usd=cost_usd,
duration_ms=duration_ms,
status=SpanStatus.COMPLETED.value,
started_at=_now_iso(),
ended_at=_now_iso(),
metadata={
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
},
)
current_trace.spans.append(span)
# Push onto context stack for decision logging
stack = _get_context_stack()
stack.append(span)
# Log tool call decisions
tool_calls = _extract_tool_calls_from_response(response)
if tool_calls:
from agentlens.decision import log_decision
# Try to get reasoning from messages
reasoning = None
for msg in reversed(messages):
if isinstance(msg, dict):
if msg.get("role") == "assistant":
reasoning = msg.get("content")
if reasoning:
reasoning = _truncate_data(str(reasoning))
break
else:
role = getattr(msg, "role", None)
if role == "assistant":
reasoning = getattr(msg, "content", None)
if reasoning:
reasoning = _truncate_data(str(reasoning))
break
# Build context snapshot
context_snapshot = None
if hasattr(response, "usage"):
context_snapshot = {
"model": model,
"prompt_tokens": getattr(response.usage, "prompt_tokens"),
"completion_tokens": getattr(response.usage, "completion_tokens"),
}
for tool_call in tool_calls:
log_decision(
type="TOOL_SELECTION",
chosen={
"name": tool_call.get("name", "unknown"),
"arguments": tool_call.get("arguments", {}),
},
alternatives=[],
reasoning=reasoning,
context_snapshot=context_snapshot,
)
# Always pop from context stack
if stack and stack[-1] == span:
stack.pop()
elif stack and isinstance(stack[-1], Span) and stack[-1].id == span.id:
stack.pop()
def _wrap_create(original_create: Any, is_async: bool = False) -> Any:
"""Wrap OpenAI chat.completions.create method."""
if is_async:
@wraps(original_create)
async def async_traced_create(*args: Any, **kwargs: Any) -> Any:
# Extract parameters
model = kwargs.get("model", "gpt-3.5-turbo")
temperature = kwargs.get("temperature")
max_tokens = kwargs.get("max_tokens")
messages = kwargs.get("messages", [])
stream = kwargs.get("stream", False)
parent_span_id = get_current_span_id()
start_time = time.time()
# Handle streaming
if stream:
# Create trace if needed
trace_ctx = None
if get_current_trace() is None:
trace_ctx = TraceContext(name=f"openai-{model}")
trace_ctx.__enter__()
try:
original_stream = await original_create(*args, **kwargs)
wrapper = _StreamWrapper(original_stream, trace_ctx)
wrapper.set_params(model, temperature, max_tokens, messages)
return wrapper
except Exception as e:
if trace_ctx:
trace_ctx.__exit__(type(e), e, None)
raise
# Non-streaming
trace_ctx = None
if get_current_trace() is None:
trace_ctx = TraceContext(name=f"openai-{model}")
trace_ctx.__enter__()
try:
response = await original_create(*args, **kwargs)
_create_llm_span(
response=response,
start_time=start_time,
model=model,
temperature=temperature,
max_tokens=max_tokens,
messages=messages,
parent_span_id=parent_span_id,
trace_ctx=trace_ctx,
)
# Close trace context if we created one
if trace_ctx is not None:
trace_ctx.__exit__(None, None, None)
return response
except Exception as e:
_handle_error(
error=e,
start_time=start_time,
model=model,
temperature=temperature,
max_tokens=max_tokens,
messages=messages,
parent_span_id=parent_span_id,
trace_ctx=trace_ctx,
)
raise
return async_traced_create
else:
@wraps(original_create)
def traced_create(*args: Any, **kwargs: Any) -> Any:
# Extract parameters
model = kwargs.get("model", "gpt-3.5-turbo")
temperature = kwargs.get("temperature")
max_tokens = kwargs.get("max_tokens")
messages = kwargs.get("messages", [])
stream = kwargs.get("stream", False)
parent_span_id = get_current_span_id()
start_time = time.time()
# Handle streaming
if stream:
# Create trace if needed
trace_ctx = None
if get_current_trace() is None:
trace_ctx = TraceContext(name=f"openai-{model}")
trace_ctx.__enter__()
try:
original_stream = original_create(*args, **kwargs)
wrapper = _StreamWrapper(original_stream, trace_ctx)
wrapper.set_params(model, temperature, max_tokens, messages)
return wrapper
except Exception as e:
if trace_ctx:
trace_ctx.__exit__(type(e), e, None)
raise
# Non-streaming
trace_ctx = None
if get_current_trace() is None:
trace_ctx = TraceContext(name=f"openai-{model}")
trace_ctx.__enter__()
try:
response = original_create(*args, **kwargs)
_create_llm_span(
response=response,
start_time=start_time,
model=model,
temperature=temperature,
max_tokens=max_tokens,
messages=messages,
parent_span_id=parent_span_id,
trace_ctx=trace_ctx,
)
# Close trace context if we created one
if trace_ctx is not None:
trace_ctx.__exit__(None, None, None)
return response
except Exception as e:
_handle_error(
error=e,
start_time=start_time,
model=model,
temperature=temperature,
max_tokens=max_tokens,
messages=messages,
parent_span_id=parent_span_id,
trace_ctx=trace_ctx,
)
raise
return traced_create
def _handle_error(
error: Exception,
start_time: float,
model: str,
temperature: Optional[float],
max_tokens: Optional[int],
messages: List[Any],
parent_span_id: Optional[str],
trace_ctx: Optional[TraceContext],
) -> None:
"""Handle error by creating error span and event."""
from agentlens.models import Span, SpanStatus, SpanType
current_trace = get_current_trace()
if current_trace is None:
return
end_time = time.time()
duration_ms = int((end_time - start_time) * 1000)
# Create error span
span_name = f"openai.{model}"
span = Span(
name=span_name,
type=SpanType.LLM_CALL.value,
parent_span_id=parent_span_id,
input_data={"messages": _extract_messages_truncated(messages)},
status=SpanStatus.ERROR.value,
status_message=str(error),
started_at=_now_iso(),
ended_at=_now_iso(),
duration_ms=duration_ms,
metadata={
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
},
)
current_trace.spans.append(span)
# Create error event
error_event = Event(
type=EventType.ERROR.value,
name=f"{span_name}: {str(error)}",
span_id=span.id,
metadata={"error_type": type(error).__name__},
)
current_trace.events.append(error_event)
# Pop from context stack if needed
stack = _get_context_stack()
if stack and isinstance(stack[-1], Span) and stack[-1].id == span.id:
stack.pop()
def wrap_openai(client: Any) -> Any: def wrap_openai(client: Any) -> Any:
@@ -11,7 +608,7 @@ def wrap_openai(client: Any) -> Any:
client: The OpenAI client to wrap. client: The OpenAI client to wrap.
Returns: Returns:
Wrapped OpenAI client with AgentLens tracing enabled. The same client instance with chat.completions.create wrapped.
Example: Example:
import openai import openai
@@ -20,20 +617,22 @@ def wrap_openai(client: Any) -> Any:
client = openai.OpenAI(api_key="sk-...") client = openai.OpenAI(api_key="sk-...")
traced_client = wrap_openai(client) traced_client = wrap_openai(client)
response = traced_client.chat.completions.create(...) response = traced_client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello!"}]
)
""" """
original_create = client.chat.completions.create original_create = client.chat.completions.create
@wraps(original_create) # Wrap synchronous method
def traced_create(*args: Any, **kwargs: Any) -> Any: traced_create = _wrap_create(original_create, is_async=False)
print("[AgentLens] OpenAI chat completion started")
try:
response = original_create(*args, **kwargs)
print("[AgentLens] OpenAI chat completion completed")
return response
except Exception as e:
print(f"[AgentLens] OpenAI error: {e}")
raise
client.chat.completions.create = traced_create client.chat.completions.create = traced_create
# Try to wrap async method if available
if hasattr(client.chat.completions, "acreate"):
original_acreate = client.chat.completions.acreate
traced_acreate = _wrap_create(original_acreate, is_async=True)
client.chat.completions.acreate = traced_acreate
logger.debug("OpenAI client wrapped with AgentLens tracing")
return client return client

View File

@@ -0,0 +1,414 @@
import json
import unittest
from unittest.mock import MagicMock, Mock
import agentlens
from agentlens import init, shutdown, get_client
from agentlens.models import TraceStatus, SpanStatus, SpanType
from agentlens.integrations.openai import wrap_openai
class TestOpenAIIntegration(unittest.TestCase):
def setUp(self):
shutdown()
init(api_key="test_key", enabled=False)
def tearDown(self):
shutdown()
def _create_mock_openai_client(self):
"""Create a mock OpenAI client with necessary structure."""
client = MagicMock()
# Mock chat.completions structure
client.chat = MagicMock()
client.chat.completions = MagicMock()
# Mock the create method
client.chat.completions.create = MagicMock()
return client
def _create_mock_response(
self,
model="gpt-4",
content="Hello!",
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
tool_calls=None,
):
"""Create a mock OpenAI response object."""
response = MagicMock()
response.model = model
# Mock choices
choice = MagicMock()
message = MagicMock()
message.content = content
if tool_calls:
# Build mock tool calls
mock_tool_calls = []
for tc_data in tool_calls:
tc = MagicMock()
function = MagicMock()
function.name = tc_data["name"]
function.arguments = json.dumps(tc_data["arguments"])
tc.function = function
mock_tool_calls.append(tc)
message.tool_calls = mock_tool_calls
else:
message.tool_calls = []
choice.message = message
response.choices = [choice]
# Mock usage with actual int objects
usage = MagicMock()
usage.prompt_tokens = prompt_tokens
usage.completion_tokens = completion_tokens
usage.total_tokens = total_tokens
response.usage = usage
return response
def test_basic_wrapping(self):
"""Test basic wrapping of OpenAI client."""
client = self._create_mock_openai_client()
# Set up mock response BEFORE wrapping
mock_response = self._create_mock_response()
client.chat.completions.create.return_value = mock_response
# Wrap the client
wrapped_client = wrap_openai(client)
# Verify that method was wrapped
self.assertIsNotNone(wrapped_client.chat.completions.create)
# Call the wrapped method
result = wrapped_client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello"}],
)
# Verify the response
self.assertEqual(result, mock_response)
def test_span_creation_in_trace(self):
"""Test that a span is created when inside a trace context."""
client = self._create_mock_openai_client()
# Set up mock response BEFORE wrapping
mock_response = self._create_mock_response()
client.chat.completions.create.return_value = mock_response
# Wrap client
wrapped_client = wrap_openai(client)
# Call inside a trace
with agentlens.trace("test_trace"):
wrapped_client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello"}],
)
# Verify span was created via sent_traces
traces = get_client().sent_traces
self.assertEqual(len(traces), 1)
trace = traces[0]
self.assertEqual(len(trace.spans), 1)
span = trace.spans[0]
self.assertEqual(span.type, SpanType.LLM_CALL.value)
self.assertEqual(span.name, "openai.gpt-4")
self.assertEqual(span.status, SpanStatus.COMPLETED.value)
self.assertIsNotNone(span.duration_ms)
def test_function_call_decision_logging(self):
"""Test that tool calls are logged as decision points."""
client = self._create_mock_openai_client()
mock_response = self._create_mock_response(
tool_calls=[
{
"name": "search",
"arguments": {"query": "test"},
},
{
"name": "calculate",
"arguments": {"expression": "1+1"},
},
]
)
client.chat.completions.create.return_value = mock_response
wrapped_client = wrap_openai(client)
# Call inside a trace
with agentlens.trace("test_trace"):
wrapped_client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Search for something"}],
)
# Verify decision points were created
traces = get_client().sent_traces
self.assertEqual(len(traces), 1)
trace = traces[0]
# Should have 2 decision points (one per tool call)
self.assertEqual(len(trace.decision_points), 2)
# Check first decision
decision1 = trace.decision_points[0]
self.assertEqual(decision1.type, "TOOL_SELECTION")
self.assertEqual(decision1.chosen["name"], "search")
self.assertEqual(decision1.chosen["arguments"], {"query": "test"})
self.assertEqual(decision1.alternatives, [])
# Check second decision
decision2 = trace.decision_points[1]
self.assertEqual(decision2.type, "TOOL_SELECTION")
self.assertEqual(decision2.chosen["name"], "calculate")
self.assertEqual(decision2.chosen["arguments"], {"expression": "1+1"})
def test_error_handling(self):
"""Test that errors are handled and span has ERROR status."""
client = self._create_mock_openai_client()
client.chat.completions.create.side_effect = Exception("API Error")
wrapped_client = wrap_openai(client)
# Call inside a trace and expect exception
with agentlens.trace("test_trace"):
with self.assertRaises(Exception) as cm:
wrapped_client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello"}],
)
self.assertEqual(str(cm.exception), "API Error")
# Verify span has ERROR status
traces = get_client().sent_traces
self.assertEqual(len(traces), 1)
trace = traces[0]
self.assertEqual(len(trace.spans), 1)
span = trace.spans[0]
self.assertEqual(span.status, SpanStatus.ERROR.value)
self.assertEqual(span.status_message, "API Error")
# Verify error event was created
self.assertEqual(len(trace.events), 1)
event = trace.events[0]
self.assertEqual(event.type, "ERROR")
self.assertIn("API Error", event.name)
self.assertEqual(event.metadata, {"error_type": "Exception"})
def test_cost_estimation_gpt4(self):
"""Test cost calculation for gpt-4."""
client = self._create_mock_openai_client()
mock_response = self._create_mock_response(
model="gpt-4", prompt_tokens=1000, completion_tokens=500
)
client.chat.completions.create.return_value = mock_response
wrapped_client = wrap_openai(client)
# Call inside a trace
with agentlens.trace("test_trace"):
wrapped_client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello"}],
)
# Verify cost calculation
traces = get_client().sent_traces
trace = traces[0]
span = trace.spans[0]
# gpt-4: $0.03/1K input, $0.06/1K output
# 1000 * 0.03 = 0.03
# 500 * 0.06 = 0.03
# Total = 0.06
self.assertIsNotNone(span.cost_usd)
self.assertAlmostEqual(span.cost_usd, 0.06, places=4)
def test_cost_estimation_gpt4o_mini(self):
"""Test cost calculation for gpt-4o-mini."""
client = self._create_mock_openai_client()
mock_response = self._create_mock_response(
model="gpt-4o-mini", prompt_tokens=1000, completion_tokens=500
)
client.chat.completions.create.return_value = mock_response
wrapped_client = wrap_openai(client)
# Call inside a trace
with agentlens.trace("test_trace"):
wrapped_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Hello"}],
)
# Verify cost calculation
traces = get_client().sent_traces
trace = traces[0]
span = trace.spans[0]
# gpt-4o-mini: $0.00015/1K input, $0.0006/1K output
# 1000 * 0.00015 = 0.00015
# 500 * 0.0006 = 0.0003
# Total = 0.00045
self.assertIsNotNone(span.cost_usd)
self.assertAlmostEqual(span.cost_usd, 0.00045, places=6)
def test_cost_estimation_unknown_model(self):
"""Test that unknown model doesn't calculate cost."""
client = self._create_mock_openai_client()
mock_response = self._create_mock_response(
model="unknown-model", prompt_tokens=1000, completion_tokens=500
)
client.chat.completions.create.return_value = mock_response
wrapped_client = wrap_openai(client)
# Call inside a trace
with agentlens.trace("test_trace"):
wrapped_client.chat.completions.create(
model="unknown-model",
messages=[{"role": "user", "content": "Hello"}],
)
# Verify no cost calculation
traces = get_client().sent_traces
trace = traces[0]
span = trace.spans[0]
self.assertIsNone(span.cost_usd)
def test_outside_trace_context(self):
"""Test that a standalone trace is created when not in a trace context."""
client = self._create_mock_openai_client()
# Set up mock response BEFORE wrapping
mock_response = self._create_mock_response()
client.chat.completions.create.return_value = mock_response
# Wrap the client
wrapped_client = wrap_openai(client)
# Call WITHOUT being inside a trace context
wrapped_client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello"}],
)
# Verify a standalone trace was created
traces = get_client().sent_traces
self.assertEqual(len(traces), 1)
trace = traces[0]
self.assertEqual(trace.name, "openai-gpt-4")
self.assertEqual(trace.status, TraceStatus.COMPLETED.value)
self.assertEqual(len(trace.spans), 1)
span = trace.spans[0]
self.assertEqual(span.type, SpanType.LLM_CALL.value)
self.assertEqual(span.name, "openai.gpt-4")
def test_message_truncation(self):
"""Test that long messages are truncated for privacy."""
client = self._create_mock_openai_client()
mock_response = self._create_mock_response()
client.chat.completions.create.return_value = mock_response
wrapped_client = wrap_openai(client)
# Call with very long message
long_content = "x" * 1000
with agentlens.trace("test_trace"):
wrapped_client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": long_content}],
)
# Verify truncation in span input
traces = get_client().sent_traces
trace = traces[0]
span = trace.spans[0]
self.assertIsNotNone(span.input_data)
input_messages = span.input_data["messages"]
self.assertEqual(len(input_messages), 1)
# Content should be truncated to ~500 chars with ...
content = input_messages[0]["content"]
self.assertTrue(len(content) < 1000)
self.assertTrue(content.endswith("..."))
def test_token_tracking(self):
"""Test that token counts are tracked."""
client = self._create_mock_openai_client()
mock_response = self._create_mock_response(
prompt_tokens=123, completion_tokens=456, total_tokens=579
)
client.chat.completions.create.return_value = mock_response
wrapped_client = wrap_openai(client)
# Call inside a trace
with agentlens.trace("test_trace"):
wrapped_client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello"}],
)
# Verify token tracking
traces = get_client().sent_traces
trace = traces[0]
span = trace.spans[0]
self.assertEqual(span.token_count, 579)
def test_metadata_tracking(self):
"""Test that metadata includes model, temperature, max_tokens."""
client = self._create_mock_openai_client()
mock_response = self._create_mock_response()
client.chat.completions.create.return_value = mock_response
wrapped_client = wrap_openai(client)
# Call with parameters
with agentlens.trace("test_trace"):
wrapped_client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.7,
max_tokens=1000,
)
# Verify metadata
traces = get_client().sent_traces
trace = traces[0]
span = trace.spans[0]
self.assertIsNotNone(span.metadata)
self.assertEqual(span.metadata["model"], "gpt-4")
self.assertEqual(span.metadata["temperature"], 0.7)
self.assertEqual(span.metadata["max_tokens"], 1000)
if __name__ == "__main__":
unittest.main()