Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tts] finetune add frozen #2385

Merged
merged 5 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/other/tts_finetune/tts3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ When "Prepare" done. The structure of the current directory is listed below.

```

### Set finetune.yaml
`finetune.yaml` contains some configurations for fine-tuning. You can try various options to fine better result.
Arguments:
- `batch_size`: finetune batch size. Default: -1, means 64 which same to pretrained model
- `learning_rate`: learning rate. Default: 0.0001
- `num_snapshots`: number of save models. Default: -1, means 5 which same to pretrained model
- `frozen_layers`: frozen layers. must be a list. If you don't want to frozen any layer, set [].



## Get Started
Run the command below to
Expand Down
43 changes: 33 additions & 10 deletions examples/other/tts_finetune/tts3/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@
import argparse
import os
from pathlib import Path
from typing import List
from typing import Union

import yaml
from local.check_oov import get_check_result
from local.extract import extract_feature
from local.label_process import get_single_label
from local.prepare_env import generate_finetune_env
from local.train import train_sp
from paddle import distributed as dist
from yacs.config import CfgNode

from paddlespeech.t2s.exps.fastspeech2.train import train_sp
from utils.gen_duration_from_textgrid import gen_duration_from_textgrid

DICT_EN = 'tools/aligner/cmudict-0.7b'
Expand All @@ -38,15 +39,24 @@


class TrainArgs():
def __init__(self, ngpu, config_file, dump_dir: Path, output_dir: Path):
def __init__(self,
ngpu,
config_file,
dump_dir: Path,
output_dir: Path,
frozen_layers: List[str]):
# config: fastspeech2 config file.
self.config = str(config_file)
self.train_metadata = str(dump_dir / "train/norm/metadata.jsonl")
self.dev_metadata = str(dump_dir / "dev/norm/metadata.jsonl")
# model output dir.
self.output_dir = str(output_dir)
self.ngpu = ngpu
self.phones_dict = str(dump_dir / "phone_id_map.txt")
self.speaker_dict = str(dump_dir / "speaker_id_map.txt")
self.voice_cloning = False
# frozen layers
self.frozen_layers = frozen_layers


def get_mfa_result(
Expand Down Expand Up @@ -122,12 +132,11 @@ def get_mfa_result(
"--ngpu", type=int, default=2, help="if ngpu=0, use cpu.")

parser.add_argument("--epoch", type=int, default=100, help="finetune epoch")

parser.add_argument(
"--batch_size",
type=int,
default=-1,
help="batch size, default -1 means same as pretrained model")
"--finetune_config",
type=str,
default="./finetune.yaml",
help="Path to finetune config file")

args = parser.parse_args()

Expand All @@ -147,8 +156,14 @@ def get_mfa_result(
with open(config_file) as f:
config = CfgNode(yaml.safe_load(f))
config.max_epoch = config.max_epoch + args.epoch
if args.batch_size > 0:
config.batch_size = args.batch_size

with open(args.finetune_config) as f2:
finetune_config = CfgNode(yaml.safe_load(f2))
config.batch_size = finetune_config.batch_size if finetune_config.batch_size > 0 else config.batch_size
config.optimizer.learning_rate = finetune_config.learning_rate if finetune_config.learning_rate > 0 else config.optimizer.learning_rate
config.num_snapshots = finetune_config.num_snapshots if finetune_config.num_snapshots > 0 else config.num_snapshots
frozen_layers = finetune_config.frozen_layers
assert type(frozen_layers) == list, "frozen_layers should be set a list."

if args.lang == 'en':
lexicon_file = DICT_EN
Expand All @@ -158,6 +173,13 @@ def get_mfa_result(
mfa_phone_file = MFA_PHONE_ZH
else:
print('please input right lang!!')

print(f"finetune max_epoch: {config.max_epoch}")
print(f"finetune batch_size: {config.batch_size}")
print(f"finetune learning_rate: {config.optimizer.learning_rate}")
print(f"finetune num_snapshots: {config.num_snapshots}")
print(f"finetune frozen_layers: {frozen_layers}")

am_phone_file = pretrained_model_dir / "phone_id_map.txt"
label_file = input_dir / "labels.txt"

Expand All @@ -181,7 +203,8 @@ def get_mfa_result(
generate_finetune_env(output_dir, pretrained_model_dir)

# create a new args for training
train_args = TrainArgs(args.ngpu, config_file, dump_dir, output_dir)
train_args = TrainArgs(args.ngpu, config_file, dump_dir, output_dir,
frozen_layers)

# finetune models
# dispatch
Expand Down
12 changes: 12 additions & 0 deletions examples/other/tts_finetune/tts3/finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
###########################################################
# PARAS SETTING #
###########################################################
# Set to -1 to indicate that the parameter is the same as the pretrained model configuration

batch_size: -1
learning_rate: 0.0001 # learning rate
num_snapshots: -1

# frozen_layers should be a list
# if you don't need to freeze, set frozen_layers to []
frozen_layers: ["encoder", "duration_predictor"]
7 changes: 3 additions & 4 deletions examples/other/tts_finetune/tts3/local/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
from operator import itemgetter
from pathlib import Path
Expand Down Expand Up @@ -211,9 +210,9 @@ def extract_feature(duration_file: str,
mel_extractor, pitch_extractor, energy_extractor = get_extractor(config)

wav_files = sorted(list((input_dir).rglob("*.wav")))
# split data into 3 sections, train: 80%, dev: 10%, test: 10%
num_train = math.ceil(len(wav_files) * 0.8)
num_dev = math.ceil(len(wav_files) * 0.1)
# split data into 3 sections, train: len(wav_files) - 2, dev: 1, test: 1
num_train = len(wav_files) - 2
num_dev = 1
print(num_train, num_dev)

train_wav_files = wav_files[:num_train]
Expand Down
178 changes: 178 additions & 0 deletions examples/other/tts_finetune/tts3/local/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
from pathlib import Path
from typing import List

import jsonlines
import numpy as np
import paddle
from paddle import DataParallel
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler

from paddlespeech.t2s.datasets.am_batch_fn import fastspeech2_multi_spk_batch_fn
from paddlespeech.t2s.datasets.am_batch_fn import fastspeech2_single_spk_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Evaluator
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Updater
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
from paddlespeech.t2s.training.optimizer import build_optimizers
from paddlespeech.t2s.training.seeding import seed_everything
from paddlespeech.t2s.training.trainer import Trainer


def freeze_layer(model, layers: List[str]):
"""freeze layers

Args:
layers (List[str]): frozen layers
"""
for layer in layers:
for param in eval("model." + layer + ".parameters()"):
param.trainable = False


def train_sp(args, config):
# decides device type and whether to run in parallel
# setup running environment correctly
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
paddle.set_device("cpu")
else:
paddle.set_device("gpu")
world_size = paddle.distributed.get_world_size()
if world_size > 1:
paddle.distributed.init_parallel_env()

# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)

print(
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
)
fields = [
"text", "text_lengths", "speech", "speech_lengths", "durations",
"pitch", "energy"
]
converters = {"speech": np.load, "pitch": np.load, "energy": np.load}
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker fastspeech2!")
collate_fn = fastspeech2_multi_spk_batch_fn
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
fields += ["spk_id"]
elif args.voice_cloning:
print("Training voice cloning!")
collate_fn = fastspeech2_multi_spk_batch_fn
fields += ["spk_emb"]
converters["spk_emb"] = np.load
else:
print("single speaker fastspeech2!")
collate_fn = fastspeech2_single_spk_batch_fn
print("spk_num:", spk_num)

# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True

# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=fields,
converters=converters, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=fields,
converters=converters, )

# collate function and dataloader

train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True)

print("samplers done!")

train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=collate_fn,
num_workers=config.num_workers)

dev_dataloader = DataLoader(
dev_dataset,
shuffle=False,
drop_last=False,
batch_size=config.batch_size,
collate_fn=collate_fn,
num_workers=config.num_workers)
print("dataloaders done!")

with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)

odim = config.n_mels
model = FastSpeech2(
idim=vocab_size, odim=odim, spk_num=spk_num, **config["model"])

# freeze layer
if args.frozen_layers != []:
freeze_layer(model, args.frozen_layers)

if world_size > 1:
model = DataParallel(model)
print("model done!")

optimizer = build_optimizers(model, **config["optimizer"])
print("optimizer done!")

output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
if dist.get_rank() == 0:
config_name = args.config.split("/")[-1]
# copy conf to output_dir
shutil.copyfile(args.config, output_dir / config_name)

updater = FastSpeech2Updater(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
output_dir=output_dir,
**config["updater"])

trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)

evaluator = FastSpeech2Evaluator(
model, dev_dataloader, output_dir=output_dir, **config["updater"])

if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
trainer.run()
12 changes: 7 additions & 5 deletions examples/other/tts_finetune/tts3/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ mfa_dir=./mfa_result
dump_dir=./dump
output_dir=./exp/default
lang=zh
ngpu=2
ngpu=1
finetune_config=./finetune.yaml

ckpt=snapshot_iter_96600
ckpt=snapshot_iter_96699

gpus=0,1
gpus=1
CUDA_VISIBLE_DEVICES=${gpus}
stage=0
stop_stage=100
Expand All @@ -35,7 +36,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--output_dir=${output_dir} \
--lang=${lang} \
--ngpu=${ngpu} \
--epoch=100
--epoch=100 \
--finetune_config=${finetune_config}
fi


Expand All @@ -54,7 +56,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--voc_stat=pretrained_models/hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=./test_e2e \
--output_dir=./test_e2e/ \
--phones_dict=${dump_dir}/phone_id_map.txt \
--speaker_dict=${dump_dir}/speaker_id_map.txt \
--spk_id=0
Expand Down