-
Notifications
You must be signed in to change notification settings - Fork 3
Upgrade to transformers release V4.23.1 #62
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,9 +21,16 @@ | |
| from inspect import isclass | ||
| from pathlib import Path | ||
| from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints | ||
|
|
||
| import os | ||
| import yaml | ||
|
|
||
| from sparsezoo import Model | ||
|
|
||
| from .utils.logging import get_logger | ||
|
|
||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| DataClass = NewType("DataClass", Any) | ||
| DataClassType = NewType("DataClassType", Any) | ||
|
|
@@ -229,12 +236,17 @@ def parse_args_into_dataclasses( | |
| # additional namespace. | ||
| outputs.append(namespace) | ||
| if return_remaining_strings: | ||
| return (*outputs, remaining_args) | ||
| return tuple( | ||
| *[_download_dataclass_zoo_stub_files(output) for output in outputs], | ||
| remaining_args, | ||
| ) | ||
| else: | ||
| if remaining_args: | ||
| raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") | ||
|
|
||
| return (*outputs,) | ||
| return tuple( | ||
| [_download_dataclass_zoo_stub_files(output) for output in outputs] | ||
| ) | ||
|
|
||
| def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: | ||
| """ | ||
|
|
@@ -262,7 +274,9 @@ def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tu | |
| outputs.append(obj) | ||
| if not allow_extra_keys and unused_keys: | ||
| raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}") | ||
| return tuple(outputs) | ||
| return tuple( | ||
| [_download_dataclass_zoo_stub_files(output) for output in outputs] | ||
| ) | ||
|
|
||
| def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: | ||
| """ | ||
|
|
@@ -305,3 +319,28 @@ def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tup | |
| """ | ||
| outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys) | ||
| return tuple(outputs) | ||
|
|
||
| def _download_dataclass_zoo_stub_files(data_class: DataClass): | ||
| for name, val in data_class.__dict__.items(): | ||
| if not isinstance(val, str) or "recipe" in name or not val.startswith("zoo:"): | ||
| continue | ||
|
|
||
| logger.info(f"Downloading framework files for SparseZoo stub: {val}") | ||
|
|
||
| zoo_model = Model(val) | ||
| framework_file_paths = [file.path for file in zoo_model.training.default.files] | ||
| assert framework_file_paths, "Unable to download any framework files for SparseZoo stub {val}" | ||
| framework_file_names = [os.path.basename(path) for path in framework_file_paths] | ||
| if "pytorch_model.bin" not in framework_file_names or ("config.json" not in framework_file_names): | ||
| raise RuntimeError( | ||
| "Unable to find 'pytorch_model.bin' and 'config.json' in framework " | ||
| f"files downloaded from {val}. Found {framework_file_names}. Check " | ||
| "if the given stub is for a transformers repo model" | ||
| ) | ||
| framework_dir_path = Path(framework_file_paths[0]).parent.absolute() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will all the framework_file_paths have the same parent? Is that why we can just use the 1st one? |
||
|
|
||
| logger.info(f"Overwriting argument {name} to downloaded {framework_dir_path}") | ||
|
|
||
| data_class.__dict__[name] = str(framework_dir_path) | ||
|
|
||
| return data_class | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -240,6 +240,22 @@ def forward( | |
| return embeddings | ||
|
|
||
|
|
||
| class QATMatMul(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| # behaves like normal torch.matmul unless a SparseML QuantizationModifier | ||
| # is initialized | ||
| self.wrap_qat = True | ||
| self.qat_wrapper_kwargs = { | ||
| "num_inputs": 2, | ||
| "input_qconfigs": ["asymmetric", "symmetric"], | ||
| } | ||
|
|
||
| def forward(self, a: torch.Tensor, b: torch.Tensor): | ||
| return torch.matmul(a, b) | ||
|
|
||
|
|
||
| class BertSelfAttention(nn.Module): | ||
|
||
| def __init__(self, config, position_embedding_type=None): | ||
| super().__init__() | ||
|
|
@@ -257,6 +273,11 @@ def __init__(self, config, position_embedding_type=None): | |
| self.key = nn.Linear(config.hidden_size, self.all_head_size) | ||
| self.value = nn.Linear(config.hidden_size, self.all_head_size) | ||
|
|
||
| # non-parameterized matmuls will behave as normal torch.matmul ops unless | ||
| # Quantization-Aware-Training is invoked | ||
| self.attention_scores_matmul = QATMatMul() | ||
| self.context_layer_matmul = QATMatMul() | ||
|
|
||
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) | ||
| self.position_embedding_type = position_embedding_type or getattr( | ||
| config, "position_embedding_type", "absolute" | ||
|
|
@@ -320,7 +341,7 @@ def forward( | |
| past_key_value = (key_layer, value_layer) | ||
|
|
||
| # Take the dot product between "query" and "key" to get the raw attention scores. | ||
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | ||
| attention_scores = self.attention_scores_matmul(query_layer, key_layer.transpose(-1, -2)) | ||
|
|
||
| if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": | ||
| seq_length = hidden_states.size()[1] | ||
|
|
@@ -354,7 +375,7 @@ def forward( | |
| if head_mask is not None: | ||
| attention_probs = attention_probs * head_mask | ||
|
|
||
| context_layer = torch.matmul(attention_probs, value_layer) | ||
| context_layer = self.context_layer_matmul(attention_probs, value_layer) | ||
|
|
||
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | ||
| new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -89,6 +89,38 @@ def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): | |
| out.detach_() | ||
|
|
||
|
|
||
| class QATAttentionScores(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| # behaves like normal torch.matmul unless a SparseML QuantizationModifier | ||
| # is initialized | ||
| self.wrap_qat = True | ||
| self.qat_wrapper_kwargs = { | ||
| "num_inputs": 2, | ||
| "input_qconfigs": ["asymmetric", "symmetric"], | ||
| } | ||
|
|
||
| def forward(self, a: torch.Tensor, b: torch.Tensor): | ||
| return torch.matmul(a, b) | ||
|
|
||
| class QATContextLayer(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| # behaves like normal torch.matmul unless a SparseML QuantizationModifier | ||
| # is initialized | ||
| self.wrap_qat = True | ||
| self.qat_wrapper_kwargs = { | ||
| "num_inputs": 2, | ||
| "num_outputs": 0, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this supposed to be here? Don't see it in others |
||
| "input_qconfigs": ["asymmetric", "symmetric"], | ||
| } | ||
|
|
||
| def forward(self, a: torch.Tensor, b: torch.Tensor): | ||
| return torch.matmul(a, b) | ||
|
|
||
|
|
||
| class Embeddings(nn.Module): | ||
| def __init__(self, config: PretrainedConfig): | ||
| super().__init__() | ||
|
|
@@ -150,6 +182,11 @@ def __init__(self, config: PretrainedConfig): | |
|
|
||
| self.pruned_heads: Set[int] = set() | ||
|
|
||
| # non-parameterized matmuls will behave as normal torch.matmul ops unless | ||
| # Quantization-Aware-Training is invoked | ||
| self.attention_scores_matmul = QATAttentionScores() | ||
| self.context_layer_matmul = QATContextLayer() | ||
|
|
||
| def prune_heads(self, heads: List[int]): | ||
| attention_head_size = self.dim // self.n_heads | ||
| if len(heads) == 0: | ||
|
|
@@ -207,7 +244,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor: | |
| v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) | ||
|
|
||
| q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) | ||
| scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) | ||
| scores = self.attention_scores_matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) | ||
| mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) | ||
| scores = scores.masked_fill( | ||
| mask, torch.tensor(torch.finfo(scores.dtype).min) | ||
|
|
@@ -220,7 +257,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor: | |
| if head_mask is not None: | ||
| weights = weights * head_mask | ||
|
|
||
| context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) | ||
| context = self.context_layer_matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) | ||
| context = unshape(context) # (bs, q_length, dim) | ||
| context = self.out_lin(context) # (bs, q_length, dim) | ||
|
|
||
|
|
@@ -645,7 +682,6 @@ def forward( | |
| loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | ||
| """ | ||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
|
||
| dlbrt_output = self.distilbert( | ||
| input_ids=input_ids, | ||
| attention_mask=attention_mask, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if i like this better but thought i'd add