Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,16 @@
from inspect import isclass
from pathlib import Path
from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints

import os
import yaml

from sparsezoo import Model

from .utils.logging import get_logger


logger = get_logger(__name__)


DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
Expand Down Expand Up @@ -229,12 +236,17 @@ def parse_args_into_dataclasses(
# additional namespace.
outputs.append(namespace)
if return_remaining_strings:
return (*outputs, remaining_args)
return tuple(
*[_download_dataclass_zoo_stub_files(output) for output in outputs],
remaining_args,
)
else:
if remaining_args:
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")

return (*outputs,)
return tuple(
[_download_dataclass_zoo_stub_files(output) for output in outputs]
)

def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Expand Down Expand Up @@ -262,7 +274,9 @@ def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tu
outputs.append(obj)
if not allow_extra_keys and unused_keys:
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
return tuple(outputs)
return tuple(
[_download_dataclass_zoo_stub_files(output) for output in outputs]
)
Comment on lines +277 to +279

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if i like this better but thought i'd add

Suggested change
return tuple(
[_download_dataclass_zoo_stub_files(output) for output in outputs]
)
return tuple(map(_download_dataclass_zoo_stub_files, outputs))


def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Expand Down Expand Up @@ -305,3 +319,28 @@ def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tup
"""
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
return tuple(outputs)

def _download_dataclass_zoo_stub_files(data_class: DataClass):
for name, val in data_class.__dict__.items():
if not isinstance(val, str) or "recipe" in name or not val.startswith("zoo:"):
continue

logger.info(f"Downloading framework files for SparseZoo stub: {val}")

zoo_model = Model(val)
framework_file_paths = [file.path for file in zoo_model.training.default.files]
assert framework_file_paths, "Unable to download any framework files for SparseZoo stub {val}"
framework_file_names = [os.path.basename(path) for path in framework_file_paths]
if "pytorch_model.bin" not in framework_file_names or ("config.json" not in framework_file_names):
raise RuntimeError(
"Unable to find 'pytorch_model.bin' and 'config.json' in framework "
f"files downloaded from {val}. Found {framework_file_names}. Check "
"if the given stub is for a transformers repo model"
)
framework_dir_path = Path(framework_file_paths[0]).parent.absolute()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will all the framework_file_paths have the same parent? Is that why we can just use the 1st one?


logger.info(f"Overwriting argument {name} to downloaded {framework_dir_path}")

data_class.__dict__[name] = str(framework_dir_path)

return data_class
25 changes: 23 additions & 2 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,22 @@ def forward(
return embeddings


class QATMatMul(nn.Module):
def __init__(self):
super().__init__()

# behaves like normal torch.matmul unless a SparseML QuantizationModifier
# is initialized
self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 2,
"input_qconfigs": ["asymmetric", "symmetric"],
}

def forward(self, a: torch.Tensor, b: torch.Tensor):
return torch.matmul(a, b)


class BertSelfAttention(nn.Module):
Copy link

@corey-nm corey-nm Oct 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm almost positive we could hot patch this stuff in instead of having repo forks.

Something like

import transformers

from transformers.models.bert.modeling_bert import BertSelfAttention as _BertSelfAttention

class PatchedBertSelfAttention(_BertSelfAttention):
    def __init__(self, *args, **kwargs):
        self.attention_scores_matmul = ...
    def forward(self):
        ...

transformers.models.bert.modeling_bert.BertSelfAttention = PatchedBertSelfAttention

Copy link
Author

@KSGulin KSGulin Oct 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think now that most of the implementation has been moved on the sparseml side, there's definitely potential to explore here. One thing to keep in mind is that every time we've upgraded the HF repo had changes that broke our integration, so the hot patch would need to be easy to debug and amend. But in general I'm all for this

def __init__(self, config, position_embedding_type=None):
super().__init__()
Expand All @@ -257,6 +273,11 @@ def __init__(self, config, position_embedding_type=None):
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)

# non-parameterized matmuls will behave as normal torch.matmul ops unless
# Quantization-Aware-Training is invoked
self.attention_scores_matmul = QATMatMul()
self.context_layer_matmul = QATMatMul()

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
Expand Down Expand Up @@ -320,7 +341,7 @@ def forward(
past_key_value = (key_layer, value_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = self.attention_scores_matmul(query_layer, key_layer.transpose(-1, -2))

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
Expand Down Expand Up @@ -354,7 +375,7 @@ def forward(
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs, value_layer)
context_layer = self.context_layer_matmul(attention_probs, value_layer)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
Expand Down
42 changes: 39 additions & 3 deletions src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,38 @@ def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
out.detach_()


class QATAttentionScores(nn.Module):
def __init__(self):
super().__init__()

# behaves like normal torch.matmul unless a SparseML QuantizationModifier
# is initialized
self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 2,
"input_qconfigs": ["asymmetric", "symmetric"],
}

def forward(self, a: torch.Tensor, b: torch.Tensor):
return torch.matmul(a, b)

class QATContextLayer(nn.Module):
def __init__(self):
super().__init__()

# behaves like normal torch.matmul unless a SparseML QuantizationModifier
# is initialized
self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 2,
"num_outputs": 0,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be here? Don't see it in others

"input_qconfigs": ["asymmetric", "symmetric"],
}

def forward(self, a: torch.Tensor, b: torch.Tensor):
return torch.matmul(a, b)


class Embeddings(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
Expand Down Expand Up @@ -150,6 +182,11 @@ def __init__(self, config: PretrainedConfig):

self.pruned_heads: Set[int] = set()

# non-parameterized matmuls will behave as normal torch.matmul ops unless
# Quantization-Aware-Training is invoked
self.attention_scores_matmul = QATAttentionScores()
self.context_layer_matmul = QATContextLayer()

def prune_heads(self, heads: List[int]):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
Expand Down Expand Up @@ -207,7 +244,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)

q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
scores = self.attention_scores_matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(
mask, torch.tensor(torch.finfo(scores.dtype).min)
Expand All @@ -220,7 +257,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
if head_mask is not None:
weights = weights * head_mask

context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
context = self.context_layer_matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
context = unshape(context) # (bs, q_length, dim)
context = self.out_lin(context) # (bs, q_length, dim)

Expand Down Expand Up @@ -645,7 +682,6 @@ def forward(
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

dlbrt_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down
19 changes: 18 additions & 1 deletion src/transformers/models/mobilebert/modeling_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,23 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:

NORM2FN = {"layer_norm": nn.LayerNorm, "no_norm": NoNorm}

class QATEmbeddingTransformation(nn.Module):
def __init__(self, embedded_input_size, hidden_size):
super().__init__()

# Behaves like normal Linear module unless a SparseML QuantizationModifier
# is initialized.
# When initialized, does not quantize inputs.
# Only weights are quantized (inputs come quantized from embeddings)
self.linear = nn.Linear(embedded_input_size, hidden_size)
self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 0,
"num_outputs": 1,
}

def forward(self, x: torch.Tensor):
return self.linear(x)

class MobileBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
Expand All @@ -186,7 +203,7 @@ def __init__(self, config):

embed_dim_multiplier = 3 if self.trigram_input else 1
embedded_input_size = self.embedding_size * embed_dim_multiplier
self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size)
self.embedding_transformation = QATEmbeddingTransformation(embedded_input_size, config.hidden_size)

self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
Expand Down
38 changes: 29 additions & 9 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,10 @@ def _inner_training_loop(
_ = list(train_dataloader.sampler)

for epoch in range(epochs_trained, num_train_epochs):
if self.use_cuda_amp and hasattr(self, "qat_active") and callable(self.qat_active) and self.qat_active(epoch):
logger.info("entering QAT phase, disabling FP16 training")
self.scaler._enabled = False

if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
Expand Down Expand Up @@ -2167,7 +2171,12 @@ def _save_checkpoint(self, model, trial, metrics=None):
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
if (
metrics is not None
and self.args.metric_for_best_model is not None
and self.args.best_model_after_epoch is not None
and self.state.epoch > self.args.best_model_after_epoch
):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
Expand Down Expand Up @@ -2421,14 +2430,14 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s

return inputs

def compute_loss_context_manager(self):
def compute_loss_context_manager(self, enabled):
"""
A helper wrapper to group together context managers.
"""
return ContextManagers(
[
self.torchdynamo_smart_context_manager(),
self.autocast_smart_context_manager(),
self.autocast_smart_context_manager(enabled=enabled),
]
)

Expand All @@ -2438,7 +2447,7 @@ def torchdynamo_smart_context_manager(self):
"""
return self.ctx_manager_torchdynamo

def autocast_smart_context_manager(self):
def autocast_smart_context_manager(self, enabled):
"""
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
arguments, depending on the situation.
Expand All @@ -2448,10 +2457,10 @@ def autocast_smart_context_manager(self):
ctx_manager = (
torch.cpu.amp.autocast(dtype=self.amp_dtype)
if self.use_cpu_amp
else torch.cuda.amp.autocast(dtype=self.amp_dtype)
else torch.cuda.amp.autocast(dtype=self.amp_dtype, enabled=enabled)
)
else:
ctx_manager = torch.cuda.amp.autocast()
ctx_manager = torch.cuda.amp.autocast(enabled=enabled)
else:
ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

Expand Down Expand Up @@ -2482,7 +2491,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)

with self.compute_loss_context_manager():
with self.compute_loss_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()):
loss = self.compute_loss(model, inputs)

if self.args.n_gpu > 1:
Expand Down Expand Up @@ -2939,7 +2948,14 @@ def evaluation_loop(

observed_num_examples = 0
# Main evaluation loop
module_forward_fn = model.module.forward if isinstance(model, nn.DataParallel) else model.forward
for step, inputs in enumerate(dataloader):
inputs = {
k: inputs[k]
for k in inputs
if k in list(inspect.signature(module_forward_fn).parameters.keys())
}

# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
Expand Down Expand Up @@ -3191,7 +3207,9 @@ def prediction_step(
logits = smp_nested_concat(logits_mb)
else:
if has_labels:
with self.compute_loss_context_manager():
with self.compute_loss_context_manager(
enabled=hasattr(self, "scaler") and self.scaler.is_enabled()
):
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()

Expand All @@ -3201,7 +3219,9 @@ def prediction_step(
logits = outputs[1:]
else:
loss = None
with self.compute_loss_context_manager():
with self.compute_loss_context_manager(
enabled=hasattr(self, "scaler") and self.scaler.is_enabled()
):
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def prediction_step(
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)

with torch.no_grad():
with self.compute_loss_context_manager():
with self.compute_loss_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()):
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
CONFIG_NAME = "config.json"
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
MODEL_CARD_NAME = "modelcard.json"
RECIPE_NAME = "recipe.yaml"

SENTENCEPIECE_UNDERLINE = "▁"
SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,9 @@ class _LazyModule(ModuleType):
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""

# flag to signal NM integration is active
NM_INTEGRATED = True

# Very heavily inspired by optuna.integration._IntegrationModule
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
Expand Down