Skip to content

Commit

Permalink
add doc about how to convert mms-tts models to sherpa-onnx (#519)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Dec 14, 2023
1 parent 18fcf08 commit 66193b5
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 0 deletions.
Binary file added docs/source/_static/mms/mms-eng.wav
Binary file not shown.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,5 @@ def get_version():
.. _aishell3: https://www.openslr.org/93/
.. _lessac_blizzard2013: https://www.cstr.ed.ac.uk/projects/blizzard/2013/lessac_blizzard2013/
.. _OpenFst: https://www.openfst.org/
.. _MMS: https://huggingface.co/spaces/mms-meta/MMS
"""
150 changes: 150 additions & 0 deletions docs/source/onnx/tts/code/vits-mms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#!/usr/bin/env python3

import collections
import os
from typing import Any, Dict

import onnx
import torch
from vits import commons, utils
from vits.models import SynthesizerTrn


class OnnxModel(torch.nn.Module):
def __init__(self, model: SynthesizerTrn):
super().__init__()
self.model = model

def forward(
self,
x,
x_lengths,
noise_scale=0.667,
length_scale=1.0,
noise_scale_w=0.8,
):
return self.model.infer(
x=x,
x_lengths=x_lengths,
noise_scale=noise_scale,
length_scale=length_scale,
noise_scale_w=noise_scale_w,
)[0]


def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)

onnx.save(model, filename)


def load_vocab():
return [
x.replace("\n", "") for x in open("vocab.txt", encoding="utf-8").readlines()
]


@torch.no_grad()
def main():
hps = utils.get_hparams_from_file("config.json")
is_uroman = hps.data.training_files.split(".")[-1] == "uroman"
if is_uroman:
raise ValueError("We don't support uroman!")

symbols = load_vocab()

# Now generate tokens.txt
all_upper_tokens = [i.upper() for i in symbols]
duplicate = set(
[
item
for item, count in collections.Counter(all_upper_tokens).items()
if count > 1
]
)

print("generate tokens.txt")

with open("tokens.txt", "w", encoding="utf-8") as f:
for idx, token in enumerate(symbols):
f.write(f"{token} {idx}\n")

# both upper case and lower case correspond to the same ID
if (
token.lower() != token.upper()
and len(token.upper()) == 1
and token.upper() not in duplicate
):
f.write(f"{token.upper()} {idx}\n")

net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model,
)
net_g.cpu()
_ = net_g.eval()

_ = utils.load_checkpoint("G_100000.pth", net_g, None)

model = OnnxModel(net_g)

x = torch.randint(low=1, high=10, size=(50,), dtype=torch.int64)
x = x.unsqueeze(0)

x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
length_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_w = torch.tensor([1], dtype=torch.float32)

opset_version = 13

filename = "model.onnx"

torch.onnx.export(
model,
(x, x_length, noise_scale, length_scale, noise_scale_w),
filename,
opset_version=opset_version,
input_names=[
"x",
"x_length",
"noise_scale",
"length_scale",
"noise_scale_w",
],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 1: "L"}, # n_audio is also known as batch_size
"x_length": {0: "N"},
"y": {0: "N", 2: "L"},
},
)
meta_data = {
"model_type": "vits",
"comment": "mms",
"url": "https://huggingface.co/facebook/mms-tts/tree/main",
"add_blank": int(hps.data.add_blank),
"language": os.environ.get("language", "unknown"),
"frontend": "characters",
"n_speakers": int(hps.data.n_speakers),
"sample_rate": hps.data.sampling_rate,
}
print("meta_data", meta_data)
add_meta_data(filename=filename, meta_data=meta_data)


main()
1 change: 1 addition & 0 deletions docs/source/onnx/tts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ to install `sherpa-onnx`_ before you continue.
./hf-space.rst
./pretrained_models/index
./piper
./mms
./faq
125 changes: 125 additions & 0 deletions docs/source/onnx/tts/mms.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
MMS
===

This section describes how to convert models
from `<https://huggingface.co/facebook/mms-tts/tree/main>`_
to `sherpa-onnx`_.

Note that `facebook/mms-tts <https://huggingface.co/facebook/mms-tts/tree/main>`_
supports more than 1000 languages. You can try models from
`facebook/mms-tts <https://huggingface.co/facebook/mms-tts/tree/main>`_ at
the huggingface space `<https://huggingface.co/spaces/mms-meta/MMS>`_.

You can try the converted models by visiting `<https://huggingface.co/spaces/k2-fsa/text-to-speech>`_.
To download the converted models, please visit `<https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models>`_.
If a filename contains ``vits-mms``, it means the model is from
`facebook/mms-tts <https://huggingface.co/facebook/mms-tts/tree/main>`_.

Install dependencies
--------------------

.. code-block:: bash
pip install -qq onnx scipy Cython
pip install -qq torch==1.13.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
Download the model file
-----------------------

Suppose that we want to convert the English model, we need to
use the following commands to download the model:

.. code-block:: bash
name=eng
wget -q https://huggingface.co/facebook/mms-tts/resolve/main/models/$name/G_100000.pth
wget -q https://huggingface.co/facebook/mms-tts/resolve/main/models/$name/config.json
wget -q https://huggingface.co/facebook/mms-tts/resolve/main/models/$name/vocab.txt
Download MMS source code
------------------------

.. code-block:: bash
git clone https://huggingface.co/spaces/mms-meta/MMS
export PYTHONPATH=$PWD/MMS:$PYTHONPATH
export PYTHONPATH=$PWD/MMS/vits:$PYTHONPATH
pushd MMS/vits/monotonic_align
python3 setup.py build
ls -lh build/
ls -lh build/lib*/
ls -lh build/lib*/*/
cp build/lib*/vits/monotonic_align/core*.so .
sed -i.bak s/.monotonic_align.core/.core/g ./__init__.py
popd
Convert the model
-----------------

Please save the following code into a file with name ``./vits-mms.py``:

.. literalinclude:: ./code/vits-mms.py

The you can run it with:

.. code-block:: bash
export PYTHONPATH=$PWD/MMS:$PYTHONPATH
export PYTHONPATH=$PWD/MMS/vits:$PYTHONPATH
export lang=eng
python3 ./vits-mms.py
It will generate the following two files:

- ``model.onnx``
- ``tokens.txt``

Use the converted model
-----------------------

We can use the converted model with the following command after installing
`sherpa-onnx`_.

.. code-block:: bash
./build/bin/sherpa-onnx-offline-tts \
--vits-model=./model.onnx \
--vits-tokens=./tokens.txt \
--debug=1 \
--output-filename=./mms-eng.wav \
"How are you doing today? This is a text-to-speech application using models from facebook with next generation Kaldi"
The above command should generate a wave file ``mms-eng.wav``.

.. raw:: html

<table>
<tr>
<th>Wave filename</th>
<th>Content</th>
<th>Text</th>
</tr>
<tr>
<td>mms-eng.wav</td>
<td>
<audio title="Generated ./mms-eng.wav" controls="controls">
<source src="/sherpa/_static/mms/mms-eng.wav" type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
How are you doing today? This is a text-to-speech application using models from facebook with next generation Kaldi
</td>
</tr>
</table>


Congratulations! You have successfully converted a model from `MMS`_ and run it with `sherpa-onnx`_.

We are using ``eng`` in this section as an example, you can replace it with other languages, such as
``deu`` for German, ``fra`` for French, etc.

0 comments on commit 66193b5

Please sign in to comment.