Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev pr2 : handle multi-speaker and GST in synthetizer class #5

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions TTS/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from TTS.utils.synthesizer import Synthesizer
from TTS.utils.manage import ModelManager
from TTS.utils.io import load_config
from TTS.utils.generic_utils import style_wav_uri_to_dict


def create_argparser():
Expand Down Expand Up @@ -75,14 +76,22 @@ def convert_boolean(x):
if not args.vocoder_config and os.path.isfile(vocoder_config_file):
args.vocoder_config = vocoder_config_file


synthesizer = Synthesizer(args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda)

use_speaker_embedding = synthesizer.tts_config.get("use_external_speaker_embedding_file", False)
use_gst = synthesizer.tts_config.get("use_gst", False)
app = Flask(__name__)


@app.route('/')
def index():
return render_template('index.html', show_details=args.show_details)
return render_template(
'index.html',
show_details=args.show_details,
use_speaker_embedding=use_speaker_embedding,
use_gst = use_gst
)

@app.route('/details')
def details():
Expand All @@ -102,8 +111,12 @@ def details():
@app.route('/api/tts', methods=['GET'])
def tts():
text = request.args.get('text')
speaker_json_key = request.args.get('speaker', "")
style_wav = request.args.get('style-wav', "")

style_wav = style_wav_uri_to_dict(style_wav)
print(" > Model input: {}".format(text))
wavs = synthesizer.tts(text)
wavs = synthesizer.tts(text, speaker_json_key=speaker_json_key, style_wav=style_wav)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
return send_file(out, mimetype='audio/wav')
Expand Down
25 changes: 21 additions & 4 deletions TTS/server/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@

<ul class="list-unstyled">
</ul>
{%if use_speaker_embedding%}
<input id="speaker-json-key" placeholder="speaker json key.." size=45 type="text" name="speaker-json-key">
{%endif%}

{%if use_gst%}
<input value='{"0": 0.1}' id="style-wav" placeholder="style wav (dict or path ot wav).." size=45 type="text" name="style-wav">
{%endif%}

<input id="text" placeholder="Type here..." size=45 type="text" name="text">
<button id="speak-button" name="speak">Speak</button><br/><br/>
{%if show_details%}
Expand All @@ -73,15 +81,24 @@

<!-- Bootstrap core JavaScript -->
<script>
function getTextValue(textId) {
const container = q(textId)
if (container) {
return container.value
}
return ""
}
function q(selector) {return document.querySelector(selector)}
q('#text').focus()
function do_tts(e) {
text = q('#text').value
const text = q('#text').value
const speakerJsonKey = getTextValue('#speaker-json-key')
const styleWav = getTextValue('#style-wav')
if (text) {
q('#message').textContent = 'Synthesizing...'
q('#speak-button').disabled = true
q('#audio').hidden = true
synthesize(text)
synthesize(text, speakerJsonKey, styleWav)
}
e.preventDefault()
return false
Expand All @@ -92,8 +109,8 @@
do_tts(e)
}
})
function synthesize(text) {
fetch('/api/tts?text=' + encodeURIComponent(text), {cache: 'no-cache'})
function synthesize(text, speakerJsonKey="", styleWav="") {
fetch(`/api/tts?text=${encodeURIComponent(text)}&speaker=${encodeURIComponent(speakerJsonKey)}&style-wav=${encodeURIComponent(styleWav)}` , {cache: 'no-cache'})
.then(function(res) {
if (!res.ok) throw Error(res.statusText)
return res.blob()
Expand Down
23 changes: 11 additions & 12 deletions TTS/tts/utils/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,22 @@ def compute_style_mel(style_wav, ap, cuda=False):

def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
if 'tacotron' in CONFIG.model.lower():
if CONFIG.use_gst:
if not CONFIG.use_gst:
style_mel = None

if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
else:
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
elif 'glow' in CONFIG.model.lower():
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, 'module'):
# distributed model
postnet_output, _, _, _, alignments, _, _ = model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
postnet_output, _, _, _, alignments, _, _ = model.module.inference(inputs, inputs_lengths, speaker_embedding_g)
else:
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, speaker_embedding_g)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
decoder_output = None
Expand All @@ -77,9 +76,9 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, 'module'):
# distributed model
postnet_output, alignments= model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
postnet_output, alignments= model.module.inference(inputs, inputs_lengths, speaker_embedding_g)
else:
postnet_output, alignments= model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
postnet_output, alignments= model.inference(inputs, inputs_lengths, speaker_embedding_g)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
decoder_output = None
Expand Down
19 changes: 19 additions & 0 deletions TTS/utils/generic_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import datetime
import glob
import json
import os
import shutil
import subprocess
import sys
from pathlib import Path
from typing import Union


def get_git_branch():
Expand Down Expand Up @@ -173,3 +175,20 @@ def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restrict
assert is_valid or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
elif val_type:
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'


def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
"""Transform an uri style_wav, in either a string (path to wav file to be use for style transfer)
or a dict (gst tokens/values to be use for styling)

Args:
style_wav (str): uri

Returns:
Union[str, dict]: path to file (str) or gst style (dict)
"""
if os.path.isfile(style_wav) and style_wav.endswith(".wav"):
return style_wav # style_wav is a .wav file located on the server

style_wav = json.loads(style_wav)
return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...}
Loading