Files
agentlens/packages/sdk-python/agentlens/integrations/openai.py
Vectry 1989366844 fix: OpenAI wrapper - remove duplicate span block, fix tool call extraction, fix cost model matching
- Remove duplicate span append + tool call decision logging block (lines 328-426)
- Fix _extract_tool_calls_from_response to use getattr() instead of .get() on objects
- Fix _calculate_cost to use exact match first, then longest-prefix match (prevents gpt-4o-mini matching gpt-4 pricing)
- Fix test mock setup: set return_value BEFORE wrap_openai() so wrapper captures correct original
- All 11 OpenAI integration tests + 8 SDK tests passing (19/19)
2026-02-10 00:48:48 +00:00

639 lines
21 KiB
Python

"""OpenAI integration for AgentLens.
This module provides a wrapper that auto-instruments OpenAI API calls with
tracing, span creation, decision logging for function/tool calls, and token tracking.
"""
import json
import logging
import time
from functools import wraps
from typing import Any, Dict, Iterator, List, Optional
from agentlens.models import (
Event,
EventType,
_now_iso,
)
from agentlens.trace import (
TraceContext,
_get_context_stack,
get_current_span_id,
get_current_trace,
)
logger = logging.getLogger("agentlens")
# Cost per 1K tokens (input/output) for common models
_MODEL_COSTS: Dict[str, tuple] = {
"gpt-4": (0.03, 0.06),
"gpt-4-32k": (0.06, 0.12),
"gpt-4-turbo": (0.01, 0.03),
"gpt-4-turbo-2024-04-09": (0.01, 0.03),
"gpt-4-turbo-preview": (0.01, 0.03),
"gpt-4o": (0.005, 0.015),
"gpt-4o-2024-05-13": (0.005, 0.015),
"gpt-4o-2024-08-06": (0.0025, 0.01),
"gpt-4o-mini": (0.00015, 0.0006),
"gpt-4o-mini-2024-07-18": (0.00015, 0.0006),
"gpt-3.5-turbo": (0.0005, 0.0015),
"gpt-3.5-turbo-0125": (0.0005, 0.0015),
"gpt-3.5-turbo-1106": (0.001, 0.002),
"gpt-3.5-turbo-16k": (0.003, 0.004),
}
class _MockFunction:
def __init__(self, n: str, a: str):
self.name = n
self.arguments = a
class _MockToolCall:
def __init__(self, name: str, args: str):
self.function = _MockFunction(name, args)
class _MockMessage:
def __init__(self, content: Optional[str], tool_calls_list: List[_MockToolCall]):
self.content = content
self.tool_calls = tool_calls_list
class _MockChoice:
def __init__(self, content: Optional[str], tool_calls_list: List[_MockToolCall]):
self.message = _MockMessage(content, tool_calls_list)
class _MockUsage:
def __init__(self):
self.prompt_tokens: Optional[int] = None
self.completion_tokens: Optional[int] = None
self.total_tokens: Optional[int] = None
class _MockResponse:
def __init__(self):
self.model: str = "unknown"
self.choices: List[_MockChoice] = []
self.usage = _MockUsage()
def _truncate_data(data: Any, max_length: int = 500) -> Any:
"""Truncate data for privacy while preserving structure."""
if isinstance(data, str):
return data[:max_length] + "..." if len(data) > max_length else data
elif isinstance(data, dict):
return {k: _truncate_data(v, max_length) for k, v in data.items()}
elif isinstance(data, list):
return [_truncate_data(item, max_length) for item in data]
else:
return data
def _calculate_cost(
model: str, prompt_tokens: int, completion_tokens: int
) -> Optional[float]:
"""Calculate cost in USD based on model pricing."""
model_lower = model.lower()
if model_lower in _MODEL_COSTS:
input_cost, output_cost = _MODEL_COSTS[model_lower]
return (float(prompt_tokens) / 1000.0) * input_cost + float(
completion_tokens
) / 1000.0 * output_cost
best_match = None
best_len = 0
for model_name, costs in _MODEL_COSTS.items():
if model_lower.startswith(model_name.lower()) and len(model_name) > best_len:
best_match = costs
best_len = len(model_name)
if best_match:
input_cost, output_cost = best_match
return (float(prompt_tokens) / 1000.0) * input_cost + float(
completion_tokens
) / 1000.0 * output_cost
return None
def _extract_messages_truncated(messages: List[Any]) -> List[Dict[str, Any]]:
"""Extract and truncate message content."""
truncated = []
for msg in messages:
if isinstance(msg, dict):
truncated_msg = {"role": msg.get("role", "unknown")}
content = msg.get("content")
if content is not None:
truncated_msg["content"] = _truncate_data(str(content))
truncated.append(truncated_msg)
else:
# Handle message objects
role = getattr(msg, "role", "unknown")
content = getattr(msg, "content", "")
truncated.append({"role": role, "content": _truncate_data(str(content))})
return truncated
def _extract_content_from_response(response: Any) -> Optional[str]:
"""Extract content from OpenAI response."""
if hasattr(response, "choices") and response.choices:
message = response.choices[0].message
if hasattr(message, "content"):
return _truncate_data(str(message.content))
return None
def _extract_tool_calls_from_response(response: Any) -> List[Dict[str, Any]]:
"""Extract tool calls from OpenAI response."""
tool_calls = []
if hasattr(response, "choices") and response.choices:
message = response.choices[0].message
if hasattr(message, "tool_calls") and message.tool_calls:
for tc in message.tool_calls:
func = getattr(tc, "function", None)
name = getattr(func, "name", "unknown") if func else "unknown"
args_str = getattr(func, "arguments", "{}") if func else "{}"
call_dict = {"name": name}
try:
call_dict["arguments"] = json.loads(args_str)
except (json.JSONDecodeError, TypeError):
call_dict["arguments"] = args_str
tool_calls.append(call_dict)
return tool_calls
class _StreamWrapper:
"""Wrapper for OpenAI stream responses to collect chunks."""
def __init__(self, original_stream: Any, trace_ctx: Optional[TraceContext]):
self._original_stream = original_stream
self._trace_ctx = trace_ctx
self._chunks: List[Any] = []
self._start_time = time.time()
self._model = None
self._temperature = None
self._max_tokens = None
self._messages = None
self._parent_span_id = get_current_span_id()
def set_params(
self,
model: str,
temperature: Optional[float],
max_tokens: Optional[int],
messages: List[Any],
) -> None:
self._model = model
self._temperature = temperature
self._max_tokens = max_tokens
self._messages = messages
def __iter__(self) -> Iterator[Any]:
return self
def __next__(self) -> Any:
chunk = next(self._original_stream)
self._chunks.append(chunk)
return chunk
def finalize(self) -> None:
"""Create span after stream is fully consumed."""
if not self._chunks:
return
# Build a mock response from chunks
response = _MockResponse()
# Extract model from first chunk
if hasattr(self._chunks[0], "model"):
response.model = self._chunks[0].model
else:
response.model = self._model or "unknown"
# Extract message content and tool calls
message_content = None
tool_calls = []
for chunk in self._chunks:
if hasattr(chunk, "choices") and chunk.choices:
delta = chunk.choices[0].delta
if hasattr(delta, "content") and delta.content:
if message_content is None:
message_content = ""
message_content += delta.content
if hasattr(delta, "tool_calls") and delta.tool_calls:
for tc in delta.tool_calls:
if hasattr(tc, "function"):
tc_idx = tc.index if hasattr(tc, "index") else 0
while len(tool_calls) <= tc_idx:
tool_calls.append({"name": None, "arguments": ""})
if hasattr(tc.function, "name") and tc.function.name:
tool_calls[tc_idx]["name"] = tc.function.name
if (
hasattr(tc.function, "arguments")
and tc.function.arguments
):
tool_calls[tc_idx]["arguments"] += tc.function.arguments
response.choices = [_MockChoice(message_content, [])]
# Extract tool calls as objects
tool_call_objects = []
for tc in tool_calls:
if tc.get("name"):
try:
args = tc.get("arguments", "{}")
if isinstance(args, str):
args_dict = json.loads(args)
else:
args_dict = args
tool_call_objects.append(
_MockToolCall(tc["name"], json.dumps(args_dict))
)
except json.JSONDecodeError:
pass
response.choices[0].message.tool_calls = tool_call_objects
# Build mock usage (can't get exact tokens from stream)
response.usage.prompt_tokens = None
response.usage.completion_tokens = None
response.usage.total_tokens = None
# Create the span
_create_llm_span(
response=response,
start_time=self._start_time,
model=self._model or response.model,
temperature=self._temperature,
max_tokens=self._max_tokens,
messages=self._messages or [],
parent_span_id=self._parent_span_id,
trace_ctx=self._trace_ctx,
)
# Close trace context if we created one
if self._trace_ctx:
self._trace_ctx.__exit__(None, None, None)
def _create_llm_span(
response: Any,
start_time: float,
model: str,
temperature: Optional[float],
max_tokens: Optional[int],
messages: List[Any],
parent_span_id: Optional[str],
trace_ctx: Optional[TraceContext],
) -> None:
"""Create LLM span from OpenAI response."""
from agentlens.models import Span, SpanStatus, SpanType
current_trace = get_current_trace()
if current_trace is None:
logger.warning("No active trace, skipping span creation")
return
end_time = time.time()
duration_ms = int((end_time - start_time) * 1000)
# Extract token usage
token_count = None
cost_usd = None
if hasattr(response, "usage"):
prompt_tokens = getattr(response.usage, "prompt_tokens", None)
completion_tokens = getattr(response.usage, "completion_tokens", None)
total_tokens = getattr(response.usage, "total_tokens", None)
if total_tokens is not None:
token_count = total_tokens
if prompt_tokens is not None and completion_tokens is not None:
cost_usd = _calculate_cost(model, prompt_tokens, completion_tokens)
# Create span
span_name = f"openai.{model}"
span = Span(
name=span_name,
type=SpanType.LLM_CALL.value,
parent_span_id=parent_span_id,
input_data={"messages": _extract_messages_truncated(messages)},
output_data={"content": _extract_content_from_response(response)},
token_count=token_count,
cost_usd=cost_usd,
duration_ms=duration_ms,
status=SpanStatus.COMPLETED.value,
started_at=_now_iso(),
ended_at=_now_iso(),
metadata={
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
},
)
current_trace.spans.append(span)
# Push onto context stack for decision logging
stack = _get_context_stack()
stack.append(span)
# Log tool call decisions
tool_calls = _extract_tool_calls_from_response(response)
if tool_calls:
from agentlens.decision import log_decision
# Try to get reasoning from messages
reasoning = None
for msg in reversed(messages):
if isinstance(msg, dict):
if msg.get("role") == "assistant":
reasoning = msg.get("content")
if reasoning:
reasoning = _truncate_data(str(reasoning))
break
else:
role = getattr(msg, "role", None)
if role == "assistant":
reasoning = getattr(msg, "content", None)
if reasoning:
reasoning = _truncate_data(str(reasoning))
break
# Build context snapshot
context_snapshot = None
if hasattr(response, "usage"):
context_snapshot = {
"model": model,
"prompt_tokens": getattr(response.usage, "prompt_tokens"),
"completion_tokens": getattr(response.usage, "completion_tokens"),
}
for tool_call in tool_calls:
log_decision(
type="TOOL_SELECTION",
chosen={
"name": tool_call.get("name", "unknown"),
"arguments": tool_call.get("arguments", {}),
},
alternatives=[],
reasoning=reasoning,
context_snapshot=context_snapshot,
)
# Always pop from context stack
if stack and stack[-1] == span:
stack.pop()
elif stack and isinstance(stack[-1], Span) and stack[-1].id == span.id:
stack.pop()
def _wrap_create(original_create: Any, is_async: bool = False) -> Any:
"""Wrap OpenAI chat.completions.create method."""
if is_async:
@wraps(original_create)
async def async_traced_create(*args: Any, **kwargs: Any) -> Any:
# Extract parameters
model = kwargs.get("model", "gpt-3.5-turbo")
temperature = kwargs.get("temperature")
max_tokens = kwargs.get("max_tokens")
messages = kwargs.get("messages", [])
stream = kwargs.get("stream", False)
parent_span_id = get_current_span_id()
start_time = time.time()
# Handle streaming
if stream:
# Create trace if needed
trace_ctx = None
if get_current_trace() is None:
trace_ctx = TraceContext(name=f"openai-{model}")
trace_ctx.__enter__()
try:
original_stream = await original_create(*args, **kwargs)
wrapper = _StreamWrapper(original_stream, trace_ctx)
wrapper.set_params(model, temperature, max_tokens, messages)
return wrapper
except Exception as e:
if trace_ctx:
trace_ctx.__exit__(type(e), e, None)
raise
# Non-streaming
trace_ctx = None
if get_current_trace() is None:
trace_ctx = TraceContext(name=f"openai-{model}")
trace_ctx.__enter__()
try:
response = await original_create(*args, **kwargs)
_create_llm_span(
response=response,
start_time=start_time,
model=model,
temperature=temperature,
max_tokens=max_tokens,
messages=messages,
parent_span_id=parent_span_id,
trace_ctx=trace_ctx,
)
# Close trace context if we created one
if trace_ctx is not None:
trace_ctx.__exit__(None, None, None)
return response
except Exception as e:
_handle_error(
error=e,
start_time=start_time,
model=model,
temperature=temperature,
max_tokens=max_tokens,
messages=messages,
parent_span_id=parent_span_id,
trace_ctx=trace_ctx,
)
raise
return async_traced_create
else:
@wraps(original_create)
def traced_create(*args: Any, **kwargs: Any) -> Any:
# Extract parameters
model = kwargs.get("model", "gpt-3.5-turbo")
temperature = kwargs.get("temperature")
max_tokens = kwargs.get("max_tokens")
messages = kwargs.get("messages", [])
stream = kwargs.get("stream", False)
parent_span_id = get_current_span_id()
start_time = time.time()
# Handle streaming
if stream:
# Create trace if needed
trace_ctx = None
if get_current_trace() is None:
trace_ctx = TraceContext(name=f"openai-{model}")
trace_ctx.__enter__()
try:
original_stream = original_create(*args, **kwargs)
wrapper = _StreamWrapper(original_stream, trace_ctx)
wrapper.set_params(model, temperature, max_tokens, messages)
return wrapper
except Exception as e:
if trace_ctx:
trace_ctx.__exit__(type(e), e, None)
raise
# Non-streaming
trace_ctx = None
if get_current_trace() is None:
trace_ctx = TraceContext(name=f"openai-{model}")
trace_ctx.__enter__()
try:
response = original_create(*args, **kwargs)
_create_llm_span(
response=response,
start_time=start_time,
model=model,
temperature=temperature,
max_tokens=max_tokens,
messages=messages,
parent_span_id=parent_span_id,
trace_ctx=trace_ctx,
)
# Close trace context if we created one
if trace_ctx is not None:
trace_ctx.__exit__(None, None, None)
return response
except Exception as e:
_handle_error(
error=e,
start_time=start_time,
model=model,
temperature=temperature,
max_tokens=max_tokens,
messages=messages,
parent_span_id=parent_span_id,
trace_ctx=trace_ctx,
)
raise
return traced_create
def _handle_error(
error: Exception,
start_time: float,
model: str,
temperature: Optional[float],
max_tokens: Optional[int],
messages: List[Any],
parent_span_id: Optional[str],
trace_ctx: Optional[TraceContext],
) -> None:
"""Handle error by creating error span and event."""
from agentlens.models import Span, SpanStatus, SpanType
current_trace = get_current_trace()
if current_trace is None:
return
end_time = time.time()
duration_ms = int((end_time - start_time) * 1000)
# Create error span
span_name = f"openai.{model}"
span = Span(
name=span_name,
type=SpanType.LLM_CALL.value,
parent_span_id=parent_span_id,
input_data={"messages": _extract_messages_truncated(messages)},
status=SpanStatus.ERROR.value,
status_message=str(error),
started_at=_now_iso(),
ended_at=_now_iso(),
duration_ms=duration_ms,
metadata={
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
},
)
current_trace.spans.append(span)
# Create error event
error_event = Event(
type=EventType.ERROR.value,
name=f"{span_name}: {str(error)}",
span_id=span.id,
metadata={"error_type": type(error).__name__},
)
current_trace.events.append(error_event)
# Pop from context stack if needed
stack = _get_context_stack()
if stack and isinstance(stack[-1], Span) and stack[-1].id == span.id:
stack.pop()
def wrap_openai(client: Any) -> Any:
"""Wrap an OpenAI client to add AgentLens tracing.
Args:
client: The OpenAI client to wrap.
Returns:
The same client instance with chat.completions.create wrapped.
Example:
import openai
from agentlens.integrations.openai import wrap_openai
client = openai.OpenAI(api_key="sk-...")
traced_client = wrap_openai(client)
response = traced_client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello!"}]
)
"""
original_create = client.chat.completions.create
# Wrap synchronous method
traced_create = _wrap_create(original_create, is_async=False)
client.chat.completions.create = traced_create
# Try to wrap async method if available
if hasattr(client.chat.completions, "acreate"):
original_acreate = client.chat.completions.acreate
traced_acreate = _wrap_create(original_acreate, is_async=True)
client.chat.completions.acreate = traced_acreate
logger.debug("OpenAI client wrapped with AgentLens tracing")
return client