diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index a1663796dec2..163178929f98 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -249,7 +249,10 @@ def load_pytorch_weights_in_tf2_model( ) raise - pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + # Numpy doesn't understand bfloat16, so upcast to a dtype that doesn't lose precision + pt_state_dict = { + k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items() + } return load_pytorch_state_dict_in_tf2_model( tf_model, pt_state_dict, diff --git a/tests/test_modeling_tf_utils.py b/tests/test_modeling_tf_utils.py index ff11a8a556bb..8a281761333d 100644 --- a/tests/test_modeling_tf_utils.py +++ b/tests/test_modeling_tf_utils.py @@ -63,6 +63,7 @@ PreTrainedModel, PushToHubCallback, RagRetriever, + TFAutoModel, TFBertForMaskedLM, TFBertForSequenceClassification, TFBertModel, @@ -435,6 +436,16 @@ def test_safetensors_checkpoint_sharding_local(self): for p1, p2 in zip(model.weights, new_model.weights): self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + @is_pt_tf_cross_test + @require_safetensors + def test_bfloat16_torch_loading(self): + # Assert that neither of these raise an error - both repos contain bfloat16 tensors + model1 = TFAutoModel.from_pretrained("Rocketknight1/tiny-random-gpt2-bfloat16-pt", from_pt=True) + model2 = TFAutoModel.from_pretrained("Rocketknight1/tiny-random-gpt2-bfloat16") # PT-format safetensors + # Check that PT and safetensors loading paths end up with the same values + for weight1, weight2 in zip(model1.weights, model2.weights): + self.assertTrue(tf.reduce_all(weight1 == weight2)) + @slow def test_save_pretrained_signatures(self): model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")