Skip to content

Commit

Permalink
Fix langchain integration (#3921)
Browse files Browse the repository at this point in the history
* Add optional `parent_span` argument to `POTelSpan` constructor and fix
`start_child`
* `run_id` is reused for the top level pipeline, so make sure to close
that span or else we get orphans
* Don't use context manager enter/exit since we're doing manual span
management
* Set correct statuses while finishing the spans
  • Loading branch information
sl0thentr0py authored Jan 13, 2025
1 parent ab5d8a7 commit 7cf7373
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 33 deletions.
2 changes: 1 addition & 1 deletion sentry_sdk/ai/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import sentry_sdk.utils
from sentry_sdk import start_span
from sentry_sdk.tracing import Span
from sentry_sdk.tracing import POTelSpan as Span
from sentry_sdk.utils import ContextVar

from typing import TYPE_CHECKING
Expand Down
2 changes: 1 addition & 1 deletion sentry_sdk/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
if TYPE_CHECKING:
from typing import Any

from sentry_sdk.tracing import Span
from sentry_sdk.tracing import POTelSpan as Span
from sentry_sdk.utils import logger


Expand Down
39 changes: 22 additions & 17 deletions sentry_sdk/integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import sentry_sdk
from sentry_sdk.ai.monitoring import set_ai_pipeline_name, record_token_usage
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.consts import OP, SPANDATA, SPANSTATUS
from sentry_sdk.ai.utils import set_data_normalized
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing import Span
from sentry_sdk.tracing import POTelSpan as Span
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.utils import logger, capture_internal_exceptions

Expand Down Expand Up @@ -72,7 +72,6 @@ def setup_once():


class WatchedSpan:
span = None # type: Span
num_completion_tokens = 0 # type: int
num_prompt_tokens = 0 # type: int
no_collect_tokens = False # type: bool
Expand Down Expand Up @@ -123,8 +122,9 @@ def _handle_error(self, run_id, error):
span_data = self.span_map[run_id]
if not span_data:
return
sentry_sdk.capture_exception(error, span_data.span.scope)
span_data.span.__exit__(None, None, None)
sentry_sdk.capture_exception(error)
span_data.span.set_status(SPANSTATUS.INTERNAL_ERROR)
span_data.span.finish()
del self.span_map[run_id]

def _normalize_langchain_message(self, message):
Expand All @@ -136,23 +136,27 @@ def _normalize_langchain_message(self, message):
def _create_span(self, run_id, parent_id, **kwargs):
# type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan

watched_span = None # type: Optional[WatchedSpan]
if parent_id:
parent_span = self.span_map.get(parent_id) # type: Optional[WatchedSpan]
if parent_span:
watched_span = WatchedSpan(parent_span.span.start_child(**kwargs))
parent_span.children.append(watched_span)
if watched_span is None:
watched_span = WatchedSpan(
sentry_sdk.start_span(only_if_parent=True, **kwargs)
)
parent_watched_span = self.span_map.get(parent_id) if parent_id else None
sentry_span = sentry_sdk.start_span(
parent_span=parent_watched_span.span if parent_watched_span else None,
only_if_parent=True,
**kwargs,
)
watched_span = WatchedSpan(sentry_span)
if parent_watched_span:
parent_watched_span.children.append(watched_span)

if kwargs.get("op", "").startswith("ai.pipeline."):
if kwargs.get("name"):
set_ai_pipeline_name(kwargs.get("name"))
watched_span.is_pipeline = True

watched_span.span.__enter__()
# the same run_id is reused for the pipeline it seems
# so we need to end the older span to avoid orphan spans
existing_span_data = self.span_map.get(run_id)
if existing_span_data is not None:
self._exit_span(existing_span_data, run_id)

self.span_map[run_id] = watched_span
self.gc_span_map()
return watched_span
Expand All @@ -163,7 +167,8 @@ def _exit_span(self, span_data, run_id):
if span_data.is_pipeline:
set_ai_pipeline_name(None)

span_data.span.__exit__(None, None, None)
span_data.span.set_status(SPANSTATUS.OK)
span_data.span.finish()
del self.span_map[run_id]

def on_llm_start(
Expand Down
14 changes: 14 additions & 0 deletions sentry_sdk/integrations/opentelemetry/span_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,17 @@ def _common_span_transaction_attributes_as_json(self, span):
common_json["tags"] = tags

return common_json

def _log_debug_info(self):
# type: () -> None
import pprint

pprint.pprint(
{
format_span_id(span_id): [
(format_span_id(child.context.span_id), child.name)
for child in children
]
for span_id, children in self._children_spans.items()
}
)
19 changes: 13 additions & 6 deletions sentry_sdk/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,7 @@ def __init__(
source=TRANSACTION_SOURCE_CUSTOM, # type: str
attributes=None, # type: OTelSpanAttributes
only_if_parent=False, # type: bool
parent_span=None, # type: Optional[POTelSpan]
otel_span=None, # type: Optional[OtelSpan]
**_, # type: dict[str, object]
):
Expand All @@ -1231,7 +1232,7 @@ def __init__(
self._otel_span = otel_span
else:
skip_span = False
if only_if_parent:
if only_if_parent and parent_span is None:
parent_span_context = get_current_span().get_span_context()
skip_span = (
not parent_span_context.is_valid or parent_span_context.is_remote
Expand Down Expand Up @@ -1262,8 +1263,17 @@ def __init__(
if sampled is not None:
attributes[SentrySpanAttribute.CUSTOM_SAMPLED] = sampled

parent_context = None
if parent_span is not None:
parent_context = otel_trace.set_span_in_context(
parent_span._otel_span
)

self._otel_span = tracer.start_span(
span_name, start_time=start_timestamp, attributes=attributes
span_name,
context=parent_context,
start_time=start_timestamp,
attributes=attributes,
)

self.origin = origin or DEFAULT_SPAN_ORIGIN
Expand Down Expand Up @@ -1506,10 +1516,7 @@ def timestamp(self):

def start_child(self, **kwargs):
# type: (**Any) -> POTelSpan
kwargs.setdefault("sampled", self.sampled)

span = POTelSpan(only_if_parent=True, **kwargs)
return span
return POTelSpan(sampled=self.sampled, parent_span=self, **kwargs)

def iter_headers(self):
# type: () -> Iterator[Tuple[str, str]]
Expand Down
10 changes: 2 additions & 8 deletions tests/integrations/langchain/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,11 @@ def test_langchain_agent(
assert "measurements" not in chat_spans[0]

if send_default_pii and include_prompts:
assert (
"You are very powerful"
in chat_spans[0]["data"]["ai.input_messages"][0]["content"]
)
assert "You are very powerful" in chat_spans[0]["data"]["ai.input_messages"]
assert "5" in chat_spans[0]["data"]["ai.responses"]
assert "word" in tool_exec_span["data"]["ai.input_messages"]
assert 5 == int(tool_exec_span["data"]["ai.responses"])
assert (
"You are very powerful"
in chat_spans[1]["data"]["ai.input_messages"][0]["content"]
)
assert "You are very powerful" in chat_spans[1]["data"]["ai.input_messages"]
assert "5" in chat_spans[1]["data"]["ai.responses"]
else:
assert "ai.input_messages" not in chat_spans[0].get("data", {})
Expand Down

0 comments on commit 7cf7373

Please sign in to comment.