-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] u2++-lite training support #2202
Conversation
What is the functionality of 'apply_non_blank_embedding'? Are there any reference materials available for learning? |
a2f0674
to
0a5ee16
Compare
it is a new feature |
db2be51
to
84f9221
Compare
u2++ lite is used for reducing rescoring latency,runtime and latency result will be check in soon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wenet/transformer/asr_model.py
Outdated
@@ -133,6 +143,34 @@ def _forward_ctc(self, encoder_out: torch.Tensor, | |||
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths) | |||
return loss_ctc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -63,7 +67,8 @@ def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor, | |||
loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens) | |||
# Batch-size average | |||
loss = loss / ys_hat.size(1) | |||
return loss | |||
ys_hat = ys_hat.transpose(0, 1) | |||
return loss, ys_hat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wenet/utils/executor.py
Outdated
if info_dict["model_conf"]["apply_non_blank_embedding"]: | ||
logging.warn( | ||
'Had better load a well trained model if' | ||
'apply_non_blank_embedding is true !!!' | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以挪到 train_utils.py::check_modify_and_save_config
函数吗?原因是:
- 放在
executor::train
里,每个epcoh都要打印 check_modify_and_save_config
这个函数就是专门用来检查配置的,符合这里的log含义
for module_name in args.freeze_modules: | ||
if module_name in name: | ||
param.requires_grad = False | ||
logging.debug("{} module is freezed".format(name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
纯好奇,freeze的结果比不freeze更好吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不freeze 多卡训练会有问题,对齐也会发生变化
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不freeze 多卡训练会有问题,对齐也会发生变化
get,多卡训练报啥错
wenet/transformer/asr_model.py
Outdated
maxlen = encoder_out.size(1) | ||
top1_index = torch.argmax(ctc_probs, dim=2) | ||
indices = [] | ||
for j in range(topk_prob.size(0)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
topk_prob is undefined
78915c3
to
f82a4c9
Compare
[train] add instructions for use
f82a4c9
to
bfaa8a3
Compare
设置 pre-commit 了吗? |
这是google用在RNNT中的frame reduce / blank skip 用在AED架构中吗? |
当时做的时候没有了解这些,刚才搜了一下 k2 团队也有类似的工作。我理解思想上都差不多,都是为了降低计算量,减小延迟。这里主要是为了减少推理时的延迟。 |
No description provided.