Skip to content

Commit

Permalink
multispeaker
Browse files Browse the repository at this point in the history
  • Loading branch information
twerkmeister committed Jun 26, 2019
1 parent 118fe61 commit d172a3d
Show file tree
Hide file tree
Showing 14 changed files with 231 additions and 81 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.git/
22 changes: 8 additions & 14 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,23 +1,17 @@
FROM nvidia/cuda:9.0-base-ubuntu16.04 as base
FROM pytorch/pytorch:1.0.1-cuda10.0-cudnn7-runtime

WORKDIR /srv/app

RUN apt-get update && \
apt-get install -y git software-properties-common wget vim build-essential libsndfile1 && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get install -y python3.6 python3.6-dev python3.6-tk && \
# Install pip manually
wget https://bootstrap.pypa.io/get-pip.py && \
python3.6 get-pip.py && \
rm get-pip.py && \
# Used by the server in server/synthesizer.py
pip install soundfile
apt-get install -y libsndfile1 espeak && \
apt-get clean && \
rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*

ADD . /srv/app
# Copy Source later to enable dependency caching
COPY requirements.txt /srv/app/
RUN pip install -r requirements.txt

# Setup for development
RUN python3.6 setup.py develop
COPY . /srv/app

# http://bugs.python.org/issue19846
# > At the moment, setting "LANG=C" on a Linux system *fundamentally breaks Python 3*, and that's not OK.
Expand Down
12 changes: 9 additions & 3 deletions datasets/TTSDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(self,
ap (TTS.utils.AudioProcessor): audio processor object.
preprocessor (dataset.preprocess.Class): preprocessor for the dataset.
Create your own if you need to run a new dataset.
speaker_id_cache_path (str): path where the speaker name to id
mapping is stored
batch_group_size (int): (0) range of batch randomization after sorting
sequences by length.
min_seq_len (int): (0) minimum sequence length to be processed
Expand Down Expand Up @@ -105,7 +107,7 @@ def load_phoneme_sequence(self, wav_file, text):
return text

def load_data(self, idx):
text, wav_file = self.items[idx]
text, wav_file, speaker_name = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)

if self.use_phonemes:
Expand All @@ -120,7 +122,8 @@ def load_data(self, idx):
sample = {
'text': text,
'wav': wav,
'item_idx': self.items[idx][1]
'item_idx': self.items[idx][1],
'speaker_name': speaker_name
}
return sample

Expand Down Expand Up @@ -182,6 +185,8 @@ def collate_fn(self, batch):
batch[idx]['item_idx'] for idx in ids_sorted_decreasing
]
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
speaker_name = [batch[idx]['speaker_name']
for idx in ids_sorted_decreasing]

mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
Expand Down Expand Up @@ -219,7 +224,8 @@ def collate_fn(self, batch):
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)

return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
stop_targets, item_idxs

raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}".format(type(batch[0]))))
26 changes: 18 additions & 8 deletions datasets/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from glob import glob
import re


def tweb(root_path, meta_file):
Expand All @@ -8,12 +9,13 @@ def tweb(root_path, meta_file):
"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "tweb"
with open(txt_file, 'r') as ttf:
for line in ttf:
cols = line.split('\t')
wav_file = os.path.join(root_path, cols[0] + '.wav')
text = cols[1]
items.append([text, wav_file])
items.append([text, wav_file, speaker_name])
return items


Expand All @@ -34,42 +36,47 @@ def mozilla_old(root_path, meta_file):
"""Normalizes Mozilla meta data files to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "mozilla_old"
with open(txt_file, 'r') as ttf:
for line in ttf:
cols = line.split('|')
batch_no = int(cols[1].strip().split("_")[0])
wav_folder = "batch{}".format(batch_no)
wav_file = os.path.join(root_path, wav_folder, "wavs_no_processing", cols[1].strip())
text = cols[0].strip()
items.append([text, wav_file])
items.append([text, wav_file, speaker_name])
return items


def mozilla(root_path, meta_file):
"""Normalizes Mozilla meta data files to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "mozilla"
with open(txt_file, 'r') as ttf:
for line in ttf:
cols = line.split('|')
wav_file = cols[1].strip()
text = cols[0].strip()
wav_file = os.path.join(root_path, "wavs", wav_file)
items.append([text, wav_file])
items.append([text, wav_file, speaker_name])
return items


def mailabs(root_path, meta_files):
"""Normalizes M-AI-Labs meta data files to TTS format"""
speaker_regex = re.compile("by_book/(male|female|mix)/(?P<speaker_name>[^/]+)/")
if meta_files is None:
csv_files = glob(root_path+"/**/metadata.csv", recursive=True)
folders = [os.path.dirname(f) for f in csv_files]
else:
csv_files = meta_files
folders = [f.strip().split("by_book")[1][1:] for f in csv_file]
folders = [f.strip().split("by_book")[1][1:] for f in csv_files]
# meta_files = [f.strip() for f in meta_files.split(",")]
items = []
for idx, csv_file in enumerate(csv_files):
# determine speaker based on folder structure...
speaker_name = speaker_regex.search(csv_file).group("speaker_name")
print(" | > {}".format(csv_file))
folder = folders[idx]
txt_file = os.path.join(root_path, csv_file)
Expand All @@ -82,7 +89,7 @@ def mailabs(root_path, meta_files):
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), 'wavs', cols[0] + '.wav')
if os.path.isfile(wav_file):
text = cols[1].strip()
items.append([text, wav_file])
items.append([text, wav_file, speaker_name])
else:
raise RuntimeError("> File %s is not exist!"%(wav_file))
return items
Expand All @@ -92,25 +99,27 @@ def ljspeech(root_path, meta_file):
"""Normalizes the Nancy meta data file to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "ljspeech"
with open(txt_file, 'r') as ttf:
for line in ttf:
cols = line.split('|')
wav_file = os.path.join(root_path, 'wavs', cols[0] + '.wav')
text = cols[1]
items.append([text, wav_file])
items.append([text, wav_file, speaker_name])
return items


def nancy(root_path, meta_file):
"""Normalizes the Nancy meta data file to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "nancy"
with open(txt_file, 'r') as ttf:
for line in ttf:
id = line.split()[1]
text = line[line.find('"') + 1:line.rfind('"') - 1]
wav_file = os.path.join(root_path, "wavn", id + ".wav")
items.append([text, wav_file])
items.append([text, wav_file, speaker_name])
return items


Expand All @@ -124,6 +133,7 @@ def common_voice(root_path, meta_file):
continue
cols = line.split("\t")
text = cols[2]
speaker_name = cols[0]
wav_file = os.path.join(root_path, "clips", cols[1] + ".wav")
items.append([text, wav_file])
items.append([text, wav_file, speaker_name])
return items
22 changes: 20 additions & 2 deletions models/tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class Tacotron(nn.Module):
def __init__(self,
num_chars,
num_speakers,
r=5,
linear_dim=1025,
mel_dim=80,
Expand All @@ -28,6 +29,9 @@ def __init__(self,
self.linear_dim = linear_dim
self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3)
self.speaker_embedding = nn.Embedding(num_speakers,
256)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(256)
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
attn_norm, prenet_type, prenet_dropout,
Expand All @@ -38,22 +42,36 @@ def __init__(self,
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
nn.Sigmoid())

def forward(self, characters, text_lengths, mel_specs):
def forward(self, characters, speaker_ids, text_lengths, mel_specs):
B = characters.size(0)
mask = sequence_mask(text_lengths).to(characters.device)
inputs = self.embedding(characters)
encoder_outputs = self.encoder(inputs)
speaker_embeddings = self.speaker_embedding(speaker_ids)

speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
encoder_outputs += speaker_embeddings
mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens

def inference(self, characters):
def inference(self, characters, speaker_ids):
B = characters.size(0)
inputs = self.embedding(characters)
encoder_outputs = self.encoder(inputs)
speaker_embeddings = self.speaker_embedding(speaker_ids)

speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
encoder_outputs += speaker_embeddings
mel_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
Expand Down
31 changes: 28 additions & 3 deletions models/tacotron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
class Tacotron2(nn.Module):
def __init__(self,
num_chars,
num_speakers,
r,
attn_win=False,
attn_norm="softmax",
Expand All @@ -28,6 +29,8 @@ def __init__(self,
std = sqrt(2.0 / (num_chars + 512))
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
self.speaker_embedding = nn.Embedding(num_speakers, 512)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(512)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win,
attn_norm, prenet_type, prenet_dropout,
Expand All @@ -40,11 +43,19 @@ def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
return mel_outputs, mel_outputs_postnet, alignments

def forward(self, text, text_lengths, mel_specs=None):
def forward(self, text, speaker_ids, text_lengths, mel_specs=None):
# compute mask for padding
mask = sequence_mask(text_lengths).to(text.device)
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
speaker_embeddings = self.speaker_embedding(speaker_ids)

speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)

encoder_outputs += speaker_embeddings
mel_outputs, stop_tokens, alignments = self.decoder(
encoder_outputs, mel_specs, mask)
mel_outputs_postnet = self.postnet(mel_outputs)
Expand All @@ -53,9 +64,16 @@ def forward(self, text, text_lengths, mel_specs=None):
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens

def inference(self, text):
def inference(self, text, speaker_ids):
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)
speaker_embeddings = self.speaker_embedding(speaker_ids)

speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
encoder_outputs += speaker_embeddings
mel_outputs, stop_tokens, alignments = self.decoder.inference(
encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
Expand All @@ -64,12 +82,19 @@ def inference(self, text):
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens

def inference_truncated(self, text):
def inference_truncated(self, text, speaker_ids):
"""
Preserve model states for continuous inference
"""
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
speaker_embeddings = self.speaker_embedding(speaker_ids)

speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
encoder_outputs += speaker_embeddings
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(
encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
Expand Down
Loading

0 comments on commit d172a3d

Please sign in to comment.