Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- `opentelemetry-instrumentation-aiohttp-client`: add typechecking for aiohttp-client instrumentation
([#4006](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4006))

## Version 1.39.0/0.60b0 (2025-12-03)

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,20 @@ def response_hook(span: Span, params: typing.Union[
---
"""

from __future__ import annotations

import types
import typing
from timeit import default_timer
from typing import Collection
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
TypedDict,
Union,
cast,
)
from urllib.parse import urlparse

import aiohttp
Expand Down Expand Up @@ -143,7 +153,9 @@ def response_hook(span: Span, params: typing.Union[
from opentelemetry.metrics import MeterProvider, get_meter
from opentelemetry.propagate import inject
from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
from opentelemetry.semconv.metrics import MetricInstruments
from opentelemetry.semconv.metrics import (
MetricInstruments, # type: ignore[reportDeprecated]
)
from opentelemetry.semconv.metrics.http_metrics import (
HTTP_CLIENT_REQUEST_DURATION,
)
Expand All @@ -155,22 +167,39 @@ def response_hook(span: Span, params: typing.Union[
sanitize_method,
)

_UrlFilterT = typing.Optional[typing.Callable[[yarl.URL], str]]
_RequestHookT = typing.Optional[
typing.Callable[[Span, aiohttp.TraceRequestStartParams], None]
]
_ResponseHookT = typing.Optional[
typing.Callable[
[
Span,
typing.Union[
aiohttp.TraceRequestEndParams,
aiohttp.TraceRequestExceptionParams,
if TYPE_CHECKING:
from typing_extensions import Unpack

UrlFilterT = typing.Optional[typing.Callable[[yarl.URL], str]]
RequestHookT = typing.Optional[
typing.Callable[[Span, aiohttp.TraceRequestStartParams], None]
]
ResponseHookT = typing.Optional[
typing.Callable[
[
Span,
typing.Union[
aiohttp.TraceRequestEndParams,
aiohttp.TraceRequestExceptionParams,
],
],
],
None,
None,
]
]
]

class ClientSessionInitKwargs(TypedDict, total=False):
trace_configs: typing.Sequence[aiohttp.TraceConfig]

class InstrumentKwargs(TypedDict, total=False):
tracer_provider: trace.TracerProvider
meter_provider: MeterProvider
url_filter: UrlFilterT
request_hook: RequestHookT
response_hook: ResponseHookT
trace_configs: typing.Sequence[aiohttp.TraceConfig]

class UninstrumentKwargs(TypedDict, total=False):
pass


def _get_span_name(method: str) -> str:
Expand All @@ -181,10 +210,10 @@ def _get_span_name(method: str) -> str:


def _set_http_status_code_attribute(
span,
status_code,
metric_attributes=None,
sem_conv_opt_in_mode=_StabilityMode.DEFAULT,
span: Span,
status_code: int,
metric_attributes: Union[dict[str, Any], None] = None,
sem_conv_opt_in_mode: _StabilityMode = _StabilityMode.DEFAULT,
):
status_code_str = str(status_code)
try:
Expand All @@ -209,11 +238,11 @@ def _set_http_status_code_attribute(
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
def create_trace_config(
url_filter: _UrlFilterT = None,
request_hook: _RequestHookT = None,
response_hook: _ResponseHookT = None,
tracer_provider: TracerProvider = None,
meter_provider: MeterProvider = None,
url_filter: UrlFilterT = None,
request_hook: RequestHookT = None,
response_hook: ResponseHookT = None,
tracer_provider: Union[TracerProvider, None] = None,
meter_provider: Union[MeterProvider, None] = None,
sem_conv_opt_in_mode: _StabilityMode = _StabilityMode.DEFAULT,
) -> aiohttp.TraceConfig:
"""Create an aiohttp-compatible trace configuration.
Expand Down Expand Up @@ -268,12 +297,10 @@ def create_trace_config(
schema_url,
)

start_time = 0

duration_histogram_old = None
if _report_old(sem_conv_opt_in_mode):
duration_histogram_old = meter.create_histogram(
name=MetricInstruments.HTTP_CLIENT_DURATION,
name=MetricInstruments.HTTP_CLIENT_DURATION, # type: ignore[reportDeprecated]
unit="ms",
description="measures the duration of the outbound HTTP request",
explicit_bucket_boundaries_advisory=HTTP_DURATION_HISTOGRAM_BUCKETS_OLD,
Expand All @@ -293,52 +320,62 @@ def _end_trace(trace_config_ctx: types.SimpleNamespace):
elapsed_time = max(default_timer() - trace_config_ctx.start_time, 0)
if trace_config_ctx.token:
context_api.detach(trace_config_ctx.token)
trace_config_ctx.span.end()
if trace_config_ctx.span:
trace_config_ctx.span.end()

if trace_config_ctx.duration_histogram_old is not None:
duration_attrs_old = _filter_semconv_duration_attrs(
trace_config_ctx.metric_attributes,
_client_duration_attrs_old,
_client_duration_attrs_new,
_StabilityMode.DEFAULT,
duration_attrs_old = cast(
dict[str, Any],
_filter_semconv_duration_attrs(
trace_config_ctx.metric_attributes,
_client_duration_attrs_old,
_client_duration_attrs_new,
_StabilityMode.DEFAULT,
),
)
trace_config_ctx.duration_histogram_old.record(
max(round(elapsed_time * 1000), 0),
attributes=duration_attrs_old,
)
if trace_config_ctx.duration_histogram_new is not None:
duration_attrs_new = _filter_semconv_duration_attrs(
trace_config_ctx.metric_attributes,
_client_duration_attrs_old,
_client_duration_attrs_new,
_StabilityMode.HTTP,
duration_attrs_new = cast(
dict[str, Any],
_filter_semconv_duration_attrs(
trace_config_ctx.metric_attributes,
_client_duration_attrs_old,
_client_duration_attrs_new,
_StabilityMode.HTTP,
),
)
trace_config_ctx.duration_histogram_new.record(
elapsed_time, attributes=duration_attrs_new
)

async def on_request_start(
unused_session: aiohttp.ClientSession,
_session: aiohttp.ClientSession,
trace_config_ctx: types.SimpleNamespace,
params: aiohttp.TraceRequestStartParams,
):
if (
not is_http_instrumentation_enabled()
or trace_config_ctx.excluded_urls.url_disabled(str(params.url))
):
trace_config_ctx.span = None
return

trace_config_ctx.start_time = default_timer()
method = params.method
request_span_name = _get_span_name(method)
request_url = (
redact_url(trace_config_ctx.url_filter(params.url))
redact_url(
cast(Callable[[yarl.URL], str], trace_config_ctx.url_filter)(
params.url
)
)
if callable(trace_config_ctx.url_filter)
else redact_url(str(params.url))
)

span_attributes = {}
span_attributes: dict[str, Any] = {}
_set_http_method(
span_attributes,
method,
Expand Down Expand Up @@ -399,7 +436,7 @@ async def on_request_start(
inject(params.headers)

async def on_request_end(
unused_session: aiohttp.ClientSession,
_session: aiohttp.ClientSession,
trace_config_ctx: types.SimpleNamespace,
params: aiohttp.TraceRequestEndParams,
):
Expand All @@ -418,7 +455,7 @@ async def on_request_end(
_end_trace(trace_config_ctx)

async def on_request_exception(
unused_session: aiohttp.ClientSession,
_session: aiohttp.ClientSession,
trace_config_ctx: types.SimpleNamespace,
params: aiohttp.TraceRequestExceptionParams,
):
Expand All @@ -441,21 +478,25 @@ async def on_request_exception(

_end_trace(trace_config_ctx)

def _trace_config_ctx_factory(**kwargs):
def _trace_config_ctx_factory(**kwargs: Any) -> types.SimpleNamespace:
kwargs.setdefault("trace_request_ctx", {})
return types.SimpleNamespace(
tracer=tracer,
url_filter=url_filter,
start_time=start_time,
span=None,
token=None,
duration_histogram_old=duration_histogram_old,
duration_histogram_new=duration_histogram_new,
excluded_urls=excluded_urls,
metric_attributes={},
url_filter=url_filter,
excluded_urls=excluded_urls,
start_time=0,
**kwargs,
)

trace_config = aiohttp.TraceConfig(
trace_config_ctx_factory=_trace_config_ctx_factory
trace_config_ctx_factory=cast(
type[types.SimpleNamespace], _trace_config_ctx_factory
)
)

trace_config.on_request_start.append(on_request_start)
Expand All @@ -466,11 +507,11 @@ def _trace_config_ctx_factory(**kwargs):


def _instrument(
tracer_provider: TracerProvider = None,
meter_provider: MeterProvider = None,
url_filter: _UrlFilterT = None,
request_hook: _RequestHookT = None,
response_hook: _ResponseHookT = None,
tracer_provider: Union[TracerProvider, None] = None,
meter_provider: Union[MeterProvider, None] = None,
url_filter: UrlFilterT = None,
request_hook: RequestHookT = None,
response_hook: ResponseHookT = None,
trace_configs: typing.Optional[
typing.Sequence[aiohttp.TraceConfig]
] = None,
Expand All @@ -485,7 +526,12 @@ def _instrument(
trace_configs = trace_configs or ()

# pylint:disable=unused-argument
def instrumented_init(wrapped, instance, args, kwargs):
def instrumented_init(
wrapped: Callable[..., None],
_instance: aiohttp.ClientSession,
args: tuple[Any, ...],
kwargs: ClientSessionInitKwargs,
):
client_trace_configs = list(kwargs.get("trace_configs") or [])
client_trace_configs.extend(trace_configs)

Expand All @@ -497,13 +543,13 @@ def instrumented_init(wrapped, instance, args, kwargs):
meter_provider=meter_provider,
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
)
trace_config._is_instrumented_by_opentelemetry = True
setattr(trace_config, "_is_instrumented_by_opentelemetry", True)
client_trace_configs.append(trace_config)

kwargs["trace_configs"] = client_trace_configs
return wrapped(*args, **kwargs)

wrapt.wrap_function_wrapper(
wrapt.wrap_function_wrapper( # type: ignore[reportUnknownVariableType]
aiohttp.ClientSession, "__init__", instrumented_init
)

Expand Down Expand Up @@ -533,7 +579,7 @@ class AioHttpClientInstrumentor(BaseInstrumentor):
def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
def _instrument(self, **kwargs: Unpack[InstrumentKwargs]):
"""Instruments aiohttp ClientSession

Args:
Expand Down Expand Up @@ -562,7 +608,7 @@ def _instrument(self, **kwargs):
sem_conv_opt_in_mode=_sem_conv_opt_in_mode,
)

def _uninstrument(self, **kwargs):
def _uninstrument(self, **kwargs: Unpack[UninstrumentKwargs]):
_uninstrument()

@staticmethod
Expand Down
Loading