Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ def load_pytorch_weights_in_tf2_model(
)
raise

pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
# Numpy doesn't understand bfloat16, so upcast to a dtype that doesn't lose precision
pt_state_dict = {
k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
}
return load_pytorch_state_dict_in_tf2_model(
tf_model,
pt_state_dict,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
PreTrainedModel,
PushToHubCallback,
RagRetriever,
TFAutoModel,
TFBertForMaskedLM,
TFBertForSequenceClassification,
TFBertModel,
Expand Down Expand Up @@ -435,6 +436,16 @@ def test_safetensors_checkpoint_sharding_local(self):
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))

@is_pt_tf_cross_test
@require_safetensors
def test_bfloat16_torch_loading(self):
# Assert that neither of these raise an error - both repos contain bfloat16 tensors
model1 = TFAutoModel.from_pretrained("Rocketknight1/tiny-random-gpt2-bfloat16-pt", from_pt=True)
model2 = TFAutoModel.from_pretrained("Rocketknight1/tiny-random-gpt2-bfloat16") # PT-format safetensors
# Check that PT and safetensors loading paths end up with the same values
for weight1, weight2 in zip(model1.weights, model2.weights):
self.assertTrue(tf.reduce_all(weight1 == weight2))

@slow
def test_save_pretrained_signatures(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
Expand Down