Skip to content

Commit 3938700

Browse files
authored
feat: Add tests to reach 90% coverage (#166)
* feat: Add tests to reach 90% coverage * replace alloydb occurances * add more tests * fix key value error * add is valid index test * test * set test orders * delete skip python version
1 parent 8e61bc7 commit 3938700

File tree

8 files changed

+146
-36
lines changed

8 files changed

+146
-36
lines changed

.coveragerc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ omit =
55

66
[report]
77
show_missing = true
8-
fail_under = 82
8+
fail_under = 90

DEVELOPER.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Learn more by reading [How should I write my commits?](https://github.com/google
2828

2929
### CI Platform Setup
3030

31-
Cloud Build is used to run tests against Google Cloud resources in test project: langchain-alloydb-testing.
31+
Cloud Build is used to run tests against Google Cloud resources in test project: langchain-cloud-sql-testing.
3232
Each test has a corresponding Cloud Build trigger, see [all triggers][triggers].
3333
These tests are registered as required tests in `.github/sync-repo-settings.yaml`.
3434

@@ -41,7 +41,7 @@ name: pg-integration-test-pr-py38
4141
description: Run integration tests on PR for Python 3.8
4242
filename: integration.cloudbuild.yaml
4343
github:
44-
name: langchain-google-alloydb-pg-python
44+
name: langchain-google-cloud-sql-pg-python
4545
owner: googleapis
4646
pullRequest:
4747
branch: .*

src/langchain_google_cloud_sql_pg/loader.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -555,29 +555,3 @@ def delete(self, docs: List[Document]) -> None:
555555
docs (List[langchain_core.documents.Document]): a list of documents to be deleted.
556556
"""
557557
self.engine._run_as_sync(self.adelete(docs))
558-
559-
async def _aload_table_schema(self) -> sqlalchemy.Table:
560-
"""
561-
Load table schema from existing table in PgSQL database.
562-
563-
Returns:
564-
(sqlalchemy.Table): The loaded table.
565-
"""
566-
metadata = sqlalchemy.MetaData()
567-
async with self.engine._engine.connect() as conn:
568-
await conn.run_sync(metadata.reflect, only=[self.table_name])
569-
570-
table = sqlalchemy.Table(self.table_name, metadata)
571-
# Extract the schema information
572-
schema = []
573-
for column in table.columns:
574-
schema.append(
575-
{
576-
"name": column.name,
577-
"type": column.type.python_type,
578-
"max_length": getattr(column.type, "length", None),
579-
"nullable": not column.nullable,
580-
}
581-
)
582-
583-
return metadata.tables[self.table_name]

src/langchain_google_cloud_sql_pg/vectorstore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ async def create(
183183
del all_columns[id_column]
184184
del all_columns[content_column]
185185
del all_columns[embedding_column]
186-
metadata_columns = [k for k, _ in all_columns.keys()]
186+
metadata_columns = [k for k in all_columns.keys()]
187187

188188
return cls(
189189
cls.__create_key,

tests/test_cloudsql_vectorstore.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,19 @@ async def vs_custom(self, engine):
133133
yield vs
134134
await engine._aexecute(f'DROP TABLE IF EXISTS "{CUSTOM_TABLE}"')
135135

136+
async def test_init_with_constructor(self, engine):
137+
with pytest.raises(Exception):
138+
PostgresVectorStore(
139+
engine,
140+
embedding_service=embeddings_service,
141+
table_name=CUSTOM_TABLE,
142+
id_column="myid",
143+
content_column="noname",
144+
embedding_column="myembedding",
145+
metadata_columns=["page", "source"],
146+
metadata_json_column="mymeta",
147+
)
148+
136149
async def test_post_init(self, engine):
137150
with pytest.raises(ValueError):
138151
await PostgresVectorStore.create(
@@ -265,4 +278,60 @@ async def test_add_texts(self, engine_sync, vs_sync):
265278
results = engine_sync._fetch(f'SELECT * FROM "{DEFAULT_TABLE_SYNC}"')
266279
assert len(results) == 6
267280

281+
async def test_ignore_metadata_columns(self, vs_custom):
282+
column_to_ignore = "source"
283+
vs = await PostgresVectorStore.create(
284+
vs_custom.engine,
285+
embedding_service=embeddings_service,
286+
table_name=CUSTOM_TABLE,
287+
ignore_metadata_columns=[column_to_ignore],
288+
id_column="myid",
289+
content_column="mycontent",
290+
embedding_column="myembedding",
291+
metadata_json_column="mymeta",
292+
)
293+
assert column_to_ignore not in vs.metadata_columns
294+
295+
async def test_create_vectorstore_with_invalid_parameters(self, vs_custom):
296+
with pytest.raises(ValueError):
297+
await PostgresVectorStore.create(
298+
vs_custom.engine,
299+
embedding_service=embeddings_service,
300+
table_name=CUSTOM_TABLE,
301+
id_column="myid",
302+
content_column="mycontent",
303+
embedding_column="myembedding",
304+
metadata_columns=["random_column"], # invalid metadata column
305+
)
306+
with pytest.raises(ValueError):
307+
await PostgresVectorStore.create(
308+
vs_custom.engine,
309+
embedding_service=embeddings_service,
310+
table_name=CUSTOM_TABLE,
311+
id_column="myid",
312+
content_column="langchain_id", # invalid content column type
313+
embedding_column="myembedding",
314+
metadata_columns=["random_column"],
315+
)
316+
with pytest.raises(ValueError):
317+
await PostgresVectorStore.create(
318+
vs_custom.engine,
319+
embedding_service=embeddings_service,
320+
table_name=CUSTOM_TABLE,
321+
id_column="myid",
322+
content_column="mycontent",
323+
embedding_column="random_column", # invalid embedding column
324+
metadata_columns=["random_column"],
325+
)
326+
with pytest.raises(ValueError):
327+
await PostgresVectorStore.create(
328+
vs_custom.engine,
329+
embedding_service=embeddings_service,
330+
table_name=CUSTOM_TABLE,
331+
id_column="myid",
332+
content_column="mycontent",
333+
embedding_column="langchain_id", # invalid embedding column data type
334+
metadata_columns=["random_column"],
335+
)
336+
268337
# Need tests for store metadata=False

tests/test_cloudsql_vectorstore_index.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,6 @@ def get_env_var(key: str, desc: str) -> str:
5555

5656

5757
@pytest.mark.asyncio(scope="class")
58-
@pytest.mark.skipif(
59-
sys.version_info != (3, 11),
60-
reason="To prevent index clashes only run on python3.11 or higher",
61-
)
6258
class TestIndex:
6359
@pytest.fixture(scope="module")
6460
def db_project(self) -> str:
@@ -101,11 +97,13 @@ async def vs(self, engine):
10197
await engine._aexecute(f"DROP TABLE IF EXISTS {DEFAULT_TABLE}")
10298
await engine._engine.dispose()
10399

100+
@pytest.mark.run(order=1)
104101
async def test_aapply_vector_index(self, vs):
105102
index = HNSWIndex()
106103
await vs.aapply_vector_index(index)
107104
assert await vs.is_valid_index(DEFAULT_INDEX_NAME)
108105

106+
@pytest.mark.run(order=2)
109107
async def test_areindex(self, vs):
110108
if not await vs.is_valid_index(DEFAULT_INDEX_NAME):
111109
index = HNSWIndex()
@@ -114,6 +112,7 @@ async def test_areindex(self, vs):
114112
await vs.areindex(DEFAULT_INDEX_NAME)
115113
assert await vs.is_valid_index(DEFAULT_INDEX_NAME)
116114

115+
@pytest.mark.run(order=3)
117116
async def test_dropindex(self, vs):
118117
await vs.adrop_vector_index()
119118
result = await vs.is_valid_index(DEFAULT_INDEX_NAME)
@@ -130,3 +129,7 @@ async def test_aapply_vector_index_ivfflat(self, vs):
130129
await vs.aapply_vector_index(index)
131130
assert await vs.is_valid_index("secondindex")
132131
await vs.adrop_vector_index("secondindex")
132+
133+
async def test_is_valid_index(self, vs):
134+
is_valid = await vs.is_valid_index("invalid_index")
135+
assert is_valid == False

tests/test_cloudsql_vectorstore_search.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from langchain_core.documents import Document
2222

2323
from langchain_google_cloud_sql_pg import Column, PostgresEngine, PostgresVectorStore
24-
from langchain_google_cloud_sql_pg.indexes import HNSWQueryOptions, IVFFlatQueryOptions
24+
from langchain_google_cloud_sql_pg.indexes import DistanceStrategy, HNSWQueryOptions
2525

2626
DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
2727
CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_")
@@ -151,7 +151,7 @@ async def test_asimilarity_search_by_vector(self, vs):
151151
assert results[0][0] == Document(page_content="foo")
152152
assert results[0][1] == 0
153153

154-
async def test_similarity_search_with_relevance_scores_threshold(self, vs):
154+
async def test_similarity_search_with_relevance_scores_threshold_cosine(self, vs):
155155
score_threshold = {"score_threshold": 0}
156156
results = await vs.asimilarity_search_with_relevance_scores(
157157
"foo", **score_threshold
@@ -171,6 +171,23 @@ async def test_similarity_search_with_relevance_scores_threshold(self, vs):
171171
assert len(results) == 1
172172
assert results[0][0] == Document(page_content="foo")
173173

174+
async def test_similarity_search_with_relevance_scores_threshold_euclidean(
175+
self, engine
176+
):
177+
vs = await PostgresVectorStore.create(
178+
engine,
179+
embedding_service=embeddings_service,
180+
table_name=DEFAULT_TABLE,
181+
distance_strategy=DistanceStrategy.EUCLIDEAN,
182+
)
183+
184+
score_threshold = {"score_threshold": 0.9}
185+
results = await vs.asimilarity_search_with_relevance_scores(
186+
"foo", **score_threshold
187+
)
188+
assert len(results) == 1
189+
assert results[0][0] == Document(page_content="foo")
190+
174191
async def test_amax_marginal_relevance_search(self, vs):
175192
results = await vs.amax_marginal_relevance_search("bar")
176193
assert results[0] == Document(page_content="bar")

tests/test_postgresql_loader.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,29 @@ async def _cleanup_table(self, engine):
6969
query = f'DROP TABLE IF EXISTS "{table_name}"'
7070
await engine._aexecute(query)
7171

72+
async def test_create_loader_with_invalid_parameters(self, engine):
73+
with pytest.raises(ValueError):
74+
await PostgresLoader.create(
75+
engine=engine,
76+
)
77+
with pytest.raises(ValueError):
78+
79+
def fake_formatter():
80+
return None
81+
82+
await PostgresLoader.create(
83+
engine=engine,
84+
table_name=table_name,
85+
format="text",
86+
formatter=fake_formatter,
87+
)
88+
with pytest.raises(ValueError):
89+
await PostgresLoader.create(
90+
engine=engine,
91+
table_name=table_name,
92+
format="fake_format",
93+
)
94+
7295
async def test_load_from_query_default(self, engine):
7396
try:
7497
await self._cleanup_table(engine)
@@ -216,6 +239,30 @@ async def test_load_from_query_customized_content_default_metadata(self, engine)
216239
)
217240
]
218241

242+
loader = await PostgresLoader.create(
243+
engine=engine,
244+
query=f'SELECT * FROM "{table_name}";',
245+
content_columns=[
246+
"variety",
247+
"quantity_in_stock",
248+
"price_per_unit",
249+
],
250+
format="JSON",
251+
)
252+
253+
documents = await self._collect_async_items(loader.alazy_load())
254+
255+
assert documents == [
256+
Document(
257+
page_content='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}',
258+
metadata={
259+
"fruit_id": 1,
260+
"fruit_name": "Apple",
261+
"organic": 1,
262+
},
263+
)
264+
]
265+
219266
finally:
220267
await self._cleanup_table(engine)
221268

0 commit comments

Comments
 (0)