Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
self.assistant_kwargs.pop("attention_mask", None)

assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences, assistant_input_ids)
new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences)

# Update state
self.prev_target_ids_len = input_ids.shape[1]
Expand Down Expand Up @@ -583,7 +583,7 @@ def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> tuple[tor
return assistant_input_ids, remove_from_pkv

def _process_assistant_outputs(
self, input_ids: torch.LongTensor, assistant_sequences: torch.LongTensor, assistant_input_ids: torch.LongTensor
self, input_ids: torch.LongTensor, assistant_sequences: torch.LongTensor
) -> torch.LongTensor:
"""Processes assistant outputs to obtain target input IDs."""
num_prev_assistant = self.prev_assistant_ids.shape[1]
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,6 @@ def _find_missing_and_unexpected_keys(
checkpoint_keys: list[str],
loading_base_model_from_task_state_dict: bool,
hf_quantizer: Optional[HfQuantizer],
device_map: dict,
) -> tuple[list[str], list[str]]:
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
Expand Down Expand Up @@ -2713,7 +2712,7 @@ def _check_and_adjust_attn_implementation(
try:
self._sdpa_can_dispatch(is_init_check)
applicable_attn_implementation = "sdpa"
except (ValueError, ImportError) as e:
except (ValueError, ImportError):
applicable_attn_implementation = "eager"
else:
applicable_attn_implementation = self.get_correct_attn_implementation(
Expand Down Expand Up @@ -5318,7 +5317,6 @@ def _load_pretrained_model(
checkpoint_keys,
loading_base_model_from_task_state_dict,
hf_quantizer,
device_map,
)
# Find all the keys with shape mismatch (if we ignore the mismatch, the weights need to be newly initialized the
# same way as missing keys)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/sew_d/modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def forward(ctx, input, mask, dim):
@staticmethod
def backward(ctx, grad_output):
(output,) = ctx.saved_tensors
inputGrad = softmax_backward_data(ctx, grad_output, output, ctx.dim, output)
inputGrad = softmax_backward_data(ctx, grad_output, output)
return inputGrad, None, None

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/fill_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def postprocess(self, model_outputs, top_k=5, target_ids=None):
return result[0]
return result

def get_target_ids(self, targets, top_k=None):
def get_target_ids(self, targets):
if isinstance(targets, str):
targets = [targets]
try:
Expand Down Expand Up @@ -213,7 +213,7 @@ def _sanitize_parameters(self, top_k=None, targets=None, tokenizer_kwargs=None):
postprocess_params = {}

if targets is not None:
target_ids = self.get_target_ids(targets, top_k)
target_ids = self.get_target_ids(targets)
postprocess_params["target_ids"] = target_ids

if top_k is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/pipelines/table_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _sanitize_parameters(self, sequential=None, padding=None, truncation=None, *

return preprocess_params, forward_params, {}

def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None):
def preprocess(self, pipeline_input, padding=True, truncation=None):
if truncation is None:
if self.type == "tapas":
truncation = "drop_rows_to_fit"
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@
_torch_distributed_available = torch.distributed.is_available()


def softmax_backward_data(parent, grad_output, output, dim, self):
def softmax_backward_data(parent, grad_output, output):
"""
A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
to the torch version detected.
"""

from torch import _softmax_backward_data

return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
return _softmax_backward_data(grad_output, output, parent.dim, output.dtype)

Comment on lines -53 to 62
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe just get rid of this entirely? It seems like it doesn't do anything anymore, and I don't see anywhere else in the library that's using it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grep softmax_backward_data -r src/
src/transformers/models/sew_d/modeling_sew_d.py:from ...pytorch_utils import softmax_backward_data
src/transformers/models/sew_d/modeling_sew_d.py:        inputGrad = softmax_backward_data(ctx, grad_output, output)
src/transformers/pytorch_utils.py:def softmax_backward_data(parent, grad_output, output):
src/transformers/pytorch_utils.py:    A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
src/transformers/pytorch_utils.py:    from torch import _softmax_backward_data
src/transformers/pytorch_utils.py:    return _softmax_backward_data(grad_output, output, parent.dim, output.dtype)

It is in use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the sew_d case it is used in a manually implemented attention. I can replace it with SDPA but that change is better put in another pR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I missed that one somehow! Let's leave it, since it's actually in use.


def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,7 @@ def _secs2timedelta(secs):
return f"{datetime.timedelta(seconds=int(secs))}.{msec:02d}"


def metrics_format(self, metrics: dict[str, float]) -> dict[str, float]:
def metrics_format(metrics: dict[str, float]) -> dict[str, float]:
"""
Reformat Trainer metrics values to a human-readable format.

Expand Down Expand Up @@ -1038,7 +1038,7 @@ def log_metrics(self, split, metrics):
return

print(f"***** {split} metrics *****")
metrics_formatted = self.metrics_format(metrics)
metrics_formatted = metrics_format(metrics)
k_width = max(len(str(x)) for x in metrics_formatted)
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
Expand Down
15 changes: 5 additions & 10 deletions src/transformers/utils/auto_docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ def get_checkpoint_from_config_class(config_class):
return checkpoint


def add_intro_docstring(func, class_name, parent_class=None, indent_level=0):
def add_intro_docstring(func, class_name, indent_level=0):
intro_docstring = ""
if func.__name__ == "forward":
intro_docstring = rf"""The [`{class_name}`] forward method, overrides the `__call__` special method.
Expand Down Expand Up @@ -1469,17 +1469,14 @@ def find_sig_line(lines, line_end):
return sig_line_end


def _process_kwargs_parameters(
sig, func, parent_class, model_name_lowercase, documented_kwargs, indent_level, undocumented_parameters
):
def _process_kwargs_parameters(sig, func, parent_class, documented_kwargs, indent_level, undocumented_parameters):
"""
Process **kwargs parameters if needed.

Args:
sig (`inspect.Signature`): Function signature
func (`function`): Function the parameters belong to
parent_class (`class`): Parent class of the function
model_name_lowercase (`str`): Lowercase model name
documented_kwargs (`dict`): Dictionary of kwargs that are already documented
indent_level (`int`): Indentation level
undocumented_parameters (`list`): List to append undocumented parameters to
Expand Down Expand Up @@ -1510,7 +1507,7 @@ def _process_kwargs_parameters(
# Extract documentation for kwargs
kwargs_documentation = kwarg_param.annotation.__args__[0].__doc__
if kwargs_documentation is not None:
documented_kwargs, _ = parse_docstring(kwargs_documentation)
documented_kwargs = parse_docstring(kwargs_documentation)[0]

# Process each kwarg parameter
for param_name, param_type_annotation in kwarg_param.annotation.__args__[0].__annotations__.items():
Expand Down Expand Up @@ -1597,7 +1594,7 @@ def _process_parameters_section(

# Process **kwargs parameters if needed
kwargs_docstring = _process_kwargs_parameters(
sig, func, parent_class, model_name_lowercase, documented_kwargs, indent_level, undocumented_parameters
sig, func, parent_class, documented_kwargs, indent_level, undocumented_parameters
)
docstring += kwargs_docstring

Expand Down Expand Up @@ -1757,9 +1754,7 @@ def auto_method_docstring(
if not docstring.strip().endswith("\n"):
docstring += "\n"
else:
docstring = add_intro_docstring(
func, class_name=class_name, parent_class=parent_class, indent_level=indent_level
)
docstring = add_intro_docstring(func, class_name=class_name, indent_level=indent_level)

# Process Parameters section
docstring += _process_parameters_section(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ def is_torch_npu_available(check_device=False) -> bool:


@lru_cache
def is_torch_mlu_available(check_device=False) -> bool:
def is_torch_mlu_available() -> bool:
"""
Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu
uninitialized.
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,6 @@ def sample_indices_fn_func(metadata, **fn_kwargs):

def convert_to_rgb(
video: np.ndarray,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Expand All @@ -723,15 +722,13 @@ def convert_to_rgb(
Args:
video (`np.array`):
The video to convert.
data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output video. If unset, will use the inferred format from the input.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input video. If unset, will use the inferred format from the input.
"""
if not isinstance(video, np.ndarray):
raise TypeError(f"Video has to be a numpy array to convert to RGB format, but found {type(video)}")

# np.array usually comes with ChannelDimension.LAST so leet's convert it
# np.array usually comes with ChannelDimension.LAST so let's convert it
if input_data_format is None:
input_data_format = infer_channel_dimension_format(video)
video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_channel_dim=input_data_format)
Expand Down