-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[KWS]Add kws example on HeySnips dataset. #1558
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对应的paper是?
a35e99a
to
fe7c4e5
Compare
@@ -0,0 +1,39 @@ | |||
data: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
配置文件参考其他egs,展开吧。不建议用层级的了。
@@ -0,0 +1,5 @@ | |||
#!/bin/bash | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加上usage,并check输入
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个脚本不推荐用户单独使用,正常都是从run.sh调用,我在run里加检查
if self.feat_type in ['kaldi_fbank', 'kaldi_mfcc']: | ||
waveform = paddle.to_tensor(waveform).unsqueeze(0) # (C, T) | ||
record['feat'] = feat_func( | ||
waveform=waveform, sr=self.sample_rate, **self.feat_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要的参数需要对应指出,不建议使用**
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feature api的参数太多,且大部分为默认,我加个说明让用户根据api中的kwargs去配置。
|
||
# yapf: disable | ||
parser = argparse.ArgumentParser(__doc__) | ||
parser.add_argument("--cfg_path", type=str, required=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不建议用配置文件传入参数,需要单独指出
|
||
mask = padding_mask(lengths) | ||
num_utts = logits.shape[0] | ||
num_keywords = logits.shape[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_keywords 好奇怪。
B,T,D
num_keywords = logits.shape[2] | ||
|
||
loss = 0.0 | ||
for i in range(num_utts): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么这个不直接调用CE?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不可以直接用CrossEntropy,需要在有效的音频长度内取max作为得分再算CE。
This pull request is now in conflict :( |
在README中的模型处给出了,https://arxiv.org/pdf/2102.13552.pdf |
while i < len(score_list): | ||
if score_list[i] >= threshold: | ||
num_false_alarm += 1 | ||
i += args.window_shift |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
50*10=500ms
for j in range(num_keywords): | ||
# Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p)) | ||
if target[i] == j: | ||
# For the keyword, do max-polling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看着是在T上做了max_pool后算Cross Entropy
super(KWSModel, self).__init__() | ||
self.backbone = backbone | ||
self.linear = nn.Linear(self.backbone.hidden_dim, num_keywords) | ||
self.activation = nn.Sigmoid() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
二分类
self.kernel_size = kernel_size | ||
self.dilation = dilation | ||
self.causal = causal | ||
self.receptive_fields = dilation * (kernel_size - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel_size - 1 if no_dilation else dilation * (kernel_size -1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Models
Describe
Add mdtc model for kws.