Skip to content

Commit e12fc63

Browse files
authored
feat: add support for text messages, user activity and contextual updates (elevenlabs#565)
* feat: add support for text messages, user activity and contextual updates * fix * add annotations to init methods * fix * use inbuilt types * fix
1 parent 615f5ab commit e12fc63

File tree

2 files changed

+127
-16
lines changed

2 files changed

+127
-16
lines changed

src/elevenlabs/conversational_ai/conversation.py

Lines changed: 126 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,71 @@
22
import base64
33
import json
44
import threading
5-
from typing import Callable, Optional, Awaitable, Union, Any
5+
from typing import Callable, Optional, Awaitable, Union, Any, Literal, Dict, Tuple
66
import asyncio
77
from concurrent.futures import ThreadPoolExecutor
8+
from enum import Enum
89

9-
from websockets.sync.client import connect, ClientConnection
10+
from websockets.sync.client import connect, Connection
1011
from websockets.exceptions import ConnectionClosedOK
1112

1213
from ..base_client import BaseElevenLabs
1314

1415

16+
class ClientToOrchestratorEvent(str, Enum):
17+
"""Event types that can be sent from client to orchestrator."""
18+
# Response to a ping request.
19+
PONG = "pong"
20+
CLIENT_TOOL_RESULT = "client_tool_result"
21+
CONVERSATION_INITIATION_CLIENT_DATA = "conversation_initiation_client_data"
22+
FEEDBACK = "feedback"
23+
# Non-interrupting content that is sent to the server to update the conversation state.
24+
CONTEXTUAL_UPDATE = "contextual_update"
25+
# User text message.
26+
USER_MESSAGE = "user_message"
27+
USER_ACTIVITY = "user_activity"
28+
29+
30+
class UserMessageClientToOrchestratorEvent:
31+
"""Event for sending user text messages."""
32+
33+
def __init__(self, text: Optional[str] = None):
34+
self.type: Literal[ClientToOrchestratorEvent.USER_MESSAGE] = ClientToOrchestratorEvent.USER_MESSAGE
35+
self.text = text
36+
37+
def to_dict(self) -> dict:
38+
return {
39+
"type": self.type,
40+
"text": self.text
41+
}
42+
43+
44+
class UserActivityClientToOrchestratorEvent:
45+
"""Event for registering user activity (ping to prevent timeout)."""
46+
47+
def __init__(self) -> None:
48+
self.type: Literal[ClientToOrchestratorEvent.USER_ACTIVITY] = ClientToOrchestratorEvent.USER_ACTIVITY
49+
50+
def to_dict(self) -> dict:
51+
return {
52+
"type": self.type
53+
}
54+
55+
56+
class ContextualUpdateClientToOrchestratorEvent:
57+
"""Event for sending non-interrupting contextual updates to the conversation state."""
58+
59+
def __init__(self, content: str):
60+
self.type: Literal[ClientToOrchestratorEvent.CONTEXTUAL_UPDATE] = ClientToOrchestratorEvent.CONTEXTUAL_UPDATE
61+
self.content = content
62+
63+
def to_dict(self) -> dict:
64+
return {
65+
"type": self.type,
66+
"content": self.content
67+
}
68+
69+
1570
class AudioInterface(ABC):
1671
"""AudioInterface provides an abstraction for handling audio input and output."""
1772

@@ -63,8 +118,8 @@ class ClientTools:
63118
ensuring non-blocking operation of the main conversation thread.
64119
"""
65120

66-
def __init__(self):
67-
self.tools: dict[str, tuple[Union[Callable[[dict], Any], Callable[[dict], Awaitable[Any]]], bool]] = {}
121+
def __init__(self) -> None:
122+
self.tools: Dict[str, Tuple[Union[Callable[[dict], Any], Callable[[dict], Awaitable[Any]]], bool]] = {}
68123
self.lock = threading.Lock()
69124
self._loop = None
70125
self._thread = None
@@ -141,6 +196,9 @@ def execute_tool(self, tool_name: str, parameters: dict, callback: Callable[[dic
141196
"""
142197
if not self._running.is_set():
143198
raise RuntimeError("ClientTools event loop is not running")
199+
200+
if self._loop is None:
201+
raise RuntimeError("Event loop is not available")
144202

145203
async def _execute_and_callback():
146204
try:
@@ -193,6 +251,7 @@ class Conversation:
193251
_should_stop: threading.Event
194252
_conversation_id: Optional[str]
195253
_last_interrupt_id: int
254+
_ws: Optional[Connection]
196255

197256
def __init__(
198257
self,
@@ -240,7 +299,7 @@ def __init__(
240299
self.client_tools.start()
241300

242301
self._thread = None
243-
self._ws: Optional[ClientConnection] = None
302+
self._ws: Optional[Connection] = None
244303
self._should_stop = threading.Event()
245304
self._conversation_id = None
246305
self._last_interrupt_id = 0
@@ -273,8 +332,68 @@ def wait_for_session_end(self) -> Optional[str]:
273332
self._thread.join()
274333
return self._conversation_id
275334

335+
def send_user_message(self, text: str):
336+
"""Send a text message from the user to the agent.
337+
338+
Args:
339+
text: The text message to send to the agent.
340+
341+
Raises:
342+
RuntimeError: If the session is not active or websocket is not connected.
343+
"""
344+
if not self._ws:
345+
raise RuntimeError("Session not started or websocket not connected.")
346+
347+
event = UserMessageClientToOrchestratorEvent(text=text)
348+
try:
349+
self._ws.send(json.dumps(event.to_dict()))
350+
except Exception as e:
351+
print(f"Error sending user message: {e}")
352+
raise
353+
354+
def register_user_activity(self):
355+
"""Register user activity to prevent session timeout.
356+
357+
This sends a ping to the orchestrator to reset the timeout timer.
358+
359+
Raises:
360+
RuntimeError: If the session is not active or websocket is not connected.
361+
"""
362+
if not self._ws:
363+
raise RuntimeError("Session not started or websocket not connected.")
364+
365+
event = UserActivityClientToOrchestratorEvent()
366+
try:
367+
self._ws.send(json.dumps(event.to_dict()))
368+
except Exception as e:
369+
print(f"Error registering user activity: {e}")
370+
raise
371+
372+
def send_contextual_update(self, content: str):
373+
"""Send a contextual update to the conversation.
374+
375+
Contextual updates are non-interrupting content that is sent to the server
376+
to update the conversation state without directly prompting the agent.
377+
378+
Args:
379+
content: The contextual information to send to the conversation.
380+
381+
Raises:
382+
RuntimeError: If the session is not active or websocket is not connected.
383+
"""
384+
if not self._ws:
385+
raise RuntimeError("Session not started or websocket not connected.")
386+
387+
event = ContextualUpdateClientToOrchestratorEvent(content=content)
388+
try:
389+
self._ws.send(json.dumps(event.to_dict()))
390+
except Exception as e:
391+
print(f"Error sending contextual update: {e}")
392+
raise
393+
276394
def _run(self, ws_url: str):
277395
with connect(ws_url, max_size=16 * 1024 * 1024) as ws:
396+
self._ws = ws
278397
ws.send(
279398
json.dumps(
280399
{
@@ -316,6 +435,8 @@ def input_callback(audio):
316435
except Exception as e:
317436
print(f"Error receiving message: {e}")
318437
self.end_session()
438+
439+
self._ws = None
319440

320441
def _handle_message(self, message, ws):
321442
if message["type"] == "conversation_initiation_metadata":
@@ -372,16 +493,6 @@ def send_response(response):
372493
else:
373494
pass # Ignore all other message types.
374495

375-
def send_contextual_update(self, text: str):
376-
if not self._ws:
377-
raise RuntimeError("WebSocket is not connected")
378-
379-
payload = {
380-
"type": "contextual_update",
381-
"text": text,
382-
}
383-
self._ws.send(json.dumps(payload))
384-
385496
def _get_wss_url(self):
386497
base_ws_url = self.client._client_wrapper.get_environment().wss
387498
return f"{base_ws_url}/v1/convai/conversation?agent_id={self.agent_id}"

tests/test_convai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,5 +191,5 @@ def test_conversation_with_contextual_update():
191191
conversation.wait_for_session_end()
192192

193193
# Assertions
194-
expected = json.dumps({"type": "contextual_update", "text": "User appears to be looking at pricing page"})
194+
expected = json.dumps({"type": "contextual_update", "content": "User appears to be looking at pricing page"})
195195
mock_ws.send.assert_any_call(expected)

0 commit comments

Comments
 (0)