22import base64
33import json
44import threading
5- from typing import Callable , Optional
5+ from typing import Callable , Optional , Awaitable , Union , Any
6+ import asyncio
7+ from concurrent .futures import ThreadPoolExecutor
68
79from websockets .sync .client import connect
810
@@ -52,22 +54,133 @@ def interrupt(self):
5254 """
5355 pass
5456
57+
58+ class ClientTools :
59+ """Handles registration and execution of client-side tools that can be called by the agent.
60+
61+ Supports both synchronous and asynchronous tools running in a dedicated event loop,
62+ ensuring non-blocking operation of the main conversation thread.
63+ """
64+
65+ def __init__ (self ):
66+ self .tools : dict [str , tuple [Union [Callable [[dict ], Any ], Callable [[dict ], Awaitable [Any ]]], bool ]] = {}
67+ self .lock = threading .Lock ()
68+ self ._loop = None
69+ self ._thread = None
70+ self ._running = threading .Event ()
71+ self .thread_pool = ThreadPoolExecutor ()
72+
73+ def start (self ):
74+ """Start the event loop in a separate thread for handling async operations."""
75+ if self ._running .is_set ():
76+ return
77+
78+ def run_event_loop ():
79+ self ._loop = asyncio .new_event_loop ()
80+ asyncio .set_event_loop (self ._loop )
81+ self ._running .set ()
82+ try :
83+ self ._loop .run_forever ()
84+ finally :
85+ self ._running .clear ()
86+ self ._loop .close ()
87+ self ._loop = None
88+
89+ self ._thread = threading .Thread (target = run_event_loop , daemon = True , name = "ClientTools-EventLoop" )
90+ self ._thread .start ()
91+ # Wait for loop to be ready
92+ self ._running .wait ()
93+
94+ def stop (self ):
95+ """Gracefully stop the event loop and clean up resources."""
96+ if self ._loop and self ._running .is_set ():
97+ self ._loop .call_soon_threadsafe (self ._loop .stop )
98+ self ._thread .join ()
99+ self .thread_pool .shutdown (wait = False )
100+
101+ def register (
102+ self ,
103+ tool_name : str ,
104+ handler : Union [Callable [[dict ], Any ], Callable [[dict ], Awaitable [Any ]]],
105+ is_async : bool = False ,
106+ ) -> None :
107+ """Register a new tool that can be called by the AI agent.
108+
109+ Args:
110+ tool_name: Unique identifier for the tool
111+ handler: Function that implements the tool's logic
112+ is_async: Whether the handler is an async function
113+ """
114+ with self .lock :
115+ if not callable (handler ):
116+ raise ValueError ("Handler must be callable" )
117+ if tool_name in self .tools :
118+ raise ValueError (f"Tool '{ tool_name } ' is already registered" )
119+ self .tools [tool_name ] = (handler , is_async )
120+
121+ async def handle (self , tool_name : str , parameters : dict ) -> Any :
122+ """Execute a registered tool with the given parameters.
123+
124+ Returns the result of the tool execution.
125+ """
126+ with self .lock :
127+ if tool_name not in self .tools :
128+ raise ValueError (f"Tool '{ tool_name } ' is not registered" )
129+ handler , is_async = self .tools [tool_name ]
130+
131+ if is_async :
132+ return await handler (parameters )
133+ else :
134+ return await asyncio .get_event_loop ().run_in_executor (self .thread_pool , handler , parameters )
135+
136+ def execute_tool (self , tool_name : str , parameters : dict , callback : Callable [[dict ], None ]):
137+ """Execute a tool and send its result via the provided callback.
138+
139+ This method is non-blocking and handles both sync and async tools.
140+ """
141+ if not self ._running .is_set ():
142+ raise RuntimeError ("ClientTools event loop is not running" )
143+
144+ async def _execute_and_callback ():
145+ try :
146+ result = await self .handle (tool_name , parameters )
147+ response = {
148+ "type" : "client_tool_result" ,
149+ "tool_call_id" : parameters .get ("tool_call_id" ),
150+ "result" : result or f"Client tool: { tool_name } called successfully." ,
151+ "is_error" : False ,
152+ }
153+ except Exception as e :
154+ response = {
155+ "type" : "client_tool_result" ,
156+ "tool_call_id" : parameters .get ("tool_call_id" ),
157+ "result" : str (e ),
158+ "is_error" : True ,
159+ }
160+ callback (response )
161+
162+ asyncio .run_coroutine_threadsafe (_execute_and_callback (), self ._loop )
163+
164+
55165class ConversationConfig :
56166 """Configuration options for the Conversation."""
167+
57168 def __init__ (
58169 self ,
59170 extra_body : Optional [dict ] = None ,
60171 conversation_config_override : Optional [dict ] = None ,
61172 ):
62173 self .extra_body = extra_body or {}
63174 self .conversation_config_override = conversation_config_override or {}
64-
175+
176+
65177class Conversation :
66178 client : BaseElevenLabs
67179 agent_id : str
68180 requires_auth : bool
69181 config : ConversationConfig
70182 audio_interface : AudioInterface
183+ client_tools : Optional [ClientTools ]
71184 callback_agent_response : Optional [Callable [[str ], None ]]
72185 callback_agent_response_correction : Optional [Callable [[str , str ], None ]]
73186 callback_user_transcript : Optional [Callable [[str ], None ]]
@@ -86,7 +199,7 @@ def __init__(
86199 requires_auth : bool ,
87200 audio_interface : AudioInterface ,
88201 config : Optional [ConversationConfig ] = None ,
89-
202+ client_tools : Optional [ ClientTools ] = None ,
90203 callback_agent_response : Optional [Callable [[str ], None ]] = None ,
91204 callback_agent_response_correction : Optional [Callable [[str , str ], None ]] = None ,
92205 callback_user_transcript : Optional [Callable [[str ], None ]] = None ,
@@ -101,6 +214,7 @@ def __init__(
101214 agent_id: The ID of the agent to converse with.
102215 requires_auth: Whether the agent requires authentication.
103216 audio_interface: The audio interface to use for input and output.
217+ client_tools: The client tools to use for the conversation.
104218 callback_agent_response: Callback for agent responses.
105219 callback_agent_response_correction: Callback for agent response corrections.
106220 First argument is the original response (previously given to
@@ -112,14 +226,16 @@ def __init__(
112226 self .client = client
113227 self .agent_id = agent_id
114228 self .requires_auth = requires_auth
115-
116229 self .audio_interface = audio_interface
117230 self .callback_agent_response = callback_agent_response
118231 self .config = config or ConversationConfig ()
232+ self .client_tools = client_tools or ClientTools ()
119233 self .callback_agent_response_correction = callback_agent_response_correction
120234 self .callback_user_transcript = callback_user_transcript
121235 self .callback_latency_measurement = callback_latency_measurement
122236
237+ self .client_tools .start ()
238+
123239 self ._thread = None
124240 self ._should_stop = threading .Event ()
125241 self ._conversation_id = None
@@ -135,8 +251,9 @@ def start_session(self):
135251 self ._thread .start ()
136252
137253 def end_session (self ):
138- """Ends the conversation session."""
254+ """Ends the conversation session and cleans up resources ."""
139255 self .audio_interface .stop ()
256+ self .client_tools .stop ()
140257 self ._should_stop .set ()
141258
142259 def wait_for_session_end (self ) -> Optional [str ]:
@@ -155,10 +272,10 @@ def _run(self, ws_url: str):
155272 with connect (ws_url ) as ws :
156273 ws .send (
157274 json .dumps (
158- {
159- "type" : "conversation_initiation_client_data" ,
160- "custom_llm_extra_body" : self .config .extra_body ,
161- "conversation_config_override" : self .config .conversation_config_override ,
275+ {
276+ "type" : "conversation_initiation_client_data" ,
277+ "custom_llm_extra_body" : self .config .extra_body ,
278+ "conversation_config_override" : self .config .conversation_config_override ,
162279 }
163280 )
164281 )
@@ -210,7 +327,7 @@ def _handle_message(self, message, ws):
210327 self .callback_user_transcript (event ["user_transcript" ].strip ())
211328 elif message ["type" ] == "interruption" :
212329 event = message ["interruption_event" ]
213- self .last_interrupt_id = int (event ["event_id" ])
330+ self ._last_interrupt_id = int (event ["event_id" ])
214331 self .audio_interface .interrupt ()
215332 elif message ["type" ] == "ping" :
216333 event = message ["ping_event" ]
@@ -224,6 +341,16 @@ def _handle_message(self, message, ws):
224341 )
225342 if self .callback_latency_measurement and event ["ping_ms" ]:
226343 self .callback_latency_measurement (int (event ["ping_ms" ]))
344+ elif message ["type" ] == "client_tool_call" :
345+ tool_call = message .get ("client_tool_call" , {})
346+ tool_name = tool_call .get ("tool_name" )
347+ parameters = {"tool_call_id" : tool_call ["tool_call_id" ], ** tool_call .get ("parameters" , {})}
348+
349+ def send_response (response ):
350+ if not self ._should_stop .is_set ():
351+ ws .send (json .dumps (response ))
352+
353+ self .client_tools .execute_tool (tool_name , parameters , send_response )
227354 else :
228355 pass # Ignore all other message types.
229356
0 commit comments