diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 51cfc09bb5..f7ffc9c511 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -140,7 +140,7 @@ jobs: conda install -c conda-forge "ffmpeg<7" pip install "mlx>=0.22.0" pip install mlx-lm - pip install "mlx-vlm<0.2.0" + pip install "mlx-vlm>=0.3.4" pip install mlx-whisper pip install f5-tts-mlx pip install qwen-vl-utils!=0.0.9 diff --git a/doc/source/locale/zh_CN/LC_MESSAGES/models/model_abilities/audio.po b/doc/source/locale/zh_CN/LC_MESSAGES/models/model_abilities/audio.po index 85bc8ee34f..f7e3e0b4eb 100644 --- a/doc/source/locale/zh_CN/LC_MESSAGES/models/model_abilities/audio.po +++ b/doc/source/locale/zh_CN/LC_MESSAGES/models/model_abilities/audio.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: Xinference \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2025-09-22 11:25+0800\n" +"POT-Creation-Date: 2025-11-10 11:08+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language: zh_CN\n" @@ -701,9 +701,9 @@ msgid "" "``False`` , and setting it to ``True`` enables randomness:" msgstr "" "可以省略情绪参考音频,转而提供一个包含8个浮点数的列表,按以下顺序指定每种" -"情绪的强度: ``[快乐, 愤怒, 悲伤, 恐惧, 厌恶, 忧郁, 惊讶, 平静]`` 。您还可以" -"使用 ``use_random`` 参数在推理过程中引入随机性情绪;默认值为 ``False`` ,设置为 ``" -"True`` 即可启用随机性情绪。" +"情绪的强度: ``[快乐, 愤怒, 悲伤, 恐惧, 厌恶, 忧郁, 惊讶, 平静]`` 。您还" +"可以使用 ``use_random`` 参数在推理过程中引入随机性情绪;默认值为 ``False`" +"` ,设置为 ``True`` 即可启用随机性情绪。" #: ../../source/models/model_abilities/audio.rst:712 msgid "" @@ -714,10 +714,10 @@ msgid "" "for more natural sounding speech. You can introduce randomness with " "``use_random`` (default: ``False``; ``True`` enables randomness):" msgstr "" -"或者,您可以启用 ``use_emo_text`` 功能,根据您提供的 ``text`` 脚本引导情感" -"表达。您的文本脚本将自动转换为情感向量。使用文本情感模式时,建议将 ``emo_" -"alpha`` 设置为 0.6 左右(或更低),以获得更自然的语音效果。您可通过 ``use_" -"random`` 引入随机性(默认值:``False`` ;``True`` 启用随机性):" +"或者,您可以启用 ``use_emo_text`` 功能,根据您提供的 ``text`` 脚本引导" +"情感表达。您的文本脚本将自动转换为情感向量。使用文本情感模式时,建议将 ``" +"emo_alpha`` 设置为 0.6 左右(或更低),以获得更自然的语音效果。您可通过 `" +"`use_random`` 引入随机性(默认值:``False`` ;``True`` 启用随机性):" #: ../../source/models/model_abilities/audio.rst:737 msgid "" @@ -729,6 +729,86 @@ msgstr "" "您也可以通过 ``emo_text`` 参数直接提供特定的文本情绪描述。您的情绪文本将" "自动转换为情绪向量。这使您能够分别控制文本脚本和文本情绪描述:" +#: ../../source/models/model_abilities/audio.rst:761 +msgid "IndexTTS2 Offline Usage" +msgstr "IndexTTS2 离线使用" + +#: ../../source/models/model_abilities/audio.rst:763 +msgid "" +"IndexTTS2 requires several small models that are downloaded automatically" +" during initialization. For offline environments, you can download these " +"models to a single directory and specify the directory path." +msgstr "" +"IndexTTS2需要多个小型模型,这些模型会在初始化过程中自动下载。在离线环境中" +",您可以将这些模型下载到单一目录,并指定该目录路径。" + +#: ../../source/models/model_abilities/audio.rst:766 +msgid "**Easy Setup Method**" +msgstr "**简易设置方法**" + +#: ../../source/models/model_abilities/audio.rst:768 +msgid "" +"The simplest way to set up offline usage is to Use the `hf download` " +"command to download the small model in advance:" +msgstr "设置离线使用的最简单方法是使用: `hf download` 命令去提前下载所有小模型" + +#: ../../source/models/model_abilities/audio.rst:781 +msgid "The final directory structure should look like this:" +msgstr "最终的目录结构应如下所示:" + +#: ../../source/models/model_abilities/audio.rst:791 +msgid "**Required Models**" +msgstr "**支持的模型列表**" + +#: ../../source/models/model_abilities/audio.rst:793 +msgid "The small models are automatically mapped as follows:" +msgstr "小型模型将按以下方式自动映射:" + +#: ../../source/models/model_abilities/audio.rst:795 +msgid "" +"**w2v-bert-2.0** (``models--facebook--w2v-bert-2.0``) - Feature " +"extraction model" +msgstr "**w2v-bert-2.0** (``models--facebook--w2v-bert-2.0``) - 特征提取模型" + +#: ../../source/models/model_abilities/audio.rst:796 +msgid "**campplus** (``models--funasr--campplus``) - Speaker recognition model" +msgstr "**campplus** (``models--funasr--campplus``) - 说话人识别模型" + +#: ../../source/models/model_abilities/audio.rst:797 +msgid "" +"**bigvgan** (``models--nvidia--bigvgan_v2_22khz_80band_256x``) - Vocoder " +"model" +msgstr "" +"**bigvgan** (``models--nvidia--bigvgan_v2_22khz_80band_256x``) - 语音" +"编码器模型" + +#: ../../source/models/model_abilities/audio.rst:798 +msgid "" +"**semantic_codec** (``models--amphion--MaskGCT``) - Semantic " +"encoding/decoding model" +msgstr "**语义编解码器** (``models--amphion--MaskGCT``) - 语义编码/解码模型" + +#: ../../source/models/model_abilities/audio.rst:800 +msgid "**Launching IndexTTS2 with Offline Models**" +msgstr "**使用离线模式启动IndexTTS2**" + +#: ../../source/models/model_abilities/audio.rst:802 +msgid "" +"When launching IndexTTS2 with Web UI, you can add an additional " +"parameter: - ``small_models_dir`` - Path to directory containing all " +"small models" +msgstr "" +"在通过Web UI启动IndexTTS2时,可添加额外参数:- ``small_models_dir`` - " +"包含所有小型模型的目录路径" + +#: ../../source/models/model_abilities/audio.rst:805 +msgid "When launching with command line, you can add the option:" +msgstr "在通过命令行启动时,您可以添加以下选项:" + +#: ../../source/models/model_abilities/audio.rst:812 +msgid "When launching with Python client:" +msgstr "使用 Python 客户端启动时:" + #~ msgid "**random sampling**" #~ msgstr "" @@ -755,3 +835,58 @@ msgstr "" #~ "`False`; `True` enables randomness):" #~ msgstr "" +#~ msgid "" +#~ "The required small models are: 1. " +#~ "**w2v-bert-2.0** - Feature extraction model" +#~ " (place in ``w2v-bert-2.0/`` subdirectory)" +#~ " 2. **semantic_codec** - Semantic " +#~ "encoding/decoding model (place in " +#~ "``semantic_codec/`` subdirectory) 3. **campplus**" +#~ " - Speaker recognition model (place " +#~ "in ``campplus/`` subdirectory) 4. **bigvgan**" +#~ " - Vocoder model (place in " +#~ "``bigvgan/`` subdirectory)" +#~ msgstr "" +#~ "所需的小型模型包括:1. **w2v-" +#~ "bert-2.0** - 特征提取模型(放置于" +#~ "``w2v-bert-2.0/``子目录)2. " +#~ "**semantic_codec** - 语义编码/解码" +#~ "模型(放置于``semantic_codec/``" +#~ "子目录)3. **campplus** - 说话" +#~ "人识别模型(放置于``campplus/``" +#~ "子目录) 4. **bigvgan** - 声" +#~ "码器模型(放置于``bigvgan/``子目录" +#~ ")" + +#~ msgid "" +#~ "Assume downloaded to ``/path/to/small_models`` " +#~ "with the following structure:" +#~ msgstr "假设下载到``/path/to/small_models``目录,其结构如下:" + +#~ msgid "" +#~ "**Find your Hugging Face cache " +#~ "directory** (usually ``~/.cache/huggingface/hub/``)" +#~ msgstr "" +#~ "**查找您的Hugging Face缓存目录** " +#~ "(通常位于 ``~/.cache/huggingface/" +#~ "hub/`` )" + +#~ msgid "**Copy the required models** to your target directory:" +#~ msgstr "**将所需模型** 复制到目标目录:" + +#~ msgid "**Note about Directory Structure**" +#~ msgstr "**关于目录结构的说明**" + +#~ msgid "" +#~ "The ``snapshots/`` directories contain " +#~ "version-specific model files with hash " +#~ "names. Xinference automatically detects and" +#~ " uses the correct snapshot directory, " +#~ "so you don't need to worry about" +#~ " the exact hash values." +#~ msgstr "" +#~ "``snapshots/`` 目录包含具有哈希名称" +#~ "的特定版本模型文件。Xinference会自动检测并" +#~ "使用正确的快照目录,因此您无需担心精确" +#~ "的哈希值。" + diff --git a/doc/source/models/model_abilities/audio.rst b/doc/source/models/model_abilities/audio.rst index 56e289fd98..92aa9470cd 100644 --- a/doc/source/models/model_abilities/audio.rst +++ b/doc/source/models/model_abilities/audio.rst @@ -757,5 +757,67 @@ Here are several examples of how to use IndexTTS2: use_random=False ) +IndexTTS2 Offline Usage +~~~~~~~~~~~~~~~~~~~~~~~~ + +IndexTTS2 requires several small models that are downloaded automatically during initialization. +For offline environments, you can download these models to a single directory and specify the directory path. + +**Easy Setup Method** + +The simplest way to set up offline usage is to Use the `hf download` command to download the small model in advance: + +.. code-block:: bash + + # Create your local models directory + mkdir -p /path/to/small_models + + # Download models from Hugging Face + hf download facebook/w2v-bert-2.0 --local-dir /path/to/small_models/w2v-bert-2.0 + hf download funasr/campplus --local-dir /path/to/small_models/campplus + hf download nvidia/bigvgan_v2_22khz_80band_256x --local-dir /path/to/small_models/bigvgan + hf download amphion/MaskGCT --local-dir /path/to/small_models/MaskGCT + +The final directory structure should look like this: + +.. code-block:: text + + /path/to/small_models/ + ├── w2v-bert-2.0/ # Feature extraction model + ├── campplus/ # Speaker recognition model + ├── bigvgan/ # Vocoder model + └── MaskGCT/ # Semantic codec model + +**Required Models** + +The small models are automatically mapped as follows: + +1. **w2v-bert-2.0** (``models--facebook--w2v-bert-2.0``) - Feature extraction model +2. **campplus** (``models--funasr--campplus``) - Speaker recognition model +3. **bigvgan** (``models--nvidia--bigvgan_v2_22khz_80band_256x``) - Vocoder model +4. **semantic_codec** (``models--amphion--MaskGCT``) - Semantic encoding/decoding model + +**Launching IndexTTS2 with Offline Models** + +When launching IndexTTS2 with Web UI, you can add an additional parameter: +- ``small_models_dir`` - Path to directory containing all small models + +When launching with command line, you can add the option: + +.. code-block:: bash + + xinference launch --model-name IndexTTS2 --model-type audio \ + --small_models_dir /path/to/small_models + +When launching with Python client: + +.. code-block:: python + + model_uid = client.launch_model( + model_name="IndexTTS2", + model_type="audio", + small_models_dir="/path/to/small_models" + ) + diff --git a/setup.cfg b/setup.cfg index 504a13d39e..bd3d79cd39 100644 --- a/setup.cfg +++ b/setup.cfg @@ -132,7 +132,7 @@ sglang = sglang[srt]>=0.4.2.post4 ; sys_platform=='linux' mlx = mlx-lm>=0.21.5 ; sys_platform=='darwin' and platform_machine=='arm64' - mlx-vlm>=0.1.11,<0.2.0 ; sys_platform=='darwin' and platform_machine=='arm64' + mlx-vlm>=0.3.4 ; sys_platform=='darwin' and platform_machine=='arm64' mlx-whisper ; sys_platform=='darwin' and platform_machine=='arm64' f5-tts-mlx ; sys_platform=='darwin' and platform_machine=='arm64' mlx-audio ; sys_platform=='darwin' and platform_machine=='arm64' diff --git a/xinference/model/audio/indextts2.py b/xinference/model/audio/indextts2.py index 30390a1532..de3ae17865 100644 --- a/xinference/model/audio/indextts2.py +++ b/xinference/model/audio/indextts2.py @@ -56,13 +56,25 @@ def load(self): use_fp16 = self._kwargs.get("use_fp16", False) use_deepspeed = self._kwargs.get("use_deepspeed", False) - logger.info("Loading IndexTTS2 model...") + # Handle small model directory for offline deployment + small_models_config = ( + self._model_spec.default_model_config + if getattr(self._model_spec, "default_model_config", None) + else {} + ) + small_models_config.update(self._kwargs) + + small_models_dir = small_models_config.get("small_models_dir") + logger.info( + f"Loading IndexTTS2 model... (small_models_dir: {small_models_dir})" + ) self._model = IndexTTS2( cfg_path=config_path, model_dir=self._model_path, use_fp16=use_fp16, device=self._device, use_deepspeed=use_deepspeed, + small_models_dir=small_models_dir, ) def speech( diff --git a/xinference/model/audio/model_spec.json b/xinference/model/audio/model_spec.json index 1f17aefbfc..a359f999f9 100644 --- a/xinference/model/audio/model_spec.json +++ b/xinference/model/audio/model_spec.json @@ -992,6 +992,9 @@ "text2audio_emotion_control" ], "multilingual": true, + "default_model_config": { + "small_models_dir": null + }, "virtualenv": { "packages": [ "transformers==4.52.1", diff --git a/xinference/thirdparty/indextts/infer_v2.py b/xinference/thirdparty/indextts/infer_v2.py index 5058f526fc..5482fe9ba3 100644 --- a/xinference/thirdparty/indextts/infer_v2.py +++ b/xinference/thirdparty/indextts/infer_v2.py @@ -1,7 +1,9 @@ import os from subprocess import CalledProcessError -os.environ["HF_HUB_CACHE"] = "./checkpoints/hf_cache" +# Set HF_HUB_CACHE only if not already set (allow custom cache directory) +if "HF_HUB_CACHE" not in os.environ: + os.environ["HF_HUB_CACHE"] = "./checkpoints/hf_cache" import json import re import time @@ -43,6 +45,7 @@ def __init__( device=None, use_cuda_kernel=None, use_deepspeed=False, + small_models_dir=None, ): """ Args: @@ -52,7 +55,96 @@ def __init__( device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS. use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device. use_deepspeed (bool): whether to use DeepSpeed or not. + small_models_dir (str): path to directory containing small models for offline deployment. """ + + print(f">> IndexTTS2.__init__ called with small_models_dir: {small_models_dir}") + + def get_small_model_path(model_name): + """Helper function to get small model path from small_models_dir""" + if small_models_dir is not None and os.path.exists(small_models_dir): + import glob + + # Direct structure model names + direct_model_names = { + "w2v-bert-2.0": "w2v-bert-2.0", + "campplus": "campplus", + "bigvgan": "bigvgan", + "semantic_codec": None, # Special handling below + } + + # Special handling for semantic_codec + if model_name == "semantic_codec": + # Look for semantic_codec in any MaskGCT directory + for item in os.listdir(small_models_dir): + item_path = os.path.join(small_models_dir, item) + if os.path.isdir(item_path) and "MaskGCT" in item: + # New structure: direct semantic_codec path + semantic_path = os.path.join(item_path, "semantic_codec") + if os.path.exists(semantic_path): + return semantic_path + # Also try direct structure + direct_path = os.path.join(small_models_dir, "semantic_codec") + if os.path.exists(direct_path): + return direct_path + else: + # Try new direct structure first + direct_name = direct_model_names.get(model_name) + if direct_name: + direct_path = os.path.join(small_models_dir, direct_name) + if os.path.exists(direct_path): + return direct_path + + # Fallback to old HuggingFace structure for compatibility + old_model_mappings = { + "w2v-bert-2.0": "models--facebook--w2v-bert-2.0", + "campplus": "models--funasr--campplus", + "bigvgan": "models--nvidia--bigvgan_v2_22khz_80band_256x", + } + + # Try old structure + mapped_name = old_model_mappings.get(model_name) + if mapped_name: + mapped_base_path = os.path.join(small_models_dir, mapped_name) + + # Check if it's a HuggingFace cache structure with snapshots + snapshots_path = os.path.join(mapped_base_path, "snapshots") + if os.path.exists(snapshots_path): + # Find the first snapshot directory + for snapshot in os.listdir(snapshots_path): + snapshot_dir = os.path.join(snapshots_path, snapshot) + if os.path.isdir(snapshot_dir): + return snapshot_dir + + # Fallback to direct path if snapshots don't exist + if os.path.exists(mapped_base_path): + return mapped_base_path + + # Try other possibilities for compatibility + possible_patterns = [ + # Generic HuggingFace structure + os.path.join(small_models_dir, f"models--*--{model_name}"), + ] + + for pattern in possible_patterns: + if "*" in pattern: + matches = glob.glob(pattern) + for match in matches: + # Check for snapshots structure + snapshots_path = os.path.join(match, "snapshots") + if os.path.exists(snapshots_path): + for snapshot in os.listdir(snapshots_path): + snapshot_dir = os.path.join( + snapshots_path, snapshot + ) + if os.path.isdir(snapshot_dir): + return snapshot_dir + # Fallback to direct match + if os.path.exists(match): + return match + + return None + if device is not None: self.device = device self.use_fp16 = False if device == "cpu" else use_fp16 @@ -129,11 +221,22 @@ def __init__( print(f"{e!r}") self.use_cuda_kernel = False - self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained( - "facebook/w2v-bert-2.0" - ) + w2v_bert_path = get_small_model_path("w2v-bert-2.0") + print(f">> w2v_bert_path lookup result: {w2v_bert_path}") + if w2v_bert_path is not None: + self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained( + w2v_bert_path + ) + print(f">> w2v-bert model loaded from local path: {w2v_bert_path}") + else: + self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained( + "facebook/w2v-bert-2.0" + ) + print(">> w2v-bert model loaded from huggingface: facebook/w2v-bert-2.0") self.semantic_model, self.semantic_mean, self.semantic_std = ( - build_semantic_model(os.path.join(self.model_dir, self.cfg.w2v_stat)) + build_semantic_model( + os.path.join(self.model_dir, self.cfg.w2v_stat), w2v_bert_path + ) ) self.semantic_model = self.semantic_model.to(self.device) self.semantic_model.eval() @@ -141,9 +244,24 @@ def __init__( self.semantic_std = self.semantic_std.to(self.device) semantic_codec = build_semantic_codec(self.cfg.semantic_codec) - semantic_code_ckpt = hf_hub_download( - "amphion/MaskGCT", filename="semantic_codec/model.safetensors" - ) + semantic_codec_path = get_small_model_path("semantic_codec") + print(f">> semantic_codec_path lookup result: {semantic_codec_path}") + if semantic_codec_path is not None: + semantic_code_ckpt = os.path.join(semantic_codec_path, "model.safetensors") + if not os.path.exists(semantic_code_ckpt): + raise FileNotFoundError( + f"semantic_codec model file not found: {semantic_code_ckpt}" + ) + print( + f">> semantic_codec model loaded from local path: {semantic_code_ckpt}" + ) + else: + semantic_code_ckpt = hf_hub_download( + "amphion/MaskGCT", + filename="semantic_codec/model.safetensors", + cache_dir=os.environ.get("HF_HUB_CACHE"), + ) + print(">> semantic_codec model loaded from huggingface: amphion/MaskGCT") safetensors.torch.load_model(semantic_codec, semantic_code_ckpt) self.semantic_codec = semantic_codec.to(self.device) self.semantic_codec.eval() @@ -167,9 +285,22 @@ def __init__( print(">> s2mel weights restored from:", s2mel_path) # load campplus_model - campplus_ckpt_path = hf_hub_download( - "funasr/campplus", filename="campplus_cn_common.bin" - ) + campplus_path = get_small_model_path("campplus") + print(f">> campplus_path lookup result: {campplus_path}") + if campplus_path is not None: + campplus_ckpt_path = os.path.join(campplus_path, "campplus_cn_common.bin") + if not os.path.exists(campplus_ckpt_path): + raise FileNotFoundError( + f"campplus model file not found: {campplus_ckpt_path}" + ) + print(f">> campplus model loaded from local path: {campplus_ckpt_path}") + else: + campplus_ckpt_path = hf_hub_download( + "funasr/campplus", + filename="campplus_cn_common.bin", + cache_dir=os.environ.get("HF_HUB_CACHE"), + ) + print(">> campplus model loaded from huggingface: funasr/campplus") campplus_model = CAMPPlus(feat_dim=80, embedding_size=192) campplus_model.load_state_dict( torch.load(campplus_ckpt_path, map_location="cpu") @@ -178,7 +309,14 @@ def __init__( self.campplus_model.eval() print(">> campplus_model weights restored from:", campplus_ckpt_path) - bigvgan_name = self.cfg.vocoder.name + bigvgan_path = get_small_model_path("bigvgan") + print(f">> bigvgan_path lookup result: {bigvgan_path}") + if bigvgan_path is not None: + bigvgan_name = bigvgan_path + print(f">> bigvgan model loaded from local path: {bigvgan_path}") + else: + bigvgan_name = self.cfg.vocoder.name + print(f">> bigvgan model loaded from default: {bigvgan_name}") self.bigvgan = bigvgan.BigVGAN.from_pretrained( bigvgan_name, use_cuda_kernel=self.use_cuda_kernel ) diff --git a/xinference/thirdparty/indextts/utils/maskgct_utils.py b/xinference/thirdparty/indextts/utils/maskgct_utils.py index 40b9cb0e15..a2d579402c 100644 --- a/xinference/thirdparty/indextts/utils/maskgct_utils.py +++ b/xinference/thirdparty/indextts/utils/maskgct_utils.py @@ -1,15 +1,18 @@ -import torch -import librosa +import time + import json5 -from huggingface_hub import hf_hub_download -from transformers import SeamlessM4TFeatureExtractor, Wav2Vec2BertModel -import safetensors +import librosa import numpy as np - +import safetensors +import torch +from huggingface_hub import hf_hub_download +from indextts.utils.maskgct.models.codec.amphion_codec.codec import ( + CodecDecoder, + CodecEncoder, +) from indextts.utils.maskgct.models.codec.kmeans.repcodec_model import RepCodec from indextts.utils.maskgct.models.tts.maskgct.maskgct_s2a import MaskGCT_S2A -from indextts.utils.maskgct.models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder -import time +from transformers import SeamlessM4TFeatureExtractor, Wav2Vec2BertModel def _load_config(config_fn, lowercase=False): @@ -84,8 +87,13 @@ def __repr__(self): return self.__dict__.__repr__() -def build_semantic_model(path_='./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt'): - semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0") +def build_semantic_model( + path_="./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt", model_path=None +): + if model_path is not None: + semantic_model = Wav2Vec2BertModel.from_pretrained(model_path) + else: + semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0") semantic_model.eval() stat_mean_var = torch.load(path_) semantic_mean = stat_mean_var["mean"] @@ -116,18 +124,18 @@ def build_acoustic_codec(cfg, device): return codec_encoder, codec_decoder -class Inference_Pipeline(): +class Inference_Pipeline: def __init__( - self, - semantic_model, - semantic_codec, - semantic_mean, - semantic_std, - codec_encoder, - codec_decoder, - s2a_model_1layer, - s2a_model_full, - ): + self, + semantic_model, + semantic_codec, + semantic_mean, + semantic_std, + codec_encoder, + codec_decoder, + s2a_model_1layer, + s2a_model_full, + ): self.semantic_model = semantic_model self.semantic_codec = semantic_codec self.semantic_mean = semantic_mean @@ -243,7 +251,7 @@ def gt_inference( combine_semantic_code, ): speech = librosa.load(prompt_speech_path, sr=24000)[0] - ''' + """ acoustic_code = self.extract_acoustic_code( torch.tensor(speech).unsqueeze(0).to(combine_semantic_code.device) ) @@ -251,9 +259,14 @@ def gt_inference( prompt_vq_emb = self.codec_decoder.vq2emb( prompt.permute(2, 0, 1), n_quantizers=12 ) - ''' + """ - prompt_vq_emb = self.codec_encoder(torch.tensor(speech).unsqueeze(0).unsqueeze(1).to(combine_semantic_code.device)) + prompt_vq_emb = self.codec_encoder( + torch.tensor(speech) + .unsqueeze(0) + .unsqueeze(1) + .to(combine_semantic_code.device) + ) recovered_prompt_audio = self.codec_decoder(prompt_vq_emb) recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy() return recovered_prompt_audio