Skip to content

Commit 1e0566a

Browse files
authored
feat: Add support for custom schema names (#191)
1 parent ec23308 commit 1e0566a

File tree

4 files changed

+107
-32
lines changed

4 files changed

+107
-32
lines changed

src/langchain_google_cloud_sql_pg/chat_message_history.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424

2525

2626
async def _aget_messages(
27-
engine: PostgresEngine, session_id: str, table_name: str
27+
engine: PostgresEngine,
28+
session_id: str,
29+
table_name: str,
30+
schema_name: str = "public",
2831
) -> List[BaseMessage]:
2932
"""Retrieve the messages from PostgreSQL."""
30-
query = f"""SELECT data, type FROM "{table_name}" WHERE session_id = :session_id ORDER BY id;"""
33+
query = f"""SELECT data, type FROM "{schema_name}"."{table_name}" WHERE session_id = :session_id ORDER BY id;"""
3134
results = await engine._afetch(query, {"session_id": session_id})
3235
if not results:
3336
return []
@@ -49,6 +52,7 @@ def __init__(
4952
session_id: str,
5053
table_name: str,
5154
messages: List[BaseMessage],
55+
schema_name: str = "public",
5256
):
5357
"""PostgresChatMessageHistory constructor.
5458
@@ -58,6 +62,7 @@ def __init__(
5862
session_id (str): Retrieve the table content with this session ID.
5963
table_name (str): Table name that stores the chat message history.
6064
messages (List[BaseMessage]): Messages to store.
65+
schema_name (str, optional): Database schema name of the chat message history table. Defaults to "public".
6166
6267
Raises:
6368
Exception: If constructor is directly called by the user.
@@ -70,73 +75,80 @@ def __init__(
7075
self.session_id = session_id
7176
self.table_name = table_name
7277
self.messages = messages
78+
self.schema_name = schema_name
7379

7480
@classmethod
7581
async def create(
7682
cls,
7783
engine: PostgresEngine,
7884
session_id: str,
7985
table_name: str,
86+
schema_name: str = "public",
8087
) -> PostgresChatMessageHistory:
8188
"""Create a new PostgresChatMessageHistory instance.
8289
8390
Args:
8491
engine (PostgresEngine): Postgres engine to use.
8592
session_id (str): Retrieve the table content with this session ID.
8693
table_name (str): Table name that stores the chat message history.
94+
schema_name (str, optional): Schema name for the chat message history table. Defaults to "public".
8795
8896
Raises:
8997
IndexError: If the table provided does not contain required schema.
9098
9199
Returns:
92100
PostgresChatMessageHistory: A newly created instance of PostgresChatMessageHistory.
93101
"""
94-
table_schema = await engine._aload_table_schema(table_name)
102+
table_schema = await engine._aload_table_schema(table_name, schema_name)
95103
column_names = table_schema.columns.keys()
96104

97105
required_columns = ["id", "session_id", "data", "type"]
98106

99107
if not (all(x in column_names for x in required_columns)):
100108
raise IndexError(
101-
f"Table '{table_name}' has incorrect schema. Got "
109+
f"Table '{schema_name}'.'{table_name}' has incorrect schema. Got "
102110
f"column names '{column_names}' but required column names "
103111
f"'{required_columns}'.\nPlease create table with following schema:"
104-
f"\nCREATE TABLE {table_name} ("
112+
f"\nCREATE TABLE {schema_name}.{table_name} ("
105113
"\n id INT AUTO_INCREMENT PRIMARY KEY,"
106114
"\n session_id TEXT NOT NULL,"
107115
"\n data JSON NOT NULL,"
108116
"\n type TEXT NOT NULL"
109117
"\n);"
110118
)
111-
messages = await _aget_messages(engine, session_id, table_name)
112-
return cls(cls.__create_key, engine, session_id, table_name, messages)
119+
messages = await _aget_messages(engine, session_id, table_name, schema_name)
120+
return cls(
121+
cls.__create_key, engine, session_id, table_name, messages, schema_name
122+
)
113123

114124
@classmethod
115125
def create_sync(
116126
cls,
117127
engine: PostgresEngine,
118128
session_id: str,
119129
table_name: str,
130+
schema_name: str = "public",
120131
) -> PostgresChatMessageHistory:
121132
"""Create a new PostgresChatMessageHistory instance.
122133
123134
Args:
124135
engine (PostgresEngine): Postgres engine to use.
125136
session_id (str): Retrieve the table content with this session ID.
126137
table_name (str): Table name that stores the chat message history.
138+
schema_name (str, optional): Database schema name for the chat message history table. Defaults to "public".
127139
128140
Raises:
129141
IndexError: If the table provided does not contain required schema.
130142
131143
Returns:
132144
PostgresChatMessageHistory: A newly created instance of PostgresChatMessageHistory.
133145
"""
134-
coro = cls.create(engine, session_id, table_name)
146+
coro = cls.create(engine, session_id, table_name, schema_name)
135147
return engine._run_as_sync(coro)
136148

137149
async def aadd_message(self, message: BaseMessage) -> None:
138150
"""Append the message to the record in PostgreSQL"""
139-
query = f"""INSERT INTO "{self.table_name}"(session_id, data, type)
151+
query = f"""INSERT INTO "{self.schema_name}"."{self.table_name}"(session_id, data, type)
140152
VALUES (:session_id, :data, :type);
141153
"""
142154
await self.engine._aexecute(
@@ -148,7 +160,7 @@ async def aadd_message(self, message: BaseMessage) -> None:
148160
},
149161
)
150162
self.messages = await _aget_messages(
151-
self.engine, self.session_id, self.table_name
163+
self.engine, self.session_id, self.table_name, self.schema_name
152164
)
153165

154166
def add_message(self, message: BaseMessage) -> None:
@@ -166,7 +178,7 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None:
166178

167179
async def aclear(self) -> None:
168180
"""Clear session memory from PostgreSQL"""
169-
query = f"""DELETE FROM "{self.table_name}" WHERE session_id = :session_id;"""
181+
query = f"""DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id;"""
170182
await self.engine._aexecute(query, {"session_id": self.session_id})
171183
self.messages = []
172184

@@ -177,7 +189,7 @@ def clear(self) -> None:
177189
async def async_messages(self) -> None:
178190
"""Retrieve the messages from Postgres."""
179191
self.messages = await _aget_messages(
180-
self.engine, self.session_id, self.table_name
192+
self.engine, self.session_id, self.table_name, self.schema_name
181193
)
182194

183195
def sync_messages(self) -> None:

src/langchain_google_cloud_sql_pg/engine.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ async def ainit_vectorstore_table(
373373
self,
374374
table_name: str,
375375
vector_size: int,
376+
schema_name: str = "public",
376377
content_column: str = "content",
377378
embedding_column: str = "embedding",
378379
metadata_columns: List[Column] = [],
@@ -387,6 +388,8 @@ async def ainit_vectorstore_table(
387388
Args:
388389
table_name (str): The Postgres database table name.
389390
vector_size (int): Vector size for the embedding model to be used.
391+
schema_name (str): The schema name to store Postgres database table.
392+
Default: "public".
390393
content_column (str): Name of the column to store document content.
391394
Default: "page_content".
392395
embedding_column (str) : Name of the column to store vector embeddings.
@@ -407,9 +410,9 @@ async def ainit_vectorstore_table(
407410
await self._aexecute("CREATE EXTENSION IF NOT EXISTS vector")
408411

409412
if overwrite_existing:
410-
await self._aexecute(f'DROP TABLE IF EXISTS "{table_name}"')
413+
await self._aexecute(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"')
411414

412-
query = f"""CREATE TABLE "{table_name}"(
415+
query = f"""CREATE TABLE "{schema_name}"."{table_name}"(
413416
"{id_column}" UUID PRIMARY KEY,
414417
"{content_column}" TEXT NOT NULL,
415418
"{embedding_column}" vector({vector_size}) NOT NULL"""
@@ -426,6 +429,7 @@ def init_vectorstore_table(
426429
self,
427430
table_name: str,
428431
vector_size: int,
432+
schema_name: str = "public",
429433
content_column: str = "content",
430434
embedding_column: str = "embedding",
431435
metadata_columns: List[Column] = [],
@@ -440,6 +444,8 @@ def init_vectorstore_table(
440444
Args:
441445
table_name (str): The Postgres database table name.
442446
vector_size (int): Vector size for the embedding model to be used.
447+
schema_name (str): The schema name to store Postgres database table.
448+
Default: "public".
443449
content_column (str): Name of the column to store document content.
444450
Default: "page_content".
445451
embedding_column (str) : Name of the column to store vector embeddings.
@@ -458,6 +464,7 @@ def init_vectorstore_table(
458464
self.ainit_vectorstore_table(
459465
table_name,
460466
vector_size,
467+
schema_name,
461468
content_column,
462469
embedding_column,
463470
metadata_columns,
@@ -468,41 +475,51 @@ def init_vectorstore_table(
468475
)
469476
)
470477

471-
async def ainit_chat_history_table(self, table_name: str) -> None:
478+
async def ainit_chat_history_table(
479+
self, table_name: str, schema_name: str = "public"
480+
) -> None:
472481
"""Create a Cloud SQL table to store chat history.
473482
474483
Args:
475484
table_name (str): Table name to store chat history.
485+
schema_name (str): Schema name to store chat history table.
486+
Default: "public".
476487
477488
Returns:
478489
None
479490
"""
480-
create_table_query = f"""CREATE TABLE IF NOT EXISTS "{table_name}"(
491+
create_table_query = f"""CREATE TABLE IF NOT EXISTS "{schema_name}"."{table_name}"(
481492
id SERIAL PRIMARY KEY,
482493
session_id TEXT NOT NULL,
483494
data JSONB NOT NULL,
484495
type TEXT NOT NULL
485496
);"""
486497
await self._aexecute(create_table_query)
487498

488-
def init_chat_history_table(self, table_name: str) -> None:
499+
def init_chat_history_table(
500+
self, table_name: str, schema_name: str = "public"
501+
) -> None:
489502
"""Create a Cloud SQL table to store chat history.
490503
491504
Args:
492505
table_name (str): Table name to store chat history.
506+
schema_name (str): Schema name to store chat history table.
507+
Default: "public".
493508
494509
Returns:
495510
None
496511
"""
497512
return self._run_as_sync(
498513
self.ainit_chat_history_table(
499514
table_name,
515+
schema_name,
500516
)
501517
)
502518

503519
async def ainit_document_table(
504520
self,
505521
table_name: str,
522+
schema_name: str = "public",
506523
content_column: str = "page_content",
507524
metadata_columns: List[Column] = [],
508525
metadata_json_column: str = "langchain_metadata",
@@ -513,6 +530,8 @@ async def ainit_document_table(
513530
514531
Args:
515532
table_name (str): The PgSQL database table name.
533+
schema_name (str): The schema name to store PgSQL database table.
534+
Default: "public".
516535
content_column (str): Name of the column to store document content.
517536
Default: "page_content".
518537
metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns
@@ -526,7 +545,7 @@ async def ainit_document_table(
526545
:class:`DuplicateTableError <asyncpg.exceptions.DuplicateTableError>`: if table already exists.
527546
"""
528547

529-
query = f"""CREATE TABLE "{table_name}"(
548+
query = f"""CREATE TABLE "{schema_name}"."{table_name}"(
530549
{content_column} TEXT NOT NULL
531550
"""
532551
for column in metadata_columns:
@@ -542,6 +561,7 @@ async def ainit_document_table(
542561
def init_document_table(
543562
self,
544563
table_name: str,
564+
schema_name: str = "public",
545565
content_column: str = "page_content",
546566
metadata_columns: List[Column] = [],
547567
metadata_json_column: str = "langchain_metadata",
@@ -552,6 +572,8 @@ def init_document_table(
552572
553573
Args:
554574
table_name (str): The PgSQL database table name.
575+
schema_name (str): The schema name to store PgSQL database table.
576+
Default: "public".
555577
content_column (str): Name of the column to store document content.
556578
metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns
557579
to create for custom metadata. Optional.
@@ -561,6 +583,7 @@ def init_document_table(
561583
return self._run_as_sync(
562584
self.ainit_document_table(
563585
table_name,
586+
schema_name,
564587
content_column,
565588
metadata_columns,
566589
metadata_json_column,
@@ -571,6 +594,7 @@ def init_document_table(
571594
async def _aload_table_schema(
572595
self,
573596
table_name: str,
597+
schema_name: str = "public",
574598
) -> Table:
575599
"""
576600
Load table schema from existing table in PgSQL database.
@@ -580,11 +604,15 @@ async def _aload_table_schema(
580604
metadata = MetaData()
581605
async with self._engine.connect() as conn:
582606
try:
583-
await conn.run_sync(metadata.reflect, only=[table_name])
607+
await conn.run_sync(
608+
metadata.reflect, schema=schema_name, only=[table_name]
609+
)
584610
except InvalidRequestError as e:
585-
raise ValueError(f"Table, {table_name}, does not exist: " + str(e))
611+
raise ValueError(
612+
f"Table, '{schema_name}'.'{table_name}', does not exist: " + str(e)
613+
)
586614

587-
table = Table(table_name, metadata)
615+
table = Table(table_name, metadata, schema=schema_name)
588616
# Extract the schema information
589617
schema = []
590618
for column in table.columns:
@@ -597,4 +625,4 @@ async def _aload_table_schema(
597625
}
598626
)
599627

600-
return metadata.tables[table_name]
628+
return metadata.tables[f"{schema_name}.{table_name}"]

0 commit comments

Comments
 (0)