1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import asyncio
1415import os
1516import uuid
17+ from typing import Any , Coroutine
1618
1719import pytest
1820import pytest_asyncio
3335table_name_async = "message_store" + str (uuid .uuid4 ())
3436
3537
38+ # Helper to bridge the Main Test Loop and the Engine Background Loop
39+ async def run_on_background (engine : PostgresEngine , coro : Coroutine ) -> Any :
40+ """Runs a coroutine on the engine's background loop."""
41+ if engine ._loop :
42+ return await asyncio .wrap_future (
43+ asyncio .run_coroutine_threadsafe (coro , engine ._loop )
44+ )
45+ return await coro
46+
47+
3648async def aexecute (engine : PostgresEngine , query : str ) -> None :
37- async with engine ._pool .connect () as conn :
38- await conn .execute (text (query ))
39- await conn .commit ()
49+ async def _impl ():
50+ async with engine ._pool .connect () as conn :
51+ await conn .execute (text (query ))
52+ await conn .commit ()
53+
54+ await run_on_background (engine , _impl ())
4055
4156
4257@pytest_asyncio .fixture
@@ -47,7 +62,10 @@ async def async_engine():
4762 instance = instance_id ,
4863 database = db_name ,
4964 )
50- await async_engine ._ainit_chat_history_table (table_name = table_name_async )
65+ await run_on_background (
66+ async_engine ,
67+ async_engine ._ainit_chat_history_table (table_name = table_name_async ),
68+ )
5169 yield async_engine
5270 # use default table for AsyncPostgresChatMessageHistory
5371 query = f'DROP TABLE IF EXISTS "{ table_name_async } "'
@@ -59,14 +77,19 @@ async def async_engine():
5977async def test_chat_message_history_async (
6078 async_engine : PostgresEngine ,
6179) -> None :
62- history = await AsyncPostgresChatMessageHistory .create (
63- engine = async_engine , session_id = "test" , table_name = table_name_async
80+ history = await run_on_background (
81+ async_engine ,
82+ AsyncPostgresChatMessageHistory .create (
83+ engine = async_engine , session_id = "test" , table_name = table_name_async
84+ ),
6485 )
6586 msg1 = HumanMessage (content = "hi!" )
6687 msg2 = AIMessage (content = "whats up?" )
67- await history .aadd_message (msg1 )
68- await history .aadd_message (msg2 )
69- messages = await history ._aget_messages ()
88+
89+ await run_on_background (async_engine , history .aadd_message (msg1 ))
90+ await run_on_background (async_engine , history .aadd_message (msg2 ))
91+
92+ messages = await run_on_background (async_engine , history ._aget_messages ())
7093
7194 # verify messages are correct
7295 assert messages [0 ].content == "hi!"
@@ -75,48 +98,71 @@ async def test_chat_message_history_async(
7598 assert type (messages [1 ]) is AIMessage
7699
77100 # verify clear() clears message history
78- await history .aclear ()
79- assert len (await history ._aget_messages ()) == 0
101+ await run_on_background (async_engine , history .aclear ())
102+ messages_after_clear = await run_on_background (
103+ async_engine , history ._aget_messages ()
104+ )
105+ assert len (messages_after_clear ) == 0
80106
81107
82108@pytest .mark .asyncio
83109async def test_chat_message_history_sync_messages (
84110 async_engine : PostgresEngine ,
85111) -> None :
86- history1 = await AsyncPostgresChatMessageHistory .create (
87- engine = async_engine , session_id = "test" , table_name = table_name_async
112+ history1 = await run_on_background (
113+ async_engine ,
114+ AsyncPostgresChatMessageHistory .create (
115+ engine = async_engine , session_id = "test" , table_name = table_name_async
116+ ),
88117 )
89- history2 = await AsyncPostgresChatMessageHistory .create (
90- engine = async_engine , session_id = "test" , table_name = table_name_async
118+ history2 = await run_on_background (
119+ async_engine ,
120+ AsyncPostgresChatMessageHistory .create (
121+ engine = async_engine , session_id = "test" , table_name = table_name_async
122+ ),
91123 )
92124 msg1 = HumanMessage (content = "hi!" )
93125 msg2 = AIMessage (content = "whats up?" )
94- await history1 .aadd_message (msg1 )
95- await history2 .aadd_message (msg2 )
126+ await run_on_background (async_engine , history1 .aadd_message (msg1 ))
127+ await run_on_background (async_engine , history2 .aadd_message (msg2 ))
128+
129+ len_history1 = len (await run_on_background (async_engine , history1 ._aget_messages ()))
130+ len_history2 = len (await run_on_background (async_engine , history2 ._aget_messages ()))
96131
97- assert len ( await history1 . _aget_messages ()) == 2
98- assert len ( await history2 . _aget_messages ()) == 2
132+ assert len_history1 == 2
133+ assert len_history2 == 2
99134
100135 # verify clear() clears message history
101- await history2 .aclear ()
102- assert len (await history2 ._aget_messages ()) == 0
136+ await run_on_background (async_engine , history2 .aclear ())
137+ len_history2_after_clear = len (
138+ await run_on_background (async_engine , history2 ._aget_messages ())
139+ )
140+ assert len_history2_after_clear == 0
103141
104142
105143@pytest .mark .asyncio
106144async def test_chat_table_async (async_engine ):
107145 with pytest .raises (ValueError ):
108- await AsyncPostgresChatMessageHistory .create (
109- engine = async_engine , session_id = "test" , table_name = "doesnotexist"
146+ await run_on_background (
147+ async_engine ,
148+ AsyncPostgresChatMessageHistory .create (
149+ engine = async_engine , session_id = "test" , table_name = "doesnotexist"
150+ ),
110151 )
111152
112153
113154@pytest .mark .asyncio
114155async def test_chat_schema_async (async_engine ):
115156 table_name = "test_table" + str (uuid .uuid4 ())
116- await async_engine ._ainit_document_table (table_name = table_name )
157+ await run_on_background (
158+ async_engine , async_engine ._ainit_document_table (table_name = table_name )
159+ )
117160 with pytest .raises (IndexError ):
118- await AsyncPostgresChatMessageHistory .create (
119- engine = async_engine , session_id = "test" , table_name = table_name
161+ await run_on_background (
162+ async_engine ,
163+ AsyncPostgresChatMessageHistory .create (
164+ engine = async_engine , session_id = "test" , table_name = table_name
165+ ),
120166 )
121167
122168 query = f'DROP TABLE IF EXISTS "{ table_name } "'
0 commit comments