Skip to content

Commit 2eda136

Browse files
✨ Support middleware to wrap httpx.send calls.
1 parent daf2d6d commit 2eda136

File tree

6 files changed

+31
-2
lines changed

6 files changed

+31
-2
lines changed

ChangeLog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and the format of this file is based on [Keep a Changelog](https://keepachangelo
1111
### Added
1212
- Accept session_factory in `ClientBase.__init__`.
1313
- Helper function to iterate over pages.
14+
- Accept middleware.
1415

1516
### Fixed
1617
- Handling collections in request bodies.

src/lapidary/runtime/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
'FormExplode',
99
'Header',
1010
'HttpErrorResponse',
11+
'HttpxMiddleware',
1112
'LapidaryError',
1213
'LapidaryResponseError',
1314
'Metadata',
@@ -35,6 +36,7 @@
3536

3637
from .annotations import Body, Cookie, Header, Metadata, Path, Query, Response, Responses, StatusCode
3738
from .client_base import ClientBase, lapidary_user_agent
39+
from .middleware import HttpxMiddleware
3840
from .model import ModelBase
3941
from .model.error import HttpErrorResponse, LapidaryError, LapidaryResponseError, UnexpectedResponse
4042
from .model.param_serialization import Form, FormExplode, SimpleMultimap, SimpleString

src/lapidary/runtime/client_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import typing_extensions as typing
88

99
from .http_consts import USER_AGENT
10+
from .middleware import HttpxMiddleware
1011
from .model.auth import AuthRegistry
1112

1213
if typing.TYPE_CHECKING:
1314
import types
14-
from collections.abc import Iterable
15+
from collections.abc import Iterable, Sequence
1516

1617
from .types_ import ClientArgs, NamedAuth, SecurityRequirements, SessionFactory
1718

@@ -29,13 +30,15 @@ def __init__(
2930
self,
3031
security: Iterable[SecurityRequirements] | None = None,
3132
session_factory: SessionFactory = httpx.AsyncClient,
33+
middlewares: Sequence[HttpxMiddleware] = (),
3234
**httpx_kwargs: typing.Unpack[ClientArgs],
3335
) -> None:
3436
self._client = session_factory(**httpx_kwargs)
3537
if USER_AGENT not in self._client.headers:
3638
self._client.headers[USER_AGENT] = lapidary_user_agent()
3739

3840
self._auth_registry = AuthRegistry(security)
41+
self._middlewares = middlewares
3942

4043
async def __aenter__(self: typing.Self) -> typing.Self:
4144
await self._client.__aenter__()

src/lapidary/runtime/middleware.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import abc
2+
from typing import Generic, TypeVar
3+
4+
import httpx
5+
6+
State = TypeVar('State')
7+
8+
9+
class HttpxMiddleware(Generic[State]):
10+
@abc.abstractmethod
11+
async def handle_request(self, request: httpx.Request) -> State:
12+
pass
13+
14+
async def handle_response(self, response: httpx.Response, request: httpx.Request, state: State) -> None:
15+
pass

src/lapidary/runtime/model/op.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,17 @@ def mk_exchange_fn(
3131
async def exchange(self: 'ClientBase', **kwargs) -> typing.Any:
3232
request, auth = request_adapter.build_request(self, kwargs)
3333

34+
mw_state = []
35+
for mw in self._middlewares:
36+
mw_state.append(await mw.handle_request(request))
37+
3438
response = await self._client.send(request, auth=auth)
3539

3640
await response.aread()
41+
42+
for mw, state in zip(reversed(self._middlewares), reversed(mw_state)):
43+
await mw.handle_response(response, request, state)
44+
3745
status_code, result = response_handler.handle_response(response)
3846
if status_code >= 400:
3947
raise HttpErrorResponse(status_code, result[1], result[0])

src/lapidary/runtime/paging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections.abc import AsyncIterable, Awaitable, Callable
22
from typing import Optional, TypeVar
33

4-
from typing_extensions import ParamSpec, Unpack
4+
from typing_extensions import ParamSpec
55

66
P = ParamSpec('P')
77
R = TypeVar('R')

0 commit comments

Comments
 (0)