diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 7ab04a7..8f6cec4 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -16,12 +16,42 @@ import time from typing import Generator from tqdm import tqdm from hyperpyyaml import load_hyperpyyaml -from modelscope import snapshot_download import torch from cosyvoice.cli.frontend import CosyVoiceFrontEnd from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model from cosyvoice.utils.file_utils import logging from cosyvoice.utils.class_utils import get_model_type +from cosyvoice.utils.device import get_preferred_device, has_accelerator, supports_trt, supports_vllm + +try: + from modelscope import snapshot_download as modelscope_snapshot_download +except ImportError: + modelscope_snapshot_download = None + + +HUGGINGFACE_MODEL_ALIASES = { + 'iic/CosyVoice2-0.5B': 'FunAudioLLM/CosyVoice2-0.5B', + 'iic/CosyVoice-300M': 'FunAudioLLM/CosyVoice-300M', + 'iic/CosyVoice-300M-SFT': 'FunAudioLLM/CosyVoice-300M-SFT', + 'iic/CosyVoice-300M-Instruct': 'FunAudioLLM/CosyVoice-300M-Instruct', + 'FunAudioLLM/Fun-CosyVoice3-0.5B': 'FunAudioLLM/Fun-CosyVoice3-0.5B-2512', +} + + +def snapshot_download(model_dir): + errors = [] + if modelscope_snapshot_download is not None: + try: + return modelscope_snapshot_download(model_dir) + except Exception as exc: + errors.append(f'modelscope: {exc}') + try: + from huggingface_hub import snapshot_download as huggingface_snapshot_download + repo_id = HUGGINGFACE_MODEL_ALIASES.get(model_dir, model_dir) + return huggingface_snapshot_download(repo_id=repo_id, local_dir_use_symlinks=False) + except Exception as exc: + errors.append(f'huggingface: {exc}') + raise RuntimeError(f'failed to download model {model_dir}: {"; ".join(errors)}') class CosyVoice: @@ -29,6 +59,7 @@ class CosyVoice: def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): self.model_dir = model_dir self.fp16 = fp16 + self.device = get_preferred_device() if not os.path.exists(model_dir): model_dir = snapshot_download(model_dir) hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir) @@ -44,9 +75,12 @@ class CosyVoice: '{}/spk2info.pt'.format(model_dir), configs['allowed_special']) self.sample_rate = configs['sample_rate'] - if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): + if has_accelerator() is False and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False - logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') + logging.warning('no accelerator device, set load_jit/load_trt/fp16 to False') + if supports_trt(self.device) is False and load_trt is True: + load_trt = False + logging.warning('TensorRT requires CUDA, set load_trt to False') self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16) self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), @@ -141,6 +175,7 @@ class CosyVoice2(CosyVoice): def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1): self.model_dir = model_dir self.fp16 = fp16 + self.device = get_preferred_device() if not os.path.exists(model_dir): model_dir = snapshot_download(model_dir) hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir) @@ -156,9 +191,15 @@ class CosyVoice2(CosyVoice): '{}/spk2info.pt'.format(model_dir), configs['allowed_special']) self.sample_rate = configs['sample_rate'] - if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True): + if has_accelerator() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True): load_jit, load_trt, load_vllm, fp16 = False, False, False, False - logging.warning('no cuda device, set load_jit/load_trt/load_vllm/fp16 to False') + logging.warning('no accelerator device, set load_jit/load_trt/load_vllm/fp16 to False') + if supports_trt(self.device) is False and load_trt is True: + load_trt = False + logging.warning('TensorRT requires CUDA, set load_trt to False') + if supports_vllm(self.device) is False and load_vllm is True: + load_vllm = False + logging.warning('vLLM support in this workspace is limited to CUDA, set load_vllm to False') self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16) self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), @@ -191,6 +232,7 @@ class CosyVoice3(CosyVoice2): def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1): self.model_dir = model_dir self.fp16 = fp16 + self.device = get_preferred_device() if not os.path.exists(model_dir): model_dir = snapshot_download(model_dir) hyper_yaml_path = '{}/cosyvoice3.yaml'.format(model_dir) @@ -206,9 +248,15 @@ class CosyVoice3(CosyVoice2): '{}/spk2info.pt'.format(model_dir), configs['allowed_special']) self.sample_rate = configs['sample_rate'] - if torch.cuda.is_available() is False and (load_trt is True or fp16 is True): + if has_accelerator() is False and (load_trt is True or fp16 is True): load_trt, fp16 = False, False - logging.warning('no cuda device, set load_trt/fp16 to False') + logging.warning('no accelerator device, set load_trt/fp16 to False') + if supports_trt(self.device) is False and load_trt is True: + load_trt = False + logging.warning('TensorRT requires CUDA, set load_trt to False') + if supports_vllm(self.device) is False and load_vllm is True: + load_vllm = False + logging.warning('vLLM support in this workspace is limited to CUDA, set load_vllm to False') self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16) self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index 6d397cc..9265a7c 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -25,6 +25,7 @@ import re import inflect from cosyvoice.utils.file_utils import logging, load_wav from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation +from cosyvoice.utils.device import get_preferred_device, speech_tokenizer_provider class CosyVoiceFrontEnd: @@ -38,14 +39,13 @@ class CosyVoiceFrontEnd: allowed_special: str = 'all'): self.tokenizer = get_tokenizer() self.feat_extractor = feat_extractor - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = get_preferred_device() option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"]) self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, - providers=["CUDAExecutionProvider" if torch.cuda.is_available() else - "CPUExecutionProvider"]) + providers=[speech_tokenizer_provider()]) if os.path.exists(spk2info): self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True) else: diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 92a15d9..026eaf0 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -19,11 +19,18 @@ import numpy as np import threading import time from torch.nn import functional as F -from contextlib import nullcontext import uuid from cosyvoice.utils.common import fade_in_out from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm from cosyvoice.utils.common import TrtContextWrapper +from cosyvoice.utils.device import ( + autocast_context, + empty_cache, + get_preferred_device, + stream_context, + supports_trt, + synchronize, +) class CosyVoiceModel: @@ -33,11 +40,11 @@ class CosyVoiceModel: flow: torch.nn.Module, hift: torch.nn.Module, fp16: bool = False): - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.llm = llm self.flow = flow self.hift = hift self.fp16 = fp16 + self.device = get_preferred_device() self.token_min_hop_len = 2 * self.flow.input_frame_rate self.token_max_hop_len = 4 * self.flow.input_frame_rate self.token_overlap_len = 20 @@ -52,7 +59,7 @@ class CosyVoiceModel: # rtf and decoding related self.stream_scale_factor = 1 assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' - self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + self.llm_context = stream_context(self.device) self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} @@ -81,7 +88,7 @@ class CosyVoiceModel: self.flow.encoder = flow_encoder def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16): - assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + assert supports_trt(self.device), 'tensorrt only supports CUDA!' if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) del self.flow.decoder.estimator @@ -100,7 +107,7 @@ class CosyVoiceModel: def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): cur_silent_token_num, max_silent_token_num = 0, 5 - with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False): + with self.llm_context, autocast_context(self.device, enabled=self.fp16 is True and hasattr(self.llm, 'vllm') is False): if isinstance(text, Generator): assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!' token_generator = self.llm.inference_bistream(text=text, @@ -133,7 +140,7 @@ class CosyVoiceModel: self.llm_end_dict[uuid] = True def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): - with torch.cuda.amp.autocast(self.fp16): + with autocast_context(self.device, enabled=self.fp16): tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32), token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), prompt_token=prompt_token.to(self.device), @@ -237,9 +244,8 @@ class CosyVoiceModel: self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid) - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.current_stream().synchronize() + empty_cache(self.device) + synchronize(self.device) class CosyVoice2Model(CosyVoiceModel): @@ -249,11 +255,11 @@ class CosyVoice2Model(CosyVoiceModel): flow: torch.nn.Module, hift: torch.nn.Module, fp16: bool = False): - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.llm = llm self.flow = flow self.hift = hift self.fp16 = fp16 + self.device = get_preferred_device() # NOTE must matching training static_chunk_size self.token_hop_len = 25 # NOTE increase token_hop_len incrementally to avoid duplicate inference @@ -266,7 +272,7 @@ class CosyVoice2Model(CosyVoiceModel): # speech fade in out self.speech_window = np.hamming(2 * self.source_cache_len) # rtf and decoding related - self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + self.llm_context = stream_context(self.device) self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} @@ -290,7 +296,7 @@ class CosyVoice2Model(CosyVoiceModel): del self.llm.llm.model.model.layers def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): - with torch.cuda.amp.autocast(self.fp16): + with autocast_context(self.device, enabled=self.fp16): tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32), token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), prompt_token=prompt_token.to(self.device), @@ -389,9 +395,8 @@ class CosyVoice2Model(CosyVoiceModel): self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.current_stream().synchronize() + empty_cache(self.device) + synchronize(self.device) class CosyVoice3Model(CosyVoice2Model): @@ -401,11 +406,11 @@ class CosyVoice3Model(CosyVoice2Model): flow: torch.nn.Module, hift: torch.nn.Module, fp16: bool = False): - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.llm = llm self.flow = flow self.hift = hift self.fp16 = fp16 + self.device = get_preferred_device() # NOTE must matching training static_chunk_size self.token_hop_len = 25 # NOTE increase token_hop_len incrementally to avoid duplicate inference @@ -413,7 +418,7 @@ class CosyVoice3Model(CosyVoice2Model): self.stream_scale_factor = 2 assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' # rtf and decoding related - self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + self.llm_context = stream_context(self.device) self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} @@ -423,7 +428,7 @@ class CosyVoice3Model(CosyVoice2Model): self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323] def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): - with torch.cuda.amp.autocast(self.fp16): + with autocast_context(self.device, enabled=self.fp16): tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32), token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), prompt_token=prompt_token.to(self.device), diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index 3f235a6..9857bf7 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -22,6 +22,7 @@ from typing import List import numpy as np import torch +from cosyvoice.utils.device import seed_all IGNORE_ID = -1 @@ -181,8 +182,7 @@ def fade_in_out(fade_in_mel, fade_out_mel, window): def set_all_random_seed(seed): random.seed(seed) np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + seed_all(seed) def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index b173ef2..a169a92 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -18,6 +18,7 @@ import os import json import torch import torchaudio +import soundfile as sf import logging logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.basicConfig(level=logging.DEBUG, @@ -42,11 +43,11 @@ def read_json_lists(list_file): def load_wav(wav, target_sr, min_sr=16000): - speech, sample_rate = torchaudio.load(wav, backend='soundfile') - speech = speech.mean(dim=0, keepdim=True) + speech, sample_rate = sf.read(wav, always_2d=True, dtype='float32') + speech = torch.from_numpy(speech.T).mean(dim=0, keepdim=True) if sample_rate != target_sr: assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) - speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) + speech = torchaudio.functional.resample(speech, orig_freq=sample_rate, new_freq=target_sr) return speech diff --git a/cosyvoice/utils/device.py b/cosyvoice/utils/device.py new file mode 100644 index 0000000..9aa3aa7 --- /dev/null +++ b/cosyvoice/utils/device.py @@ -0,0 +1,87 @@ +from contextlib import nullcontext + +import torch + + +def is_xpu_available(): + return hasattr(torch, 'xpu') and torch.xpu.is_available() + + +def get_preferred_device(): + if torch.cuda.is_available(): + return torch.device('cuda') + if is_xpu_available(): + return torch.device('xpu') + return torch.device('cpu') + + +def get_device_module(device=None): + device = torch.device(device or get_preferred_device()) + if device.type == 'cuda': + return torch.cuda + if device.type == 'xpu' and hasattr(torch, 'xpu'): + return torch.xpu + return None + + +def has_accelerator(): + return get_preferred_device().type != 'cpu' + + +def supports_trt(device=None): + return torch.device(device or get_preferred_device()).type == 'cuda' + + +def supports_vllm(device=None): + return torch.device(device or get_preferred_device()).type == 'cuda' + + +def get_autocast_dtype(device=None): + device = torch.device(device or get_preferred_device()) + if device.type == 'cuda': + return torch.float16 + if device.type == 'xpu': + return torch.float16 + return None + + +def autocast_context(device=None, enabled=False): + device = torch.device(device or get_preferred_device()) + dtype = get_autocast_dtype(device) + if enabled is False or dtype is None: + return nullcontext() + return torch.autocast(device_type=device.type, dtype=dtype) + + +def stream_context(device=None): + device = torch.device(device or get_preferred_device()) + module = get_device_module(device) + if module is None: + return nullcontext() + return module.stream(module.Stream(device=device)) + + +def empty_cache(device=None): + module = get_device_module(device) + if module is not None and hasattr(module, 'empty_cache'): + module.empty_cache() + + +def synchronize(device=None): + module = get_device_module(device) + if module is not None and hasattr(module, 'current_stream'): + module.current_stream(device=device).synchronize() + + +def seed_all(seed): + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + if is_xpu_available(): + torch.xpu.manual_seed_all(seed) + + +def speech_tokenizer_provider(): + if torch.cuda.is_available(): + return 'CUDAExecutionProvider' + return 'CPUExecutionProvider'