Skip to content

Commit 6c3f476

Browse files
committed
1.2.0
1 parent 406f160 commit 6c3f476

File tree

21 files changed

+320
-240
lines changed

21 files changed

+320
-240
lines changed

README.md

Lines changed: 75 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ A convenient way to configure and interact with async sqlalchemy session
1414
from context_async_sqlalchemy import db_session
1515
from sqlalchemy import insert
1616

17-
from ..models import ExampleTable
17+
from ..database import master # your configured connection to the database
18+
from ..models import ExampleTable # just some model for example
1819

1920
async def some_func() -> None:
2021
# Created a session (no connection to the database yet)
2122
# If you call db_session again, it will return the same session
2223
# even in child coroutines.
23-
session = await db_session()
24+
session = await db_session(master)
2425

2526
stmt = insert(ExampleTable).values(text="example_with_db_session")
2627

@@ -46,8 +47,7 @@ It also includes two types of test setups you can use in your projects.
4647

4748
#### 1. Configure the connection to the database
4849

49-
for example for PostgreSQL:
50-
50+
for example for PostgreSQL database.py:
5151
```python
5252
from sqlalchemy.ext.asyncio import (
5353
async_sessionmaker,
@@ -56,13 +56,14 @@ from sqlalchemy.ext.asyncio import (
5656
create_async_engine,
5757
)
5858

59+
from context_async_sqlalchemy import DBConnect
60+
5961

6062
def create_engine(host: str) -> AsyncEngine:
6163
"""
6264
database connection parameters.
63-
In production code, you will probably take these parameters from
64-
the environment.
6565
"""
66+
# In production code, you will probably take these parameters from env
6667
pg_user = "krylosov-aa"
6768
pg_password = ""
6869
pg_port = 6432
@@ -84,6 +85,14 @@ def create_session_maker(
8485
return async_sessionmaker(
8586
engine, class_=AsyncSession, expire_on_commit=False
8687
)
88+
89+
90+
master = DBConnect(
91+
host="127.0.0.1",
92+
engine_creator=create_engine,
93+
session_maker_creator=create_session_maker,
94+
)
95+
8796
```
8897

8998
#### 2. Manage Database connection lifecycle
@@ -94,36 +103,18 @@ Close the resources at the end of your application's life
94103
Example for FastAPI:
95104

96105
```python
97-
import asyncio
98-
from typing import Any, AsyncGenerator
99106
from contextlib import asynccontextmanager
107+
from typing import Any, AsyncGenerator
100108
from fastapi import FastAPI
101109

102-
from context_async_sqlalchemy import (
103-
master_connect,
104-
replica_connect,
105-
)
106-
107-
from .database import create_engine, create_session_maker
110+
from .database import master
108111

109-
async def setup_database() -> None:
110-
"""
111-
Here you pass the database connection parameters to the library.
112-
More specifically, the engine and session maker.
113-
"""
114-
master_connect.engine_creator = create_engine
115-
master_connect.session_maker_creator = create_session_maker
116-
await master_connect.connect("127.0.0.1")
117112

118113
@asynccontextmanager
119114
async def lifespan(app: FastAPI) -> AsyncGenerator[None, Any]:
120115
"""Database connection lifecycle management"""
121-
await setup_database()
122116
yield
123-
await asyncio.gather(
124-
master_connect.close(), # Close the engine if it was open
125-
replica_connect.close(), # Close the engine if it was open
126-
)
117+
await master.close() # Close the engine if it was open
127118
```
128119

129120

@@ -141,11 +132,11 @@ from starlette.middleware.base import ( # type: ignore[attr-defined]
141132
)
142133

143134
from context_async_sqlalchemy import (
144-
auto_commit_by_status_code,
145135
init_db_session_ctx,
146136
is_context_initiated,
147137
reset_db_session_ctx,
148-
rollback_db_session,
138+
auto_commit_by_status_code,
139+
rollback_all_sessions,
149140
)
150141

151142

@@ -161,7 +152,7 @@ async def fastapi_db_session_middleware(
161152
162153
But you can commit or rollback manually in the handler.
163154
"""
164-
# Tests may have different session management rules
155+
# Tests have different session management rules
165156
# so if the context variable is already set, we do nothing
166157
if is_context_initiated():
167158
return await call_next(request)
@@ -176,7 +167,7 @@ async def fastapi_db_session_middleware(
176167
await auto_commit_by_status_code(response.status_code)
177168
return response
178169
except Exception:
179-
await rollback_db_session()
170+
await rollback_all_sessions()
180171
raise
181172
finally:
182173
await reset_db_session_ctx(token)
@@ -197,11 +188,14 @@ app.add_middleware(
197188
#### 4. Write a function that will work with the session
198189

199190
```python
200-
from context_async_sqlalchemy import db_session
201191
from sqlalchemy import insert
202192

193+
from context_async_sqlalchemy import db_session
194+
195+
from ..database import master
203196
from ..models import ExampleTable
204197

198+
205199
async def handler_with_db_session() -> None:
206200
"""
207201
An example of a typical handle that uses a context session to work with
@@ -212,10 +206,58 @@ async def handler_with_db_session() -> None:
212206
# Created a session (no connection to the database yet)
213207
# If you call db_session again, it will return the same session
214208
# even in child coroutines.
215-
session = await db_session()
209+
session = await db_session(master)
216210

217211
stmt = insert(ExampleTable).values(text="example_with_db_session")
218212

219213
# On the first request, a connection and transaction were opened
220214
await session.execute(stmt)
221215
```
216+
217+
218+
## Master/Replica or several databases at the same time
219+
220+
This is why `db_session` and other functions accept `DBConnect` as input.
221+
This way, you can work with multiple hosts simultaneously,
222+
for example, with the master and the replica.
223+
224+
225+
Let's imagine that you have a third-party functionality that helps determine
226+
the master or replica.
227+
228+
In this example, the host is not set from the very beginning, but will be
229+
calculated during the first call to create a session.
230+
231+
```python
232+
from context_async_sqlalchemy import DBConnect
233+
234+
from master_replica_helper import get_master, get_replica
235+
236+
237+
async def renew_master_connect(connect: DBConnect) -> None:
238+
"""Updates the connection with the master if the master has changed"""
239+
master_host = await get_master()
240+
if master_host != connect.host:
241+
await connect.change_host(master_host)
242+
243+
244+
master = DBConnect(
245+
engine_creator=create_engine,
246+
session_maker_creator=create_session_maker,
247+
before_create_session_handler=renew_master_connect,
248+
)
249+
250+
251+
async def renew_replica_connect(connect: DBConnect) -> None:
252+
"""Updates the connection with the replica if the master has changed"""
253+
replica_host = await get_replica()
254+
if replica_host != connect.host:
255+
await connect.change_host(replica_host)
256+
257+
258+
replica = DBConnect(
259+
engine_creator=create_engine,
260+
session_maker_creator=create_session_maker,
261+
before_create_session_handler=renew_replica_connect,
262+
)
263+
```

context_async_sqlalchemy/__init__.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,22 @@
77
pop_db_session_from_context,
88
run_in_new_ctx,
99
)
10-
from .connect import (
11-
DBConnect,
12-
master_connect,
13-
replica_connect,
14-
)
10+
from .connect import DBConnect
1511
from .session import (
1612
db_session,
1713
atomic_db_session,
18-
run_with_new_db_session,
19-
run_with_new_atomic_db_session,
2014
commit_db_session,
2115
rollback_db_session,
2216
close_db_session,
2317
new_non_ctx_atomic_session,
2418
new_non_ctx_session,
2519
)
26-
from .auto_commit import auto_commit_by_status_code
20+
from .auto_commit import (
21+
auto_commit_by_status_code,
22+
commit_all_sessions,
23+
rollback_all_sessions,
24+
close_all_sessions,
25+
)
2726
from .fastapi_utils.middleware import fastapi_db_session_middleware
2827

2928
__all__ = [
@@ -35,17 +34,16 @@
3534
"pop_db_session_from_context",
3635
"run_in_new_ctx",
3736
"DBConnect",
38-
"master_connect",
39-
"replica_connect",
4037
"db_session",
4138
"atomic_db_session",
42-
"run_with_new_db_session",
43-
"run_with_new_atomic_db_session",
4439
"commit_db_session",
4540
"rollback_db_session",
4641
"close_db_session",
47-
"auto_commit_by_status_code",
48-
"fastapi_db_session_middleware",
4942
"new_non_ctx_atomic_session",
5043
"new_non_ctx_session",
44+
"auto_commit_by_status_code",
45+
"commit_all_sessions",
46+
"rollback_all_sessions",
47+
"close_all_sessions",
48+
"fastapi_db_session_middleware",
5149
]
Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,51 @@
11
from http import HTTPStatus
22

3-
from .context import get_db_session_from_context
3+
from .context import sessions_stream
44

55

66
async def auto_commit_by_status_code(status_code: int) -> None:
77
"""
88
Implements automatic commit or rollback.
9-
It should be used, for example, in the middleware or anywhere else
10-
where you expect session lifecycle management.
9+
It should be used in a middleware or anywhere else where you expect
10+
session lifecycle management.
1111
"""
12-
session = get_db_session_from_context()
12+
if status_code < HTTPStatus.BAD_REQUEST:
13+
await commit_all_sessions()
14+
else:
15+
await rollback_all_sessions()
1316

14-
if session and session.in_transaction():
15-
if status_code < HTTPStatus.BAD_REQUEST:
16-
await session.commit()
17-
else:
17+
18+
async def rollback_all_sessions() -> None:
19+
"""
20+
Rolls back all open context sessions.
21+
22+
It should be used in middleware or anywhere else where you expect
23+
lifecycle management and need to roll back all opened sessions.
24+
For example, inside an except block.
25+
"""
26+
for session in sessions_stream():
27+
if session.in_transaction():
1828
await session.rollback()
29+
30+
31+
async def commit_all_sessions() -> None:
32+
"""
33+
Commits all open context sessions.
34+
35+
It should be used in middleware or anywhere else where you expect
36+
lifecycle management and need to commit all opened sessions.
37+
"""
38+
for session in sessions_stream():
39+
if session.in_transaction():
40+
await session.commit()
41+
42+
43+
async def close_all_sessions() -> None:
44+
"""
45+
Closes all open context sessions.
46+
47+
It should be used in middleware or anywhere else where you expect
48+
lifecycle management and need to close all sessions.
49+
"""
50+
for session in sessions_stream():
51+
await session.close()

0 commit comments

Comments
 (0)