feat: LangChain auto-instrumentation + dashboard UI
- LangChain: AgentLensCallbackHandler with auto-span creation for LLM calls, tool calls, chains, and agent decision logging - Dashboard: trace list with search, status filters, pagination - Dashboard: trace detail with Decision/Span/Event tabs - Dashboard: sidebar layout, responsive design, dark theme
This commit is contained in:
@@ -1,55 +1,493 @@
|
||||
"""LangChain integration for AgentLens."""
|
||||
"""LangChain integration for AgentLens.
|
||||
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.messages import BaseMessage
|
||||
This module provides a callback handler that auto-instruments LangChain chains,
|
||||
agents, LLM calls, and tool calls, creating Spans and DecisionPoints in AgentLens traces.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.messages import BaseMessage
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"langchain-core is required. Install with: pip install agentlens[langchain]"
|
||||
)
|
||||
|
||||
from agentlens.models import (
|
||||
Span,
|
||||
SpanType,
|
||||
SpanStatus,
|
||||
Event,
|
||||
EventType,
|
||||
_now_iso,
|
||||
)
|
||||
from agentlens.trace import (
|
||||
get_current_trace,
|
||||
get_current_span_id,
|
||||
_get_context_stack,
|
||||
TraceContext,
|
||||
)
|
||||
from agentlens.decision import log_decision
|
||||
|
||||
logger = logging.getLogger("agentlens")
|
||||
|
||||
|
||||
class AgentLensCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for LangChain integration with AgentLens.
|
||||
|
||||
This handler captures LLM calls, tool calls, and agent actions
|
||||
to provide observability for LangChain-based agents.
|
||||
This handler captures LLM calls, tool calls, agent actions, and chain execution
|
||||
to provide observability for LangChain-based agents. It works both inside and
|
||||
outside of an existing AgentLens trace context.
|
||||
|
||||
Example usage:
|
||||
# Standalone (creates its own trace)
|
||||
handler = AgentLensCallbackHandler(trace_name="my-langchain-trace")
|
||||
chain.invoke(input, config={"callbacks": [handler]})
|
||||
|
||||
# Inside existing trace
|
||||
with trace(name="my-operation"):
|
||||
handler = AgentLensCallbackHandler()
|
||||
chain.invoke(input, config={"callbacks": [handler]})
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.trace_id: Optional[str] = None
|
||||
def __init__(
|
||||
self,
|
||||
trace_name: str = "langchain-trace",
|
||||
tags: Optional[List[str]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize callback handler.
|
||||
|
||||
Args:
|
||||
trace_name: Name for trace (if not already in a trace context)
|
||||
tags: Optional tags to add to trace
|
||||
session_id: Optional session ID for trace
|
||||
"""
|
||||
self.trace_name = trace_name
|
||||
self.tags = tags
|
||||
self.session_id = session_id
|
||||
|
||||
# Mapping from LangChain run_id to AgentLens Span
|
||||
self._run_map: Dict[UUID, Span] = {}
|
||||
|
||||
# TraceContext if we create our own trace
|
||||
self._trace_ctx: Optional[TraceContext] = None
|
||||
|
||||
# Track if we're in a top-level chain for trace lifecycle
|
||||
self._top_level_run_id: Optional[UUID] = None
|
||||
|
||||
logger.debug(
|
||||
"AgentLensCallbackHandler initialized: trace_name=%s, tags=%s",
|
||||
trace_name,
|
||||
tags,
|
||||
)
|
||||
|
||||
def _get_or_create_trace(self) -> Optional[Any]:
|
||||
"""Get current trace or create one if needed.
|
||||
|
||||
Returns:
|
||||
The current trace (TraceData) or None if not in a trace context
|
||||
"""
|
||||
current_trace = get_current_trace()
|
||||
if current_trace is not None:
|
||||
return current_trace
|
||||
|
||||
# No active trace, create our own
|
||||
if self._trace_ctx is None:
|
||||
self._trace_ctx = TraceContext(
|
||||
name=self.trace_name,
|
||||
tags=self.tags,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
self._trace_ctx.__enter__()
|
||||
logger.debug(
|
||||
"AgentLensCallbackHandler created new trace: %s", self.trace_name
|
||||
)
|
||||
|
||||
return get_current_trace()
|
||||
|
||||
def _create_span(
|
||||
self,
|
||||
name: str,
|
||||
span_type: SpanType,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID],
|
||||
input_data: Any = None,
|
||||
) -> Optional[Span]:
|
||||
"""Create a new span and add it to the trace.
|
||||
|
||||
Args:
|
||||
name: Span name
|
||||
span_type: Type of span (LLM_CALL, TOOL_CALL, CHAIN, etc.)
|
||||
run_id: LangChain run ID for this operation
|
||||
parent_run_id: LangChain run ID of parent operation
|
||||
input_data: Input data for the span
|
||||
|
||||
Returns:
|
||||
The created Span, or None if no active trace
|
||||
"""
|
||||
trace = self._get_or_create_trace()
|
||||
if trace is None:
|
||||
logger.warning("No active trace, skipping span creation")
|
||||
return None
|
||||
|
||||
# Determine parent span ID from context or parent_run_id
|
||||
parent_span_id = get_current_span_id()
|
||||
if parent_span_id is None and parent_run_id is not None:
|
||||
parent_span = self._run_map.get(parent_run_id)
|
||||
if parent_span:
|
||||
parent_span_id = parent_span.id
|
||||
|
||||
span = Span(
|
||||
name=name,
|
||||
type=span_type.value,
|
||||
parent_span_id=parent_span_id,
|
||||
input_data=input_data,
|
||||
status=SpanStatus.RUNNING.value,
|
||||
started_at=_now_iso(),
|
||||
)
|
||||
|
||||
trace.spans.append(span)
|
||||
self._run_map[run_id] = span
|
||||
|
||||
# Push onto context stack for nested operations
|
||||
stack = _get_context_stack()
|
||||
stack.append(span)
|
||||
|
||||
logger.debug(
|
||||
"AgentLensCallbackHandler created span: type=%s, name=%s, run_id=%s",
|
||||
span_type.value,
|
||||
name,
|
||||
run_id,
|
||||
)
|
||||
|
||||
return span
|
||||
|
||||
def _complete_span(
|
||||
self,
|
||||
run_id: UUID,
|
||||
output_data: Any = None,
|
||||
status: SpanStatus = SpanStatus.COMPLETED,
|
||||
) -> None:
|
||||
"""Mark a span as completed.
|
||||
|
||||
Args:
|
||||
run_id: LangChain run ID
|
||||
output_data: Output data for the span
|
||||
status: Final status of the span
|
||||
"""
|
||||
span = self._run_map.pop(run_id, None)
|
||||
if span is None:
|
||||
return
|
||||
|
||||
span.status = status.value
|
||||
span.ended_at = _now_iso()
|
||||
|
||||
# Calculate duration if we have start time
|
||||
# Note: We don't store start time separately, so estimate from timestamps
|
||||
# In a more sophisticated implementation, we'd store start_time_ms
|
||||
|
||||
if output_data is not None:
|
||||
span.output_data = output_data
|
||||
|
||||
# Pop from context stack if this is the top span
|
||||
stack = _get_context_stack()
|
||||
if stack and isinstance(stack[-1], Span) and stack[-1].id == span.id:
|
||||
stack.pop()
|
||||
|
||||
logger.debug(
|
||||
"AgentLensCallbackHandler completed span: name=%s, status=%s, run_id=%s",
|
||||
span.name,
|
||||
status.value,
|
||||
run_id,
|
||||
)
|
||||
|
||||
def _error_span(self, run_id: UUID, error: Exception) -> None:
|
||||
"""Mark a span as errored and add an error event.
|
||||
|
||||
Args:
|
||||
run_id: LangChain run ID
|
||||
error: The exception that occurred
|
||||
"""
|
||||
span = self._run_map.get(run_id)
|
||||
if span is None:
|
||||
return
|
||||
|
||||
span.status = SpanStatus.ERROR.value
|
||||
span.status_message = str(error)
|
||||
span.ended_at = _now_iso()
|
||||
|
||||
# Add error event to trace
|
||||
trace = get_current_trace()
|
||||
if trace:
|
||||
error_event = Event(
|
||||
type=EventType.ERROR.value,
|
||||
name=f"{span.name}: {str(error)}",
|
||||
span_id=span.id,
|
||||
metadata={"error_type": type(error).__name__},
|
||||
)
|
||||
trace.events.append(error_event)
|
||||
|
||||
# Pop from context stack
|
||||
stack = _get_context_stack()
|
||||
if stack and isinstance(stack[-1], Span) and stack[-1].id == span.id:
|
||||
stack.pop()
|
||||
|
||||
logger.debug(
|
||||
"AgentLensCallbackHandler errored span: name=%s, error=%s, run_id=%s",
|
||||
span.name,
|
||||
error,
|
||||
run_id,
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: list[str],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called when an LLM starts processing."""
|
||||
print(f"[AgentLens] LLM started: {serialized.get('name', 'unknown')}")
|
||||
# Extract model name from serialized data
|
||||
model_name = "unknown"
|
||||
if "id" in serialized and isinstance(serialized["id"], list):
|
||||
model_name = serialized["id"][-1]
|
||||
elif "name" in serialized:
|
||||
model_name = serialized["name"]
|
||||
elif "model_name" in serialized:
|
||||
model_name = serialized["model_name"]
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self._create_span(
|
||||
name=model_name,
|
||||
span_type=SpanType.LLM_CALL,
|
||||
run_id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
input_data={"prompts": prompts},
|
||||
)
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called when an LLM finishes processing."""
|
||||
print(f"[AgentLens] LLM completed")
|
||||
span = self._run_map.get(run_id)
|
||||
if span is None:
|
||||
return
|
||||
|
||||
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
|
||||
# Extract token usage
|
||||
token_count = None
|
||||
llm_output = getattr(response, "llm_output", {}) or {}
|
||||
if llm_output:
|
||||
token_usage = llm_output.get("token_usage", {})
|
||||
if token_usage:
|
||||
total_tokens = token_usage.get("total_tokens")
|
||||
if total_tokens is not None:
|
||||
token_count = total_tokens
|
||||
span.token_count = total_tokens
|
||||
|
||||
# Extract generation text
|
||||
generations = getattr(response, "generations", [])
|
||||
output_data = None
|
||||
if generations:
|
||||
# Get text from generations
|
||||
texts = []
|
||||
for gen in generations:
|
||||
gen_dict = gen if isinstance(gen, dict) else gen.__dict__
|
||||
text = gen_dict.get("text", "")
|
||||
if text:
|
||||
texts.append(text)
|
||||
if texts:
|
||||
output_data = {"generations": texts}
|
||||
|
||||
self._complete_span(run_id, output_data=output_data)
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: Exception,
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called when an LLM encounters an error."""
|
||||
print(f"[AgentLens] LLM error: {error}")
|
||||
self._error_span(run_id, error)
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called when a chat model starts processing."""
|
||||
# Extract model name
|
||||
model_name = "unknown"
|
||||
if "id" in serialized and isinstance(serialized["id"], list):
|
||||
model_name = serialized["id"][-1]
|
||||
elif "name" in serialized:
|
||||
model_name = serialized["name"]
|
||||
|
||||
# Extract message content for input
|
||||
message_content = []
|
||||
for msg_list in messages:
|
||||
for msg in msg_list:
|
||||
msg_dict = msg if isinstance(msg, dict) else msg.__dict__
|
||||
content = msg_dict.get("content", "")
|
||||
message_content.append(str(content))
|
||||
|
||||
self._create_span(
|
||||
name=model_name,
|
||||
span_type=SpanType.LLM_CALL,
|
||||
run_id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
input_data={"messages": message_content},
|
||||
)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called when a tool starts executing."""
|
||||
print(f"[AgentLens] Tool started: {serialized.get('name', 'unknown')}")
|
||||
tool_name = serialized.get("name", "unknown-tool")
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
self._create_span(
|
||||
name=tool_name,
|
||||
span_type=SpanType.TOOL_CALL,
|
||||
run_id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
input_data={"input": input_str},
|
||||
)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: Union[str, Dict[str, Any]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called when a tool finishes executing."""
|
||||
print(f"[AgentLens] Tool completed")
|
||||
self._complete_span(run_id, output_data={"output": output})
|
||||
|
||||
def on_tool_error(self, error: Exception, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: Exception,
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called when a tool encounters an error."""
|
||||
print(f"[AgentLens] Tool error: {error}")
|
||||
self._error_span(run_id, error)
|
||||
|
||||
def on_agent_action(self, action: Any, **kwargs: Any) -> None:
|
||||
"""Called when an agent performs an action."""
|
||||
print(f"[AgentLens] Agent action: {action.tool}")
|
||||
def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called when an agent performs an action.
|
||||
|
||||
This logs tool selection as a decision point.
|
||||
"""
|
||||
# Log decision point for tool selection
|
||||
log_decision(
|
||||
type="TOOL_SELECTION",
|
||||
chosen={
|
||||
"name": action.tool,
|
||||
"input": str(action.tool_input),
|
||||
},
|
||||
alternatives=[],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"AgentLensCallbackHandler logged agent action: tool=%s, run_id=%s",
|
||||
action.tool,
|
||||
run_id,
|
||||
)
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called when a chain starts executing."""
|
||||
# Check if this is a top-level chain (no parent)
|
||||
is_top_level = parent_run_id is None
|
||||
|
||||
# Extract chain name
|
||||
chain_name = serialized.get("name", serialized.get("id", ["unknown-chain"])[-1])
|
||||
|
||||
# Create span for the chain
|
||||
self._create_span(
|
||||
name=chain_name,
|
||||
span_type=SpanType.CHAIN,
|
||||
run_id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
input_data=inputs,
|
||||
)
|
||||
|
||||
# Track top-level chain for trace lifecycle
|
||||
if is_top_level:
|
||||
self._top_level_run_id = run_id
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called when a chain finishes executing."""
|
||||
# Complete the span
|
||||
self._complete_span(run_id, output_data=outputs)
|
||||
|
||||
# If this was the top-level chain, close the trace
|
||||
if self._top_level_run_id == run_id and self._trace_ctx is not None:
|
||||
self._trace_ctx.__exit__(None, None, None)
|
||||
logger.debug(
|
||||
"AgentLensCallbackHandler closed trace: %s (top-level chain completed)",
|
||||
self.trace_name,
|
||||
)
|
||||
self._trace_ctx = None
|
||||
self._top_level_run_id = None
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Exception,
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called when a chain encounters an error."""
|
||||
self._error_span(run_id, error)
|
||||
|
||||
# If this was the top-level chain, close the trace with error
|
||||
if self._top_level_run_id == run_id and self._trace_ctx is not None:
|
||||
self._trace_ctx.__exit__(type(error), error, None)
|
||||
logger.debug(
|
||||
"AgentLensCallbackHandler closed trace: %s (top-level chain errored)",
|
||||
self.trace_name,
|
||||
)
|
||||
self._trace_ctx = None
|
||||
self._top_level_run_id = None
|
||||
|
||||
Reference in New Issue
Block a user