|
| 1 | +from starlette.types import ASGIApp, Receive, Scope, Send |
| 2 | +from http import HTTPStatus |
| 3 | + |
| 4 | +from context_async_sqlalchemy import ( |
| 5 | + init_db_session_ctx, |
| 6 | + is_context_initiated, |
| 7 | + reset_db_session_ctx, |
| 8 | + auto_commit_by_status_code, |
| 9 | + rollback_all_sessions, |
| 10 | +) |
| 11 | + |
| 12 | + |
| 13 | +class ASGIHTTPDBSessionMiddleware: |
| 14 | + """Database session lifecycle management.""" |
| 15 | + |
| 16 | + def __init__(self, app: ASGIApp): |
| 17 | + self.app = app |
| 18 | + |
| 19 | + async def __call__(self, scope: Scope, receive: Receive, send: Send): |
| 20 | + """ |
| 21 | + Database session lifecycle management. |
| 22 | + The session itself is created on demand in db_session(). |
| 23 | +
|
| 24 | + Transaction auto-commit is implemented if there is no exception and |
| 25 | + the response status is < 400. Otherwise, a rollback is performed. |
| 26 | +
|
| 27 | + But you can commit or rollback manually in the handler. |
| 28 | + """ |
| 29 | + if scope["type"] != "http": |
| 30 | + await self.app(scope, receive, send) |
| 31 | + return |
| 32 | + |
| 33 | + # Tests have different session management rules |
| 34 | + # so if the context variable is already set, we do nothing |
| 35 | + if is_context_initiated(): |
| 36 | + await self.app(scope, receive, send) |
| 37 | + return |
| 38 | + |
| 39 | + # We set the context here, meaning all child coroutines |
| 40 | + # will receive the same context. |
| 41 | + # And even if a child coroutine requests the session first, |
| 42 | + # the container itself is shared, and this coroutine will |
| 43 | + # add the session to container = shared context. |
| 44 | + token = init_db_session_ctx() |
| 45 | + |
| 46 | + status_code = HTTPStatus.INTERNAL_SERVER_ERROR |
| 47 | + |
| 48 | + async def send_wrapper(message): |
| 49 | + nonlocal status_code |
| 50 | + if message["type"] == "http.response.start": |
| 51 | + status_code = message["status"] |
| 52 | + await send(message) |
| 53 | + |
| 54 | + try: |
| 55 | + await self.app(scope, receive, send_wrapper) |
| 56 | + # using the status code, we decide to commit or rollback |
| 57 | + # all sessions |
| 58 | + await auto_commit_by_status_code(status_code) |
| 59 | + except Exception: |
| 60 | + # If an exception occurs, we roll all sessions back |
| 61 | + await rollback_all_sessions() |
| 62 | + raise |
| 63 | + finally: |
| 64 | + # Close all sessions and clear the context |
| 65 | + await reset_db_session_ctx(token) |
0 commit comments