Skip to content

Commit 9c72c83

Browse files
committed
asgi middleware
1 parent 8846697 commit 9c72c83

File tree

7 files changed

+96
-3
lines changed

7 files changed

+96
-3
lines changed

context_async_sqlalchemy/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@
2323
close_all_sessions,
2424
)
2525
from .run_in_new_context import run_in_new_ctx
26+
27+
from .asgi_utils import (
28+
ASGIHTTPDBSessionMiddleware,
29+
)
30+
2631
from .starlette_utils import (
2732
add_starlette_http_db_session_middleware,
2833
starlette_http_db_session_middleware,
34+
StarletteHTTPDBSessionMiddleware,
2935
)
3036

3137
from .fastapi_utils import (
@@ -53,8 +59,10 @@
5359
"commit_all_sessions",
5460
"rollback_all_sessions",
5561
"close_all_sessions",
62+
"ASGIHTTPDBSessionMiddleware",
5663
"add_starlette_http_db_session_middleware",
5764
"starlette_http_db_session_middleware",
65+
"StarletteHTTPDBSessionMiddleware",
5866
"fastapi_http_db_session_middleware",
5967
"add_fastapi_http_db_session_middleware",
6068
]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .middleware import ASGIHTTPDBSessionMiddleware
2+
3+
__all__ = [
4+
"ASGIHTTPDBSessionMiddleware",
5+
]
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from starlette.types import ASGIApp, Receive, Scope, Send
2+
from http import HTTPStatus
3+
4+
from context_async_sqlalchemy import (
5+
init_db_session_ctx,
6+
is_context_initiated,
7+
reset_db_session_ctx,
8+
auto_commit_by_status_code,
9+
rollback_all_sessions,
10+
)
11+
12+
13+
class ASGIHTTPDBSessionMiddleware:
14+
"""Database session lifecycle management."""
15+
16+
def __init__(self, app: ASGIApp):
17+
self.app = app
18+
19+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
20+
"""
21+
Database session lifecycle management.
22+
The session itself is created on demand in db_session().
23+
24+
Transaction auto-commit is implemented if there is no exception and
25+
the response status is < 400. Otherwise, a rollback is performed.
26+
27+
But you can commit or rollback manually in the handler.
28+
"""
29+
if scope["type"] != "http":
30+
await self.app(scope, receive, send)
31+
return
32+
33+
# Tests have different session management rules
34+
# so if the context variable is already set, we do nothing
35+
if is_context_initiated():
36+
await self.app(scope, receive, send)
37+
return
38+
39+
# We set the context here, meaning all child coroutines
40+
# will receive the same context.
41+
# And even if a child coroutine requests the session first,
42+
# the container itself is shared, and this coroutine will
43+
# add the session to container = shared context.
44+
token = init_db_session_ctx()
45+
46+
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
47+
48+
async def send_wrapper(message):
49+
nonlocal status_code
50+
if message["type"] == "http.response.start":
51+
status_code = message["status"]
52+
await send(message)
53+
54+
try:
55+
await self.app(scope, receive, send_wrapper)
56+
# using the status code, we decide to commit or rollback
57+
# all sessions
58+
await auto_commit_by_status_code(status_code)
59+
except Exception:
60+
# If an exception occurs, we roll all sessions back
61+
await rollback_all_sessions()
62+
raise
63+
finally:
64+
# Close all sessions and clear the context
65+
await reset_db_session_ctx(token)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from .http_middleware import (
22
add_starlette_http_db_session_middleware,
33
starlette_http_db_session_middleware,
4+
StarletteHTTPDBSessionMiddleware,
45
)
56

67
__all__ = [
78
"add_starlette_http_db_session_middleware",
89
"starlette_http_db_session_middleware",
10+
"StarletteHTTPDBSessionMiddleware",
911
]

context_async_sqlalchemy/starlette_utils/http_middleware.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ def add_starlette_http_db_session_middleware(app: Starlette) -> None:
2222
)
2323

2424

25+
class StarletteHTTPDBSessionMiddleware(BaseHTTPMiddleware):
26+
"""Database session lifecycle management."""
27+
28+
async def dispatch(
29+
self, request: Request, call_next: RequestResponseEndpoint
30+
) -> Response:
31+
return await starlette_http_db_session_middleware(
32+
request, call_next
33+
)
34+
35+
2536
async def starlette_http_db_session_middleware(
2637
request: Request, call_next: RequestResponseEndpoint
2738
) -> Response:

examples/fastapi_example/setup_app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def setup_app() -> FastAPI:
2626
"""
2727
A convenient entry point for app configuration.
2828
Convenient for testing.
29-
You don't have to follow my example (though I recommend it).
29+
30+
You don't have to follow my example here.
3031
"""
3132
app = FastAPI(
3233
lifespan=lifespan,

examples/starlette_example/setup_app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from starlette.applications import Starlette
77
from starlette.routing import Route
88

9-
9+
from context_async_sqlalchemy import ASGIHTTPDBSessionMiddleware
1010
from context_async_sqlalchemy.starlette_utils import (
1111
add_starlette_http_db_session_middleware,
1212
)
@@ -29,7 +29,8 @@ def setup_app() -> Starlette:
2929
"""
3030
A convenient entry point for app configuration.
3131
Convenient for testing.
32-
You don't have to follow my example.
32+
33+
You don't have to follow my example here.
3334
"""
3435
app = Starlette(
3536
debug=True,

0 commit comments

Comments
 (0)