Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Token/words confidence #1092

Closed
Zeiny96 opened this issue May 24, 2023 · 5 comments
Closed

Token/words confidence #1092

Zeiny96 opened this issue May 24, 2023 · 5 comments

Comments

@Zeiny96
Copy link

Zeiny96 commented May 24, 2023

Any available way to get each token/word confidence while using the modified beam search decoding method?

@Zeiny96
Copy link
Author

Zeiny96 commented May 28, 2023

I successfully added it, but the confidences for out of domain sets are very low, even though the result itself is fair, is this normal?

@csukuangfj
Copy link
Collaborator

Could you describe how you get the scores?

@Zeiny96
Copy link
Author

Zeiny96 commented May 29, 2023

Using the log_probs in the modified beam search, then using it with the exponential function just to get rid of the log scale as follows:

for k in range(len(topk_hyp_indexes)):
                ...
                ...
                ...
                new_confidences = hyp.confidences[:]
                if new_token not in (blank_id, unk_id):
                    ...
                    ...
                    ...
                    new_confidences.append(torch.exp(topk_log_probs[k]).item())

                new_log_prob = topk_log_probs[k]
                new_hyp = Hypothesis(
                    ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp, confidences=new_confidences
                )
                B[i].add(new_hyp)

    ...
    ...
    ...
    sorted_confidences = [h.confidences for h in best_hyps]
    ...
    ...
    ...
    ans_confidences = []
    unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
    for i in range(N):
        ans.append(sorted_ans[unsorted_indices[i]])
        ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
        ans_confidences.append(sorted_confidences[unsorted_indices[i]])

    return DecodingResults(
        tokens=ans,
        timestamps=ans_timestamps,
        confidences=ans_confidences,
    )

Then to get each word confidence I average all of its tokens prob values as follows:

confidence = sum(tokens_per_word_probs)/len(tokens_per_word_probs)

@csukuangfj
Copy link
Collaborator

                new_confidences.append(torch.exp(topk_log_probs[k]).item())

topk_log_probs is cumulative. You have to subtract it from ys_log_probs to get the log_prob of the current frame.

@Zeiny96
Copy link
Author

Zeiny96 commented May 30, 2023

Yeah, that was my bad, it is fixed now:

new_confidences.append(torch.exp(topk_log_probs[k]-hyp.log_prob).item())

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants