feat: Python SDK real implementation + API ingestion routes

- SDK: client with BatchTransport, trace decorator/context manager,
  log_decision, thread-local context stack, nested trace→span support
- API: POST /api/traces (batch ingest), GET /api/traces (paginated list),
  GET /api/traces/[id] (full trace with relations), GET /api/health
- Tests: 8 unit tests for SDK (all passing)
- Transport: thread-safe buffer with background flush thread
This commit is contained in:
Vectry
2026-02-09 23:25:34 +00:00
parent 9264866d1f
commit 3fe9013838
12 changed files with 1144 additions and 133 deletions

View File

@@ -0,0 +1,5 @@
import { NextResponse } from "next/server";
export async function GET() {
return NextResponse.json({ status: "ok", service: "agentlens", timestamp: new Date().toISOString() });
}

View File

@@ -0,0 +1,46 @@
import { NextResponse } from "next/server";
import { prisma } from "@/lib/prisma";
// GET /api/traces/[id] — Get single trace with all relations
export async function GET(
_request: Request,
{ params }: { params: Promise<{ id: string }> }
) {
try {
const { id } = await params;
if (!id || typeof id !== "string") {
return NextResponse.json({ error: "Invalid trace ID" }, { status: 400 });
}
const trace = await prisma.trace.findUnique({
where: { id },
include: {
decisionPoints: {
orderBy: {
timestamp: "asc",
},
},
spans: {
orderBy: {
startedAt: "asc",
},
},
events: {
orderBy: {
timestamp: "asc",
},
},
},
});
if (!trace) {
return NextResponse.json({ error: "Trace not found" }, { status: 404 });
}
return NextResponse.json({ trace }, { status: 200 });
} catch (error) {
console.error("Error retrieving trace:", error);
return NextResponse.json({ error: "Internal server error" }, { status: 500 });
}
}

View File

@@ -0,0 +1,324 @@
import { NextRequest, NextResponse } from "next/server";
import { prisma } from "@/lib/prisma";
import { Prisma } from "@agentlens/database";
// Types
interface DecisionPointPayload {
id: string;
type: "TOOL_SELECTION" | "ROUTING" | "RETRY" | "ESCALATION" | "MEMORY_RETRIEVAL" | "PLANNING" | "CUSTOM";
reasoning?: string;
chosen: Prisma.JsonValue;
alternatives: Prisma.JsonValue[];
contextSnapshot?: Prisma.JsonValue;
durationMs?: number;
costUsd?: number;
parentSpanId?: string;
timestamp: string;
}
interface SpanPayload {
id: string;
parentSpanId?: string;
name: string;
type: "LLM_CALL" | "TOOL_CALL" | "MEMORY_OP" | "CHAIN" | "AGENT" | "CUSTOM";
input?: Prisma.JsonValue;
output?: Prisma.JsonValue;
tokenCount?: number;
costUsd?: number;
durationMs?: number;
status: "RUNNING" | "COMPLETED" | "ERROR";
statusMessage?: string;
startedAt: string;
endedAt?: string;
metadata?: Prisma.JsonValue;
}
interface EventPayload {
id: string;
spanId?: string;
type: "ERROR" | "RETRY" | "FALLBACK" | "CONTEXT_OVERFLOW" | "USER_FEEDBACK" | "CUSTOM";
name: string;
metadata?: Prisma.JsonValue;
timestamp: string;
}
interface TracePayload {
id: string;
name: string;
sessionId?: string;
status: "RUNNING" | "COMPLETED" | "ERROR";
tags: string[];
metadata?: Prisma.JsonValue;
totalCost?: number;
totalTokens?: number;
totalDuration?: number;
startedAt: string;
endedAt?: string;
decisionPoints: DecisionPointPayload[];
spans: SpanPayload[];
events: EventPayload[];
}
interface BatchTracesRequest {
traces: TracePayload[];
}
// POST /api/traces — Batch ingest traces from SDK
export async function POST(request: NextRequest) {
try {
// Validate Authorization header
const authHeader = request.headers.get("authorization");
if (!authHeader || !authHeader.startsWith("Bearer ")) {
return NextResponse.json({ error: "Missing or invalid Authorization header" }, { status: 401 });
}
const apiKey = authHeader.slice(7);
if (!apiKey) {
return NextResponse.json({ error: "Missing API key in Authorization header" }, { status: 401 });
}
// Parse and validate request body
const body: BatchTracesRequest = await request.json();
if (!body.traces || !Array.isArray(body.traces)) {
return NextResponse.json({ error: "Request body must contain a 'traces' array" }, { status: 400 });
}
if (body.traces.length === 0) {
return NextResponse.json({ error: "Traces array cannot be empty" }, { status: 400 });
}
// Validate each trace payload
for (const trace of body.traces) {
if (!trace.id || typeof trace.id !== "string") {
return NextResponse.json({ error: "Each trace must have a valid 'id' (string)" }, { status: 400 });
}
if (!trace.name || typeof trace.name !== "string") {
return NextResponse.json({ error: "Each trace must have a valid 'name' (string)" }, { status: 400 });
}
if (!trace.startedAt || typeof trace.startedAt !== "string") {
return NextResponse.json({ error: "Each trace must have a valid 'startedAt' (ISO date string)" }, { status: 400 });
}
if (!["RUNNING", "COMPLETED", "ERROR"].includes(trace.status)) {
return NextResponse.json({ error: `Invalid trace status: ${trace.status}` }, { status: 400 });
}
if (!Array.isArray(trace.tags)) {
return NextResponse.json({ error: "Trace tags must be an array" }, { status: 400 });
}
if (!Array.isArray(trace.decisionPoints)) {
return NextResponse.json({ error: "Trace decisionPoints must be an array" }, { status: 400 });
}
if (!Array.isArray(trace.spans)) {
return NextResponse.json({ error: "Trace spans must be an array" }, { status: 400 });
}
if (!Array.isArray(trace.events)) {
return NextResponse.json({ error: "Trace events must be an array" }, { status: 400 });
}
// Validate decision points
for (const dp of trace.decisionPoints) {
if (!dp.id || typeof dp.id !== "string") {
return NextResponse.json({ error: "Each decision point must have a valid 'id' (string)" }, { status: 400 });
}
if (!["TOOL_SELECTION", "ROUTING", "RETRY", "ESCALATION", "MEMORY_RETRIEVAL", "PLANNING", "CUSTOM"].includes(dp.type)) {
return NextResponse.json({ error: `Invalid decision point type: ${dp.type}` }, { status: 400 });
}
if (!dp.timestamp || typeof dp.timestamp !== "string") {
return NextResponse.json({ error: "Each decision point must have a valid 'timestamp' (ISO date string)" }, { status: 400 });
}
if (!Array.isArray(dp.alternatives)) {
return NextResponse.json({ error: "Decision point alternatives must be an array" }, { status: 400 });
}
}
// Validate spans
for (const span of trace.spans) {
if (!span.id || typeof span.id !== "string") {
return NextResponse.json({ error: "Each span must have a valid 'id' (string)" }, { status: 400 });
}
if (!span.name || typeof span.name !== "string") {
return NextResponse.json({ error: "Each span must have a valid 'name' (string)" }, { status: 400 });
}
if (!["LLM_CALL", "TOOL_CALL", "MEMORY_OP", "CHAIN", "AGENT", "CUSTOM"].includes(span.type)) {
return NextResponse.json({ error: `Invalid span type: ${span.type}` }, { status: 400 });
}
if (!span.startedAt || typeof span.startedAt !== "string") {
return NextResponse.json({ error: "Each span must have a valid 'startedAt' (ISO date string)" }, { status: 400 });
}
if (!["RUNNING", "COMPLETED", "ERROR"].includes(span.status)) {
return NextResponse.json({ error: `Invalid span status: ${span.status}` }, { status: 400 });
}
}
// Validate events
for (const event of trace.events) {
if (!event.id || typeof event.id !== "string") {
return NextResponse.json({ error: "Each event must have a valid 'id' (string)" }, { status: 400 });
}
if (!event.name || typeof event.name !== "string") {
return NextResponse.json({ error: "Each event must have a valid 'name' (string)" }, { status: 400 });
}
if (!["ERROR", "RETRY", "FALLBACK", "CONTEXT_OVERFLOW", "USER_FEEDBACK", "CUSTOM"].includes(event.type)) {
return NextResponse.json({ error: `Invalid event type: ${event.type}` }, { status: 400 });
}
if (!event.timestamp || typeof event.timestamp !== "string") {
return NextResponse.json({ error: "Each event must have a valid 'timestamp' (ISO date string)" }, { status: 400 });
}
}
}
// Insert traces using transaction
const result = await prisma.$transaction(
body.traces.map((trace) =>
prisma.trace.create({
data: {
id: trace.id,
name: trace.name,
sessionId: trace.sessionId,
status: trace.status,
tags: trace.tags,
metadata: trace.metadata as Prisma.InputJsonValue,
totalCost: trace.totalCost,
totalTokens: trace.totalTokens,
totalDuration: trace.totalDuration,
startedAt: new Date(trace.startedAt),
endedAt: trace.endedAt ? new Date(trace.endedAt) : null,
decisionPoints: {
create: trace.decisionPoints.map((dp) => ({
id: dp.id,
type: dp.type,
reasoning: dp.reasoning,
chosen: dp.chosen as Prisma.InputJsonValue,
alternatives: dp.alternatives as Prisma.InputJsonValue[],
contextSnapshot: dp.contextSnapshot as Prisma.InputJsonValue | undefined,
durationMs: dp.durationMs,
costUsd: dp.costUsd,
parentSpanId: dp.parentSpanId,
timestamp: new Date(dp.timestamp),
})),
},
spans: {
create: trace.spans.map((span) => ({
id: span.id,
parentSpanId: span.parentSpanId,
name: span.name,
type: span.type,
input: span.input as Prisma.InputJsonValue | undefined,
output: span.output as Prisma.InputJsonValue | undefined,
tokenCount: span.tokenCount,
costUsd: span.costUsd,
durationMs: span.durationMs,
status: span.status,
statusMessage: span.statusMessage,
startedAt: new Date(span.startedAt),
endedAt: span.endedAt ? new Date(span.endedAt) : null,
metadata: span.metadata as Prisma.InputJsonValue | undefined,
})),
},
events: {
create: trace.events.map((event) => ({
id: event.id,
spanId: event.spanId,
type: event.type,
name: event.name,
metadata: event.metadata as Prisma.InputJsonValue | undefined,
timestamp: new Date(event.timestamp),
})),
},
},
})
)
);
return NextResponse.json({ success: true, count: result.length }, { status: 200 });
} catch (error) {
if (error instanceof SyntaxError) {
return NextResponse.json({ error: "Invalid JSON in request body" }, { status: 400 });
}
// Handle unique constraint violations
if (error instanceof Error && error.message.includes("Unique constraint")) {
return NextResponse.json({ error: "Duplicate trace ID detected" }, { status: 409 });
}
console.error("Error processing traces:", error);
return NextResponse.json({ error: "Internal server error" }, { status: 500 });
}
}
// GET /api/traces — List traces with pagination
export async function GET(request: NextRequest) {
try {
const { searchParams } = new URL(request.url);
const page = parseInt(searchParams.get("page") ?? "1", 10);
const limit = parseInt(searchParams.get("limit") ?? "20", 10);
const status = searchParams.get("status");
const search = searchParams.get("search");
const sessionId = searchParams.get("sessionId");
// Validate pagination parameters
if (isNaN(page) || page < 1) {
return NextResponse.json({ error: "Invalid page parameter. Must be a positive integer." }, { status: 400 });
}
if (isNaN(limit) || limit < 1 || limit > 100) {
return NextResponse.json({ error: "Invalid limit parameter. Must be between 1 and 100." }, { status: 400 });
}
// Validate status parameter if provided
const validStatuses = ["RUNNING", "COMPLETED", "ERROR"];
if (status && !validStatuses.includes(status)) {
return NextResponse.json({ error: `Invalid status. Must be one of: ${validStatuses.join(", ")}` }, { status: 400 });
}
// Build where clause
const where: Record<string, unknown> = {};
if (status) {
where.status = status;
}
if (search) {
where.name = {
contains: search,
mode: "insensitive",
};
}
if (sessionId) {
where.sessionId = sessionId;
}
// Count total traces
const total = await prisma.trace.count({ where });
// Calculate pagination
const skip = (page - 1) * limit;
const totalPages = Math.ceil(total / limit);
// Fetch traces with pagination
const traces = await prisma.trace.findMany({
where,
include: {
_count: {
select: {
decisionPoints: true,
spans: true,
events: true,
},
},
},
orderBy: {
startedAt: "desc",
},
skip,
take: limit,
});
return NextResponse.json({
traces,
total,
page,
limit,
totalPages,
}, { status: 200 });
} catch (error) {
console.error("Error listing traces:", error);
return NextResponse.json({ error: "Internal server error" }, { status: 500 });
}
}

View File

@@ -0,0 +1,9 @@
import { PrismaClient } from "@agentlens/database";
const globalForPrisma = globalThis as unknown as {
prisma: PrismaClient | undefined;
};
export const prisma = globalForPrisma.prisma ?? new PrismaClient();
if (process.env.NODE_ENV !== "production") globalForPrisma.prisma = prisma;

View File

@@ -1,8 +1,47 @@
"""AgentLens - Agent observability that traces decisions, not just API calls."""
from agentlens.client import init, shutdown
from agentlens.trace import trace
from agentlens.client import init, shutdown, get_client
from agentlens.decision import log_decision
from agentlens.models import (
TraceData,
DecisionPoint,
Span,
Event,
TraceStatus,
DecisionType,
SpanType,
SpanStatus,
EventType,
)
import sys
if sys.version_info >= (3, 7):
import importlib
__trace_module = importlib.import_module("agentlens.trace")
trace = getattr(__trace_module, "trace")
TraceContext = getattr(__trace_module, "TraceContext")
get_current_trace = getattr(__trace_module, "get_current_trace")
else:
from agentlens.trace import trace, TraceContext, get_current_trace
__version__ = "0.1.0"
__all__ = ["init", "shutdown", "trace", "log_decision"]
__all__ = [
"init",
"shutdown",
"get_client",
"trace",
"TraceContext",
"get_current_trace",
"log_decision",
"TraceData",
"DecisionPoint",
"Span",
"Event",
"TraceStatus",
"DecisionType",
"SpanType",
"SpanStatus",
"EventType",
]

View File

@@ -1,43 +1,107 @@
"""Client initialization and management for AgentLens."""
from typing import Optional
import atexit
import logging
from typing import Optional, List
from agentlens.models import TraceData
from agentlens.transport import BatchTransport
logger = logging.getLogger("agentlens")
_client: Optional["AgentLensClient"] = None
_client: Optional["_Client"] = None
class _Client:
"""Internal client class for managing AgentLens connection."""
def __init__(self, api_key: str, endpoint: str) -> None:
class AgentLensClient:
def __init__(
self,
api_key: str,
endpoint: str = "https://agentlens.vectry.tech",
flush_interval: float = 5.0,
max_batch_size: int = 10,
enabled: bool = True,
) -> None:
self.api_key = api_key
self.endpoint = endpoint
self.is_shutdown = False
self.endpoint = endpoint.rstrip("/")
self.flush_interval = flush_interval
self.max_batch_size = max_batch_size
self.enabled = enabled
self._transport: Optional[BatchTransport] = None
self._sent_traces: List[TraceData] = []
if enabled:
try:
self._transport = BatchTransport(
api_key=api_key,
endpoint=endpoint,
flush_interval=flush_interval,
max_batch_size=max_batch_size,
)
except ImportError:
logger.warning("httpx not available, transport disabled")
self.enabled = False
def send_trace(self, trace: TraceData) -> None:
if not self.enabled:
self._sent_traces.append(trace)
return
if self._transport:
self._transport.add(trace)
def shutdown(self) -> None:
"""Shutdown the client."""
self.is_shutdown = True
if self._transport:
self._transport.flush()
self._transport.shutdown()
self._transport = None
@property
def sent_traces(self) -> List[TraceData]:
return self._sent_traces
@property
def transport(self) -> Optional["BatchTransport"]:
return self._transport
def init(api_key: str, endpoint: str = "https://agentlens.vectry.tech") -> None:
"""Initialize the AgentLens client.
def init(
api_key: str,
endpoint: str = "https://agentlens.vectry.tech",
flush_interval: float = 5.0,
max_batch_size: int = 10,
enabled: bool = True,
) -> None:
"""Initialize the AgentLens SDK.
Args:
api_key: Your AgentLens API key.
endpoint: The AgentLens API endpoint (default: https://agentlens.vectry.tech).
endpoint: The AgentLens API endpoint.
flush_interval: Seconds between automatic flushes (default: 5.0).
max_batch_size: Number of traces before auto-flush (default: 10).
enabled: Set to False to disable sending (useful for testing).
"""
global _client
_client = _Client(api_key=api_key, endpoint=endpoint)
if _client is not None:
_client.shutdown()
_client = AgentLensClient(
api_key=api_key,
endpoint=endpoint,
flush_interval=flush_interval,
max_batch_size=max_batch_size,
enabled=enabled,
)
atexit.register(shutdown)
logger.debug("AgentLens initialized: endpoint=%s", endpoint)
def shutdown() -> None:
"""Shutdown the AgentLens client."""
global _client
if _client:
if _client is not None:
_client.shutdown()
_client = None
logger.debug("AgentLens shutdown complete")
def get_client() -> Optional[_Client]:
"""Get the current client instance."""
def get_client() -> Optional[AgentLensClient]:
"""Get the current client instance. Returns None if not initialized."""
return _client

View File

@@ -1,32 +1,47 @@
"""Decision logging for tracking agent decision points."""
from typing import Any, Dict, List
import logging
from typing import Any, Dict, List, Optional
from agentlens.models import DecisionPoint
from agentlens.trace import get_current_trace, get_current_span_id
logger = logging.getLogger("agentlens")
def log_decision(
type: str,
chosen: Any,
alternatives: List[Any],
chosen: Dict[str, Any],
alternatives: Optional[List[Dict[str, Any]]] = None,
reasoning: Optional[str] = None,
) -> None:
"""Log a decision point in the agent's reasoning.
Args:
type: Type of decision (e.g., "tool_selection", "routing", "retry").
chosen: The option that was selected.
alternatives: List of alternatives that were considered.
reasoning: Optional explanation for the decision.
Example:
log_decision(
type="tool_selection",
chosen="search",
alternatives=["search", "calculate", "browse"],
reasoning="Search is most appropriate for finding information"
context_snapshot: Optional[Dict[str, Any]] = None,
cost_usd: Optional[float] = None,
duration_ms: Optional[int] = None,
) -> Optional[DecisionPoint]:
current_trace = get_current_trace()
if current_trace is None:
logger.warning(
"AgentLens: log_decision called outside of a trace context. Decision not recorded."
)
"""
print(f"[AgentLens] Decision logged: {type}")
print(f"[AgentLens] Chosen: {chosen}")
print(f"[AgentLens] Alternatives: {alternatives}")
if reasoning:
print(f"[AgentLens] Reasoning: {reasoning}")
return None
parent_span_id = get_current_span_id()
decision = DecisionPoint(
type=type,
chosen=chosen,
alternatives=alternatives or [],
reasoning=reasoning,
context_snapshot=context_snapshot,
cost_usd=cost_usd,
duration_ms=duration_ms,
parent_span_id=parent_span_id,
)
current_trace.decision_points.append(decision)
logger.debug(
"AgentLens: Decision logged: type=%s, chosen=%s",
type,
chosen.get("name", str(chosen)),
)
return decision

View File

@@ -0,0 +1,211 @@
"""Data models for AgentLens SDK, matching the server-side Prisma schema."""
import time
import uuid
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, List, Optional
from enum import Enum
class TraceStatus(str, Enum):
RUNNING = "RUNNING"
COMPLETED = "COMPLETED"
ERROR = "ERROR"
class DecisionType(str, Enum):
TOOL_SELECTION = "TOOL_SELECTION"
ROUTING = "ROUTING"
RETRY = "RETRY"
ESCALATION = "ESCALATION"
MEMORY_RETRIEVAL = "MEMORY_RETRIEVAL"
PLANNING = "PLANNING"
CUSTOM = "CUSTOM"
class SpanType(str, Enum):
LLM_CALL = "LLM_CALL"
TOOL_CALL = "TOOL_CALL"
MEMORY_OP = "MEMORY_OP"
CHAIN = "CHAIN"
AGENT = "AGENT"
CUSTOM = "CUSTOM"
class SpanStatus(str, Enum):
RUNNING = "RUNNING"
COMPLETED = "COMPLETED"
ERROR = "ERROR"
class EventType(str, Enum):
ERROR = "ERROR"
RETRY = "RETRY"
FALLBACK = "FALLBACK"
CONTEXT_OVERFLOW = "CONTEXT_OVERFLOW"
USER_FEEDBACK = "USER_FEEDBACK"
CUSTOM = "CUSTOM"
def _generate_id() -> str:
return str(uuid.uuid4())
def _now_ms() -> int:
return int(time.time() * 1000)
def _now_iso() -> str:
from datetime import datetime, timezone
return datetime.now(timezone.utc).isoformat()
@dataclass
class DecisionPoint:
type: str # DecisionType value
chosen: Dict[str, Any] # {name, confidence, params}
alternatives: List[Dict[str, Any]] # [{name, confidence, reason_rejected}]
reasoning: Optional[str] = None
context_snapshot: Optional[Dict[str, Any]] = (
None # {window_usage_pct, tokens_used, tokens_available}
)
duration_ms: Optional[int] = None
cost_usd: Optional[float] = None
parent_span_id: Optional[str] = None
id: str = field(default_factory=_generate_id)
timestamp: str = field(default_factory=_now_iso)
def to_dict(self) -> Dict[str, Any]:
d = {
"id": self.id,
"type": self.type,
"chosen": self.chosen,
"alternatives": self.alternatives,
"timestamp": self.timestamp,
}
if self.reasoning is not None:
d["reasoning"] = self.reasoning
if self.context_snapshot is not None:
d["contextSnapshot"] = self.context_snapshot
if self.duration_ms is not None:
d["durationMs"] = self.duration_ms
if self.cost_usd is not None:
d["costUsd"] = self.cost_usd
if self.parent_span_id is not None:
d["parentSpanId"] = self.parent_span_id
return d
@dataclass
class Span:
name: str
type: str # SpanType value
id: str = field(default_factory=_generate_id)
parent_span_id: Optional[str] = None
input_data: Optional[Any] = None
output_data: Optional[Any] = None
token_count: Optional[int] = None
cost_usd: Optional[float] = None
duration_ms: Optional[int] = None
status: str = field(default_factory=lambda: SpanStatus.RUNNING.value)
status_message: Optional[str] = None
started_at: str = field(default_factory=_now_iso)
ended_at: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
d = {
"id": self.id,
"name": self.name,
"type": self.type,
"status": self.status,
"startedAt": self.started_at,
}
if self.parent_span_id is not None:
d["parentSpanId"] = self.parent_span_id
if self.input_data is not None:
d["input"] = self.input_data
if self.output_data is not None:
d["output"] = self.output_data
if self.token_count is not None:
d["tokenCount"] = self.token_count
if self.cost_usd is not None:
d["costUsd"] = self.cost_usd
if self.duration_ms is not None:
d["durationMs"] = self.duration_ms
if self.status_message is not None:
d["statusMessage"] = self.status_message
if self.ended_at is not None:
d["endedAt"] = self.ended_at
if self.metadata is not None:
d["metadata"] = self.metadata
return d
@dataclass
class Event:
type: str # EventType value
name: str
span_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
id: str = field(default_factory=_generate_id)
timestamp: str = field(default_factory=_now_iso)
def to_dict(self) -> Dict[str, Any]:
d = {
"id": self.id,
"type": self.type,
"name": self.name,
"timestamp": self.timestamp,
}
if self.span_id is not None:
d["spanId"] = self.span_id
if self.metadata is not None:
d["metadata"] = self.metadata
return d
@dataclass
class TraceData:
"""Represents a complete trace to be sent to the server."""
name: str
id: str = field(default_factory=_generate_id)
session_id: Optional[str] = None
status: str = field(default_factory=lambda: TraceStatus.RUNNING.value)
tags: List[str] = field(default_factory=list)
metadata: Optional[Dict[str, Any]] = None
total_cost: Optional[float] = None
total_tokens: Optional[int] = None
total_duration: Optional[int] = None # ms
started_at: str = field(default_factory=_now_iso)
ended_at: Optional[str] = None
decision_points: List[DecisionPoint] = field(default_factory=list)
spans: List[Span] = field(default_factory=list)
events: List[Event] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
d = {
"id": self.id,
"name": self.name,
"status": self.status,
"tags": self.tags,
"startedAt": self.started_at,
"decisionPoints": [dp.to_dict() for dp in self.decision_points],
"spans": [s.to_dict() for s in self.spans],
"events": [e.to_dict() for e in self.events],
}
if self.session_id is not None:
d["sessionId"] = self.session_id
if self.metadata is not None:
d["metadata"] = self.metadata
if self.total_cost is not None:
d["totalCost"] = self.total_cost
if self.total_tokens is not None:
d["totalTokens"] = self.total_tokens
if self.total_duration is not None:
d["totalDuration"] = self.total_duration
if self.ended_at is not None:
d["endedAt"] = self.ended_at
return d

View File

@@ -1,76 +1,191 @@
"""Trace decorator and context manager for instrumenting agent functions."""
from typing import Callable, Optional, Any
import asyncio
import logging
import threading
import time
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Union
from agentlens.models import (
TraceData,
Span,
SpanType,
TraceStatus,
SpanStatus,
Event,
EventType,
_now_iso,
)
logger = logging.getLogger("agentlens")
_context = threading.local()
def trace(name: Optional[str] = None) -> Callable[..., Any]:
"""Decorator to trace a function or method.
def _get_context_stack() -> List[Union[TraceData, Span]]:
if not hasattr(_context, "stack"):
_context.stack = []
return _context.stack
Args:
name: Name for the trace. If not provided, uses the function name.
Returns:
Decorated function with tracing enabled.
def get_current_trace() -> Optional[TraceData]:
stack = _get_context_stack()
if not stack:
return None
for item in stack:
if isinstance(item, TraceData):
return item
return None
Example:
@trace(name="research-agent")
async def research(topic: str) -> str:
return f"Researching: {topic}"
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def get_current_span_id() -> Optional[str]:
stack = _get_context_stack()
if not stack:
return None
for item in reversed(stack):
if isinstance(item, Span):
return item.id
return None
class TraceContext:
def __init__(
self,
name: str,
tags: Optional[List[str]] = None,
session_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
self.name = name or "trace"
self.tags = tags or []
self.session_id = session_id
self.metadata = metadata
self._trace_data: Optional[TraceData] = None
self._span: Optional[Span] = None
self._start_time: float = 0
self._is_nested: bool = False
def __enter__(self) -> "TraceContext":
self._start_time = time.time()
stack = _get_context_stack()
if stack:
self._is_nested = True
parent_trace = get_current_trace()
parent_span_id = get_current_span_id()
self._span = Span(
name=self.name,
type=SpanType.AGENT.value,
parent_span_id=parent_span_id,
started_at=_now_iso(),
)
if parent_trace:
parent_trace.spans.append(self._span)
stack.append(self._span)
else:
self._trace_data = TraceData(
name=self.name,
tags=self.tags,
session_id=self.session_id,
metadata=self.metadata,
status=TraceStatus.RUNNING.value,
started_at=_now_iso(),
)
stack.append(self._trace_data)
return self
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> None:
from agentlens.client import get_client
client = get_client()
end_time = time.time()
total_duration = int((end_time - self._start_time) * 1000)
stack = _get_context_stack()
if self._is_nested and self._span:
self._span.status = (
SpanStatus.COMPLETED.value
if exc_type is None
else SpanStatus.ERROR.value
)
self._span.duration_ms = total_duration
self._span.ended_at = _now_iso()
stack.pop()
elif self._trace_data:
if exc_type is not None:
self._trace_data.status = TraceStatus.ERROR.value
error_event = Event(
type=EventType.ERROR.value,
name=str(exc_val) if exc_val else "Unknown error",
)
self._trace_data.events.append(error_event)
else:
self._trace_data.status = TraceStatus.COMPLETED.value
self._trace_data.total_duration = total_duration
self._trace_data.ended_at = _now_iso()
stack.pop()
client = get_client()
if client:
client.send_trace(self._trace_data)
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
trace_name = name or func.__name__
print(f"[AgentLens] Starting trace: {trace_name}")
try:
result = await func(*args, **kwargs)
print(f"[AgentLens] Completed trace: {trace_name}")
return result
except Exception as e:
print(f"[AgentLens] Error in trace {trace_name}: {e}")
raise
with TraceContext(
name=self.name,
tags=self.tags,
session_id=self.session_id,
metadata=self.metadata,
):
return await func(*args, **kwargs)
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
trace_name = name or func.__name__
print(f"[AgentLens] Starting trace: {trace_name}")
try:
result = func(*args, **kwargs)
print(f"[AgentLens] Completed trace: {trace_name}")
return result
except Exception as e:
print(f"[AgentLens] Error in trace {trace_name}: {e}")
raise
with TraceContext(
name=self.name,
tags=self.tags,
session_id=self.session_id,
metadata=self.metadata,
):
return func(*args, **kwargs)
if hasattr(func, "__await__"):
return async_wrapper
else:
return sync_wrapper
return decorator
@property
def trace_id(self) -> Optional[str]:
if self._trace_data:
return self._trace_data.id
return None
class Tracer:
"""Context manager for creating traces.
def trace(
name: Union[Callable[..., Any], str, None] = None,
tags: Optional[List[str]] = None,
session_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Union[TraceContext, Callable[..., Any]]:
if callable(name):
func = name
ctx = TraceContext(
name=func.__name__, tags=tags, session_id=session_id, metadata=metadata
)
return ctx(func)
Example:
with Tracer(name="custom-operation"):
# Your code here
pass
"""
def __init__(self, name: str) -> None:
self.name = name
def __enter__(self) -> "Tracer":
print(f"[AgentLens] Starting trace: {self.name}")
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
if exc_type is None:
print(f"[AgentLens] Completed trace: {self.name}")
else:
print(f"[AgentLens] Error in trace {self.name}: {exc_val}")
return False
return TraceContext(
name=name or "trace", tags=tags, session_id=session_id, metadata=metadata
)

View File

@@ -1,38 +1,92 @@
"""Batch transport for sending data to AgentLens API."""
"""Batch transport for sending trace data to AgentLens API."""
from typing import List, Dict, Any
import json
import logging
import threading
import time
from typing import Any, Dict, List, Optional
import httpx
from agentlens.models import TraceData
logger = logging.getLogger("agentlens")
class BatchTransport:
"""Transport layer that batches events for efficient API calls.
"""Thread-safe batch transport that buffers and sends traces to API."""
This class handles batching and sending of traces, decisions, and other
events to the AgentLens backend.
"""
def __init__(
self,
api_key: str,
endpoint: str,
max_batch_size: int = 10,
flush_interval: float = 5.0,
) -> None:
self._api_key = api_key
self._endpoint = endpoint.rstrip("/")
self._max_batch_size = max_batch_size
self._flush_interval = flush_interval
self._buffer: List[Dict[str, Any]] = []
self._lock = threading.Lock()
self._shutdown_event = threading.Event()
self._flush_thread = threading.Thread(
target=self._flush_loop, daemon=True, name="agentlens-flush"
)
self._flush_thread.start()
self._client = httpx.Client(timeout=30.0)
def __init__(self, max_batch_size: int = 100, flush_interval: float = 1.0) -> None:
self.max_batch_size = max_batch_size
self.flush_interval = flush_interval
self._batch: List[Dict[str, Any]] = []
def add(self, trace: TraceData) -> None:
"""Add a completed trace to send buffer."""
trace_dict = trace.to_dict()
with self._lock:
self._buffer.append(trace_dict)
should_flush = len(self._buffer) >= self._max_batch_size
if should_flush:
self._do_flush()
def add(self, event: Dict[str, Any]) -> None:
"""Add an event to the batch.
def _flush_loop(self) -> None:
"""Background loop that periodically flushes buffer."""
while not self._shutdown_event.is_set():
self._shutdown_event.wait(timeout=self._flush_interval)
if not self._shutdown_event.is_set():
self._do_flush()
Args:
event: Event data to be sent.
"""
self._batch.append(event)
if len(self._batch) >= self.max_batch_size:
self.flush()
def _do_flush(self) -> None:
"""Flush all buffered traces to the API."""
with self._lock:
if not self._buffer:
return
batch = self._buffer.copy()
self._buffer.clear()
try:
response = self._client.post(
f"{self._endpoint}/api/traces",
json={"traces": batch},
headers={
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
},
)
if response.status_code >= 400:
logger.warning(
"AgentLens: Failed to send traces (HTTP %d): %s",
response.status_code,
response.text[:200],
)
except Exception as e:
logger.warning("AgentLens: Failed to send traces: %s", e)
# Put traces back in buffer for retry (optional, up to you)
# For now, drop on failure to avoid unbounded growth
def flush(self) -> None:
"""Flush the batch by sending all pending events."""
if not self._batch:
return
print(f"[AgentLens] Flushing batch of {len(self._batch)} events")
self._batch.clear()
"""Manually flush all buffered traces."""
self._do_flush()
def shutdown(self) -> None:
"""Shutdown the transport, flushing any remaining events."""
self.flush()
"""Shutdown transport: flush remaining traces and stop background thread."""
self._shutdown_event.set()
self._flush_thread.join(timeout=5.0)
self._do_flush() # Final flush
self._client.close()

View File

View File

@@ -0,0 +1,129 @@
import unittest
from agentlens import init, shutdown, get_client, trace, log_decision
from agentlens.models import TraceData, TraceStatus
class TestSDK(unittest.TestCase):
def setUp(self):
shutdown()
init(api_key="test_key", enabled=False)
def tearDown(self):
shutdown()
def test_init_and_shutdown(self):
client = get_client()
self.assertIsNotNone(client)
self.assertEqual(client.api_key, "test_key")
self.assertFalse(client.enabled)
shutdown()
self.assertIsNone(get_client())
def test_decorator_sync_function(self):
@trace("test_decorator")
def test_func():
return 42
result = test_func()
self.assertEqual(result, 42)
client = get_client()
traces = client.sent_traces
self.assertEqual(len(traces), 1)
self.assertEqual(traces[0].name, "test_decorator")
self.assertEqual(traces[0].status, TraceStatus.COMPLETED.value)
self.assertIsNotNone(traces[0].total_duration)
def test_context_manager(self):
with trace("test_context") as t:
pass
client = get_client()
traces = client.sent_traces
self.assertEqual(len(traces), 1)
self.assertEqual(traces[0].name, "test_context")
self.assertEqual(traces[0].status, TraceStatus.COMPLETED.value)
def test_nested_traces(self):
with trace("outer_trace"):
with trace("inner_span"):
pass
client = get_client()
traces = client.sent_traces
self.assertEqual(len(traces), 1)
self.assertEqual(traces[0].name, "outer_trace")
self.assertEqual(len(traces[0].spans), 1)
self.assertEqual(traces[0].spans[0].name, "inner_span")
def test_log_decision(self):
with trace("test_trace"):
log_decision(
type="tool_selection",
chosen={"name": "search"},
alternatives=[{"name": "calculate"}, {"name": "browse"}],
reasoning="Search is best for finding information",
)
client = get_client()
traces = client.sent_traces
self.assertEqual(len(traces), 1)
self.assertEqual(len(traces[0].decision_points), 1)
decision = traces[0].decision_points[0]
self.assertEqual(decision.type, "tool_selection")
self.assertEqual(decision.chosen, {"name": "search"})
self.assertEqual(len(decision.alternatives), 2)
self.assertEqual(decision.reasoning, "Search is best for finding information")
def test_error_handling(self):
with self.assertRaises(ValueError):
with trace("test_error"):
raise ValueError("Test error")
client = get_client()
traces = client.sent_traces
self.assertEqual(len(traces), 1)
self.assertEqual(traces[0].name, "test_error")
self.assertEqual(traces[0].status, TraceStatus.ERROR.value)
self.assertEqual(len(traces[0].events), 1)
self.assertEqual(traces[0].events[0].type, "ERROR")
self.assertIn("Test error", traces[0].events[0].name)
def test_log_decision_outside_trace(self):
shutdown()
log_decision(
type="tool_selection",
chosen={"name": "search"},
alternatives=[],
)
init(api_key="test_key", enabled=False)
client = get_client()
self.assertEqual(len(client.sent_traces), 0)
def test_model_serialization(self):
trace_data = TraceData(
name="test",
tags=["tag1", "tag2"],
session_id="session123",
metadata={"key": "value"},
)
trace_dict = trace_data.to_dict()
self.assertEqual(trace_dict["name"], "test")
self.assertEqual(trace_dict["tags"], ["tag1", "tag2"])
self.assertEqual(trace_dict["sessionId"], "session123")
self.assertEqual(trace_dict["metadata"], {"key": "value"})
self.assertEqual(trace_dict["status"], "RUNNING")
self.assertIn("startedAt", trace_dict)
self.assertIn("decisionPoints", trace_dict)
self.assertIn("spans", trace_dict)
self.assertIn("events", trace_dict)
if __name__ == "__main__":
unittest.main()