2727
2828start_time = datetime .now ()
2929
30+
3031class TTS (nn .Module ):
31- def __init__ (self ,
32- language ,
33- device = 'auto' ,
34- use_hf = True ,
35- config_path = None ,
36- ckpt_path = None ):
32+ def __init__ (
33+ self , language , device = "auto" , use_hf = True , config_path = None , ckpt_path = None
34+ ):
3735 super ().__init__ ()
38- if device == 'auto' :
39- device = 'cpu'
40- if torch .cuda .is_available (): device = 'cuda'
41- if torch .backends .mps .is_available (): device = 'mps'
42- if 'cuda' in device :
36+ if device == "auto" :
37+ device = "cpu"
38+ if torch .cuda .is_available ():
39+ device = "cuda"
40+ if torch .backends .mps .is_available ():
41+ device = "mps"
42+ if "cuda" in device :
4343 assert torch .cuda .is_available ()
4444
45- # config_path =
45+ # config_path =
4646 hps = load_or_download_config (language , use_hf = use_hf , config_path = config_path )
4747
4848 num_languages = hps .num_languages
@@ -64,16 +64,20 @@ def __init__(self,
6464 self .symbol_to_id = {s : i for i , s in enumerate (symbols )}
6565 self .hps = hps
6666 self .device = device
67-
67+
6868 # load state_dict
69- checkpoint_dict = load_or_download_model (language , device , use_hf = use_hf , ckpt_path = ckpt_path )
70- self .model .load_state_dict (checkpoint_dict ['model' ], strict = True )
71-
72- language = language .split ('_' )[0 ]
73- self .language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
69+ checkpoint_dict = load_or_download_model (
70+ language , device , use_hf = use_hf , ckpt_path = ckpt_path
71+ )
72+ self .model .load_state_dict (checkpoint_dict ["model" ], strict = True )
73+
74+ language = language .split ("_" )[0 ]
75+ self .language = (
76+ "ZH_MIX_EN" if language == "ZH" else language
77+ ) # we support a ZH_MIX_EN model
7478
7579 @staticmethod
76- def audio_numpy_concat (segment_data_list , sr , speed = 1. ):
80+ def audio_numpy_concat (segment_data_list , sr , speed = 1.0 ):
7781 audio_segments = []
7882 for segment_data in segment_data_list :
7983 audio_segments += segment_data .reshape (- 1 ).tolist ()
@@ -86,11 +90,24 @@ def split_sentences_into_pieces(text, language, quiet=False):
8690 texts = split_sentence (text , language_str = language )
8791 if not quiet :
8892 print (" > Text split to sentences." )
89- print (' \n ' .join (texts ))
93+ print (" \n " .join (texts ))
9094 print (" > ===========================" )
9195 return texts
9296
93- def tts_to_file (self , text , speaker_id , output_path = None , sdp_ratio = 0.2 , noise_scale = 0.6 , noise_scale_w = 0.8 , speed = 1.0 , pbar = None , format = None , position = None , quiet = False ,):
97+ def tts_to_file (
98+ self ,
99+ text ,
100+ speaker_id ,
101+ output_path = None ,
102+ sdp_ratio = 0.2 ,
103+ noise_scale = 0.6 ,
104+ noise_scale_w = 0.8 ,
105+ speed = 1.0 ,
106+ pbar = None ,
107+ format = None ,
108+ position = None ,
109+ quiet = False ,
110+ ):
94111 language = self .language
95112 texts = self .split_sentences_into_pieces (text , language , quiet )
96113 audio_list = []
@@ -104,10 +121,12 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s
104121 else :
105122 tx = tqdm (texts )
106123 for t in tx :
107- if language in ['EN' , ' ZH_MIX_EN' ]:
108- t = re .sub (r' ([a-z])([A-Z])' , r' \1 \2' , t )
124+ if language in ["EN" , " ZH_MIX_EN" ]:
125+ t = re .sub (r" ([a-z])([A-Z])" , r" \1 \2" , t )
109126 device = self .device
110- bert , ja_bert , phones , tones , lang_ids = utils .get_text_for_tts_infer (t , language , self .hps , device , self .symbol_to_id )
127+ bert , ja_bert , phones , tones , lang_ids = utils .get_text_for_tts_infer (
128+ t , language , self .hps , device , self .symbol_to_id
129+ )
111130 with torch .no_grad ():
112131 x_tst = phones .to (device ).unsqueeze (0 )
113132 tones = tones .to (device ).unsqueeze (0 )
@@ -117,7 +136,8 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s
117136 x_tst_lengths = torch .LongTensor ([phones .size (0 )]).to (device )
118137 del phones
119138 speakers = torch .LongTensor ([speaker_id ]).to (device )
120- audio = self .model .infer (
139+ audio = (
140+ self .model .infer (
121141 x_tst ,
122142 x_tst_lengths ,
123143 speakers ,
@@ -128,26 +148,43 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s
128148 sdp_ratio = sdp_ratio ,
129149 noise_scale = noise_scale ,
130150 noise_scale_w = noise_scale_w ,
131- length_scale = 1. / speed ,
132- )[0 ][0 , 0 ].data .cpu ().float ().numpy ()
151+ length_scale = 1.0 / speed ,
152+ )[0 ][0 , 0 ]
153+ .data .cpu ()
154+ .float ()
155+ .numpy ()
156+ )
133157 del x_tst , tones , lang_ids , bert , ja_bert , x_tst_lengths , speakers
134- #
158+ #
135159 audio_list .append (audio )
136160 torch .cuda .empty_cache ()
137- audio = self .audio_numpy_concat (audio_list , sr = self .hps .data .sampling_rate , speed = speed )
161+ audio = self .audio_numpy_concat (
162+ audio_list , sr = self .hps .data .sampling_rate , speed = speed
163+ )
138164
139165 if output_path is None :
140166 return audio
141167 else :
142168 if format :
143- soundfile .write (output_path , audio , self .hps .data .sampling_rate , format = format )
169+ soundfile .write (
170+ output_path , audio , self .hps .data .sampling_rate , format = format
171+ )
144172 else :
145173 soundfile .write (output_path , audio , self .hps .data .sampling_rate )
146174
147-
148-
149-
150- def tts_to_base64 (self , text , speaker_id , sdp_ratio = 0.2 , noise_scale = 0.6 , noise_scale_w = 0.8 , speed = 1.0 , pbar = None , format = None , position = None , quiet = False ,):
175+ def old_tts_to_base64 (
176+ self ,
177+ text ,
178+ speaker_id ,
179+ sdp_ratio = 0.2 ,
180+ noise_scale = 0.6 ,
181+ noise_scale_w = 0.8 ,
182+ speed = 1.0 ,
183+ pbar = None ,
184+ format = None ,
185+ position = None ,
186+ quiet = False ,
187+ ):
151188 language = self .language
152189 texts = self .split_sentences_into_pieces (text , language , quiet )
153190 audio_list = []
@@ -161,10 +198,12 @@ def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_
161198 else :
162199 tx = tqdm (texts )
163200 for t in tx :
164- if language in ['EN' , ' ZH_MIX_EN' ]:
165- t = re .sub (r' ([a-z])([A-Z])' , r' \1 \2' , t )
201+ if language in ["EN" , " ZH_MIX_EN" ]:
202+ t = re .sub (r" ([a-z])([A-Z])" , r" \1 \2" , t )
166203 device = self .device
167- bert , ja_bert , phones , tones , lang_ids = utils .get_text_for_tts_infer (t , language , self .hps , device , self .symbol_to_id )
204+ bert , ja_bert , phones , tones , lang_ids = utils .get_text_for_tts_infer (
205+ t , language , self .hps , device , self .symbol_to_id
206+ )
168207 with torch .no_grad ():
169208 x_tst = phones .to (device ).unsqueeze (0 )
170209 tones = tones .to (device ).unsqueeze (0 )
@@ -174,7 +213,8 @@ def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_
174213 x_tst_lengths = torch .LongTensor ([phones .size (0 )]).to (device )
175214 del phones
176215 speakers = torch .LongTensor ([speaker_id ]).to (device )
177- audio = self .model .infer (
216+ audio = (
217+ self .model .infer (
178218 x_tst ,
179219 x_tst_lengths ,
180220 speakers ,
@@ -185,26 +225,149 @@ def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_
185225 sdp_ratio = sdp_ratio ,
186226 noise_scale = noise_scale ,
187227 noise_scale_w = noise_scale_w ,
188- length_scale = 1. / speed ,
189- )[0 ][0 , 0 ].data .cpu ().float ().numpy ()
228+ length_scale = 1.0 / speed ,
229+ )[0 ][0 , 0 ]
230+ .data .cpu ()
231+ .float ()
232+ .numpy ()
233+ )
190234 del x_tst , tones , lang_ids , bert , ja_bert , x_tst_lengths , speakers
191- #
235+ #
192236 audio_list .append (audio )
193237 torch .cuda .empty_cache ()
194- audio = self .audio_numpy_concat (audio_list , sr = self .hps .data .sampling_rate , speed = speed )
238+ audio = self .audio_numpy_concat (
239+ audio_list , sr = self .hps .data .sampling_rate , speed = speed
240+ )
195241
196242 with io .BytesIO () as wav_buffer :
197- soundfile .write (wav_buffer , audio , self .hps .data .sampling_rate , format = "WAV" )
243+ soundfile .write (
244+ wav_buffer , audio , self .hps .data .sampling_rate , format = "WAV"
245+ )
198246 wav_buffer .seek (0 )
199247 wav_bytes = wav_buffer .read ()
200248
201-
202249 wav_base64 = base64 .b64encode (wav_bytes ).decode ("utf-8" )
203250 end_time = datetime .now ()
204251 elapsed_time = end_time - start_time
205252
206- return jsonable_encoder ({
207- "audioContent" : wav_base64 ,
208- "time_taken" : elapsed_time
209- })
253+ return jsonable_encoder (
254+ {"audioContent" : wav_base64 , "time_taken" : elapsed_time }
255+ )
256+
257+ def tts_iter (
258+ self ,
259+ text ,
260+ speaker_id ,
261+ sdp_ratio = 0.2 ,
262+ noise_scale = 0.6 ,
263+ noise_scale_w = 0.8 ,
264+ speed = 1.0 ,
265+ pbar = None ,
266+ position = None ,
267+ quiet = False ,
268+ ):
269+ """
270+ https://github.com/myshell-ai/MeloTTS/pull/88/files
271+ """
272+ language = self .language
273+ texts = self .split_sentences_into_pieces (text , language , quiet )
274+
275+ if pbar :
276+ tx = pbar (texts )
277+ else :
278+ if position :
279+ tx = tqdm (texts , position = position )
280+ elif quiet :
281+ tx = texts
282+ else :
283+ tx = tqdm (texts )
284+ for t in tx :
285+ if language in ["EN" , "ZH_MIX_EN" ]:
286+ t = re .sub (r"([a-z])([A-Z])" , r"\1 \2" , t )
287+ device = self .device
288+ bert , ja_bert , phones , tones , lang_ids = utils .get_text_for_tts_infer (
289+ t , language , self .hps , device , self .symbol_to_id
290+ )
291+ with torch .no_grad ():
292+ x_tst = phones .to (device ).unsqueeze (0 )
293+ tones = tones .to (device ).unsqueeze (0 )
294+ lang_ids = lang_ids .to (device ).unsqueeze (0 )
295+ bert = bert .to (device ).unsqueeze (0 )
296+ ja_bert = ja_bert .to (device ).unsqueeze (0 )
297+ x_tst_lengths = torch .LongTensor ([phones .size (0 )]).to (device )
298+ del phones
299+ speakers = torch .LongTensor ([speaker_id ]).to (device )
300+ audio = (
301+ self .model .infer (
302+ x_tst ,
303+ x_tst_lengths ,
304+ speakers ,
305+ tones ,
306+ lang_ids ,
307+ bert ,
308+ ja_bert ,
309+ sdp_ratio = sdp_ratio ,
310+ noise_scale = noise_scale ,
311+ noise_scale_w = noise_scale_w ,
312+ length_scale = 1.0 / speed ,
313+ )[0 ][0 , 0 ]
314+ .data .cpu ()
315+ .float ()
316+ .numpy ()
317+ )
318+ del x_tst , tones , lang_ids , bert , ja_bert , x_tst_lengths , speakers
319+
320+ audio_segments = []
321+ audio_segments += audio .reshape (- 1 ).tolist ()
322+ audio_segments += [0 ] * int (
323+ (self .hps .data .sampling_rate * 0.05 ) / speed
324+ )
325+ audio_segments = np .array (audio_segments ).astype (np .float32 )
326+
327+ yield audio_segments
328+
329+ torch .cuda .empty_cache ()
330+
331+ def tts_to_base64 (
332+ self ,
333+ text ,
334+ speaker_id ,
335+ sdp_ratio = 0.2 ,
336+ noise_scale = 0.6 ,
337+ noise_scale_w = 0.8 ,
338+ speed = 1.0 ,
339+ pbar = None ,
340+ format = None ,
341+ position = None ,
342+ quiet = False ,
343+ ):
344+ audio_list = []
345+ for audio in self .tts_iter (
346+ text ,
347+ speaker_id ,
348+ sdp_ratio ,
349+ noise_scale ,
350+ noise_scale_w ,
351+ speed ,
352+ pbar ,
353+ position ,
354+ quiet ,
355+ ):
356+ audio_list .append (audio )
357+
358+ audio = np .concatenate (audio_list )
359+
360+ with io .BytesIO () as wav_buffer :
361+ soundfile .write (
362+ wav_buffer , audio , self .hps .data .sampling_rate , format = "WAV"
363+ )
364+ wav_buffer .seek (0 )
365+ wav_bytes = wav_buffer .read ()
366+
367+ wav_base64 = base64 .b64encode (wav_bytes ).decode ("utf-8" )
368+ end_time = datetime .now ()
369+ elapsed_time = end_time - start_time
210370
371+ return jsonable_encoder (
372+ {"audioContent" : wav_base64 , "time_taken" : elapsed_time }
373+ )
0 commit comments