Skip to content

Commit de16842

Browse files
authored
feat!: refactor to support both async and sync usage (#206)
* feat: separate Async only interface (#186) * feat: separate Async only interface * add tests * remove afetch/aexecute * respond to comments * update tests * add close * fix test * add tests * clean up * test * remove connector closing * lint * lint * feat: refactor vector store to wrap async class (#187) * feat: separate Async only interface * add tests * remove afetch/aexecute * respond to comments * update tests * add close * fix test * add tests * clean up * test * remove connector closing * lint * lint * feat: refactor vector store to wrap async class * rebase * refactor * remove changes * fix * update tests * respond to comments * feat: refactor chat message history (#192) * feat: separate Async only interface * add tests * remove afetch/aexecute * respond to comments * update tests * add close * fix test * add tests * clean up * test * remove connector closing * lint * lint * feat: refactor vector store to wrap async class * rebase * refactor * remove changes * fix * update tests * feat: refactor chat message history * lint * feat: refactor loader (#193) * feat: separate Async only interface * add tests * remove afetch/aexecute * respond to comments * update tests * add close * fix test * add tests * clean up * test * remove connector closing * lint * lint * feat: refactor vector store to wrap async class * rebase * refactor * remove changes * fix * update tests * feat: refactor loader * lint * update async tests * lint * feat: add from_engine_args (#194) * feat: add from_engine_args * lint * Update test_engine.py * add support for loop * add chat tests * tests * add proxy * Debug * fix tests * clean up * clean up * use wget * fix version * debug * feat: ensure schema support for refactor * fix
1 parent 0651231 commit de16842

21 files changed

+4655
-1169
lines changed

integration.cloudbuild.yaml

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,30 @@ steps:
2323
entrypoint: pip
2424
args: ["install", ".[test]", "--user"]
2525

26+
- id: proxy-install
27+
name: alpine:3.10
28+
entrypoint: sh
29+
args:
30+
- -c
31+
- |
32+
wget -O /workspace/cloud_sql_proxy https://storage.googleapis.com/cloudsql-proxy/v1.37.0/cloud_sql_proxy.linux.386
33+
chmod +x /workspace/cloud_sql_proxy
34+
2635
- id: Run integration tests
2736
name: python:${_VERSION}
28-
entrypoint: python
29-
args: ["-m", "pytest", "--cov=langchain_google_cloud_sql_pg", "--cov-config=.coveragerc", "tests/"]
37+
entrypoint: /bin/bash
3038
env:
3139
- "PROJECT_ID=$PROJECT_ID"
3240
- "INSTANCE_ID=$_INSTANCE_ID"
3341
- "DATABASE_ID=$_DATABASE_ID"
3442
- "REGION=$_REGION"
43+
- "IP_ADDRESS=$_IP_ADDRESS"
3544
secretEnv: ["DB_USER", "DB_PASSWORD", "IAM_ACCOUNT"]
45+
args:
46+
- "-c"
47+
- |
48+
/workspace/cloud_sql_proxy -dir=/workspace -instances=${_INSTANCE_CONNECTION_NAME}=tcp:$_IP_ADDRESS:$_DATABASE_PORT & sleep 2;
49+
python -m pytest --cov=langchain_google_cloud_sql_pg --cov-config=.coveragerc tests/
3650
3751
availableSecrets:
3852
secretManager:
@@ -44,9 +58,12 @@ availableSecrets:
4458
env: "IAM_ACCOUNT"
4559

4660
substitutions:
61+
_INSTANCE_CONNECTION_NAME: ${PROJECT_ID}:${_REGION}:${_INSTANCE_ID}
62+
_DATABASE_PORT: "5432"
4763
_DATABASE_ID: test-database
4864
_REGION: us-central1
4965
_VERSION: "3.8"
66+
_IP_ADDRESS: "127.0.0.1"
5067

5168
options:
5269
dynamicSubstitutions: true
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import json
18+
from typing import List, Sequence
19+
20+
from langchain_core.chat_history import BaseChatMessageHistory
21+
from langchain_core.messages import BaseMessage, messages_from_dict
22+
from sqlalchemy import text
23+
from sqlalchemy.ext.asyncio import AsyncEngine
24+
25+
from .engine import PostgresEngine
26+
27+
28+
class AsyncPostgresChatMessageHistory(BaseChatMessageHistory):
29+
"""Chat message history stored in an Cloud SQL for PostgreSQL database."""
30+
31+
__create_key = object()
32+
33+
def __init__(
34+
self,
35+
key: object,
36+
pool: AsyncEngine,
37+
session_id: str,
38+
table_name: str,
39+
schema_name: str = "public",
40+
):
41+
"""AsyncPostgresChatMessageHistory constructor.
42+
43+
Args:
44+
key (object): Key to prevent direct constructor usage.
45+
engine (PostgresEngine): Database connection pool.
46+
session_id (str): Retrieve the table content with this session ID.
47+
table_name (str): Table name that stores the chat message history.
48+
schema_name (str, optional): Database schema name of the chat message history table. Defaults to "public".
49+
50+
Raises:
51+
Exception: If constructor is directly called by the user.
52+
"""
53+
if key != AsyncPostgresChatMessageHistory.__create_key:
54+
raise Exception(
55+
"Only create class through 'create' or 'create_sync' methods!"
56+
)
57+
self.pool = pool
58+
self.session_id = session_id
59+
self.table_name = table_name
60+
self.schema_name = schema_name
61+
62+
@classmethod
63+
async def create(
64+
cls,
65+
engine: PostgresEngine,
66+
session_id: str,
67+
table_name: str,
68+
schema_name: str = "public",
69+
) -> AsyncPostgresChatMessageHistory:
70+
"""Create a new AsyncPostgresChatMessageHistory instance.
71+
72+
Args:
73+
engine (PostgresEngine): Postgres engine to use.
74+
session_id (str): Retrieve the table content with this session ID.
75+
table_name (str): Table name that stores the chat message history.
76+
schema_name (str, optional): Database schema name for the chat message history table. Defaults to "public".
77+
78+
Raises:
79+
IndexError: If the table provided does not contain required schema.
80+
81+
Returns:
82+
AsyncPostgresChatMessageHistory: A newly created instance of AsyncPostgresChatMessageHistory.
83+
"""
84+
table_schema = await engine._aload_table_schema(table_name, schema_name)
85+
column_names = table_schema.columns.keys()
86+
87+
required_columns = ["id", "session_id", "data", "type"]
88+
89+
if not (all(x in column_names for x in required_columns)):
90+
raise IndexError(
91+
f"Table '{schema_name}'.'{table_name}' has incorrect schema. Got "
92+
f"column names '{column_names}' but required column names "
93+
f"'{required_columns}'.\nPlease create table with following schema:"
94+
f"\nCREATE TABLE {schema_name}.{table_name} ("
95+
"\n id INT AUTO_INCREMENT PRIMARY KEY,"
96+
"\n session_id TEXT NOT NULL,"
97+
"\n data JSON NOT NULL,"
98+
"\n type TEXT NOT NULL"
99+
"\n);"
100+
)
101+
return cls(cls.__create_key, engine._pool, session_id, table_name)
102+
103+
async def aadd_message(self, message: BaseMessage) -> None:
104+
"""Append the message to the record in PostgreSQL"""
105+
query = f"""INSERT INTO "{self.schema_name}"."{self.table_name}"(session_id, data, type)
106+
VALUES (:session_id, :data, :type);
107+
"""
108+
async with self.pool.connect() as conn:
109+
await conn.execute(
110+
text(query),
111+
{
112+
"session_id": self.session_id,
113+
"data": json.dumps(message.dict()),
114+
"type": message.type,
115+
},
116+
)
117+
await conn.commit()
118+
119+
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
120+
"""Append a list of messages to the record in PostgreSQL"""
121+
for message in messages:
122+
await self.aadd_message(message)
123+
124+
async def aclear(self) -> None:
125+
"""Clear session memory from PostgreSQL"""
126+
query = f"""DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id;"""
127+
async with self.pool.connect() as conn:
128+
await conn.execute(text(query), {"session_id": self.session_id})
129+
await conn.commit()
130+
131+
async def _aget_messages(self) -> List[BaseMessage]:
132+
"""Retrieve the messages from PostgreSQL."""
133+
query = f"""SELECT data, type FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id ORDER BY id;"""
134+
async with self.pool.connect() as conn:
135+
result = await conn.execute(text(query), {"session_id": self.session_id})
136+
result_map = result.mappings()
137+
results = result_map.fetchall()
138+
if not results:
139+
return []
140+
141+
items = [{"data": result["data"], "type": result["type"]} for result in results]
142+
messages = messages_from_dict(items)
143+
return messages
144+
145+
def clear(self) -> None:
146+
raise NotImplementedError(
147+
"Sync methods are not implemented for AsyncPostgresChatMessageHistory. Use PostgresChatMessageHistory interface instead."
148+
)

0 commit comments

Comments
 (0)