Skip to content

Commit

Permalink
Merge pull request #3086 from coqui-ai/xtts_trainer
Browse files Browse the repository at this point in the history
XTTS v1.1 GPT Trainer
  • Loading branch information
erogol authored Oct 25, 2023
2 parents 1e15269 + 01839af commit 16ba377
Show file tree
Hide file tree
Showing 14 changed files with 14,009 additions and 291 deletions.
53 changes: 53 additions & 0 deletions .github/workflows/xtts_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
name: xtts-tests

on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened]
jobs:
check_skip:
runs-on: ubuntu-latest
if: "! contains(github.event.head_commit.message, '[ci skip]')"
steps:
- run: echo "${{ github.event.head_commit.message }}"

test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.9, "3.10", "3.11"]
experimental: [false]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: set ENV
run: export TRAINER_TELEMETRY=0
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc
sudo apt-get install espeak
sudo apt-get install espeak-ng
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel
- name: Replace scarf urls
run: |
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
- name: Install TTS
run: |
python3 -m pip install .[all]
python3 setup.py egg_info
- name: Unit tests
run: make test_xtts
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ test_tts: ## run tts tests.
test_tts2: ## run tts tests.
nose2 -F -v -B --with-coverage --coverage TTS tests.tts_tests2

test_xtts:
nose2 -F -v -B --with-coverage --coverage TTS tests.xtts_tests

test_aux: ## run aux tests.
nose2 -F -v -B --with-coverage --coverage TTS tests.aux_tests
./run_bash_tests.sh
Expand Down
69 changes: 40 additions & 29 deletions TTS/tts/layers/xtts/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):

if use_deepspeed:
import deepspeed

self.ds_engine = deepspeed.init_inference(
model=self.gpt_inference.half(), # Transformers models
mp_size=1, # Number of GPU
Expand Down Expand Up @@ -233,6 +234,7 @@ def get_logits(
prompt=None,
get_attns=False,
return_latent=False,
attn_mask_cond=None,
attn_mask_text=None,
attn_mask_mel=None,
):
Expand All @@ -248,8 +250,11 @@ def get_logits(
if attn_mask_text is not None:
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
if prompt is not None:
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
if attn_mask_cond is not None:
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
else:
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)

gpt_out = self.gpt(
inputs_embeds=emb,
Expand Down Expand Up @@ -326,7 +331,7 @@ def get_prompts(self, prompt_codes):
prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
return prompt

def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_latent=False, sample=True):
def get_style_emb(self, cond_input, return_latent=False):
"""
cond_input: (b, 80, s) or (b, 1, 80, s)
conds: (b, 1024, s)
Expand All @@ -335,26 +340,7 @@ def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_la
if not return_latent:
if cond_input.ndim == 4:
cond_input = cond_input.squeeze(1)
if sample:
_len_secs = random.randint(2, 6) # in secs
cond_seg_len = int((22050 / 1024) * _len_secs) # in frames
if cond_input.shape[-1] >= cond_seg_len:
new_conds = []
for i in range(cond_input.shape[0]):
cond_len = int(cond_lens[i] / 1024)
if cond_len < cond_seg_len:
start = 0
else:
start = random.randint(0, cond_len - cond_seg_len)
cond_vec = cond_input[i, :, start : start + cond_seg_len]
new_conds.append(cond_vec)
conds = torch.stack(new_conds, dim=0)
else:
cond_seg_len = 5 if cond_seg_len is None else cond_seg_len # secs
cond_frame_len = int((22050 / 1024) * cond_seg_len)
conds = cond_input[:, :, -cond_frame_len:]

conds = self.conditioning_encoder(conds)
conds = self.conditioning_encoder(cond_input)
else:
# already computed
conds = cond_input.unsqueeze(1)
Expand All @@ -366,22 +352,22 @@ def forward(
text_lengths,
audio_codes,
wav_lengths,
cond_lens=None,
cond_mels=None,
cond_idxs=None,
cond_latents=None,
loss_weights=None,
return_attentions=False,
return_latent=False,
):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
cond_mels: MEL float tensor, (b, 1, 80,s)
text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
cond_mels: MEL float tensor, (b, 1, 80,s)
cond_idxs: cond start and end indexs, (b, 2)
If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
Expand All @@ -393,6 +379,11 @@ def forward(
max_text_len = text_lengths.max()
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3

if cond_idxs is not None:
# recompute cond idxs for mel lengths
for idx, l in enumerate(code_lengths):
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len

# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
max_mel_len = code_lengths.max()

Expand Down Expand Up @@ -435,9 +426,16 @@ def forward(
)

# Set attn_mask
attn_mask_cond = None
attn_mask_text = None
attn_mask_mel = None
if not return_latent:
attn_mask_cond = torch.ones(
cond_mels.shape[0],
cond_mels.shape[-1],
dtype=torch.bool,
device=text_inputs.device,
)
attn_mask_text = torch.ones(
text_inputs.shape[0],
text_inputs.shape[1],
Expand All @@ -451,6 +449,11 @@ def forward(
device=audio_codes.device,
)

if cond_idxs is not None:
for idx, r in enumerate(cond_idxs):
l = r[1] - r[0]
attn_mask_cond[idx, l:] = 0.0

for idx, l in enumerate(text_lengths):
attn_mask_text[idx, l + 1 :] = 0.0

Expand All @@ -465,7 +468,7 @@ def forward(

# Compute speech conditioning input
if cond_latents is None:
cond_latents = self.get_style_emb(cond_mels, cond_lens).transpose(1, 2)
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)

# Get logits
sub = -5 # don't ask me why 😄
Expand All @@ -480,6 +483,7 @@ def forward(
prompt=cond_latents,
get_attns=return_attentions,
return_latent=return_latent,
attn_mask_cond=attn_mask_cond,
attn_mask_text=attn_mask_text,
attn_mask_mel=attn_mask_mel,
)
Expand All @@ -501,6 +505,13 @@ def forward(
0
], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row."

# ignore the loss for the segment used for conditioning
# coin flip for the segment to be ignored
if cond_idxs is not None:
cond_start = cond_idxs[idx, 0]
cond_end = cond_idxs[idx, 1]
mel_targets[idx, cond_start:cond_end] = -1

# Compute losses
loss_text = F.cross_entropy(
text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
Expand Down Expand Up @@ -548,7 +559,7 @@ def generate(
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
max_length=self.max_mel_tokens,
**hf_generate_kwargs,
)
if "return_dict_in_generate" in hf_generate_kwargs:
Expand All @@ -561,7 +572,7 @@ def get_generator(self, fake_inputs, **hf_generate_kwargs):
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
max_length=self.max_mel_tokens,
do_stream=True,
**hf_generate_kwargs,
)
29 changes: 9 additions & 20 deletions TTS/tts/layers/xtts/hifigan_decoder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import torch
import torchaudio
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
import torchaudio

from TTS.utils.io import load_fsspec


LRELU_SLOPE = 0.1


Expand Down Expand Up @@ -224,9 +223,7 @@ def __init__(
self.cond_in_each_up_layer = cond_in_each_up_layer

# initial upsampling layers
self.conv_pre = weight_norm(
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
)
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
# upsampling layers
self.ups = nn.ModuleList()
Expand All @@ -246,14 +243,10 @@ def __init__(
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
# post convolution layer
self.conv_post = weight_norm(
Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
)
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
if cond_channels > 0:
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)

Expand Down Expand Up @@ -318,9 +311,7 @@ def inference(self, c):
Tensor: [B, 1, T]
"""
c = c.to(self.conv_pre.weight.device)
c = torch.nn.functional.pad(
c, (self.inference_padding, self.inference_padding), "replicate"
)
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
return self.forward(c)

def remove_weight_norm(self):
Expand All @@ -342,6 +333,7 @@ def load_checkpoint(
assert not self.training
self.remove_weight_norm()


class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()
Expand Down Expand Up @@ -425,10 +417,8 @@ def forward(self, x):
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)



class ResNetSpeakerEncoder(nn.Module):
"""This is copied from 🐸TTS to remove it from the dependencies.
"""
"""This is copied from 🐸TTS to remove it from the dependencies."""

# pylint: disable=W0102
def __init__(
Expand Down Expand Up @@ -620,6 +610,7 @@ def load_checkpoint(
return criterion, state["step"]
return criterion


class HifiDecoder(torch.nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -724,9 +715,7 @@ def inference(self, c, g):
"""
return self.forward(c, g=g)

def load_checkpoint(
self, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# remove unused keys
state = state["model"]
Expand Down
Loading

0 comments on commit 16ba377

Please sign in to comment.