Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 25 additions & 27 deletions tests/test_modeling_flax_bert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest

import pytest
from numpy import ndarray

from transformers import BertTokenizerFast, TensorType, is_flax_available, is_torch_available
Expand All @@ -24,6 +23,10 @@
@require_flax
@require_torch
class FlaxBertModelTest(unittest.TestCase):
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
diff = (a - b).sum()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol})")

def test_from_pytorch(self):
with torch.no_grad():
with self.subTest("bert-base-cased"):
Expand All @@ -40,32 +43,27 @@ def test_from_pytorch(self):
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")

for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-4)
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)

def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
diff = (a - b).sum()
self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol))
def test_multiple_sequences(self):
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
model = FlaxBertModel.from_pretrained("bert-base-cased")

sequences = ["this is an example sentence", "this is another", "and a third one"]
encodings = tokenizer(sequences, return_tensors=TensorType.JAX, padding=True, truncation=True)

@require_flax
@require_torch
@pytest.mark.parametrize("jit", ["disable_jit", "enable_jit"])
def test_multiple_sentences(jit):
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
model = FlaxBertModel.from_pretrained("bert-base-cased")

sentences = ["this is an example sentence", "this is another", "and a third one"]
encodings = tokenizer(sentences, return_tensors=TensorType.JAX, padding=True, truncation=True)

@jax.jit
def model_jitted(input_ids, attention_mask, token_type_ids):
return model(input_ids, attention_mask, token_type_ids)

if jit == "disable_jit":
with jax.disable_jit():
tokens, pooled = model_jitted(**encodings)
else:
tokens, pooled = model_jitted(**encodings)

assert tokens.shape == (3, 7, 768)
assert pooled.shape == (3, 768)
@jax.jit
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
return model(input_ids, attention_mask, token_type_ids)

with self.subTest("JIT Disabled"):
with jax.disable_jit():
tokens, pooled = model_jitted(**encodings)
self.assertEqual(tokens.shape, (3, 7, 768))
self.assertEqual(pooled.shape, (3, 768))

with self.subTest("JIT Enabled"):
jitted_tokens, jitted_pooled = model_jitted(**encodings)

self.assertEqual(jitted_tokens.shape, (3, 7, 768))
self.assertEqual(jitted_pooled.shape, (3, 768))
52 changes: 25 additions & 27 deletions tests/test_modeling_flax_roberta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest

import pytest
from numpy import ndarray

from transformers import RobertaTokenizerFast, TensorType, is_flax_available, is_torch_available
Expand All @@ -24,6 +23,10 @@
@require_flax
@require_torch
class FlaxRobertaModelTest(unittest.TestCase):
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
diff = (a - b).sum()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol})")

def test_from_pytorch(self):
with torch.no_grad():
with self.subTest("roberta-base"):
Expand All @@ -40,32 +43,27 @@ def test_from_pytorch(self):
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")

for fx_output, pt_output in zip(fx_outputs, pt_outputs.to_tuple()):
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-4)
self.assert_almost_equals(fx_output, pt_output.numpy(), 6e-4)

def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
diff = (a - b).sum()
self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol))
def test_multiple_sequences(self):
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
model = FlaxRobertaModel.from_pretrained("roberta-base")

sequences = ["this is an example sentence", "this is another", "and a third one"]
encodings = tokenizer(sequences, return_tensors=TensorType.JAX, padding=True, truncation=True)

@require_flax
@require_torch
@pytest.mark.parametrize("jit", ["disable_jit", "enable_jit"])
def test_multiple_sentences(jit):
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
model = FlaxRobertaModel.from_pretrained("roberta-base")

sentences = ["this is an example sentence", "this is another", "and a third one"]
encodings = tokenizer(sentences, return_tensors=TensorType.JAX, padding=True, truncation=True)

@jax.jit
def model_jitted(input_ids, attention_mask):
return model(input_ids, attention_mask)

if jit == "disable_jit":
with jax.disable_jit():
tokens, pooled = model_jitted(**encodings)
else:
tokens, pooled = model_jitted(**encodings)

assert tokens.shape == (3, 7, 768)
assert pooled.shape == (3, 768)
@jax.jit
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
return model(input_ids, attention_mask, token_type_ids)

with self.subTest("JIT Disabled"):
with jax.disable_jit():
tokens, pooled = model_jitted(**encodings)
self.assertEqual(tokens.shape, (3, 7, 768))
self.assertEqual(pooled.shape, (3, 768))

with self.subTest("JIT Enabled"):
jitted_tokens, jitted_pooled = model_jitted(**encodings)

self.assertEqual(jitted_tokens.shape, (3, 7, 768))
self.assertEqual(jitted_pooled.shape, (3, 768))