Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions tests/models/hubert/test_modeling_tf_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,20 @@ def test_model_from_pretrained(self):
model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960")
self.assertIsNotNone(model)

# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_dataset_conversion(self):
Copy link
Contributor Author

@gante gante Feb 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one, in TFHubertModelTest, was the issue. Since it did not have the lower batch size treatment as in the other TF CTC tests, I'm assuming this is the fix.

default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size

# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_keras_fit(self):
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_keras_fit()
self.model_tester.batch_size = default_batch_size


@require_tf
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
Expand Down Expand Up @@ -431,20 +445,18 @@ def test_resize_tokens_embeddings(self):
def test_model_common_attributes(self):
pass

@slow
def test_model_from_pretrained(self):
model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
self.assertIsNotNone(model)

# We override here as passing a full batch of 13 samples results in OOM errors for CTC
# TODO: fix me
@unittest.skip(reason="Crashing on CI, temporarily skipped")
def test_dataset_conversion(self):
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size

@slow
def test_model_from_pretrained(self):
model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
self.assertIsNotNone(model)

# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_keras_fit(self):
default_batch_size = self.model_tester.batch_size
Expand Down
4 changes: 2 additions & 2 deletions tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def test_dataset_conversion(self):
def test_keras_fit(self):
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
super().test_keras_fit()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ninja copy paste hehe

self.model_tester.batch_size = default_batch_size


Expand Down Expand Up @@ -528,7 +528,7 @@ def test_dataset_conversion(self):
def test_keras_fit(self):
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
super().test_keras_fit()
self.model_tester.batch_size = default_batch_size


Expand Down