Skip to content

Commit ffaa87f

Browse files
feat: allow non-uuid data types for vectorstore primary key (#209)
* feat: allow non-uuid data types for vectorstore primary key * Update src/langchain_google_cloud_sql_pg/engine.py * Update src/langchain_google_cloud_sql_pg/engine.py * Update src/langchain_google_cloud_sql_pg/engine.py --------- Co-authored-by: Averi Kitsch <akitsch@google.com>
1 parent 7ef9335 commit ffaa87f

File tree

6 files changed

+272
-47
lines changed

6 files changed

+272
-47
lines changed

src/langchain_google_cloud_sql_pg/async_vectorstore.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,14 @@ async def __aadd_embeddings(
226226
texts: Iterable[str],
227227
embeddings: List[List[float]],
228228
metadatas: Optional[List[dict]] = None,
229-
ids: Optional[List[str]] = None,
229+
ids: Optional[List] = None,
230230
**kwargs: Any,
231231
) -> List[str]:
232-
"""Add embeddings to the table."""
232+
"""Add embeddings to the table.
233+
234+
Raises:
235+
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
236+
"""
233237
if not ids:
234238
ids = [str(uuid.uuid4()) for _ in texts]
235239
if not metadatas:
@@ -276,10 +280,14 @@ async def aadd_texts(
276280
self,
277281
texts: Iterable[str],
278282
metadatas: Optional[List[dict]] = None,
279-
ids: Optional[List[str]] = None,
283+
ids: Optional[List] = None,
280284
**kwargs: Any,
281285
) -> List[str]:
282-
"""Embed texts and add to the table."""
286+
"""Embed texts and add to the table.
287+
288+
Raises:
289+
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
290+
"""
283291
embeddings = self.embedding_service.embed_documents(list(texts))
284292
ids = await self.__aadd_embeddings(
285293
texts, embeddings, metadatas=metadatas, ids=ids, **kwargs
@@ -289,21 +297,29 @@ async def aadd_texts(
289297
async def aadd_documents(
290298
self,
291299
documents: List[Document],
292-
ids: Optional[List[str]] = None,
300+
ids: Optional[List] = None,
293301
**kwargs: Any,
294302
) -> List[str]:
295-
"""Embed documents and add to the table"""
303+
"""Embed documents and add to the table.
304+
305+
Raises:
306+
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
307+
"""
296308
texts = [doc.page_content for doc in documents]
297309
metadatas = [doc.metadata for doc in documents]
298310
ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
299311
return ids
300312

301313
async def adelete(
302314
self,
303-
ids: Optional[List[str]] = None,
315+
ids: Optional[List] = None,
304316
**kwargs: Any,
305317
) -> Optional[bool]:
306-
"""Delete records from the table."""
318+
"""Delete records from the table.
319+
320+
Raises:
321+
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
322+
"""
307323
if not ids:
308324
return False
309325

@@ -323,7 +339,7 @@ async def afrom_texts( # type: ignore[override]
323339
table_name: str,
324340
schema_name: str = "public",
325341
metadatas: Optional[List[dict]] = None,
326-
ids: Optional[List[str]] = None,
342+
ids: Optional[List] = None,
327343
content_column: str = "content",
328344
embedding_column: str = "embedding",
329345
metadata_columns: List[str] = [],
@@ -338,6 +354,7 @@ async def afrom_texts( # type: ignore[override]
338354
**kwargs: Any,
339355
) -> AsyncPostgresVectorStore:
340356
"""Create an AsyncPostgresVectorStore instance from texts.
357+
341358
Args:
342359
texts (List[str]): Texts to add to the vector store.
343360
embedding (Embeddings): Text embedding model to use.
@@ -358,6 +375,9 @@ async def afrom_texts( # type: ignore[override]
358375
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
359376
index_query_options (QueryOptions): Index query option.
360377
378+
Raises:
379+
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
380+
361381
Returns:
362382
AsyncPostgresVectorStore
363383
"""
@@ -389,7 +409,7 @@ async def afrom_documents( # type: ignore[override]
389409
engine: PostgresEngine,
390410
table_name: str,
391411
schema_name: str = "public",
392-
ids: Optional[List[str]] = None,
412+
ids: Optional[List] = None,
393413
content_column: str = "content",
394414
embedding_column: str = "embedding",
395415
metadata_columns: List[str] = [],
@@ -425,6 +445,9 @@ async def afrom_documents( # type: ignore[override]
425445
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
426446
index_query_options (QueryOptions): Index query option.
427447
448+
Raises:
449+
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
450+
428451
Returns:
429452
AsyncPostgresVectorStore
430453
"""
@@ -735,7 +758,7 @@ def add_texts(
735758
self,
736759
texts: Iterable[str],
737760
metadatas: Optional[List[dict]] = None,
738-
ids: Optional[List[str]] = None,
761+
ids: Optional[List] = None,
739762
**kwargs: Any,
740763
) -> List[str]:
741764
raise NotImplementedError(
@@ -745,7 +768,7 @@ def add_texts(
745768
def add_documents(
746769
self,
747770
documents: List[Document],
748-
ids: Optional[List[str]] = None,
771+
ids: Optional[List] = None,
749772
**kwargs: Any,
750773
) -> List[str]:
751774
raise NotImplementedError(
@@ -754,7 +777,7 @@ def add_documents(
754777

755778
def delete(
756779
self,
757-
ids: Optional[List[str]] = None,
780+
ids: Optional[List] = None,
758781
**kwargs: Any,
759782
) -> Optional[bool]:
760783
raise NotImplementedError(
@@ -769,7 +792,7 @@ def from_texts( # type: ignore[override]
769792
engine: PostgresEngine,
770793
table_name: str,
771794
metadatas: Optional[List[dict]] = None,
772-
ids: Optional[List[str]] = None,
795+
ids: Optional[List] = None,
773796
content_column: str = "content",
774797
embedding_column: str = "embedding",
775798
metadata_columns: List[str] = [],
@@ -789,7 +812,7 @@ def from_documents( # type: ignore[override]
789812
embedding: Embeddings,
790813
engine: PostgresEngine,
791814
table_name: str,
792-
ids: Optional[List[str]] = None,
815+
ids: Optional[List] = None,
793816
content_column: str = "content",
794817
embedding_column: str = "embedding",
795818
metadata_columns: List[str] = [],

src/langchain_google_cloud_sql_pg/engine.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ async def _ainit_vectorstore_table(
410410
embedding_column: str = "embedding",
411411
metadata_columns: List[Column] = [],
412412
metadata_json_column: str = "langchain_metadata",
413-
id_column: str = "langchain_id",
413+
id_column: Union[str, Column] = "langchain_id",
414414
overwrite_existing: bool = False,
415415
store_metadata: bool = True,
416416
) -> None:
@@ -430,14 +430,14 @@ async def _ainit_vectorstore_table(
430430
metadata. Default: []. Optional.
431431
metadata_json_column (str): The column to store extra metadata in JSON format.
432432
Default: "langchain_metadata". Optional.
433-
id_column (str): Name of the column to store ids.
434-
Default: "langchain_id". Optional,
433+
id_column (Union[str, Column]) : Column to store ids.
434+
Default: "langchain_id" column name with data type UUID. Optional.
435435
overwrite_existing (bool): Whether to drop existing table. Default: False.
436436
store_metadata (bool): Whether to store metadata in the table.
437437
Default: True.
438-
439438
Raises:
440439
:class:`DuplicateTableError <asyncpg.exceptions.DuplicateTableError>`: if table already exists and overwrite flag is not set.
440+
:class:`UndefinedObjectError <asyncpg.exceptions.UndefinedObjectError>`: if the data type of the id column is not a postgreSQL data type.
441441
"""
442442
async with self._pool.connect() as conn:
443443
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
@@ -450,8 +450,11 @@ async def _ainit_vectorstore_table(
450450
)
451451
await conn.commit()
452452

453+
id_data_type = "UUID" if isinstance(id_column, str) else id_column.data_type
454+
id_column_name = id_column if isinstance(id_column, str) else id_column.name
455+
453456
query = f"""CREATE TABLE "{schema_name}"."{table_name}"(
454-
"{id_column}" UUID PRIMARY KEY,
457+
"{id_column_name}" {id_data_type} PRIMARY KEY,
455458
"{content_column}" TEXT NOT NULL,
456459
"{embedding_column}" vector({vector_size}) NOT NULL"""
457460
for column in metadata_columns:
@@ -474,7 +477,7 @@ async def ainit_vectorstore_table(
474477
embedding_column: str = "embedding",
475478
metadata_columns: List[Column] = [],
476479
metadata_json_column: str = "langchain_metadata",
477-
id_column: str = "langchain_id",
480+
id_column: Union[str, Column] = "langchain_id",
478481
overwrite_existing: bool = False,
479482
store_metadata: bool = True,
480483
) -> None:
@@ -494,8 +497,8 @@ async def ainit_vectorstore_table(
494497
metadata. Default: []. Optional.
495498
metadata_json_column (str): The column to store extra metadata in JSON format.
496499
Default: "langchain_metadata". Optional.
497-
id_column (str): Name of the column to store ids.
498-
Default: "langchain_id". Optional,
500+
id_column (Union[str, Column]) : Column to store ids.
501+
Default: "langchain_id" column name with data type UUID. Optional.
499502
overwrite_existing (bool): Whether to drop existing table. Default: False.
500503
store_metadata (bool): Whether to store metadata in the table.
501504
Default: True.
@@ -524,7 +527,7 @@ def init_vectorstore_table(
524527
embedding_column: str = "embedding",
525528
metadata_columns: List[Column] = [],
526529
metadata_json_column: str = "langchain_metadata",
527-
id_column: str = "langchain_id",
530+
id_column: Union[str, Column] = "langchain_id",
528531
overwrite_existing: bool = False,
529532
store_metadata: bool = True,
530533
) -> None:
@@ -544,11 +547,13 @@ def init_vectorstore_table(
544547
metadata. Default: []. Optional.
545548
metadata_json_column (str): The column to store extra metadata in JSON format.
546549
Default: "langchain_metadata". Optional.
547-
id_column (str): Name of the column to store ids.
548-
Default: "langchain_id". Optional,
550+
id_column (Union[str, Column]) : Column to store ids.
551+
Default: "langchain_id" column name with data type UUID. Optional.
549552
overwrite_existing (bool): Whether to drop existing table. Default: False.
550553
store_metadata (bool): Whether to store metadata in the table.
551554
Default: True.
555+
Raises:
556+
:class:`UndefinedObjectError <asyncpg.exceptions.UndefinedObjectError>`: if the `ids` data type does not match that of the `id_column`.
552557
"""
553558
self._run_as_sync(
554559
self._ainit_vectorstore_table(

0 commit comments

Comments
 (0)