Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 19 additions & 31 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,39 +817,27 @@ def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothes
(lambda x: self._aggregate_confidence(x)) if self.tdt_include_duration_confidence else (lambda x: x)
)
for hyp in hypotheses_list:
frame_confidence = []
for frame_confs in hyp.frame_confidence:
for frame_conf in frame_confs:
if len(frame_conf) > 0:
frame_confidence.append(maybe_pre_aggregate(frame_conf))
assert len(frame_confidence) == len(hyp.alignment_labels), (
f"Length mismatch: frame_confidence has {len(frame_confidence)} elements, "
f"but hyp.alignment_labels has {len(hyp.alignment_labels)} elements"
)

token_confidence = []
# trying to recover frame_confidence according to alignments
subsequent_blank_confidence = []
# going backwards since <blank> tokens are considered belonging to the last non-blank token.
for fc, fa in zip(hyp.frame_confidence[::-1], hyp.alignments[::-1]):
# there is only one score per frame most of the time
if len(fa) > 1:
for i, a in reversed(list(enumerate(fa))):
if a[-1] == self.blank_id:
if not self.exclude_blank_from_confidence:
subsequent_blank_confidence.append(maybe_pre_aggregate(fc[i]))
elif not subsequent_blank_confidence:
token_confidence.append(maybe_pre_aggregate(fc[i]))
else:
token_confidence.append(
self._aggregate_confidence(
[maybe_pre_aggregate(fc[i])] + subsequent_blank_confidence
)
)
subsequent_blank_confidence = []
for label, frame_conf in zip(hyp.alignment_labels, frame_confidence):
if label == self.blank_id:
if self.exclude_blank_from_confidence:
# skip if blank tokens are to be excluded from confidence calculation
continue
elif len(token_confidence) > 0:
# aggregate blank confidence to the previous token if any
token_confidence[-1] = self._aggregate_confidence([token_confidence[-1], frame_conf])
else:
i, a = 0, fa[0]
if a[-1] == self.blank_id:
if not self.exclude_blank_from_confidence:
subsequent_blank_confidence.append(maybe_pre_aggregate(fc[i]))
elif not subsequent_blank_confidence:
token_confidence.append(maybe_pre_aggregate(fc[i]))
else:
token_confidence.append(
self._aggregate_confidence([maybe_pre_aggregate(fc[i])] + subsequent_blank_confidence)
)
subsequent_blank_confidence = []
token_confidence = token_confidence[::-1]
token_confidence.append(frame_conf)
hyp.token_confidence = token_confidence
else:
if self.exclude_blank_from_confidence:
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2623,6 +2623,7 @@ def _greedy_decode(

if self.preserve_frame_confidence:
hypothesis.frame_confidence = [[]]
hypothesis.alignment_labels = []

time_idx = 0
while time_idx < out_len:
Expand Down Expand Up @@ -2678,6 +2679,7 @@ def _greedy_decode(
if self.include_duration_confidence
else self._get_confidence_tensor(logp)
)
hypothesis.alignment_labels.append(k)

del logp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ def cuda_graphs_impl(
# to avoid any manipulations with allocated memory outside the decoder
return (
self.state.batched_hyps.clone(),
self.state.alignments.clone() if self.preserve_alignments else None,
self.state.alignments.clone() if (self.preserve_alignments or self.preserve_frame_confidence) else None,
decoding_state,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def __init__(
self.max_symbols = max_symbols_per_step
self.preserve_alignments = preserve_alignments
self.preserve_frame_confidence = preserve_frame_confidence
self.preserve_alignments = preserve_alignments or preserve_frame_confidence
self.allow_cuda_graphs = allow_cuda_graphs
self.include_duration = include_duration
self.include_duration_confidence = include_duration_confidence
Expand Down Expand Up @@ -427,7 +426,7 @@ def torch_impl(
active_mask=active_mask,
time_indices=time_indices_current_labels,
logits=logits if self.preserve_alignments else None,
labels=labels if self.preserve_alignments else None,
labels=labels if (self.preserve_alignments or self.preserve_frame_confidence) else None,
confidence=self._get_frame_confidence(logits=logits, num_durations=num_durations),
)

Expand Down Expand Up @@ -481,7 +480,7 @@ def torch_impl(
active_mask=advance_mask,
time_indices=time_indices_current_labels,
logits=logits if self.preserve_alignments else None,
labels=more_labels if self.preserve_alignments else None,
labels=more_labels if (self.preserve_alignments or self.preserve_frame_confidence) else None,
confidence=self._get_frame_confidence(logits=logits, num_durations=num_durations),
)

Expand Down Expand Up @@ -835,7 +834,7 @@ def cuda_graphs_impl(
# to avoid any manipulations with allocated memory outside the decoder
return (
self.state.batched_hyps.clone(),
self.state.alignments.clone() if self.preserve_alignments else None,
self.state.alignments.clone() if (self.preserve_alignments or self.preserve_frame_confidence) else None,
decoding_state,
)

Expand Down Expand Up @@ -1185,7 +1184,7 @@ def _before_inner_loop_get_joint_output(self):
active_mask=self.state.active_mask,
time_indices=self.state.time_indices_current_labels,
logits=logits if self.preserve_alignments else None,
labels=self.state.labels if self.preserve_alignments else None,
labels=self.state.labels if (self.preserve_alignments or self.preserve_frame_confidence) else None,
confidence=self._get_frame_confidence(
logits=logits, num_durations=self.state.model_durations.shape[0]
),
Expand Down Expand Up @@ -1252,7 +1251,7 @@ def _inner_loop_step_find_next_non_blank(self):
active_mask=self.state.advance_mask,
time_indices=self.state.time_indices_current_labels,
logits=logits if self.preserve_alignments else None,
labels=more_labels if self.preserve_alignments else None,
labels=more_labels if (self.preserve_alignments or self.preserve_frame_confidence) else None,
confidence=self._get_frame_confidence(
logits=logits, num_durations=self.state.model_durations.shape[0]
),
Expand Down
31 changes: 23 additions & 8 deletions nemo/collections/asr/parts/utils/rnnt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class Hypothesis:
token_duration: Optional[torch.Tensor] = None
last_frame: Optional[int] = None

alignment_labels: Optional[List[int]] = None # labels corresponding to alignments (can contain blanks)

@property
def non_blank_frame_confidence(self) -> List[float]:
"""Get per-frame confidence for non-blank tokens according to self.timestamp
Expand Down Expand Up @@ -589,10 +591,13 @@ def __init__(
# empty tensors instead of None to make torch.jit.script happy
self.logits = torch.zeros(0, device=device, dtype=float_dtype)
self.labels = torch.zeros(0, device=device, dtype=torch.long)
if self.with_alignments or self.with_frame_confidence:
# labels; labels can contain <blank>, different from BatchedHyps
# are used during token confidence calculation for TDT models
self.labels = torch.zeros((batch_size, self._max_length), device=device, dtype=torch.long)
if self.with_alignments:
# logits and labels; labels can contain <blank>, different from BatchedHyps
# logits; labels can contain <blank>, different from BatchedHyps
self.logits = torch.zeros((batch_size, self._max_length, logits_dim), device=device, dtype=float_dtype)
self.labels = torch.zeros((batch_size, self._max_length), device=device, dtype=torch.long)

# empty tensor instead of None to make torch.jit.script happy
self.frame_confidence = torch.zeros(0, device=device, dtype=float_dtype)
Expand Down Expand Up @@ -621,9 +626,10 @@ def _allocate_more(self):
to maintain O(1) insertion time complexity
"""
self.timestamps = torch.cat((self.timestamps, torch.zeros_like(self.timestamps)), dim=-1)
if self.with_alignments or self.with_frame_confidence:
self.labels = torch.cat((self.labels, torch.zeros_like(self.labels)), dim=-1)
if self.with_alignments:
self.logits = torch.cat((self.logits, torch.zeros_like(self.logits)), dim=1)
self.labels = torch.cat((self.labels, torch.zeros_like(self.labels)), dim=-1)
if self.with_frame_confidence:
self.frame_confidence = torch.cat((self.frame_confidence, torch.zeros_like(self.frame_confidence)), dim=1)
self._max_length *= 2
Expand Down Expand Up @@ -658,10 +664,12 @@ def add_results_(
# store timestamps - same for alignments / confidence
self.timestamps[active_indices, active_lengths] = time_indices

if self.with_alignments and logits is not None and labels is not None:
self.logits[active_indices, active_lengths] = logits
if self.with_alignments or self.with_frame_confidence and labels is not None:
self.labels[active_indices, active_lengths] = labels

if self.with_alignments and logits is not None:
self.logits[active_indices, active_lengths] = logits

if self.with_frame_confidence and confidence is not None:
self.frame_confidence[active_indices, active_lengths] = confidence
# increase lengths
Expand Down Expand Up @@ -714,10 +722,12 @@ def add_results_masked_no_checks_(
# store timestamps - same for alignments / confidence
self.timestamps[self._batch_indices, self.current_lengths] = time_indices

if self.with_alignments and logits is not None and labels is not None:
if self.with_alignments or self.with_frame_confidence and labels is not None:
self.labels[self._batch_indices, self.current_lengths] = labels

if self.with_alignments and logits is not None:
self.timestamps[self._batch_indices, self.current_lengths] = time_indices
self.logits[self._batch_indices, self.current_lengths] = logits
self.labels[self._batch_indices, self.current_lengths] = labels

if self.with_frame_confidence and confidence is not None:
self.frame_confidence[self._batch_indices, self.current_lengths] = confidence
Expand Down Expand Up @@ -785,9 +795,10 @@ def batched_hyps_to_hypotheses(
if alignments is not None:
# move all data to cpu to avoid overhead with moving data by chunks
alignment_lengths = alignments.current_lengths.cpu().tolist()
if alignments.with_alignments or alignments.with_frame_confidence:
alignment_labels = alignments.labels.cpu()
if alignments.with_alignments:
alignment_logits = alignments.logits.cpu()
alignment_labels = alignments.labels.cpu()
if alignments.with_frame_confidence:
frame_confidence = alignments.frame_confidence.cpu()

Expand All @@ -813,4 +824,8 @@ def batched_hyps_to_hypotheses(
[frame_confidence[i, start + j] for j in range(timestamp_cnt)]
)
start += timestamp_cnt

if alignments.with_frame_confidence:
hypotheses[i].alignment_labels = alignment_labels[i][: alignment_lengths[i]]

return hypotheses
Loading