Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions examples/research_projects/lxmert/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,18 @@
}
],
"source": [
"from IPython.display import clear_output, Image, display\n",
"import PIL.Image\n",
"import io\n",
"import json\n",
"import torch\n",
"\n",
"import numpy as np\n",
"import PIL.Image\n",
"from IPython.display import Image, display\n",
"from modeling_frcnn import GeneralizedRCNN\n",
"from processing_image import Preprocess\n",
"from visualizing_image import SingleImageViz\n",
"from modeling_frcnn import GeneralizedRCNN\n",
"from utils import Config\n",
"\n",
"import utils\n",
"from transformers import LxmertForQuestionAnswering, LxmertTokenizer\n",
"import wget\n",
"import pickle\n",
"import os\n",
"from utils import Config\n",
"\n",
"\n",
"# URL = \"https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/images/input.jpg\",\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@
"source": [
"# Includes\n",
"\n",
"import h5py\n",
"import os\n",
"import json\n",
"import os\n",
"from collections import OrderedDict\n",
"\n",
"from scipy import sparse\n",
"import h5py\n",
"import numpy as np\n",
"\n",
"import torch\n",
"from scipy import sparse\n",
"from torch import nn\n",
"\n",
"from transformers import *\n",
"\n",
"\n",
"os.chdir(\"../../\")"
]
},
Expand Down
179 changes: 91 additions & 88 deletions examples/research_projects/visual_bert/demo.ipynb

Large diffs are not rendered by default.

107 changes: 76 additions & 31 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,8 @@ def _load_state_dict_into_meta_model(
is_safetensors=False,
keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
# for flagging the user when the model contains renamed keys
pretrained_model_name_or_path=None,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
Expand Down Expand Up @@ -1202,9 +1203,11 @@ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
# We can specify head_mask for each layer
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
# switch to float if need + fp16 compatibility
head_mask = head_mask.to(dtype=self.dtype)
return head_mask

def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
Expand Down Expand Up @@ -1489,7 +1492,8 @@ def _from_config(cls, config, **kwargs):
if torch_dtype is not None:
dtype_orig = cls._set_default_torch_dtype(torch_dtype)

config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
# We do not want to modify the config inplace in _from_config.
config = copy.deepcopy(config)

if config._attn_implementation_internal is not None:
# In this case, the config has been created with the attn_implementation set by the user, which we
Expand Down Expand Up @@ -2369,7 +2373,8 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
for layer, heads in heads_to_prune.items():
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
# Unfortunately we have to store it as list for JSON
self.config.pruned_heads[layer] = list(union_heads)

self.base_model._prune_heads(heads_to_prune)

Expand Down Expand Up @@ -3866,7 +3871,8 @@ def from_pretrained(
elif low_cpu_mem_usage:
init_contexts.append(init_empty_weights())

config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
# We do not want to modify the config inplace in from_pretrained.
config = copy.deepcopy(config)
config = cls._autoset_attn_implementation(
config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
)
Expand Down Expand Up @@ -4155,27 +4161,54 @@ def _load_pretrained_model(
expected_keys = list(model_state_dict.keys())
prefix = model.base_model_prefix

def _fix_key(key):
if "beta" in key:
return key.replace("beta", "bias")
if "gamma" in key:
return key.replace("gamma", "weight")
error_msgs = []
old_keys = []
new_keys = []
renamed_keys = {}
warning_msg = f"This model {type(model)}"

# Preserve the original_loaded_keys reference without modifying it
original_loaded_keys = loaded_keys # Do not copy, just refer to the original

# Create a new list to hold the updated keys
updated_loaded_keys = []

# Single loop for processing keys
for key in loaded_keys:
new_key = key

if "gamma" in key:
new_key = key.replace("gamma", "weight")
elif "beta" in key:
new_key = key.replace("beta", "bias")
# to avoid logging parametrized weight norm renaming
if hasattr(nn.utils.parametrizations, "weight_norm"):
if "weight_g" in key:
return key.replace("weight_g", "parametrizations.weight.original0")
new_key = key.replace("weight_g", "parametrizations.weight.original0")
if "weight_v" in key:
return key.replace("weight_v", "parametrizations.weight.original1")
new_key = key.replace("weight_v", "parametrizations.weight.original1")
else:
if "parametrizations.weight.original0" in key:
return key.replace("parametrizations.weight.original0", "weight_g")
new_key = key.replace("parametrizations.weight.original0", "weight_g")
if "parametrizations.weight.original1" in key:
return key.replace("parametrizations.weight.original1", "weight_v")
return key

original_loaded_keys = loaded_keys
loaded_keys = [_fix_key(key) for key in loaded_keys]
new_key = key.replace("parametrizations.weight.original1", "weight_v")
if new_key != key:
old_keys.append(key)
new_keys.append(new_key)
renamed_keys[key] = new_key
# Add the new (or unchanged) key
updated_loaded_keys.append(new_key)

if renamed_keys:
warning_msg += 'contains parameters that have been renamed internally ("gamma" and "beta" in parameters or parametrized weight norm) (a few are listed below but more are present in the model):\n'
logger.warning(warning_msg)
for old_key, new_key in renamed_keys.items():
warning_msg += f"* {old_key} -> {new_key}\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info(warning_msg)

# Now assign the updated list to loaded_keys
loaded_keys = updated_loaded_keys

if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
Expand Down Expand Up @@ -4643,7 +4676,8 @@ def _load_pretrained_model_low_mem(

_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
state_dict = load_state_dict(resolved_archive_file)
expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
# plug for missing expected_keys. TODO: replace with proper keys
expected_keys = loaded_state_dict_keys
error_msgs = _load_state_dict_into_meta_model(
model,
state_dict,
Expand Down Expand Up @@ -4873,9 +4907,12 @@ def forward(
), "One of start_states, start_positions should be not None"
if start_positions is not None:
slen, hsz = hidden_states.shape[-2:]
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
# shape (bsz, 1, hsz)
start_positions = start_positions[:, None, None].expand(-1, -1, hsz)
# shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions)
# shape (bsz, slen, hsz)
start_states = start_states.expand(-1, slen, -1)

x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
x = self.activation(x)
Expand Down Expand Up @@ -4940,12 +4977,16 @@ def forward(
start_states is not None or start_positions is not None
), "One of start_states, start_positions should be not None"
if start_positions is not None:
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
# shape (bsz, 1, hsz)
start_positions = start_positions[:, None, None].expand(-1, -1, hsz)
# shape (bsz, hsz)
start_states = hidden_states.gather(-2, start_positions).squeeze(-2)

if cls_index is not None:
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
# shape (bsz, 1, hsz)
cls_index = cls_index[:, None, None].expand(-1, -1, hsz)
# shape (bsz, hsz)
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2)
else:
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)

Expand Down Expand Up @@ -5072,9 +5113,12 @@ def forward(
start_top_log_probs, start_top_index = torch.topk(
start_log_probs, self.start_n_top, dim=-1
) # shape (bsz, start_n_top)
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
# shape (bsz, start_n_top, hsz)
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz)
# shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index_exp)
# shape (bsz, slen, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1)

hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
start_states
Expand Down Expand Up @@ -5191,7 +5235,8 @@ def forward(
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
# shape (bsz, XX, hidden_size)
output = hidden_states.gather(-2, cls_index).squeeze(-2)
elif self.summary_type == "attn":
raise NotImplementedError

Expand Down