Skip to content

Commit

Permalink
Shallow diffusion and aux decoder (#128)
Browse files Browse the repository at this point in the history
* Add shallow diffusion API

* Support aux decoder training

* Support shallow diffusion inference

* add shallow farmwork

* add shallow farmwork

* Support lambda for aux mel loss

* Move config key

* add shallow farmework

* add shallow farmework

* add denorm

* add shallow model training switch

* Limit gradient from aux decoder

* Improve loss calculation control flow

* add   independent encoder in shallow

* Adjust lambda

* Implement shallow diffusion
There are some issues to resolve in DPM-Solver++ and UniPC

* Fix missing depth assignment

* fix bugs of shallow diffusion inference

* Fix errors and remove debug code

* Support K_step < timesteps (shallow-only diffusion)

* Fix argument passing

* Add missing checks

* add   glow decoder

* add   glow decoder

* add   convnext glow decoder

* fix fs2

* Support using gt mel as source during validation

* Clean files and configs

* Clean and refactor aux decoder

* Fix KeyError

* Support exporting shallow diffusion to ONNX

* Add missing logic to ONNX

* Rename `diff_depth` to `K_step_infer`

---------

Co-authored-by: yqzhishen <[email protected]>
Co-authored-by: autumn <2>
Co-authored-by: [email protected] <[email protected]>
  • Loading branch information
3 people authored Sep 22, 2023
1 parent bc6b0dd commit 0ec98bd
Show file tree
Hide file tree
Showing 13 changed files with 492 additions and 117 deletions.
18 changes: 18 additions & 0 deletions configs/acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,24 @@ diff_decoder_type: 'wavenet'
diff_loss_type: l2
schedule_type: 'linear'

# shallow diffusion
use_shallow_diffusion: false
K_step_infer: 400

shallow_diffusion_args:
train_aux_decoder: true
train_diffusion: true
val_gt_start: false
aux_decoder_arch: convnext
aux_decoder_args:
num_channels: 512
num_layers: 6
kernel_size: 7
dropout_rate: 0.1
aux_decoder_grad: 0.1

lambda_aux_mel_loss: 0.2

# train and eval
num_sanity_val_steps: 1
optimizer_args:
Expand Down
17 changes: 17 additions & 0 deletions configs/templates/config_acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,26 @@ augmentation_args:
domain: log # or linear
scale: 1.0

K_step: 1000
residual_channels: 512
residual_layers: 20

# shallow diffusion
use_shallow_diffusion: false
K_step_infer: 400
shallow_diffusion_args:
train_aux_decoder: true
train_diffusion: true
val_gt_start: false
aux_decoder_arch: convnext
aux_decoder_args:
num_channels: 512
num_layers: 6
kernel_size: 7
dropout_rate: 0.1
aux_decoder_grad: 0.1
lambda_aux_mel_loss: 0.2

optimizer_args:
lr: 0.0004
lr_scheduler_args:
Expand Down
88 changes: 64 additions & 24 deletions deployment/exporters/acoustic_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,22 @@ def __init__(
self.spk_map: dict = self.build_spk_map()
self.vocab = TokenTextEncoder(vocab_list=build_phoneme_list())
self.model = self.build_model()
self.fs2_cache_path = self.cache_dir / 'fs2.onnx'
self.fs2_aux_cache_path = self.cache_dir / (
'fs2_aux.onnx' if self.model.use_shallow_diffusion else 'fs2.onnx'
)
self.diffusion_cache_path = self.cache_dir / 'diffusion.onnx'

# Attributes for logging
self.model_class_name = remove_suffix(self.model.__class__.__name__, 'ONNX')
self.fs2_class_name = remove_suffix(self.model.fs2.__class__.__name__, 'ONNX')
fs2_aux_cls_logging = [remove_suffix(self.model.fs2.__class__.__name__, 'ONNX')]
if self.model.use_shallow_diffusion:
fs2_aux_cls_logging.append(remove_suffix(
self.model.aux_decoder.decoder.__class__.__name__, 'ONNX'
))
self.fs2_aux_class_name = ', '.join(fs2_aux_cls_logging)
self.aux_decoder_class_name = remove_suffix(
self.model.aux_decoder.decoder.__class__.__name__, 'ONNX'
) if self.model.use_shallow_diffusion else None
self.denoiser_class_name = remove_suffix(self.model.diffusion.denoise_fn.__class__.__name__, 'ONNX')
self.diffusion_class_name = remove_suffix(self.model.diffusion.__class__.__name__, 'ONNX')

Expand Down Expand Up @@ -86,11 +96,11 @@ def export(self, path: Path):

def export_model(self, path: Path):
self._torch_export_model()
fs2_onnx = self._optimize_fs2_graph(onnx.load(self.fs2_cache_path))
fs2_aux_onnx = self._optimize_fs2_aux_graph(onnx.load(self.fs2_aux_cache_path))
diffusion_onnx = self._optimize_diffusion_graph(onnx.load(self.diffusion_cache_path))
model_onnx = self._merge_fs2_diffusion_graphs(fs2_onnx, diffusion_onnx)
model_onnx = self._merge_fs2_aux_diffusion_graphs(fs2_aux_onnx, diffusion_onnx)
onnx.save(model_onnx, path)
self.fs2_cache_path.unlink()
self.fs2_aux_cache_path.unlink()
self.diffusion_cache_path.unlink()
print(f'| export model => {path}')

Expand All @@ -105,7 +115,7 @@ def export_attachments(self, path: Path):

@torch.no_grad()
def _torch_export_model(self):
# Prepare inputs for FastSpeech2 tracing
# Prepare inputs for FastSpeech2 and aux decoder tracing
n_frames = 10
tokens = torch.LongTensor([[1]]).to(self.device)
durations = torch.LongTensor([[n_frames]]).to(self.device)
Expand Down Expand Up @@ -161,22 +171,30 @@ def _torch_export_model(self):
1: 'n_frames'
}

# PyTorch ONNX export for FastSpeech2
print(f'Exporting {self.fs2_class_name}...')
# PyTorch ONNX export for FastSpeech2 and aux decoder
output_names = ['condition']
if self.model.use_shallow_diffusion:
output_names.append('aux_mel')
dynamix_axes['aux_mel'] = {
1: 'n_frames'
}
print(f'Exporting {self.fs2_aux_class_name}...')
torch.onnx.export(
self.model.view_as_fs2(),
self.model.view_as_fs2_aux(),
arguments,
self.fs2_cache_path,
self.fs2_aux_cache_path,
input_names=input_names,
output_names=['condition'],
output_names=output_names,
dynamic_axes=dynamix_axes,
opset_version=15
)

condition = torch.rand((1, n_frames, hparams['hidden_size']), device=self.device)

# Prepare inputs for denoiser tracing and GaussianDiffusion scripting
shape = (1, 1, hparams['audio_num_mel_bins'], n_frames)
noise = torch.randn(shape, device=self.device)
condition = torch.rand((1, hparams['hidden_size'], n_frames), device=self.device)
x_start = torch.randn((1, n_frames, hparams['audio_num_mel_bins']),device=self.device)
step = (torch.rand((1,), device=self.device) * hparams['K_step']).long()

print(f'Tracing {self.denoiser_class_name} denoiser...')
Expand All @@ -186,20 +204,24 @@ def _torch_export_model(self):
(
noise,
step,
condition
condition.transpose(1, 2)
)
)

print(f'Scripting {self.diffusion_class_name}...')
diffusion_inputs = [
condition,
*([x_start, 100] if self.model.use_shallow_diffusion else [])
]
diffusion = torch.jit.script(
diffusion,
example_inputs=[
(
condition.transpose(1, 2),
*diffusion_inputs,
1 # p_sample branch
),
(
condition.transpose(1, 2),
*diffusion_inputs,
200 # p_sample_plms branch
)
]
Expand All @@ -210,12 +232,14 @@ def _torch_export_model(self):
torch.onnx.export(
diffusion,
(
condition.transpose(1, 2),
*diffusion_inputs,
200
),
self.diffusion_cache_path,
input_names=[
'condition', 'speedup'
'condition',
*(['x_start', 'depth'] if self.model.use_shallow_diffusion else []),
'speedup'
],
output_names=[
'mel'
Expand All @@ -224,6 +248,7 @@ def _torch_export_model(self):
'condition': {
1: 'n_frames'
},
**({'x_start': {1: 'n_frames'}} if self.model.use_shallow_diffusion else {}),
'mel': {
1: 'n_frames'
}
Expand Down Expand Up @@ -252,11 +277,11 @@ def _perform_spk_mix(self, spk_mix: Dict[str, float]):
) # => [1, H]
return spk_mix_embed

def _optimize_fs2_graph(self, fs2: onnx.ModelProto) -> onnx.ModelProto:
print(f'Running ONNX Simplifier on {self.fs2_class_name}...')
def _optimize_fs2_aux_graph(self, fs2: onnx.ModelProto) -> onnx.ModelProto:
print(f'Running ONNX Simplifier on {self.fs2_aux_class_name}...')
fs2, check = onnxsim.simplify(fs2, include_subgraph=True)
assert check, 'Simplified ONNX model could not be validated'
print(f'| optimize graph: {self.fs2_class_name}')
print(f'| optimize graph: {self.fs2_aux_class_name}')
return fs2

def _optimize_diffusion_graph(self, diffusion: onnx.ModelProto) -> onnx.ModelProto:
Expand All @@ -282,18 +307,33 @@ def _optimize_diffusion_graph(self, diffusion: onnx.ModelProto) -> onnx.ModelPro
print(f'| optimize graph: {self.diffusion_class_name}')
return diffusion

def _merge_fs2_diffusion_graphs(self, fs2: onnx.ModelProto, diffusion: onnx.ModelProto) -> onnx.ModelProto:
onnx_helper.model_add_prefixes(fs2, dim_prefix='fs2.', ignored_pattern=r'(n_tokens)|(n_frames)')
def _merge_fs2_aux_diffusion_graphs(self, fs2: onnx.ModelProto, diffusion: onnx.ModelProto) -> onnx.ModelProto:
onnx_helper.model_add_prefixes(
fs2, dim_prefix=('fs2aux.' if self.model.use_shallow_diffusion else 'fs2.'),
ignored_pattern=r'(n_tokens)|(n_frames)'
)
onnx_helper.model_add_prefixes(diffusion, dim_prefix='diffusion.', ignored_pattern='n_frames')
print(f'Merging {self.fs2_class_name} and {self.diffusion_class_name} '
print(f'Merging {self.fs2_aux_class_name} and {self.diffusion_class_name} '
f'back into {self.model_class_name}...')
merged = onnx.compose.merge_models(
fs2, diffusion, io_map=[('condition', 'condition')],
fs2, diffusion, io_map=[
('condition', 'condition'),
*([('aux_mel', 'x_start')] if self.model.use_shallow_diffusion else []),
],
prefix1='', prefix2='', doc_string='',
producer_name=fs2.producer_name, producer_version=fs2.producer_version,
domain=fs2.domain, model_version=fs2.model_version
)
merged.graph.name = fs2.graph.name

print(f'Running ONNX Simplifier on {self.model_class_name}...')
merged, check = onnxsim.simplify(
merged,
include_subgraph=True
)
assert check, 'Simplified ONNX model could not be validated'
print(f'| optimize graph: {self.model_class_name}')

return merged

# noinspection PyMethodMayBeStatic
Expand Down
39 changes: 33 additions & 6 deletions deployment/modules/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def extract(a, t):

# noinspection PyMethodOverriding
class GaussianDiffusionONNX(GaussianDiffusion):
def q_sample(self, x_start, t, noise):
return (
extract(self.sqrt_alphas_cumprod, t) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t) * noise
)

def p_sample(self, x, t, cond):
x_pred = self.denoise_fn(x, t, cond)
x_recon = (
Expand Down Expand Up @@ -74,18 +80,39 @@ def p_sample_plms(self, x_prev, t, interval: int, cond, noise_list: List[Tensor]
x_prev = self.plms_get_x_pred(x_prev, noise_pred_prime, t, t_prev)
return noise_pred, x_prev

def norm_spec(self, x):
k = (self.spec_max - self.spec_min) / 2.
b = (self.spec_max + self.spec_min) / 2.
return (x - b) / k

def denorm_spec(self, x):
d = (self.spec_max - self.spec_min) / 2.
m = (self.spec_max + self.spec_min) / 2.
return x * d + m
k = (self.spec_max - self.spec_min) / 2.
b = (self.spec_max + self.spec_min) / 2.
return x * k + b

def forward(self, condition, speedup: int):
def forward(self, condition, x_start=None, depth: int = 1000, speedup: int = 1):
condition = condition.transpose(1, 2) # [1, T, H] => [1, H, T]
device = condition.device
n_frames = condition.shape[2]

step_range = torch.arange(0, self.k_step, speedup, dtype=torch.long, device=device).flip(0)[:, None]
x = torch.randn((1, self.num_feats, self.out_dims, n_frames), device=device)
noise = torch.randn((1, self.num_feats, self.out_dims, n_frames), device=device)
if x_start is None:
step_range = torch.arange(0, self.k_step, speedup, dtype=torch.long, device=device).flip(0)[:, None]
x = noise
else:
depth = min(depth, self.k_step)
step_range = torch.arange(0, depth, speedup, dtype=torch.long, device=device).flip(0)[:, None]
x_start = self.norm_spec(x_start).transpose(-2, -1)
if self.num_feats == 1:
x_start = x_start[:, None, :, :]
if depth >= self.timesteps:
x = noise
elif depth > 0:
x = self.q_sample(
x_start, torch.full((1,), depth - 1, device=device, dtype=torch.long), noise
)
else:
x = x_start

if speedup > 1:
for t in step_range:
Expand Down
49 changes: 24 additions & 25 deletions deployment/modules/toplevel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import copy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(self, vocab_size, out_dims):
spec_max=hparams['spec_max']
)

def forward_fs2(
def forward_fs2_aux(
self,
tokens: Tensor,
durations: Tensor,
Expand All @@ -46,41 +46,40 @@ def forward_fs2(
gender: Tensor = None,
velocity: Tensor = None,
spk_embed: Tensor = None
) -> Tensor:
return self.fs2(
):
condition = self.fs2(
tokens, durations, f0, variances=variances,
gender=gender, velocity=velocity, spk_embed=spk_embed
)
if self.use_shallow_diffusion:
aux_mel_pred = self.aux_decoder(condition, infer=True)
return condition, aux_mel_pred
else:
return condition

def forward_shallow_diffusion(
self, condition: Tensor, x_start: Tensor,
depth: int, speedup: int
) -> Tensor:
return self.diffusion(condition, x_start=x_start, depth=depth, speedup=speedup)

def forward_diffusion(self, condition: Tensor, speedup: int) -> Tensor:
return self.diffusion(condition, speedup)
def forward_diffusion(self, condition: Tensor, speedup: int):
return self.diffusion(condition, speedup=speedup)

def view_as_fs2(self) -> nn.Module:
def view_as_fs2_aux(self) -> nn.Module:
model = copy.deepcopy(self)
try:
del model.variance_embeds
del model.variance_adaptor
except AttributeError:
pass
del model.diffusion
model.forward = model.forward_fs2
model.forward = model.forward_fs2_aux
return model

def view_as_adaptor(self) -> nn.Module:
model = copy.deepcopy(self)
del model.fs2
del model.diffusion
raise NotImplementedError()

def view_as_diffusion(self) -> nn.Module:
model = copy.deepcopy(self)
del model.fs2
try:
del model.variance_embeds
del model.variance_adaptor
except AttributeError:
pass
model.forward = model.forward_diffusion
if self.use_shallow_diffusion:
del model.aux_decoder
model.forward = model.forward_shallow_diffusion
else:
model.forward = model.forward_diffusion
return model


Expand Down
Loading

0 comments on commit 0ec98bd

Please sign in to comment.