From 90acc1b435a7c2423390fe26a46d42f83a78b90c Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 20 Apr 2023 07:31:21 +0000 Subject: [PATCH 1/6] fix some trainer --- examples/vctk/vc3/conf/default.yaml | 35 ++++++++++--------- examples/vctk/vc3/local/train.sh | 4 +-- paddlespeech/t2s/datasets/am_batch_fn.py | 23 ++++++------ paddlespeech/t2s/exps/starganv2_vc/train.py | 22 ++++++------ .../starganv2_vc/starganv2_vc_updater.py | 11 +++--- 5 files changed, 49 insertions(+), 46 deletions(-) diff --git a/examples/vctk/vc3/conf/default.yaml b/examples/vctk/vc3/conf/default.yaml index b1168a40e50..eb98515ac82 100644 --- a/examples/vctk/vc3/conf/default.yaml +++ b/examples/vctk/vc3/conf/default.yaml @@ -41,7 +41,7 @@ discriminator_params: dim_in: 64 # same as dim_in in generator_params num_domains: 20 # same as num_domains in mapping_network_params max_conv_dim: 512 # same as max_conv_dim in generator_params - n_repeat: 4 + repeat_num: 4 asr_params: input_dim: 80 hidden_dim: 256 @@ -77,6 +77,7 @@ loss_params: ########################################################### batch_size: 5 # Batch size. num_workers: 2 # Number of workers in DataLoader. +max_mel_length: 192 ########################################################### # OPTIMIZER & SCHEDULER SETTING # @@ -84,47 +85,47 @@ num_workers: 2 # Number of workers in DataLoader. generator_optimizer_params: beta1: 0.0 beta2: 0.99 - weight_decay: 1e-4 - epsilon: 1e-9 + weight_decay: 1.0e-4 + epsilon: 1.0e-9 generator_scheduler_params: - max_learning_rate: 2e-4 + max_learning_rate: 2.0e-4 phase_pct: 0.0 divide_factor: 1 total_steps: 200000 # train_max_steps - end_learning_rate: 2e-4 + end_learning_rate: 2.0e-4 style_encoder_optimizer_params: beta1: 0.0 beta2: 0.99 - weight_decay: 1e-4 - epsilon: 1e-9 + weight_decay: 1.0e-4 + epsilon: 1.0e-9 style_encoder_scheduler_params: - max_learning_rate: 2e-4 + max_learning_rate: 2.0e-4 phase_pct: 0.0 divide_factor: 1 total_steps: 200000 # train_max_steps - end_learning_rate: 2e-4 + end_learning_rate: 2.0e-4 mapping_network_optimizer_params: beta1: 0.0 beta2: 0.99 - weight_decay: 1e-4 - epsilon: 1e-9 + weight_decay: 1.0e-4 + epsilon: 1.0e-9 mapping_network_scheduler_params: - max_learning_rate: 2e-6 + max_learning_rate: 2.0e-6 phase_pct: 0.0 divide_factor: 1 total_steps: 200000 # train_max_steps - end_learning_rate: 2e-6 + end_learning_rate: 2.0e-6 discriminator_optimizer_params: beta1: 0.0 beta2: 0.99 - weight_decay: 1e-4 - epsilon: 1e-9 + weight_decay: 1.0e-4 + epsilon: 1.0e-9 discriminator_scheduler_params: - max_learning_rate: 2e-4 + max_learning_rate: 2.0e-4 phase_pct: 0.0 divide_factor: 1 total_steps: 200000 # train_max_steps - end_learning_rate: 2e-4 + end_learning_rate: 2.0e-4 ########################################################### # TRAINING SETTING # diff --git a/examples/vctk/vc3/local/train.sh b/examples/vctk/vc3/local/train.sh index 3a5076505dd..bdd8deaedcd 100755 --- a/examples/vctk/vc3/local/train.sh +++ b/examples/vctk/vc3/local/train.sh @@ -8,6 +8,4 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=1 \ - --phones-dict=dump/phone_id_map.txt \ - --speaker-dict=dump/speaker_id_map.txt + --ngpu=1 diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 4cd5bccc72f..ae46f1e1afd 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -852,20 +852,21 @@ def starganv2_vc_batch_fn(self, examples): # (B,) label = paddle.to_tensor(label) ref_label = paddle.to_tensor(ref_label) - # [B, 80, T] -> [B, 1, 80, T] - mel = paddle.to_tensor(mel) - ref_mel = paddle.to_tensor(ref_mel) - ref_mel_2 = paddle.to_tensor(ref_mel_2) + # [B, T, 80] -> [B, 1, 80, T] + mel = paddle.to_tensor(mel).transpose([0, 2, 1]).unsqueeze(1) + ref_mel = paddle.to_tensor(ref_mel).transpose([0, 2, 1]).unsqueeze(1) + ref_mel_2 = paddle.to_tensor(ref_mel_2).transpose( + [0, 2, 1]).unsqueeze(1) - z_trg = paddle.randn(batch_size, self.latent_dim) - z_trg2 = paddle.randn(batch_size, self.latent_dim) + z_trg = paddle.randn([batch_size, self.latent_dim]) + z_trg2 = paddle.randn([batch_size, self.latent_dim]) batch = { - "x_real": mels, - "y_org": labels, - "x_ref": ref_mels, - "x_ref2": ref_mels_2, - "y_trg": ref_labels, + "x_real": mel, + "y_org": label, + "x_ref": ref_mel, + "x_ref2": ref_mel_2, + "y_trg": ref_label, "z_trg": z_trg, "z_trg2": z_trg2 } diff --git a/paddlespeech/t2s/exps/starganv2_vc/train.py b/paddlespeech/t2s/exps/starganv2_vc/train.py index 529f1f3ddac..616591e798f 100644 --- a/paddlespeech/t2s/exps/starganv2_vc/train.py +++ b/paddlespeech/t2s/exps/starganv2_vc/train.py @@ -29,9 +29,12 @@ from paddle.optimizer.lr import OneCycleLR from yacs.config import CfgNode -from paddlespeech.t2s.datasets.am_batch_fn import starganv2_vc_batch_fn -from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.resource.pretrained_models import StarGANv2VC_source +from paddlespeech.t2s.datasets.am_batch_fn import build_starganv2_vc_collate_fn +from paddlespeech.t2s.datasets.data_table import StarGANv2VCDataTable from paddlespeech.t2s.models.starganv2_vc import ASRCNN +from paddlespeech.t2s.models.starganv2_vc import Discriminator from paddlespeech.t2s.models.starganv2_vc import Generator from paddlespeech.t2s.models.starganv2_vc import JDCNet from paddlespeech.t2s.models.starganv2_vc import MappingNetwork @@ -66,7 +69,9 @@ def train_sp(args, config): fields = ["speech", "speech_lengths"] converters = {"speech": np.load} - collate_fn = starganv2_vc_batch_fn + collate_fn = build_starganv2_vc_collate_fn( + latent_dim=config['mapping_network_params']['latent_dim'], + max_mel_length=config['max_mel_length']) # dataloader has been too verbose logging.getLogger("DataLoader").disabled = True @@ -74,16 +79,10 @@ def train_sp(args, config): # construct dataset for training and validation with jsonlines.open(args.train_metadata, 'r') as reader: train_metadata = list(reader) - train_dataset = DataTable( - data=train_metadata, - fields=fields, - converters=converters, ) + train_dataset = StarGANv2VCDataTable(data=train_metadata) with jsonlines.open(args.dev_metadata, 'r') as reader: dev_metadata = list(reader) - dev_dataset = DataTable( - data=dev_metadata, - fields=fields, - converters=converters, ) + dev_dataset = StarGANv2VCDataTable(data=dev_metadata) # collate function and dataloader train_sampler = DistributedBatchSampler( @@ -118,6 +117,7 @@ def train_sp(args, config): generator = Generator(**config['generator_params']) mapping_network = MappingNetwork(**config['mapping_network_params']) style_encoder = StyleEncoder(**config['style_encoder_params']) + discriminator = Discriminator(**config['discriminator_params']) # load pretrained model jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz') diff --git a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py index 09d4780e9e7..6a77fbb2caa 100644 --- a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py +++ b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py @@ -21,10 +21,13 @@ from paddle.optimizer import Optimizer from paddle.optimizer.lr import LRScheduler +from paddlespeech.t2s.models.starganv2_vc.losses import compute_d_loss +from paddlespeech.t2s.models.starganv2_vc.losses import compute_g_loss from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator from paddlespeech.t2s.training.reporter import report from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState + logging.basicConfig( format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', datefmt='[%Y-%m-%d %H:%M:%S]') @@ -62,10 +65,10 @@ def __init__(self, self.models = models self.optimizers = optimizers - self.optimizer_g = optimizers['optimizer_g'] - self.optimizer_s = optimizers['optimizer_s'] - self.optimizer_m = optimizers['optimizer_m'] - self.optimizer_d = optimizers['optimizer_d'] + self.optimizer_g = optimizers['generator'] + self.optimizer_s = optimizers['style_encoder'] + self.optimizer_m = optimizers['mapping_network'] + self.optimizer_d = optimizers['discriminator'] self.schedulers = schedulers self.scheduler_g = schedulers['generator'] From dee663094517af1d99f31580e7ea99c5d0e5a13d Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 20 Apr 2023 08:04:06 +0000 Subject: [PATCH 2/6] add reset_parameters --- .../models/starganv2_vc/AuxiliaryASR/model.py | 13 ++++++- .../t2s/models/starganv2_vc/starganv2_vc.py | 22 +++++++++++ paddlespeech/t2s/modules/nets_utils.py | 38 +++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py index 251974572a3..85b3453d874 100644 --- a/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py +++ b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py @@ -22,6 +22,7 @@ from .layers import ConvNorm from .layers import LinearNorm from .layers import MFCC +from paddlespeech.t2s.modules.nets_utils import _reset_parameters from paddlespeech.utils.initialize import uniform_ @@ -59,6 +60,9 @@ def __init__( hidden_dim=hidden_dim // 2, n_token=n_token) + self.reset_parameters() + self.asr_s2s.reset_parameters() + def forward(self, x: paddle.Tensor, src_key_padding_mask: paddle.Tensor=None, @@ -108,6 +112,9 @@ def get_future_mask(self, out_length: int, unmask_future_steps: int=0): index_tensor.T + unmask_future_steps) return mask + def reset_parameters(self): + self.apply(_reset_parameters) + class ASRS2S(nn.Layer): def __init__(self, @@ -118,8 +125,7 @@ def __init__(self, n_token: int=40): super().__init__() self.embedding = nn.Embedding(n_token, embedding_dim) - val_range = math.sqrt(6 / hidden_dim) - uniform_(self.embedding.weight, -val_range, val_range) + self.val_range = math.sqrt(6 / hidden_dim) self.decoder_rnn_dim = hidden_dim self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token) @@ -236,3 +242,6 @@ def parse_decoder_outputs(self, hidden = paddle.stack(hidden).transpose([1, 0, 2]) return hidden, logit, alignments + + def reset_parameters(self): + uniform_(self.embedding.weight, -self.val_range, self.val_range) diff --git a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py index 2b6775c4824..236708906cf 100644 --- a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py +++ b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py @@ -25,6 +25,8 @@ import paddle.nn.functional as F from paddle import nn +from paddlespeech.t2s.modules.nets_utils import _reset_parameters + class DownSample(nn.Layer): def __init__(self, layer_type: str): @@ -355,6 +357,8 @@ def __init__(self, if w_hpf > 0: self.hpf = HighPass(w_hpf) + self.reset_parameters() + def forward(self, x: paddle.Tensor, s: paddle.Tensor, @@ -399,6 +403,9 @@ def forward(self, out = self.to_out(x) return out + def reset_parameters(self): + self.apply(_reset_parameters) + class MappingNetwork(nn.Layer): def __init__(self, @@ -427,6 +434,8 @@ def __init__(self, nn.ReLU(), nn.Linear(hidden_dim, style_dim)) ]) + self.reset_parameters() + def forward(self, z: paddle.Tensor, y: paddle.Tensor): """Calculate forward propagation. Args: @@ -449,6 +458,9 @@ def forward(self, z: paddle.Tensor, y: paddle.Tensor): s = out[idx, y] return s + def reset_parameters(self): + self.apply(_reset_parameters) + class StyleEncoder(nn.Layer): def __init__(self, @@ -490,6 +502,8 @@ def __init__(self, for _ in range(num_domains): self.unshared.append(nn.Linear(dim_out, style_dim)) + self.reset_parameters() + def forward(self, x: paddle.Tensor, y: paddle.Tensor): """Calculate forward propagation. Args: @@ -513,6 +527,9 @@ def forward(self, x: paddle.Tensor, y: paddle.Tensor): s = out[idx, y] return s + def reset_parameters(self): + self.apply(_reset_parameters) + class Discriminator(nn.Layer): def __init__(self, @@ -535,6 +552,8 @@ def __init__(self, repeat_num=repeat_num) self.num_domains = num_domains + self.reset_parameters() + def forward(self, x: paddle.Tensor, y: paddle.Tensor): out = self.dis(x, y) return out @@ -543,6 +562,9 @@ def classifier(self, x: paddle.Tensor): out = self.cls.get_feature(x) return out + def reset_parameters(self): + self.apply(_reset_parameters) + class Discriminator2D(nn.Layer): def __init__(self, diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index 99130acca17..3d1b48dec0d 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -20,6 +20,44 @@ from paddle import nn from typeguard import check_argument_types +from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out +from paddlespeech.utils.initialize import kaiming_uniform_ +from paddlespeech.utils.initialize import normal_ +from paddlespeech.utils.initialize import ones_ +from paddlespeech.utils.initialize import uniform_ +from paddlespeech.utils.initialize import zeros_ + + +# default init method of torch +# copy from https://github.com/PaddlePaddle/PaddleSpeech/blob/9cf8c1985a98bb380c183116123672976bdfe5c9/paddlespeech/t2s/models/vits/vits.py#L506 +def _reset_parameters(module): + if isinstance(module, (nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, + nn.Conv2DTranspose)): + kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + uniform_(module.bias, -bound, bound) + + if isinstance(module, + (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)): + ones_(module.weight) + zeros_(module.bias) + + if isinstance(module, nn.Linear): + kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + uniform_(module.bias, -bound, bound) + + if isinstance(module, nn.Embedding): + normal_(module.weight) + if module._padding_idx is not None: + with paddle.no_grad(): + module.weight[module._padding_idx] = 0 + def pad_list(xs, pad_value): """Perform padding for the list of tensors. From ec13243ff46aa490a3d7ffe4b16da874a9ebf649 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 20 Apr 2023 08:58:16 +0000 Subject: [PATCH 3/6] add docstring for Discriminator --- paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py index 236708906cf..99aeb73bfbf 100644 --- a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py +++ b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py @@ -555,6 +555,16 @@ def __init__(self, self.reset_parameters() def forward(self, x: paddle.Tensor, y: paddle.Tensor): + """Calculate forward propagation. + Args: + x(Tensor(float32)): + Shape (B, 1, 80, T). + y(Tensor(float32)): + Shape (B, ). + Returns: + Tensor: + Shape (B, ) + """ out = self.dis(x, y) return out From b523701867972b39d8010af0df6a05458feef3b6 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 20 Apr 2023 09:03:56 +0000 Subject: [PATCH 4/6] add typehint --- paddlespeech/t2s/models/starganv2_vc/losses.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/paddlespeech/t2s/models/starganv2_vc/losses.py b/paddlespeech/t2s/models/starganv2_vc/losses.py index f9ff39276f4..aef7559f9a3 100644 --- a/paddlespeech/t2s/models/starganv2_vc/losses.py +++ b/paddlespeech/t2s/models/starganv2_vc/losses.py @@ -27,9 +27,9 @@ def compute_d_loss(nets: Dict[str, Any], y_trg: paddle.Tensor, z_trg: paddle.Tensor=None, x_ref: paddle.Tensor=None, - use_r1_reg=True, - use_adv_cls=False, - use_con_reg=False, + use_r1_reg: bool=True, + use_adv_cls: bool=False, + use_con_reg: bool=False, lambda_reg: float=1., lambda_adv_cls: float=0.1, lambda_con_reg: float=10.): @@ -37,7 +37,6 @@ def compute_d_loss(nets: Dict[str, Any], assert (z_trg is None) != (x_ref is None) # with real audios x_real.stop_gradient = False - out = nets['discriminator'](x_real, y_org) loss_real = adv_loss(out, 1) From 2edc79f96548c9d04362820f7c7394c765c317ea Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 20 Apr 2023 09:51:37 +0000 Subject: [PATCH 5/6] fix clip bug --- examples/vctk/vc3/local/train.sh | 3 +- paddlespeech/t2s/datasets/am_batch_fn.py | 10 +++--- paddlespeech/t2s/exps/starganv2_vc/train.py | 17 +++++++++- .../t2s/models/starganv2_vc/losses.py | 33 ++++++++++--------- .../starganv2_vc/starganv2_vc_updater.py | 4 +-- 5 files changed, 44 insertions(+), 23 deletions(-) diff --git a/examples/vctk/vc3/local/train.sh b/examples/vctk/vc3/local/train.sh index bdd8deaedcd..d4ea02da873 100755 --- a/examples/vctk/vc3/local/train.sh +++ b/examples/vctk/vc3/local/train.sh @@ -8,4 +8,5 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=1 + --ngpu=1 \ + --speaker-dict=dump/speaker_id_map.txt diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index ae46f1e1afd..85959aa250b 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -820,12 +820,13 @@ def __init__(self, latent_dim: int=16, max_mel_length: int=192): self.max_mel_length = max_mel_length def random_clip(self, mel: np.array): - # [80, T] - mel_length = mel.shape[1] + # [T, 80] + mel_length = mel.shape[0] if mel_length > self.max_mel_length: random_start = np.random.randint(0, mel_length - self.max_mel_length) - mel = mel[:, random_start:random_start + self.max_mel_length] + + mel = mel[random_start:random_start + self.max_mel_length, :] return mel def __call__(self, exmaples): @@ -843,8 +844,9 @@ def starganv2_vc_batch_fn(self, examples): mel = [self.random_clip(item["mel"]) for item in examples] ref_mel = [self.random_clip(item["ref_mel"]) for item in examples] ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples] - + print("mel[0].shape after batch_sequences:", mel[0].shape) mel = batch_sequences(mel) + print("mel.shape after batch_sequences:", mel.shape) ref_mel = batch_sequences(ref_mel) ref_mel_2 = batch_sequences(ref_mel_2) diff --git a/paddlespeech/t2s/exps/starganv2_vc/train.py b/paddlespeech/t2s/exps/starganv2_vc/train.py index 616591e798f..94fa3032cee 100644 --- a/paddlespeech/t2s/exps/starganv2_vc/train.py +++ b/paddlespeech/t2s/exps/starganv2_vc/train.py @@ -113,6 +113,16 @@ def train_sp(args, config): model_version = '1.0' uncompress_path = download_and_decompress(StarGANv2VC_source[model_version], MODEL_HOME) + # 根据 speaker 的个数修改 num_domains + # 源码的预训练模型和 default.yaml 里面默认是 20 + if args.speaker_dict is not None: + with open(args.speaker_dict, 'rt', encoding='utf-8') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + print("spk_num:", spk_num) + config['mapping_network_params']['num_domains'] = spk_num + config['style_encoder_params']['num_domains'] = spk_num + config['discriminator_params']['num_domains'] = spk_num generator = Generator(**config['generator_params']) mapping_network = MappingNetwork(**config['mapping_network_params']) @@ -123,7 +133,7 @@ def train_sp(args, config): jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz') asr_model_dir = os.path.join(uncompress_path, 'asr.pdz') - F0_model = JDCNet(num_class=1, seq_len=192) + F0_model = JDCNet(num_class=1, seq_len=config['max_mel_length']) F0_model.set_state_dict(paddle.load(jdc_model_dir)['main_params']) F0_model.eval() @@ -234,6 +244,11 @@ def main(): parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument( + "--speaker-dict", + type=str, + default=None, + help="speaker id map file for multiple speaker model.") args = parser.parse_args() diff --git a/paddlespeech/t2s/models/starganv2_vc/losses.py b/paddlespeech/t2s/models/starganv2_vc/losses.py index aef7559f9a3..f4a308da0d0 100644 --- a/paddlespeech/t2s/models/starganv2_vc/losses.py +++ b/paddlespeech/t2s/models/starganv2_vc/losses.py @@ -21,33 +21,35 @@ # 这些都写到 updater 里 -def compute_d_loss(nets: Dict[str, Any], - x_real: paddle.Tensor, - y_org: paddle.Tensor, - y_trg: paddle.Tensor, - z_trg: paddle.Tensor=None, - x_ref: paddle.Tensor=None, - use_r1_reg: bool=True, - use_adv_cls: bool=False, - use_con_reg: bool=False, - lambda_reg: float=1., - lambda_adv_cls: float=0.1, - lambda_con_reg: float=10.): +def compute_d_loss( + nets: Dict[str, Any], + x_real: paddle.Tensor, + y_org: paddle.Tensor, + y_trg: paddle.Tensor, + z_trg: paddle.Tensor=None, + x_ref: paddle.Tensor=None, + # TODO: should be True here, but r1_reg has some bug now + use_r1_reg: bool=False, + use_adv_cls: bool=False, + use_con_reg: bool=False, + lambda_reg: float=1., + lambda_adv_cls: float=0.1, + lambda_con_reg: float=10.): assert (z_trg is None) != (x_ref is None) # with real audios x_real.stop_gradient = False out = nets['discriminator'](x_real, y_org) loss_real = adv_loss(out, 1) - # R1 regularizaition (https://arxiv.org/abs/1801.04406v4) if use_r1_reg: loss_reg = r1_reg(out, x_real) else: - loss_reg = paddle.to_tensor([0.], dtype=paddle.float32) + # loss_reg = paddle.to_tensor([0.], dtype=paddle.float32) + loss_reg = paddle.zeros([1]) # consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724) - loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32) + loss_con_reg = paddle.zeros([1]) if use_con_reg: t = build_transforms() out_aug = nets['discriminator'](t(x_real).detach(), y_org) @@ -119,6 +121,7 @@ def compute_g_loss(nets: Dict[str, Any], # compute ASR/F0 features (real) with paddle.no_grad(): + print("x_real.shape:", x_real.shape) F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real) ASR_real = nets['asr_model'].get_feature(x_real) diff --git a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py index 6a77fbb2caa..1b811a3f704 100644 --- a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py +++ b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py @@ -259,7 +259,7 @@ def evaluate_core(self, batch): y_org=y_org, y_trg=y_trg, z_trg=z_trg, - use_r1_reg=False, + use_r1_reg=self.use_r1_reg, use_adv_cls=use_adv_cls, **self.d_loss_params) @@ -269,7 +269,7 @@ def evaluate_core(self, batch): y_org=y_org, y_trg=y_trg, x_ref=x_ref, - use_r1_reg=False, + use_r1_reg=self.use_r1_reg, use_adv_cls=use_adv_cls, **self.d_loss_params) From cb8c99405803bdcc2a7f91a27f3aeb567e313378 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Sun, 23 Apr 2023 07:35:07 +0000 Subject: [PATCH 6/6] fix loss bug --- paddlespeech/t2s/datasets/am_batch_fn.py | 2 -- paddlespeech/t2s/models/starganv2_vc/losses.py | 10 +++++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 85959aa250b..fe5d977a541 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -844,9 +844,7 @@ def starganv2_vc_batch_fn(self, examples): mel = [self.random_clip(item["mel"]) for item in examples] ref_mel = [self.random_clip(item["ref_mel"]) for item in examples] ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples] - print("mel[0].shape after batch_sequences:", mel[0].shape) mel = batch_sequences(mel) - print("mel.shape after batch_sequences:", mel.shape) ref_mel = batch_sequences(ref_mel) ref_mel_2 = batch_sequences(ref_mel_2) diff --git a/paddlespeech/t2s/models/starganv2_vc/losses.py b/paddlespeech/t2s/models/starganv2_vc/losses.py index 14534467610..d94c9342a45 100644 --- a/paddlespeech/t2s/models/starganv2_vc/losses.py +++ b/paddlespeech/t2s/models/starganv2_vc/losses.py @@ -19,9 +19,9 @@ from .transforms import build_transforms - # 这些都写到 updater 里 + def compute_d_loss( nets: Dict[str, Any], x_real: paddle.Tensor, @@ -121,10 +121,10 @@ def compute_g_loss(nets: Dict[str, Any], s_trg = nets['style_encoder'](x_ref, y_trg) # compute ASR/F0 features (real) - with paddle.no_grad(): - print("x_real.shape:", x_real.shape) - F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real) - ASR_real = nets['asr_model'].get_feature(x_real) + # 源码没有用 .eval(), 使用了 no_grad() + # 我们使用了 .eval(), 开启 with paddle.no_grad() 会报错 + F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real) + ASR_real = nets['asr_model'].get_feature(x_real) # adversarial loss x_fake = nets['generator'](x_real, s_trg, masks=None, F0=GAN_F0_real)