Skip to content

Commit 9745c47

Browse files
ci: Add mypy function type check (#130)
* ci: Add mypy function type check * fix typing * fix rebase conflict * more typing fix * fix type * moreeee fix --------- Co-authored-by: Averi Kitsch <akitsch@google.com>
1 parent 4eb3011 commit 9745c47

File tree

6 files changed

+44
-29
lines changed

6 files changed

+44
-29
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ profile = "black"
6262
[tool.mypy]
6363
python_version = 3.8
6464
warn_unused_configs = true
65+
disallow_incomplete_defs = true
66+
6567
exclude = [
6668
'docs/*',
6769
'noxfile.py'

src/langchain_google_cloud_sql_pg/chat_message_history.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
4444

4545
def __init__(
4646
self,
47-
key,
47+
key: object,
4848
engine: PostgresEngine,
4949
session_id: str,
5050
table_name: str,
@@ -77,7 +77,7 @@ async def create(
7777
engine: PostgresEngine,
7878
session_id: str,
7979
table_name: str,
80-
):
80+
) -> PostgresChatMessageHistory:
8181
"""Create a new PostgresChatMessageHistory instance.
8282
8383
Args:
@@ -117,7 +117,7 @@ def create_sync(
117117
engine: PostgresEngine,
118118
session_id: str,
119119
table_name: str,
120-
):
120+
) -> PostgresChatMessageHistory:
121121
"""Create a new PostgresChatMessageHistory instance.
122122
123123
Args:

src/langchain_google_cloud_sql_pg/engine.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,23 @@
1717
import asyncio
1818
from dataclasses import dataclass
1919
from threading import Thread
20-
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, TypeVar, Union
20+
from typing import (
21+
TYPE_CHECKING,
22+
Awaitable,
23+
Dict,
24+
List,
25+
Optional,
26+
Sequence,
27+
TypeVar,
28+
Union,
29+
)
2130

2231
import aiohttp
2332
import google.auth # type: ignore
2433
import google.auth.transport.requests # type: ignore
2534
from google.cloud.sql.connector import Connector, IPTypes, RefreshStrategy
2635
from sqlalchemy import MetaData, Table, text
36+
from sqlalchemy.engine.row import RowMapping
2737
from sqlalchemy.exc import InvalidRequestError
2838
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
2939

@@ -305,19 +315,21 @@ def from_engine(cls, engine: AsyncEngine) -> PostgresEngine:
305315
"""Create an PostgresEngine instance from an AsyncEngine."""
306316
return cls(cls.__create_key, engine, None, None)
307317

308-
async def _aexecute(self, query: str, params: Optional[dict] = None):
318+
async def _aexecute(self, query: str, params: Optional[dict] = None) -> None:
309319
"""Execute a SQL query."""
310320
async with self._engine.connect() as conn:
311321
await conn.execute(text(query), params)
312322
await conn.commit()
313323

314-
async def _aexecute_outside_tx(self, query: str):
324+
async def _aexecute_outside_tx(self, query: str) -> None:
315325
"""Execute a SQL query."""
316326
async with self._engine.connect() as conn:
317327
await conn.execute(text("COMMIT"))
318328
await conn.execute(text(query))
319329

320-
async def _afetch(self, query: str, params: Optional[dict] = None):
330+
async def _afetch(
331+
self, query: str, params: Optional[dict] = None
332+
) -> Sequence[RowMapping]:
321333
"""Fetch results from a SQL query."""
322334
async with self._engine.connect() as conn:
323335
result = await conn.execute(text(query), params)
@@ -326,11 +338,11 @@ async def _afetch(self, query: str, params: Optional[dict] = None):
326338

327339
return result_fetch
328340

329-
def _execute(self, query: str, params: Optional[dict] = None):
341+
def _execute(self, query: str, params: Optional[dict] = None) -> None:
330342
"""Execute a SQL query."""
331343
return self._run_as_sync(self._aexecute(query, params))
332344

333-
def _fetch(self, query: str, params: Optional[dict] = None):
345+
def _fetch(self, query: str, params: Optional[dict] = None) -> Sequence[RowMapping]:
334346
"""Fetch results from a SQL query."""
335347
return self._run_as_sync(self._afetch(query, params))
336348

@@ -439,7 +451,7 @@ def init_vectorstore_table(
439451
)
440452
)
441453

442-
async def ainit_chat_history_table(self, table_name) -> None:
454+
async def ainit_chat_history_table(self, table_name: str) -> None:
443455
"""Create a Cloud SQL table to store chat history.
444456
445457
Args:
@@ -456,7 +468,7 @@ async def ainit_chat_history_table(self, table_name) -> None:
456468
);"""
457469
await self._aexecute(create_table_query)
458470

459-
def init_chat_history_table(self, table_name) -> None:
471+
def init_chat_history_table(self, table_name: str) -> None:
460472
"""Create a Cloud SQL table to store chat history.
461473
462474
Args:

src/langchain_google_cloud_sql_pg/loader.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,24 @@
3636
DEFAULT_METADATA_COL = "langchain_metadata"
3737

3838

39-
def text_formatter(row, content_columns) -> str:
39+
def text_formatter(row: dict, content_columns: List[str]) -> str:
4040
"""txt document formatter."""
4141
return " ".join(str(row[column]) for column in content_columns if column in row)
4242

4343

44-
def csv_formatter(row, content_columns) -> str:
44+
def csv_formatter(row: dict, content_columns: List[str]) -> str:
4545
"""CSV document formatter."""
4646
return ", ".join(str(row[column]) for column in content_columns if column in row)
4747

4848

49-
def yaml_formatter(row, content_columns) -> str:
49+
def yaml_formatter(row: dict, content_columns: List[str]) -> str:
5050
"""YAML document formatter."""
5151
return "\n".join(
5252
f"{column}: {str(row[column])}" for column in content_columns if column in row
5353
)
5454

5555

56-
def json_formatter(row, content_columns) -> str:
56+
def json_formatter(row: dict, content_columns: List[str]) -> str:
5757
"""JSON document formatter."""
5858
dictionary = {}
5959
for column in content_columns:
@@ -116,7 +116,7 @@ class PostgresLoader(BaseLoader):
116116

117117
def __init__(
118118
self,
119-
key,
119+
key: object,
120120
engine: PostgresEngine,
121121
query: str,
122122
content_columns: List[str],
@@ -162,7 +162,7 @@ async def create(
162162
metadata_json_column: Optional[str] = None,
163163
format: Optional[str] = None,
164164
formatter: Optional[Callable] = None,
165-
):
165+
) -> PostgresLoader:
166166
"""Create a new PostgresLoader instance.
167167
168168
Args:
@@ -255,7 +255,7 @@ def create_sync(
255255
metadata_json_column: Optional[str] = None,
256256
format: Optional[str] = None,
257257
formatter: Optional[Callable] = None,
258-
):
258+
) -> PostgresLoader:
259259
"""Create a new PostgresLoader instance.
260260
261261
Args:
@@ -340,7 +340,7 @@ class PostgresDocumentSaver:
340340

341341
def __init__(
342342
self,
343-
key,
343+
key: object,
344344
engine: PostgresEngine,
345345
table_name: str,
346346
content_column: str,
@@ -378,7 +378,7 @@ async def create(
378378
content_column: str = DEFAULT_CONTENT_COL,
379379
metadata_columns: List[str] = [],
380380
metadata_json_column: Optional[str] = DEFAULT_METADATA_COL,
381-
):
381+
) -> PostgresDocumentSaver:
382382
"""Create an PostgresDocumentSaver instance.
383383
384384
Args:
@@ -435,7 +435,7 @@ def create_sync(
435435
content_column: str = DEFAULT_CONTENT_COL,
436436
metadata_columns: List[str] = [],
437437
metadata_json_column: str = DEFAULT_METADATA_COL,
438-
):
438+
) -> PostgresDocumentSaver:
439439
"""Create an PostgresDocumentSaver instance.
440440
441441
Args:

src/langchain_google_cloud_sql_pg/vectorstore.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
import json
1919
import uuid
20-
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union
20+
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type, Union
2121

2222
import numpy as np
2323
from langchain_core.documents import Document
2424
from langchain_core.embeddings import Embeddings
2525
from langchain_core.vectorstores import VectorStore
26+
from sqlalchemy.engine.row import RowMapping
2627

2728
from .engine import PostgresEngine
2829
from .indexes import (
@@ -42,7 +43,7 @@ class PostgresVectorStore(VectorStore):
4243

4344
def __init__(
4445
self,
45-
key,
46+
key: object,
4647
engine: PostgresEngine,
4748
embedding_service: Embeddings,
4849
table_name: str,
@@ -114,7 +115,7 @@ async def create(
114115
fetch_k: int = 20,
115116
lambda_mult: float = 0.5,
116117
index_query_options: Optional[QueryOptions] = None,
117-
):
118+
) -> PostgresVectorStore:
118119
"""Create a new PostgresVectorStore instance.
119120
120121
Args:
@@ -218,7 +219,7 @@ def create_sync(
218219
fetch_k: int = 20,
219220
lambda_mult: float = 0.5,
220221
index_query_options: Optional[QueryOptions] = None,
221-
):
222+
) -> PostgresVectorStore:
222223
"""Create a new PostgresVectorStore instance.
223224
224225
Args:
@@ -496,7 +497,7 @@ def from_texts( # type: ignore[override]
496497
id_column: str = "langchain_id",
497498
metadata_json_column: str = "langchain_metadata",
498499
**kwargs: Any,
499-
):
500+
) -> PostgresVectorStore:
500501
"""Create an PostgresVectorStore instance from texts.
501502
Args:
502503
texts (List[str]): Texts to add to the vector store.
@@ -589,7 +590,7 @@ async def __query_collection(
589590
k: Optional[int] = None,
590591
filter: Optional[str] = None,
591592
**kwargs: Any,
592-
) -> List[Any]:
593+
) -> Sequence[RowMapping]:
593594
"""Perform similarity search query on the vector store table."""
594595
k = k if k else self.k
595596
operator = self.distance_strategy.operator

tests/test_postgresql_chatmessagehistory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,14 @@ def test_chat_message_history(memory_engine: PostgresEngine) -> None:
7979
assert len(history.messages) == 0
8080

8181

82-
def test_chat_table(memory_engine: Any):
82+
def test_chat_table(memory_engine: Any) -> None:
8383
with pytest.raises(ValueError):
8484
PostgresChatMessageHistory.create_sync(
8585
engine=memory_engine, session_id="test", table_name="doesnotexist"
8686
)
8787

8888

89-
def test_chat_schema(memory_engine: Any):
89+
def test_chat_schema(memory_engine: Any) -> None:
9090
doc_table_name = "test_table" + str(uuid.uuid4())
9191
memory_engine.init_document_table(table_name=doc_table_name)
9292
with pytest.raises(IndexError):

0 commit comments

Comments
 (0)