Skip to content

Commit 4e450f8

Browse files
authored
(feat): Support for text iterators in AsyncElevenLabs (elevenlabs#346)
1 parent 6db2fdd commit 4e450f8

File tree

3 files changed

+180
-12
lines changed

3 files changed

+180
-12
lines changed

src/elevenlabs/client.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .types import Voice, VoiceSettings, \
1414
PronunciationDictionaryVersionLocator, Model
1515
from .environment import ElevenLabsEnvironment
16-
from .realtime_tts import RealtimeTextToSpeechClient
16+
from .realtime_tts import RealtimeTextToSpeechClient, AsyncRealtimeTextToSpeechClient
1717
from .types import OutputFormat
1818

1919

@@ -257,6 +257,25 @@ class AsyncElevenLabs(AsyncBaseElevenLabs):
257257
api_key="YOUR_API_KEY",
258258
)
259259
"""
260+
def __init__(
261+
self,
262+
*,
263+
base_url: typing.Optional[str] = None,
264+
environment: ElevenLabsEnvironment = ElevenLabsEnvironment.PRODUCTION,
265+
api_key: typing.Optional[str] = os.getenv("ELEVEN_API_KEY"),
266+
timeout: typing.Optional[float] = None,
267+
follow_redirects: typing.Optional[bool] = True,
268+
httpx_client: typing.Optional[httpx.AsyncClient] = None
269+
):
270+
super().__init__(
271+
base_url=base_url,
272+
environment=environment,
273+
api_key=api_key,
274+
timeout=timeout,
275+
follow_redirects=follow_redirects,
276+
httpx_client=httpx_client,
277+
)
278+
self.text_to_speech = AsyncRealtimeTextToSpeechClient(client_wrapper=self._client_wrapper)
260279

261280
async def clone(
262281
self,
@@ -383,16 +402,28 @@ async def generate(
383402
model_id = model.model_id
384403

385404
if stream:
386-
return self.text_to_speech.convert_as_stream(
387-
voice_id=voice_id,
388-
model_id=model_id,
389-
voice_settings=voice_settings,
390-
optimize_streaming_latency=optimize_streaming_latency,
391-
output_format=output_format,
392-
text=text,
393-
request_options=request_options,
394-
pronunciation_dictionary_locators=pronunciation_dictionary_locators
395-
)
405+
if isinstance(text, str):
406+
return self.text_to_speech.convert_as_stream(
407+
voice_id=voice_id,
408+
model_id=model_id,
409+
voice_settings=voice_settings,
410+
optimize_streaming_latency=optimize_streaming_latency,
411+
output_format=output_format,
412+
text=text,
413+
request_options=request_options,
414+
pronunciation_dictionary_locators=pronunciation_dictionary_locators
415+
)
416+
elif isinstance(text, AsyncIterator):
417+
return self.text_to_speech.convert_realtime( # type: ignore
418+
voice_id=voice_id,
419+
voice_settings=voice_settings,
420+
output_format=output_format,
421+
text=text,
422+
request_options=request_options,
423+
model_id=model_id
424+
)
425+
else:
426+
raise ApiError(body="Text is neither a string nor an iterator.")
396427
else:
397428
if not isinstance(text, str):
398429
raise ApiError(body="Text must be a string when stream is False.")

src/elevenlabs/realtime_tts.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
import json
66
import base64
77
import websockets
8+
import asyncio
89

910
from websockets.sync.client import connect
11+
from websockets.client import connect as async_connect
1012

1113
from .core.api_error import ApiError
1214
from .core.jsonable_encoder import jsonable_encoder
1315
from .core.remove_none_from_dict import remove_none_from_dict
1416
from .core.request_options import RequestOptions
1517
from .types.voice_settings import VoiceSettings
16-
from .text_to_speech.client import TextToSpeechClient
18+
from .text_to_speech.client import TextToSpeechClient, AsyncTextToSpeechClient
1719
from .types import OutputFormat
1820

1921
# this is used as the default value for optional parameters
@@ -37,6 +39,22 @@ def text_chunker(chunks: typing.Iterator[str]) -> typing.Iterator[str]:
3739
if buffer != "":
3840
yield buffer + " "
3941

42+
async def async_text_chunker(chunks: typing.AsyncIterator[str]) -> typing.AsyncIterator[str]:
43+
"""Used during input streaming to chunk text blocks and set last char to space"""
44+
splitters = (".", ",", "?", "!", ";", ":", "—", "-", "(", ")", "[", "]", "}", " ")
45+
buffer = ""
46+
async for text in chunks:
47+
if buffer.endswith(splitters):
48+
yield buffer if buffer.endswith(" ") else buffer + " "
49+
buffer = text
50+
elif text.startswith(splitters):
51+
output = buffer + text[0]
52+
yield output if output.endswith(" ") else output + " "
53+
buffer = text[1:]
54+
else:
55+
buffer += text
56+
if buffer != "":
57+
yield buffer + " "
4058

4159
class RealtimeTextToSpeechClient(TextToSpeechClient):
4260

@@ -137,3 +155,105 @@ def get_text() -> typing.Iterator[str]:
137155
raise ApiError(body=data, status_code=ce.code)
138156
elif ce.code != 1000:
139157
raise ApiError(body=ce.reason, status_code=ce.code)
158+
159+
160+
class AsyncRealtimeTextToSpeechClient(AsyncTextToSpeechClient):
161+
162+
async def convert_realtime(
163+
self,
164+
voice_id: str,
165+
*,
166+
text: typing.AsyncIterator[str],
167+
model_id: typing.Optional[str] = OMIT,
168+
output_format: typing.Optional[OutputFormat] = "mp3_44100_128",
169+
voice_settings: typing.Optional[VoiceSettings] = OMIT,
170+
request_options: typing.Optional[RequestOptions] = None,
171+
) -> typing.AsyncIterator[bytes]:
172+
"""
173+
Converts text into speech using a voice of your choice and returns audio.
174+
175+
Parameters:
176+
- voice_id: str. Voice ID to be used, you can use https://api.elevenlabs.io/v1/voices to list all the available voices.
177+
178+
- text: typing.Iterator[str]. The text that will get converted into speech.
179+
180+
- model_id: typing.Optional[str]. Identifier of the model that will be used, you can query them using GET /v1/models. The model needs to have support for text to speech, you can check this using the can_do_text_to_speech property.
181+
182+
- voice_settings: typing.Optional[VoiceSettings]. Voice settings overriding stored setttings for the given voice. They are applied only on the given request.
183+
184+
- request_options: typing.Optional[RequestOptions]. Request-specific configuration.
185+
---
186+
from elevenlabs import PronunciationDictionaryVersionLocator, VoiceSettings
187+
from elevenlabs.client import ElevenLabs
188+
189+
def get_text() -> typing.Iterator[str]:
190+
yield "Hello, how are you?"
191+
yield "I am fine, thank you."
192+
193+
client = ElevenLabs(
194+
api_key="YOUR_API_KEY",
195+
)
196+
client.text_to_speech.convert_realtime(
197+
voice_id="string",
198+
text=get_text(),
199+
model_id="string",
200+
voice_settings=VoiceSettings(
201+
stability=1.1,
202+
similarity_boost=1.1,
203+
style=1.1,
204+
use_speaker_boost=True,
205+
),
206+
)
207+
"""
208+
async with async_connect(
209+
urllib.parse.urljoin(
210+
"wss://api.elevenlabs.io/",
211+
f"v1/text-to-speech/{jsonable_encoder(voice_id)}/stream-input?model_id={model_id}&output_format={output_format}"
212+
),
213+
extra_headers=jsonable_encoder(
214+
remove_none_from_dict(
215+
{
216+
**self._client_wrapper.get_headers(),
217+
**(request_options.get("additional_headers", {}) if request_options is not None else {}),
218+
}
219+
)
220+
)
221+
) as socket:
222+
try:
223+
await socket.send(json.dumps(
224+
dict(
225+
text=" ",
226+
try_trigger_generation=True,
227+
voice_settings=voice_settings.dict() if voice_settings else None,
228+
generation_config=dict(
229+
chunk_length_schedule=[50],
230+
),
231+
)
232+
))
233+
except websockets.exceptions.ConnectionClosedError as ce:
234+
raise ApiError(body=ce.reason, status_code=ce.code)
235+
236+
try:
237+
async for text_chunk in async_text_chunker(text):
238+
data = dict(text=text_chunk, try_trigger_generation=True)
239+
await socket.send(json.dumps(data))
240+
try:
241+
async with asyncio.timeout(1e-4):
242+
data = json.loads(await socket.recv())
243+
if "audio" in data and data["audio"]:
244+
yield base64.b64decode(data["audio"]) # type: ignore
245+
except TimeoutError:
246+
pass
247+
248+
await socket.send(json.dumps(dict(text="")))
249+
250+
while True:
251+
252+
data = json.loads(await socket.recv())
253+
if "audio" in data and data["audio"]:
254+
yield base64.b64decode(data["audio"]) # type: ignore
255+
except websockets.exceptions.ConnectionClosed as ce:
256+
if "message" in data:
257+
raise ApiError(body=data, status_code=ce.code)
258+
elif ce.code != 1000:
259+
raise ApiError(body=ce.reason, status_code=ce.code)

tests/test_async_generation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,20 @@ async def main():
2929
if not IN_GITHUB:
3030
play(out)
3131
asyncio.run(main())
32+
33+
def test_generate_stream() -> None:
34+
async def main():
35+
async def text_stream():
36+
yield "Hi there, I'm Eleven "
37+
yield "I'm a text to speech API "
38+
39+
audio_stream = await async_client.generate(
40+
text=text_stream(),
41+
voice="Nicole",
42+
model="eleven_monolingual_v1",
43+
stream=True
44+
)
45+
46+
if not IN_GITHUB:
47+
stream(audio_stream) # type: ignore
48+
asyncio.run(main())

0 commit comments

Comments
 (0)