|
1 | | -from starlette.types import ASGIApp, Receive, Scope, Send |
| 1 | +from collections.abc import Awaitable, MutableMapping |
| 2 | +from typing import Any, Callable |
| 3 | + |
2 | 4 | from http import HTTPStatus |
3 | 5 |
|
4 | 6 | from context_async_sqlalchemy import ( |
|
9 | 11 | rollback_all_sessions, |
10 | 12 | ) |
11 | 13 |
|
| 14 | +Message = MutableMapping[str, Any] |
| 15 | +Receive = Callable[[], Awaitable[Message]] |
| 16 | +Scope = MutableMapping[str, Any] |
| 17 | +Send = Callable[[Message], Awaitable[None]] |
| 18 | +ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]] |
| 19 | + |
12 | 20 |
|
13 | 21 | class ASGIHTTPDBSessionMiddleware: |
14 | 22 | """Database session lifecycle management.""" |
15 | 23 |
|
16 | 24 | def __init__(self, app: ASGIApp): |
17 | 25 | self.app = app |
18 | 26 |
|
19 | | - async def __call__(self, scope: Scope, receive: Receive, send: Send): |
| 27 | + async def __call__( |
| 28 | + self, scope: Scope, receive: Receive, send: Send |
| 29 | + ) -> None: |
20 | 30 | """ |
21 | 31 | Database session lifecycle management. |
22 | 32 | The session itself is created on demand in db_session(). |
@@ -45,7 +55,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): |
45 | 55 |
|
46 | 56 | status_code = HTTPStatus.INTERNAL_SERVER_ERROR |
47 | 57 |
|
48 | | - async def send_wrapper(message): |
| 58 | + async def send_wrapper(message: Message) -> None: |
49 | 59 | nonlocal status_code |
50 | 60 | if message["type"] == "http.response.start": |
51 | 61 | status_code = message["status"] |
|
0 commit comments