@@ -97,36 +97,40 @@ def get_text() -> typing.Iterator[str]:
9797 )
9898 )
9999 ) as socket :
100- socket .send (json .dumps (
101- dict (
102- text = " " ,
103- try_trigger_generation = True ,
104- voice_settings = voice_settings .dict () if voice_settings else None ,
105- generation_config = dict (
106- chunk_length_schedule = [50 ],
107- ),
108- )
109- ))
110-
111- for text_chunk in text_chunker (text ):
112- data = dict (text = text_chunk , try_trigger_generation = True )
113- socket .send (json .dumps (data ))
114- try :
115- data = json .loads (socket .recv (1e-4 ))
100+ try :
101+ socket .send (json .dumps (
102+ dict (
103+ text = " " ,
104+ try_trigger_generation = True ,
105+ voice_settings = voice_settings .dict () if voice_settings else None ,
106+ generation_config = dict (
107+ chunk_length_schedule = [50 ],
108+ ),
109+ )
110+ ))
111+ except websockets .exceptions .ConnectionClosedError as ce :
112+ raise ApiError (body = ce .reason , status_code = ce .code )
113+
114+ try :
115+ for text_chunk in text_chunker (text ):
116+ data = dict (text = text_chunk , try_trigger_generation = True )
117+ socket .send (json .dumps (data ))
118+ try :
119+ data = json .loads (socket .recv (1e-4 ))
120+ if "audio" in data and data ["audio" ]:
121+ yield base64 .b64decode (data ["audio" ]) # type: ignore
122+ except TimeoutError :
123+ pass
124+
125+ socket .send (json .dumps (dict (text = "" )))
126+
127+ while True :
128+
129+ data = json .loads (socket .recv ())
116130 if "audio" in data and data ["audio" ]:
117131 yield base64 .b64decode (data ["audio" ]) # type: ignore
118- except TimeoutError :
119- pass
120-
121- socket .send (json .dumps (dict (text = "" )))
122-
123- while True :
124- try :
125- data = json .loads (socket .recv ())
126- if "audio" in data and data ["audio" ]:
127- yield base64 .b64decode (data ["audio" ]) # type: ignore
128- except websockets .exceptions .ConnectionClosed :
129- if "message" in data :
130- raise ApiError (body = data )
131- break
132-
132+ except websockets .exceptions .ConnectionClosed as ce :
133+ if "message" in data :
134+ raise ApiError (body = data , status_code = ce .code )
135+ elif ce .code != 1000 :
136+ raise ApiError (body = ce .reason , status_code = ce .code )
0 commit comments