From 0ec98bdbf9ebc80bac9bfd9fb6ebd6ef14445114 Mon Sep 17 00:00:00 2001 From: autumn-2-net <109412646+autumn-2-net@users.noreply.github.com> Date: Sat, 23 Sep 2023 00:50:02 +0800 Subject: [PATCH] Shallow diffusion and aux decoder (#128) * Add shallow diffusion API * Support aux decoder training * Support shallow diffusion inference * add shallow farmwork * add shallow farmwork * Support lambda for aux mel loss * Move config key * add shallow farmework * add shallow farmework * add denorm * add shallow model training switch * Limit gradient from aux decoder * Improve loss calculation control flow * add independent encoder in shallow * Adjust lambda * Implement shallow diffusion There are some issues to resolve in DPM-Solver++ and UniPC * Fix missing depth assignment * fix bugs of shallow diffusion inference * Fix errors and remove debug code * Support K_step < timesteps (shallow-only diffusion) * Fix argument passing * Add missing checks * add glow decoder * add glow decoder * add convnext glow decoder * fix fs2 * Support using gt mel as source during validation * Clean files and configs * Clean and refactor aux decoder * Fix KeyError * Support exporting shallow diffusion to ONNX * Add missing logic to ONNX * Rename `diff_depth` to `K_step_infer` --------- Co-authored-by: yqzhishen Co-authored-by: autumn <2> Co-authored-by: llc1995@sina.com --- configs/acoustic.yaml | 18 +++++ configs/templates/config_acoustic.yaml | 17 +++++ deployment/exporters/acoustic_exporter.py | 88 ++++++++++++++++------- deployment/modules/diffusion.py | 39 ++++++++-- deployment/modules/toplevel.py | 49 +++++++------ inference/ds_acoustic.py | 6 +- modules/aux_decoder/__init__.py | 70 ++++++++++++++++++ modules/aux_decoder/convnext.py | 87 ++++++++++++++++++++++ modules/diffusion/ddpm.py | 60 +++++++++++----- modules/toplevel.py | 57 +++++++++++++-- scripts/infer.py | 12 +++- training/acoustic_task.py | 64 ++++++++++++----- utils/onnx_helper.py | 42 ++++++----- 13 files changed, 492 insertions(+), 117 deletions(-) create mode 100644 modules/aux_decoder/__init__.py create mode 100644 modules/aux_decoder/convnext.py diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 88ae1b12..92c7aa7f 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -76,6 +76,24 @@ diff_decoder_type: 'wavenet' diff_loss_type: l2 schedule_type: 'linear' +# shallow diffusion +use_shallow_diffusion: false +K_step_infer: 400 + +shallow_diffusion_args: + train_aux_decoder: true + train_diffusion: true + val_gt_start: false + aux_decoder_arch: convnext + aux_decoder_args: + num_channels: 512 + num_layers: 6 + kernel_size: 7 + dropout_rate: 0.1 + aux_decoder_grad: 0.1 + +lambda_aux_mel_loss: 0.2 + # train and eval num_sanity_val_steps: 1 optimizer_args: diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 0291177a..72e3c2df 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -40,9 +40,26 @@ augmentation_args: domain: log # or linear scale: 1.0 +K_step: 1000 residual_channels: 512 residual_layers: 20 +# shallow diffusion +use_shallow_diffusion: false +K_step_infer: 400 +shallow_diffusion_args: + train_aux_decoder: true + train_diffusion: true + val_gt_start: false + aux_decoder_arch: convnext + aux_decoder_args: + num_channels: 512 + num_layers: 6 + kernel_size: 7 + dropout_rate: 0.1 + aux_decoder_grad: 0.1 +lambda_aux_mel_loss: 0.2 + optimizer_args: lr: 0.0004 lr_scheduler_args: diff --git a/deployment/exporters/acoustic_exporter.py b/deployment/exporters/acoustic_exporter.py index 34cf2a01..ebfd75a1 100644 --- a/deployment/exporters/acoustic_exporter.py +++ b/deployment/exporters/acoustic_exporter.py @@ -33,12 +33,22 @@ def __init__( self.spk_map: dict = self.build_spk_map() self.vocab = TokenTextEncoder(vocab_list=build_phoneme_list()) self.model = self.build_model() - self.fs2_cache_path = self.cache_dir / 'fs2.onnx' + self.fs2_aux_cache_path = self.cache_dir / ( + 'fs2_aux.onnx' if self.model.use_shallow_diffusion else 'fs2.onnx' + ) self.diffusion_cache_path = self.cache_dir / 'diffusion.onnx' # Attributes for logging self.model_class_name = remove_suffix(self.model.__class__.__name__, 'ONNX') - self.fs2_class_name = remove_suffix(self.model.fs2.__class__.__name__, 'ONNX') + fs2_aux_cls_logging = [remove_suffix(self.model.fs2.__class__.__name__, 'ONNX')] + if self.model.use_shallow_diffusion: + fs2_aux_cls_logging.append(remove_suffix( + self.model.aux_decoder.decoder.__class__.__name__, 'ONNX' + )) + self.fs2_aux_class_name = ', '.join(fs2_aux_cls_logging) + self.aux_decoder_class_name = remove_suffix( + self.model.aux_decoder.decoder.__class__.__name__, 'ONNX' + ) if self.model.use_shallow_diffusion else None self.denoiser_class_name = remove_suffix(self.model.diffusion.denoise_fn.__class__.__name__, 'ONNX') self.diffusion_class_name = remove_suffix(self.model.diffusion.__class__.__name__, 'ONNX') @@ -86,11 +96,11 @@ def export(self, path: Path): def export_model(self, path: Path): self._torch_export_model() - fs2_onnx = self._optimize_fs2_graph(onnx.load(self.fs2_cache_path)) + fs2_aux_onnx = self._optimize_fs2_aux_graph(onnx.load(self.fs2_aux_cache_path)) diffusion_onnx = self._optimize_diffusion_graph(onnx.load(self.diffusion_cache_path)) - model_onnx = self._merge_fs2_diffusion_graphs(fs2_onnx, diffusion_onnx) + model_onnx = self._merge_fs2_aux_diffusion_graphs(fs2_aux_onnx, diffusion_onnx) onnx.save(model_onnx, path) - self.fs2_cache_path.unlink() + self.fs2_aux_cache_path.unlink() self.diffusion_cache_path.unlink() print(f'| export model => {path}') @@ -105,7 +115,7 @@ def export_attachments(self, path: Path): @torch.no_grad() def _torch_export_model(self): - # Prepare inputs for FastSpeech2 tracing + # Prepare inputs for FastSpeech2 and aux decoder tracing n_frames = 10 tokens = torch.LongTensor([[1]]).to(self.device) durations = torch.LongTensor([[n_frames]]).to(self.device) @@ -161,22 +171,30 @@ def _torch_export_model(self): 1: 'n_frames' } - # PyTorch ONNX export for FastSpeech2 - print(f'Exporting {self.fs2_class_name}...') + # PyTorch ONNX export for FastSpeech2 and aux decoder + output_names = ['condition'] + if self.model.use_shallow_diffusion: + output_names.append('aux_mel') + dynamix_axes['aux_mel'] = { + 1: 'n_frames' + } + print(f'Exporting {self.fs2_aux_class_name}...') torch.onnx.export( - self.model.view_as_fs2(), + self.model.view_as_fs2_aux(), arguments, - self.fs2_cache_path, + self.fs2_aux_cache_path, input_names=input_names, - output_names=['condition'], + output_names=output_names, dynamic_axes=dynamix_axes, opset_version=15 ) + condition = torch.rand((1, n_frames, hparams['hidden_size']), device=self.device) + # Prepare inputs for denoiser tracing and GaussianDiffusion scripting shape = (1, 1, hparams['audio_num_mel_bins'], n_frames) noise = torch.randn(shape, device=self.device) - condition = torch.rand((1, hparams['hidden_size'], n_frames), device=self.device) + x_start = torch.randn((1, n_frames, hparams['audio_num_mel_bins']),device=self.device) step = (torch.rand((1,), device=self.device) * hparams['K_step']).long() print(f'Tracing {self.denoiser_class_name} denoiser...') @@ -186,20 +204,24 @@ def _torch_export_model(self): ( noise, step, - condition + condition.transpose(1, 2) ) ) print(f'Scripting {self.diffusion_class_name}...') + diffusion_inputs = [ + condition, + *([x_start, 100] if self.model.use_shallow_diffusion else []) + ] diffusion = torch.jit.script( diffusion, example_inputs=[ ( - condition.transpose(1, 2), + *diffusion_inputs, 1 # p_sample branch ), ( - condition.transpose(1, 2), + *diffusion_inputs, 200 # p_sample_plms branch ) ] @@ -210,12 +232,14 @@ def _torch_export_model(self): torch.onnx.export( diffusion, ( - condition.transpose(1, 2), + *diffusion_inputs, 200 ), self.diffusion_cache_path, input_names=[ - 'condition', 'speedup' + 'condition', + *(['x_start', 'depth'] if self.model.use_shallow_diffusion else []), + 'speedup' ], output_names=[ 'mel' @@ -224,6 +248,7 @@ def _torch_export_model(self): 'condition': { 1: 'n_frames' }, + **({'x_start': {1: 'n_frames'}} if self.model.use_shallow_diffusion else {}), 'mel': { 1: 'n_frames' } @@ -252,11 +277,11 @@ def _perform_spk_mix(self, spk_mix: Dict[str, float]): ) # => [1, H] return spk_mix_embed - def _optimize_fs2_graph(self, fs2: onnx.ModelProto) -> onnx.ModelProto: - print(f'Running ONNX Simplifier on {self.fs2_class_name}...') + def _optimize_fs2_aux_graph(self, fs2: onnx.ModelProto) -> onnx.ModelProto: + print(f'Running ONNX Simplifier on {self.fs2_aux_class_name}...') fs2, check = onnxsim.simplify(fs2, include_subgraph=True) assert check, 'Simplified ONNX model could not be validated' - print(f'| optimize graph: {self.fs2_class_name}') + print(f'| optimize graph: {self.fs2_aux_class_name}') return fs2 def _optimize_diffusion_graph(self, diffusion: onnx.ModelProto) -> onnx.ModelProto: @@ -282,18 +307,33 @@ def _optimize_diffusion_graph(self, diffusion: onnx.ModelProto) -> onnx.ModelPro print(f'| optimize graph: {self.diffusion_class_name}') return diffusion - def _merge_fs2_diffusion_graphs(self, fs2: onnx.ModelProto, diffusion: onnx.ModelProto) -> onnx.ModelProto: - onnx_helper.model_add_prefixes(fs2, dim_prefix='fs2.', ignored_pattern=r'(n_tokens)|(n_frames)') + def _merge_fs2_aux_diffusion_graphs(self, fs2: onnx.ModelProto, diffusion: onnx.ModelProto) -> onnx.ModelProto: + onnx_helper.model_add_prefixes( + fs2, dim_prefix=('fs2aux.' if self.model.use_shallow_diffusion else 'fs2.'), + ignored_pattern=r'(n_tokens)|(n_frames)' + ) onnx_helper.model_add_prefixes(diffusion, dim_prefix='diffusion.', ignored_pattern='n_frames') - print(f'Merging {self.fs2_class_name} and {self.diffusion_class_name} ' + print(f'Merging {self.fs2_aux_class_name} and {self.diffusion_class_name} ' f'back into {self.model_class_name}...') merged = onnx.compose.merge_models( - fs2, diffusion, io_map=[('condition', 'condition')], + fs2, diffusion, io_map=[ + ('condition', 'condition'), + *([('aux_mel', 'x_start')] if self.model.use_shallow_diffusion else []), + ], prefix1='', prefix2='', doc_string='', producer_name=fs2.producer_name, producer_version=fs2.producer_version, domain=fs2.domain, model_version=fs2.model_version ) merged.graph.name = fs2.graph.name + + print(f'Running ONNX Simplifier on {self.model_class_name}...') + merged, check = onnxsim.simplify( + merged, + include_subgraph=True + ) + assert check, 'Simplified ONNX model could not be validated' + print(f'| optimize graph: {self.model_class_name}') + return merged # noinspection PyMethodMayBeStatic diff --git a/deployment/modules/diffusion.py b/deployment/modules/diffusion.py index 8905bebd..c8a03fe5 100644 --- a/deployment/modules/diffusion.py +++ b/deployment/modules/diffusion.py @@ -16,6 +16,12 @@ def extract(a, t): # noinspection PyMethodOverriding class GaussianDiffusionONNX(GaussianDiffusion): + def q_sample(self, x_start, t, noise): + return ( + extract(self.sqrt_alphas_cumprod, t) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t) * noise + ) + def p_sample(self, x, t, cond): x_pred = self.denoise_fn(x, t, cond) x_recon = ( @@ -74,18 +80,39 @@ def p_sample_plms(self, x_prev, t, interval: int, cond, noise_list: List[Tensor] x_prev = self.plms_get_x_pred(x_prev, noise_pred_prime, t, t_prev) return noise_pred, x_prev + def norm_spec(self, x): + k = (self.spec_max - self.spec_min) / 2. + b = (self.spec_max + self.spec_min) / 2. + return (x - b) / k + def denorm_spec(self, x): - d = (self.spec_max - self.spec_min) / 2. - m = (self.spec_max + self.spec_min) / 2. - return x * d + m + k = (self.spec_max - self.spec_min) / 2. + b = (self.spec_max + self.spec_min) / 2. + return x * k + b - def forward(self, condition, speedup: int): + def forward(self, condition, x_start=None, depth: int = 1000, speedup: int = 1): condition = condition.transpose(1, 2) # [1, T, H] => [1, H, T] device = condition.device n_frames = condition.shape[2] - step_range = torch.arange(0, self.k_step, speedup, dtype=torch.long, device=device).flip(0)[:, None] - x = torch.randn((1, self.num_feats, self.out_dims, n_frames), device=device) + noise = torch.randn((1, self.num_feats, self.out_dims, n_frames), device=device) + if x_start is None: + step_range = torch.arange(0, self.k_step, speedup, dtype=torch.long, device=device).flip(0)[:, None] + x = noise + else: + depth = min(depth, self.k_step) + step_range = torch.arange(0, depth, speedup, dtype=torch.long, device=device).flip(0)[:, None] + x_start = self.norm_spec(x_start).transpose(-2, -1) + if self.num_feats == 1: + x_start = x_start[:, None, :, :] + if depth >= self.timesteps: + x = noise + elif depth > 0: + x = self.q_sample( + x_start, torch.full((1,), depth - 1, device=device, dtype=torch.long), noise + ) + else: + x = x_start if speedup > 1: for t in step_range: diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index 2cbbda8f..02799721 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -1,6 +1,6 @@ -import numpy as np import copy +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -37,7 +37,7 @@ def __init__(self, vocab_size, out_dims): spec_max=hparams['spec_max'] ) - def forward_fs2( + def forward_fs2_aux( self, tokens: Tensor, durations: Tensor, @@ -46,41 +46,40 @@ def forward_fs2( gender: Tensor = None, velocity: Tensor = None, spk_embed: Tensor = None - ) -> Tensor: - return self.fs2( + ): + condition = self.fs2( tokens, durations, f0, variances=variances, gender=gender, velocity=velocity, spk_embed=spk_embed ) + if self.use_shallow_diffusion: + aux_mel_pred = self.aux_decoder(condition, infer=True) + return condition, aux_mel_pred + else: + return condition + + def forward_shallow_diffusion( + self, condition: Tensor, x_start: Tensor, + depth: int, speedup: int + ) -> Tensor: + return self.diffusion(condition, x_start=x_start, depth=depth, speedup=speedup) - def forward_diffusion(self, condition: Tensor, speedup: int) -> Tensor: - return self.diffusion(condition, speedup) + def forward_diffusion(self, condition: Tensor, speedup: int): + return self.diffusion(condition, speedup=speedup) - def view_as_fs2(self) -> nn.Module: + def view_as_fs2_aux(self) -> nn.Module: model = copy.deepcopy(self) - try: - del model.variance_embeds - del model.variance_adaptor - except AttributeError: - pass del model.diffusion - model.forward = model.forward_fs2 + model.forward = model.forward_fs2_aux return model - def view_as_adaptor(self) -> nn.Module: - model = copy.deepcopy(self) - del model.fs2 - del model.diffusion - raise NotImplementedError() - def view_as_diffusion(self) -> nn.Module: model = copy.deepcopy(self) del model.fs2 - try: - del model.variance_embeds - del model.variance_adaptor - except AttributeError: - pass - model.forward = model.forward_diffusion + if self.use_shallow_diffusion: + del model.aux_decoder + model.forward = model.forward_shallow_diffusion + else: + model.forward = model.forward_diffusion return model diff --git a/inference/ds_acoustic.py b/inference/ds_acoustic.py index b37727da..b3254046 100644 --- a/inference/ds_acoustic.py +++ b/inference/ds_acoustic.py @@ -11,7 +11,7 @@ from basics.base_svs_infer import BaseSVSInfer from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST from modules.fastspeech.tts_modules import LengthRegulator -from modules.toplevel import DiffSingerAcoustic +from modules.toplevel import DiffSingerAcoustic, ShallowDiffusionOutput from modules.vocoders.registry import VOCODERS from utils import load_ckpt from utils.hparams import hparams @@ -170,12 +170,12 @@ def forward_model(self, sample): ) # => [B, T, H] else: spk_mix_embed = None - mel_pred = self.model( + mel_pred: ShallowDiffusionOutput = self.model( txt_tokens, mel2ph=sample['mel2ph'], f0=sample['f0'], **variances, key_shift=sample.get('key_shift'), speed=sample.get('speed'), spk_mix_embed=spk_mix_embed, infer=True ) - return mel_pred + return mel_pred.diff_out @torch.no_grad() def run_vocoder(self, spec, **kwargs): diff --git a/modules/aux_decoder/__init__.py b/modules/aux_decoder/__init__.py new file mode 100644 index 00000000..54ceb211 --- /dev/null +++ b/modules/aux_decoder/__init__.py @@ -0,0 +1,70 @@ +import torch.nn +from torch import nn + +from .convnext import ConvNeXtDecoder +from utils import filter_kwargs + +AUX_DECODERS = { + 'convnext': ConvNeXtDecoder +} +AUX_LOSSES = { + 'convnext': nn.L1Loss +} + + +def build_aux_decoder( + in_dims: int, out_dims: int, + aux_decoder_arch: str, aux_decoder_args: dict +) -> torch.nn.Module: + decoder_cls = AUX_DECODERS[aux_decoder_arch] + kwargs = filter_kwargs(aux_decoder_args, decoder_cls) + return AUX_DECODERS[aux_decoder_arch](in_dims, out_dims, **kwargs) + + +def build_aux_loss(aux_decoder_arch): + return AUX_LOSSES[aux_decoder_arch]() + + +class AuxDecoderAdaptor(nn.Module): + def __init__(self, in_dims: int, out_dims: int, num_feats: int, + spec_min: list, spec_max: list, + aux_decoder_arch: str, aux_decoder_args: dict): + super().__init__() + self.decoder = build_aux_decoder( + in_dims=in_dims, out_dims=out_dims * num_feats, + aux_decoder_arch=aux_decoder_arch, + aux_decoder_args=aux_decoder_args + ) + self.out_dims = out_dims + self.n_feats = num_feats + if spec_min is not None and spec_max is not None: + # spec: [B, T, M] or [B, F, T, M] + # spec_min and spec_max: [1, 1, M] or [1, 1, F, M] => transpose(-3, -2) => [1, 1, M] or [1, F, 1, M] + spec_min = torch.FloatTensor(spec_min)[None, None, :].transpose(-3, -2) + spec_max = torch.FloatTensor(spec_max)[None, None, :].transpose(-3, -2) + self.register_buffer('spec_min', spec_min, persistent=False) + self.register_buffer('spec_max', spec_max, persistent=False) + + def norm_spec(self, x): + k = (self.spec_max - self.spec_min) / 2. + b = (self.spec_max + self.spec_min) / 2. + return (x - b) / k + + def denorm_spec(self, x): + k = (self.spec_max - self.spec_min) / 2. + b = (self.spec_max + self.spec_min) / 2. + return x * k + b + + def forward(self, condition, infer=False): + x = self.decoder(condition, infer=infer) # [B, T, F x C] + + if self.n_feats > 1: + # This is the temporary solution since PyTorch 1.13 + # does not support exporting aten::unflatten to ONNX + # x = x.unflatten(dim=2, sizes=(self.n_feats, self.in_dims)) + x = x.reshape(-1, x.shape[1], self.n_feats, self.out_dims) # [B, T, F, C] + x = x.transpose(1, 2) # [B, F, T, C] + if infer: + x = self.denorm_spec(x) + + return x # [B, T, C] or [B, F, T, C] diff --git a/modules/aux_decoder/convnext.py b/modules/aux_decoder/convnext.py new file mode 100644 index 00000000..a03959dd --- /dev/null +++ b/modules/aux_decoder/convnext.py @@ -0,0 +1,87 @@ +from typing import Optional + +import torch +import torch.nn as nn + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: Optional[float] = None, drop_out: float = 0.0 + + ): + super().__init__() + self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = nn.Identity() + self.dropout = nn.Dropout(drop_out) if drop_out > 0. else nn.Identity() + + def forward(self, x: torch.Tensor, ) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + x = self.dropout(x) + + x = residual + self.drop_path(x) + return x + + +class ConvNeXtDecoder(nn.Module): + def __init__( + self, in_dims, out_dims, /, *, + num_channels=512, num_layers=6, kernel_size=7, dropout_rate=0.1 + ): + super().__init__() + self.inconv = nn.Conv1d( + in_dims, num_channels, kernel_size, + stride=1, padding=(kernel_size - 1) // 2 + ) + self.conv = nn.ModuleList( + ConvNeXtBlock( + dim=num_channels, intermediate_dim=num_channels * 4, + layer_scale_init_value=1e-6, drop_out=dropout_rate + ) for _ in range(num_layers) + ) + self.outconv = nn.Conv1d( + num_channels, out_dims, kernel_size, + stride=1, padding=(kernel_size - 1) // 2 + ) + + # noinspection PyUnusedLocal + def forward(self, x, infer=False): + x = x.transpose(1, 2) + x = self.inconv(x) + for conv in self.conv: + x = conv(x) + x = self.outconv(x) + x = x.transpose(1, 2) + return x diff --git a/modules/diffusion/ddpm.py b/modules/diffusion/ddpm.py index 46c3eacc..d8bdc444 100644 --- a/modules/diffusion/ddpm.py +++ b/modules/diffusion/ddpm.py @@ -81,6 +81,12 @@ def __init__(self, out_dims, num_feats=1, timesteps=1000, k_step=1000, alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + self.use_shallow_diffusion = hparams.get('use_shallow_diffusion', False) + if self.use_shallow_diffusion: + assert k_step <= timesteps, 'K_step should not be larger than timesteps.' + else: + assert k_step == timesteps, 'K_step must equal timesteps if use_shallow_diffusion is False.' + self.timesteps = timesteps self.k_step = k_step self.noise_list = deque(maxlen=4) @@ -216,16 +222,31 @@ def p_losses(self, x_start, t, cond, noise=None): return x_recon, noise - def inference(self, cond, b=1, device=None): - t = self.k_step - shape = (b, self.num_feats, self.out_dims, cond.shape[2]) - x = torch.randn(shape, device=device) - if hparams.get('pndm_speedup') and hparams['pndm_speedup'] > 1: + def inference(self, cond, b=1, x_start=None, device=None): + depth = hparams.get('K_step_infer', self.k_step) + noise = torch.randn(b, self.num_feats, self.out_dims, cond.shape[2], device=device) + if self.use_shallow_diffusion: + t_max = min(depth, self.k_step) + else: + t_max = self.k_step + + if t_max >= self.timesteps: + x = noise + elif t_max > 0: + assert x_start is not None, 'Missing shallow diffusion source.' + x = self.q_sample( + x_start, torch.full((b,), t_max - 1, device=device, dtype=torch.long), noise + ) + else: + assert x_start is not None, 'Missing shallow diffusion source.' + x = x_start + + if hparams.get('pndm_speedup') and hparams['pndm_speedup'] > 1 and t_max > 0: algorithm = hparams.get('diff_accelerator', 'ddim') if algorithm == 'dpm-solver': from inference.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver # 1. Define the noise schedule. - noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas) + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t_max]) # 2. Convert your discrete-time `model` to the continuous-time # noise prediction model. Here is an example for a diffusion model @@ -251,7 +272,7 @@ def wrapped(x, t, **kwargs): # costs and the sample quality. dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") - steps = t // hparams["pndm_speedup"] + steps = t_max // hparams["pndm_speedup"] self.bar = tqdm(desc="sample time step", total=steps, disable=not hparams['infer'], leave=False) x = dpm_solver.sample( x, @@ -264,7 +285,7 @@ def wrapped(x, t, **kwargs): elif algorithm == 'unipc': from inference.uni_pc import NoiseScheduleVP, model_wrapper, UniPC # 1. Define the noise schedule. - noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas) + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t_max]) # 2. Convert your discrete-time `model` to the continuous-time # noise prediction model. Here is an example for a diffusion model @@ -289,7 +310,7 @@ def wrapped(x, t, **kwargs): # costs and the sample quality. uni_pc = UniPC(model_fn, noise_schedule, variant='bh2') - steps = t // hparams["pndm_speedup"] + steps = t_max // hparams["pndm_speedup"] self.bar = tqdm(desc="sample time step", total=steps, disable=not hparams['infer'], leave=False) x = uni_pc.sample( x, @@ -303,8 +324,8 @@ def wrapped(x, t, **kwargs): self.noise_list = deque(maxlen=4) iteration_interval = hparams['pndm_speedup'] for i in tqdm( - reversed(range(0, t, iteration_interval)), desc='sample time step', - total=t // iteration_interval, disable=not hparams['infer'], leave=False + reversed(range(0, t_max, iteration_interval)), desc='sample time step', + total=t_max // iteration_interval, disable=not hparams['infer'], leave=False ): x = self.p_sample_plms( x, torch.full((b,), i, device=device, dtype=torch.long), @@ -313,8 +334,8 @@ def wrapped(x, t, **kwargs): elif algorithm == 'ddim': iteration_interval = hparams['pndm_speedup'] for i in tqdm( - reversed(range(0, t, iteration_interval)), desc='sample time step', - total=t // iteration_interval, disable=not hparams['infer'], leave=False + reversed(range(0, t_max, iteration_interval)), desc='sample time step', + total=t_max // iteration_interval, disable=not hparams['infer'], leave=False ): x = self.p_sample_ddim( x, torch.full((b,), i, device=device, dtype=torch.long), @@ -323,13 +344,13 @@ def wrapped(x, t, **kwargs): else: raise NotImplementedError(algorithm) else: - for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t, + for i in tqdm(reversed(range(0, t_max)), desc='sample time step', total=t_max, disable=not hparams['infer'], leave=False): x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) x = x.transpose(2, 3).squeeze(1) # [B, F, M, T] => [B, T, M] or [B, F, T, M] return x - def forward(self, condition, gt_spec=None, infer=True): + def forward(self, condition, gt_spec=None, src_spec=None, infer=True): """ conditioning diffusion, use fastspeech2 encoder output as the condition """ @@ -344,7 +365,14 @@ def forward(self, condition, gt_spec=None, infer=True): t = torch.randint(0, self.k_step, (b,), device=device).long() return self.p_losses(spec, t, cond=cond) else: - x = self.inference(cond, b=b, device=device) + # src_spec: [B, T, M] or [B, F, T, M] + if src_spec is not None: + spec = self.norm_spec(src_spec).transpose(-2, -1) + if self.num_feats == 1: + spec = spec[:, None, :, :] + else: + spec = None + x = self.inference(cond, b=b, x_start=spec, device=device) return self.denorm_spec(x) def norm_spec(self, x): diff --git a/modules/toplevel.py b/modules/toplevel.py index e90725b2..a1a7647a 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -6,6 +6,7 @@ from torch import Tensor from basics.base_module import CategorizedModule +from modules.aux_decoder import AuxDecoderAdaptor from modules.commons.common_layers import ( XavierUniformInitLinear as Linear, NormalInitEmbedding as Embedding @@ -20,6 +21,12 @@ from utils.hparams import hparams +class ShallowDiffusionOutput: + def __init__(self, *, aux_out=None, diff_out=None): + self.aux_out = aux_out + self.diff_out = diff_out + + class DiffSingerAcoustic(CategorizedModule, ParameterAdaptorModule): @property def category(self): @@ -32,6 +39,19 @@ def __init__(self, vocab_size, out_dims): vocab_size=vocab_size ) + self.use_shallow_diffusion = hparams.get('use_shallow_diffusion', False) + self.shallow_args = hparams.get('shallow_diffusion_args', {}) + if self.use_shallow_diffusion: + self.train_aux_decoder = self.shallow_args['train_aux_decoder'] + self.train_diffusion = self.shallow_args['train_diffusion'] + self.aux_decoder_grad = self.shallow_args['aux_decoder_grad'] + self.aux_decoder = AuxDecoderAdaptor( + in_dims=hparams['hidden_size'], out_dims=out_dims, num_feats=1, + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], + aux_decoder_arch=self.shallow_args['aux_decoder_arch'], + aux_decoder_args=self.shallow_args['aux_decoder_args'] + ) + self.diffusion = GaussianDiffusion( out_dims=out_dims, num_feats=1, @@ -50,19 +70,42 @@ def __init__(self, vocab_size, out_dims): def forward( self, txt_tokens, mel2ph, f0, key_shift=None, speed=None, spk_embed_id=None, gt_mel=None, infer=True, **kwargs - ): + ) -> ShallowDiffusionOutput: condition = self.fs2( txt_tokens, mel2ph, f0, key_shift=key_shift, speed=speed, spk_embed_id=spk_embed_id, **kwargs ) - if infer: - mel_pred = self.diffusion(condition, infer=True) + if self.use_shallow_diffusion: + aux_mel_pred = self.aux_decoder(condition, infer=True) + aux_mel_pred *= ((mel2ph > 0).float()[:, :, None]) + if gt_mel is not None and self.shallow_args['val_gt_start']: + src_mel = gt_mel + else: + src_mel = aux_mel_pred + else: + aux_mel_pred = src_mel = None + mel_pred = self.diffusion(condition, src_spec=src_mel, infer=True) mel_pred *= ((mel2ph > 0).float()[:, :, None]) - return mel_pred + return ShallowDiffusionOutput(aux_out=aux_mel_pred, diff_out=mel_pred) else: - x_recon, noise = self.diffusion(condition, gt_spec=gt_mel, infer=False) - return x_recon, noise + if self.use_shallow_diffusion: + if self.train_aux_decoder: + aux_cond = condition * self.aux_decoder_grad + condition.detach() * (1 - self.aux_decoder_grad) + aux_out = self.aux_decoder(aux_cond, infer=False) + else: + aux_out = None + if self.train_diffusion: + x_recon, noise = self.diffusion(condition, gt_spec=gt_mel, infer=False) + diff_out = (x_recon, noise) + else: + diff_out = None + return ShallowDiffusionOutput(aux_out=aux_out, diff_out=diff_out) + + else: + aux_out = None + x_recon, noise = self.diffusion(condition, gt_spec=gt_mel, infer=False) + return ShallowDiffusionOutput(aux_out=aux_out, diff_out=(x_recon, noise)) class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule): @@ -194,7 +237,7 @@ def forward( ] condition += torch.stack(variance_embeds, dim=-1).sum(-1) - variance_outputs = self.variance_predictor(condition, variance_inputs, infer) + variance_outputs = self.variance_predictor(condition, variance_inputs, infer=infer) if infer: variances_pred_out = self.collect_variance_outputs(variance_outputs) diff --git a/scripts/infer.py b/scripts/infer.py index d83bb931..8c6e6e83 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -47,6 +47,7 @@ def main(): @click.option('--key', type=int, required=False, default=0, help='Key transition of pitch') @click.option('--gender', type=float, required=False, help='Formant shifting (gender control)') @click.option('--seed', type=int, required=False, default=-1, help='Random seed of the inference') +@click.option('--depth', type=int, required=False, default=-1, help='Shallow diffusion depth') @click.option('--speedup', type=int, required=False, default=0, help='Diffusion acceleration ratio') @click.option('--mel', is_flag=True, help='Save intermediate mel format instead of waveform') def acoustic( @@ -60,6 +61,7 @@ def acoustic( key: int, gender: float, seed: int, + depth: int, speedup: int, mel: bool ): @@ -107,8 +109,16 @@ def acoustic( f'Vocoder ckpt \'{hparams["vocoder_ckpt"]}\' not found. ' \ f'Please put it to the checkpoints directory to run inference.' + if depth >= 0: + assert depth <= hparams['K_step'], f'Diffusion depth should not be larger than K_step {hparams["K_step"]}.' + hparams['K_step_infer'] = depth + elif hparams.get('use_shallow_diffusion', False): + depth = hparams['K_step_infer'] + else: + depth = hparams['K_step'] # gaussian start (full depth diffusion) + if speedup > 0: - assert hparams['K_step'] % speedup == 0, f'Acceleration ratio must be factor of K_step {hparams["K_step"]}.' + assert depth % speedup == 0, f'Acceleration ratio must be factor of diffusion depth {depth}.' hparams['pndm_speedup'] = speedup spk_mix = parse_commandline_spk_mix(spk) if hparams['use_spk_id'] and spk is not None else None diff --git a/training/acoustic_task.py b/training/acoustic_task.py index b0723912..04dedb65 100644 --- a/training/acoustic_task.py +++ b/training/acoustic_task.py @@ -9,8 +9,9 @@ from basics.base_dataset import BaseDataset from basics.base_task import BaseTask from basics.base_vocoder import BaseVocoder +from modules.aux_decoder import build_aux_loss from modules.losses.diff_loss import DiffusionNoiseLoss -from modules.toplevel import DiffSingerAcoustic +from modules.toplevel import DiffSingerAcoustic, ShallowDiffusionOutput from modules.vocoders.registry import get_vocoder_cls from utils.hparams import hparams from utils.plot import spec_to_figure, curve_to_figure @@ -60,6 +61,12 @@ class AcousticTask(BaseTask): def __init__(self): super().__init__() self.dataset_cls = AcousticDataset + self.use_shallow_diffusion = hparams['use_shallow_diffusion'] + if self.use_shallow_diffusion: + self.shallow_args = hparams['shallow_diffusion_args'] + self.train_aux_decoder = self.shallow_args['train_aux_decoder'] + self.train_diffusion = self.shallow_args['train_diffusion'] + self.use_vocoder = hparams['infer'] or hparams['val_with_vocoder'] if self.use_vocoder: self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() @@ -78,6 +85,9 @@ def build_model(self): # noinspection PyAttributeOutsideInit def build_losses_and_metrics(self): + if self.use_shallow_diffusion: + self.aux_mel_loss = build_aux_loss(self.shallow_args['aux_decoder_arch']) + self.lambda_aux_mel_loss = hparams['lambda_aux_mel_loss'] self.mel_loss = DiffusionNoiseLoss(loss_type=hparams['diff_loss_type']) def run_model(self, sample, infer=False): @@ -96,20 +106,27 @@ def run_model(self, sample, infer=False): spk_embed_id = sample['spk_ids'] else: spk_embed_id = None - output = self.model( + output: ShallowDiffusionOutput = self.model( txt_tokens, mel2ph=mel2ph, f0=f0, **variances, key_shift=key_shift, speed=speed, spk_embed_id=spk_embed_id, gt_mel=target, infer=infer ) if infer: - return output # mel_pred + return output else: - x_recon, x_noise = output - mel_loss = self.mel_loss(x_recon, x_noise, nonpadding=(mel2ph > 0).unsqueeze(-1).float()) - losses = { - 'mel_loss': mel_loss - } + losses = {} + + if output.aux_out is not None: + aux_out = output.aux_out + norm_gt = self.model.aux_decoder.norm_spec(target) + aux_mel_loss = self.lambda_aux_mel_loss * self.aux_mel_loss(aux_out, norm_gt) + losses['aux_mel_loss'] = aux_mel_loss + + if output.diff_out is not None: + x_recon, x_noise = output.diff_out + mel_loss = self.mel_loss(x_recon, x_noise, nonpadding=(mel2ph > 0).unsqueeze(-1).float()) + losses['mel_loss'] = mel_loss return losses @@ -126,29 +143,44 @@ def _validation_step(self, sample, batch_idx): if batch_idx < hparams['num_valid_plots'] \ and (self.trainer.distributed_sampler_kwargs or {}).get('rank', 0) == 0: - mel_pred = self.run_model(sample, infer=True) + mel_out: ShallowDiffusionOutput = self.run_model(sample, infer=True) if self.use_vocoder: - self.plot_wav(batch_idx, sample['mel'], mel_pred, f0=sample['f0']) - self.plot_mel(batch_idx, sample['mel'], mel_pred, name=f'diffmel_{batch_idx}') + self.plot_wav( + batch_idx, gt_mel=sample['mel'], + aux_mel=mel_out.aux_out, diff_mel=mel_out.diff_out, + f0=sample['f0'] + ) + if mel_out.aux_out is not None: + self.plot_mel(batch_idx, sample['mel'], mel_out.aux_out, name=f'auxmel_{batch_idx}') + if mel_out.diff_out is not None: + self.plot_mel(batch_idx, sample['mel'], mel_out.diff_out, name=f'diffmel_{batch_idx}') return losses, sample['size'] ############ # validation plots ############ - def plot_wav(self, batch_idx, gt_mel, pred_mel, f0=None): + def plot_wav(self, batch_idx, gt_mel, aux_mel=None, diff_mel=None, f0=None): gt_mel = gt_mel[0].cpu().numpy() - pred_mel = pred_mel[0].cpu().numpy() + if aux_mel is not None: + aux_mel = aux_mel[0].cpu().numpy() + if diff_mel is not None: + diff_mel = diff_mel[0].cpu().numpy() f0 = f0[0].cpu().numpy() if batch_idx not in self.logged_gt_wav: gt_wav = self.vocoder.spec2wav(gt_mel, f0=f0) self.logger.experiment.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step) self.logged_gt_wav.add(batch_idx) - pred_wav = self.vocoder.spec2wav(pred_mel, f0=f0) - self.logger.experiment.add_audio(f'pred_{batch_idx}', pred_wav, sample_rate=hparams['audio_sample_rate'], - global_step=self.global_step) + if aux_mel is not None: + aux_wav = self.vocoder.spec2wav(aux_mel, f0=f0) + self.logger.experiment.add_audio(f'aux_{batch_idx}', aux_wav, sample_rate=hparams['audio_sample_rate'], + global_step=self.global_step) + if diff_mel is not None: + diff_wav = self.vocoder.spec2wav(diff_mel, f0=f0) + self.logger.experiment.add_audio(f'diff_{batch_idx}', diff_wav, sample_rate=hparams['audio_sample_rate'], + global_step=self.global_step) def plot_mel(self, batch_idx, spec, spec_out, name=None): name = f'mel_{batch_idx}' if name is None else name diff --git a/utils/onnx_helper.py b/utils/onnx_helper.py index 9fc3f6fa..bebe9756 100644 --- a/utils/onnx_helper.py +++ b/utils/onnx_helper.py @@ -277,27 +277,31 @@ def _extract_conv_nodes_recursive(subgraph: GraphProto): to_be_removed.append(sub_node) [subgraph.node.remove(_n) for _n in to_be_removed] + toplevel_if_idx = toplevel_if_node = None + # Find the **last** If node in toplevel graph for i, n in enumerate(graph.node): if n.op_type == 'If': - for a in n.attribute: - b = onnx.helper.get_attribute_value(a) - _extract_conv_nodes_recursive(b) - # Insert the extracted nodes before the first 'If' node which carries the main denoising loop. - for key in reversed(node_dict): - alias, node = node_dict[key] - # Rename output of the node. - out_name = node.output[0] - node.output.remove(node.output[0]) - node.output.insert(0, alias) - # Insert node into the main graph. - graph.node.insert(i, node) - # Rename value info of the output. - for v in graph.value_info: - if v.name == out_name: - v.name = alias - break - _verbose(f'| extract conditioner projection: \'{node.name}\'') - break + toplevel_if_idx = i + toplevel_if_node = n + if toplevel_if_node is not None: + for a in toplevel_if_node.attribute: + b = onnx.helper.get_attribute_value(a) + _extract_conv_nodes_recursive(b) + # Insert the extracted nodes before the first 'If' node which carries the main denoising loop. + for key in reversed(node_dict): + alias, node = node_dict[key] + # Rename output of the node. + out_name = node.output[0] + node.output.remove(node.output[0]) + node.output.insert(0, alias) + # Insert node into the main graph. + graph.node.insert(toplevel_if_idx, node) + # Rename value info of the output. + for v in graph.value_info: + if v.name == out_name: + v.name = alias + break + _verbose(f'| extract conditioner projection: \'{node.name}\'') def graph_remove_unused_values(graph: GraphProto):