|
2 | 2 | import base64 |
3 | 3 | import json |
4 | 4 | import threading |
5 | | -from typing import Callable, Optional, Awaitable, Union, Any |
| 5 | +from typing import Callable, Optional, Awaitable, Union, Any, Literal, Dict, Tuple |
6 | 6 | import asyncio |
7 | 7 | from concurrent.futures import ThreadPoolExecutor |
| 8 | +from enum import Enum |
8 | 9 |
|
9 | | -from websockets.sync.client import connect, ClientConnection |
| 10 | +from websockets.sync.client import connect, Connection |
10 | 11 | from websockets.exceptions import ConnectionClosedOK |
11 | 12 |
|
12 | 13 | from ..base_client import BaseElevenLabs |
13 | 14 |
|
14 | 15 |
|
| 16 | +class ClientToOrchestratorEvent(str, Enum): |
| 17 | + """Event types that can be sent from client to orchestrator.""" |
| 18 | + # Response to a ping request. |
| 19 | + PONG = "pong" |
| 20 | + CLIENT_TOOL_RESULT = "client_tool_result" |
| 21 | + CONVERSATION_INITIATION_CLIENT_DATA = "conversation_initiation_client_data" |
| 22 | + FEEDBACK = "feedback" |
| 23 | + # Non-interrupting content that is sent to the server to update the conversation state. |
| 24 | + CONTEXTUAL_UPDATE = "contextual_update" |
| 25 | + # User text message. |
| 26 | + USER_MESSAGE = "user_message" |
| 27 | + USER_ACTIVITY = "user_activity" |
| 28 | + |
| 29 | + |
| 30 | +class UserMessageClientToOrchestratorEvent: |
| 31 | + """Event for sending user text messages.""" |
| 32 | + |
| 33 | + def __init__(self, text: Optional[str] = None): |
| 34 | + self.type: Literal[ClientToOrchestratorEvent.USER_MESSAGE] = ClientToOrchestratorEvent.USER_MESSAGE |
| 35 | + self.text = text |
| 36 | + |
| 37 | + def to_dict(self) -> dict: |
| 38 | + return { |
| 39 | + "type": self.type, |
| 40 | + "text": self.text |
| 41 | + } |
| 42 | + |
| 43 | + |
| 44 | +class UserActivityClientToOrchestratorEvent: |
| 45 | + """Event for registering user activity (ping to prevent timeout).""" |
| 46 | + |
| 47 | + def __init__(self) -> None: |
| 48 | + self.type: Literal[ClientToOrchestratorEvent.USER_ACTIVITY] = ClientToOrchestratorEvent.USER_ACTIVITY |
| 49 | + |
| 50 | + def to_dict(self) -> dict: |
| 51 | + return { |
| 52 | + "type": self.type |
| 53 | + } |
| 54 | + |
| 55 | + |
| 56 | +class ContextualUpdateClientToOrchestratorEvent: |
| 57 | + """Event for sending non-interrupting contextual updates to the conversation state.""" |
| 58 | + |
| 59 | + def __init__(self, content: str): |
| 60 | + self.type: Literal[ClientToOrchestratorEvent.CONTEXTUAL_UPDATE] = ClientToOrchestratorEvent.CONTEXTUAL_UPDATE |
| 61 | + self.content = content |
| 62 | + |
| 63 | + def to_dict(self) -> dict: |
| 64 | + return { |
| 65 | + "type": self.type, |
| 66 | + "content": self.content |
| 67 | + } |
| 68 | + |
| 69 | + |
15 | 70 | class AudioInterface(ABC): |
16 | 71 | """AudioInterface provides an abstraction for handling audio input and output.""" |
17 | 72 |
|
@@ -63,8 +118,8 @@ class ClientTools: |
63 | 118 | ensuring non-blocking operation of the main conversation thread. |
64 | 119 | """ |
65 | 120 |
|
66 | | - def __init__(self): |
67 | | - self.tools: dict[str, tuple[Union[Callable[[dict], Any], Callable[[dict], Awaitable[Any]]], bool]] = {} |
| 121 | + def __init__(self) -> None: |
| 122 | + self.tools: Dict[str, Tuple[Union[Callable[[dict], Any], Callable[[dict], Awaitable[Any]]], bool]] = {} |
68 | 123 | self.lock = threading.Lock() |
69 | 124 | self._loop = None |
70 | 125 | self._thread = None |
@@ -141,6 +196,9 @@ def execute_tool(self, tool_name: str, parameters: dict, callback: Callable[[dic |
141 | 196 | """ |
142 | 197 | if not self._running.is_set(): |
143 | 198 | raise RuntimeError("ClientTools event loop is not running") |
| 199 | + |
| 200 | + if self._loop is None: |
| 201 | + raise RuntimeError("Event loop is not available") |
144 | 202 |
|
145 | 203 | async def _execute_and_callback(): |
146 | 204 | try: |
@@ -193,6 +251,7 @@ class Conversation: |
193 | 251 | _should_stop: threading.Event |
194 | 252 | _conversation_id: Optional[str] |
195 | 253 | _last_interrupt_id: int |
| 254 | + _ws: Optional[Connection] |
196 | 255 |
|
197 | 256 | def __init__( |
198 | 257 | self, |
@@ -240,7 +299,7 @@ def __init__( |
240 | 299 | self.client_tools.start() |
241 | 300 |
|
242 | 301 | self._thread = None |
243 | | - self._ws: Optional[ClientConnection] = None |
| 302 | + self._ws: Optional[Connection] = None |
244 | 303 | self._should_stop = threading.Event() |
245 | 304 | self._conversation_id = None |
246 | 305 | self._last_interrupt_id = 0 |
@@ -273,8 +332,68 @@ def wait_for_session_end(self) -> Optional[str]: |
273 | 332 | self._thread.join() |
274 | 333 | return self._conversation_id |
275 | 334 |
|
| 335 | + def send_user_message(self, text: str): |
| 336 | + """Send a text message from the user to the agent. |
| 337 | + |
| 338 | + Args: |
| 339 | + text: The text message to send to the agent. |
| 340 | + |
| 341 | + Raises: |
| 342 | + RuntimeError: If the session is not active or websocket is not connected. |
| 343 | + """ |
| 344 | + if not self._ws: |
| 345 | + raise RuntimeError("Session not started or websocket not connected.") |
| 346 | + |
| 347 | + event = UserMessageClientToOrchestratorEvent(text=text) |
| 348 | + try: |
| 349 | + self._ws.send(json.dumps(event.to_dict())) |
| 350 | + except Exception as e: |
| 351 | + print(f"Error sending user message: {e}") |
| 352 | + raise |
| 353 | + |
| 354 | + def register_user_activity(self): |
| 355 | + """Register user activity to prevent session timeout. |
| 356 | + |
| 357 | + This sends a ping to the orchestrator to reset the timeout timer. |
| 358 | + |
| 359 | + Raises: |
| 360 | + RuntimeError: If the session is not active or websocket is not connected. |
| 361 | + """ |
| 362 | + if not self._ws: |
| 363 | + raise RuntimeError("Session not started or websocket not connected.") |
| 364 | + |
| 365 | + event = UserActivityClientToOrchestratorEvent() |
| 366 | + try: |
| 367 | + self._ws.send(json.dumps(event.to_dict())) |
| 368 | + except Exception as e: |
| 369 | + print(f"Error registering user activity: {e}") |
| 370 | + raise |
| 371 | + |
| 372 | + def send_contextual_update(self, content: str): |
| 373 | + """Send a contextual update to the conversation. |
| 374 | + |
| 375 | + Contextual updates are non-interrupting content that is sent to the server |
| 376 | + to update the conversation state without directly prompting the agent. |
| 377 | + |
| 378 | + Args: |
| 379 | + content: The contextual information to send to the conversation. |
| 380 | + |
| 381 | + Raises: |
| 382 | + RuntimeError: If the session is not active or websocket is not connected. |
| 383 | + """ |
| 384 | + if not self._ws: |
| 385 | + raise RuntimeError("Session not started or websocket not connected.") |
| 386 | + |
| 387 | + event = ContextualUpdateClientToOrchestratorEvent(content=content) |
| 388 | + try: |
| 389 | + self._ws.send(json.dumps(event.to_dict())) |
| 390 | + except Exception as e: |
| 391 | + print(f"Error sending contextual update: {e}") |
| 392 | + raise |
| 393 | + |
276 | 394 | def _run(self, ws_url: str): |
277 | 395 | with connect(ws_url, max_size=16 * 1024 * 1024) as ws: |
| 396 | + self._ws = ws |
278 | 397 | ws.send( |
279 | 398 | json.dumps( |
280 | 399 | { |
@@ -316,6 +435,8 @@ def input_callback(audio): |
316 | 435 | except Exception as e: |
317 | 436 | print(f"Error receiving message: {e}") |
318 | 437 | self.end_session() |
| 438 | + |
| 439 | + self._ws = None |
319 | 440 |
|
320 | 441 | def _handle_message(self, message, ws): |
321 | 442 | if message["type"] == "conversation_initiation_metadata": |
@@ -372,16 +493,6 @@ def send_response(response): |
372 | 493 | else: |
373 | 494 | pass # Ignore all other message types. |
374 | 495 |
|
375 | | - def send_contextual_update(self, text: str): |
376 | | - if not self._ws: |
377 | | - raise RuntimeError("WebSocket is not connected") |
378 | | - |
379 | | - payload = { |
380 | | - "type": "contextual_update", |
381 | | - "text": text, |
382 | | - } |
383 | | - self._ws.send(json.dumps(payload)) |
384 | | - |
385 | 496 | def _get_wss_url(self): |
386 | 497 | base_ws_url = self.client._client_wrapper.get_environment().wss |
387 | 498 | return f"{base_ws_url}/v1/convai/conversation?agent_id={self.agent_id}" |
|
0 commit comments