import unittest from agentlens import init, shutdown, get_client, trace, log_decision from agentlens.models import TraceData, TraceStatus class TestSDK(unittest.TestCase): def setUp(self): shutdown() init(api_key="test_key", enabled=False) def tearDown(self): shutdown() def test_init_and_shutdown(self): client = get_client() self.assertIsNotNone(client) self.assertEqual(client.api_key, "test_key") self.assertFalse(client.enabled) shutdown() self.assertIsNone(get_client()) def test_decorator_sync_function(self): @trace("test_decorator") def test_func(): return 42 result = test_func() self.assertEqual(result, 42) client = get_client() traces = client.sent_traces self.assertEqual(len(traces), 1) self.assertEqual(traces[0].name, "test_decorator") self.assertEqual(traces[0].status, TraceStatus.COMPLETED.value) self.assertIsNotNone(traces[0].total_duration) def test_context_manager(self): with trace("test_context") as t: pass client = get_client() traces = client.sent_traces self.assertEqual(len(traces), 1) self.assertEqual(traces[0].name, "test_context") self.assertEqual(traces[0].status, TraceStatus.COMPLETED.value) def test_nested_traces(self): with trace("outer_trace"): with trace("inner_span"): pass client = get_client() traces = client.sent_traces self.assertEqual(len(traces), 1) self.assertEqual(traces[0].name, "outer_trace") self.assertEqual(len(traces[0].spans), 1) self.assertEqual(traces[0].spans[0].name, "inner_span") def test_log_decision(self): with trace("test_trace"): log_decision( type="tool_selection", chosen={"name": "search"}, alternatives=[{"name": "calculate"}, {"name": "browse"}], reasoning="Search is best for finding information", ) client = get_client() traces = client.sent_traces self.assertEqual(len(traces), 1) self.assertEqual(len(traces[0].decision_points), 1) decision = traces[0].decision_points[0] self.assertEqual(decision.type, "tool_selection") self.assertEqual(decision.chosen, {"name": "search"}) self.assertEqual(len(decision.alternatives), 2) self.assertEqual(decision.reasoning, "Search is best for finding information") def test_error_handling(self): with self.assertRaises(ValueError): with trace("test_error"): raise ValueError("Test error") client = get_client() traces = client.sent_traces self.assertEqual(len(traces), 1) self.assertEqual(traces[0].name, "test_error") self.assertEqual(traces[0].status, TraceStatus.ERROR.value) self.assertEqual(len(traces[0].events), 1) self.assertEqual(traces[0].events[0].type, "ERROR") self.assertIn("Test error", traces[0].events[0].name) def test_log_decision_outside_trace(self): shutdown() log_decision( type="tool_selection", chosen={"name": "search"}, alternatives=[], ) init(api_key="test_key", enabled=False) client = get_client() self.assertEqual(len(client.sent_traces), 0) def test_model_serialization(self): trace_data = TraceData( name="test", tags=["tag1", "tag2"], session_id="session123", metadata={"key": "value"}, ) trace_dict = trace_data.to_dict() self.assertEqual(trace_dict["name"], "test") self.assertEqual(trace_dict["tags"], ["tag1", "tag2"]) self.assertEqual(trace_dict["sessionId"], "session123") self.assertEqual(trace_dict["metadata"], {"key": "value"}) self.assertEqual(trace_dict["status"], "RUNNING") self.assertIn("startedAt", trace_dict) self.assertIn("decisionPoints", trace_dict) self.assertIn("spans", trace_dict) self.assertIn("events", trace_dict) if __name__ == "__main__": unittest.main()