fix: OpenAI wrapper - remove duplicate span block, fix tool call extraction, fix cost model matching
- Remove duplicate span append + tool call decision logging block (lines 328-426) - Fix _extract_tool_calls_from_response to use getattr() instead of .get() on objects - Fix _calculate_cost to use exact match first, then longest-prefix match (prevents gpt-4o-mini matching gpt-4 pricing) - Fix test mock setup: set return_value BEFORE wrap_openai() so wrapper captures correct original - All 11 OpenAI integration tests + 8 SDK tests passing (19/19)
This commit is contained in:
@@ -1,7 +1,604 @@
|
|||||||
"""OpenAI integration for AgentLens."""
|
"""OpenAI integration for AgentLens.
|
||||||
|
|
||||||
from typing import Any, Optional
|
This module provides a wrapper that auto-instruments OpenAI API calls with
|
||||||
|
tracing, span creation, decision logging for function/tool calls, and token tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import Any, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
|
from agentlens.models import (
|
||||||
|
Event,
|
||||||
|
EventType,
|
||||||
|
_now_iso,
|
||||||
|
)
|
||||||
|
from agentlens.trace import (
|
||||||
|
TraceContext,
|
||||||
|
_get_context_stack,
|
||||||
|
get_current_span_id,
|
||||||
|
get_current_trace,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("agentlens")
|
||||||
|
|
||||||
|
# Cost per 1K tokens (input/output) for common models
|
||||||
|
_MODEL_COSTS: Dict[str, tuple] = {
|
||||||
|
"gpt-4": (0.03, 0.06),
|
||||||
|
"gpt-4-32k": (0.06, 0.12),
|
||||||
|
"gpt-4-turbo": (0.01, 0.03),
|
||||||
|
"gpt-4-turbo-2024-04-09": (0.01, 0.03),
|
||||||
|
"gpt-4-turbo-preview": (0.01, 0.03),
|
||||||
|
"gpt-4o": (0.005, 0.015),
|
||||||
|
"gpt-4o-2024-05-13": (0.005, 0.015),
|
||||||
|
"gpt-4o-2024-08-06": (0.0025, 0.01),
|
||||||
|
"gpt-4o-mini": (0.00015, 0.0006),
|
||||||
|
"gpt-4o-mini-2024-07-18": (0.00015, 0.0006),
|
||||||
|
"gpt-3.5-turbo": (0.0005, 0.0015),
|
||||||
|
"gpt-3.5-turbo-0125": (0.0005, 0.0015),
|
||||||
|
"gpt-3.5-turbo-1106": (0.001, 0.002),
|
||||||
|
"gpt-3.5-turbo-16k": (0.003, 0.004),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _MockFunction:
|
||||||
|
def __init__(self, n: str, a: str):
|
||||||
|
self.name = n
|
||||||
|
self.arguments = a
|
||||||
|
|
||||||
|
|
||||||
|
class _MockToolCall:
|
||||||
|
def __init__(self, name: str, args: str):
|
||||||
|
self.function = _MockFunction(name, args)
|
||||||
|
|
||||||
|
|
||||||
|
class _MockMessage:
|
||||||
|
def __init__(self, content: Optional[str], tool_calls_list: List[_MockToolCall]):
|
||||||
|
self.content = content
|
||||||
|
self.tool_calls = tool_calls_list
|
||||||
|
|
||||||
|
|
||||||
|
class _MockChoice:
|
||||||
|
def __init__(self, content: Optional[str], tool_calls_list: List[_MockToolCall]):
|
||||||
|
self.message = _MockMessage(content, tool_calls_list)
|
||||||
|
|
||||||
|
|
||||||
|
class _MockUsage:
|
||||||
|
def __init__(self):
|
||||||
|
self.prompt_tokens: Optional[int] = None
|
||||||
|
self.completion_tokens: Optional[int] = None
|
||||||
|
self.total_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class _MockResponse:
|
||||||
|
def __init__(self):
|
||||||
|
self.model: str = "unknown"
|
||||||
|
self.choices: List[_MockChoice] = []
|
||||||
|
self.usage = _MockUsage()
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_data(data: Any, max_length: int = 500) -> Any:
|
||||||
|
"""Truncate data for privacy while preserving structure."""
|
||||||
|
if isinstance(data, str):
|
||||||
|
return data[:max_length] + "..." if len(data) > max_length else data
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
return {k: _truncate_data(v, max_length) for k, v in data.items()}
|
||||||
|
elif isinstance(data, list):
|
||||||
|
return [_truncate_data(item, max_length) for item in data]
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_cost(
|
||||||
|
model: str, prompt_tokens: int, completion_tokens: int
|
||||||
|
) -> Optional[float]:
|
||||||
|
"""Calculate cost in USD based on model pricing."""
|
||||||
|
model_lower = model.lower()
|
||||||
|
|
||||||
|
if model_lower in _MODEL_COSTS:
|
||||||
|
input_cost, output_cost = _MODEL_COSTS[model_lower]
|
||||||
|
return (float(prompt_tokens) / 1000.0) * input_cost + float(
|
||||||
|
completion_tokens
|
||||||
|
) / 1000.0 * output_cost
|
||||||
|
|
||||||
|
best_match = None
|
||||||
|
best_len = 0
|
||||||
|
for model_name, costs in _MODEL_COSTS.items():
|
||||||
|
if model_lower.startswith(model_name.lower()) and len(model_name) > best_len:
|
||||||
|
best_match = costs
|
||||||
|
best_len = len(model_name)
|
||||||
|
|
||||||
|
if best_match:
|
||||||
|
input_cost, output_cost = best_match
|
||||||
|
return (float(prompt_tokens) / 1000.0) * input_cost + float(
|
||||||
|
completion_tokens
|
||||||
|
) / 1000.0 * output_cost
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_messages_truncated(messages: List[Any]) -> List[Dict[str, Any]]:
|
||||||
|
"""Extract and truncate message content."""
|
||||||
|
truncated = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
truncated_msg = {"role": msg.get("role", "unknown")}
|
||||||
|
content = msg.get("content")
|
||||||
|
if content is not None:
|
||||||
|
truncated_msg["content"] = _truncate_data(str(content))
|
||||||
|
truncated.append(truncated_msg)
|
||||||
|
else:
|
||||||
|
# Handle message objects
|
||||||
|
role = getattr(msg, "role", "unknown")
|
||||||
|
content = getattr(msg, "content", "")
|
||||||
|
truncated.append({"role": role, "content": _truncate_data(str(content))})
|
||||||
|
return truncated
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_content_from_response(response: Any) -> Optional[str]:
|
||||||
|
"""Extract content from OpenAI response."""
|
||||||
|
if hasattr(response, "choices") and response.choices:
|
||||||
|
message = response.choices[0].message
|
||||||
|
if hasattr(message, "content"):
|
||||||
|
return _truncate_data(str(message.content))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_calls_from_response(response: Any) -> List[Dict[str, Any]]:
|
||||||
|
"""Extract tool calls from OpenAI response."""
|
||||||
|
tool_calls = []
|
||||||
|
if hasattr(response, "choices") and response.choices:
|
||||||
|
message = response.choices[0].message
|
||||||
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
|
for tc in message.tool_calls:
|
||||||
|
func = getattr(tc, "function", None)
|
||||||
|
name = getattr(func, "name", "unknown") if func else "unknown"
|
||||||
|
args_str = getattr(func, "arguments", "{}") if func else "{}"
|
||||||
|
call_dict = {"name": name}
|
||||||
|
try:
|
||||||
|
call_dict["arguments"] = json.loads(args_str)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
call_dict["arguments"] = args_str
|
||||||
|
tool_calls.append(call_dict)
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
class _StreamWrapper:
|
||||||
|
"""Wrapper for OpenAI stream responses to collect chunks."""
|
||||||
|
|
||||||
|
def __init__(self, original_stream: Any, trace_ctx: Optional[TraceContext]):
|
||||||
|
self._original_stream = original_stream
|
||||||
|
self._trace_ctx = trace_ctx
|
||||||
|
self._chunks: List[Any] = []
|
||||||
|
self._start_time = time.time()
|
||||||
|
self._model = None
|
||||||
|
self._temperature = None
|
||||||
|
self._max_tokens = None
|
||||||
|
self._messages = None
|
||||||
|
self._parent_span_id = get_current_span_id()
|
||||||
|
|
||||||
|
def set_params(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
temperature: Optional[float],
|
||||||
|
max_tokens: Optional[int],
|
||||||
|
messages: List[Any],
|
||||||
|
) -> None:
|
||||||
|
self._model = model
|
||||||
|
self._temperature = temperature
|
||||||
|
self._max_tokens = max_tokens
|
||||||
|
self._messages = messages
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Any]:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> Any:
|
||||||
|
chunk = next(self._original_stream)
|
||||||
|
self._chunks.append(chunk)
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
def finalize(self) -> None:
|
||||||
|
"""Create span after stream is fully consumed."""
|
||||||
|
if not self._chunks:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build a mock response from chunks
|
||||||
|
response = _MockResponse()
|
||||||
|
|
||||||
|
# Extract model from first chunk
|
||||||
|
if hasattr(self._chunks[0], "model"):
|
||||||
|
response.model = self._chunks[0].model
|
||||||
|
else:
|
||||||
|
response.model = self._model or "unknown"
|
||||||
|
|
||||||
|
# Extract message content and tool calls
|
||||||
|
message_content = None
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
for chunk in self._chunks:
|
||||||
|
if hasattr(chunk, "choices") and chunk.choices:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if hasattr(delta, "content") and delta.content:
|
||||||
|
if message_content is None:
|
||||||
|
message_content = ""
|
||||||
|
message_content += delta.content
|
||||||
|
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
||||||
|
for tc in delta.tool_calls:
|
||||||
|
if hasattr(tc, "function"):
|
||||||
|
tc_idx = tc.index if hasattr(tc, "index") else 0
|
||||||
|
while len(tool_calls) <= tc_idx:
|
||||||
|
tool_calls.append({"name": None, "arguments": ""})
|
||||||
|
if hasattr(tc.function, "name") and tc.function.name:
|
||||||
|
tool_calls[tc_idx]["name"] = tc.function.name
|
||||||
|
if (
|
||||||
|
hasattr(tc.function, "arguments")
|
||||||
|
and tc.function.arguments
|
||||||
|
):
|
||||||
|
tool_calls[tc_idx]["arguments"] += tc.function.arguments
|
||||||
|
|
||||||
|
response.choices = [_MockChoice(message_content, [])]
|
||||||
|
|
||||||
|
# Extract tool calls as objects
|
||||||
|
tool_call_objects = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
if tc.get("name"):
|
||||||
|
try:
|
||||||
|
args = tc.get("arguments", "{}")
|
||||||
|
if isinstance(args, str):
|
||||||
|
args_dict = json.loads(args)
|
||||||
|
else:
|
||||||
|
args_dict = args
|
||||||
|
tool_call_objects.append(
|
||||||
|
_MockToolCall(tc["name"], json.dumps(args_dict))
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
response.choices[0].message.tool_calls = tool_call_objects
|
||||||
|
|
||||||
|
# Build mock usage (can't get exact tokens from stream)
|
||||||
|
response.usage.prompt_tokens = None
|
||||||
|
response.usage.completion_tokens = None
|
||||||
|
response.usage.total_tokens = None
|
||||||
|
|
||||||
|
# Create the span
|
||||||
|
_create_llm_span(
|
||||||
|
response=response,
|
||||||
|
start_time=self._start_time,
|
||||||
|
model=self._model or response.model,
|
||||||
|
temperature=self._temperature,
|
||||||
|
max_tokens=self._max_tokens,
|
||||||
|
messages=self._messages or [],
|
||||||
|
parent_span_id=self._parent_span_id,
|
||||||
|
trace_ctx=self._trace_ctx,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Close trace context if we created one
|
||||||
|
if self._trace_ctx:
|
||||||
|
self._trace_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_llm_span(
|
||||||
|
response: Any,
|
||||||
|
start_time: float,
|
||||||
|
model: str,
|
||||||
|
temperature: Optional[float],
|
||||||
|
max_tokens: Optional[int],
|
||||||
|
messages: List[Any],
|
||||||
|
parent_span_id: Optional[str],
|
||||||
|
trace_ctx: Optional[TraceContext],
|
||||||
|
) -> None:
|
||||||
|
"""Create LLM span from OpenAI response."""
|
||||||
|
from agentlens.models import Span, SpanStatus, SpanType
|
||||||
|
|
||||||
|
current_trace = get_current_trace()
|
||||||
|
if current_trace is None:
|
||||||
|
logger.warning("No active trace, skipping span creation")
|
||||||
|
return
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
duration_ms = int((end_time - start_time) * 1000)
|
||||||
|
|
||||||
|
# Extract token usage
|
||||||
|
token_count = None
|
||||||
|
cost_usd = None
|
||||||
|
if hasattr(response, "usage"):
|
||||||
|
prompt_tokens = getattr(response.usage, "prompt_tokens", None)
|
||||||
|
completion_tokens = getattr(response.usage, "completion_tokens", None)
|
||||||
|
total_tokens = getattr(response.usage, "total_tokens", None)
|
||||||
|
|
||||||
|
if total_tokens is not None:
|
||||||
|
token_count = total_tokens
|
||||||
|
|
||||||
|
if prompt_tokens is not None and completion_tokens is not None:
|
||||||
|
cost_usd = _calculate_cost(model, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
# Create span
|
||||||
|
span_name = f"openai.{model}"
|
||||||
|
span = Span(
|
||||||
|
name=span_name,
|
||||||
|
type=SpanType.LLM_CALL.value,
|
||||||
|
parent_span_id=parent_span_id,
|
||||||
|
input_data={"messages": _extract_messages_truncated(messages)},
|
||||||
|
output_data={"content": _extract_content_from_response(response)},
|
||||||
|
token_count=token_count,
|
||||||
|
cost_usd=cost_usd,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
status=SpanStatus.COMPLETED.value,
|
||||||
|
started_at=_now_iso(),
|
||||||
|
ended_at=_now_iso(),
|
||||||
|
metadata={
|
||||||
|
"model": model,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
current_trace.spans.append(span)
|
||||||
|
|
||||||
|
# Push onto context stack for decision logging
|
||||||
|
stack = _get_context_stack()
|
||||||
|
stack.append(span)
|
||||||
|
|
||||||
|
# Log tool call decisions
|
||||||
|
tool_calls = _extract_tool_calls_from_response(response)
|
||||||
|
if tool_calls:
|
||||||
|
from agentlens.decision import log_decision
|
||||||
|
|
||||||
|
# Try to get reasoning from messages
|
||||||
|
reasoning = None
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
if msg.get("role") == "assistant":
|
||||||
|
reasoning = msg.get("content")
|
||||||
|
if reasoning:
|
||||||
|
reasoning = _truncate_data(str(reasoning))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
role = getattr(msg, "role", None)
|
||||||
|
if role == "assistant":
|
||||||
|
reasoning = getattr(msg, "content", None)
|
||||||
|
if reasoning:
|
||||||
|
reasoning = _truncate_data(str(reasoning))
|
||||||
|
break
|
||||||
|
|
||||||
|
# Build context snapshot
|
||||||
|
context_snapshot = None
|
||||||
|
if hasattr(response, "usage"):
|
||||||
|
context_snapshot = {
|
||||||
|
"model": model,
|
||||||
|
"prompt_tokens": getattr(response.usage, "prompt_tokens"),
|
||||||
|
"completion_tokens": getattr(response.usage, "completion_tokens"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
log_decision(
|
||||||
|
type="TOOL_SELECTION",
|
||||||
|
chosen={
|
||||||
|
"name": tool_call.get("name", "unknown"),
|
||||||
|
"arguments": tool_call.get("arguments", {}),
|
||||||
|
},
|
||||||
|
alternatives=[],
|
||||||
|
reasoning=reasoning,
|
||||||
|
context_snapshot=context_snapshot,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Always pop from context stack
|
||||||
|
if stack and stack[-1] == span:
|
||||||
|
stack.pop()
|
||||||
|
elif stack and isinstance(stack[-1], Span) and stack[-1].id == span.id:
|
||||||
|
stack.pop()
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_create(original_create: Any, is_async: bool = False) -> Any:
|
||||||
|
"""Wrap OpenAI chat.completions.create method."""
|
||||||
|
|
||||||
|
if is_async:
|
||||||
|
|
||||||
|
@wraps(original_create)
|
||||||
|
async def async_traced_create(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
# Extract parameters
|
||||||
|
model = kwargs.get("model", "gpt-3.5-turbo")
|
||||||
|
temperature = kwargs.get("temperature")
|
||||||
|
max_tokens = kwargs.get("max_tokens")
|
||||||
|
messages = kwargs.get("messages", [])
|
||||||
|
stream = kwargs.get("stream", False)
|
||||||
|
|
||||||
|
parent_span_id = get_current_span_id()
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Handle streaming
|
||||||
|
if stream:
|
||||||
|
# Create trace if needed
|
||||||
|
trace_ctx = None
|
||||||
|
if get_current_trace() is None:
|
||||||
|
trace_ctx = TraceContext(name=f"openai-{model}")
|
||||||
|
trace_ctx.__enter__()
|
||||||
|
|
||||||
|
try:
|
||||||
|
original_stream = await original_create(*args, **kwargs)
|
||||||
|
|
||||||
|
wrapper = _StreamWrapper(original_stream, trace_ctx)
|
||||||
|
wrapper.set_params(model, temperature, max_tokens, messages)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
except Exception as e:
|
||||||
|
if trace_ctx:
|
||||||
|
trace_ctx.__exit__(type(e), e, None)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Non-streaming
|
||||||
|
trace_ctx = None
|
||||||
|
if get_current_trace() is None:
|
||||||
|
trace_ctx = TraceContext(name=f"openai-{model}")
|
||||||
|
trace_ctx.__enter__()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await original_create(*args, **kwargs)
|
||||||
|
|
||||||
|
_create_llm_span(
|
||||||
|
response=response,
|
||||||
|
start_time=start_time,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
messages=messages,
|
||||||
|
parent_span_id=parent_span_id,
|
||||||
|
trace_ctx=trace_ctx,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Close trace context if we created one
|
||||||
|
if trace_ctx is not None:
|
||||||
|
trace_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
_handle_error(
|
||||||
|
error=e,
|
||||||
|
start_time=start_time,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
messages=messages,
|
||||||
|
parent_span_id=parent_span_id,
|
||||||
|
trace_ctx=trace_ctx,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return async_traced_create
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
@wraps(original_create)
|
||||||
|
def traced_create(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
# Extract parameters
|
||||||
|
model = kwargs.get("model", "gpt-3.5-turbo")
|
||||||
|
temperature = kwargs.get("temperature")
|
||||||
|
max_tokens = kwargs.get("max_tokens")
|
||||||
|
messages = kwargs.get("messages", [])
|
||||||
|
stream = kwargs.get("stream", False)
|
||||||
|
|
||||||
|
parent_span_id = get_current_span_id()
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Handle streaming
|
||||||
|
if stream:
|
||||||
|
# Create trace if needed
|
||||||
|
trace_ctx = None
|
||||||
|
if get_current_trace() is None:
|
||||||
|
trace_ctx = TraceContext(name=f"openai-{model}")
|
||||||
|
trace_ctx.__enter__()
|
||||||
|
|
||||||
|
try:
|
||||||
|
original_stream = original_create(*args, **kwargs)
|
||||||
|
|
||||||
|
wrapper = _StreamWrapper(original_stream, trace_ctx)
|
||||||
|
wrapper.set_params(model, temperature, max_tokens, messages)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
except Exception as e:
|
||||||
|
if trace_ctx:
|
||||||
|
trace_ctx.__exit__(type(e), e, None)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Non-streaming
|
||||||
|
trace_ctx = None
|
||||||
|
if get_current_trace() is None:
|
||||||
|
trace_ctx = TraceContext(name=f"openai-{model}")
|
||||||
|
trace_ctx.__enter__()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = original_create(*args, **kwargs)
|
||||||
|
|
||||||
|
_create_llm_span(
|
||||||
|
response=response,
|
||||||
|
start_time=start_time,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
messages=messages,
|
||||||
|
parent_span_id=parent_span_id,
|
||||||
|
trace_ctx=trace_ctx,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Close trace context if we created one
|
||||||
|
if trace_ctx is not None:
|
||||||
|
trace_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
_handle_error(
|
||||||
|
error=e,
|
||||||
|
start_time=start_time,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
messages=messages,
|
||||||
|
parent_span_id=parent_span_id,
|
||||||
|
trace_ctx=trace_ctx,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return traced_create
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_error(
|
||||||
|
error: Exception,
|
||||||
|
start_time: float,
|
||||||
|
model: str,
|
||||||
|
temperature: Optional[float],
|
||||||
|
max_tokens: Optional[int],
|
||||||
|
messages: List[Any],
|
||||||
|
parent_span_id: Optional[str],
|
||||||
|
trace_ctx: Optional[TraceContext],
|
||||||
|
) -> None:
|
||||||
|
"""Handle error by creating error span and event."""
|
||||||
|
from agentlens.models import Span, SpanStatus, SpanType
|
||||||
|
|
||||||
|
current_trace = get_current_trace()
|
||||||
|
if current_trace is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
duration_ms = int((end_time - start_time) * 1000)
|
||||||
|
|
||||||
|
# Create error span
|
||||||
|
span_name = f"openai.{model}"
|
||||||
|
span = Span(
|
||||||
|
name=span_name,
|
||||||
|
type=SpanType.LLM_CALL.value,
|
||||||
|
parent_span_id=parent_span_id,
|
||||||
|
input_data={"messages": _extract_messages_truncated(messages)},
|
||||||
|
status=SpanStatus.ERROR.value,
|
||||||
|
status_message=str(error),
|
||||||
|
started_at=_now_iso(),
|
||||||
|
ended_at=_now_iso(),
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
metadata={
|
||||||
|
"model": model,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
current_trace.spans.append(span)
|
||||||
|
|
||||||
|
# Create error event
|
||||||
|
error_event = Event(
|
||||||
|
type=EventType.ERROR.value,
|
||||||
|
name=f"{span_name}: {str(error)}",
|
||||||
|
span_id=span.id,
|
||||||
|
metadata={"error_type": type(error).__name__},
|
||||||
|
)
|
||||||
|
|
||||||
|
current_trace.events.append(error_event)
|
||||||
|
|
||||||
|
# Pop from context stack if needed
|
||||||
|
stack = _get_context_stack()
|
||||||
|
if stack and isinstance(stack[-1], Span) and stack[-1].id == span.id:
|
||||||
|
stack.pop()
|
||||||
|
|
||||||
|
|
||||||
def wrap_openai(client: Any) -> Any:
|
def wrap_openai(client: Any) -> Any:
|
||||||
@@ -11,7 +608,7 @@ def wrap_openai(client: Any) -> Any:
|
|||||||
client: The OpenAI client to wrap.
|
client: The OpenAI client to wrap.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Wrapped OpenAI client with AgentLens tracing enabled.
|
The same client instance with chat.completions.create wrapped.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
import openai
|
import openai
|
||||||
@@ -20,20 +617,22 @@ def wrap_openai(client: Any) -> Any:
|
|||||||
client = openai.OpenAI(api_key="sk-...")
|
client = openai.OpenAI(api_key="sk-...")
|
||||||
traced_client = wrap_openai(client)
|
traced_client = wrap_openai(client)
|
||||||
|
|
||||||
response = traced_client.chat.completions.create(...)
|
response = traced_client.chat.completions.create(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Hello!"}]
|
||||||
|
)
|
||||||
"""
|
"""
|
||||||
original_create = client.chat.completions.create
|
original_create = client.chat.completions.create
|
||||||
|
|
||||||
@wraps(original_create)
|
# Wrap synchronous method
|
||||||
def traced_create(*args: Any, **kwargs: Any) -> Any:
|
traced_create = _wrap_create(original_create, is_async=False)
|
||||||
print("[AgentLens] OpenAI chat completion started")
|
|
||||||
try:
|
|
||||||
response = original_create(*args, **kwargs)
|
|
||||||
print("[AgentLens] OpenAI chat completion completed")
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[AgentLens] OpenAI error: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
client.chat.completions.create = traced_create
|
client.chat.completions.create = traced_create
|
||||||
|
|
||||||
|
# Try to wrap async method if available
|
||||||
|
if hasattr(client.chat.completions, "acreate"):
|
||||||
|
original_acreate = client.chat.completions.acreate
|
||||||
|
traced_acreate = _wrap_create(original_acreate, is_async=True)
|
||||||
|
client.chat.completions.acreate = traced_acreate
|
||||||
|
|
||||||
|
logger.debug("OpenAI client wrapped with AgentLens tracing")
|
||||||
return client
|
return client
|
||||||
|
|||||||
414
packages/sdk-python/tests/test_openai_integration.py
Normal file
414
packages/sdk-python/tests/test_openai_integration.py
Normal file
@@ -0,0 +1,414 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, Mock
|
||||||
|
|
||||||
|
import agentlens
|
||||||
|
from agentlens import init, shutdown, get_client
|
||||||
|
from agentlens.models import TraceStatus, SpanStatus, SpanType
|
||||||
|
from agentlens.integrations.openai import wrap_openai
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIIntegration(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
shutdown()
|
||||||
|
init(api_key="test_key", enabled=False)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutdown()
|
||||||
|
|
||||||
|
def _create_mock_openai_client(self):
|
||||||
|
"""Create a mock OpenAI client with necessary structure."""
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
# Mock chat.completions structure
|
||||||
|
client.chat = MagicMock()
|
||||||
|
client.chat.completions = MagicMock()
|
||||||
|
|
||||||
|
# Mock the create method
|
||||||
|
client.chat.completions.create = MagicMock()
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
def _create_mock_response(
|
||||||
|
self,
|
||||||
|
model="gpt-4",
|
||||||
|
content="Hello!",
|
||||||
|
prompt_tokens=10,
|
||||||
|
completion_tokens=5,
|
||||||
|
total_tokens=15,
|
||||||
|
tool_calls=None,
|
||||||
|
):
|
||||||
|
"""Create a mock OpenAI response object."""
|
||||||
|
response = MagicMock()
|
||||||
|
response.model = model
|
||||||
|
|
||||||
|
# Mock choices
|
||||||
|
choice = MagicMock()
|
||||||
|
message = MagicMock()
|
||||||
|
message.content = content
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
# Build mock tool calls
|
||||||
|
mock_tool_calls = []
|
||||||
|
for tc_data in tool_calls:
|
||||||
|
tc = MagicMock()
|
||||||
|
function = MagicMock()
|
||||||
|
function.name = tc_data["name"]
|
||||||
|
function.arguments = json.dumps(tc_data["arguments"])
|
||||||
|
tc.function = function
|
||||||
|
mock_tool_calls.append(tc)
|
||||||
|
|
||||||
|
message.tool_calls = mock_tool_calls
|
||||||
|
else:
|
||||||
|
message.tool_calls = []
|
||||||
|
|
||||||
|
choice.message = message
|
||||||
|
response.choices = [choice]
|
||||||
|
|
||||||
|
# Mock usage with actual int objects
|
||||||
|
usage = MagicMock()
|
||||||
|
usage.prompt_tokens = prompt_tokens
|
||||||
|
usage.completion_tokens = completion_tokens
|
||||||
|
usage.total_tokens = total_tokens
|
||||||
|
response.usage = usage
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def test_basic_wrapping(self):
|
||||||
|
"""Test basic wrapping of OpenAI client."""
|
||||||
|
client = self._create_mock_openai_client()
|
||||||
|
|
||||||
|
# Set up mock response BEFORE wrapping
|
||||||
|
mock_response = self._create_mock_response()
|
||||||
|
client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
# Wrap the client
|
||||||
|
wrapped_client = wrap_openai(client)
|
||||||
|
|
||||||
|
# Verify that method was wrapped
|
||||||
|
self.assertIsNotNone(wrapped_client.chat.completions.create)
|
||||||
|
|
||||||
|
# Call the wrapped method
|
||||||
|
result = wrapped_client.chat.completions.create(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
self.assertEqual(result, mock_response)
|
||||||
|
|
||||||
|
def test_span_creation_in_trace(self):
|
||||||
|
"""Test that a span is created when inside a trace context."""
|
||||||
|
client = self._create_mock_openai_client()
|
||||||
|
|
||||||
|
# Set up mock response BEFORE wrapping
|
||||||
|
mock_response = self._create_mock_response()
|
||||||
|
client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
# Wrap client
|
||||||
|
wrapped_client = wrap_openai(client)
|
||||||
|
|
||||||
|
# Call inside a trace
|
||||||
|
with agentlens.trace("test_trace"):
|
||||||
|
wrapped_client.chat.completions.create(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify span was created via sent_traces
|
||||||
|
traces = get_client().sent_traces
|
||||||
|
self.assertEqual(len(traces), 1)
|
||||||
|
trace = traces[0]
|
||||||
|
self.assertEqual(len(trace.spans), 1)
|
||||||
|
|
||||||
|
span = trace.spans[0]
|
||||||
|
self.assertEqual(span.type, SpanType.LLM_CALL.value)
|
||||||
|
self.assertEqual(span.name, "openai.gpt-4")
|
||||||
|
self.assertEqual(span.status, SpanStatus.COMPLETED.value)
|
||||||
|
self.assertIsNotNone(span.duration_ms)
|
||||||
|
|
||||||
|
def test_function_call_decision_logging(self):
|
||||||
|
"""Test that tool calls are logged as decision points."""
|
||||||
|
client = self._create_mock_openai_client()
|
||||||
|
|
||||||
|
mock_response = self._create_mock_response(
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "search",
|
||||||
|
"arguments": {"query": "test"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "calculate",
|
||||||
|
"arguments": {"expression": "1+1"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
wrapped_client = wrap_openai(client)
|
||||||
|
|
||||||
|
# Call inside a trace
|
||||||
|
with agentlens.trace("test_trace"):
|
||||||
|
wrapped_client.chat.completions.create(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Search for something"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify decision points were created
|
||||||
|
traces = get_client().sent_traces
|
||||||
|
self.assertEqual(len(traces), 1)
|
||||||
|
trace = traces[0]
|
||||||
|
|
||||||
|
# Should have 2 decision points (one per tool call)
|
||||||
|
self.assertEqual(len(trace.decision_points), 2)
|
||||||
|
|
||||||
|
# Check first decision
|
||||||
|
decision1 = trace.decision_points[0]
|
||||||
|
self.assertEqual(decision1.type, "TOOL_SELECTION")
|
||||||
|
self.assertEqual(decision1.chosen["name"], "search")
|
||||||
|
self.assertEqual(decision1.chosen["arguments"], {"query": "test"})
|
||||||
|
self.assertEqual(decision1.alternatives, [])
|
||||||
|
|
||||||
|
# Check second decision
|
||||||
|
decision2 = trace.decision_points[1]
|
||||||
|
self.assertEqual(decision2.type, "TOOL_SELECTION")
|
||||||
|
self.assertEqual(decision2.chosen["name"], "calculate")
|
||||||
|
self.assertEqual(decision2.chosen["arguments"], {"expression": "1+1"})
|
||||||
|
|
||||||
|
def test_error_handling(self):
|
||||||
|
"""Test that errors are handled and span has ERROR status."""
|
||||||
|
client = self._create_mock_openai_client()
|
||||||
|
|
||||||
|
client.chat.completions.create.side_effect = Exception("API Error")
|
||||||
|
|
||||||
|
wrapped_client = wrap_openai(client)
|
||||||
|
|
||||||
|
# Call inside a trace and expect exception
|
||||||
|
with agentlens.trace("test_trace"):
|
||||||
|
with self.assertRaises(Exception) as cm:
|
||||||
|
wrapped_client.chat.completions.create(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(str(cm.exception), "API Error")
|
||||||
|
|
||||||
|
# Verify span has ERROR status
|
||||||
|
traces = get_client().sent_traces
|
||||||
|
self.assertEqual(len(traces), 1)
|
||||||
|
trace = traces[0]
|
||||||
|
|
||||||
|
self.assertEqual(len(trace.spans), 1)
|
||||||
|
span = trace.spans[0]
|
||||||
|
self.assertEqual(span.status, SpanStatus.ERROR.value)
|
||||||
|
self.assertEqual(span.status_message, "API Error")
|
||||||
|
|
||||||
|
# Verify error event was created
|
||||||
|
self.assertEqual(len(trace.events), 1)
|
||||||
|
event = trace.events[0]
|
||||||
|
self.assertEqual(event.type, "ERROR")
|
||||||
|
self.assertIn("API Error", event.name)
|
||||||
|
self.assertEqual(event.metadata, {"error_type": "Exception"})
|
||||||
|
|
||||||
|
def test_cost_estimation_gpt4(self):
|
||||||
|
"""Test cost calculation for gpt-4."""
|
||||||
|
client = self._create_mock_openai_client()
|
||||||
|
|
||||||
|
mock_response = self._create_mock_response(
|
||||||
|
model="gpt-4", prompt_tokens=1000, completion_tokens=500
|
||||||
|
)
|
||||||
|
client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
wrapped_client = wrap_openai(client)
|
||||||
|
|
||||||
|
# Call inside a trace
|
||||||
|
with agentlens.trace("test_trace"):
|
||||||
|
wrapped_client.chat.completions.create(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify cost calculation
|
||||||
|
traces = get_client().sent_traces
|
||||||
|
trace = traces[0]
|
||||||
|
span = trace.spans[0]
|
||||||
|
|
||||||
|
# gpt-4: $0.03/1K input, $0.06/1K output
|
||||||
|
# 1000 * 0.03 = 0.03
|
||||||
|
# 500 * 0.06 = 0.03
|
||||||
|
# Total = 0.06
|
||||||
|
self.assertIsNotNone(span.cost_usd)
|
||||||
|
self.assertAlmostEqual(span.cost_usd, 0.06, places=4)
|
||||||
|
|
||||||
|
def test_cost_estimation_gpt4o_mini(self):
|
||||||
|
"""Test cost calculation for gpt-4o-mini."""
|
||||||
|
client = self._create_mock_openai_client()
|
||||||
|
|
||||||
|
mock_response = self._create_mock_response(
|
||||||
|
model="gpt-4o-mini", prompt_tokens=1000, completion_tokens=500
|
||||||
|
)
|
||||||
|
client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
wrapped_client = wrap_openai(client)
|
||||||
|
|
||||||
|
# Call inside a trace
|
||||||
|
with agentlens.trace("test_trace"):
|
||||||
|
wrapped_client.chat.completions.create(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify cost calculation
|
||||||
|
traces = get_client().sent_traces
|
||||||
|
trace = traces[0]
|
||||||
|
span = trace.spans[0]
|
||||||
|
|
||||||
|
# gpt-4o-mini: $0.00015/1K input, $0.0006/1K output
|
||||||
|
# 1000 * 0.00015 = 0.00015
|
||||||
|
# 500 * 0.0006 = 0.0003
|
||||||
|
# Total = 0.00045
|
||||||
|
self.assertIsNotNone(span.cost_usd)
|
||||||
|
self.assertAlmostEqual(span.cost_usd, 0.00045, places=6)
|
||||||
|
|
||||||
|
def test_cost_estimation_unknown_model(self):
|
||||||
|
"""Test that unknown model doesn't calculate cost."""
|
||||||
|
client = self._create_mock_openai_client()
|
||||||
|
|
||||||
|
mock_response = self._create_mock_response(
|
||||||
|
model="unknown-model", prompt_tokens=1000, completion_tokens=500
|
||||||
|
)
|
||||||
|
client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
wrapped_client = wrap_openai(client)
|
||||||
|
|
||||||
|
# Call inside a trace
|
||||||
|
with agentlens.trace("test_trace"):
|
||||||
|
wrapped_client.chat.completions.create(
|
||||||
|
model="unknown-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify no cost calculation
|
||||||
|
traces = get_client().sent_traces
|
||||||
|
trace = traces[0]
|
||||||
|
span = trace.spans[0]
|
||||||
|
|
||||||
|
self.assertIsNone(span.cost_usd)
|
||||||
|
|
||||||
|
def test_outside_trace_context(self):
|
||||||
|
"""Test that a standalone trace is created when not in a trace context."""
|
||||||
|
client = self._create_mock_openai_client()
|
||||||
|
|
||||||
|
# Set up mock response BEFORE wrapping
|
||||||
|
mock_response = self._create_mock_response()
|
||||||
|
client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
# Wrap the client
|
||||||
|
wrapped_client = wrap_openai(client)
|
||||||
|
|
||||||
|
# Call WITHOUT being inside a trace context
|
||||||
|
wrapped_client.chat.completions.create(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify a standalone trace was created
|
||||||
|
traces = get_client().sent_traces
|
||||||
|
self.assertEqual(len(traces), 1)
|
||||||
|
|
||||||
|
trace = traces[0]
|
||||||
|
self.assertEqual(trace.name, "openai-gpt-4")
|
||||||
|
self.assertEqual(trace.status, TraceStatus.COMPLETED.value)
|
||||||
|
self.assertEqual(len(trace.spans), 1)
|
||||||
|
|
||||||
|
span = trace.spans[0]
|
||||||
|
self.assertEqual(span.type, SpanType.LLM_CALL.value)
|
||||||
|
self.assertEqual(span.name, "openai.gpt-4")
|
||||||
|
|
||||||
|
def test_message_truncation(self):
|
||||||
|
"""Test that long messages are truncated for privacy."""
|
||||||
|
client = self._create_mock_openai_client()
|
||||||
|
|
||||||
|
mock_response = self._create_mock_response()
|
||||||
|
client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
wrapped_client = wrap_openai(client)
|
||||||
|
|
||||||
|
# Call with very long message
|
||||||
|
long_content = "x" * 1000
|
||||||
|
with agentlens.trace("test_trace"):
|
||||||
|
wrapped_client.chat.completions.create(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": long_content}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify truncation in span input
|
||||||
|
traces = get_client().sent_traces
|
||||||
|
trace = traces[0]
|
||||||
|
span = trace.spans[0]
|
||||||
|
|
||||||
|
self.assertIsNotNone(span.input_data)
|
||||||
|
input_messages = span.input_data["messages"]
|
||||||
|
self.assertEqual(len(input_messages), 1)
|
||||||
|
|
||||||
|
# Content should be truncated to ~500 chars with ...
|
||||||
|
content = input_messages[0]["content"]
|
||||||
|
self.assertTrue(len(content) < 1000)
|
||||||
|
self.assertTrue(content.endswith("..."))
|
||||||
|
|
||||||
|
def test_token_tracking(self):
|
||||||
|
"""Test that token counts are tracked."""
|
||||||
|
client = self._create_mock_openai_client()
|
||||||
|
|
||||||
|
mock_response = self._create_mock_response(
|
||||||
|
prompt_tokens=123, completion_tokens=456, total_tokens=579
|
||||||
|
)
|
||||||
|
client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
wrapped_client = wrap_openai(client)
|
||||||
|
|
||||||
|
# Call inside a trace
|
||||||
|
with agentlens.trace("test_trace"):
|
||||||
|
wrapped_client.chat.completions.create(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify token tracking
|
||||||
|
traces = get_client().sent_traces
|
||||||
|
trace = traces[0]
|
||||||
|
span = trace.spans[0]
|
||||||
|
|
||||||
|
self.assertEqual(span.token_count, 579)
|
||||||
|
|
||||||
|
def test_metadata_tracking(self):
|
||||||
|
"""Test that metadata includes model, temperature, max_tokens."""
|
||||||
|
client = self._create_mock_openai_client()
|
||||||
|
|
||||||
|
mock_response = self._create_mock_response()
|
||||||
|
client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
wrapped_client = wrap_openai(client)
|
||||||
|
|
||||||
|
# Call with parameters
|
||||||
|
with agentlens.trace("test_trace"):
|
||||||
|
wrapped_client.chat.completions.create(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify metadata
|
||||||
|
traces = get_client().sent_traces
|
||||||
|
trace = traces[0]
|
||||||
|
span = trace.spans[0]
|
||||||
|
|
||||||
|
self.assertIsNotNone(span.metadata)
|
||||||
|
self.assertEqual(span.metadata["model"], "gpt-4")
|
||||||
|
self.assertEqual(span.metadata["temperature"], 0.7)
|
||||||
|
self.assertEqual(span.metadata["max_tokens"], 1000)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user