Voice tone cloning with OpenVoice and OpenVINO#
This Jupyter notebook can be launched on-line, opening an interactive environment in a browser window. You can also make a local installation. Choose one of the following options:
OpenVoice is a versatile instant voice tone transferring and generating speech in various languages with just a brief audio snippet from the source speaker. OpenVoice has three main features: (i) high quality tone color replication with multiple languages and accents; (ii) it provides fine-tuned control over voice styles, including emotions, accents, as well as other parameters such as rhythm, pauses, and intonation. (iii) OpenVoice achieves zero-shot cross-lingual voice cloning, eliminating the need for the generated speech and the reference speech to be part of a massive-speaker multilingual training dataset.
image#
More details about model can be found in project web page, paper, and official repository
This notebook provides example of converting PyTorch OpenVoice model to OpenVINO IR. In this tutorial we will explore how to convert and run OpenVoice using OpenVINO.
Table of contents:
Installation Instructions#
This is a self-contained example that relies solely on its own code.
We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.
Clone repository and install requirements#
# Fetch `notebook_utils` module
import requests
from pathlib import Path
if not Path("notebook_utils.py").exists():
r = requests.get(
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
)
open("notebook_utils.py", "w").write(r.text)
if not Path("cmd_helper.py").exists():
r = requests.get(
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/cmd_helper.py",
)
open("cmd_helper.py", "w").write(r.text)
if not Path("pip_helper.py").exists():
r = requests.get(
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/pip_helper.py",
)
open("pip_helper.py", "w").write(r.text)
# Read more about telemetry collection at https://github.com/openvinotoolkit/openvino_notebooks?tab=readme-ov-file#-telemetry
from notebook_utils import collect_telemetry
collect_telemetry("openvoice.ipynb")
from pathlib import Path
from cmd_helper import clone_repo
from pip_helper import pip_install
import platform
repo_dir = Path("OpenVoice")
if not repo_dir.exists():
clone_repo("https://github.com/myshell-ai/OpenVoice")
orig_english_path = Path("OpenVoice/openvoice/text/_orig_english.py")
english_path = Path("OpenVoice/openvoice/text/english.py")
english_path.rename(orig_english_path)
with orig_english_path.open("r") as f:
data = f.read()
data = data.replace("unidecode", "anyascii")
with english_path.open("w") as out_f:
out_f.write(data)
# fix a problem with silero downloading and installing
with Path("OpenVoice/openvoice/se_extractor.py").open("r") as orig_file:
data = orig_file.read()
data = data.replace('method="silero"', 'method="silero:3.0"')
with Path("OpenVoice/openvoice/se_extractor.py").open("w") as out_f:
out_f.write(data)
pip_install("librosa>=0.8.1", "pydub>=0.25.1", "tqdm", "inflect>=7.0.0", "pypinyin>=0.50.0", "openvino>=2023.3", "gradio>=4.15")
pip_install(
"--extra-index-url",
"https://download.pytorch.org/whl/cpu",
"wavmark>=0.0.3",
"faster-whisper>=0.9.0",
"eng_to_ipa>=0.0.2",
"cn2an>=0.5.22",
"jieba>=0.42.1",
"langid>=1.1.6",
"ipywebrtc",
"anyascii",
"torch>=2.1",
"nncf>=2.11.0",
"dtw-python",
"more-itertools",
"tiktoken",
)
pip_install("--no-deps", "whisper-timestamped>=1.14.2", "openai-whisper")
if platform.system() == "Darwin":
pip_install("numpy<2.0")
Download checkpoints and load PyTorch model#
import os
import torch
import openvino as ov
import ipywidgets as widgets
from IPython.display import Audio
from notebook_utils import download_file, device_widget
core = ov.Core()
from openvoice.api import BaseSpeakerTTS, ToneColorConverter, OpenVoiceBaseClass
import openvoice.se_extractor as se_extractor
Importing the dtw module. When using in academic works please cite:
T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
J. Stat. Soft., doi:10.18637/jss.v031.i07.
CKPT_BASE_PATH = Path("checkpoints")
en_suffix = CKPT_BASE_PATH / "base_speakers/EN"
zh_suffix = CKPT_BASE_PATH / "base_speakers/ZH"
converter_suffix = CKPT_BASE_PATH / "converter"
To make notebook lightweight by default model for Chinese speech is not
activated, in order turn on please set flag enable_chinese_lang
to
True
enable_chinese_lang = False
def download_from_hf_hub(filename, local_dir="./"):
from huggingface_hub import hf_hub_download
local_path = Path(local_dir)
local_path.mkdir(exist_ok=True)
hf_hub_download(repo_id="myshell-ai/OpenVoice", filename=filename, local_dir=local_path)
download_from_hf_hub(f"{converter_suffix.as_posix()}/checkpoint.pth")
download_from_hf_hub(f"{converter_suffix.as_posix()}/config.json")
download_from_hf_hub(f"{en_suffix.as_posix()}/checkpoint.pth")
download_from_hf_hub(f"{en_suffix.as_posix()}/config.json")
download_from_hf_hub(f"{en_suffix.as_posix()}/en_default_se.pth")
download_from_hf_hub(f"{en_suffix.as_posix()}/en_style_se.pth")
if enable_chinese_lang:
download_from_hf_hub(f"{zh_suffix.as_posix()}/checkpoint.pth")
download_from_hf_hub(f"{zh_suffix.as_posix()}/config.json")
download_from_hf_hub(f"{zh_suffix.as_posix()}/zh_default_se.pth")
pt_device = "cpu"
en_base_speaker_tts = BaseSpeakerTTS(en_suffix / "config.json", device=pt_device)
en_base_speaker_tts.load_ckpt(en_suffix / "checkpoint.pth")
tone_color_converter = ToneColorConverter(converter_suffix / "config.json", device=pt_device)
tone_color_converter.load_ckpt(converter_suffix / "checkpoint.pth")
if enable_chinese_lang:
zh_base_speaker_tts = BaseSpeakerTTS(zh_suffix / "config.json", device=pt_device)
zh_base_speaker_tts.load_ckpt(zh_suffix / "checkpoint.pth")
else:
zh_base_speaker_tts = None
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
Loaded checkpoint 'checkpoints/base_speakers/EN/checkpoint.pth'
missing/unexpected keys: [] []
Loaded checkpoint 'checkpoints/converter/checkpoint.pth'
missing/unexpected keys: [] []
Convert models to OpenVINO IR#
There are 2 models in OpenVoice: first one is responsible for speech
generation BaseSpeakerTTS
and the second one ToneColorConverter
imposes arbitrary voice tone to the original speech. To convert to
OpenVino IR format first we need to get acceptable torch.nn.Module
object. Both ToneColorConverter, BaseSpeakerTTS instead of using
self.forward
as the main entry point use custom infer
and
convert_voice
methods respectively, therefore need to wrap them with
a custom class that is inherited from torch.nn.Module.
class OVOpenVoiceBase(torch.nn.Module):
"""
Base class for both TTS and voice tone conversion model: constructor is same for both of them.
"""
def __init__(self, voice_model: OpenVoiceBaseClass):
super().__init__()
self.voice_model = voice_model
for par in voice_model.model.parameters():
par.requires_grad = False
class OVOpenVoiceTTS(OVOpenVoiceBase):
"""
Constructor of this class accepts BaseSpeakerTTS object for speech generation and wraps it's 'infer' method with forward.
"""
def get_example_input(self):
stn_tst = self.voice_model.get_text("this is original text", self.voice_model.hps, False)
x_tst = stn_tst.unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
speaker_id = torch.LongTensor([1])
noise_scale = torch.tensor(0.667)
length_scale = torch.tensor(1.0)
noise_scale_w = torch.tensor(0.6)
return (
x_tst,
x_tst_lengths,
speaker_id,
noise_scale,
length_scale,
noise_scale_w,
)
def forward(self, x, x_lengths, sid, noise_scale, length_scale, noise_scale_w):
return self.voice_model.model.infer(x, x_lengths, sid, noise_scale, length_scale, noise_scale_w)
class OVOpenVoiceConverter(OVOpenVoiceBase):
"""
Constructor of this class accepts ToneColorConverter object for voice tone conversion and wraps it's 'voice_conversion' method with forward.
"""
def get_example_input(self):
y = torch.randn([1, 513, 238], dtype=torch.float32)
y_lengths = torch.LongTensor([y.size(-1)])
target_se = torch.randn(*(1, 256, 1))
source_se = torch.randn(*(1, 256, 1))
tau = torch.tensor(0.3)
return (y, y_lengths, source_se, target_se, tau)
def forward(self, y, y_lengths, sid_src, sid_tgt, tau):
return self.voice_model.model.voice_conversion(y, y_lengths, sid_src, sid_tgt, tau)
Convert to OpenVino IR and save to IRs_path folder for the future use. If IRs already exist skip conversion and read them directly
For reducing memory consumption, weights compression optimization can be applied using NNCF. Weight compression aims to reduce the memory footprint of a model. models, which require extensive memory to store the weights during inference, can benefit from weight compression in the following ways:
enabling the inference of exceptionally large models that cannot be accommodated in the memory of the device;
improving the inference performance of the models by reducing the latency of the memory access when computing the operations with weights, for example, Linear layers.
Neural Network Compression Framework (NNCF) provides 4-bit / 8-bit mixed weight quantization as a compression method. The main difference between weights compression and full model quantization (post-training quantization) is that activations remain floating-point in the case of weights compression which leads to a better accuracy. In addition, weight compression is data-free and does not require a calibration dataset, making it easy to use.
nncf.compress_weights
function can be used for performing weights
compression. The function accepts an OpenVINO model and other
compression parameters.
More details about weights compression can be found in OpenVINO documentation.
import nncf
IRS_PATH = Path("openvino_irs/")
EN_TTS_IR = IRS_PATH / "openvoice_en_tts.xml"
ZH_TTS_IR = IRS_PATH / "openvoice_zh_tts.xml"
VOICE_CONVERTER_IR = IRS_PATH / "openvoice_tone_conversion.xml"
paths = [EN_TTS_IR, VOICE_CONVERTER_IR]
models = [
OVOpenVoiceTTS(en_base_speaker_tts),
OVOpenVoiceConverter(tone_color_converter),
]
if enable_chinese_lang:
models.append(OVOpenVoiceTTS(zh_base_speaker_tts))
paths.append(ZH_TTS_IR)
ov_models = []
for model, path in zip(models, paths):
if not path.exists():
ov_model = ov.convert_model(model, example_input=model.get_example_input())
ov_model = nncf.compress_weights(ov_model)
ov.save_model(ov_model, path)
else:
ov_model = core.read_model(path)
ov_models.append(ov_model)
ov_en_tts, ov_voice_conversion = ov_models[:2]
if enable_chinese_lang:
ov_zh_tts = ov_models[-1]
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino
this is original text.
length:22
length:21
/home/ea/work/openvino_notebooks_new_clone/openvino_notebooks/notebooks/openvoice/OpenVoice/openvoice/attentions.py:283: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert (
/home/ea/work/openvino_notebooks_new_clone/openvino_notebooks/notebooks/openvoice/OpenVoice/openvoice/attentions.py:346: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
pad_length = max(length - (self.window_size + 1), 0)
/home/ea/work/openvino_notebooks_new_clone/openvino_notebooks/notebooks/openvoice/OpenVoice/openvoice/attentions.py:347: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
slice_start_position = max((self.window_size + 1) - length, 0)
/home/ea/work/openvino_notebooks_new_clone/openvino_notebooks/notebooks/openvoice/OpenVoice/openvoice/attentions.py:349: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if pad_length > 0:
/home/ea/work/openvino_notebooks_new_clone/openvino_notebooks/notebooks/openvoice/OpenVoice/openvoice/transforms.py:114: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if torch.min(inputs) < left or torch.max(inputs) > right:
/home/ea/work/openvino_notebooks_new_clone/openvino_notebooks/notebooks/openvoice/OpenVoice/openvoice/transforms.py:119: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if min_bin_width * num_bins > 1.0:
/home/ea/work/openvino_notebooks_new_clone/openvino_notebooks/notebooks/openvoice/OpenVoice/openvoice/transforms.py:121: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if min_bin_height * num_bins > 1.0:
/home/ea/work/openvino_notebooks_new_clone/openvino_notebooks/notebooks/openvoice/OpenVoice/openvoice/transforms.py:171: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert (discriminant >= 0).all()
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/torch/jit/_trace.py:1116: TracerWarning: Trace had nondeterministic nodes. Did you forget call .eval() on your model? Nodes:
%3293 : Float(1, 2, 43, strides=[86, 43, 1], requires_grad=0, device=cpu) = aten::randn(%3288, %3289, %3290, %3291, %3292) # /home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/nncf/torch/dynamic_graph/wrappers.py:86:0
%5559 : Float(1, 192, 151, strides=[28992, 1, 192], requires_grad=0, device=cpu) = aten::randn_like(%m_p, %5554, %5555, %5556, %5557, %5558) # /home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/nncf/torch/dynamic_graph/wrappers.py:86:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
_check_trace(
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/torch/jit/_trace.py:1116: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
The values for attribute 'shape' do not match: torch.Size([1, 1, 39168]) != torch.Size([1, 1, 39424]).
_check_trace(
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/torch/jit/_trace.py:1116: TracerWarning: Output nr 2. of the traced function does not match the corresponding output of the Python function. Detailed error:
The values for attribute 'shape' do not match: torch.Size([1, 1, 153, 43]) != torch.Size([1, 1, 154, 43]).
_check_trace(
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/torch/jit/_trace.py:1116: TracerWarning: Output nr 3. of the traced function does not match the corresponding output of the Python function. Detailed error:
The values for attribute 'shape' do not match: torch.Size([1, 1, 153]) != torch.Size([1, 1, 154]).
_check_trace(
2024-07-31 19:08:02.879488: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│ Num bits (N) │ % all parameters (layers) │ % ratio-defining parameters (layers) │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│ 8 │ 100% (199 / 199) │ 100% (199 / 199) │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Output()
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/torch/jit/_trace.py:1116: TracerWarning: Trace had nondeterministic nodes. Did you forget call .eval() on your model? Nodes:
%1596 : Float(1, 192, 238, strides=[91392, 238, 1], requires_grad=0, device=cpu) = aten::randn_like(%m, %1591, %1592, %1593, %1594, %1595) # /home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/nncf/torch/dynamic_graph/wrappers.py:86:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
_check_trace(
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/torch/jit/_trace.py:1116: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Tensor-likes are not close!
Mismatched elements: 18501 / 60928 (30.4%)
Greatest absolute difference: 0.010825839242897928 at index (0, 0, 27067) (up to 1e-05 allowed)
Greatest relative difference: 9452.158227848102 at index (0, 0, 45473) (up to 1e-05 allowed)
_check_trace(
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│ Num bits (N) │ % all parameters (layers) │ % ratio-defining parameters (layers) │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│ 8 │ 100% (194 / 194) │ 100% (194 / 194) │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Output()
Inference#
Select inference device#
core = ov.Core()
device = device_widget()
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
Select reference tone#
First of all, select the reference tone of voice to which the generated
text will be converted: your can select from existing ones, record your
own by selecting record_manually
or upload you own file by
load_manually
REFERENCE_VOICES_PATH = f"{repo_dir}/resources/"
reference_speakers = [
*[path for path in os.listdir(REFERENCE_VOICES_PATH) if os.path.splitext(path)[-1] == ".mp3"],
"record_manually",
"load_manually",
]
ref_speaker = widgets.Dropdown(
options=reference_speakers,
value=reference_speakers[0],
description="reference voice from which tone color will be copied",
disabled=False,
)
ref_speaker
Dropdown(description='reference voice from which tone color will be copied', options=('demo_speaker2.mp3', 'ex…
OUTPUT_DIR = Path("outputs/")
OUTPUT_DIR.mkdir(exist_ok=True)
ref_speaker_path = f"{REFERENCE_VOICES_PATH}/{ref_speaker.value}"
allowed_audio_types = ".mp4,.mp3,.wav,.wma,.aac,.m4a,.m4b,.webm"
if ref_speaker.value == "record_manually":
ref_speaker_path = OUTPUT_DIR / "custom_example_sample.webm"
from ipywebrtc import AudioRecorder, CameraStream
camera = CameraStream(constraints={"audio": True, "video": False})
recorder = AudioRecorder(stream=camera, filename=ref_speaker_path, autosave=True)
display(recorder)
elif ref_speaker.value == "load_manually":
upload_ref = widgets.FileUpload(
accept=allowed_audio_types,
multiple=False,
description="Select audio with reference voice",
)
display(upload_ref)
Play the reference voice sample before cloning it’s tone to another speech
def save_audio(voice_source: widgets.FileUpload, out_path: str):
with open(out_path, "wb") as output_file:
assert len(voice_source.value) > 0, "Please select audio file"
output_file.write(voice_source.value[0]["content"])
if ref_speaker.value == "load_manually":
ref_speaker_path = f"{OUTPUT_DIR}/{upload_ref.value[0].name}"
save_audio(upload_ref, ref_speaker_path)
Audio(ref_speaker_path)
torch_hub_local = Path("torch_hub_local/")
%env TORCH_HOME={str(torch_hub_local.absolute())}
Load speaker embeddings
# second step to fix a problem with silero downloading and installing
import os
import zipfile
url = "https://github.com/snakers4/silero-vad/zipball/v3.0"
torch_hub_dir = torch_hub_local / "hub"
torch.hub.set_dir(torch_hub_dir.as_posix())
zip_filename = "v3.0.zip"
output_path = torch_hub_dir / "v3.0"
if not (torch_hub_dir / zip_filename).exists():
download_file(url, directory=torch_hub_dir, filename=zip_filename)
zip_ref = zipfile.ZipFile((torch_hub_dir / zip_filename).as_posix(), "r")
zip_ref.extractall(path=output_path.as_posix())
zip_ref.close()
v3_dirs = [d for d in output_path.iterdir() if "snakers4-silero-vad" in d.as_posix()]
if len(v3_dirs) > 0 and not (torch_hub_dir / "snakers4_silero-vad_v3.0").exists():
v3_dir = str(v3_dirs[0])
os.rename(str(v3_dirs[0]), (torch_hub_dir / "snakers4_silero-vad_v3.0").as_posix())
en_source_default_se = torch.load(f"{en_suffix}/en_default_se.pth")
en_source_style_se = torch.load(f"{en_suffix}/en_style_se.pth")
zh_source_se = torch.load(f"{zh_suffix}/zh_default_se.pth") if enable_chinese_lang else None
target_se, audio_name = se_extractor.get_se(ref_speaker_path, tone_color_converter, target_dir=OUTPUT_DIR, vad=True)
OpenVoice version: v1
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/librosa/util/decorators.py:88: UserWarning: PySoundFile failed. Trying audioread instead. return f(*args, **kwargs) /home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/torch/hub.py:293: UserWarning: You are about to download and run code from an untrusted repository. In a future release, this won't be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, or load(..., trust_repo=True), which will assume that the prompt is to be answered with 'yes'. You can also use load(..., trust_repo='check') which will only prompt for confirmation if the repo is not already trusted. This will eventually be the default behaviour warnings.warn( Downloading: "https://github.com/snakers4/silero-vad/zipball/v3.0" to /home/ea/.cache/torch/hub/v3.0.zip /home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/torch/nn/modules/module.py:1541: UserWarning: A window was not provided. A rectangular window will be applied,which is known to cause spectral leakage. Other windows such as torch.hann_window or torch.hamming_window can are recommended to reduce spectral leakage.To suppress this warning and use a rectangular window, explicitly set window=torch.ones(n_fft, device=<device>). (Triggered internally at ../aten/src/ATen/native/SpectralOps.cpp:836.) return forward_call(*args, **kwargs)
[(0.0, 8.21), (9.292, 13.106), (13.228, 16.466), (16.684, 29.49225)]
after vad: dur = 28.07
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/torch/functional.py:665: UserWarning: stft with return_complex=False is deprecated. In a future pytorch release, stft will return complex tensors for all inputs, and return_complex=False will raise an error.
Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at ../aten/src/ATen/native/SpectralOps.cpp:873.)
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]
Replace original infer methods of OpenVoiceBaseClass
with optimized
OpenVINO inference.
There are pre and post processings that are not traceable and could not
be offloaded to OpenVINO, instead of writing such processing ourselves
we will rely on the already existing ones. We just replace infer and
voice conversion functions of OpenVoiceBaseClass
so that the the
most computationally expensive part is done in OpenVINO.
def get_pathched_infer(ov_model: ov.Model, device: str) -> callable:
compiled_model = core.compile_model(ov_model, device)
def infer_impl(x, x_lengths, sid, noise_scale, length_scale, noise_scale_w):
ov_output = compiled_model((x, x_lengths, sid, noise_scale, length_scale, noise_scale_w))
return (torch.tensor(ov_output[0]),)
return infer_impl
def get_patched_voice_conversion(ov_model: ov.Model, device: str) -> callable:
compiled_model = core.compile_model(ov_model, device)
def voice_conversion_impl(y, y_lengths, sid_src, sid_tgt, tau):
ov_output = compiled_model((y, y_lengths, sid_src, sid_tgt, tau))
return (torch.tensor(ov_output[0]),)
return voice_conversion_impl
en_base_speaker_tts.model.infer = get_pathched_infer(ov_en_tts, device.value)
tone_color_converter.model.voice_conversion = get_patched_voice_conversion(ov_voice_conversion, device.value)
if enable_chinese_lang:
zh_base_speaker_tts.model.infer = get_pathched_infer(ov_zh_tts, device.value)
Run inference#
voice_source = widgets.Dropdown(
options=["use TTS", "choose_manually"],
value="use TTS",
description="Voice source",
disabled=False,
)
voice_source
Dropdown(description='Voice source', options=('use TTS', 'choose_manually'), value='use TTS')
if voice_source.value == "choose_manually":
upload_orig_voice = widgets.FileUpload(
accept=allowed_audio_types,
multiple=False,
description="audo whose tone will be replaced",
)
display(upload_orig_voice)
if voice_source.value == "choose_manually":
orig_voice_path = f"{OUTPUT_DIR}/{upload_orig_voice.value[0].name}"
save_audio(upload_orig_voice, orig_voice_path)
source_se, _ = se_extractor.get_se(orig_voice_path, tone_color_converter, target_dir=OUTPUT_DIR, vad=True)
else:
text = """
OpenVINO toolkit is a comprehensive toolkit for quickly developing applications and solutions that solve
a variety of tasks including emulation of human vision, automatic speech recognition, natural language processing,
recommendation systems, and many others.
"""
source_se = en_source_default_se
orig_voice_path = OUTPUT_DIR / "tmp.wav"
en_base_speaker_tts.tts(text, orig_voice_path, speaker="default", language="English")
> Text splitted to sentences.
OpenVINO toolkit is a comprehensive toolkit for quickly developing applications and solutions that solve a variety of tasks including emulation of human vision,
automatic speech recognition, natural language processing, recommendation systems, and many others.
> ===========================
ˈoʊpən vino* toolkit* ɪz ə ˌkɑmpɹiˈhɛnsɪv toolkit* fəɹ kˈwɪkli dɪˈvɛləpɪŋ ˌæpləˈkeɪʃənz ənd səˈluʃənz ðət sɑɫv ə vəɹˈaɪəti əv tæsks ˌɪnˈkludɪŋ ˌɛmjəˈleɪʃən əv ˈjumən ˈvɪʒən,
length:173
length:173
ˌɔtəˈmætɪk spitʃ ˌɹɛkɪgˈnɪʃən, ˈnætʃəɹəɫ ˈlæŋgwɪdʒ ˈpɹɑsɛsɪŋ, ˌɹɛkəmənˈdeɪʃən ˈsɪstəmz, ənd ˈmɛni ˈəðəɹz.
length:105
length:105
And finally, run voice tone conversion with OpenVINO optimized model
tau_slider = widgets.FloatSlider(
value=0.3,
min=0.01,
max=2.0,
step=0.01,
description="tau",
disabled=False,
readout_format=".2f",
)
tau_slider
FloatSlider(value=0.3, description='tau', max=2.0, min=0.01, step=0.01)
resulting_voice_path = OUTPUT_DIR / "output_with_cloned_voice_tone.wav"
tone_color_converter.convert(
audio_src_path=orig_voice_path,
src_se=source_se,
tgt_se=target_se,
output_path=resulting_voice_path,
tau=tau_slider.value,
message="@MyShell",
)
Audio(orig_voice_path)
Audio(resulting_voice_path)
Run OpenVoice Gradio interactive demo#
We can also use Gradio app to run TTS and voice tone conversion online.
import gradio as gr
import langid
supported_languages = ["zh", "en"]
def predict_impl(
prompt,
style,
audio_file_pth,
agree,
output_dir,
tone_color_converter,
en_tts_model,
zh_tts_model,
en_source_default_se,
en_source_style_se,
zh_source_se,
):
text_hint = ""
if not agree:
text_hint += "[ERROR] Please accept the Terms & Condition!\n"
gr.Warning("Please accept the Terms & Condition!")
return (
text_hint,
None,
None,
)
language_predicted = langid.classify(prompt)[0].strip()
print(f"Detected language:{language_predicted}")
if language_predicted not in supported_languages:
text_hint += f"[ERROR] The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}\n"
gr.Warning(f"The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}")
return (
text_hint,
None,
)
if language_predicted == "zh":
tts_model = zh_tts_model
if zh_tts_model is None:
gr.Warning("TTS model for Chinece language was not loaded please set 'enable_chinese_lang=True`")
return (
text_hint,
None,
)
source_se = zh_source_se
language = "Chinese"
if style not in ["default"]:
text_hint += f"[ERROR] The style {style} is not supported for Chinese, which should be in ['default']\n"
gr.Warning(f"The style {style} is not supported for Chinese, which should be in ['default']")
return (
text_hint,
None,
)
else:
tts_model = en_tts_model
if style == "default":
source_se = en_source_default_se
else:
source_se = en_source_style_se
language = "English"
supported_styles = [
"default",
"whispering",
"shouting",
"excited",
"cheerful",
"terrified",
"angry",
"sad",
"friendly",
]
if style not in supported_styles:
text_hint += f"[ERROR] The style {style} is not supported for English, which should be in {*supported_styles,}\n"
gr.Warning(f"The style {style} is not supported for English, which should be in {*supported_styles,}")
return (
text_hint,
None,
)
speaker_wav = audio_file_pth
if len(prompt) < 2:
text_hint += "[ERROR] Please give a longer prompt text \n"
gr.Warning("Please give a longer prompt text")
return (
text_hint,
None,
)
if len(prompt) > 200:
text_hint += (
"[ERROR] Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo and try for your usage \n"
)
gr.Warning("Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo for your usage")
return (
text_hint,
None,
)
# note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference
try:
target_se, audio_name = se_extractor.get_se(speaker_wav, tone_color_converter, target_dir=OUTPUT_DIR, vad=True)
except Exception as e:
text_hint += f"[ERROR] Get target tone color error {str(e)} \n"
gr.Warning("[ERROR] Get target tone color error {str(e)} \n")
return (
text_hint,
None,
)
src_path = f"{output_dir}/tmp.wav"
tts_model.tts(prompt, src_path, speaker=style, language=language)
save_path = f"{output_dir}/output.wav"
encode_message = "@MyShell"
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message,
)
text_hint += "Get response successfully \n"
return (
text_hint,
src_path,
save_path,
)
from functools import partial
predict = partial(
predict_impl,
output_dir=OUTPUT_DIR,
tone_color_converter=tone_color_converter,
en_tts_model=en_base_speaker_tts,
zh_tts_model=zh_base_speaker_tts,
en_source_default_se=en_source_default_se,
en_source_style_se=en_source_style_se,
zh_source_se=zh_source_se,
)
if not Path("gradio_helper.py").exists():
r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/openvoice/gradio_helper.py")
open("gradio_helper.py", "w").write(r.text)
from gradio_helper import make_demo
demo = make_demo(fn=predict)
try:
demo.queue(max_size=2).launch(debug=True, height=1000)
except Exception:
demo.queue(max_size=2).launch(share=True, debug=True, height=1000)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/
# please uncomment and run this cell for stopping gradio interface
# demo.close()
Cleanup#
# import shutil
# shutil.rmtree(CKPT_BASE_PATH)
# shutil.rmtree(IRS_PATH)
# shutil.rmtree(OUTPUT_DIR)