Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions tests/models/hubert/test_modeling_tf_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,19 +450,15 @@ def test_model_from_pretrained(self):
model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
self.assertIsNotNone(model)

# We override here as passing a full batch of 13 samples results in OOM errors for CTC
@unittest.skip(reason="Fix me! Hubert hits OOM errors when loss is computed on full batch")
def test_dataset_conversion(self):
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass

# We override here as passing a full batch of 13 samples results in OOM errors for CTC
@unittest.skip(reason="Fix me! Hubert hits OOM errors when loss is computed on full batch")
def test_keras_fit(self):
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_keras_fit()
self.model_tester.batch_size = default_batch_size
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass


@require_tf
Expand Down
17 changes: 6 additions & 11 deletions tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,20 +385,15 @@ def test_model_from_pretrained(self):
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsNotNone(model)

# We override here as passing a full batch of 13 samples results in OOM errors for CTC
@unittest.skip("Fix me!")
@unittest.skip(reason="Fix me! Wav2Vec2 hits OOM errors when loss is computed on full batch")
def test_dataset_conversion(self):
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass

# We override here as passing a full batch of 13 samples results in OOM errors for CTC
@unittest.skip(reason="Fix me! Wav2Vec2 hits OOM errors when loss is computed on full batch")
def test_keras_fit(self):
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_keras_fit()
self.model_tester.batch_size = default_batch_size
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass


@require_tf
Expand Down