Skip to content

Commit 5f2f7b7

Browse files
authored
feat: add engine_args argument to engine creation functions (#242)
1 parent 1d106f8 commit 5f2f7b7

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

src/langchain_google_cloud_sql_pg/engine.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from concurrent.futures import Future
1919
from dataclasses import dataclass
2020
from threading import Thread
21-
from typing import TYPE_CHECKING, Any, Awaitable, Optional, TypeVar, Union
21+
from typing import TYPE_CHECKING, Any, Awaitable, Mapping, Optional, TypeVar, Union
2222

2323
import aiohttp
2424
import google.auth # type: ignore
@@ -143,6 +143,7 @@ async def _create(
143143
thread: Optional[Thread] = None,
144144
quota_project: Optional[str] = None,
145145
iam_account_email: Optional[str] = None,
146+
engine_args: Mapping = {},
146147
) -> PostgresEngine:
147148
"""Create a PostgresEngine instance.
148149
@@ -158,6 +159,9 @@ async def _create(
158159
thread (Optional[Thread]): Thread used to create the engine async.
159160
quota_project (Optional[str]): Project that provides quota for API calls.
160161
iam_account_email (Optional[str]): IAM service account email. Defaults to None.
162+
engine_args (Mapping): Additional arguments that are passed directly to
163+
:func:`~sqlalchemy.ext.asyncio.mymodule.MyClass.create_async_engine`. This can be
164+
used to specify additional parameters to the underlying pool during it's creation.
161165
162166
Raises:
163167
ValueError: If only one of `user` and `password` is specified.
@@ -211,6 +215,7 @@ async def getconn() -> asyncpg.Connection:
211215
engine = create_async_engine(
212216
"postgresql+asyncpg://",
213217
async_creator=getconn,
218+
**engine_args,
214219
)
215220
return cls(cls.__create_key, engine, loop, thread)
216221

@@ -226,6 +231,7 @@ def __start_background_loop(
226231
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
227232
quota_project: Optional[str] = None,
228233
iam_account_email: Optional[str] = None,
234+
engine_args: Mapping = {},
229235
) -> Future:
230236
# Running a loop in a background thread allows us to support
231237
# async methods from non-async environments
@@ -247,6 +253,7 @@ def __start_background_loop(
247253
thread=cls._default_thread,
248254
quota_project=quota_project,
249255
iam_account_email=iam_account_email,
256+
engine_args=engine_args,
250257
)
251258
return asyncio.run_coroutine_threadsafe(coro, cls._default_loop)
252259

@@ -262,6 +269,7 @@ def from_instance(
262269
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
263270
quota_project: Optional[str] = None,
264271
iam_account_email: Optional[str] = None,
272+
engine_args: Mapping = {},
265273
) -> PostgresEngine:
266274
"""Create a PostgresEngine from a Postgres instance.
267275
@@ -275,6 +283,9 @@ def from_instance(
275283
ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC.
276284
quota_project (Optional[str]): Project that provides quota for API calls.
277285
iam_account_email (Optional[str], optional): IAM service account email. Defaults to None.
286+
engine_args (Mapping): Additional arguments that are passed directly to
287+
:func:`~sqlalchemy.ext.asyncio.mymodule.MyClass.create_async_engine`. This can be
288+
used to specify additional parameters to the underlying pool during it's creation.
278289
279290
Returns:
280291
PostgresEngine: A newly created PostgresEngine instance.
@@ -289,6 +300,7 @@ def from_instance(
289300
ip_type,
290301
quota_project=quota_project,
291302
iam_account_email=iam_account_email,
303+
engine_args=engine_args,
292304
)
293305
return future.result()
294306

@@ -304,6 +316,7 @@ async def afrom_instance(
304316
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
305317
quota_project: Optional[str] = None,
306318
iam_account_email: Optional[str] = None,
319+
engine_args: Mapping = {},
307320
) -> PostgresEngine:
308321
"""Create a PostgresEngine from a Postgres instance.
309322
@@ -317,6 +330,9 @@ async def afrom_instance(
317330
ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC.
318331
quota_project (Optional[str]): Project that provides quota for API calls.
319332
iam_account_email (Optional[str], optional): IAM service account email. Defaults to None.
333+
engine_args (Mapping): Additional arguments that are passed directly to
334+
:func:`~sqlalchemy.ext.asyncio.mymodule.MyClass.create_async_engine`. This can be
335+
used to specify additional parameters to the underlying pool during it's creation.
320336
321337
Returns:
322338
PostgresEngine: A newly created PostgresEngine instance.
@@ -331,6 +347,7 @@ async def afrom_instance(
331347
ip_type,
332348
quota_project=quota_project,
333349
iam_account_email=iam_account_email,
350+
engine_args=engine_args,
334351
)
335352
return await asyncio.wrap_future(future)
336353

@@ -346,7 +363,7 @@ def from_engine(
346363
@classmethod
347364
def from_engine_args(
348365
cls,
349-
url: Union[str | URL],
366+
url: str | URL,
350367
**kwargs: Any,
351368
) -> PostgresEngine:
352369
"""Create an PostgresEngine instance from arguments. These parameters are pass directly into sqlalchemy's create_async_engine function.

tests/test_engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,21 @@ async def engine(self, db_project, db_region, db_instance, db_name):
110110
instance=db_instance,
111111
region=db_region,
112112
database=db_name,
113+
engine_args={
114+
# add some connection args to validate engine_args works correctly
115+
"pool_size": 3,
116+
"max_overflow": 2,
117+
},
113118
)
114119
yield engine
115120
await aexecute(engine, f'DROP TABLE "{CUSTOM_TABLE}"')
116121
await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"')
117122
await aexecute(engine, f'DROP TABLE "{INT_ID_CUSTOM_TABLE}"')
118123
await engine.close()
119124

125+
async def test_engine_args(self, engine):
126+
assert "Pool size: 3" in engine._pool.pool.status()
127+
120128
async def test_init_table(self, engine):
121129
await engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE)
122130
id = str(uuid.uuid4())

0 commit comments

Comments
 (0)