Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 82 additions & 11 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:
if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor

# batch norm layer mean
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("mean",)
if pt_tuple_key[-1] == "running_mean" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor

# batch norm layer var
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("var",)
if pt_tuple_key[-1] == "running_var" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor

# embedding
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
Expand Down Expand Up @@ -118,13 +128,25 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}

model_prefix = flax_model.base_model_prefix
random_flax_state_dict = flatten_dict(flax_model.params)

# use params dict if the model contains batch norm layers
if "params" in flax_model.params:
flax_model_params = flax_model.params["params"]
else:
flax_model_params = flax_model.params
random_flax_state_dict = flatten_dict(flax_model_params)

# add batch_stats keys,values to dict
if "batch_stats" in flax_model.params:
flax_batch_stats = flatten_dict(flax_model.params["batch_stats"])
random_flax_state_dict.update(flax_batch_stats)

flax_state_dict = {}

load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
)
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
)

Expand Down Expand Up @@ -154,8 +176,22 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
)

# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
# add batch stats if the model contains batchnorm layers
if "batch_stats" in flax_model.params:
if "mean" in flax_key[-1] or "var" in flax_key[-1]:
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
continue
# remove num_batches_tracked key
if "num_batches_tracked" in flax_key[-1]:
flax_state_dict.pop(flax_key, None)
continue

# also add unexpected weight so that warning is thrown
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor)

else:
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)

return unflatten_dict(flax_state_dict)

Expand All @@ -176,12 +212,21 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}

model_prefix = flax_model.base_model_prefix
random_flax_state_dict = flatten_dict(flax_model.params)

load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
# use params dict if the model contains batch norm layers and then add batch_stats keys,values to dict
if "batch_stats" in flax_model.params:
flax_model_params = flax_model.params["params"]

random_flax_state_dict = flatten_dict(flax_model_params)
random_flax_state_dict.update(flatten_dict(flax_model.params["batch_stats"]))
else:
flax_model_params = flax_model.params
random_flax_state_dict = flatten_dict(flax_model_params)

load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
)
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
)
# Need to change some parameters name to match Flax names
Expand Down Expand Up @@ -209,8 +254,25 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
)

# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
# add batch stats if the model contains batchnorm layers
if "batch_stats" in flax_model.params:
if "mean" in flax_key[-1]:
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
continue
if "var" in flax_key[-1]:
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
continue
# remove num_batches_tracked key
if "num_batches_tracked" in flax_key[-1]:
flax_state_dict.pop(flax_key, None)
continue

# also add unexpected weight so that warning is thrown
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor)

else:
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
return unflatten_dict(flax_state_dict)


Expand Down Expand Up @@ -299,7 +361,16 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
elif flax_key_tuple[-1] in ["scale", "embedding"]:
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)

flax_key = ".".join(flax_key_tuple)
# adding batch stats from flax batch norm to pt
elif "mean" in flax_key_tuple[-1]:
flax_key_tuple = flax_key_tuple[:-1] + ("running_mean",)
elif "var" in flax_key_tuple[-1]:
flax_key_tuple = flax_key_tuple[:-1] + ("running_var",)

if "batch_stats" in flax_state:
flax_key = ".".join(flax_key_tuple[1:]) # Remove the params/batch_stats header
else:
flax_key = ".".join(flax_key_tuple)

if flax_key in pt_model_dict:
if flax_tensor.shape != pt_model_dict[flax_key].shape:
Expand Down
40 changes: 33 additions & 7 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,14 +837,35 @@ def from_pretrained(
# keep the params on CPU if we don't want to initialize
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)

# if model is base model only use model_prefix key
if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
state = state[cls.base_model_prefix]
if "batch_stats" in state: # if flax model contains batch norm layers
# if model is base model only use model_prefix key
if (
cls.base_model_prefix not in dict(model.params_shape_tree["params"])
and cls.base_model_prefix in state["params"]
):
state["params"] = state["params"][cls.base_model_prefix]
state["batch_stats"] = state["batch_stats"][cls.base_model_prefix]

# if model is head model and we are loading weights from base model
# we initialize new params dict with base_model_prefix
if (
cls.base_model_prefix in dict(model.params_shape_tree["params"])
and cls.base_model_prefix not in state["params"]
):
state = {
"params": {cls.base_model_prefix: state["params"]},
"batch_stats": {cls.base_model_prefix: state["batch_stats"]},
}

else:
# if model is base model only use model_prefix key
if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
state = state[cls.base_model_prefix]

# if model is head model and we are loading weights from base model
# we initialize new params dict with base_model_prefix
if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
state = {cls.base_model_prefix: state}
# if model is head model and we are loading weights from base model
# we initialize new params dict with base_model_prefix
if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
state = {cls.base_model_prefix: state}

# flatten dicts
state = flatten_dict(state)
Expand All @@ -854,6 +875,11 @@ def from_pretrained(
missing_keys = model.required_params - set(state.keys())
unexpected_keys = set(state.keys()) - model.required_params

# Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked
for unexpected_key in unexpected_keys.copy():
if "num_batches_tracked" in unexpected_key[-1]:
unexpected_keys.remove(unexpected_key)

if missing_keys and not _do_init:
logger.warning(
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
Expand Down
44 changes: 34 additions & 10 deletions tests/test_modeling_flax_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,30 @@ def random_attention_mask(shape, rng=None):
return attn_mask


def get_params(params, from_head_prefix=None):
"""Function extracts relevant parameters into flatten dict from model params,
appends batch normalization statistics if present"""

# If Both parameters and batch normalization statistics are present
if "batch_stats" in params:
# Extract only parameters for the specified head prefix (if specified) and add batch statistics
if from_head_prefix is not None:
extracted_params = flatten_dict(unfreeze(params["params"][from_head_prefix]))
extracted_params.update(flatten_dict(params["batch_stats"][from_head_prefix]))
else:
extracted_params = flatten_dict(unfreeze(params["params"]))
extracted_params.update(flatten_dict(params["batch_stats"]))

# Only parameters are present
else:
if from_head_prefix is not None:
extracted_params = flatten_dict(unfreeze(params[from_head_prefix]))
else:
extracted_params = flatten_dict(unfreeze(params))

return extracted_params


@require_flax
class FlaxModelTesterMixin:
model_tester = None
Expand Down Expand Up @@ -426,14 +450,14 @@ def test_save_load_from_base(self):
continue

model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
base_params = get_params(model.params)

# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname)

base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)

for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
Expand All @@ -448,14 +472,14 @@ def test_save_load_to_base(self):
continue

model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)

# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname)

base_params = flatten_dict(unfreeze(base_model.params))
base_params = get_params(base_model.params)

for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
Expand All @@ -471,7 +495,7 @@ def test_save_load_from_base_pt(self):
continue

model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
base_params = get_params(model.params)

# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
Expand All @@ -484,7 +508,7 @@ def test_save_load_from_base_pt(self):
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)

base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)

for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
Expand All @@ -500,7 +524,7 @@ def test_save_load_to_base_pt(self):
continue

model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)

# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
Expand All @@ -512,7 +536,7 @@ def test_save_load_to_base_pt(self):
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)

base_params = flatten_dict(unfreeze(base_model.params))
base_params = get_params(base_model.params)

for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
Expand All @@ -529,7 +553,7 @@ def test_save_load_bf16_to_base_pt(self):

model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)

# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
Expand All @@ -541,7 +565,7 @@ def test_save_load_bf16_to_base_pt(self):
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)

base_params = flatten_dict(unfreeze(base_model.params))
base_params = get_params(base_model.params)

for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
Expand Down