Voice tone cloning with OpenVoice and OpenVINO#
This Jupyter notebook can be launched on-line, opening an interactive environment in a browser window. You can also make a local installation. Choose one of the following options:
OpenVoice is a versatile instant voice tone transferring and generating speech in various languages with just a brief audio snippet from the source speaker. OpenVoice has three main features: (i) high quality tone color replication with multiple languages and accents; (ii) it provides fine-tuned control over voice styles, including emotions, accents, as well as other parameters such as rhythm, pauses, and intonation. (iii) OpenVoice achieves zero-shot cross-lingual voice cloning, eliminating the need for the generated speech and the reference speech to be part of a massive-speaker multilingual training dataset.
More details about model can be found in project web page, paper, and official repository
This notebook provides example of converting PyTorch OpenVoice model to OpenVINO IR. In this tutorial we will explore how to convert and run OpenVoice using OpenVINO.
Table of contents:
Installation Instructions#
This is a self-contained example that relies solely on its own code.
We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.
Clone repository and install requirements#
import sys
from pathlib import Path
repo_dir = Path("OpenVoice")
if not repo_dir.exists():
!git clone https://github.com/myshell-ai/OpenVoice
orig_english_path = Path("OpenVoice/openvoice/text/_orig_english.py")
english_path = Path("OpenVoice/openvoice/text/english.py")
english_path.rename(orig_english_path)
with orig_english_path.open("r") as f:
data = f.read()
data = data.replace("unidecode", "anyascii")
with english_path.open("w") as out_f:
out_f.write(data)
# append to sys.path so that modules from the repo could be imported
sys.path.append(str(repo_dir))
# fix a problem with silero downloading and installing
with Path("OpenVoice/openvoice/se_extractor.py").open("r") as orig_file:
data = orig_file.read()
data = data.replace("method=\"silero\"", "method=\"silero:3.0\"")
with Path("OpenVoice/openvoice/se_extractor.py").open("w") as out_f:
out_f.write(data)
%pip install -q "librosa>=0.8.1" "wavmark>=0.0.3" "faster-whisper>=0.9.0" "pydub>=0.25.1" "whisper-timestamped>=1.14.2" "tqdm" "inflect>=7.0.0" "eng_to_ipa>=0.0.2" "pypinyin>=0.50.0" \
"cn2an>=0.5.22" "jieba>=0.42.1" "langid>=1.1.6" "gradio>=4.15" "ipywebrtc" "anyascii" "openvino>=2023.3" "torch>=2.1" "nncf>=2.11.0"
Cloning into 'OpenVoice'...
remote: Enumerating objects: 435, done.[K
remote: Counting objects: 100% (235/235), done.[K
remote: Compressing objects: 100% (110/110), done.[K
remote: Total 435 (delta 176), reused 126 (delta 125), pack-reused 200[K
Receiving objects: 100% (435/435), 3.82 MiB | 19.57 MiB/s, done.
Resolving deltas: 100% (219/219), done.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
mobileclip 0.1.0 requires torch==1.13.1, but you have torch 2.3.1 which is incompatible.
mobileclip 0.1.0 requires torchvision==0.14.1, but you have torchvision 0.17.2+cpu which is incompatible.
torchvision 0.17.2+cpu requires torch==2.2.2, but you have torch 2.3.1 which is incompatible.
Note: you may need to restart the kernel to use updated packages.
Download checkpoints and load PyTorch model#
import os
import torch
import openvino as ov
import ipywidgets as widgets
from IPython.display import Audio
core = ov.Core()
from openvoice.api import BaseSpeakerTTS, ToneColorConverter, OpenVoiceBaseClass
import openvoice.se_extractor as se_extractor
Importing the dtw module. When using in academic works please cite:
T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
J. Stat. Soft., doi:10.18637/jss.v031.i07.
CKPT_BASE_PATH = Path("checkpoints")
en_suffix = CKPT_BASE_PATH / "base_speakers/EN"
zh_suffix = CKPT_BASE_PATH / "base_speakers/ZH"
converter_suffix = CKPT_BASE_PATH / "converter"
To make notebook lightweight by default model for Chinese speech is not
activated, in order turn on please set flag enable_chinese_lang
to
True
enable_chinese_lang = False
def download_from_hf_hub(filename, local_dir="./"):
from huggingface_hub import hf_hub_download
local_path = Path(local_dir)
local_path.mkdir(exist_ok=True)
hf_hub_download(repo_id="myshell-ai/OpenVoice", filename=filename, local_dir=local_path)
download_from_hf_hub(f"{converter_suffix.as_posix()}/checkpoint.pth")
download_from_hf_hub(f"{converter_suffix.as_posix()}/config.json")
download_from_hf_hub(f"{en_suffix.as_posix()}/checkpoint.pth")
download_from_hf_hub(f"{en_suffix.as_posix()}/config.json")
download_from_hf_hub(f"{en_suffix.as_posix()}/en_default_se.pth")
download_from_hf_hub(f"{en_suffix.as_posix()}/en_style_se.pth")
if enable_chinese_lang:
download_from_hf_hub(f"{zh_suffix.as_posix()}/checkpoint.pth")
download_from_hf_hub(f"{zh_suffix.as_posix()}/config.json")
download_from_hf_hub(f"{zh_suffix.as_posix()}/zh_default_se.pth")
checkpoint.pth: 0%| | 0.00/131M [00:00<?, ?B/s]
checkpoints/converter/config.json: 0%| | 0.00/850 [00:00<?, ?B/s]
checkpoint.pth: 0%| | 0.00/160M [00:00<?, ?B/s]
checkpoints/base_speakers/EN/config.json: 0%| | 0.00/1.97k [00:00<?, ?B/s]
en_default_se.pth: 0%| | 0.00/1.79k [00:00<?, ?B/s]
en_style_se.pth: 0%| | 0.00/1.78k [00:00<?, ?B/s]
pt_device = "cpu"
en_base_speaker_tts = BaseSpeakerTTS(en_suffix / "config.json", device=pt_device)
en_base_speaker_tts.load_ckpt(en_suffix / "checkpoint.pth")
tone_color_converter = ToneColorConverter(converter_suffix / "config.json", device=pt_device)
tone_color_converter.load_ckpt(converter_suffix / "checkpoint.pth")
if enable_chinese_lang:
zh_base_speaker_tts = BaseSpeakerTTS(zh_suffix / "config.json", device=pt_device)
zh_base_speaker_tts.load_ckpt(zh_suffix/ "checkpoint.pth")
else:
zh_base_speaker_tts = None
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
Loaded checkpoint 'checkpoints/base_speakers/EN/checkpoint.pth'
missing/unexpected keys: [] []
Loaded checkpoint 'checkpoints/converter/checkpoint.pth'
missing/unexpected keys: [] []
Convert models to OpenVINO IR#
There are 2 models in OpenVoice: first one is responsible for speech
generation BaseSpeakerTTS
and the second one ToneColorConverter
imposes arbitrary voice tone to the original speech. To convert to
OpenVino IR format first we need to get acceptable torch.nn.Module
object. Both ToneColorConverter, BaseSpeakerTTS instead of using
self.forward
as the main entry point use custom infer
and
convert_voice
methods respectively, therefore need to wrap them with
a custom class that is inherited from torch.nn.Module.
class OVOpenVoiceBase(torch.nn.Module):
"""
Base class for both TTS and voice tone conversion model: constructor is same for both of them.
"""
def __init__(self, voice_model: OpenVoiceBaseClass):
super().__init__()
self.voice_model = voice_model
for par in voice_model.model.parameters():
par.requires_grad = False
class OVOpenVoiceTTS(OVOpenVoiceBase):
"""
Constructor of this class accepts BaseSpeakerTTS object for speech generation and wraps it's 'infer' method with forward.
"""
def get_example_input(self):
stn_tst = self.voice_model.get_text("this is original text", self.voice_model.hps, False)
x_tst = stn_tst.unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
speaker_id = torch.LongTensor([1])
noise_scale = torch.tensor(0.667)
length_scale = torch.tensor(1.0)
noise_scale_w = torch.tensor(0.6)
return (
x_tst,
x_tst_lengths,
speaker_id,
noise_scale,
length_scale,
noise_scale_w,
)
def forward(self, x, x_lengths, sid, noise_scale, length_scale, noise_scale_w):
return self.voice_model.model.infer(x, x_lengths, sid, noise_scale, length_scale, noise_scale_w)
class OVOpenVoiceConverter(OVOpenVoiceBase):
"""
Constructor of this class accepts ToneColorConverter object for voice tone conversion and wraps it's 'voice_conversion' method with forward.
"""
def get_example_input(self):
y = torch.randn([1, 513, 238], dtype=torch.float32)
y_lengths = torch.LongTensor([y.size(-1)])
target_se = torch.randn(*(1, 256, 1))
source_se = torch.randn(*(1, 256, 1))
tau = torch.tensor(0.3)
return (y, y_lengths, source_se, target_se, tau)
def forward(self, y, y_lengths, sid_src, sid_tgt, tau):
return self.voice_model.model.voice_conversion(y, y_lengths, sid_src, sid_tgt, tau)
Convert to OpenVino IR and save to IRs_path folder for the future use. If IRs already exist skip conversion and read them directly
For reducing memory consumption, weights compression optimization can be applied using NNCF. Weight compression aims to reduce the memory footprint of a model. models, which require extensive memory to store the weights during inference, can benefit from weight compression in the following ways:
enabling the inference of exceptionally large models that cannot be accommodated in the memory of the device;
improving the inference performance of the models by reducing the latency of the memory access when computing the operations with weights, for example, Linear layers.
Neural Network Compression Framework (NNCF) provides 4-bit / 8-bit mixed weight quantization as a compression method. The main difference between weights compression and full model quantization (post-training quantization) is that activations remain floating-point in the case of weights compression which leads to a better accuracy. In addition, weight compression is data-free and does not require a calibration dataset, making it easy to use.
nncf.compress_weights
function can be used for performing weights
compression. The function accepts an OpenVINO model and other
compression parameters.
More details about weights compression can be found in OpenVINO documentation.
import nncf
IRS_PATH = Path("openvino_irs/")
EN_TTS_IR = IRS_PATH / "openvoice_en_tts.xml"
ZH_TTS_IR = IRS_PATH / "openvoice_zh_tts.xml"
VOICE_CONVERTER_IR = IRS_PATH / "openvoice_tone_conversion.xml"
paths = [EN_TTS_IR, VOICE_CONVERTER_IR]
models = [
OVOpenVoiceTTS(en_base_speaker_tts),
OVOpenVoiceConverter(tone_color_converter),
]
if enable_chinese_lang:
models.append(OVOpenVoiceTTS(zh_base_speaker_tts))
paths.append(ZH_TTS_IR)
ov_models = []
for model, path in zip(models, paths):
if not path.exists():
ov_model = ov.convert_model(model, example_input=model.get_example_input())
ov_model = nncf.compress_weights(ov_model)
ov.save_model(ov_model, path)
else:
ov_model = core.read_model(path)
ov_models.append(ov_model)
ov_en_tts, ov_voice_conversion = ov_models[:2]
if enable_chinese_lang:
ov_zh_tts = ov_models[-1]
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino
this is original text.
length:22
length:21
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/cuda/__init__.py:118: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)
return torch._C._cuda_getDeviceCount() > 0
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/cuda/__init__.py:619: UserWarning: Can't initialize NVML
warnings.warn("Can't initialize NVML")
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/notebooks/openvoice/OpenVoice/openvoice/attentions.py:283: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert (
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/notebooks/openvoice/OpenVoice/openvoice/attentions.py:346: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
pad_length = max(length - (self.window_size + 1), 0)
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/notebooks/openvoice/OpenVoice/openvoice/attentions.py:347: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
slice_start_position = max((self.window_size + 1) - length, 0)
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/notebooks/openvoice/OpenVoice/openvoice/attentions.py:349: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if pad_length > 0:
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/notebooks/openvoice/OpenVoice/openvoice/transforms.py:114: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if torch.min(inputs) < left or torch.max(inputs) > right:
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/notebooks/openvoice/OpenVoice/openvoice/transforms.py:119: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if min_bin_width * num_bins > 1.0:
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/notebooks/openvoice/OpenVoice/openvoice/transforms.py:121: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if min_bin_height * num_bins > 1.0:
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/notebooks/openvoice/OpenVoice/openvoice/transforms.py:171: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert (discriminant >= 0).all()
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/jit/_trace.py:1116: TracerWarning: Trace had nondeterministic nodes. Did you forget call .eval() on your model? Nodes:
%3293 : Float(1, 2, 43, strides=[86, 43, 1], requires_grad=0, device=cpu) = aten::randn(%3288, %3289, %3290, %3291, %3292) # /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/nncf/torch/dynamic_graph/wrappers.py:86:0
%5559 : Float(1, 192, 154, strides=[29568, 1, 192], requires_grad=0, device=cpu) = aten::randn_like(%m_p, %5554, %5555, %5556, %5557, %5558) # /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/nncf/torch/dynamic_graph/wrappers.py:86:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
_check_trace(
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/jit/_trace.py:1116: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
The values for attribute 'shape' do not match: torch.Size([1, 1, 38912]) != torch.Size([1, 1, 38656]).
_check_trace(
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/jit/_trace.py:1116: TracerWarning: Output nr 2. of the traced function does not match the corresponding output of the Python function. Detailed error:
The values for attribute 'shape' do not match: torch.Size([1, 1, 152, 43]) != torch.Size([1, 1, 151, 43]).
_check_trace(
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/jit/_trace.py:1116: TracerWarning: Output nr 3. of the traced function does not match the corresponding output of the Python function. Detailed error:
The values for attribute 'shape' do not match: torch.Size([1, 1, 152]) != torch.Size([1, 1, 151]).
_check_trace(
2024-08-07 02:06:50.138216: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│ Num bits (N) │ % all parameters (layers) │ % ratio-defining parameters (layers) │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│ 8 │ 100% (199 / 199) │ 100% (199 / 199) │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Output()
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/jit/_trace.py:1116: TracerWarning: Trace had nondeterministic nodes. Did you forget call .eval() on your model? Nodes:
%1596 : Float(1, 192, 238, strides=[91392, 238, 1], requires_grad=0, device=cpu) = aten::randn_like(%m, %1591, %1592, %1593, %1594, %1595) # /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/nncf/torch/dynamic_graph/wrappers.py:86:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
_check_trace(
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-744/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/jit/_trace.py:1116: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Tensor-likes are not close!
Mismatched elements: 25090 / 60928 (41.2%)
Greatest absolute difference: 0.012138169724494219 at index (0, 0, 29268) (up to 1e-05 allowed)
Greatest relative difference: 14193.028616852147 at index (0, 0, 49972) (up to 1e-05 allowed)
_check_trace(
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│ Num bits (N) │ % all parameters (layers) │ % ratio-defining parameters (layers) │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│ 8 │ 100% (194 / 194) │ 100% (194 / 194) │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Output()
Inference#
Select inference device#
core = ov.Core()
device = widgets.Dropdown(
options=core.available_devices + ["AUTO"],
value="AUTO",
description="Device:",
disabled=False,
)
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
Select reference tone#
First of all, select the reference tone of voice to which the generated
text will be converted: your can select from existing ones, record your
own by selecting record_manually
or upload you own file by
load_manually
REFERENCE_VOICES_PATH = f"{repo_dir}/resources/"
reference_speakers = [
*[path for path in os.listdir(REFERENCE_VOICES_PATH) if os.path.splitext(path)[-1] == ".mp3"],
"record_manually",
"load_manually",
]
ref_speaker = widgets.Dropdown(
options=reference_speakers,
value=reference_speakers[0],
description="reference voice from which tone color will be copied",
disabled=False,
)
ref_speaker
Dropdown(description='reference voice from which tone color will be copied', options=('demo_speaker2.mp3', 'de…
OUTPUT_DIR = Path("outputs/")
OUTPUT_DIR.mkdir(exist_ok=True)
ref_speaker_path = f"{REFERENCE_VOICES_PATH}/{ref_speaker.value}"
allowed_audio_types = ".mp4,.mp3,.wav,.wma,.aac,.m4a,.m4b,.webm"
if ref_speaker.value == "record_manually":
ref_speaker_path = OUTPUT_DIR / "custom_example_sample.webm"
from ipywebrtc import AudioRecorder, CameraStream
camera = CameraStream(constraints={"audio": True, "video": False})
recorder = AudioRecorder(stream=camera, filename=ref_speaker_path, autosave=True)
display(recorder)
elif ref_speaker.value == "load_manually":
upload_ref = widgets.FileUpload(
accept=allowed_audio_types,
multiple=False,
description="Select audio with reference voice",
)
display(upload_ref)
Play the reference voice sample before cloning it’s tone to another speech
def save_audio(voice_source: widgets.FileUpload, out_path: str):
with open(out_path, "wb") as output_file:
assert len(voice_source.value) > 0, "Please select audio file"
output_file.write(voice_source.value[0]["content"])
if ref_speaker.value == "load_manually":
ref_speaker_path = f"{OUTPUT_DIR}/{upload_ref.value[0].name}"
save_audio(upload_ref, ref_speaker_path)
Audio(ref_speaker_path)