1717import asyncio
1818from dataclasses import dataclass
1919from threading import Thread
20- from typing import TYPE_CHECKING , Awaitable , Dict , List , Optional , TypeVar , Union
20+ from typing import (
21+ TYPE_CHECKING ,
22+ Awaitable ,
23+ Dict ,
24+ List ,
25+ Optional ,
26+ Sequence ,
27+ TypeVar ,
28+ Union ,
29+ )
2130
2231import aiohttp
2332import google .auth # type: ignore
2433import google .auth .transport .requests # type: ignore
2534from google .cloud .sql .connector import Connector , IPTypes , RefreshStrategy
2635from sqlalchemy import MetaData , Table , text
36+ from sqlalchemy .engine .row import RowMapping
2737from sqlalchemy .exc import InvalidRequestError
2838from sqlalchemy .ext .asyncio import AsyncEngine , create_async_engine
2939
@@ -305,19 +315,21 @@ def from_engine(cls, engine: AsyncEngine) -> PostgresEngine:
305315 """Create an PostgresEngine instance from an AsyncEngine."""
306316 return cls (cls .__create_key , engine , None , None )
307317
308- async def _aexecute (self , query : str , params : Optional [dict ] = None ):
318+ async def _aexecute (self , query : str , params : Optional [dict ] = None ) -> None :
309319 """Execute a SQL query."""
310320 async with self ._engine .connect () as conn :
311321 await conn .execute (text (query ), params )
312322 await conn .commit ()
313323
314- async def _aexecute_outside_tx (self , query : str ):
324+ async def _aexecute_outside_tx (self , query : str ) -> None :
315325 """Execute a SQL query."""
316326 async with self ._engine .connect () as conn :
317327 await conn .execute (text ("COMMIT" ))
318328 await conn .execute (text (query ))
319329
320- async def _afetch (self , query : str , params : Optional [dict ] = None ):
330+ async def _afetch (
331+ self , query : str , params : Optional [dict ] = None
332+ ) -> Sequence [RowMapping ]:
321333 """Fetch results from a SQL query."""
322334 async with self ._engine .connect () as conn :
323335 result = await conn .execute (text (query ), params )
@@ -326,11 +338,11 @@ async def _afetch(self, query: str, params: Optional[dict] = None):
326338
327339 return result_fetch
328340
329- def _execute (self , query : str , params : Optional [dict ] = None ):
341+ def _execute (self , query : str , params : Optional [dict ] = None ) -> None :
330342 """Execute a SQL query."""
331343 return self ._run_as_sync (self ._aexecute (query , params ))
332344
333- def _fetch (self , query : str , params : Optional [dict ] = None ):
345+ def _fetch (self , query : str , params : Optional [dict ] = None ) -> Sequence [ RowMapping ] :
334346 """Fetch results from a SQL query."""
335347 return self ._run_as_sync (self ._afetch (query , params ))
336348
@@ -439,7 +451,7 @@ def init_vectorstore_table(
439451 )
440452 )
441453
442- async def ainit_chat_history_table (self , table_name ) -> None :
454+ async def ainit_chat_history_table (self , table_name : str ) -> None :
443455 """Create a Cloud SQL table to store chat history.
444456
445457 Args:
@@ -456,7 +468,7 @@ async def ainit_chat_history_table(self, table_name) -> None:
456468 );"""
457469 await self ._aexecute (create_table_query )
458470
459- def init_chat_history_table (self , table_name ) -> None :
471+ def init_chat_history_table (self , table_name : str ) -> None :
460472 """Create a Cloud SQL table to store chat history.
461473
462474 Args:
0 commit comments