Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
if v.dtype == bfloat16:
v = v.float()
pt_state_dict[k] = v.numpy()
pt_state_dict[k] = v.cpu().numpy()

model_prefix = flax_model.base_model_prefix

Expand Down
11 changes: 7 additions & 4 deletions tests/models/clip/test_modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ def test_equivalence_pt_to_flax(self):
with self.subTest(model_class.__name__):
# load PyTorch class
pt_model = model_class(config).eval()
pt_model.to(torch_device)
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
Expand Down Expand Up @@ -881,7 +882,7 @@ def test_equivalence_pt_to_flax(self):
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)

with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
Expand All @@ -892,7 +893,7 @@ def test_equivalence_pt_to_flax(self):
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)

# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
Expand Down Expand Up @@ -921,6 +922,7 @@ def test_equivalence_flax_to_pt(self):
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()

pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
pt_model.to(torch_device)

# make sure weights are tied in PyTorch
pt_model.tie_weights()
Expand All @@ -940,11 +942,12 @@ def test_equivalence_flax_to_pt(self):
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")

for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)

with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)

with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
Expand All @@ -953,7 +956,7 @@ def test_equivalence_flax_to_pt(self):
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)

@slow
def test_model_from_pretrained(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()

fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)

# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -315,7 +315,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)

# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -330,7 +330,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)

def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/informer/test_modeling_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def check_encoder_decoder_model_standalone(self, config, inputs_dict):

embed_positions = InformerSinusoidalPositionalEmbedding(
config.context_length + config.prediction_length, config.d_model
)
).to(torch_device)
self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight))
self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,15 +412,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()

fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)

# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -430,7 +430,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)

# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -445,7 +445,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)

def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()

fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)

# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -259,7 +259,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)

# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -274,7 +274,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)

def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()

fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)

# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -178,7 +178,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)

# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -193,7 +193,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)

def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mas
# prepare inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
pt_inputs = inputs_dict
flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}
flax_inputs = {k: v.numpy(force=True) for k, v in pt_inputs.items()}

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()

fx_outputs = fx_model(**flax_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)

# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -197,7 +197,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mas
fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)

# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -212,7 +212,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mas

self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)

def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
Expand Down