Skip to content

Commit

Permalink
Fix for interctc test random failure (#6644)
Browse files Browse the repository at this point in the history
Signed-off-by: Igor Gitman <[email protected]>
  • Loading branch information
Kipok committed May 26, 2023
1 parent ee61a48 commit b50ae98
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions tests/collections/asr/test_asr_interctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def squeezeformer_encoder_config() -> Dict:


class TestInterCTCLoss:
@pytest.mark.pleasefixme
@pytest.mark.unit
@pytest.mark.parametrize(
"model_class", [EncDecCTCModel, EncDecHybridRNNTCTCModel],
Expand Down Expand Up @@ -199,11 +198,18 @@ def __len__(self):
def __getitem__(self, idx):
return self.values

# this sometimes results in all zeros in the output which breaks tests
# so using this only for the ptl calls in the bottom, but using
# processed signal directly initially to remove the chance of
# this edge-case
input_signal = torch.randn(size=(1, 512))
input_length = torch.randint(low=161, high=500, size=[1])
target = torch.randint(size=(1, input_length[0]), low=0, high=28)
target_length = torch.tensor([input_length[0]])

processed_signal = torch.randn(size=([1, 64, 12]))
processed_length = torch.tensor([8])

if len(apply_at_layers) != len(loss_weights):
# has to throw an error here
with pytest.raises(
Expand All @@ -216,7 +222,9 @@ def __getitem__(self, idx):
asr_model = model_class(cfg=model_config)
asr_model.train()
AccessMixin.set_access_enabled(access_enabled=True)
logprobs, *_ = asr_model.forward(input_signal=input_signal, input_signal_length=input_length)
logprobs, *_ = asr_model.forward(
processed_signal=processed_signal, processed_signal_length=processed_length
)
captured_tensors = asr_model.get_captured_interctc_tensors()
AccessMixin.reset_registry(asr_model)
assert len(captured_tensors) == len(apply_at_layers)
Expand Down

0 comments on commit b50ae98

Please sign in to comment.