-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] u2++-lite training support #2202
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# network architecture | ||
# encoder related | ||
encoder: conformer | ||
encoder_conf: | ||
output_size: 256 # dimension of attention | ||
attention_heads: 4 | ||
linear_units: 2048 # the number of units of position-wise feed forward | ||
num_blocks: 12 # the number of encoder blocks | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
attention_dropout_rate: 0.1 | ||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 | ||
normalize_before: true | ||
cnn_module_kernel: 8 | ||
use_cnn_module: True | ||
activation_type: 'swish' | ||
pos_enc_layer_type: 'rel_pos' | ||
selfattention_layer_type: 'rel_selfattn' | ||
causal: true | ||
use_dynamic_chunk: true | ||
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster | ||
use_dynamic_left_chunk: false | ||
|
||
# decoder related | ||
decoder: bitransformer | ||
decoder_conf: | ||
attention_heads: 4 | ||
linear_units: 1024 | ||
num_blocks: 3 | ||
r_num_blocks: 3 | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
self_attention_dropout_rate: 0.1 | ||
src_attention_dropout_rate: 0.1 | ||
|
||
# hybrid CTC/attention | ||
model_conf: | ||
ctc_weight: 0.3 | ||
lsm_weight: 0.1 # label smoothing option | ||
length_normalized_loss: false | ||
reverse_weight: 0.3 | ||
apply_non_blank_embedding: true # warning: had better use a well trained model as init model | ||
|
||
dataset_conf: | ||
filter_conf: | ||
max_length: 40960 | ||
min_length: 0 | ||
token_max_length: 200 | ||
token_min_length: 1 | ||
resample_conf: | ||
resample_rate: 16000 | ||
speed_perturb: true | ||
fbank_conf: | ||
num_mel_bins: 80 | ||
frame_shift: 10 | ||
frame_length: 25 | ||
dither: 1.0 | ||
spec_aug: true | ||
spec_aug_conf: | ||
num_t_mask: 2 | ||
num_f_mask: 2 | ||
max_t: 50 | ||
max_f: 10 | ||
spec_sub: true | ||
spec_sub_conf: | ||
num_t_sub: 3 | ||
max_t: 30 | ||
spec_trim: false | ||
spec_trim_conf: | ||
max_t: 50 | ||
shuffle: true | ||
shuffle_conf: | ||
shuffle_size: 1500 | ||
sort: true | ||
sort_conf: | ||
sort_size: 500 # sort_size should be less than shuffle_size | ||
batch_conf: | ||
batch_type: 'static' # static or dynamic | ||
batch_size: 16 | ||
|
||
grad_clip: 5 | ||
accum_grad: 1 | ||
max_epoch: 360 | ||
log_interval: 100 | ||
|
||
optim: adam | ||
optim_conf: | ||
lr: 0.001 | ||
scheduler: warmuplr # pytorch v1.1.0+ required | ||
scheduler_conf: | ||
warmup_steps: 25000 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,12 +66,15 @@ def add_model_args(parser): | |
default=None, | ||
type=str, | ||
help="Pre-trained model to initialize encoder") | ||
parser.add_argument( | ||
"--enc_init_mods", | ||
default="encoder.", | ||
type=lambda s: [str(mod) for mod in s.split(",") if s != ""], | ||
help="List of encoder modules \ | ||
parser.add_argument('--enc_init_mods', | ||
default="encoder.", | ||
type=lambda s: [str(mod) for mod in s.split(",") if s != ""], | ||
help="List of encoder modules \ | ||
to initialize ,separated by a comma") | ||
parser.add_argument('--freeze_modules', | ||
default="", | ||
type=lambda s: [str(mod) for mod in s.split(",") if s != ""], | ||
help='free module names',) | ||
parser.add_argument('--lfmmi_dir', | ||
default='', | ||
required=False, | ||
|
@@ -239,6 +242,12 @@ def check_modify_and_save_config(args, configs, symbol_table): | |
data = yaml.dump(configs) | ||
fout.write(data) | ||
|
||
if configs["model_conf"]["apply_non_blank_embedding"]: | ||
logging.warn( | ||
'Had better load a well trained model' | ||
'if apply_non_blank_embedding is true !!!' | ||
) | ||
|
||
return configs | ||
|
||
|
||
|
@@ -601,3 +610,10 @@ def log_per_epoch(writer, info_dict): | |
if int(os.environ.get('RANK', 0)) == 0: | ||
writer.add_scalar('epoch/cv_loss', info_dict["cv_loss"], epoch) | ||
writer.add_scalar('epoch/lr', info_dict["lr"], epoch) | ||
|
||
def freeze_modules(model, args): | ||
for name, param in model.named_parameters(): | ||
for module_name in args.freeze_modules: | ||
if module_name in name: | ||
param.requires_grad = False | ||
logging.debug("{} module is freezed".format(name)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 纯好奇,freeze的结果比不freeze更好吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不freeze 多卡训练会有问题,对齐也会发生变化 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
get,多卡训练报啥错 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里ctc的返回值已经变成俩了,所以k2和paraformer里的ctc调用的返回值也得改一下,不然会报错