- SDK: client with BatchTransport, trace decorator/context manager, log_decision, thread-local context stack, nested trace→span support - API: POST /api/traces (batch ingest), GET /api/traces (paginated list), GET /api/traces/[id] (full trace with relations), GET /api/health - Tests: 8 unit tests for SDK (all passing) - Transport: thread-safe buffer with background flush thread
130 lines
4.2 KiB
Python
130 lines
4.2 KiB
Python
import unittest
|
|
|
|
from agentlens import init, shutdown, get_client, trace, log_decision
|
|
from agentlens.models import TraceData, TraceStatus
|
|
|
|
|
|
class TestSDK(unittest.TestCase):
|
|
def setUp(self):
|
|
shutdown()
|
|
init(api_key="test_key", enabled=False)
|
|
|
|
def tearDown(self):
|
|
shutdown()
|
|
|
|
def test_init_and_shutdown(self):
|
|
client = get_client()
|
|
self.assertIsNotNone(client)
|
|
self.assertEqual(client.api_key, "test_key")
|
|
self.assertFalse(client.enabled)
|
|
shutdown()
|
|
self.assertIsNone(get_client())
|
|
|
|
def test_decorator_sync_function(self):
|
|
@trace("test_decorator")
|
|
def test_func():
|
|
return 42
|
|
|
|
result = test_func()
|
|
self.assertEqual(result, 42)
|
|
|
|
client = get_client()
|
|
traces = client.sent_traces
|
|
self.assertEqual(len(traces), 1)
|
|
self.assertEqual(traces[0].name, "test_decorator")
|
|
self.assertEqual(traces[0].status, TraceStatus.COMPLETED.value)
|
|
self.assertIsNotNone(traces[0].total_duration)
|
|
|
|
def test_context_manager(self):
|
|
with trace("test_context") as t:
|
|
pass
|
|
|
|
client = get_client()
|
|
traces = client.sent_traces
|
|
self.assertEqual(len(traces), 1)
|
|
self.assertEqual(traces[0].name, "test_context")
|
|
self.assertEqual(traces[0].status, TraceStatus.COMPLETED.value)
|
|
|
|
def test_nested_traces(self):
|
|
with trace("outer_trace"):
|
|
with trace("inner_span"):
|
|
pass
|
|
|
|
client = get_client()
|
|
traces = client.sent_traces
|
|
self.assertEqual(len(traces), 1)
|
|
self.assertEqual(traces[0].name, "outer_trace")
|
|
self.assertEqual(len(traces[0].spans), 1)
|
|
self.assertEqual(traces[0].spans[0].name, "inner_span")
|
|
|
|
def test_log_decision(self):
|
|
with trace("test_trace"):
|
|
log_decision(
|
|
type="tool_selection",
|
|
chosen={"name": "search"},
|
|
alternatives=[{"name": "calculate"}, {"name": "browse"}],
|
|
reasoning="Search is best for finding information",
|
|
)
|
|
|
|
client = get_client()
|
|
traces = client.sent_traces
|
|
self.assertEqual(len(traces), 1)
|
|
self.assertEqual(len(traces[0].decision_points), 1)
|
|
decision = traces[0].decision_points[0]
|
|
self.assertEqual(decision.type, "tool_selection")
|
|
self.assertEqual(decision.chosen, {"name": "search"})
|
|
self.assertEqual(len(decision.alternatives), 2)
|
|
self.assertEqual(decision.reasoning, "Search is best for finding information")
|
|
|
|
def test_error_handling(self):
|
|
with self.assertRaises(ValueError):
|
|
with trace("test_error"):
|
|
raise ValueError("Test error")
|
|
|
|
client = get_client()
|
|
traces = client.sent_traces
|
|
self.assertEqual(len(traces), 1)
|
|
self.assertEqual(traces[0].name, "test_error")
|
|
self.assertEqual(traces[0].status, TraceStatus.ERROR.value)
|
|
self.assertEqual(len(traces[0].events), 1)
|
|
self.assertEqual(traces[0].events[0].type, "ERROR")
|
|
self.assertIn("Test error", traces[0].events[0].name)
|
|
|
|
def test_log_decision_outside_trace(self):
|
|
shutdown()
|
|
|
|
log_decision(
|
|
type="tool_selection",
|
|
chosen={"name": "search"},
|
|
alternatives=[],
|
|
)
|
|
|
|
init(api_key="test_key", enabled=False)
|
|
|
|
client = get_client()
|
|
self.assertEqual(len(client.sent_traces), 0)
|
|
|
|
def test_model_serialization(self):
|
|
trace_data = TraceData(
|
|
name="test",
|
|
tags=["tag1", "tag2"],
|
|
session_id="session123",
|
|
metadata={"key": "value"},
|
|
)
|
|
|
|
trace_dict = trace_data.to_dict()
|
|
|
|
self.assertEqual(trace_dict["name"], "test")
|
|
self.assertEqual(trace_dict["tags"], ["tag1", "tag2"])
|
|
self.assertEqual(trace_dict["sessionId"], "session123")
|
|
self.assertEqual(trace_dict["metadata"], {"key": "value"})
|
|
self.assertEqual(trace_dict["status"], "RUNNING")
|
|
self.assertIn("startedAt", trace_dict)
|
|
self.assertIn("decisionPoints", trace_dict)
|
|
self.assertIn("spans", trace_dict)
|
|
self.assertIn("events", trace_dict)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|