Skip to content

Commit ed1c9fd

Browse files
refactor: parse embedding dimensions as float[] for numpyV2 (#277)
1 parent e576609 commit ed1c9fd

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

src/langchain_google_cloud_sql_pg/async_vectorstore.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,11 @@ async def __aadd_embeddings(
246246
else ""
247247
)
248248
insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{metadata_col_names}'
249-
values = {"id": id, "content": content, "embedding": str(embedding)}
249+
values = {
250+
"id": id,
251+
"content": content,
252+
"embedding": str([float(dimension) for dimension in embedding]),
253+
}
250254
values_stmt = "VALUES (:id, :content, :embedding"
251255

252256
# Add metadata
@@ -496,9 +500,9 @@ async def __query_collection(
496500
columns.append(self.metadata_json_column)
497501

498502
column_names = ", ".join(f'"{col}"' for col in columns)
499-
500503
filter = f"WHERE {filter}" if filter else ""
501-
stmt = f"SELECT {column_names}, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.schema_name}\".\"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};"
504+
embedding_string = f"'{[float(dimension) for dimension in embedding]}'"
505+
stmt = f'SELECT {column_names}, {search_function}({self.embedding_column}, {embedding_string}) as distance FROM "{self.schema_name}"."{self.table_name}" {filter} ORDER BY {self.embedding_column} {operator} {embedding_string} LIMIT {k};'
502506
if self.index_query_options:
503507
async with self.pool.connect() as conn:
504508
await conn.execute(

tests/test_engine.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ async def test_init_table(self, engine):
130130
id = str(uuid.uuid4())
131131
content = "coffee"
132132
embedding = await embeddings_service.aembed_query(content)
133-
stmt = f"INSERT INTO {DEFAULT_TABLE} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding}');"
133+
# Note: DeterministicFakeEmbedding generates a numpy array, converting to list a list of float values
134+
embedding_string = [float(dimension) for dimension in embedding]
135+
stmt = f"INSERT INTO {DEFAULT_TABLE} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding_string}');"
134136
await aexecute(engine, stmt)
135137

136138
async def test_init_table_custom(self, engine):
@@ -350,7 +352,9 @@ async def test_init_table(self, engine):
350352
id = str(uuid.uuid4())
351353
content = "coffee"
352354
embedding = await embeddings_service.aembed_query(content)
353-
stmt = f"INSERT INTO {DEFAULT_TABLE_SYNC} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding}');"
355+
# Note: DeterministicFakeEmbedding generates a numpy array, converting to list a list of float values
356+
embedding_string = [float(dimension) for dimension in embedding]
357+
stmt = f"INSERT INTO {DEFAULT_TABLE_SYNC} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding_string}');"
354358
await aexecute(engine, stmt)
355359

356360
async def test_init_table_custom(self, engine):

0 commit comments

Comments
 (0)