Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KWS]Add kws example on HeySnips dataset. #1558

Merged
merged 6 commits into from
Apr 25, 2022

Conversation

KPatr1ck
Copy link
Contributor

PR types

New features

PR changes

Models

Describe

Add mdtc model for kws.

@KPatr1ck KPatr1ck added this to the r0.2.0 milestone Mar 11, 2022
Copy link
Collaborator

@zh794390558 zh794390558 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对应的paper是?

@zh794390558 zh794390558 marked this pull request as draft March 16, 2022 07:21
@zh794390558 zh794390558 modified the milestones: r0.2.0, r1.0.0 Apr 1, 2022
@KPatr1ck KPatr1ck force-pushed the kws branch 2 times, most recently from a35e99a to fe7c4e5 Compare April 19, 2022 09:46
@KPatr1ck KPatr1ck marked this pull request as ready for review April 19, 2022 09:46
@@ -0,0 +1,39 @@
data:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

配置文件参考其他egs,展开吧。不建议用层级的了。

@@ -0,0 +1,5 @@
#!/bin/bash

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加上usage,并check输入

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个脚本不推荐用户单独使用,正常都是从run.sh调用,我在run里加检查

if self.feat_type in ['kaldi_fbank', 'kaldi_mfcc']:
waveform = paddle.to_tensor(waveform).unsqueeze(0) # (C, T)
record['feat'] = feat_func(
waveform=waveform, sr=self.sample_rate, **self.feat_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要的参数需要对应指出,不建议使用**

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feature api的参数太多,且大部分为默认,我加个说明让用户根据api中的kwargs去配置。


# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不建议用配置文件传入参数,需要单独指出

paddlespeech/kws/exps/mdtc/compute_det.py Outdated Show resolved Hide resolved
paddlespeech/kws/exps/mdtc/train.py Show resolved Hide resolved
paddlespeech/kws/exps/mdtc/train.py Show resolved Hide resolved
paddlespeech/kws/models/loss.py Outdated Show resolved Hide resolved

mask = padding_mask(lengths)
num_utts = logits.shape[0]
num_keywords = logits.shape[2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_keywords 好奇怪。
B,T,D

num_keywords = logits.shape[2]

loss = 0.0
for i in range(num_utts):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么这个不直接调用CE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不可以直接用CrossEntropy,需要在有效的音频长度内取max作为得分再算CE。

@KPatr1ck KPatr1ck changed the title [KWS]Add mdtc model. [KWS]Add kws example on HeySnips dataset. Apr 19, 2022
@PaddlePaddle PaddlePaddle deleted a comment from KPatr1ck Apr 20, 2022
@mergify
Copy link

mergify bot commented Apr 22, 2022

This pull request is now in conflict :(

@mergify mergify bot added the conflicts label Apr 22, 2022
@KPatr1ck
Copy link
Contributor Author

对应的paper是?

在README中的模型处给出了,https://arxiv.org/pdf/2102.13552.pdf

while i < len(score_list):
if score_list[i] >= threshold:
num_false_alarm += 1
i += args.window_shift
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

50*10=500ms

for j in range(num_keywords):
# Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p))
if target[i] == j:
# For the keyword, do max-polling
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看着是在T上做了max_pool后算Cross Entropy

super(KWSModel, self).__init__()
self.backbone = backbone
self.linear = nn.Linear(self.backbone.hidden_dim, num_keywords)
self.activation = nn.Sigmoid()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

二分类

self.kernel_size = kernel_size
self.dilation = dilation
self.causal = causal
self.receptive_fields = dilation * (kernel_size - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel_size - 1 if no_dilation else dilation * (kernel_size -1)

Copy link
Collaborator

@zh794390558 zh794390558 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zh794390558 zh794390558 merged commit 962a278 into PaddlePaddle:develop Apr 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants