feat: Python SDK real implementation + API ingestion routes
- SDK: client with BatchTransport, trace decorator/context manager, log_decision, thread-local context stack, nested trace→span support - API: POST /api/traces (batch ingest), GET /api/traces (paginated list), GET /api/traces/[id] (full trace with relations), GET /api/health - Tests: 8 unit tests for SDK (all passing) - Transport: thread-safe buffer with background flush thread
This commit is contained in:
5
apps/web/src/app/api/health/route.ts
Normal file
5
apps/web/src/app/api/health/route.ts
Normal file
@@ -0,0 +1,5 @@
|
||||
import { NextResponse } from "next/server";
|
||||
|
||||
export async function GET() {
|
||||
return NextResponse.json({ status: "ok", service: "agentlens", timestamp: new Date().toISOString() });
|
||||
}
|
||||
46
apps/web/src/app/api/traces/[id]/route.ts
Normal file
46
apps/web/src/app/api/traces/[id]/route.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
import { NextResponse } from "next/server";
|
||||
import { prisma } from "@/lib/prisma";
|
||||
|
||||
// GET /api/traces/[id] — Get single trace with all relations
|
||||
export async function GET(
|
||||
_request: Request,
|
||||
{ params }: { params: Promise<{ id: string }> }
|
||||
) {
|
||||
try {
|
||||
const { id } = await params;
|
||||
|
||||
if (!id || typeof id !== "string") {
|
||||
return NextResponse.json({ error: "Invalid trace ID" }, { status: 400 });
|
||||
}
|
||||
|
||||
const trace = await prisma.trace.findUnique({
|
||||
where: { id },
|
||||
include: {
|
||||
decisionPoints: {
|
||||
orderBy: {
|
||||
timestamp: "asc",
|
||||
},
|
||||
},
|
||||
spans: {
|
||||
orderBy: {
|
||||
startedAt: "asc",
|
||||
},
|
||||
},
|
||||
events: {
|
||||
orderBy: {
|
||||
timestamp: "asc",
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (!trace) {
|
||||
return NextResponse.json({ error: "Trace not found" }, { status: 404 });
|
||||
}
|
||||
|
||||
return NextResponse.json({ trace }, { status: 200 });
|
||||
} catch (error) {
|
||||
console.error("Error retrieving trace:", error);
|
||||
return NextResponse.json({ error: "Internal server error" }, { status: 500 });
|
||||
}
|
||||
}
|
||||
324
apps/web/src/app/api/traces/route.ts
Normal file
324
apps/web/src/app/api/traces/route.ts
Normal file
@@ -0,0 +1,324 @@
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
import { prisma } from "@/lib/prisma";
|
||||
import { Prisma } from "@agentlens/database";
|
||||
|
||||
// Types
|
||||
interface DecisionPointPayload {
|
||||
id: string;
|
||||
type: "TOOL_SELECTION" | "ROUTING" | "RETRY" | "ESCALATION" | "MEMORY_RETRIEVAL" | "PLANNING" | "CUSTOM";
|
||||
reasoning?: string;
|
||||
chosen: Prisma.JsonValue;
|
||||
alternatives: Prisma.JsonValue[];
|
||||
contextSnapshot?: Prisma.JsonValue;
|
||||
durationMs?: number;
|
||||
costUsd?: number;
|
||||
parentSpanId?: string;
|
||||
timestamp: string;
|
||||
}
|
||||
|
||||
interface SpanPayload {
|
||||
id: string;
|
||||
parentSpanId?: string;
|
||||
name: string;
|
||||
type: "LLM_CALL" | "TOOL_CALL" | "MEMORY_OP" | "CHAIN" | "AGENT" | "CUSTOM";
|
||||
input?: Prisma.JsonValue;
|
||||
output?: Prisma.JsonValue;
|
||||
tokenCount?: number;
|
||||
costUsd?: number;
|
||||
durationMs?: number;
|
||||
status: "RUNNING" | "COMPLETED" | "ERROR";
|
||||
statusMessage?: string;
|
||||
startedAt: string;
|
||||
endedAt?: string;
|
||||
metadata?: Prisma.JsonValue;
|
||||
}
|
||||
|
||||
interface EventPayload {
|
||||
id: string;
|
||||
spanId?: string;
|
||||
type: "ERROR" | "RETRY" | "FALLBACK" | "CONTEXT_OVERFLOW" | "USER_FEEDBACK" | "CUSTOM";
|
||||
name: string;
|
||||
metadata?: Prisma.JsonValue;
|
||||
timestamp: string;
|
||||
}
|
||||
|
||||
interface TracePayload {
|
||||
id: string;
|
||||
name: string;
|
||||
sessionId?: string;
|
||||
status: "RUNNING" | "COMPLETED" | "ERROR";
|
||||
tags: string[];
|
||||
metadata?: Prisma.JsonValue;
|
||||
totalCost?: number;
|
||||
totalTokens?: number;
|
||||
totalDuration?: number;
|
||||
startedAt: string;
|
||||
endedAt?: string;
|
||||
decisionPoints: DecisionPointPayload[];
|
||||
spans: SpanPayload[];
|
||||
events: EventPayload[];
|
||||
}
|
||||
|
||||
interface BatchTracesRequest {
|
||||
traces: TracePayload[];
|
||||
}
|
||||
|
||||
// POST /api/traces — Batch ingest traces from SDK
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
// Validate Authorization header
|
||||
const authHeader = request.headers.get("authorization");
|
||||
if (!authHeader || !authHeader.startsWith("Bearer ")) {
|
||||
return NextResponse.json({ error: "Missing or invalid Authorization header" }, { status: 401 });
|
||||
}
|
||||
|
||||
const apiKey = authHeader.slice(7);
|
||||
if (!apiKey) {
|
||||
return NextResponse.json({ error: "Missing API key in Authorization header" }, { status: 401 });
|
||||
}
|
||||
|
||||
// Parse and validate request body
|
||||
const body: BatchTracesRequest = await request.json();
|
||||
if (!body.traces || !Array.isArray(body.traces)) {
|
||||
return NextResponse.json({ error: "Request body must contain a 'traces' array" }, { status: 400 });
|
||||
}
|
||||
|
||||
if (body.traces.length === 0) {
|
||||
return NextResponse.json({ error: "Traces array cannot be empty" }, { status: 400 });
|
||||
}
|
||||
|
||||
// Validate each trace payload
|
||||
for (const trace of body.traces) {
|
||||
if (!trace.id || typeof trace.id !== "string") {
|
||||
return NextResponse.json({ error: "Each trace must have a valid 'id' (string)" }, { status: 400 });
|
||||
}
|
||||
if (!trace.name || typeof trace.name !== "string") {
|
||||
return NextResponse.json({ error: "Each trace must have a valid 'name' (string)" }, { status: 400 });
|
||||
}
|
||||
if (!trace.startedAt || typeof trace.startedAt !== "string") {
|
||||
return NextResponse.json({ error: "Each trace must have a valid 'startedAt' (ISO date string)" }, { status: 400 });
|
||||
}
|
||||
if (!["RUNNING", "COMPLETED", "ERROR"].includes(trace.status)) {
|
||||
return NextResponse.json({ error: `Invalid trace status: ${trace.status}` }, { status: 400 });
|
||||
}
|
||||
if (!Array.isArray(trace.tags)) {
|
||||
return NextResponse.json({ error: "Trace tags must be an array" }, { status: 400 });
|
||||
}
|
||||
if (!Array.isArray(trace.decisionPoints)) {
|
||||
return NextResponse.json({ error: "Trace decisionPoints must be an array" }, { status: 400 });
|
||||
}
|
||||
if (!Array.isArray(trace.spans)) {
|
||||
return NextResponse.json({ error: "Trace spans must be an array" }, { status: 400 });
|
||||
}
|
||||
if (!Array.isArray(trace.events)) {
|
||||
return NextResponse.json({ error: "Trace events must be an array" }, { status: 400 });
|
||||
}
|
||||
|
||||
// Validate decision points
|
||||
for (const dp of trace.decisionPoints) {
|
||||
if (!dp.id || typeof dp.id !== "string") {
|
||||
return NextResponse.json({ error: "Each decision point must have a valid 'id' (string)" }, { status: 400 });
|
||||
}
|
||||
if (!["TOOL_SELECTION", "ROUTING", "RETRY", "ESCALATION", "MEMORY_RETRIEVAL", "PLANNING", "CUSTOM"].includes(dp.type)) {
|
||||
return NextResponse.json({ error: `Invalid decision point type: ${dp.type}` }, { status: 400 });
|
||||
}
|
||||
if (!dp.timestamp || typeof dp.timestamp !== "string") {
|
||||
return NextResponse.json({ error: "Each decision point must have a valid 'timestamp' (ISO date string)" }, { status: 400 });
|
||||
}
|
||||
if (!Array.isArray(dp.alternatives)) {
|
||||
return NextResponse.json({ error: "Decision point alternatives must be an array" }, { status: 400 });
|
||||
}
|
||||
}
|
||||
|
||||
// Validate spans
|
||||
for (const span of trace.spans) {
|
||||
if (!span.id || typeof span.id !== "string") {
|
||||
return NextResponse.json({ error: "Each span must have a valid 'id' (string)" }, { status: 400 });
|
||||
}
|
||||
if (!span.name || typeof span.name !== "string") {
|
||||
return NextResponse.json({ error: "Each span must have a valid 'name' (string)" }, { status: 400 });
|
||||
}
|
||||
if (!["LLM_CALL", "TOOL_CALL", "MEMORY_OP", "CHAIN", "AGENT", "CUSTOM"].includes(span.type)) {
|
||||
return NextResponse.json({ error: `Invalid span type: ${span.type}` }, { status: 400 });
|
||||
}
|
||||
if (!span.startedAt || typeof span.startedAt !== "string") {
|
||||
return NextResponse.json({ error: "Each span must have a valid 'startedAt' (ISO date string)" }, { status: 400 });
|
||||
}
|
||||
if (!["RUNNING", "COMPLETED", "ERROR"].includes(span.status)) {
|
||||
return NextResponse.json({ error: `Invalid span status: ${span.status}` }, { status: 400 });
|
||||
}
|
||||
}
|
||||
|
||||
// Validate events
|
||||
for (const event of trace.events) {
|
||||
if (!event.id || typeof event.id !== "string") {
|
||||
return NextResponse.json({ error: "Each event must have a valid 'id' (string)" }, { status: 400 });
|
||||
}
|
||||
if (!event.name || typeof event.name !== "string") {
|
||||
return NextResponse.json({ error: "Each event must have a valid 'name' (string)" }, { status: 400 });
|
||||
}
|
||||
if (!["ERROR", "RETRY", "FALLBACK", "CONTEXT_OVERFLOW", "USER_FEEDBACK", "CUSTOM"].includes(event.type)) {
|
||||
return NextResponse.json({ error: `Invalid event type: ${event.type}` }, { status: 400 });
|
||||
}
|
||||
if (!event.timestamp || typeof event.timestamp !== "string") {
|
||||
return NextResponse.json({ error: "Each event must have a valid 'timestamp' (ISO date string)" }, { status: 400 });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Insert traces using transaction
|
||||
const result = await prisma.$transaction(
|
||||
body.traces.map((trace) =>
|
||||
prisma.trace.create({
|
||||
data: {
|
||||
id: trace.id,
|
||||
name: trace.name,
|
||||
sessionId: trace.sessionId,
|
||||
status: trace.status,
|
||||
tags: trace.tags,
|
||||
metadata: trace.metadata as Prisma.InputJsonValue,
|
||||
totalCost: trace.totalCost,
|
||||
totalTokens: trace.totalTokens,
|
||||
totalDuration: trace.totalDuration,
|
||||
startedAt: new Date(trace.startedAt),
|
||||
endedAt: trace.endedAt ? new Date(trace.endedAt) : null,
|
||||
decisionPoints: {
|
||||
create: trace.decisionPoints.map((dp) => ({
|
||||
id: dp.id,
|
||||
type: dp.type,
|
||||
reasoning: dp.reasoning,
|
||||
chosen: dp.chosen as Prisma.InputJsonValue,
|
||||
alternatives: dp.alternatives as Prisma.InputJsonValue[],
|
||||
contextSnapshot: dp.contextSnapshot as Prisma.InputJsonValue | undefined,
|
||||
durationMs: dp.durationMs,
|
||||
costUsd: dp.costUsd,
|
||||
parentSpanId: dp.parentSpanId,
|
||||
timestamp: new Date(dp.timestamp),
|
||||
})),
|
||||
},
|
||||
spans: {
|
||||
create: trace.spans.map((span) => ({
|
||||
id: span.id,
|
||||
parentSpanId: span.parentSpanId,
|
||||
name: span.name,
|
||||
type: span.type,
|
||||
input: span.input as Prisma.InputJsonValue | undefined,
|
||||
output: span.output as Prisma.InputJsonValue | undefined,
|
||||
tokenCount: span.tokenCount,
|
||||
costUsd: span.costUsd,
|
||||
durationMs: span.durationMs,
|
||||
status: span.status,
|
||||
statusMessage: span.statusMessage,
|
||||
startedAt: new Date(span.startedAt),
|
||||
endedAt: span.endedAt ? new Date(span.endedAt) : null,
|
||||
metadata: span.metadata as Prisma.InputJsonValue | undefined,
|
||||
})),
|
||||
},
|
||||
events: {
|
||||
create: trace.events.map((event) => ({
|
||||
id: event.id,
|
||||
spanId: event.spanId,
|
||||
type: event.type,
|
||||
name: event.name,
|
||||
metadata: event.metadata as Prisma.InputJsonValue | undefined,
|
||||
timestamp: new Date(event.timestamp),
|
||||
})),
|
||||
},
|
||||
},
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
return NextResponse.json({ success: true, count: result.length }, { status: 200 });
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
return NextResponse.json({ error: "Invalid JSON in request body" }, { status: 400 });
|
||||
}
|
||||
|
||||
// Handle unique constraint violations
|
||||
if (error instanceof Error && error.message.includes("Unique constraint")) {
|
||||
return NextResponse.json({ error: "Duplicate trace ID detected" }, { status: 409 });
|
||||
}
|
||||
|
||||
console.error("Error processing traces:", error);
|
||||
return NextResponse.json({ error: "Internal server error" }, { status: 500 });
|
||||
}
|
||||
}
|
||||
|
||||
// GET /api/traces — List traces with pagination
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
const { searchParams } = new URL(request.url);
|
||||
const page = parseInt(searchParams.get("page") ?? "1", 10);
|
||||
const limit = parseInt(searchParams.get("limit") ?? "20", 10);
|
||||
const status = searchParams.get("status");
|
||||
const search = searchParams.get("search");
|
||||
const sessionId = searchParams.get("sessionId");
|
||||
|
||||
// Validate pagination parameters
|
||||
if (isNaN(page) || page < 1) {
|
||||
return NextResponse.json({ error: "Invalid page parameter. Must be a positive integer." }, { status: 400 });
|
||||
}
|
||||
if (isNaN(limit) || limit < 1 || limit > 100) {
|
||||
return NextResponse.json({ error: "Invalid limit parameter. Must be between 1 and 100." }, { status: 400 });
|
||||
}
|
||||
|
||||
// Validate status parameter if provided
|
||||
const validStatuses = ["RUNNING", "COMPLETED", "ERROR"];
|
||||
if (status && !validStatuses.includes(status)) {
|
||||
return NextResponse.json({ error: `Invalid status. Must be one of: ${validStatuses.join(", ")}` }, { status: 400 });
|
||||
}
|
||||
|
||||
// Build where clause
|
||||
const where: Record<string, unknown> = {};
|
||||
if (status) {
|
||||
where.status = status;
|
||||
}
|
||||
if (search) {
|
||||
where.name = {
|
||||
contains: search,
|
||||
mode: "insensitive",
|
||||
};
|
||||
}
|
||||
if (sessionId) {
|
||||
where.sessionId = sessionId;
|
||||
}
|
||||
|
||||
// Count total traces
|
||||
const total = await prisma.trace.count({ where });
|
||||
|
||||
// Calculate pagination
|
||||
const skip = (page - 1) * limit;
|
||||
const totalPages = Math.ceil(total / limit);
|
||||
|
||||
// Fetch traces with pagination
|
||||
const traces = await prisma.trace.findMany({
|
||||
where,
|
||||
include: {
|
||||
_count: {
|
||||
select: {
|
||||
decisionPoints: true,
|
||||
spans: true,
|
||||
events: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
orderBy: {
|
||||
startedAt: "desc",
|
||||
},
|
||||
skip,
|
||||
take: limit,
|
||||
});
|
||||
|
||||
return NextResponse.json({
|
||||
traces,
|
||||
total,
|
||||
page,
|
||||
limit,
|
||||
totalPages,
|
||||
}, { status: 200 });
|
||||
} catch (error) {
|
||||
console.error("Error listing traces:", error);
|
||||
return NextResponse.json({ error: "Internal server error" }, { status: 500 });
|
||||
}
|
||||
}
|
||||
9
apps/web/src/lib/prisma.ts
Normal file
9
apps/web/src/lib/prisma.ts
Normal file
@@ -0,0 +1,9 @@
|
||||
import { PrismaClient } from "@agentlens/database";
|
||||
|
||||
const globalForPrisma = globalThis as unknown as {
|
||||
prisma: PrismaClient | undefined;
|
||||
};
|
||||
|
||||
export const prisma = globalForPrisma.prisma ?? new PrismaClient();
|
||||
|
||||
if (process.env.NODE_ENV !== "production") globalForPrisma.prisma = prisma;
|
||||
@@ -1,8 +1,47 @@
|
||||
"""AgentLens - Agent observability that traces decisions, not just API calls."""
|
||||
|
||||
from agentlens.client import init, shutdown
|
||||
from agentlens.trace import trace
|
||||
from agentlens.client import init, shutdown, get_client
|
||||
from agentlens.decision import log_decision
|
||||
from agentlens.models import (
|
||||
TraceData,
|
||||
DecisionPoint,
|
||||
Span,
|
||||
Event,
|
||||
TraceStatus,
|
||||
DecisionType,
|
||||
SpanType,
|
||||
SpanStatus,
|
||||
EventType,
|
||||
)
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 7):
|
||||
import importlib
|
||||
|
||||
__trace_module = importlib.import_module("agentlens.trace")
|
||||
trace = getattr(__trace_module, "trace")
|
||||
TraceContext = getattr(__trace_module, "TraceContext")
|
||||
get_current_trace = getattr(__trace_module, "get_current_trace")
|
||||
else:
|
||||
from agentlens.trace import trace, TraceContext, get_current_trace
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__all__ = ["init", "shutdown", "trace", "log_decision"]
|
||||
__all__ = [
|
||||
"init",
|
||||
"shutdown",
|
||||
"get_client",
|
||||
"trace",
|
||||
"TraceContext",
|
||||
"get_current_trace",
|
||||
"log_decision",
|
||||
"TraceData",
|
||||
"DecisionPoint",
|
||||
"Span",
|
||||
"Event",
|
||||
"TraceStatus",
|
||||
"DecisionType",
|
||||
"SpanType",
|
||||
"SpanStatus",
|
||||
"EventType",
|
||||
]
|
||||
|
||||
@@ -1,43 +1,107 @@
|
||||
"""Client initialization and management for AgentLens."""
|
||||
|
||||
from typing import Optional
|
||||
import atexit
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from agentlens.models import TraceData
|
||||
from agentlens.transport import BatchTransport
|
||||
|
||||
logger = logging.getLogger("agentlens")
|
||||
|
||||
_client: Optional["AgentLensClient"] = None
|
||||
|
||||
|
||||
_client: Optional["_Client"] = None
|
||||
|
||||
|
||||
class _Client:
|
||||
"""Internal client class for managing AgentLens connection."""
|
||||
|
||||
def __init__(self, api_key: str, endpoint: str) -> None:
|
||||
class AgentLensClient:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
endpoint: str = "https://agentlens.vectry.tech",
|
||||
flush_interval: float = 5.0,
|
||||
max_batch_size: int = 10,
|
||||
enabled: bool = True,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.is_shutdown = False
|
||||
self.endpoint = endpoint.rstrip("/")
|
||||
self.flush_interval = flush_interval
|
||||
self.max_batch_size = max_batch_size
|
||||
self.enabled = enabled
|
||||
self._transport: Optional[BatchTransport] = None
|
||||
self._sent_traces: List[TraceData] = []
|
||||
|
||||
if enabled:
|
||||
try:
|
||||
self._transport = BatchTransport(
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
flush_interval=flush_interval,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("httpx not available, transport disabled")
|
||||
self.enabled = False
|
||||
|
||||
def send_trace(self, trace: TraceData) -> None:
|
||||
if not self.enabled:
|
||||
self._sent_traces.append(trace)
|
||||
return
|
||||
|
||||
if self._transport:
|
||||
self._transport.add(trace)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the client."""
|
||||
self.is_shutdown = True
|
||||
if self._transport:
|
||||
self._transport.flush()
|
||||
self._transport.shutdown()
|
||||
self._transport = None
|
||||
|
||||
@property
|
||||
def sent_traces(self) -> List[TraceData]:
|
||||
return self._sent_traces
|
||||
|
||||
@property
|
||||
def transport(self) -> Optional["BatchTransport"]:
|
||||
return self._transport
|
||||
|
||||
|
||||
def init(api_key: str, endpoint: str = "https://agentlens.vectry.tech") -> None:
|
||||
"""Initialize the AgentLens client.
|
||||
def init(
|
||||
api_key: str,
|
||||
endpoint: str = "https://agentlens.vectry.tech",
|
||||
flush_interval: float = 5.0,
|
||||
max_batch_size: int = 10,
|
||||
enabled: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the AgentLens SDK.
|
||||
|
||||
Args:
|
||||
api_key: Your AgentLens API key.
|
||||
endpoint: The AgentLens API endpoint (default: https://agentlens.vectry.tech).
|
||||
endpoint: The AgentLens API endpoint.
|
||||
flush_interval: Seconds between automatic flushes (default: 5.0).
|
||||
max_batch_size: Number of traces before auto-flush (default: 10).
|
||||
enabled: Set to False to disable sending (useful for testing).
|
||||
"""
|
||||
global _client
|
||||
_client = _Client(api_key=api_key, endpoint=endpoint)
|
||||
if _client is not None:
|
||||
_client.shutdown()
|
||||
_client = AgentLensClient(
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
flush_interval=flush_interval,
|
||||
max_batch_size=max_batch_size,
|
||||
enabled=enabled,
|
||||
)
|
||||
atexit.register(shutdown)
|
||||
logger.debug("AgentLens initialized: endpoint=%s", endpoint)
|
||||
|
||||
|
||||
def shutdown() -> None:
|
||||
"""Shutdown the AgentLens client."""
|
||||
global _client
|
||||
if _client:
|
||||
if _client is not None:
|
||||
_client.shutdown()
|
||||
_client = None
|
||||
logger.debug("AgentLens shutdown complete")
|
||||
|
||||
|
||||
def get_client() -> Optional[_Client]:
|
||||
"""Get the current client instance."""
|
||||
def get_client() -> Optional[AgentLensClient]:
|
||||
"""Get the current client instance. Returns None if not initialized."""
|
||||
return _client
|
||||
|
||||
@@ -1,32 +1,47 @@
|
||||
"""Decision logging for tracking agent decision points."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agentlens.models import DecisionPoint
|
||||
from agentlens.trace import get_current_trace, get_current_span_id
|
||||
|
||||
logger = logging.getLogger("agentlens")
|
||||
|
||||
|
||||
def log_decision(
|
||||
type: str,
|
||||
chosen: Any,
|
||||
alternatives: List[Any],
|
||||
chosen: Dict[str, Any],
|
||||
alternatives: Optional[List[Dict[str, Any]]] = None,
|
||||
reasoning: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Log a decision point in the agent's reasoning.
|
||||
|
||||
Args:
|
||||
type: Type of decision (e.g., "tool_selection", "routing", "retry").
|
||||
chosen: The option that was selected.
|
||||
alternatives: List of alternatives that were considered.
|
||||
reasoning: Optional explanation for the decision.
|
||||
|
||||
Example:
|
||||
log_decision(
|
||||
type="tool_selection",
|
||||
chosen="search",
|
||||
alternatives=["search", "calculate", "browse"],
|
||||
reasoning="Search is most appropriate for finding information"
|
||||
context_snapshot: Optional[Dict[str, Any]] = None,
|
||||
cost_usd: Optional[float] = None,
|
||||
duration_ms: Optional[int] = None,
|
||||
) -> Optional[DecisionPoint]:
|
||||
current_trace = get_current_trace()
|
||||
if current_trace is None:
|
||||
logger.warning(
|
||||
"AgentLens: log_decision called outside of a trace context. Decision not recorded."
|
||||
)
|
||||
"""
|
||||
print(f"[AgentLens] Decision logged: {type}")
|
||||
print(f"[AgentLens] Chosen: {chosen}")
|
||||
print(f"[AgentLens] Alternatives: {alternatives}")
|
||||
if reasoning:
|
||||
print(f"[AgentLens] Reasoning: {reasoning}")
|
||||
return None
|
||||
|
||||
parent_span_id = get_current_span_id()
|
||||
|
||||
decision = DecisionPoint(
|
||||
type=type,
|
||||
chosen=chosen,
|
||||
alternatives=alternatives or [],
|
||||
reasoning=reasoning,
|
||||
context_snapshot=context_snapshot,
|
||||
cost_usd=cost_usd,
|
||||
duration_ms=duration_ms,
|
||||
parent_span_id=parent_span_id,
|
||||
)
|
||||
|
||||
current_trace.decision_points.append(decision)
|
||||
logger.debug(
|
||||
"AgentLens: Decision logged: type=%s, chosen=%s",
|
||||
type,
|
||||
chosen.get("name", str(chosen)),
|
||||
)
|
||||
return decision
|
||||
|
||||
211
packages/sdk-python/agentlens/models.py
Normal file
211
packages/sdk-python/agentlens/models.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Data models for AgentLens SDK, matching the server-side Prisma schema."""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TraceStatus(str, Enum):
|
||||
RUNNING = "RUNNING"
|
||||
COMPLETED = "COMPLETED"
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
class DecisionType(str, Enum):
|
||||
TOOL_SELECTION = "TOOL_SELECTION"
|
||||
ROUTING = "ROUTING"
|
||||
RETRY = "RETRY"
|
||||
ESCALATION = "ESCALATION"
|
||||
MEMORY_RETRIEVAL = "MEMORY_RETRIEVAL"
|
||||
PLANNING = "PLANNING"
|
||||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
class SpanType(str, Enum):
|
||||
LLM_CALL = "LLM_CALL"
|
||||
TOOL_CALL = "TOOL_CALL"
|
||||
MEMORY_OP = "MEMORY_OP"
|
||||
CHAIN = "CHAIN"
|
||||
AGENT = "AGENT"
|
||||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
class SpanStatus(str, Enum):
|
||||
RUNNING = "RUNNING"
|
||||
COMPLETED = "COMPLETED"
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
ERROR = "ERROR"
|
||||
RETRY = "RETRY"
|
||||
FALLBACK = "FALLBACK"
|
||||
CONTEXT_OVERFLOW = "CONTEXT_OVERFLOW"
|
||||
USER_FEEDBACK = "USER_FEEDBACK"
|
||||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
def _generate_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecisionPoint:
|
||||
type: str # DecisionType value
|
||||
chosen: Dict[str, Any] # {name, confidence, params}
|
||||
alternatives: List[Dict[str, Any]] # [{name, confidence, reason_rejected}]
|
||||
reasoning: Optional[str] = None
|
||||
context_snapshot: Optional[Dict[str, Any]] = (
|
||||
None # {window_usage_pct, tokens_used, tokens_available}
|
||||
)
|
||||
duration_ms: Optional[int] = None
|
||||
cost_usd: Optional[float] = None
|
||||
parent_span_id: Optional[str] = None
|
||||
id: str = field(default_factory=_generate_id)
|
||||
timestamp: str = field(default_factory=_now_iso)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {
|
||||
"id": self.id,
|
||||
"type": self.type,
|
||||
"chosen": self.chosen,
|
||||
"alternatives": self.alternatives,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
if self.reasoning is not None:
|
||||
d["reasoning"] = self.reasoning
|
||||
if self.context_snapshot is not None:
|
||||
d["contextSnapshot"] = self.context_snapshot
|
||||
if self.duration_ms is not None:
|
||||
d["durationMs"] = self.duration_ms
|
||||
if self.cost_usd is not None:
|
||||
d["costUsd"] = self.cost_usd
|
||||
if self.parent_span_id is not None:
|
||||
d["parentSpanId"] = self.parent_span_id
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class Span:
|
||||
name: str
|
||||
type: str # SpanType value
|
||||
id: str = field(default_factory=_generate_id)
|
||||
parent_span_id: Optional[str] = None
|
||||
input_data: Optional[Any] = None
|
||||
output_data: Optional[Any] = None
|
||||
token_count: Optional[int] = None
|
||||
cost_usd: Optional[float] = None
|
||||
duration_ms: Optional[int] = None
|
||||
status: str = field(default_factory=lambda: SpanStatus.RUNNING.value)
|
||||
status_message: Optional[str] = None
|
||||
started_at: str = field(default_factory=_now_iso)
|
||||
ended_at: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"type": self.type,
|
||||
"status": self.status,
|
||||
"startedAt": self.started_at,
|
||||
}
|
||||
if self.parent_span_id is not None:
|
||||
d["parentSpanId"] = self.parent_span_id
|
||||
if self.input_data is not None:
|
||||
d["input"] = self.input_data
|
||||
if self.output_data is not None:
|
||||
d["output"] = self.output_data
|
||||
if self.token_count is not None:
|
||||
d["tokenCount"] = self.token_count
|
||||
if self.cost_usd is not None:
|
||||
d["costUsd"] = self.cost_usd
|
||||
if self.duration_ms is not None:
|
||||
d["durationMs"] = self.duration_ms
|
||||
if self.status_message is not None:
|
||||
d["statusMessage"] = self.status_message
|
||||
if self.ended_at is not None:
|
||||
d["endedAt"] = self.ended_at
|
||||
if self.metadata is not None:
|
||||
d["metadata"] = self.metadata
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
type: str # EventType value
|
||||
name: str
|
||||
span_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
id: str = field(default_factory=_generate_id)
|
||||
timestamp: str = field(default_factory=_now_iso)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {
|
||||
"id": self.id,
|
||||
"type": self.type,
|
||||
"name": self.name,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
if self.span_id is not None:
|
||||
d["spanId"] = self.span_id
|
||||
if self.metadata is not None:
|
||||
d["metadata"] = self.metadata
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceData:
|
||||
"""Represents a complete trace to be sent to the server."""
|
||||
|
||||
name: str
|
||||
id: str = field(default_factory=_generate_id)
|
||||
session_id: Optional[str] = None
|
||||
status: str = field(default_factory=lambda: TraceStatus.RUNNING.value)
|
||||
tags: List[str] = field(default_factory=list)
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
total_cost: Optional[float] = None
|
||||
total_tokens: Optional[int] = None
|
||||
total_duration: Optional[int] = None # ms
|
||||
started_at: str = field(default_factory=_now_iso)
|
||||
ended_at: Optional[str] = None
|
||||
decision_points: List[DecisionPoint] = field(default_factory=list)
|
||||
spans: List[Span] = field(default_factory=list)
|
||||
events: List[Event] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"status": self.status,
|
||||
"tags": self.tags,
|
||||
"startedAt": self.started_at,
|
||||
"decisionPoints": [dp.to_dict() for dp in self.decision_points],
|
||||
"spans": [s.to_dict() for s in self.spans],
|
||||
"events": [e.to_dict() for e in self.events],
|
||||
}
|
||||
if self.session_id is not None:
|
||||
d["sessionId"] = self.session_id
|
||||
if self.metadata is not None:
|
||||
d["metadata"] = self.metadata
|
||||
if self.total_cost is not None:
|
||||
d["totalCost"] = self.total_cost
|
||||
if self.total_tokens is not None:
|
||||
d["totalTokens"] = self.total_tokens
|
||||
if self.total_duration is not None:
|
||||
d["totalDuration"] = self.total_duration
|
||||
if self.ended_at is not None:
|
||||
d["endedAt"] = self.ended_at
|
||||
return d
|
||||
@@ -1,76 +1,191 @@
|
||||
"""Trace decorator and context manager for instrumenting agent functions."""
|
||||
|
||||
from typing import Callable, Optional, Any
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from agentlens.models import (
|
||||
TraceData,
|
||||
Span,
|
||||
SpanType,
|
||||
TraceStatus,
|
||||
SpanStatus,
|
||||
Event,
|
||||
EventType,
|
||||
_now_iso,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("agentlens")
|
||||
|
||||
_context = threading.local()
|
||||
|
||||
|
||||
def trace(name: Optional[str] = None) -> Callable[..., Any]:
|
||||
"""Decorator to trace a function or method.
|
||||
def _get_context_stack() -> List[Union[TraceData, Span]]:
|
||||
if not hasattr(_context, "stack"):
|
||||
_context.stack = []
|
||||
return _context.stack
|
||||
|
||||
Args:
|
||||
name: Name for the trace. If not provided, uses the function name.
|
||||
|
||||
Returns:
|
||||
Decorated function with tracing enabled.
|
||||
def get_current_trace() -> Optional[TraceData]:
|
||||
stack = _get_context_stack()
|
||||
if not stack:
|
||||
return None
|
||||
for item in stack:
|
||||
if isinstance(item, TraceData):
|
||||
return item
|
||||
return None
|
||||
|
||||
Example:
|
||||
@trace(name="research-agent")
|
||||
async def research(topic: str) -> str:
|
||||
return f"Researching: {topic}"
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
trace_name = name or func.__name__
|
||||
print(f"[AgentLens] Starting trace: {trace_name}")
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
print(f"[AgentLens] Completed trace: {trace_name}")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"[AgentLens] Error in trace {trace_name}: {e}")
|
||||
raise
|
||||
def get_current_span_id() -> Optional[str]:
|
||||
stack = _get_context_stack()
|
||||
if not stack:
|
||||
return None
|
||||
for item in reversed(stack):
|
||||
if isinstance(item, Span):
|
||||
return item.id
|
||||
return None
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
trace_name = name or func.__name__
|
||||
print(f"[AgentLens] Starting trace: {trace_name}")
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
print(f"[AgentLens] Completed trace: {trace_name}")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"[AgentLens] Error in trace {trace_name}: {e}")
|
||||
raise
|
||||
|
||||
if hasattr(func, "__await__"):
|
||||
return async_wrapper
|
||||
class TraceContext:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
tags: Optional[List[str]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.name = name or "trace"
|
||||
self.tags = tags or []
|
||||
self.session_id = session_id
|
||||
self.metadata = metadata
|
||||
self._trace_data: Optional[TraceData] = None
|
||||
self._span: Optional[Span] = None
|
||||
self._start_time: float = 0
|
||||
self._is_nested: bool = False
|
||||
|
||||
def __enter__(self) -> "TraceContext":
|
||||
self._start_time = time.time()
|
||||
stack = _get_context_stack()
|
||||
|
||||
if stack:
|
||||
self._is_nested = True
|
||||
parent_trace = get_current_trace()
|
||||
parent_span_id = get_current_span_id()
|
||||
self._span = Span(
|
||||
name=self.name,
|
||||
type=SpanType.AGENT.value,
|
||||
parent_span_id=parent_span_id,
|
||||
started_at=_now_iso(),
|
||||
)
|
||||
if parent_trace:
|
||||
parent_trace.spans.append(self._span)
|
||||
stack.append(self._span)
|
||||
else:
|
||||
return sync_wrapper
|
||||
self._trace_data = TraceData(
|
||||
name=self.name,
|
||||
tags=self.tags,
|
||||
session_id=self.session_id,
|
||||
metadata=self.metadata,
|
||||
status=TraceStatus.RUNNING.value,
|
||||
started_at=_now_iso(),
|
||||
)
|
||||
stack.append(self._trace_data)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Tracer:
|
||||
"""Context manager for creating traces.
|
||||
|
||||
Example:
|
||||
with Tracer(name="custom-operation"):
|
||||
# Your code here
|
||||
pass
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def __enter__(self) -> "Tracer":
|
||||
print(f"[AgentLens] Starting trace: {self.name}")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
|
||||
if exc_type is None:
|
||||
print(f"[AgentLens] Completed trace: {self.name}")
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[Any],
|
||||
) -> None:
|
||||
from agentlens.client import get_client
|
||||
|
||||
client = get_client()
|
||||
end_time = time.time()
|
||||
total_duration = int((end_time - self._start_time) * 1000)
|
||||
|
||||
stack = _get_context_stack()
|
||||
|
||||
if self._is_nested and self._span:
|
||||
self._span.status = (
|
||||
SpanStatus.COMPLETED.value
|
||||
if exc_type is None
|
||||
else SpanStatus.ERROR.value
|
||||
)
|
||||
self._span.duration_ms = total_duration
|
||||
self._span.ended_at = _now_iso()
|
||||
stack.pop()
|
||||
elif self._trace_data:
|
||||
if exc_type is not None:
|
||||
self._trace_data.status = TraceStatus.ERROR.value
|
||||
error_event = Event(
|
||||
type=EventType.ERROR.value,
|
||||
name=str(exc_val) if exc_val else "Unknown error",
|
||||
)
|
||||
self._trace_data.events.append(error_event)
|
||||
else:
|
||||
self._trace_data.status = TraceStatus.COMPLETED.value
|
||||
|
||||
self._trace_data.total_duration = total_duration
|
||||
self._trace_data.ended_at = _now_iso()
|
||||
stack.pop()
|
||||
|
||||
client = get_client()
|
||||
if client:
|
||||
client.send_trace(self._trace_data)
|
||||
|
||||
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
with TraceContext(
|
||||
name=self.name,
|
||||
tags=self.tags,
|
||||
session_id=self.session_id,
|
||||
metadata=self.metadata,
|
||||
):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return async_wrapper
|
||||
else:
|
||||
print(f"[AgentLens] Error in trace {self.name}: {exc_val}")
|
||||
return False
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
with TraceContext(
|
||||
name=self.name,
|
||||
tags=self.tags,
|
||||
session_id=self.session_id,
|
||||
metadata=self.metadata,
|
||||
):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
@property
|
||||
def trace_id(self) -> Optional[str]:
|
||||
if self._trace_data:
|
||||
return self._trace_data.id
|
||||
return None
|
||||
|
||||
|
||||
def trace(
|
||||
name: Union[Callable[..., Any], str, None] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[TraceContext, Callable[..., Any]]:
|
||||
if callable(name):
|
||||
func = name
|
||||
ctx = TraceContext(
|
||||
name=func.__name__, tags=tags, session_id=session_id, metadata=metadata
|
||||
)
|
||||
return ctx(func)
|
||||
|
||||
return TraceContext(
|
||||
name=name or "trace", tags=tags, session_id=session_id, metadata=metadata
|
||||
)
|
||||
|
||||
@@ -1,38 +1,92 @@
|
||||
"""Batch transport for sending data to AgentLens API."""
|
||||
"""Batch transport for sending trace data to AgentLens API."""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from agentlens.models import TraceData
|
||||
|
||||
logger = logging.getLogger("agentlens")
|
||||
|
||||
|
||||
class BatchTransport:
|
||||
"""Transport layer that batches events for efficient API calls.
|
||||
"""Thread-safe batch transport that buffers and sends traces to API."""
|
||||
|
||||
This class handles batching and sending of traces, decisions, and other
|
||||
events to the AgentLens backend.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
endpoint: str,
|
||||
max_batch_size: int = 10,
|
||||
flush_interval: float = 5.0,
|
||||
) -> None:
|
||||
self._api_key = api_key
|
||||
self._endpoint = endpoint.rstrip("/")
|
||||
self._max_batch_size = max_batch_size
|
||||
self._flush_interval = flush_interval
|
||||
self._buffer: List[Dict[str, Any]] = []
|
||||
self._lock = threading.Lock()
|
||||
self._shutdown_event = threading.Event()
|
||||
self._flush_thread = threading.Thread(
|
||||
target=self._flush_loop, daemon=True, name="agentlens-flush"
|
||||
)
|
||||
self._flush_thread.start()
|
||||
self._client = httpx.Client(timeout=30.0)
|
||||
|
||||
def __init__(self, max_batch_size: int = 100, flush_interval: float = 1.0) -> None:
|
||||
self.max_batch_size = max_batch_size
|
||||
self.flush_interval = flush_interval
|
||||
self._batch: List[Dict[str, Any]] = []
|
||||
def add(self, trace: TraceData) -> None:
|
||||
"""Add a completed trace to send buffer."""
|
||||
trace_dict = trace.to_dict()
|
||||
with self._lock:
|
||||
self._buffer.append(trace_dict)
|
||||
should_flush = len(self._buffer) >= self._max_batch_size
|
||||
if should_flush:
|
||||
self._do_flush()
|
||||
|
||||
def add(self, event: Dict[str, Any]) -> None:
|
||||
"""Add an event to the batch.
|
||||
def _flush_loop(self) -> None:
|
||||
"""Background loop that periodically flushes buffer."""
|
||||
while not self._shutdown_event.is_set():
|
||||
self._shutdown_event.wait(timeout=self._flush_interval)
|
||||
if not self._shutdown_event.is_set():
|
||||
self._do_flush()
|
||||
|
||||
Args:
|
||||
event: Event data to be sent.
|
||||
"""
|
||||
self._batch.append(event)
|
||||
if len(self._batch) >= self.max_batch_size:
|
||||
self.flush()
|
||||
def _do_flush(self) -> None:
|
||||
"""Flush all buffered traces to the API."""
|
||||
with self._lock:
|
||||
if not self._buffer:
|
||||
return
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
|
||||
try:
|
||||
response = self._client.post(
|
||||
f"{self._endpoint}/api/traces",
|
||||
json={"traces": batch},
|
||||
headers={
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
if response.status_code >= 400:
|
||||
logger.warning(
|
||||
"AgentLens: Failed to send traces (HTTP %d): %s",
|
||||
response.status_code,
|
||||
response.text[:200],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("AgentLens: Failed to send traces: %s", e)
|
||||
# Put traces back in buffer for retry (optional, up to you)
|
||||
# For now, drop on failure to avoid unbounded growth
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush the batch by sending all pending events."""
|
||||
if not self._batch:
|
||||
return
|
||||
|
||||
print(f"[AgentLens] Flushing batch of {len(self._batch)} events")
|
||||
self._batch.clear()
|
||||
"""Manually flush all buffered traces."""
|
||||
self._do_flush()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the transport, flushing any remaining events."""
|
||||
self.flush()
|
||||
"""Shutdown transport: flush remaining traces and stop background thread."""
|
||||
self._shutdown_event.set()
|
||||
self._flush_thread.join(timeout=5.0)
|
||||
self._do_flush() # Final flush
|
||||
self._client.close()
|
||||
|
||||
0
packages/sdk-python/tests/__init__.py
Normal file
0
packages/sdk-python/tests/__init__.py
Normal file
129
packages/sdk-python/tests/test_sdk.py
Normal file
129
packages/sdk-python/tests/test_sdk.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import unittest
|
||||
|
||||
from agentlens import init, shutdown, get_client, trace, log_decision
|
||||
from agentlens.models import TraceData, TraceStatus
|
||||
|
||||
|
||||
class TestSDK(unittest.TestCase):
|
||||
def setUp(self):
|
||||
shutdown()
|
||||
init(api_key="test_key", enabled=False)
|
||||
|
||||
def tearDown(self):
|
||||
shutdown()
|
||||
|
||||
def test_init_and_shutdown(self):
|
||||
client = get_client()
|
||||
self.assertIsNotNone(client)
|
||||
self.assertEqual(client.api_key, "test_key")
|
||||
self.assertFalse(client.enabled)
|
||||
shutdown()
|
||||
self.assertIsNone(get_client())
|
||||
|
||||
def test_decorator_sync_function(self):
|
||||
@trace("test_decorator")
|
||||
def test_func():
|
||||
return 42
|
||||
|
||||
result = test_func()
|
||||
self.assertEqual(result, 42)
|
||||
|
||||
client = get_client()
|
||||
traces = client.sent_traces
|
||||
self.assertEqual(len(traces), 1)
|
||||
self.assertEqual(traces[0].name, "test_decorator")
|
||||
self.assertEqual(traces[0].status, TraceStatus.COMPLETED.value)
|
||||
self.assertIsNotNone(traces[0].total_duration)
|
||||
|
||||
def test_context_manager(self):
|
||||
with trace("test_context") as t:
|
||||
pass
|
||||
|
||||
client = get_client()
|
||||
traces = client.sent_traces
|
||||
self.assertEqual(len(traces), 1)
|
||||
self.assertEqual(traces[0].name, "test_context")
|
||||
self.assertEqual(traces[0].status, TraceStatus.COMPLETED.value)
|
||||
|
||||
def test_nested_traces(self):
|
||||
with trace("outer_trace"):
|
||||
with trace("inner_span"):
|
||||
pass
|
||||
|
||||
client = get_client()
|
||||
traces = client.sent_traces
|
||||
self.assertEqual(len(traces), 1)
|
||||
self.assertEqual(traces[0].name, "outer_trace")
|
||||
self.assertEqual(len(traces[0].spans), 1)
|
||||
self.assertEqual(traces[0].spans[0].name, "inner_span")
|
||||
|
||||
def test_log_decision(self):
|
||||
with trace("test_trace"):
|
||||
log_decision(
|
||||
type="tool_selection",
|
||||
chosen={"name": "search"},
|
||||
alternatives=[{"name": "calculate"}, {"name": "browse"}],
|
||||
reasoning="Search is best for finding information",
|
||||
)
|
||||
|
||||
client = get_client()
|
||||
traces = client.sent_traces
|
||||
self.assertEqual(len(traces), 1)
|
||||
self.assertEqual(len(traces[0].decision_points), 1)
|
||||
decision = traces[0].decision_points[0]
|
||||
self.assertEqual(decision.type, "tool_selection")
|
||||
self.assertEqual(decision.chosen, {"name": "search"})
|
||||
self.assertEqual(len(decision.alternatives), 2)
|
||||
self.assertEqual(decision.reasoning, "Search is best for finding information")
|
||||
|
||||
def test_error_handling(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with trace("test_error"):
|
||||
raise ValueError("Test error")
|
||||
|
||||
client = get_client()
|
||||
traces = client.sent_traces
|
||||
self.assertEqual(len(traces), 1)
|
||||
self.assertEqual(traces[0].name, "test_error")
|
||||
self.assertEqual(traces[0].status, TraceStatus.ERROR.value)
|
||||
self.assertEqual(len(traces[0].events), 1)
|
||||
self.assertEqual(traces[0].events[0].type, "ERROR")
|
||||
self.assertIn("Test error", traces[0].events[0].name)
|
||||
|
||||
def test_log_decision_outside_trace(self):
|
||||
shutdown()
|
||||
|
||||
log_decision(
|
||||
type="tool_selection",
|
||||
chosen={"name": "search"},
|
||||
alternatives=[],
|
||||
)
|
||||
|
||||
init(api_key="test_key", enabled=False)
|
||||
|
||||
client = get_client()
|
||||
self.assertEqual(len(client.sent_traces), 0)
|
||||
|
||||
def test_model_serialization(self):
|
||||
trace_data = TraceData(
|
||||
name="test",
|
||||
tags=["tag1", "tag2"],
|
||||
session_id="session123",
|
||||
metadata={"key": "value"},
|
||||
)
|
||||
|
||||
trace_dict = trace_data.to_dict()
|
||||
|
||||
self.assertEqual(trace_dict["name"], "test")
|
||||
self.assertEqual(trace_dict["tags"], ["tag1", "tag2"])
|
||||
self.assertEqual(trace_dict["sessionId"], "session123")
|
||||
self.assertEqual(trace_dict["metadata"], {"key": "value"})
|
||||
self.assertEqual(trace_dict["status"], "RUNNING")
|
||||
self.assertIn("startedAt", trace_dict)
|
||||
self.assertIn("decisionPoints", trace_dict)
|
||||
self.assertIn("spans", trace_dict)
|
||||
self.assertIn("events", trace_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user