Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shallow diffusion and aux decoder #128

Merged
merged 38 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
7da2d65
Add shallow diffusion API
yqzhishen Aug 6, 2023
4f3d765
Support aux decoder training
yqzhishen Aug 6, 2023
2380f88
Support shallow diffusion inference
yqzhishen Aug 6, 2023
e386348
add shallow farmwork
Aug 7, 2023
0f5cd68
Merge remote-tracking branch 'origin/shallow-diffusion' into shallow-…
Aug 7, 2023
6ad8fd2
add shallow farmwork
Aug 7, 2023
6d93610
Support lambda for aux mel loss
yqzhishen Aug 7, 2023
b16c066
Move config key
yqzhishen Aug 7, 2023
8f3a622
add shallow farmework
Aug 7, 2023
2f1c600
Merge remote-tracking branch 'origin/shallow-diffusion' into shallow-…
Aug 7, 2023
ac29eeb
add shallow farmework
Aug 7, 2023
5c687eb
add denorm
Aug 7, 2023
a47f9ae
add shallow model training switch
Aug 7, 2023
4708692
Limit gradient from aux decoder
yqzhishen Aug 7, 2023
f449e04
Improve loss calculation control flow
yqzhishen Aug 7, 2023
28a67ae
add independent encoder in shallow
Aug 8, 2023
144c776
Adjust lambda
yqzhishen Aug 8, 2023
e269708
Implement shallow diffusion
yqzhishen Aug 8, 2023
52b3125
Fix missing depth assignment
yqzhishen Aug 8, 2023
030223b
fix bugs of shallow diffusion inference
yxlllc Aug 8, 2023
39bdcb8
Fix errors and remove debug code
yqzhishen Aug 8, 2023
eb114d6
Support K_step < timesteps (shallow-only diffusion)
yqzhishen Aug 8, 2023
348e7cc
Fix argument passing
yqzhishen Aug 8, 2023
554c4ac
Add missing checks
yqzhishen Aug 9, 2023
b04b039
add glow decoder
Aug 17, 2023
3a4e77a
add glow decoder
Aug 17, 2023
a564622
Merge remote-tracking branch 'origin/shallow-diffusion' into shallow-…
Aug 17, 2023
4f6e50f
add convnext glow decoder
Aug 18, 2023
bf1d62c
fix fs2
Aug 20, 2023
6bbdf6c
Support using gt mel as source during validation
yqzhishen Aug 21, 2023
f0c95b1
Merge branch 'main' into shallow-diffusion
yqzhishen Sep 21, 2023
f1cc641
Clean files and configs
yqzhishen Sep 21, 2023
20d5bb5
Clean and refactor aux decoder
yqzhishen Sep 21, 2023
ef87664
Fix KeyError
yqzhishen Sep 21, 2023
2986c88
Support exporting shallow diffusion to ONNX
yqzhishen Sep 21, 2023
6e213c7
Merge branch 'main' into shallow-diffusion
yqzhishen Sep 21, 2023
1a8fb72
Add missing logic to ONNX
yqzhishen Sep 21, 2023
acf00e4
Rename `diff_depth` to `K_step_infer`
yqzhishen Sep 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions configs/acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,24 @@ diff_decoder_type: 'wavenet'
diff_loss_type: l2
schedule_type: 'linear'

# shallow diffusion
use_shallow_diffusion: false
K_step_infer: 400

shallow_diffusion_args:
train_aux_decoder: true
train_diffusion: true
val_gt_start: false
aux_decoder_arch: convnext
aux_decoder_args:
num_channels: 512
num_layers: 6
kernel_size: 7
dropout_rate: 0.1
aux_decoder_grad: 0.1

lambda_aux_mel_loss: 0.2

# train and eval
num_sanity_val_steps: 1
optimizer_args:
Expand Down
17 changes: 17 additions & 0 deletions configs/templates/config_acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,26 @@ augmentation_args:
domain: log # or linear
scale: 1.0

K_step: 1000
residual_channels: 512
residual_layers: 20

# shallow diffusion
use_shallow_diffusion: false
K_step_infer: 400
shallow_diffusion_args:
train_aux_decoder: true
train_diffusion: true
val_gt_start: false
aux_decoder_arch: convnext
aux_decoder_args:
num_channels: 512
num_layers: 6
kernel_size: 7
dropout_rate: 0.1
aux_decoder_grad: 0.1
lambda_aux_mel_loss: 0.2

optimizer_args:
lr: 0.0004
lr_scheduler_args:
Expand Down
88 changes: 64 additions & 24 deletions deployment/exporters/acoustic_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,22 @@ def __init__(
self.spk_map: dict = self.build_spk_map()
self.vocab = TokenTextEncoder(vocab_list=build_phoneme_list())
self.model = self.build_model()
self.fs2_cache_path = self.cache_dir / 'fs2.onnx'
self.fs2_aux_cache_path = self.cache_dir / (
'fs2_aux.onnx' if self.model.use_shallow_diffusion else 'fs2.onnx'
)
self.diffusion_cache_path = self.cache_dir / 'diffusion.onnx'

# Attributes for logging
self.model_class_name = remove_suffix(self.model.__class__.__name__, 'ONNX')
self.fs2_class_name = remove_suffix(self.model.fs2.__class__.__name__, 'ONNX')
fs2_aux_cls_logging = [remove_suffix(self.model.fs2.__class__.__name__, 'ONNX')]
if self.model.use_shallow_diffusion:
fs2_aux_cls_logging.append(remove_suffix(
self.model.aux_decoder.decoder.__class__.__name__, 'ONNX'
))
self.fs2_aux_class_name = ', '.join(fs2_aux_cls_logging)
self.aux_decoder_class_name = remove_suffix(
self.model.aux_decoder.decoder.__class__.__name__, 'ONNX'
) if self.model.use_shallow_diffusion else None
self.denoiser_class_name = remove_suffix(self.model.diffusion.denoise_fn.__class__.__name__, 'ONNX')
self.diffusion_class_name = remove_suffix(self.model.diffusion.__class__.__name__, 'ONNX')

Expand Down Expand Up @@ -86,11 +96,11 @@ def export(self, path: Path):

def export_model(self, path: Path):
self._torch_export_model()
fs2_onnx = self._optimize_fs2_graph(onnx.load(self.fs2_cache_path))
fs2_aux_onnx = self._optimize_fs2_aux_graph(onnx.load(self.fs2_aux_cache_path))
diffusion_onnx = self._optimize_diffusion_graph(onnx.load(self.diffusion_cache_path))
model_onnx = self._merge_fs2_diffusion_graphs(fs2_onnx, diffusion_onnx)
model_onnx = self._merge_fs2_aux_diffusion_graphs(fs2_aux_onnx, diffusion_onnx)
onnx.save(model_onnx, path)
self.fs2_cache_path.unlink()
self.fs2_aux_cache_path.unlink()
self.diffusion_cache_path.unlink()
print(f'| export model => {path}')

Expand All @@ -105,7 +115,7 @@ def export_attachments(self, path: Path):

@torch.no_grad()
def _torch_export_model(self):
# Prepare inputs for FastSpeech2 tracing
# Prepare inputs for FastSpeech2 and aux decoder tracing
n_frames = 10
tokens = torch.LongTensor([[1]]).to(self.device)
durations = torch.LongTensor([[n_frames]]).to(self.device)
Expand Down Expand Up @@ -161,22 +171,30 @@ def _torch_export_model(self):
1: 'n_frames'
}

# PyTorch ONNX export for FastSpeech2
print(f'Exporting {self.fs2_class_name}...')
# PyTorch ONNX export for FastSpeech2 and aux decoder
output_names = ['condition']
if self.model.use_shallow_diffusion:
output_names.append('aux_mel')
dynamix_axes['aux_mel'] = {
1: 'n_frames'
}
print(f'Exporting {self.fs2_aux_class_name}...')
torch.onnx.export(
self.model.view_as_fs2(),
self.model.view_as_fs2_aux(),
arguments,
self.fs2_cache_path,
self.fs2_aux_cache_path,
input_names=input_names,
output_names=['condition'],
output_names=output_names,
dynamic_axes=dynamix_axes,
opset_version=15
)

condition = torch.rand((1, n_frames, hparams['hidden_size']), device=self.device)

# Prepare inputs for denoiser tracing and GaussianDiffusion scripting
shape = (1, 1, hparams['audio_num_mel_bins'], n_frames)
noise = torch.randn(shape, device=self.device)
condition = torch.rand((1, hparams['hidden_size'], n_frames), device=self.device)
x_start = torch.randn((1, n_frames, hparams['audio_num_mel_bins']),device=self.device)
step = (torch.rand((1,), device=self.device) * hparams['K_step']).long()

print(f'Tracing {self.denoiser_class_name} denoiser...')
Expand All @@ -186,20 +204,24 @@ def _torch_export_model(self):
(
noise,
step,
condition
condition.transpose(1, 2)
)
)

print(f'Scripting {self.diffusion_class_name}...')
diffusion_inputs = [
condition,
*([x_start, 100] if self.model.use_shallow_diffusion else [])
]
diffusion = torch.jit.script(
diffusion,
example_inputs=[
(
condition.transpose(1, 2),
*diffusion_inputs,
1 # p_sample branch
),
(
condition.transpose(1, 2),
*diffusion_inputs,
200 # p_sample_plms branch
)
]
Expand All @@ -210,12 +232,14 @@ def _torch_export_model(self):
torch.onnx.export(
diffusion,
(
condition.transpose(1, 2),
*diffusion_inputs,
200
),
self.diffusion_cache_path,
input_names=[
'condition', 'speedup'
'condition',
*(['x_start', 'depth'] if self.model.use_shallow_diffusion else []),
'speedup'
],
output_names=[
'mel'
Expand All @@ -224,6 +248,7 @@ def _torch_export_model(self):
'condition': {
1: 'n_frames'
},
**({'x_start': {1: 'n_frames'}} if self.model.use_shallow_diffusion else {}),
'mel': {
1: 'n_frames'
}
Expand Down Expand Up @@ -252,11 +277,11 @@ def _perform_spk_mix(self, spk_mix: Dict[str, float]):
) # => [1, H]
return spk_mix_embed

def _optimize_fs2_graph(self, fs2: onnx.ModelProto) -> onnx.ModelProto:
print(f'Running ONNX Simplifier on {self.fs2_class_name}...')
def _optimize_fs2_aux_graph(self, fs2: onnx.ModelProto) -> onnx.ModelProto:
print(f'Running ONNX Simplifier on {self.fs2_aux_class_name}...')
fs2, check = onnxsim.simplify(fs2, include_subgraph=True)
assert check, 'Simplified ONNX model could not be validated'
print(f'| optimize graph: {self.fs2_class_name}')
print(f'| optimize graph: {self.fs2_aux_class_name}')
return fs2

def _optimize_diffusion_graph(self, diffusion: onnx.ModelProto) -> onnx.ModelProto:
Expand All @@ -282,18 +307,33 @@ def _optimize_diffusion_graph(self, diffusion: onnx.ModelProto) -> onnx.ModelPro
print(f'| optimize graph: {self.diffusion_class_name}')
return diffusion

def _merge_fs2_diffusion_graphs(self, fs2: onnx.ModelProto, diffusion: onnx.ModelProto) -> onnx.ModelProto:
onnx_helper.model_add_prefixes(fs2, dim_prefix='fs2.', ignored_pattern=r'(n_tokens)|(n_frames)')
def _merge_fs2_aux_diffusion_graphs(self, fs2: onnx.ModelProto, diffusion: onnx.ModelProto) -> onnx.ModelProto:
onnx_helper.model_add_prefixes(
fs2, dim_prefix=('fs2aux.' if self.model.use_shallow_diffusion else 'fs2.'),
ignored_pattern=r'(n_tokens)|(n_frames)'
)
onnx_helper.model_add_prefixes(diffusion, dim_prefix='diffusion.', ignored_pattern='n_frames')
print(f'Merging {self.fs2_class_name} and {self.diffusion_class_name} '
print(f'Merging {self.fs2_aux_class_name} and {self.diffusion_class_name} '
f'back into {self.model_class_name}...')
merged = onnx.compose.merge_models(
fs2, diffusion, io_map=[('condition', 'condition')],
fs2, diffusion, io_map=[
('condition', 'condition'),
*([('aux_mel', 'x_start')] if self.model.use_shallow_diffusion else []),
],
prefix1='', prefix2='', doc_string='',
producer_name=fs2.producer_name, producer_version=fs2.producer_version,
domain=fs2.domain, model_version=fs2.model_version
)
merged.graph.name = fs2.graph.name

print(f'Running ONNX Simplifier on {self.model_class_name}...')
merged, check = onnxsim.simplify(
merged,
include_subgraph=True
)
assert check, 'Simplified ONNX model could not be validated'
print(f'| optimize graph: {self.model_class_name}')

return merged

# noinspection PyMethodMayBeStatic
Expand Down
39 changes: 33 additions & 6 deletions deployment/modules/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def extract(a, t):

# noinspection PyMethodOverriding
class GaussianDiffusionONNX(GaussianDiffusion):
def q_sample(self, x_start, t, noise):
return (
extract(self.sqrt_alphas_cumprod, t) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t) * noise
)

def p_sample(self, x, t, cond):
x_pred = self.denoise_fn(x, t, cond)
x_recon = (
Expand Down Expand Up @@ -74,18 +80,39 @@ def p_sample_plms(self, x_prev, t, interval: int, cond, noise_list: List[Tensor]
x_prev = self.plms_get_x_pred(x_prev, noise_pred_prime, t, t_prev)
return noise_pred, x_prev

def norm_spec(self, x):
k = (self.spec_max - self.spec_min) / 2.
b = (self.spec_max + self.spec_min) / 2.
return (x - b) / k

def denorm_spec(self, x):
d = (self.spec_max - self.spec_min) / 2.
m = (self.spec_max + self.spec_min) / 2.
return x * d + m
k = (self.spec_max - self.spec_min) / 2.
b = (self.spec_max + self.spec_min) / 2.
return x * k + b

def forward(self, condition, speedup: int):
def forward(self, condition, x_start=None, depth: int = 1000, speedup: int = 1):
condition = condition.transpose(1, 2) # [1, T, H] => [1, H, T]
device = condition.device
n_frames = condition.shape[2]

step_range = torch.arange(0, self.k_step, speedup, dtype=torch.long, device=device).flip(0)[:, None]
x = torch.randn((1, self.num_feats, self.out_dims, n_frames), device=device)
noise = torch.randn((1, self.num_feats, self.out_dims, n_frames), device=device)
if x_start is None:
step_range = torch.arange(0, self.k_step, speedup, dtype=torch.long, device=device).flip(0)[:, None]
x = noise
else:
depth = min(depth, self.k_step)
step_range = torch.arange(0, depth, speedup, dtype=torch.long, device=device).flip(0)[:, None]
x_start = self.norm_spec(x_start).transpose(-2, -1)
if self.num_feats == 1:
x_start = x_start[:, None, :, :]
if depth >= self.timesteps:
x = noise
elif depth > 0:
x = self.q_sample(
x_start, torch.full((1,), depth - 1, device=device, dtype=torch.long), noise
)
else:
x = x_start

if speedup > 1:
for t in step_range:
Expand Down
49 changes: 24 additions & 25 deletions deployment/modules/toplevel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import copy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(self, vocab_size, out_dims):
spec_max=hparams['spec_max']
)

def forward_fs2(
def forward_fs2_aux(
self,
tokens: Tensor,
durations: Tensor,
Expand All @@ -46,41 +46,40 @@ def forward_fs2(
gender: Tensor = None,
velocity: Tensor = None,
spk_embed: Tensor = None
) -> Tensor:
return self.fs2(
):
condition = self.fs2(
tokens, durations, f0, variances=variances,
gender=gender, velocity=velocity, spk_embed=spk_embed
)
if self.use_shallow_diffusion:
aux_mel_pred = self.aux_decoder(condition, infer=True)
return condition, aux_mel_pred
else:
return condition

def forward_shallow_diffusion(
self, condition: Tensor, x_start: Tensor,
depth: int, speedup: int
) -> Tensor:
return self.diffusion(condition, x_start=x_start, depth=depth, speedup=speedup)

def forward_diffusion(self, condition: Tensor, speedup: int) -> Tensor:
return self.diffusion(condition, speedup)
def forward_diffusion(self, condition: Tensor, speedup: int):
return self.diffusion(condition, speedup=speedup)

def view_as_fs2(self) -> nn.Module:
def view_as_fs2_aux(self) -> nn.Module:
model = copy.deepcopy(self)
try:
del model.variance_embeds
del model.variance_adaptor
except AttributeError:
pass
del model.diffusion
model.forward = model.forward_fs2
model.forward = model.forward_fs2_aux
return model

def view_as_adaptor(self) -> nn.Module:
model = copy.deepcopy(self)
del model.fs2
del model.diffusion
raise NotImplementedError()

def view_as_diffusion(self) -> nn.Module:
model = copy.deepcopy(self)
del model.fs2
try:
del model.variance_embeds
del model.variance_adaptor
except AttributeError:
pass
model.forward = model.forward_diffusion
if self.use_shallow_diffusion:
del model.aux_decoder
model.forward = model.forward_shallow_diffusion
else:
model.forward = model.forward_diffusion
return model


Expand Down
Loading