|
17 | 17 | from .mel_processing import spectrogram_torch, spectrogram_torch_conv |
18 | 18 | from .download_utils import load_or_download_config, load_or_download_model |
19 | 19 |
|
| 20 | + |
| 21 | +import os |
| 22 | +import io |
| 23 | +import base64 |
| 24 | +from datetime import datetime |
| 25 | +from fastapi.encoders import jsonable_encoder |
| 26 | + |
| 27 | + |
| 28 | +start_time = datetime.now() |
| 29 | + |
20 | 30 | class TTS(nn.Module): |
21 | 31 | def __init__(self, |
22 | 32 | language, |
@@ -133,3 +143,68 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s |
133 | 143 | soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format) |
134 | 144 | else: |
135 | 145 | soundfile.write(output_path, audio, self.hps.data.sampling_rate) |
| 146 | + |
| 147 | + |
| 148 | + |
| 149 | + |
| 150 | + def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,): |
| 151 | + language = self.language |
| 152 | + texts = self.split_sentences_into_pieces(text, language, quiet) |
| 153 | + audio_list = [] |
| 154 | + if pbar: |
| 155 | + tx = pbar(texts) |
| 156 | + else: |
| 157 | + if position: |
| 158 | + tx = tqdm(texts, position=position) |
| 159 | + elif quiet: |
| 160 | + tx = texts |
| 161 | + else: |
| 162 | + tx = tqdm(texts) |
| 163 | + for t in tx: |
| 164 | + if language in ['EN', 'ZH_MIX_EN']: |
| 165 | + t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t) |
| 166 | + device = self.device |
| 167 | + bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id) |
| 168 | + with torch.no_grad(): |
| 169 | + x_tst = phones.to(device).unsqueeze(0) |
| 170 | + tones = tones.to(device).unsqueeze(0) |
| 171 | + lang_ids = lang_ids.to(device).unsqueeze(0) |
| 172 | + bert = bert.to(device).unsqueeze(0) |
| 173 | + ja_bert = ja_bert.to(device).unsqueeze(0) |
| 174 | + x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) |
| 175 | + del phones |
| 176 | + speakers = torch.LongTensor([speaker_id]).to(device) |
| 177 | + audio = self.model.infer( |
| 178 | + x_tst, |
| 179 | + x_tst_lengths, |
| 180 | + speakers, |
| 181 | + tones, |
| 182 | + lang_ids, |
| 183 | + bert, |
| 184 | + ja_bert, |
| 185 | + sdp_ratio=sdp_ratio, |
| 186 | + noise_scale=noise_scale, |
| 187 | + noise_scale_w=noise_scale_w, |
| 188 | + length_scale=1. / speed, |
| 189 | + )[0][0, 0].data.cpu().float().numpy() |
| 190 | + del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers |
| 191 | + # |
| 192 | + audio_list.append(audio) |
| 193 | + torch.cuda.empty_cache() |
| 194 | + audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed) |
| 195 | + |
| 196 | + with io.BytesIO() as wav_buffer: |
| 197 | + soundfile.write(wav_buffer, audio, self.hps.data.sampling_rate, format="WAV") |
| 198 | + wav_buffer.seek(0) |
| 199 | + wav_bytes = wav_buffer.read() |
| 200 | + |
| 201 | + |
| 202 | + wav_base64 = base64.b64encode(wav_bytes).decode("utf-8") |
| 203 | + end_time = datetime.now() |
| 204 | + elapsed_time = end_time - start_time |
| 205 | + |
| 206 | + return jsonable_encoder({ |
| 207 | + "audio_base64": wav_base64, |
| 208 | + "time_taken": elapsed_time |
| 209 | + }) |
| 210 | + |
0 commit comments