From 1989366844acd41497265c3bee35f348c904c3b9 Mon Sep 17 00:00:00 2001 From: Vectry Date: Tue, 10 Feb 2026 00:48:48 +0000 Subject: [PATCH] fix: OpenAI wrapper - remove duplicate span block, fix tool call extraction, fix cost model matching - Remove duplicate span append + tool call decision logging block (lines 328-426) - Fix _extract_tool_calls_from_response to use getattr() instead of .get() on objects - Fix _calculate_cost to use exact match first, then longest-prefix match (prevents gpt-4o-mini matching gpt-4 pricing) - Fix test mock setup: set return_value BEFORE wrap_openai() so wrapper captures correct original - All 11 OpenAI integration tests + 8 SDK tests passing (19/19) --- .../agentlens/integrations/openai.py | 629 +++++++++++++++++- .../tests/test_openai_integration.py | 414 ++++++++++++ 2 files changed, 1028 insertions(+), 15 deletions(-) create mode 100644 packages/sdk-python/tests/test_openai_integration.py diff --git a/packages/sdk-python/agentlens/integrations/openai.py b/packages/sdk-python/agentlens/integrations/openai.py index b248bac..c9f184b 100644 --- a/packages/sdk-python/agentlens/integrations/openai.py +++ b/packages/sdk-python/agentlens/integrations/openai.py @@ -1,7 +1,604 @@ -"""OpenAI integration for AgentLens.""" +"""OpenAI integration for AgentLens. -from typing import Any, Optional +This module provides a wrapper that auto-instruments OpenAI API calls with +tracing, span creation, decision logging for function/tool calls, and token tracking. +""" + +import json +import logging +import time from functools import wraps +from typing import Any, Dict, Iterator, List, Optional + +from agentlens.models import ( + Event, + EventType, + _now_iso, +) +from agentlens.trace import ( + TraceContext, + _get_context_stack, + get_current_span_id, + get_current_trace, +) + +logger = logging.getLogger("agentlens") + +# Cost per 1K tokens (input/output) for common models +_MODEL_COSTS: Dict[str, tuple] = { + "gpt-4": (0.03, 0.06), + "gpt-4-32k": (0.06, 0.12), + "gpt-4-turbo": (0.01, 0.03), + "gpt-4-turbo-2024-04-09": (0.01, 0.03), + "gpt-4-turbo-preview": (0.01, 0.03), + "gpt-4o": (0.005, 0.015), + "gpt-4o-2024-05-13": (0.005, 0.015), + "gpt-4o-2024-08-06": (0.0025, 0.01), + "gpt-4o-mini": (0.00015, 0.0006), + "gpt-4o-mini-2024-07-18": (0.00015, 0.0006), + "gpt-3.5-turbo": (0.0005, 0.0015), + "gpt-3.5-turbo-0125": (0.0005, 0.0015), + "gpt-3.5-turbo-1106": (0.001, 0.002), + "gpt-3.5-turbo-16k": (0.003, 0.004), +} + + +class _MockFunction: + def __init__(self, n: str, a: str): + self.name = n + self.arguments = a + + +class _MockToolCall: + def __init__(self, name: str, args: str): + self.function = _MockFunction(name, args) + + +class _MockMessage: + def __init__(self, content: Optional[str], tool_calls_list: List[_MockToolCall]): + self.content = content + self.tool_calls = tool_calls_list + + +class _MockChoice: + def __init__(self, content: Optional[str], tool_calls_list: List[_MockToolCall]): + self.message = _MockMessage(content, tool_calls_list) + + +class _MockUsage: + def __init__(self): + self.prompt_tokens: Optional[int] = None + self.completion_tokens: Optional[int] = None + self.total_tokens: Optional[int] = None + + +class _MockResponse: + def __init__(self): + self.model: str = "unknown" + self.choices: List[_MockChoice] = [] + self.usage = _MockUsage() + + +def _truncate_data(data: Any, max_length: int = 500) -> Any: + """Truncate data for privacy while preserving structure.""" + if isinstance(data, str): + return data[:max_length] + "..." if len(data) > max_length else data + elif isinstance(data, dict): + return {k: _truncate_data(v, max_length) for k, v in data.items()} + elif isinstance(data, list): + return [_truncate_data(item, max_length) for item in data] + else: + return data + + +def _calculate_cost( + model: str, prompt_tokens: int, completion_tokens: int +) -> Optional[float]: + """Calculate cost in USD based on model pricing.""" + model_lower = model.lower() + + if model_lower in _MODEL_COSTS: + input_cost, output_cost = _MODEL_COSTS[model_lower] + return (float(prompt_tokens) / 1000.0) * input_cost + float( + completion_tokens + ) / 1000.0 * output_cost + + best_match = None + best_len = 0 + for model_name, costs in _MODEL_COSTS.items(): + if model_lower.startswith(model_name.lower()) and len(model_name) > best_len: + best_match = costs + best_len = len(model_name) + + if best_match: + input_cost, output_cost = best_match + return (float(prompt_tokens) / 1000.0) * input_cost + float( + completion_tokens + ) / 1000.0 * output_cost + + return None + + +def _extract_messages_truncated(messages: List[Any]) -> List[Dict[str, Any]]: + """Extract and truncate message content.""" + truncated = [] + for msg in messages: + if isinstance(msg, dict): + truncated_msg = {"role": msg.get("role", "unknown")} + content = msg.get("content") + if content is not None: + truncated_msg["content"] = _truncate_data(str(content)) + truncated.append(truncated_msg) + else: + # Handle message objects + role = getattr(msg, "role", "unknown") + content = getattr(msg, "content", "") + truncated.append({"role": role, "content": _truncate_data(str(content))}) + return truncated + + +def _extract_content_from_response(response: Any) -> Optional[str]: + """Extract content from OpenAI response.""" + if hasattr(response, "choices") and response.choices: + message = response.choices[0].message + if hasattr(message, "content"): + return _truncate_data(str(message.content)) + return None + + +def _extract_tool_calls_from_response(response: Any) -> List[Dict[str, Any]]: + """Extract tool calls from OpenAI response.""" + tool_calls = [] + if hasattr(response, "choices") and response.choices: + message = response.choices[0].message + if hasattr(message, "tool_calls") and message.tool_calls: + for tc in message.tool_calls: + func = getattr(tc, "function", None) + name = getattr(func, "name", "unknown") if func else "unknown" + args_str = getattr(func, "arguments", "{}") if func else "{}" + call_dict = {"name": name} + try: + call_dict["arguments"] = json.loads(args_str) + except (json.JSONDecodeError, TypeError): + call_dict["arguments"] = args_str + tool_calls.append(call_dict) + return tool_calls + + +class _StreamWrapper: + """Wrapper for OpenAI stream responses to collect chunks.""" + + def __init__(self, original_stream: Any, trace_ctx: Optional[TraceContext]): + self._original_stream = original_stream + self._trace_ctx = trace_ctx + self._chunks: List[Any] = [] + self._start_time = time.time() + self._model = None + self._temperature = None + self._max_tokens = None + self._messages = None + self._parent_span_id = get_current_span_id() + + def set_params( + self, + model: str, + temperature: Optional[float], + max_tokens: Optional[int], + messages: List[Any], + ) -> None: + self._model = model + self._temperature = temperature + self._max_tokens = max_tokens + self._messages = messages + + def __iter__(self) -> Iterator[Any]: + return self + + def __next__(self) -> Any: + chunk = next(self._original_stream) + self._chunks.append(chunk) + return chunk + + def finalize(self) -> None: + """Create span after stream is fully consumed.""" + if not self._chunks: + return + + # Build a mock response from chunks + response = _MockResponse() + + # Extract model from first chunk + if hasattr(self._chunks[0], "model"): + response.model = self._chunks[0].model + else: + response.model = self._model or "unknown" + + # Extract message content and tool calls + message_content = None + tool_calls = [] + + for chunk in self._chunks: + if hasattr(chunk, "choices") and chunk.choices: + delta = chunk.choices[0].delta + if hasattr(delta, "content") and delta.content: + if message_content is None: + message_content = "" + message_content += delta.content + if hasattr(delta, "tool_calls") and delta.tool_calls: + for tc in delta.tool_calls: + if hasattr(tc, "function"): + tc_idx = tc.index if hasattr(tc, "index") else 0 + while len(tool_calls) <= tc_idx: + tool_calls.append({"name": None, "arguments": ""}) + if hasattr(tc.function, "name") and tc.function.name: + tool_calls[tc_idx]["name"] = tc.function.name + if ( + hasattr(tc.function, "arguments") + and tc.function.arguments + ): + tool_calls[tc_idx]["arguments"] += tc.function.arguments + + response.choices = [_MockChoice(message_content, [])] + + # Extract tool calls as objects + tool_call_objects = [] + for tc in tool_calls: + if tc.get("name"): + try: + args = tc.get("arguments", "{}") + if isinstance(args, str): + args_dict = json.loads(args) + else: + args_dict = args + tool_call_objects.append( + _MockToolCall(tc["name"], json.dumps(args_dict)) + ) + except json.JSONDecodeError: + pass + + response.choices[0].message.tool_calls = tool_call_objects + + # Build mock usage (can't get exact tokens from stream) + response.usage.prompt_tokens = None + response.usage.completion_tokens = None + response.usage.total_tokens = None + + # Create the span + _create_llm_span( + response=response, + start_time=self._start_time, + model=self._model or response.model, + temperature=self._temperature, + max_tokens=self._max_tokens, + messages=self._messages or [], + parent_span_id=self._parent_span_id, + trace_ctx=self._trace_ctx, + ) + + # Close trace context if we created one + if self._trace_ctx: + self._trace_ctx.__exit__(None, None, None) + + +def _create_llm_span( + response: Any, + start_time: float, + model: str, + temperature: Optional[float], + max_tokens: Optional[int], + messages: List[Any], + parent_span_id: Optional[str], + trace_ctx: Optional[TraceContext], +) -> None: + """Create LLM span from OpenAI response.""" + from agentlens.models import Span, SpanStatus, SpanType + + current_trace = get_current_trace() + if current_trace is None: + logger.warning("No active trace, skipping span creation") + return + + end_time = time.time() + duration_ms = int((end_time - start_time) * 1000) + + # Extract token usage + token_count = None + cost_usd = None + if hasattr(response, "usage"): + prompt_tokens = getattr(response.usage, "prompt_tokens", None) + completion_tokens = getattr(response.usage, "completion_tokens", None) + total_tokens = getattr(response.usage, "total_tokens", None) + + if total_tokens is not None: + token_count = total_tokens + + if prompt_tokens is not None and completion_tokens is not None: + cost_usd = _calculate_cost(model, prompt_tokens, completion_tokens) + + # Create span + span_name = f"openai.{model}" + span = Span( + name=span_name, + type=SpanType.LLM_CALL.value, + parent_span_id=parent_span_id, + input_data={"messages": _extract_messages_truncated(messages)}, + output_data={"content": _extract_content_from_response(response)}, + token_count=token_count, + cost_usd=cost_usd, + duration_ms=duration_ms, + status=SpanStatus.COMPLETED.value, + started_at=_now_iso(), + ended_at=_now_iso(), + metadata={ + "model": model, + "temperature": temperature, + "max_tokens": max_tokens, + }, + ) + + current_trace.spans.append(span) + + # Push onto context stack for decision logging + stack = _get_context_stack() + stack.append(span) + + # Log tool call decisions + tool_calls = _extract_tool_calls_from_response(response) + if tool_calls: + from agentlens.decision import log_decision + + # Try to get reasoning from messages + reasoning = None + for msg in reversed(messages): + if isinstance(msg, dict): + if msg.get("role") == "assistant": + reasoning = msg.get("content") + if reasoning: + reasoning = _truncate_data(str(reasoning)) + break + else: + role = getattr(msg, "role", None) + if role == "assistant": + reasoning = getattr(msg, "content", None) + if reasoning: + reasoning = _truncate_data(str(reasoning)) + break + + # Build context snapshot + context_snapshot = None + if hasattr(response, "usage"): + context_snapshot = { + "model": model, + "prompt_tokens": getattr(response.usage, "prompt_tokens"), + "completion_tokens": getattr(response.usage, "completion_tokens"), + } + + for tool_call in tool_calls: + log_decision( + type="TOOL_SELECTION", + chosen={ + "name": tool_call.get("name", "unknown"), + "arguments": tool_call.get("arguments", {}), + }, + alternatives=[], + reasoning=reasoning, + context_snapshot=context_snapshot, + ) + + # Always pop from context stack + if stack and stack[-1] == span: + stack.pop() + elif stack and isinstance(stack[-1], Span) and stack[-1].id == span.id: + stack.pop() + + +def _wrap_create(original_create: Any, is_async: bool = False) -> Any: + """Wrap OpenAI chat.completions.create method.""" + + if is_async: + + @wraps(original_create) + async def async_traced_create(*args: Any, **kwargs: Any) -> Any: + # Extract parameters + model = kwargs.get("model", "gpt-3.5-turbo") + temperature = kwargs.get("temperature") + max_tokens = kwargs.get("max_tokens") + messages = kwargs.get("messages", []) + stream = kwargs.get("stream", False) + + parent_span_id = get_current_span_id() + start_time = time.time() + + # Handle streaming + if stream: + # Create trace if needed + trace_ctx = None + if get_current_trace() is None: + trace_ctx = TraceContext(name=f"openai-{model}") + trace_ctx.__enter__() + + try: + original_stream = await original_create(*args, **kwargs) + + wrapper = _StreamWrapper(original_stream, trace_ctx) + wrapper.set_params(model, temperature, max_tokens, messages) + + return wrapper + except Exception as e: + if trace_ctx: + trace_ctx.__exit__(type(e), e, None) + raise + + # Non-streaming + trace_ctx = None + if get_current_trace() is None: + trace_ctx = TraceContext(name=f"openai-{model}") + trace_ctx.__enter__() + + try: + response = await original_create(*args, **kwargs) + + _create_llm_span( + response=response, + start_time=start_time, + model=model, + temperature=temperature, + max_tokens=max_tokens, + messages=messages, + parent_span_id=parent_span_id, + trace_ctx=trace_ctx, + ) + + # Close trace context if we created one + if trace_ctx is not None: + trace_ctx.__exit__(None, None, None) + + return response + except Exception as e: + _handle_error( + error=e, + start_time=start_time, + model=model, + temperature=temperature, + max_tokens=max_tokens, + messages=messages, + parent_span_id=parent_span_id, + trace_ctx=trace_ctx, + ) + raise + + return async_traced_create + + else: + + @wraps(original_create) + def traced_create(*args: Any, **kwargs: Any) -> Any: + # Extract parameters + model = kwargs.get("model", "gpt-3.5-turbo") + temperature = kwargs.get("temperature") + max_tokens = kwargs.get("max_tokens") + messages = kwargs.get("messages", []) + stream = kwargs.get("stream", False) + + parent_span_id = get_current_span_id() + start_time = time.time() + + # Handle streaming + if stream: + # Create trace if needed + trace_ctx = None + if get_current_trace() is None: + trace_ctx = TraceContext(name=f"openai-{model}") + trace_ctx.__enter__() + + try: + original_stream = original_create(*args, **kwargs) + + wrapper = _StreamWrapper(original_stream, trace_ctx) + wrapper.set_params(model, temperature, max_tokens, messages) + + return wrapper + except Exception as e: + if trace_ctx: + trace_ctx.__exit__(type(e), e, None) + raise + + # Non-streaming + trace_ctx = None + if get_current_trace() is None: + trace_ctx = TraceContext(name=f"openai-{model}") + trace_ctx.__enter__() + + try: + response = original_create(*args, **kwargs) + + _create_llm_span( + response=response, + start_time=start_time, + model=model, + temperature=temperature, + max_tokens=max_tokens, + messages=messages, + parent_span_id=parent_span_id, + trace_ctx=trace_ctx, + ) + + # Close trace context if we created one + if trace_ctx is not None: + trace_ctx.__exit__(None, None, None) + + return response + except Exception as e: + _handle_error( + error=e, + start_time=start_time, + model=model, + temperature=temperature, + max_tokens=max_tokens, + messages=messages, + parent_span_id=parent_span_id, + trace_ctx=trace_ctx, + ) + raise + + return traced_create + + +def _handle_error( + error: Exception, + start_time: float, + model: str, + temperature: Optional[float], + max_tokens: Optional[int], + messages: List[Any], + parent_span_id: Optional[str], + trace_ctx: Optional[TraceContext], +) -> None: + """Handle error by creating error span and event.""" + from agentlens.models import Span, SpanStatus, SpanType + + current_trace = get_current_trace() + if current_trace is None: + return + + end_time = time.time() + duration_ms = int((end_time - start_time) * 1000) + + # Create error span + span_name = f"openai.{model}" + span = Span( + name=span_name, + type=SpanType.LLM_CALL.value, + parent_span_id=parent_span_id, + input_data={"messages": _extract_messages_truncated(messages)}, + status=SpanStatus.ERROR.value, + status_message=str(error), + started_at=_now_iso(), + ended_at=_now_iso(), + duration_ms=duration_ms, + metadata={ + "model": model, + "temperature": temperature, + "max_tokens": max_tokens, + }, + ) + + current_trace.spans.append(span) + + # Create error event + error_event = Event( + type=EventType.ERROR.value, + name=f"{span_name}: {str(error)}", + span_id=span.id, + metadata={"error_type": type(error).__name__}, + ) + + current_trace.events.append(error_event) + + # Pop from context stack if needed + stack = _get_context_stack() + if stack and isinstance(stack[-1], Span) and stack[-1].id == span.id: + stack.pop() def wrap_openai(client: Any) -> Any: @@ -11,7 +608,7 @@ def wrap_openai(client: Any) -> Any: client: The OpenAI client to wrap. Returns: - Wrapped OpenAI client with AgentLens tracing enabled. + The same client instance with chat.completions.create wrapped. Example: import openai @@ -20,20 +617,22 @@ def wrap_openai(client: Any) -> Any: client = openai.OpenAI(api_key="sk-...") traced_client = wrap_openai(client) - response = traced_client.chat.completions.create(...) + response = traced_client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hello!"}] + ) """ original_create = client.chat.completions.create - @wraps(original_create) - def traced_create(*args: Any, **kwargs: Any) -> Any: - print("[AgentLens] OpenAI chat completion started") - try: - response = original_create(*args, **kwargs) - print("[AgentLens] OpenAI chat completion completed") - return response - except Exception as e: - print(f"[AgentLens] OpenAI error: {e}") - raise - + # Wrap synchronous method + traced_create = _wrap_create(original_create, is_async=False) client.chat.completions.create = traced_create + + # Try to wrap async method if available + if hasattr(client.chat.completions, "acreate"): + original_acreate = client.chat.completions.acreate + traced_acreate = _wrap_create(original_acreate, is_async=True) + client.chat.completions.acreate = traced_acreate + + logger.debug("OpenAI client wrapped with AgentLens tracing") return client diff --git a/packages/sdk-python/tests/test_openai_integration.py b/packages/sdk-python/tests/test_openai_integration.py new file mode 100644 index 0000000..6d16173 --- /dev/null +++ b/packages/sdk-python/tests/test_openai_integration.py @@ -0,0 +1,414 @@ +import json +import unittest +from unittest.mock import MagicMock, Mock + +import agentlens +from agentlens import init, shutdown, get_client +from agentlens.models import TraceStatus, SpanStatus, SpanType +from agentlens.integrations.openai import wrap_openai + + +class TestOpenAIIntegration(unittest.TestCase): + def setUp(self): + shutdown() + init(api_key="test_key", enabled=False) + + def tearDown(self): + shutdown() + + def _create_mock_openai_client(self): + """Create a mock OpenAI client with necessary structure.""" + client = MagicMock() + + # Mock chat.completions structure + client.chat = MagicMock() + client.chat.completions = MagicMock() + + # Mock the create method + client.chat.completions.create = MagicMock() + + return client + + def _create_mock_response( + self, + model="gpt-4", + content="Hello!", + prompt_tokens=10, + completion_tokens=5, + total_tokens=15, + tool_calls=None, + ): + """Create a mock OpenAI response object.""" + response = MagicMock() + response.model = model + + # Mock choices + choice = MagicMock() + message = MagicMock() + message.content = content + + if tool_calls: + # Build mock tool calls + mock_tool_calls = [] + for tc_data in tool_calls: + tc = MagicMock() + function = MagicMock() + function.name = tc_data["name"] + function.arguments = json.dumps(tc_data["arguments"]) + tc.function = function + mock_tool_calls.append(tc) + + message.tool_calls = mock_tool_calls + else: + message.tool_calls = [] + + choice.message = message + response.choices = [choice] + + # Mock usage with actual int objects + usage = MagicMock() + usage.prompt_tokens = prompt_tokens + usage.completion_tokens = completion_tokens + usage.total_tokens = total_tokens + response.usage = usage + + return response + + def test_basic_wrapping(self): + """Test basic wrapping of OpenAI client.""" + client = self._create_mock_openai_client() + + # Set up mock response BEFORE wrapping + mock_response = self._create_mock_response() + client.chat.completions.create.return_value = mock_response + + # Wrap the client + wrapped_client = wrap_openai(client) + + # Verify that method was wrapped + self.assertIsNotNone(wrapped_client.chat.completions.create) + + # Call the wrapped method + result = wrapped_client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hello"}], + ) + + # Verify the response + self.assertEqual(result, mock_response) + + def test_span_creation_in_trace(self): + """Test that a span is created when inside a trace context.""" + client = self._create_mock_openai_client() + + # Set up mock response BEFORE wrapping + mock_response = self._create_mock_response() + client.chat.completions.create.return_value = mock_response + + # Wrap client + wrapped_client = wrap_openai(client) + + # Call inside a trace + with agentlens.trace("test_trace"): + wrapped_client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hello"}], + ) + + # Verify span was created via sent_traces + traces = get_client().sent_traces + self.assertEqual(len(traces), 1) + trace = traces[0] + self.assertEqual(len(trace.spans), 1) + + span = trace.spans[0] + self.assertEqual(span.type, SpanType.LLM_CALL.value) + self.assertEqual(span.name, "openai.gpt-4") + self.assertEqual(span.status, SpanStatus.COMPLETED.value) + self.assertIsNotNone(span.duration_ms) + + def test_function_call_decision_logging(self): + """Test that tool calls are logged as decision points.""" + client = self._create_mock_openai_client() + + mock_response = self._create_mock_response( + tool_calls=[ + { + "name": "search", + "arguments": {"query": "test"}, + }, + { + "name": "calculate", + "arguments": {"expression": "1+1"}, + }, + ] + ) + client.chat.completions.create.return_value = mock_response + + wrapped_client = wrap_openai(client) + + # Call inside a trace + with agentlens.trace("test_trace"): + wrapped_client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Search for something"}], + ) + + # Verify decision points were created + traces = get_client().sent_traces + self.assertEqual(len(traces), 1) + trace = traces[0] + + # Should have 2 decision points (one per tool call) + self.assertEqual(len(trace.decision_points), 2) + + # Check first decision + decision1 = trace.decision_points[0] + self.assertEqual(decision1.type, "TOOL_SELECTION") + self.assertEqual(decision1.chosen["name"], "search") + self.assertEqual(decision1.chosen["arguments"], {"query": "test"}) + self.assertEqual(decision1.alternatives, []) + + # Check second decision + decision2 = trace.decision_points[1] + self.assertEqual(decision2.type, "TOOL_SELECTION") + self.assertEqual(decision2.chosen["name"], "calculate") + self.assertEqual(decision2.chosen["arguments"], {"expression": "1+1"}) + + def test_error_handling(self): + """Test that errors are handled and span has ERROR status.""" + client = self._create_mock_openai_client() + + client.chat.completions.create.side_effect = Exception("API Error") + + wrapped_client = wrap_openai(client) + + # Call inside a trace and expect exception + with agentlens.trace("test_trace"): + with self.assertRaises(Exception) as cm: + wrapped_client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hello"}], + ) + + self.assertEqual(str(cm.exception), "API Error") + + # Verify span has ERROR status + traces = get_client().sent_traces + self.assertEqual(len(traces), 1) + trace = traces[0] + + self.assertEqual(len(trace.spans), 1) + span = trace.spans[0] + self.assertEqual(span.status, SpanStatus.ERROR.value) + self.assertEqual(span.status_message, "API Error") + + # Verify error event was created + self.assertEqual(len(trace.events), 1) + event = trace.events[0] + self.assertEqual(event.type, "ERROR") + self.assertIn("API Error", event.name) + self.assertEqual(event.metadata, {"error_type": "Exception"}) + + def test_cost_estimation_gpt4(self): + """Test cost calculation for gpt-4.""" + client = self._create_mock_openai_client() + + mock_response = self._create_mock_response( + model="gpt-4", prompt_tokens=1000, completion_tokens=500 + ) + client.chat.completions.create.return_value = mock_response + + wrapped_client = wrap_openai(client) + + # Call inside a trace + with agentlens.trace("test_trace"): + wrapped_client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hello"}], + ) + + # Verify cost calculation + traces = get_client().sent_traces + trace = traces[0] + span = trace.spans[0] + + # gpt-4: $0.03/1K input, $0.06/1K output + # 1000 * 0.03 = 0.03 + # 500 * 0.06 = 0.03 + # Total = 0.06 + self.assertIsNotNone(span.cost_usd) + self.assertAlmostEqual(span.cost_usd, 0.06, places=4) + + def test_cost_estimation_gpt4o_mini(self): + """Test cost calculation for gpt-4o-mini.""" + client = self._create_mock_openai_client() + + mock_response = self._create_mock_response( + model="gpt-4o-mini", prompt_tokens=1000, completion_tokens=500 + ) + client.chat.completions.create.return_value = mock_response + + wrapped_client = wrap_openai(client) + + # Call inside a trace + with agentlens.trace("test_trace"): + wrapped_client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hello"}], + ) + + # Verify cost calculation + traces = get_client().sent_traces + trace = traces[0] + span = trace.spans[0] + + # gpt-4o-mini: $0.00015/1K input, $0.0006/1K output + # 1000 * 0.00015 = 0.00015 + # 500 * 0.0006 = 0.0003 + # Total = 0.00045 + self.assertIsNotNone(span.cost_usd) + self.assertAlmostEqual(span.cost_usd, 0.00045, places=6) + + def test_cost_estimation_unknown_model(self): + """Test that unknown model doesn't calculate cost.""" + client = self._create_mock_openai_client() + + mock_response = self._create_mock_response( + model="unknown-model", prompt_tokens=1000, completion_tokens=500 + ) + client.chat.completions.create.return_value = mock_response + + wrapped_client = wrap_openai(client) + + # Call inside a trace + with agentlens.trace("test_trace"): + wrapped_client.chat.completions.create( + model="unknown-model", + messages=[{"role": "user", "content": "Hello"}], + ) + + # Verify no cost calculation + traces = get_client().sent_traces + trace = traces[0] + span = trace.spans[0] + + self.assertIsNone(span.cost_usd) + + def test_outside_trace_context(self): + """Test that a standalone trace is created when not in a trace context.""" + client = self._create_mock_openai_client() + + # Set up mock response BEFORE wrapping + mock_response = self._create_mock_response() + client.chat.completions.create.return_value = mock_response + + # Wrap the client + wrapped_client = wrap_openai(client) + + # Call WITHOUT being inside a trace context + wrapped_client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hello"}], + ) + + # Verify a standalone trace was created + traces = get_client().sent_traces + self.assertEqual(len(traces), 1) + + trace = traces[0] + self.assertEqual(trace.name, "openai-gpt-4") + self.assertEqual(trace.status, TraceStatus.COMPLETED.value) + self.assertEqual(len(trace.spans), 1) + + span = trace.spans[0] + self.assertEqual(span.type, SpanType.LLM_CALL.value) + self.assertEqual(span.name, "openai.gpt-4") + + def test_message_truncation(self): + """Test that long messages are truncated for privacy.""" + client = self._create_mock_openai_client() + + mock_response = self._create_mock_response() + client.chat.completions.create.return_value = mock_response + + wrapped_client = wrap_openai(client) + + # Call with very long message + long_content = "x" * 1000 + with agentlens.trace("test_trace"): + wrapped_client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": long_content}], + ) + + # Verify truncation in span input + traces = get_client().sent_traces + trace = traces[0] + span = trace.spans[0] + + self.assertIsNotNone(span.input_data) + input_messages = span.input_data["messages"] + self.assertEqual(len(input_messages), 1) + + # Content should be truncated to ~500 chars with ... + content = input_messages[0]["content"] + self.assertTrue(len(content) < 1000) + self.assertTrue(content.endswith("...")) + + def test_token_tracking(self): + """Test that token counts are tracked.""" + client = self._create_mock_openai_client() + + mock_response = self._create_mock_response( + prompt_tokens=123, completion_tokens=456, total_tokens=579 + ) + client.chat.completions.create.return_value = mock_response + + wrapped_client = wrap_openai(client) + + # Call inside a trace + with agentlens.trace("test_trace"): + wrapped_client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hello"}], + ) + + # Verify token tracking + traces = get_client().sent_traces + trace = traces[0] + span = trace.spans[0] + + self.assertEqual(span.token_count, 579) + + def test_metadata_tracking(self): + """Test that metadata includes model, temperature, max_tokens.""" + client = self._create_mock_openai_client() + + mock_response = self._create_mock_response() + client.chat.completions.create.return_value = mock_response + + wrapped_client = wrap_openai(client) + + # Call with parameters + with agentlens.trace("test_trace"): + wrapped_client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hello"}], + temperature=0.7, + max_tokens=1000, + ) + + # Verify metadata + traces = get_client().sent_traces + trace = traces[0] + span = trace.spans[0] + + self.assertIsNotNone(span.metadata) + self.assertEqual(span.metadata["model"], "gpt-4") + self.assertEqual(span.metadata["temperature"], 0.7) + self.assertEqual(span.metadata["max_tokens"], 1000) + + +if __name__ == "__main__": + unittest.main()