-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathspin.yaml
90 lines (82 loc) · 2.48 KB
/
spin.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Interspeech 2023 version
# Training data
data:
json_dir: /data/sls/r/u/hengjui/home/scratch/spin_test/data
splits:
- train-clean-100
sample_rate: 16000
min_audio_len: 40000 # minimum audio samples per utterance
random_crop_len: 272000 # maximum audio samples per utterance
spk2info: /data/sls/r/u/hengjui/home/scratch/dataset/libri_util/spk2info.dict
# Validation data (not used for checkpointing, just for monitoring training progress)
val_data:
json_dir: /data/sls/r/u/hengjui/home/scratch/spin_test/data
phn_dir: /data/sls/r/u/hengjui/home/scratch/spin_test/data
splits:
- dev-clean
- dev-other
sample_rate: 16000
# SpinModel config
model:
encoder:
type: HuBERT # `HuBERT` / `WavLM`
use_layer: 12 # the layer which its representations are used for clustering
normalize: False
feat_select: x
randomize_all: False
randomize_layers: []
freeze_all: False
freeze_layers: ["pos", 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # `pos`: positional encoding, `0`: CNN extractor
pred_head:
type: DNN
hid_dims: [256]
dropout: 0
activation: ReLU
loss:
type: SwavVQDisentangle
num_vars: 256 # cluster size
epsilon: 0.02
sinkhorn_iters: 3
temp: 0.1
l2_norm: True
prob_ratio: 1.0
# Optimization
optim:
optimizer:
name: Adam
args:
lr: 1.e-4
weight_decay: 1.e-6
scheduler:
name: linear_warmup_decay # `linear_warmup_decay` / `linear_warmup_cosine_scheduler` / `noam_scheduler`
args:
warmup: 2500
max_step: 5000
final_lr: 1.e-6
hparam:
batch_len: 4096000 # audio samples per GPU (256 secs ~ batch_size = 12.8k)
val_batch_size: 8
# pytorch_lightning.Trainer
# ref: https://lightning.ai/docs/pytorch/latest/common/trainer.html
trainer:
max_steps: 5000
gradient_clip_val: 10
accumulate_grad_batches: 1
precision: 16
logger: wandb # use `False` to disable logging
log_every_n_steps: 100
default_root_dir: exp/tmp
accelerator: gpu
# strategy: ddp # uncomment this line to enable DDP training
num_sanity_val_steps: 0
val_check_interval: 1000
# pytorch_lightning.callbacks.ModelCheckpoint
# ref: https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.ModelCheckpoint.html
checkpoint:
filename: "{epoch}-{step}"
every_n_train_steps: 5000
save_last: true
# pytorch_lightning.loggers.WandbLogger
# ref: https://lightning.ai/docs/pytorch/latest/extensions/generated/lightning.pytorch.loggers.WandbLogger.html
logger:
project: spin_is2023