Skip to content

Commit f4d88fa

Browse files
Anthropic instrumentation (#181)
Co-authored-by: Alex Hall <alex.mojaki@gmail.com>
1 parent 3544608 commit f4d88fa

File tree

12 files changed

+1004
-297
lines changed

12 files changed

+1004
-297
lines changed

logfire/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
install_auto_tracing = DEFAULT_LOGFIRE_INSTANCE.install_auto_tracing
2222
instrument_fastapi = DEFAULT_LOGFIRE_INSTANCE.instrument_fastapi
2323
instrument_openai = DEFAULT_LOGFIRE_INSTANCE.instrument_openai
24+
instrument_anthropic = DEFAULT_LOGFIRE_INSTANCE.instrument_anthropic
2425
instrument_asyncpg = DEFAULT_LOGFIRE_INSTANCE.instrument_asyncpg
2526
instrument_psycopg = DEFAULT_LOGFIRE_INSTANCE.instrument_psycopg
2627
shutdown = DEFAULT_LOGFIRE_INSTANCE.shutdown
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
import anthropic
6+
from anthropic.types import ContentBlockDeltaEvent, ContentBlockStartEvent, Message
7+
from anthropic.types.beta.tools import ToolsBetaMessage
8+
9+
from .types import EndpointConfig
10+
11+
if TYPE_CHECKING:
12+
from anthropic._models import FinalRequestOptions
13+
from anthropic._types import ResponseT
14+
15+
from ...main import LogfireSpan
16+
17+
18+
__all__ = (
19+
'get_endpoint_config',
20+
'on_response',
21+
'is_async_client',
22+
)
23+
24+
25+
def get_endpoint_config(options: FinalRequestOptions) -> EndpointConfig:
26+
"""Returns the endpoint config for Anthropic depending on the url."""
27+
url = options.url
28+
json_data = options.json_data
29+
if not isinstance(json_data, dict):
30+
raise ValueError('Expected `options.json_data` to be a dictionary')
31+
32+
if url == '/v1/messages' or url == '/v1/messages?beta=tools':
33+
return EndpointConfig(
34+
message_template='Message with {request_data[model]!r}',
35+
span_data={'request_data': json_data},
36+
content_from_stream=content_from_messages,
37+
)
38+
else:
39+
raise ValueError(f'Unknown Anthropic API endpoint: `{url}`')
40+
41+
42+
def content_from_messages(chunk: anthropic.types.MessageStreamEvent) -> str | None:
43+
if isinstance(chunk, ContentBlockStartEvent):
44+
return chunk.content_block.text
45+
if isinstance(chunk, ContentBlockDeltaEvent):
46+
return chunk.delta.text
47+
return None
48+
49+
50+
def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT:
51+
"""Updates the span based on the type of response."""
52+
if isinstance(response, (Message, ToolsBetaMessage)): # pragma: no branch
53+
block = response.content[0]
54+
message: dict[str, Any] = {'role': 'assistant'}
55+
if block.type == 'text':
56+
message['content'] = block.text
57+
else:
58+
message['tool_calls'] = [
59+
{
60+
'function': {
61+
'arguments': block.model_dump_json(include={'input'}),
62+
'name': block.name, # type: ignore
63+
}
64+
}
65+
for block in response.content
66+
]
67+
span.set_attribute('response_data', {'message': message, 'usage': response.usage})
68+
return response
69+
70+
71+
def is_async_client(client: anthropic.Anthropic | anthropic.AsyncAnthropic):
72+
"""Returns whether or not `client` is async."""
73+
if isinstance(client, anthropic.Anthropic):
74+
return False
75+
assert isinstance(client, anthropic.AsyncAnthropic), f'Unexpected Anthropic or AsyncAnthropic type, got: {client}'
76+
return True
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, ContextManager, Iterator
5+
6+
from opentelemetry import context
7+
8+
from ...constants import ONE_SECOND_IN_NANOSECONDS
9+
10+
if TYPE_CHECKING:
11+
from ...main import Logfire, LogfireSpan
12+
from .types import EndpointConfig
13+
14+
15+
__all__ = ('instrument_llm_provider',)
16+
17+
18+
def instrument_llm_provider(
19+
logfire: Logfire,
20+
client: Any,
21+
suppress_otel: bool,
22+
scope_suffix: str,
23+
get_endpoint_config_fn: Callable[[Any], EndpointConfig],
24+
on_response_fn: Callable[[Any, LogfireSpan], Any],
25+
is_async_client_fn: Callable[[Any], bool],
26+
) -> ContextManager[None]:
27+
"""Instruments the provided `client` with `logfire`."""
28+
logfire_llm = logfire.with_settings(custom_scope_suffix=scope_suffix.lower(), tags=['LLM'])
29+
30+
client._is_instrumented_by_logfire = True
31+
client._original_request_method = original_request_method = client._request
32+
33+
is_async = is_async_client_fn(client)
34+
35+
def _instrumentation_setup(**kwargs: Any) -> Any:
36+
if context.get_value('suppress_instrumentation'):
37+
return None, None, kwargs
38+
39+
options = kwargs['options']
40+
try:
41+
message_template, span_data, content_from_stream = get_endpoint_config_fn(options)
42+
except ValueError as exc:
43+
logfire_llm.warn(
44+
'Unable to instrument {suffix} API call: {error}', suffix=scope_suffix, error=str(exc), kwargs=kwargs
45+
)
46+
return None, None, kwargs
47+
48+
span_data['async'] = is_async
49+
50+
stream = kwargs['stream']
51+
52+
if stream and content_from_stream:
53+
stream_cls = kwargs['stream_cls']
54+
assert stream_cls is not None, 'Expected `stream_cls` when streaming'
55+
56+
if is_async:
57+
58+
class LogfireInstrumentedAsyncStream(stream_cls):
59+
async def __stream__(self) -> AsyncIterator[Any]:
60+
with record_streaming(logfire_llm, span_data, content_from_stream) as record_chunk:
61+
async for chunk in super().__stream__(): # type: ignore
62+
record_chunk(chunk)
63+
yield chunk
64+
65+
kwargs['stream_cls'] = LogfireInstrumentedAsyncStream
66+
else:
67+
68+
class LogfireInstrumentedStream(stream_cls):
69+
def __stream__(self) -> Iterator[Any]:
70+
with record_streaming(logfire_llm, span_data, content_from_stream) as record_chunk:
71+
for chunk in super().__stream__(): # type: ignore
72+
record_chunk(chunk)
73+
yield chunk
74+
75+
kwargs['stream_cls'] = LogfireInstrumentedStream
76+
77+
return message_template, span_data, kwargs
78+
79+
def instrumented_llm_request_sync(**kwargs: Any) -> Any:
80+
message_template, span_data, kwargs = _instrumentation_setup(**kwargs)
81+
if message_template is None:
82+
return original_request_method(**kwargs)
83+
stream = kwargs['stream']
84+
with logfire_llm.span(message_template, **span_data) as span:
85+
with maybe_suppress_instrumentation(suppress_otel):
86+
if stream:
87+
return original_request_method(**kwargs)
88+
else:
89+
response = on_response_fn(original_request_method(**kwargs), span)
90+
return response
91+
92+
async def instrumented_llm_request_async(**kwargs: Any) -> Any:
93+
message_template, span_data, kwargs = _instrumentation_setup(**kwargs)
94+
if message_template is None:
95+
return await original_request_method(**kwargs)
96+
stream = kwargs['stream']
97+
with logfire_llm.span(message_template, **span_data) as span:
98+
with maybe_suppress_instrumentation(suppress_otel):
99+
if stream:
100+
return await original_request_method(**kwargs)
101+
else:
102+
response = on_response_fn(await original_request_method(**kwargs), span)
103+
return response
104+
105+
if is_async:
106+
client._request = instrumented_llm_request_async
107+
else:
108+
client._request = instrumented_llm_request_sync
109+
110+
@contextmanager
111+
def uninstrument_context():
112+
"""Context manager to remove instrumentation from the LLM client.
113+
114+
The user isn't required (or even expected) to use this context manager,
115+
which is why the instrumenting has already happened before.
116+
It exists mostly for tests and just in case users want it.
117+
"""
118+
try:
119+
yield
120+
finally:
121+
client._request = client._original_request_method
122+
del client._original_request_method
123+
client._is_instrumented_by_logfire = False
124+
125+
return uninstrument_context()
126+
127+
128+
@contextmanager
129+
def maybe_suppress_instrumentation(suppress: bool) -> Iterator[None]:
130+
if suppress:
131+
new_context = context.set_value('suppress_instrumentation', True)
132+
token = context.attach(new_context)
133+
try:
134+
yield
135+
finally:
136+
context.detach(token)
137+
else:
138+
yield
139+
140+
141+
@contextmanager
142+
def record_streaming(
143+
logire_llm: Logfire,
144+
span_data: dict[str, Any],
145+
content_from_stream: Callable[[Any], str | None],
146+
):
147+
content: list[str] = []
148+
149+
def record_chunk(chunk: Any) -> Any:
150+
chunk_content = content_from_stream(chunk)
151+
if chunk_content is not None:
152+
content.append(chunk_content)
153+
154+
timer = logire_llm._config.ns_timestamp_generator # type: ignore
155+
start = timer()
156+
try:
157+
yield record_chunk
158+
finally:
159+
duration = (timer() - start) / ONE_SECOND_IN_NANOSECONDS
160+
logire_llm.info(
161+
'streaming response from {request_data[model]!r} took {duration:.2f}s',
162+
**span_data,
163+
duration=duration,
164+
response_data={'combined_chunk_content': ''.join(content), 'chunk_count': len(content)},
165+
)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, cast
4+
5+
import openai
6+
from openai._legacy_response import LegacyAPIResponse
7+
from openai.types.chat.chat_completion import ChatCompletion
8+
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
9+
from openai.types.completion import Completion
10+
from openai.types.create_embedding_response import CreateEmbeddingResponse
11+
from openai.types.images_response import ImagesResponse
12+
13+
from .types import EndpointConfig
14+
15+
if TYPE_CHECKING:
16+
from openai._models import FinalRequestOptions
17+
from openai._types import ResponseT
18+
19+
from ...main import LogfireSpan
20+
21+
__all__ = (
22+
'get_endpoint_config',
23+
'on_response',
24+
'is_async_client',
25+
)
26+
27+
28+
def get_endpoint_config(options: FinalRequestOptions) -> EndpointConfig:
29+
"""Returns the endpoint config for OpenAI depending on the url."""
30+
url = options.url
31+
json_data = options.json_data
32+
if not isinstance(json_data, dict):
33+
raise ValueError('Expected `options.json_data` to be a dictionary')
34+
if 'model' not in json_data:
35+
# all OpenAI API calls have a model AFAIK
36+
raise ValueError('`model` not found in request data')
37+
38+
if url == '/chat/completions':
39+
return EndpointConfig(
40+
message_template='Chat Completion with {request_data[model]!r}',
41+
span_data={'request_data': json_data},
42+
content_from_stream=content_from_chat_completions,
43+
)
44+
elif url == '/completions':
45+
return EndpointConfig(
46+
message_template='Completion with {request_data[model]!r}',
47+
span_data={'request_data': json_data},
48+
content_from_stream=content_from_completions,
49+
)
50+
elif url == '/embeddings':
51+
return EndpointConfig(
52+
message_template='Embedding Creation with {request_data[model]!r}',
53+
span_data={'request_data': json_data},
54+
content_from_stream=None,
55+
)
56+
elif url == '/images/generations':
57+
return EndpointConfig(
58+
message_template='Image Generation with {request_data[model]!r}',
59+
span_data={'request_data': json_data},
60+
content_from_stream=None,
61+
)
62+
else:
63+
raise ValueError(f'Unknown OpenAI API endpoint: `{url}`')
64+
65+
66+
def content_from_completions(chunk: Completion | None) -> str | None:
67+
if chunk and chunk.choices:
68+
return chunk.choices[0].text
69+
return None # pragma: no cover
70+
71+
72+
def content_from_chat_completions(chunk: ChatCompletionChunk | None) -> str | None:
73+
if chunk and chunk.choices:
74+
return chunk.choices[0].delta.content
75+
return None
76+
77+
78+
def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT:
79+
"""Updates the span based on the type of response."""
80+
if isinstance(response, LegacyAPIResponse): # pragma: no cover
81+
on_response(response.parse(), span) # type: ignore
82+
return cast('ResponseT', response)
83+
84+
if isinstance(response, ChatCompletion):
85+
span.set_attribute(
86+
'response_data',
87+
{'message': response.choices[0].message, 'usage': response.usage},
88+
)
89+
elif isinstance(response, Completion):
90+
first_choice = response.choices[0]
91+
span.set_attribute(
92+
'response_data',
93+
{'finish_reason': first_choice.finish_reason, 'text': first_choice.text, 'usage': response.usage},
94+
)
95+
elif isinstance(response, CreateEmbeddingResponse):
96+
span.set_attribute('response_data', {'usage': response.usage})
97+
elif isinstance(response, ImagesResponse): # pragma: no branch
98+
span.set_attribute('response_data', {'images': response.data})
99+
return response
100+
101+
102+
def is_async_client(client: openai.OpenAI | openai.AsyncOpenAI):
103+
"""Returns whether or not `client` is async."""
104+
if isinstance(client, openai.OpenAI):
105+
return False
106+
assert isinstance(client, openai.AsyncOpenAI), f'Unexpected OpenAI or AsyncOpenAI type, got: {client}'
107+
return True
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Callable, NamedTuple
4+
5+
from typing_extensions import LiteralString
6+
7+
8+
class EndpointConfig(NamedTuple):
9+
"""The configuration for the endpoint of a provider based on request url."""
10+
11+
message_template: LiteralString
12+
span_data: dict[str, Any]
13+
content_from_stream: Callable[[Any], str | None] | None

0 commit comments

Comments
 (0)