From 3fe9013838c5f73a69bf0a33ee990849d02b4d13 Mon Sep 17 00:00:00 2001 From: Vectry Date: Mon, 9 Feb 2026 23:25:34 +0000 Subject: [PATCH] feat: Python SDK real implementation + API ingestion routes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - SDK: client with BatchTransport, trace decorator/context manager, log_decision, thread-local context stack, nested trace→span support - API: POST /api/traces (batch ingest), GET /api/traces (paginated list), GET /api/traces/[id] (full trace with relations), GET /api/health - Tests: 8 unit tests for SDK (all passing) - Transport: thread-safe buffer with background flush thread --- apps/web/src/app/api/health/route.ts | 5 + apps/web/src/app/api/traces/[id]/route.ts | 46 +++ apps/web/src/app/api/traces/route.ts | 324 +++++++++++++++++++++ apps/web/src/lib/prisma.ts | 9 + packages/sdk-python/agentlens/__init__.py | 45 ++- packages/sdk-python/agentlens/client.py | 104 +++++-- packages/sdk-python/agentlens/decision.py | 63 ++-- packages/sdk-python/agentlens/models.py | 211 ++++++++++++++ packages/sdk-python/agentlens/trace.py | 235 +++++++++++---- packages/sdk-python/agentlens/transport.py | 106 +++++-- packages/sdk-python/tests/__init__.py | 0 packages/sdk-python/tests/test_sdk.py | 129 ++++++++ 12 files changed, 1144 insertions(+), 133 deletions(-) create mode 100644 apps/web/src/app/api/health/route.ts create mode 100644 apps/web/src/app/api/traces/[id]/route.ts create mode 100644 apps/web/src/app/api/traces/route.ts create mode 100644 apps/web/src/lib/prisma.ts create mode 100644 packages/sdk-python/agentlens/models.py create mode 100644 packages/sdk-python/tests/__init__.py create mode 100644 packages/sdk-python/tests/test_sdk.py diff --git a/apps/web/src/app/api/health/route.ts b/apps/web/src/app/api/health/route.ts new file mode 100644 index 0000000..dd4e332 --- /dev/null +++ b/apps/web/src/app/api/health/route.ts @@ -0,0 +1,5 @@ +import { NextResponse } from "next/server"; + +export async function GET() { + return NextResponse.json({ status: "ok", service: "agentlens", timestamp: new Date().toISOString() }); +} diff --git a/apps/web/src/app/api/traces/[id]/route.ts b/apps/web/src/app/api/traces/[id]/route.ts new file mode 100644 index 0000000..5cef035 --- /dev/null +++ b/apps/web/src/app/api/traces/[id]/route.ts @@ -0,0 +1,46 @@ +import { NextResponse } from "next/server"; +import { prisma } from "@/lib/prisma"; + +// GET /api/traces/[id] — Get single trace with all relations +export async function GET( + _request: Request, + { params }: { params: Promise<{ id: string }> } +) { + try { + const { id } = await params; + + if (!id || typeof id !== "string") { + return NextResponse.json({ error: "Invalid trace ID" }, { status: 400 }); + } + + const trace = await prisma.trace.findUnique({ + where: { id }, + include: { + decisionPoints: { + orderBy: { + timestamp: "asc", + }, + }, + spans: { + orderBy: { + startedAt: "asc", + }, + }, + events: { + orderBy: { + timestamp: "asc", + }, + }, + }, + }); + + if (!trace) { + return NextResponse.json({ error: "Trace not found" }, { status: 404 }); + } + + return NextResponse.json({ trace }, { status: 200 }); + } catch (error) { + console.error("Error retrieving trace:", error); + return NextResponse.json({ error: "Internal server error" }, { status: 500 }); + } +} diff --git a/apps/web/src/app/api/traces/route.ts b/apps/web/src/app/api/traces/route.ts new file mode 100644 index 0000000..d3393fa --- /dev/null +++ b/apps/web/src/app/api/traces/route.ts @@ -0,0 +1,324 @@ +import { NextRequest, NextResponse } from "next/server"; +import { prisma } from "@/lib/prisma"; +import { Prisma } from "@agentlens/database"; + +// Types +interface DecisionPointPayload { + id: string; + type: "TOOL_SELECTION" | "ROUTING" | "RETRY" | "ESCALATION" | "MEMORY_RETRIEVAL" | "PLANNING" | "CUSTOM"; + reasoning?: string; + chosen: Prisma.JsonValue; + alternatives: Prisma.JsonValue[]; + contextSnapshot?: Prisma.JsonValue; + durationMs?: number; + costUsd?: number; + parentSpanId?: string; + timestamp: string; +} + +interface SpanPayload { + id: string; + parentSpanId?: string; + name: string; + type: "LLM_CALL" | "TOOL_CALL" | "MEMORY_OP" | "CHAIN" | "AGENT" | "CUSTOM"; + input?: Prisma.JsonValue; + output?: Prisma.JsonValue; + tokenCount?: number; + costUsd?: number; + durationMs?: number; + status: "RUNNING" | "COMPLETED" | "ERROR"; + statusMessage?: string; + startedAt: string; + endedAt?: string; + metadata?: Prisma.JsonValue; +} + +interface EventPayload { + id: string; + spanId?: string; + type: "ERROR" | "RETRY" | "FALLBACK" | "CONTEXT_OVERFLOW" | "USER_FEEDBACK" | "CUSTOM"; + name: string; + metadata?: Prisma.JsonValue; + timestamp: string; +} + +interface TracePayload { + id: string; + name: string; + sessionId?: string; + status: "RUNNING" | "COMPLETED" | "ERROR"; + tags: string[]; + metadata?: Prisma.JsonValue; + totalCost?: number; + totalTokens?: number; + totalDuration?: number; + startedAt: string; + endedAt?: string; + decisionPoints: DecisionPointPayload[]; + spans: SpanPayload[]; + events: EventPayload[]; +} + +interface BatchTracesRequest { + traces: TracePayload[]; +} + +// POST /api/traces — Batch ingest traces from SDK +export async function POST(request: NextRequest) { + try { + // Validate Authorization header + const authHeader = request.headers.get("authorization"); + if (!authHeader || !authHeader.startsWith("Bearer ")) { + return NextResponse.json({ error: "Missing or invalid Authorization header" }, { status: 401 }); + } + + const apiKey = authHeader.slice(7); + if (!apiKey) { + return NextResponse.json({ error: "Missing API key in Authorization header" }, { status: 401 }); + } + + // Parse and validate request body + const body: BatchTracesRequest = await request.json(); + if (!body.traces || !Array.isArray(body.traces)) { + return NextResponse.json({ error: "Request body must contain a 'traces' array" }, { status: 400 }); + } + + if (body.traces.length === 0) { + return NextResponse.json({ error: "Traces array cannot be empty" }, { status: 400 }); + } + + // Validate each trace payload + for (const trace of body.traces) { + if (!trace.id || typeof trace.id !== "string") { + return NextResponse.json({ error: "Each trace must have a valid 'id' (string)" }, { status: 400 }); + } + if (!trace.name || typeof trace.name !== "string") { + return NextResponse.json({ error: "Each trace must have a valid 'name' (string)" }, { status: 400 }); + } + if (!trace.startedAt || typeof trace.startedAt !== "string") { + return NextResponse.json({ error: "Each trace must have a valid 'startedAt' (ISO date string)" }, { status: 400 }); + } + if (!["RUNNING", "COMPLETED", "ERROR"].includes(trace.status)) { + return NextResponse.json({ error: `Invalid trace status: ${trace.status}` }, { status: 400 }); + } + if (!Array.isArray(trace.tags)) { + return NextResponse.json({ error: "Trace tags must be an array" }, { status: 400 }); + } + if (!Array.isArray(trace.decisionPoints)) { + return NextResponse.json({ error: "Trace decisionPoints must be an array" }, { status: 400 }); + } + if (!Array.isArray(trace.spans)) { + return NextResponse.json({ error: "Trace spans must be an array" }, { status: 400 }); + } + if (!Array.isArray(trace.events)) { + return NextResponse.json({ error: "Trace events must be an array" }, { status: 400 }); + } + + // Validate decision points + for (const dp of trace.decisionPoints) { + if (!dp.id || typeof dp.id !== "string") { + return NextResponse.json({ error: "Each decision point must have a valid 'id' (string)" }, { status: 400 }); + } + if (!["TOOL_SELECTION", "ROUTING", "RETRY", "ESCALATION", "MEMORY_RETRIEVAL", "PLANNING", "CUSTOM"].includes(dp.type)) { + return NextResponse.json({ error: `Invalid decision point type: ${dp.type}` }, { status: 400 }); + } + if (!dp.timestamp || typeof dp.timestamp !== "string") { + return NextResponse.json({ error: "Each decision point must have a valid 'timestamp' (ISO date string)" }, { status: 400 }); + } + if (!Array.isArray(dp.alternatives)) { + return NextResponse.json({ error: "Decision point alternatives must be an array" }, { status: 400 }); + } + } + + // Validate spans + for (const span of trace.spans) { + if (!span.id || typeof span.id !== "string") { + return NextResponse.json({ error: "Each span must have a valid 'id' (string)" }, { status: 400 }); + } + if (!span.name || typeof span.name !== "string") { + return NextResponse.json({ error: "Each span must have a valid 'name' (string)" }, { status: 400 }); + } + if (!["LLM_CALL", "TOOL_CALL", "MEMORY_OP", "CHAIN", "AGENT", "CUSTOM"].includes(span.type)) { + return NextResponse.json({ error: `Invalid span type: ${span.type}` }, { status: 400 }); + } + if (!span.startedAt || typeof span.startedAt !== "string") { + return NextResponse.json({ error: "Each span must have a valid 'startedAt' (ISO date string)" }, { status: 400 }); + } + if (!["RUNNING", "COMPLETED", "ERROR"].includes(span.status)) { + return NextResponse.json({ error: `Invalid span status: ${span.status}` }, { status: 400 }); + } + } + + // Validate events + for (const event of trace.events) { + if (!event.id || typeof event.id !== "string") { + return NextResponse.json({ error: "Each event must have a valid 'id' (string)" }, { status: 400 }); + } + if (!event.name || typeof event.name !== "string") { + return NextResponse.json({ error: "Each event must have a valid 'name' (string)" }, { status: 400 }); + } + if (!["ERROR", "RETRY", "FALLBACK", "CONTEXT_OVERFLOW", "USER_FEEDBACK", "CUSTOM"].includes(event.type)) { + return NextResponse.json({ error: `Invalid event type: ${event.type}` }, { status: 400 }); + } + if (!event.timestamp || typeof event.timestamp !== "string") { + return NextResponse.json({ error: "Each event must have a valid 'timestamp' (ISO date string)" }, { status: 400 }); + } + } + } + + // Insert traces using transaction + const result = await prisma.$transaction( + body.traces.map((trace) => + prisma.trace.create({ + data: { + id: trace.id, + name: trace.name, + sessionId: trace.sessionId, + status: trace.status, + tags: trace.tags, + metadata: trace.metadata as Prisma.InputJsonValue, + totalCost: trace.totalCost, + totalTokens: trace.totalTokens, + totalDuration: trace.totalDuration, + startedAt: new Date(trace.startedAt), + endedAt: trace.endedAt ? new Date(trace.endedAt) : null, + decisionPoints: { + create: trace.decisionPoints.map((dp) => ({ + id: dp.id, + type: dp.type, + reasoning: dp.reasoning, + chosen: dp.chosen as Prisma.InputJsonValue, + alternatives: dp.alternatives as Prisma.InputJsonValue[], + contextSnapshot: dp.contextSnapshot as Prisma.InputJsonValue | undefined, + durationMs: dp.durationMs, + costUsd: dp.costUsd, + parentSpanId: dp.parentSpanId, + timestamp: new Date(dp.timestamp), + })), + }, + spans: { + create: trace.spans.map((span) => ({ + id: span.id, + parentSpanId: span.parentSpanId, + name: span.name, + type: span.type, + input: span.input as Prisma.InputJsonValue | undefined, + output: span.output as Prisma.InputJsonValue | undefined, + tokenCount: span.tokenCount, + costUsd: span.costUsd, + durationMs: span.durationMs, + status: span.status, + statusMessage: span.statusMessage, + startedAt: new Date(span.startedAt), + endedAt: span.endedAt ? new Date(span.endedAt) : null, + metadata: span.metadata as Prisma.InputJsonValue | undefined, + })), + }, + events: { + create: trace.events.map((event) => ({ + id: event.id, + spanId: event.spanId, + type: event.type, + name: event.name, + metadata: event.metadata as Prisma.InputJsonValue | undefined, + timestamp: new Date(event.timestamp), + })), + }, + }, + }) + ) + ); + + return NextResponse.json({ success: true, count: result.length }, { status: 200 }); + } catch (error) { + if (error instanceof SyntaxError) { + return NextResponse.json({ error: "Invalid JSON in request body" }, { status: 400 }); + } + + // Handle unique constraint violations + if (error instanceof Error && error.message.includes("Unique constraint")) { + return NextResponse.json({ error: "Duplicate trace ID detected" }, { status: 409 }); + } + + console.error("Error processing traces:", error); + return NextResponse.json({ error: "Internal server error" }, { status: 500 }); + } +} + +// GET /api/traces — List traces with pagination +export async function GET(request: NextRequest) { + try { + const { searchParams } = new URL(request.url); + const page = parseInt(searchParams.get("page") ?? "1", 10); + const limit = parseInt(searchParams.get("limit") ?? "20", 10); + const status = searchParams.get("status"); + const search = searchParams.get("search"); + const sessionId = searchParams.get("sessionId"); + + // Validate pagination parameters + if (isNaN(page) || page < 1) { + return NextResponse.json({ error: "Invalid page parameter. Must be a positive integer." }, { status: 400 }); + } + if (isNaN(limit) || limit < 1 || limit > 100) { + return NextResponse.json({ error: "Invalid limit parameter. Must be between 1 and 100." }, { status: 400 }); + } + + // Validate status parameter if provided + const validStatuses = ["RUNNING", "COMPLETED", "ERROR"]; + if (status && !validStatuses.includes(status)) { + return NextResponse.json({ error: `Invalid status. Must be one of: ${validStatuses.join(", ")}` }, { status: 400 }); + } + + // Build where clause + const where: Record = {}; + if (status) { + where.status = status; + } + if (search) { + where.name = { + contains: search, + mode: "insensitive", + }; + } + if (sessionId) { + where.sessionId = sessionId; + } + + // Count total traces + const total = await prisma.trace.count({ where }); + + // Calculate pagination + const skip = (page - 1) * limit; + const totalPages = Math.ceil(total / limit); + + // Fetch traces with pagination + const traces = await prisma.trace.findMany({ + where, + include: { + _count: { + select: { + decisionPoints: true, + spans: true, + events: true, + }, + }, + }, + orderBy: { + startedAt: "desc", + }, + skip, + take: limit, + }); + + return NextResponse.json({ + traces, + total, + page, + limit, + totalPages, + }, { status: 200 }); + } catch (error) { + console.error("Error listing traces:", error); + return NextResponse.json({ error: "Internal server error" }, { status: 500 }); + } +} diff --git a/apps/web/src/lib/prisma.ts b/apps/web/src/lib/prisma.ts new file mode 100644 index 0000000..23646dd --- /dev/null +++ b/apps/web/src/lib/prisma.ts @@ -0,0 +1,9 @@ +import { PrismaClient } from "@agentlens/database"; + +const globalForPrisma = globalThis as unknown as { + prisma: PrismaClient | undefined; +}; + +export const prisma = globalForPrisma.prisma ?? new PrismaClient(); + +if (process.env.NODE_ENV !== "production") globalForPrisma.prisma = prisma; diff --git a/packages/sdk-python/agentlens/__init__.py b/packages/sdk-python/agentlens/__init__.py index fb5c384..851a6f2 100644 --- a/packages/sdk-python/agentlens/__init__.py +++ b/packages/sdk-python/agentlens/__init__.py @@ -1,8 +1,47 @@ """AgentLens - Agent observability that traces decisions, not just API calls.""" -from agentlens.client import init, shutdown -from agentlens.trace import trace +from agentlens.client import init, shutdown, get_client from agentlens.decision import log_decision +from agentlens.models import ( + TraceData, + DecisionPoint, + Span, + Event, + TraceStatus, + DecisionType, + SpanType, + SpanStatus, + EventType, +) + +import sys + +if sys.version_info >= (3, 7): + import importlib + + __trace_module = importlib.import_module("agentlens.trace") + trace = getattr(__trace_module, "trace") + TraceContext = getattr(__trace_module, "TraceContext") + get_current_trace = getattr(__trace_module, "get_current_trace") +else: + from agentlens.trace import trace, TraceContext, get_current_trace __version__ = "0.1.0" -__all__ = ["init", "shutdown", "trace", "log_decision"] +__all__ = [ + "init", + "shutdown", + "get_client", + "trace", + "TraceContext", + "get_current_trace", + "log_decision", + "TraceData", + "DecisionPoint", + "Span", + "Event", + "TraceStatus", + "DecisionType", + "SpanType", + "SpanStatus", + "EventType", +] diff --git a/packages/sdk-python/agentlens/client.py b/packages/sdk-python/agentlens/client.py index fb6e014..5dd90af 100644 --- a/packages/sdk-python/agentlens/client.py +++ b/packages/sdk-python/agentlens/client.py @@ -1,43 +1,107 @@ """Client initialization and management for AgentLens.""" -from typing import Optional +import atexit +import logging +from typing import Optional, List + +from agentlens.models import TraceData +from agentlens.transport import BatchTransport + +logger = logging.getLogger("agentlens") + +_client: Optional["AgentLensClient"] = None -_client: Optional["_Client"] = None - - -class _Client: - """Internal client class for managing AgentLens connection.""" - - def __init__(self, api_key: str, endpoint: str) -> None: +class AgentLensClient: + def __init__( + self, + api_key: str, + endpoint: str = "https://agentlens.vectry.tech", + flush_interval: float = 5.0, + max_batch_size: int = 10, + enabled: bool = True, + ) -> None: self.api_key = api_key - self.endpoint = endpoint - self.is_shutdown = False + self.endpoint = endpoint.rstrip("/") + self.flush_interval = flush_interval + self.max_batch_size = max_batch_size + self.enabled = enabled + self._transport: Optional[BatchTransport] = None + self._sent_traces: List[TraceData] = [] + + if enabled: + try: + self._transport = BatchTransport( + api_key=api_key, + endpoint=endpoint, + flush_interval=flush_interval, + max_batch_size=max_batch_size, + ) + except ImportError: + logger.warning("httpx not available, transport disabled") + self.enabled = False + + def send_trace(self, trace: TraceData) -> None: + if not self.enabled: + self._sent_traces.append(trace) + return + + if self._transport: + self._transport.add(trace) def shutdown(self) -> None: - """Shutdown the client.""" - self.is_shutdown = True + if self._transport: + self._transport.flush() + self._transport.shutdown() + self._transport = None + + @property + def sent_traces(self) -> List[TraceData]: + return self._sent_traces + + @property + def transport(self) -> Optional["BatchTransport"]: + return self._transport -def init(api_key: str, endpoint: str = "https://agentlens.vectry.tech") -> None: - """Initialize the AgentLens client. +def init( + api_key: str, + endpoint: str = "https://agentlens.vectry.tech", + flush_interval: float = 5.0, + max_batch_size: int = 10, + enabled: bool = True, +) -> None: + """Initialize the AgentLens SDK. Args: api_key: Your AgentLens API key. - endpoint: The AgentLens API endpoint (default: https://agentlens.vectry.tech). + endpoint: The AgentLens API endpoint. + flush_interval: Seconds between automatic flushes (default: 5.0). + max_batch_size: Number of traces before auto-flush (default: 10). + enabled: Set to False to disable sending (useful for testing). """ global _client - _client = _Client(api_key=api_key, endpoint=endpoint) + if _client is not None: + _client.shutdown() + _client = AgentLensClient( + api_key=api_key, + endpoint=endpoint, + flush_interval=flush_interval, + max_batch_size=max_batch_size, + enabled=enabled, + ) + atexit.register(shutdown) + logger.debug("AgentLens initialized: endpoint=%s", endpoint) def shutdown() -> None: - """Shutdown the AgentLens client.""" global _client - if _client: + if _client is not None: _client.shutdown() _client = None + logger.debug("AgentLens shutdown complete") -def get_client() -> Optional[_Client]: - """Get the current client instance.""" +def get_client() -> Optional[AgentLensClient]: + """Get the current client instance. Returns None if not initialized.""" return _client diff --git a/packages/sdk-python/agentlens/decision.py b/packages/sdk-python/agentlens/decision.py index 121f3c3..b12e1dc 100644 --- a/packages/sdk-python/agentlens/decision.py +++ b/packages/sdk-python/agentlens/decision.py @@ -1,32 +1,47 @@ """Decision logging for tracking agent decision points.""" -from typing import Any, Dict, List +import logging +from typing import Any, Dict, List, Optional + +from agentlens.models import DecisionPoint +from agentlens.trace import get_current_trace, get_current_span_id + +logger = logging.getLogger("agentlens") def log_decision( type: str, - chosen: Any, - alternatives: List[Any], + chosen: Dict[str, Any], + alternatives: Optional[List[Dict[str, Any]]] = None, reasoning: Optional[str] = None, -) -> None: - """Log a decision point in the agent's reasoning. - - Args: - type: Type of decision (e.g., "tool_selection", "routing", "retry"). - chosen: The option that was selected. - alternatives: List of alternatives that were considered. - reasoning: Optional explanation for the decision. - - Example: - log_decision( - type="tool_selection", - chosen="search", - alternatives=["search", "calculate", "browse"], - reasoning="Search is most appropriate for finding information" + context_snapshot: Optional[Dict[str, Any]] = None, + cost_usd: Optional[float] = None, + duration_ms: Optional[int] = None, +) -> Optional[DecisionPoint]: + current_trace = get_current_trace() + if current_trace is None: + logger.warning( + "AgentLens: log_decision called outside of a trace context. Decision not recorded." ) - """ - print(f"[AgentLens] Decision logged: {type}") - print(f"[AgentLens] Chosen: {chosen}") - print(f"[AgentLens] Alternatives: {alternatives}") - if reasoning: - print(f"[AgentLens] Reasoning: {reasoning}") + return None + + parent_span_id = get_current_span_id() + + decision = DecisionPoint( + type=type, + chosen=chosen, + alternatives=alternatives or [], + reasoning=reasoning, + context_snapshot=context_snapshot, + cost_usd=cost_usd, + duration_ms=duration_ms, + parent_span_id=parent_span_id, + ) + + current_trace.decision_points.append(decision) + logger.debug( + "AgentLens: Decision logged: type=%s, chosen=%s", + type, + chosen.get("name", str(chosen)), + ) + return decision diff --git a/packages/sdk-python/agentlens/models.py b/packages/sdk-python/agentlens/models.py new file mode 100644 index 0000000..2a99541 --- /dev/null +++ b/packages/sdk-python/agentlens/models.py @@ -0,0 +1,211 @@ +"""Data models for AgentLens SDK, matching the server-side Prisma schema.""" + +import time +import uuid +from dataclasses import dataclass, field, asdict +from typing import Any, Dict, List, Optional +from enum import Enum + + +class TraceStatus(str, Enum): + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + ERROR = "ERROR" + + +class DecisionType(str, Enum): + TOOL_SELECTION = "TOOL_SELECTION" + ROUTING = "ROUTING" + RETRY = "RETRY" + ESCALATION = "ESCALATION" + MEMORY_RETRIEVAL = "MEMORY_RETRIEVAL" + PLANNING = "PLANNING" + CUSTOM = "CUSTOM" + + +class SpanType(str, Enum): + LLM_CALL = "LLM_CALL" + TOOL_CALL = "TOOL_CALL" + MEMORY_OP = "MEMORY_OP" + CHAIN = "CHAIN" + AGENT = "AGENT" + CUSTOM = "CUSTOM" + + +class SpanStatus(str, Enum): + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + ERROR = "ERROR" + + +class EventType(str, Enum): + ERROR = "ERROR" + RETRY = "RETRY" + FALLBACK = "FALLBACK" + CONTEXT_OVERFLOW = "CONTEXT_OVERFLOW" + USER_FEEDBACK = "USER_FEEDBACK" + CUSTOM = "CUSTOM" + + +def _generate_id() -> str: + return str(uuid.uuid4()) + + +def _now_ms() -> int: + return int(time.time() * 1000) + + +def _now_iso() -> str: + from datetime import datetime, timezone + + return datetime.now(timezone.utc).isoformat() + + +@dataclass +class DecisionPoint: + type: str # DecisionType value + chosen: Dict[str, Any] # {name, confidence, params} + alternatives: List[Dict[str, Any]] # [{name, confidence, reason_rejected}] + reasoning: Optional[str] = None + context_snapshot: Optional[Dict[str, Any]] = ( + None # {window_usage_pct, tokens_used, tokens_available} + ) + duration_ms: Optional[int] = None + cost_usd: Optional[float] = None + parent_span_id: Optional[str] = None + id: str = field(default_factory=_generate_id) + timestamp: str = field(default_factory=_now_iso) + + def to_dict(self) -> Dict[str, Any]: + d = { + "id": self.id, + "type": self.type, + "chosen": self.chosen, + "alternatives": self.alternatives, + "timestamp": self.timestamp, + } + if self.reasoning is not None: + d["reasoning"] = self.reasoning + if self.context_snapshot is not None: + d["contextSnapshot"] = self.context_snapshot + if self.duration_ms is not None: + d["durationMs"] = self.duration_ms + if self.cost_usd is not None: + d["costUsd"] = self.cost_usd + if self.parent_span_id is not None: + d["parentSpanId"] = self.parent_span_id + return d + + +@dataclass +class Span: + name: str + type: str # SpanType value + id: str = field(default_factory=_generate_id) + parent_span_id: Optional[str] = None + input_data: Optional[Any] = None + output_data: Optional[Any] = None + token_count: Optional[int] = None + cost_usd: Optional[float] = None + duration_ms: Optional[int] = None + status: str = field(default_factory=lambda: SpanStatus.RUNNING.value) + status_message: Optional[str] = None + started_at: str = field(default_factory=_now_iso) + ended_at: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + d = { + "id": self.id, + "name": self.name, + "type": self.type, + "status": self.status, + "startedAt": self.started_at, + } + if self.parent_span_id is not None: + d["parentSpanId"] = self.parent_span_id + if self.input_data is not None: + d["input"] = self.input_data + if self.output_data is not None: + d["output"] = self.output_data + if self.token_count is not None: + d["tokenCount"] = self.token_count + if self.cost_usd is not None: + d["costUsd"] = self.cost_usd + if self.duration_ms is not None: + d["durationMs"] = self.duration_ms + if self.status_message is not None: + d["statusMessage"] = self.status_message + if self.ended_at is not None: + d["endedAt"] = self.ended_at + if self.metadata is not None: + d["metadata"] = self.metadata + return d + + +@dataclass +class Event: + type: str # EventType value + name: str + span_id: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + id: str = field(default_factory=_generate_id) + timestamp: str = field(default_factory=_now_iso) + + def to_dict(self) -> Dict[str, Any]: + d = { + "id": self.id, + "type": self.type, + "name": self.name, + "timestamp": self.timestamp, + } + if self.span_id is not None: + d["spanId"] = self.span_id + if self.metadata is not None: + d["metadata"] = self.metadata + return d + + +@dataclass +class TraceData: + """Represents a complete trace to be sent to the server.""" + + name: str + id: str = field(default_factory=_generate_id) + session_id: Optional[str] = None + status: str = field(default_factory=lambda: TraceStatus.RUNNING.value) + tags: List[str] = field(default_factory=list) + metadata: Optional[Dict[str, Any]] = None + total_cost: Optional[float] = None + total_tokens: Optional[int] = None + total_duration: Optional[int] = None # ms + started_at: str = field(default_factory=_now_iso) + ended_at: Optional[str] = None + decision_points: List[DecisionPoint] = field(default_factory=list) + spans: List[Span] = field(default_factory=list) + events: List[Event] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + d = { + "id": self.id, + "name": self.name, + "status": self.status, + "tags": self.tags, + "startedAt": self.started_at, + "decisionPoints": [dp.to_dict() for dp in self.decision_points], + "spans": [s.to_dict() for s in self.spans], + "events": [e.to_dict() for e in self.events], + } + if self.session_id is not None: + d["sessionId"] = self.session_id + if self.metadata is not None: + d["metadata"] = self.metadata + if self.total_cost is not None: + d["totalCost"] = self.total_cost + if self.total_tokens is not None: + d["totalTokens"] = self.total_tokens + if self.total_duration is not None: + d["totalDuration"] = self.total_duration + if self.ended_at is not None: + d["endedAt"] = self.ended_at + return d diff --git a/packages/sdk-python/agentlens/trace.py b/packages/sdk-python/agentlens/trace.py index 60e148d..533d04d 100644 --- a/packages/sdk-python/agentlens/trace.py +++ b/packages/sdk-python/agentlens/trace.py @@ -1,76 +1,191 @@ """Trace decorator and context manager for instrumenting agent functions.""" -from typing import Callable, Optional, Any +import asyncio +import logging +import threading +import time from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Union + +from agentlens.models import ( + TraceData, + Span, + SpanType, + TraceStatus, + SpanStatus, + Event, + EventType, + _now_iso, +) + +logger = logging.getLogger("agentlens") + +_context = threading.local() -def trace(name: Optional[str] = None) -> Callable[..., Any]: - """Decorator to trace a function or method. +def _get_context_stack() -> List[Union[TraceData, Span]]: + if not hasattr(_context, "stack"): + _context.stack = [] + return _context.stack - Args: - name: Name for the trace. If not provided, uses the function name. - Returns: - Decorated function with tracing enabled. +def get_current_trace() -> Optional[TraceData]: + stack = _get_context_stack() + if not stack: + return None + for item in stack: + if isinstance(item, TraceData): + return item + return None - Example: - @trace(name="research-agent") - async def research(topic: str) -> str: - return f"Researching: {topic}" - """ - def decorator(func: Callable[..., Any]) -> Callable[..., Any]: - @wraps(func) - async def async_wrapper(*args: Any, **kwargs: Any) -> Any: - trace_name = name or func.__name__ - print(f"[AgentLens] Starting trace: {trace_name}") - try: - result = await func(*args, **kwargs) - print(f"[AgentLens] Completed trace: {trace_name}") - return result - except Exception as e: - print(f"[AgentLens] Error in trace {trace_name}: {e}") - raise +def get_current_span_id() -> Optional[str]: + stack = _get_context_stack() + if not stack: + return None + for item in reversed(stack): + if isinstance(item, Span): + return item.id + return None - @wraps(func) - def sync_wrapper(*args: Any, **kwargs: Any) -> Any: - trace_name = name or func.__name__ - print(f"[AgentLens] Starting trace: {trace_name}") - try: - result = func(*args, **kwargs) - print(f"[AgentLens] Completed trace: {trace_name}") - return result - except Exception as e: - print(f"[AgentLens] Error in trace {trace_name}: {e}") - raise - if hasattr(func, "__await__"): - return async_wrapper +class TraceContext: + def __init__( + self, + name: str, + tags: Optional[List[str]] = None, + session_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + self.name = name or "trace" + self.tags = tags or [] + self.session_id = session_id + self.metadata = metadata + self._trace_data: Optional[TraceData] = None + self._span: Optional[Span] = None + self._start_time: float = 0 + self._is_nested: bool = False + + def __enter__(self) -> "TraceContext": + self._start_time = time.time() + stack = _get_context_stack() + + if stack: + self._is_nested = True + parent_trace = get_current_trace() + parent_span_id = get_current_span_id() + self._span = Span( + name=self.name, + type=SpanType.AGENT.value, + parent_span_id=parent_span_id, + started_at=_now_iso(), + ) + if parent_trace: + parent_trace.spans.append(self._span) + stack.append(self._span) else: - return sync_wrapper + self._trace_data = TraceData( + name=self.name, + tags=self.tags, + session_id=self.session_id, + metadata=self.metadata, + status=TraceStatus.RUNNING.value, + started_at=_now_iso(), + ) + stack.append(self._trace_data) - return decorator - - -class Tracer: - """Context manager for creating traces. - - Example: - with Tracer(name="custom-operation"): - # Your code here - pass - """ - - def __init__(self, name: str) -> None: - self.name = name - - def __enter__(self) -> "Tracer": - print(f"[AgentLens] Starting trace: {self.name}") return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - if exc_type is None: - print(f"[AgentLens] Completed trace: {self.name}") + def __exit__( + self, + exc_type: Optional[type], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + from agentlens.client import get_client + + client = get_client() + end_time = time.time() + total_duration = int((end_time - self._start_time) * 1000) + + stack = _get_context_stack() + + if self._is_nested and self._span: + self._span.status = ( + SpanStatus.COMPLETED.value + if exc_type is None + else SpanStatus.ERROR.value + ) + self._span.duration_ms = total_duration + self._span.ended_at = _now_iso() + stack.pop() + elif self._trace_data: + if exc_type is not None: + self._trace_data.status = TraceStatus.ERROR.value + error_event = Event( + type=EventType.ERROR.value, + name=str(exc_val) if exc_val else "Unknown error", + ) + self._trace_data.events.append(error_event) + else: + self._trace_data.status = TraceStatus.COMPLETED.value + + self._trace_data.total_duration = total_duration + self._trace_data.ended_at = _now_iso() + stack.pop() + + client = get_client() + if client: + client.send_trace(self._trace_data) + + def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]: + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + with TraceContext( + name=self.name, + tags=self.tags, + session_id=self.session_id, + metadata=self.metadata, + ): + return await func(*args, **kwargs) + + return async_wrapper else: - print(f"[AgentLens] Error in trace {self.name}: {exc_val}") - return False + + @wraps(func) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + with TraceContext( + name=self.name, + tags=self.tags, + session_id=self.session_id, + metadata=self.metadata, + ): + return func(*args, **kwargs) + + return sync_wrapper + + @property + def trace_id(self) -> Optional[str]: + if self._trace_data: + return self._trace_data.id + return None + + +def trace( + name: Union[Callable[..., Any], str, None] = None, + tags: Optional[List[str]] = None, + session_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> Union[TraceContext, Callable[..., Any]]: + if callable(name): + func = name + ctx = TraceContext( + name=func.__name__, tags=tags, session_id=session_id, metadata=metadata + ) + return ctx(func) + + return TraceContext( + name=name or "trace", tags=tags, session_id=session_id, metadata=metadata + ) diff --git a/packages/sdk-python/agentlens/transport.py b/packages/sdk-python/agentlens/transport.py index 37b4e75..c804a7b 100644 --- a/packages/sdk-python/agentlens/transport.py +++ b/packages/sdk-python/agentlens/transport.py @@ -1,38 +1,92 @@ -"""Batch transport for sending data to AgentLens API.""" +"""Batch transport for sending trace data to AgentLens API.""" -from typing import List, Dict, Any +import json +import logging +import threading +import time +from typing import Any, Dict, List, Optional + +import httpx + +from agentlens.models import TraceData + +logger = logging.getLogger("agentlens") class BatchTransport: - """Transport layer that batches events for efficient API calls. + """Thread-safe batch transport that buffers and sends traces to API.""" - This class handles batching and sending of traces, decisions, and other - events to the AgentLens backend. - """ + def __init__( + self, + api_key: str, + endpoint: str, + max_batch_size: int = 10, + flush_interval: float = 5.0, + ) -> None: + self._api_key = api_key + self._endpoint = endpoint.rstrip("/") + self._max_batch_size = max_batch_size + self._flush_interval = flush_interval + self._buffer: List[Dict[str, Any]] = [] + self._lock = threading.Lock() + self._shutdown_event = threading.Event() + self._flush_thread = threading.Thread( + target=self._flush_loop, daemon=True, name="agentlens-flush" + ) + self._flush_thread.start() + self._client = httpx.Client(timeout=30.0) - def __init__(self, max_batch_size: int = 100, flush_interval: float = 1.0) -> None: - self.max_batch_size = max_batch_size - self.flush_interval = flush_interval - self._batch: List[Dict[str, Any]] = [] + def add(self, trace: TraceData) -> None: + """Add a completed trace to send buffer.""" + trace_dict = trace.to_dict() + with self._lock: + self._buffer.append(trace_dict) + should_flush = len(self._buffer) >= self._max_batch_size + if should_flush: + self._do_flush() - def add(self, event: Dict[str, Any]) -> None: - """Add an event to the batch. + def _flush_loop(self) -> None: + """Background loop that periodically flushes buffer.""" + while not self._shutdown_event.is_set(): + self._shutdown_event.wait(timeout=self._flush_interval) + if not self._shutdown_event.is_set(): + self._do_flush() - Args: - event: Event data to be sent. - """ - self._batch.append(event) - if len(self._batch) >= self.max_batch_size: - self.flush() + def _do_flush(self) -> None: + """Flush all buffered traces to the API.""" + with self._lock: + if not self._buffer: + return + batch = self._buffer.copy() + self._buffer.clear() + + try: + response = self._client.post( + f"{self._endpoint}/api/traces", + json={"traces": batch}, + headers={ + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + }, + ) + if response.status_code >= 400: + logger.warning( + "AgentLens: Failed to send traces (HTTP %d): %s", + response.status_code, + response.text[:200], + ) + except Exception as e: + logger.warning("AgentLens: Failed to send traces: %s", e) + # Put traces back in buffer for retry (optional, up to you) + # For now, drop on failure to avoid unbounded growth def flush(self) -> None: - """Flush the batch by sending all pending events.""" - if not self._batch: - return - - print(f"[AgentLens] Flushing batch of {len(self._batch)} events") - self._batch.clear() + """Manually flush all buffered traces.""" + self._do_flush() def shutdown(self) -> None: - """Shutdown the transport, flushing any remaining events.""" - self.flush() + """Shutdown transport: flush remaining traces and stop background thread.""" + self._shutdown_event.set() + self._flush_thread.join(timeout=5.0) + self._do_flush() # Final flush + self._client.close() diff --git a/packages/sdk-python/tests/__init__.py b/packages/sdk-python/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/sdk-python/tests/test_sdk.py b/packages/sdk-python/tests/test_sdk.py new file mode 100644 index 0000000..a259ba7 --- /dev/null +++ b/packages/sdk-python/tests/test_sdk.py @@ -0,0 +1,129 @@ +import unittest + +from agentlens import init, shutdown, get_client, trace, log_decision +from agentlens.models import TraceData, TraceStatus + + +class TestSDK(unittest.TestCase): + def setUp(self): + shutdown() + init(api_key="test_key", enabled=False) + + def tearDown(self): + shutdown() + + def test_init_and_shutdown(self): + client = get_client() + self.assertIsNotNone(client) + self.assertEqual(client.api_key, "test_key") + self.assertFalse(client.enabled) + shutdown() + self.assertIsNone(get_client()) + + def test_decorator_sync_function(self): + @trace("test_decorator") + def test_func(): + return 42 + + result = test_func() + self.assertEqual(result, 42) + + client = get_client() + traces = client.sent_traces + self.assertEqual(len(traces), 1) + self.assertEqual(traces[0].name, "test_decorator") + self.assertEqual(traces[0].status, TraceStatus.COMPLETED.value) + self.assertIsNotNone(traces[0].total_duration) + + def test_context_manager(self): + with trace("test_context") as t: + pass + + client = get_client() + traces = client.sent_traces + self.assertEqual(len(traces), 1) + self.assertEqual(traces[0].name, "test_context") + self.assertEqual(traces[0].status, TraceStatus.COMPLETED.value) + + def test_nested_traces(self): + with trace("outer_trace"): + with trace("inner_span"): + pass + + client = get_client() + traces = client.sent_traces + self.assertEqual(len(traces), 1) + self.assertEqual(traces[0].name, "outer_trace") + self.assertEqual(len(traces[0].spans), 1) + self.assertEqual(traces[0].spans[0].name, "inner_span") + + def test_log_decision(self): + with trace("test_trace"): + log_decision( + type="tool_selection", + chosen={"name": "search"}, + alternatives=[{"name": "calculate"}, {"name": "browse"}], + reasoning="Search is best for finding information", + ) + + client = get_client() + traces = client.sent_traces + self.assertEqual(len(traces), 1) + self.assertEqual(len(traces[0].decision_points), 1) + decision = traces[0].decision_points[0] + self.assertEqual(decision.type, "tool_selection") + self.assertEqual(decision.chosen, {"name": "search"}) + self.assertEqual(len(decision.alternatives), 2) + self.assertEqual(decision.reasoning, "Search is best for finding information") + + def test_error_handling(self): + with self.assertRaises(ValueError): + with trace("test_error"): + raise ValueError("Test error") + + client = get_client() + traces = client.sent_traces + self.assertEqual(len(traces), 1) + self.assertEqual(traces[0].name, "test_error") + self.assertEqual(traces[0].status, TraceStatus.ERROR.value) + self.assertEqual(len(traces[0].events), 1) + self.assertEqual(traces[0].events[0].type, "ERROR") + self.assertIn("Test error", traces[0].events[0].name) + + def test_log_decision_outside_trace(self): + shutdown() + + log_decision( + type="tool_selection", + chosen={"name": "search"}, + alternatives=[], + ) + + init(api_key="test_key", enabled=False) + + client = get_client() + self.assertEqual(len(client.sent_traces), 0) + + def test_model_serialization(self): + trace_data = TraceData( + name="test", + tags=["tag1", "tag2"], + session_id="session123", + metadata={"key": "value"}, + ) + + trace_dict = trace_data.to_dict() + + self.assertEqual(trace_dict["name"], "test") + self.assertEqual(trace_dict["tags"], ["tag1", "tag2"]) + self.assertEqual(trace_dict["sessionId"], "session123") + self.assertEqual(trace_dict["metadata"], {"key": "value"}) + self.assertEqual(trace_dict["status"], "RUNNING") + self.assertIn("startedAt", trace_dict) + self.assertIn("decisionPoints", trace_dict) + self.assertIn("spans", trace_dict) + self.assertIn("events", trace_dict) + + +if __name__ == "__main__": + unittest.main()