Skip to content

Commit 2505d18

Browse files
authored
feat(train): support deepspeed (#1849)
* feat(train): support deepspeed * feat(train): fix bug * feat(train): enable deepspeed in run.sh * feat(train): fix step bug for tensorboard logging * feat(train): recover cv log tensorboard logging * feat(train): make save_states configurable * feat(deepspeed): Support fp16/bf16 and deepspedCPUadam+customLRscheduler * feat(deepspeed): fix lint * feat(deepspeed): fix lint * feat(deepspeed): update stage2 config * feat(deepspeed): avoid re-generate filtered list if exists * feat(deepspeed): add 1.8B model * feat(deepspeed): make workers&prefetch configurable * feat(deepspeed): refine comment * feat(deepspeed): fix saving yaml * feat(deepspeed): refine if-else * feat(deepspeed): refine if-else
1 parent ac9a261 commit 2505d18

File tree

7 files changed

+480
-62
lines changed

7 files changed

+480
-62
lines changed
+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
{
2+
"train_micro_batch_size_per_gpu": 1,
3+
"gradient_accumulation_steps": 1,
4+
"steps_per_print": 100,
5+
"gradient_clipping": 0.0001,
6+
"fp16": {
7+
"enabled": false,
8+
"auto_cast": false,
9+
"loss_scale": 0,
10+
"initial_scale_power": 8,
11+
"loss_scale_window": 1000,
12+
"hysteresis": 2,
13+
"min_loss_scale": 1
14+
},
15+
"bf16": {
16+
"enabled": false
17+
},
18+
"zero_force_ds_cpu_optimizer": false,
19+
"zero_optimization": {
20+
"stage": 2,
21+
"offload_optimizer": {
22+
"device": "none",
23+
"pin_memory": true
24+
},
25+
"offload_param": {
26+
"device": "none",
27+
"pin_memory": true
28+
},
29+
"allgather_partitions": true,
30+
"allgather_bucket_size": 1e7,
31+
"overlap_comm": true,
32+
"reduce_scatter": true,
33+
"reduce_bucket_size": 1e7,
34+
"contiguous_gradients" : true
35+
},
36+
"activation_checkpointing": {
37+
"partition_activations": false,
38+
"cpu_checkpointing": false,
39+
"contiguous_memory_optimization": false,
40+
"number_checkpoints": null,
41+
"synchronize_checkpoint_boundary": false,
42+
"profile": true
43+
},
44+
"flops_profiler": {
45+
"enabled": false,
46+
"profile_step": 100,
47+
"module_depth": -1,
48+
"top_modules": 1,
49+
"detailed": true,
50+
"output_file": null
51+
},
52+
"tensorboard": {
53+
"enabled": true,
54+
"output_path": "tensorboard/ds_logs/",
55+
"job_name": "deepspeed"
56+
}
57+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# network architecture
2+
# encoder related
3+
encoder: conformer
4+
encoder_conf:
5+
output_size: 2048 # dimension of attention
6+
attention_heads: 16
7+
linear_units: 8192 # the number of units of position-wise feed forward
8+
num_blocks: 12 # the number of encoder blocks
9+
dropout_rate: 0.1
10+
positional_dropout_rate: 0.1
11+
attention_dropout_rate: 0.1
12+
input_layer: conv2d8 # encoder input type, you can chose conv2d, conv2d6 and conv2d8
13+
normalize_before: true
14+
cnn_module_kernel: 8
15+
use_cnn_module: True
16+
activation_type: 'swish'
17+
pos_enc_layer_type: 'rel_pos'
18+
selfattention_layer_type: 'rel_selfattn'
19+
causal: true
20+
use_dynamic_chunk: true
21+
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
22+
use_dynamic_left_chunk: false
23+
24+
# decoder related
25+
decoder: bitransformer
26+
decoder_conf:
27+
attention_heads: 16
28+
linear_units: 8192
29+
num_blocks: 3
30+
r_num_blocks: 3
31+
dropout_rate: 0.1
32+
positional_dropout_rate: 0.1
33+
self_attention_dropout_rate: 0.1
34+
src_attention_dropout_rate: 0.1
35+
36+
# hybrid CTC/attention
37+
model_conf:
38+
ctc_weight: 0.3
39+
lsm_weight: 0.1 # label smoothing option
40+
length_normalized_loss: false
41+
reverse_weight: 0.3
42+
43+
dataset_conf:
44+
filter_conf:
45+
max_length: 40960
46+
min_length: 0
47+
token_max_length: 200
48+
token_min_length: 1
49+
resample_conf:
50+
resample_rate: 16000
51+
speed_perturb: true
52+
fbank_conf:
53+
num_mel_bins: 80
54+
frame_shift: 10
55+
frame_length: 25
56+
dither: 1.0
57+
spec_aug: true
58+
spec_aug_conf:
59+
num_t_mask: 2
60+
num_f_mask: 2
61+
max_t: 50
62+
max_f: 10
63+
spec_sub: true
64+
spec_sub_conf:
65+
num_t_sub: 3
66+
max_t: 30
67+
spec_trim: false
68+
spec_trim_conf:
69+
max_t: 50
70+
shuffle: true
71+
shuffle_conf:
72+
shuffle_size: 1500
73+
sort: true
74+
sort_conf:
75+
sort_size: 500 # sort_size should be less than shuffle_size
76+
batch_conf:
77+
batch_type: 'static' # static or dynamic
78+
batch_size: 16
79+
80+
grad_clip: 5
81+
accum_grad: 1
82+
max_epoch: 100
83+
log_interval: 100
84+
85+
optim: adam
86+
optim_conf:
87+
lr: 0.001
88+
scheduler: warmuplr # pytorch v1.1.0+ required
89+
scheduler_conf:
90+
warmup_steps: 25000

examples/aishell/s0/run.sh

+64-27
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,19 @@ train_config=conf/train_conformer.yaml
4848
cmvn=true
4949
dir=exp/conformer
5050
checkpoint=
51+
num_workers=8
52+
prefetch=500
5153

5254
# use average_checkpoint will get better result
5355
average_checkpoint=true
5456
decode_checkpoint=$dir/final.pt
5557
average_num=30
5658
decode_modes="ctc_greedy_search ctc_prefix_beam_search attention attention_rescoring"
5759

60+
deepspeed=false
61+
deepspeed_config=conf/ds_stage2.json
62+
deepspeed_save_states="model_only"
63+
5864
. tools/parse_options.sh || exit 1;
5965

6066
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
@@ -116,11 +122,12 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
116122
# You have to rm `INIT_FILE` manually when you resume or restart a
117123
# multi-machine training.
118124
INIT_FILE=$dir/ddp_init
125+
rm -f ${INIT_FILE} # remove previous INIT_FILE
119126
init_method=file://$(readlink -f $INIT_FILE)
120127
echo "$0: init method is $init_method"
121128
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
122129
# Use "nccl" if it works, otherwise use "gloo"
123-
dist_backend="gloo"
130+
dist_backend="nccl"
124131
world_size=`expr $num_gpus \* $num_nodes`
125132
echo "total gpus is: $world_size"
126133
cmvn_opts=
@@ -130,30 +137,60 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
130137
# train.py rewrite $train_config to $dir/train.yaml with model input
131138
# and output dimension, and $dir/train.yaml will be used for inference
132139
# and export.
133-
for ((i = 0; i < $num_gpus; ++i)); do
134-
{
135-
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
136-
# Rank of each gpu/process used for knowing whether it is
137-
# the master of a worker.
138-
rank=`expr $node_rank \* $num_gpus + $i`
139-
python wenet/bin/train.py --gpu $gpu_id \
140-
--config $train_config \
141-
--data_type $data_type \
142-
--symbol_table $dict \
143-
--train_data data/$train_set/data.list \
144-
--cv_data data/dev/data.list \
145-
${checkpoint:+--checkpoint $checkpoint} \
146-
--model_dir $dir \
147-
--ddp.init_method $init_method \
148-
--ddp.world_size $world_size \
149-
--ddp.rank $rank \
150-
--ddp.dist_backend $dist_backend \
151-
--num_workers 1 \
152-
$cmvn_opts \
153-
--pin_memory
154-
} &
155-
done
156-
wait
140+
if [ ${deepspeed} == true ]; then
141+
echo "using deepspeed"
142+
# NOTE(xcsong): deepspeed fails with gloo, see
143+
# https://github.com/microsoft/DeepSpeed/issues/2818
144+
dist_backend="nccl"
145+
[ ! -f data/$train_set/data.list.filter ] && \
146+
python tools/filter_uneven_data.py data/$train_set/data.list \
147+
$data_type $num_gpus $num_utts_per_shard data/$train_set/data.list.filter
148+
deepspeed --include localhost:$CUDA_VISIBLE_DEVICES \
149+
wenet/bin/train.py \
150+
--deepspeed \
151+
--deepspeed_config ${deepspeed_config} \
152+
--deepspeed.save_states ${deepspeed_save_states} \
153+
--ddp.dist_backend $dist_backend \
154+
--ddp.init_method $init_method \
155+
--data_type $data_type \
156+
--config $train_config \
157+
--symbol_table data/dict/lang_char.txt \
158+
--train_data data/$train_set/data.list.filter \
159+
--cv_data data/dev/data.list \
160+
${checkpoint:+--checkpoint $checkpoint} \
161+
--model_dir $dir \
162+
--num_workers ${num_workers} \
163+
--prefetch ${prefetch} \
164+
$cmvn_opts \
165+
--pin_memory
166+
else
167+
echo "using torch ddp"
168+
for ((i = 0; i < $num_gpus; ++i)); do
169+
{
170+
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
171+
# Rank of each gpu/process used for knowing whether it is
172+
# the master of a worker.
173+
rank=`expr $node_rank \* $num_gpus + $i`
174+
python wenet/bin/train.py --gpu $gpu_id \
175+
--config $train_config \
176+
--data_type $data_type \
177+
--symbol_table $dict \
178+
--train_data data/$train_set/data.list \
179+
--cv_data data/dev/data.list \
180+
${checkpoint:+--checkpoint $checkpoint} \
181+
--model_dir $dir \
182+
--ddp.init_method $init_method \
183+
--ddp.world_size $world_size \
184+
--ddp.rank $rank \
185+
--ddp.dist_backend $dist_backend \
186+
--num_workers ${num_workers} \
187+
--prefetch ${prefetch} \
188+
$cmvn_opts \
189+
--pin_memory
190+
} &
191+
done
192+
wait
193+
fi
157194
fi
158195

159196
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
@@ -171,8 +208,8 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
171208
# non-streaming model. The default value is -1, which is full chunk
172209
# for non-streaming inference.
173210
decoding_chunk_size=
174-
ctc_weight=0.5
175-
reverse_weight=0.0
211+
ctc_weight=0.3
212+
reverse_weight=0.5
176213
for mode in ${decode_modes}; do
177214
{
178215
test_dir=$dir/test_${mode}

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ pycodestyle==2.6.0
1515
pyflakes==2.2.0
1616
torch==1.13.0
1717
torchaudio==0.13.0
18+
deepspeed

tools/filter_uneven_data.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
# Copyright [2023-04-27] <[email protected], Xingchen Song>
4+
5+
import os
6+
import random
7+
import tarfile
8+
9+
random.seed(1024)
10+
11+
# parse arg from command line
12+
datalist = os.sys.argv[1]
13+
datatype = os.sys.argv[2]
14+
num_gpus = int(os.sys.argv[3])
15+
num_samples_per_tar = int(os.sys.argv[4]) # only used in shard mode
16+
new_datalist = os.sys.argv[5]
17+
18+
assert datatype in ["shard", "raw"]
19+
20+
21+
filtered_list = []
22+
with open(datalist, "r") as f:
23+
lines = f.readlines()
24+
lines = [l.strip() for l in lines]
25+
if datatype == "raw":
26+
valid_num = len(lines) // num_gpus * num_gpus
27+
random.shuffle(lines)
28+
filtered_list = lines[:valid_num]
29+
else:
30+
for line in lines:
31+
cnt = 0
32+
with open(line, "rb") as tar:
33+
stream = tarfile.open(fileobj=tar, mode="r|*")
34+
for tarinfo in stream:
35+
name = tarinfo.name
36+
pos = name.rfind('.')
37+
assert pos > 0
38+
prefix, postfix = name[:pos], name[pos + 1:]
39+
if postfix == 'txt':
40+
cnt += 1
41+
if cnt == num_samples_per_tar:
42+
filtered_list.append(line)
43+
valid_num = len(filtered_list) // num_gpus * num_gpus
44+
random.shuffle(filtered_list)
45+
filtered_list = filtered_list[:valid_num]
46+
filtered_list.sort()
47+
print("before filter: {} after filter: {}".format(len(lines), len(filtered_list)))
48+
49+
with open(new_datalist, "w") as f:
50+
for line in filtered_list:
51+
f.writelines("{}\n".format(line))

0 commit comments

Comments
 (0)