Skip to content

Commit 8059390

Browse files
authored
chore(samples): update samples for async/sync refactor (#221)
* chore(samples): update samples for async/sync refactor * clean * fix * clean
1 parent 1737adc commit 8059390

File tree

5 files changed

+29
-14
lines changed

5 files changed

+29
-14
lines changed

samples/index_tuning_sample/create_vector_embeddings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ async def create_vector_store_table(documents):
106106
ids = [str(uuid.uuid4()) for i in range(len(documents))]
107107
await vector_store.aadd_documents(documents, ids)
108108
print("Vector table created.")
109+
await engine.close()
110+
await engine._connector.close()
109111

110112

111113
async def main():

samples/index_tuning_sample/index_search.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ async def query_vector_with_timing(vector_store, query):
8181

8282
async def hnsw_search(vector_store, knn_docs):
8383
hnsw_index = HNSWIndex(
84-
name="hnsw", distance_strategy=DISTANCE_STRATEGY, m=36, ef_construction=96
84+
name="hnsw",
85+
distance_strategy=DISTANCE_STRATEGY,
86+
m=36,
87+
ef_construction=96,
8588
)
8689
await vector_store.aapply_vector_index(hnsw_index)
8790
assert await vector_store.is_valid_index(hnsw_index.name)
@@ -156,6 +159,8 @@ async def main():
156159
print(
157160
f"IVFFLAT average recall: {ivfflat_average_recall} IVFFLAT latency: {ivfflat_average_latency}"
158161
)
162+
await vector_store._engine.close()
163+
await vector_store._engine._connector.close()
159164

160165

161166
if __name__ == "__main__":

samples/langchain_on_vertexai/clean_up.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
TABLE_NAME,
2525
USER,
2626
)
27+
from sqlalchemy import text
2728
from vertexai.preview import reasoning_engines # type: ignore
2829

2930
from langchain_google_cloud_sql_pg import PostgresEngine
@@ -41,10 +42,12 @@ async def delete_tables():
4142
password=PASSWORD,
4243
)
4344

44-
await engine._aexecute_outside_tx(f"DROP TABLE IF EXISTS {TABLE_NAME}")
45-
await engine._aexecute_outside_tx(f"DROP TABLE IF EXISTS {CHAT_TABLE_NAME}")
45+
async with engine._pool.connect() as conn:
46+
await conn.execute(text("COMMIT"))
47+
await conn.execute(text(f"DROP TABLE IF EXISTS {TABLE_NAME}"))
48+
await conn.execute(text(f"DROP TABLE IF EXISTS {CHAT_TABLE_NAME}"))
49+
await engine.close()
4650
await engine._connector.close_async()
47-
await engine._engine.dispose()
4851

4952

5053
def delete_engines():

samples/langchain_on_vertexai/create_embeddings.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import asyncio
15-
import os
1615
import uuid
1716

1817
from config import (
@@ -28,6 +27,7 @@
2827
from google.cloud import resourcemanager_v3 # type: ignore
2928
from langchain_community.document_loaders.csv_loader import CSVLoader
3029
from langchain_google_vertexai import VertexAIEmbeddings
30+
from sqlalchemy import text
3131

3232
from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore
3333

@@ -41,10 +41,11 @@ async def create_databases():
4141
user=USER,
4242
password=PASSWORD,
4343
)
44-
await engine._aexecute_outside_tx(f'DROP DATABASE IF EXISTS "{DATABASE}"')
45-
await engine._aexecute_outside_tx(f'CREATE DATABASE "{DATABASE}"')
46-
await engine._connector.close_async()
47-
await engine._engine.dispose()
44+
async with engine._pool.connect() as conn:
45+
await conn.execute(text("COMMIT"))
46+
await conn.execute(text(f'DROP DATABASE IF EXISTS "{DATABASE}"'))
47+
await conn.execute(text(f'CREATE DATABASE "{DATABASE}"'))
48+
await engine.close()
4849

4950

5051
async def create_vectorstore():
@@ -69,7 +70,13 @@ async def create_vectorstore():
6970
)
7071
project_number = res.name.split("/")[1]
7172
IAM_USER = f"service-{project_number}@gcp-sa-aiplatform-re.iam"
72-
await engine._aexecute(f'GRANT SELECT ON {TABLE_NAME} TO "{IAM_USER}";')
73+
74+
async def grant_select(engine):
75+
async with engine._pool.connect() as conn:
76+
await conn.execute(text(f'GRANT SELECT ON {TABLE_NAME} TO "{IAM_USER}";'))
77+
await conn.commit()
78+
79+
await engine._run_as_async(grant_select(engine))
7380

7481
metadata = [
7582
"show_id",
@@ -94,8 +101,6 @@ async def create_vectorstore():
94101

95102
ids = [str(uuid.uuid4()) for i in range(len(docs))]
96103
await vector_store.aadd_documents(docs, ids=ids)
97-
await engine._connector.close_async()
98-
await engine._engine.dispose()
99104

100105

101106
async def main():
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
google-cloud-aiplatform[reasoningengine,langchain]==1.65.0
1+
google-cloud-aiplatform[reasoningengine,langchain]==1.68.0
22
google-cloud-resource-manager==1.12.5
33
langchain-community==0.2.16
4-
langchain-google-cloud-sql-pg==0.9.0
4+
langchain-google-cloud-sql-pg==0.10.0
55
langchain-google-vertexai==1.0.10

0 commit comments

Comments
 (0)