diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 3f2b564b70bf..9db0f582e2aa 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -216,7 +216,13 @@ def load_pytorch_state_dict_in_tf2_model( tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False ): """Load a pytorch state_dict in a TF 2.0 model.""" - from tensorflow.python.keras import backend as K + import tensorflow as tf + from packaging.version import parse + + if parse(tf.__version__) >= parse("2.11.0"): + from keras import backend as K + else: + from tensorflow.python.keras import backend as K if tf_inputs is None: tf_inputs = tf_model.dummy_inputs diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index f79e3dedff5b..e8c87d83f135 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -30,13 +30,9 @@ import h5py import numpy as np import tensorflow as tf -from tensorflow.python.keras import backend as K -from tensorflow.python.keras.engine import data_adapter -from tensorflow.python.keras.engine.keras_tensor import KerasTensor -from tensorflow.python.keras.saving import hdf5_format +from packaging.version import parse from huggingface_hub import Repository, list_repo_files -from keras.saving.hdf5_format import save_attributes_to_hdf5_group from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from . import DataCollatorWithPadding, DefaultDataCollator @@ -68,6 +64,18 @@ ) +if parse(tf.__version__) >= parse("2.11.0"): + from keras import backend as K + from keras.engine import data_adapter + from keras.engine.keras_tensor import KerasTensor + from keras.saving.legacy import hdf5_format +else: + from tensorflow.python.keras import backend as K + from tensorflow.python.keras.engine import data_adapter + from tensorflow.python.keras.engine.keras_tensor import KerasTensor + from tensorflow.python.keras.saving import hdf5_format + + if is_safetensors_available(): from safetensors import safe_open from safetensors.tensorflow import load_file as safe_load_file @@ -2310,7 +2318,7 @@ def save_pretrained( ) param_dset[:] = layer.numpy() layers.append(layer_name.encode("utf8")) - save_attributes_to_hdf5_group(shard_file, "layer_names", layers) + hdf5_format.save_attributes_to_hdf5_group(shard_file, "layer_names", layers) if push_to_hub: self._upload_modified_files(