diff --git a/.circleci/config.yml b/.circleci/config.yml index 7875cdc368f5..ab63a3823c2f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -156,6 +156,7 @@ jobs: path: ~/transformers/installed.txt - run: ruff check examples tests src utils scripts benchmark benchmark_v2 setup.py conftest.py - run: ruff format --check examples tests src utils scripts benchmark benchmark_v2 setup.py conftest.py + - run: ty check src/transformers/utils/*.py --force-exclude --exclude '**/*_pb2*.py' - run: python utils/custom_init_isort.py --check_only - run: python utils/sort_auto_mappings.py --check_only diff --git a/Makefile b/Makefile index 8b3b4dc2acba..ba78e2a4d461 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,11 @@ export PYTHONPATH = src check_dirs := examples tests src utils scripts benchmark benchmark_v2 exclude_folders := "" +# Helper to find all Python files in directories (ty doesn't recursively scan directories) +define get_py_files +$(shell find $(1) -name "*.py" -type f 2>/dev/null) +endef + # this runs all linting/formatting scripts, most notably ruff style: @@ -20,6 +25,7 @@ style: check-repo: ruff check $(check_dirs) setup.py conftest.py ruff format --check $(check_dirs) setup.py conftest.py + ty check $(call get_py_files,src/transformers/utils) --force-exclude --exclude '**/*_pb2*.py' -python utils/custom_init_isort.py --check_only -python utils/sort_auto_mappings.py --check_only -python -c "from transformers import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1) diff --git a/docker/quality.dockerfile b/docker/quality.dockerfile index 6455a27d642b..97987b0d098d 100644 --- a/docker/quality.dockerfile +++ b/docker/quality.dockerfile @@ -5,5 +5,5 @@ USER root RUN apt-get update && apt-get install -y time git ENV UV_PYTHON=/usr/local/bin/python RUN pip install uv -RUN uv pip install --no-cache-dir -U pip setuptools GitPython "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[ruff]" urllib3 +RUN uv pip install --no-cache-dir -U pip setuptools GitPython "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[quality]" urllib3 RUN apt-get install -y jq curl && apt-get clean && rm -rf /var/lib/apt/lists/* diff --git a/pyproject.toml b/pyproject.toml index 2705851dd49a..c138b905cd21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,3 +92,23 @@ env = [ # Note: 'D:' means default value from laptop or CI won't be overwritten "D:HF_HUB_DOWNLOAD_TIMEOUT=60", ] + +[tool.ty] +# ty type checker configuration +# Using default settings for comprehensive type checking + +[tool.ty.rules] +# Disable specific rules that produce false positives or are too strict for this codebase +invalid-method-override = "ignore" # Parameter name differences are acceptable (e.g., x vs input, new_embeddings vs value) +not-subscriptable = "ignore" # False positives on tensor slicing (e.g., self.position_ids[:, :seq_length]) +no-matching-overload = "ignore" # False positives on torch.zeros and similar functions accepting Size/tuple +unsupported-operator = "ignore" # False positives on tuple concatenation with += when properly initialized +unresolved-import = "ignore" # Optional dependencies (mlx, torch_npu, habana_frameworks, etc.) checked at runtime +call-non-callable = "ignore" # Mixin pattern issues where classes are used as both types and callables +unresolved-reference = "ignore" # Forward references with noqa: F821 that ty doesn't respect +invalid-argument-type = "ignore" # Complex type narrowing and union type issues +not-iterable = "ignore" # Complex async/Future type patterns +invalid-return-type = "ignore" # Return type mismatches that would require refactoring +deprecated = "ignore" # Deprecation warnings from dependencies +invalid-assignment = "ignore" # Low-level assignments that are runtime-safe +unused-ignore-comment = "ignore" # Ignore comments that became unnecessary after adding broader per-file-ignores diff --git a/setup.py b/setup.py index 616e16e67f2e..f0b41cb3dc3c 100644 --- a/setup.py +++ b/setup.py @@ -126,6 +126,7 @@ "rjieba", "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "ruff==0.14.10", + "ty==0.0.12", # `sacrebleu` not used in `transformers`. However, it is needed in several tests, when a test calls # `evaluate.load("sacrebleu")`. This metric is used in the examples that we use to test the `Trainer` with, in the # `Trainer` tests (see references to `run_translation.py`). @@ -182,7 +183,7 @@ def deps_list(*pkgs): extras["audio"] += deps_list("kenlm") extras["video"] = deps_list("av") extras["timm"] = deps_list("timm") -extras["quality"] = deps_list("datasets", "ruff", "GitPython", "urllib3", "libcst", "rich") +extras["quality"] = deps_list("datasets", "ruff", "GitPython", "urllib3", "libcst", "rich", "ty") extras["kernels"] = deps_list("kernels") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["tiktoken"] = deps_list("tiktoken", "blobfile") diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 1b38c0a702d3..a68a55c869ea 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -58,6 +58,7 @@ "rjieba": "rjieba", "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "ruff": "ruff==0.14.10", + "ty": "ty==0.0.12", "sacrebleu": "sacrebleu>=1.4.12,<2.0.0", "sacremoses": "sacremoses", "safetensors": "safetensors>=0.4.3", diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index aabfd0bbe268..4f065ab7a748 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -35,7 +35,7 @@ import numpy as np import packaging.version -from transformers.utils.import_utils import _is_package_available +from transformers.utils.import_utils import is_pynvml_available if os.getenv("WANDB_MODE") == "offline": @@ -1030,7 +1030,7 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs): f"gpu/{device_idx}/allocated_memory": memory_allocated / (1024**3), # GB f"gpu/{device_idx}/memory_usage": memory_allocated / total_memory, # ratio } - if _is_package_available("pynvml"): + if is_pynvml_available(): power = torch.cuda.power_draw(device_idx) gpu_memory_logs[f"gpu/{device_idx}/power"] = power / 1000 # Watts if dist.is_available() and dist.is_initialized(): diff --git a/src/transformers/models/dia/convert_dia_to_hf.py b/src/transformers/models/dia/convert_dia_to_hf.py index 732e71b54e32..067f176e1404 100644 --- a/src/transformers/models/dia/convert_dia_to_hf.py +++ b/src/transformers/models/dia/convert_dia_to_hf.py @@ -30,7 +30,7 @@ DiaTokenizer, GenerationConfig, ) -from transformers.utils.import_utils import _is_package_available +from transformers.utils.import_utils import is_tiktoken_available # Provide just the list of layer keys you want to fix @@ -180,7 +180,7 @@ def convert_dia_model_to_hf(checkpoint_path, verbose=False): model = convert_dia_model_to_hf(args.checkpoint_path, args.verbose) if args.convert_preprocessor: try: - if not _is_package_available("tiktoken"): + if not is_tiktoken_available(with_blobfile=False): raise ModuleNotFoundError( """`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer""" ) diff --git a/src/transformers/models/whisper/convert_openai_to_hf.py b/src/transformers/models/whisper/convert_openai_to_hf.py index b30caab9f261..9a26ddb3a0f2 100755 --- a/src/transformers/models/whisper/convert_openai_to_hf.py +++ b/src/transformers/models/whisper/convert_openai_to_hf.py @@ -38,7 +38,7 @@ WhisperTokenizerFast, ) from transformers.models.whisper.tokenization_whisper import LANGUAGES, bytes_to_unicode -from transformers.utils.import_utils import _is_package_available +from transformers.utils.import_utils import is_tiktoken_available _MODELS = { @@ -345,7 +345,7 @@ def convert_tiktoken_to_hf( if args.convert_preprocessor: try: - if not _is_package_available("tiktoken"): + if not is_tiktoken_available(with_blobfile=False): raise ModuleNotFoundError( """`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer""" ) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 62918bd277ed..a857beeb76ae 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -26,10 +26,10 @@ import shutil import threading import time -from collections.abc import Callable +from collections.abc import Callable, Sized from functools import partial from pathlib import Path -from typing import Any, NamedTuple +from typing import Any, NamedTuple, TypeGuard import numpy as np @@ -820,7 +820,7 @@ def stop_and_update_metrics(self, metrics=None): self.update_metrics(stage, metrics) -def has_length(dataset): +def has_length(dataset: Any) -> TypeGuard[Sized]: """ Checks if the dataset implements __len__() and it doesn't raise an error """ diff --git a/src/transformers/utils/_typing.py b/src/transformers/utils/_typing.py new file mode 100644 index 000000000000..c98703340ee1 --- /dev/null +++ b/src/transformers/utils/_typing.py @@ -0,0 +1,123 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import logging +from collections.abc import Mapping, MutableMapping +from typing import Any, Protocol, TypeAlias + + +# A few helpful type aliases +Level: TypeAlias = int +ExcInfo: TypeAlias = ( + None + | bool + | BaseException + | tuple[type[BaseException], BaseException, object] # traceback is `types.TracebackType`, but keep generic here +) + + +class TransformersLogger(Protocol): + # ---- Core Logger identity / configuration ---- + name: str + level: int + parent: logging.Logger | None + propagate: bool + disabled: bool + handlers: list[logging.Handler] + + # Exists on Logger; default is True. (Not heavily used, but is part of API.) + raiseExceptions: bool # type: ignore[assignment] + + # ---- Standard methods ---- + def setLevel(self, level: Level) -> None: ... + def isEnabledFor(self, level: Level) -> bool: ... + def getEffectiveLevel(self) -> int: ... + + def getChild(self, suffix: str) -> logging.Logger: ... + + def addHandler(self, hdlr: logging.Handler) -> None: ... + def removeHandler(self, hdlr: logging.Handler) -> None: ... + def hasHandlers(self) -> bool: ... + + # ---- Logging calls ---- + def debug(self, msg: object, *args: object, **kwargs: object) -> None: ... + def info(self, msg: object, *args: object, **kwargs: object) -> None: ... + def warning(self, msg: object, *args: object, **kwargs: object) -> None: ... + def warn(self, msg: object, *args: object, **kwargs: object) -> None: ... + def error(self, msg: object, *args: object, **kwargs: object) -> None: ... + def exception(self, msg: object, *args: object, exc_info: ExcInfo = True, **kwargs: object) -> None: ... + def critical(self, msg: object, *args: object, **kwargs: object) -> None: ... + def fatal(self, msg: object, *args: object, **kwargs: object) -> None: ... + + # The lowest-level primitive + def log(self, level: Level, msg: object, *args: object, **kwargs: object) -> None: ... + + # ---- Record-level / formatting ---- + def makeRecord( + self, + name: str, + level: Level, + fn: str, + lno: int, + msg: object, + args: tuple[object, ...] | Mapping[str, object], + exc_info: ExcInfo, + func: str | None = None, + extra: Mapping[str, object] | None = None, + sinfo: str | None = None, + ) -> logging.LogRecord: ... + + def handle(self, record: logging.LogRecord) -> None: ... + def findCaller( + self, + stack_info: bool = False, + stacklevel: int = 1, + ) -> tuple[str, int, str, str | None]: ... + + def callHandlers(self, record: logging.LogRecord) -> None: ... + def getMessage(self) -> str: ... # NOTE: actually on LogRecord; included rarely; safe to omit if you want + + def _log( + self, + level: Level, + msg: object, + args: tuple[object, ...] | Mapping[str, object], + exc_info: ExcInfo = None, + extra: Mapping[str, object] | None = None, + stack_info: bool = False, + stacklevel: int = 1, + ) -> None: ... + + # ---- Filters ---- + def addFilter(self, filt: logging.Filter) -> None: ... + def removeFilter(self, filt: logging.Filter) -> None: ... + @property + def filters(self) -> list[logging.Filter]: ... + + def filter(self, record: logging.LogRecord) -> bool: ... + + # ---- Convenience helpers ---- + def setFormatter(self, fmt: logging.Formatter) -> None: ... # mostly on handlers; present on adapters sometimes + def debugStack(self, msg: object, *args: object, **kwargs: object) -> None: ... # not std; safe no-op if absent + + # ---- stdlib dictConfig-friendly / extra storage ---- + # Logger has `manager` and can have arbitrary attributes; Protocol can't express arbitrary attrs, + # but we can at least include `__dict__` to make "extra attributes" less painful. + __dict__: MutableMapping[str, Any] + + # ---- Transformers logger specific methods ---- + def warning_advice(self, msg: object, *args: object, **kwargs: object) -> None: ... + def warning_once(self, msg: object, *args: object, **kwargs: object) -> None: ... + def info_once(self, msg: object, *args: object, **kwargs: object) -> None: ... diff --git a/src/transformers/utils/attention_visualizer.py b/src/transformers/utils/attention_visualizer.py index 592c283758e4..5d04ec247014 100644 --- a/src/transformers/utils/attention_visualizer.py +++ b/src/transformers/utils/attention_visualizer.py @@ -221,10 +221,14 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""): past_key_values=None, ) - if causal_mask is not None: - attention_mask = ~causal_mask.bool() - else: + if causal_mask is None: + # attention_mask must be a tensor here attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, seq_length, seq_length) + elif isinstance(causal_mask, torch.Tensor): + attention_mask = ~causal_mask.to(dtype=torch.bool) + else: + attention_mask = ~causal_mask + top_bottom_border = "##" * ( len(f"Attention visualization for {self.config.model_type} | {self.mapped_cls}") + 4 ) # Box width adjusted to text length diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index ed3a6daee73e..b30dfa40bf34 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -22,14 +22,7 @@ from datetime import datetime from functools import lru_cache from inspect import isfunction -from typing import ( - Any, - Literal, - Union, - get_args, - get_origin, - get_type_hints, -) +from typing import Any, Literal, Union, get_args, get_origin, get_type_hints, no_type_check from packaging import version @@ -41,6 +34,10 @@ if is_jinja_available(): import jinja2 + import jinja2.exceptions + import jinja2.ext + import jinja2.nodes + import jinja2.runtime from jinja2.ext import Extension from jinja2.sandbox import ImmutableSandboxedEnvironment else: @@ -181,10 +178,11 @@ def _parse_type_hint(hint: str) -> dict: def _convert_type_hints_to_json_schema(func: Callable) -> dict: type_hints = get_type_hints(func) signature = inspect.signature(func) + func_name = getattr(func, "__name__", "operation") required = [] for param_name, param in signature.parameters.items(): if param.annotation == inspect.Parameter.empty: - raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") + raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func_name}") if param.default == inspect.Parameter.empty: required.append(param_name) @@ -340,10 +338,10 @@ def get_json_schema(func: Callable) -> dict: } """ doc = inspect.getdoc(func) + func_name = getattr(func, "__name__", "operation") + if not doc: - raise DocstringParsingException( - f"Cannot generate JSON schema for {func.__name__} because it has no docstring!" - ) + raise DocstringParsingException(f"Cannot generate JSON schema for {func_name} because it has no docstring!") doc = doc.strip() main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc) @@ -354,7 +352,7 @@ def get_json_schema(func: Callable) -> dict: for arg, schema in json_schema["properties"].items(): if arg not in param_descriptions: raise DocstringParsingException( - f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" + f"Cannot generate JSON schema for {func_name} because the docstring has no description for the argument '{arg}'" ) desc = param_descriptions[arg] enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE) @@ -363,7 +361,7 @@ def get_json_schema(func: Callable) -> dict: desc = enum_choices.string[: enum_choices.start()].strip() schema["description"] = desc - output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} + output = {"name": func_name, "description": main_doc, "parameters": json_schema} if return_dict is not None: output["return"] = return_dict return {"type": "function", "function": output} @@ -389,6 +387,11 @@ def _render_with_assistant_indices( @lru_cache def _compile_jinja_template(chat_template): + return _cached_compile_jinja_template(chat_template) + + +@no_type_check +def _cached_compile_jinja_template(chat_template): if not is_jinja_available(): raise ImportError( "apply_chat_template requires jinja2 to be installed. Please install it using `pip install jinja2`." diff --git a/src/transformers/utils/doc.py b/src/transformers/utils/doc.py index eb648a205ccc..4e46e4230bac 100644 --- a/src/transformers/utils/doc.py +++ b/src/transformers/utils/doc.py @@ -21,6 +21,7 @@ import textwrap import types from collections import OrderedDict +from typing import cast def get_docstring_indentation_level(func): @@ -1091,6 +1092,6 @@ def copy_func(f): """Returns a copy of a function f.""" # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard) g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__) - g = functools.update_wrapper(g, f) + g = cast(types.FunctionType, functools.update_wrapper(g, f)) g.__kwdefaults__ = f.__kwdefaults__ return g diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 6b76adf073cc..17a396e170aa 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -457,7 +457,7 @@ def _model_output_flatten(output: ModelOutput) -> tuple[list[Any], "_torch_pytre def _model_output_unflatten( values: Iterable[Any], context: "_torch_pytree.Context", - output_type=None, + output_type: type[ModelOutput] | None = None, ) -> ModelOutput: return output_type(**dict(zip(context, values))) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 127ae9bdc595..b3d5c19f2984 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -168,7 +168,8 @@ def define_sagemaker_information(): sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}")) runs_distributed_training = "sagemaker_distributed_dataparallel_enabled" in sagemaker_params - account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None + training_job_arn = os.getenv("TRAINING_JOB_ARN") + account_id = training_job_arn.split(":")[4] if training_job_arn is not None else None sagemaker_object = { "sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None), @@ -295,7 +296,7 @@ def cached_files( _raise_exceptions_for_connection_errors: bool = True, _commit_hash: str | None = None, **deprecated_kwargs, -) -> str | None: +) -> list[str] | None: """ Tries to locate several files in a local folder and repo, downloads and cache them if necessary. @@ -708,6 +709,10 @@ def _upload_modified_files( revision=revision, ) + def save_pretrained(self, *args, **kwargs): + # explicit contract + raise NotImplementedError(f"{self.__class__.__name__} must implement `save_pretrained` to use `push_to_hub`.") + def push_to_hub( self, repo_id: str, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 00864bdcc35b..dc82b8e903fb 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -45,7 +45,7 @@ PACKAGE_DISTRIBUTION_MAPPING = importlib.metadata.packages_distributions() -def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: +def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str]: """Check if `pkg_name` exist, and optionally try to get its version""" spec = importlib.util.find_spec(pkg_name) package_exists = spec is not None @@ -71,10 +71,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[ package = importlib.import_module(pkg_name) package_version = getattr(package, "__version__", "N/A") logger.debug(f"Detected {pkg_name} version: {package_version}") + if return_version: return package_exists, package_version else: - return package_exists + return package_exists, None def is_env_variable_true(env_variable: str) -> bool: @@ -222,7 +223,7 @@ def is_torch_mps_available(min_version: str | None = None) -> bool: @lru_cache def is_torch_npu_available(check_device=False) -> bool: "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" - if not is_torch_available() or not _is_package_available("torch_npu"): + if not is_torch_available() or not _is_package_available("torch_npu")[0]: return False import torch @@ -231,8 +232,10 @@ def is_torch_npu_available(check_device=False) -> bool: if check_device: try: # Will raise a RuntimeError if no NPU is found - _ = torch.npu.device_count() - return torch.npu.is_available() + if hasattr(torch, "npu"): + _ = torch.npu.device_count() + return torch.npu.is_available() + return False except RuntimeError: return False return hasattr(torch, "npu") and torch.npu.is_available() @@ -269,7 +272,7 @@ def is_torch_mlu_available() -> bool: Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu uninitialized. """ - if not is_torch_available() or not _is_package_available("torch_mlu"): + if not is_torch_available() or not _is_package_available("torch_mlu")[0]: return False import torch @@ -278,7 +281,7 @@ def is_torch_mlu_available() -> bool: pytorch_cndev_based_mlu_check_previous_value = os.environ.get("PYTORCH_CNDEV_BASED_MLU_CHECK") try: os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = str(1) - available = torch.mlu.is_available() + available = torch.mlu.is_available() if hasattr(torch, "mlu") else False finally: if pytorch_cndev_based_mlu_check_previous_value: os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = pytorch_cndev_based_mlu_check_previous_value @@ -291,7 +294,7 @@ def is_torch_mlu_available() -> bool: @lru_cache def is_torch_musa_available(check_device=False) -> bool: "Checks if `torch_musa` is installed and potentially if a MUSA is in the environment" - if not is_torch_available() or not _is_package_available("torch_musa"): + if not is_torch_available() or not _is_package_available("torch_musa")[0]: return False import torch @@ -305,8 +308,10 @@ def is_torch_musa_available(check_device=False) -> bool: if check_device: try: # Will raise a RuntimeError if no MUSA is found - _ = torch.musa.device_count() - return torch.musa.is_available() + if hasattr(torch, "musa"): + _ = torch.musa.device_count() + return torch.musa.is_available() + return False except RuntimeError: return False return hasattr(torch, "musa") and torch.musa.is_available() @@ -320,7 +325,7 @@ def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False) -> bool: """ assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true." - torch_xla_available = USE_TORCH_XLA in ENV_VARS_TRUE_VALUES and _is_package_available("torch_xla") + torch_xla_available = USE_TORCH_XLA in ENV_VARS_TRUE_VALUES and _is_package_available("torch_xla")[0] if not torch_xla_available: return False @@ -339,8 +344,8 @@ def is_torch_hpu_available() -> bool: "Checks if `torch.hpu` is available and potentially if a HPU is in the environment" if ( not is_torch_available() - or not _is_package_available("habana_frameworks") - or not _is_package_available("habana_frameworks.torch") + or not _is_package_available("habana_frameworks")[0] + or not _is_package_available("habana_frameworks.torch")[0] ): return False @@ -444,12 +449,12 @@ def is_torch_bf16_gpu_available() -> bool: if is_torch_hpu_available(): return True if is_torch_npu_available(): - return torch.npu.is_bf16_supported() + return torch.npu.is_bf16_supported() if hasattr(torch, "npu") else False if is_torch_mps_available(): # Note: Emulated in software by Metal using fp32 for hardware without native support (like M1/M2) return torch.backends.mps.is_macos_or_newer(14, 0) if is_torch_musa_available(): - return torch.musa.is_bf16_supported() + return torch.musa.is_bf16_supported() if hasattr(torch, "musa") else False return False @@ -509,9 +514,10 @@ def is_torch_tf32_available() -> bool: import torch if is_torch_musa_available(): - device_info = torch.musa.get_device_properties(torch.musa.current_device()) - if f"{device_info.major}{device_info.minor}" >= "22": - return True + if hasattr(torch, "musa"): + device_info = torch.musa.get_device_properties(torch.musa.current_device()) + if f"{device_info.major}{device_info.minor}" >= "22": + return True return False if not torch.cuda.is_available() or torch.version.cuda is None: return False @@ -534,10 +540,12 @@ def enable_tf32(enable: bool) -> None: pytorch_version = version.parse(get_torch_version()) if pytorch_version >= version.parse("2.9.0"): precision_mode = "tf32" if enable else "ieee" - torch.backends.fp32_precision = precision_mode + if hasattr(torch.backends, "fp32_precision"): + torch.backends.fp32_precision = precision_mode else: if is_torch_musa_available(): - torch.backends.mudnn.allow_tf32 = enable + if hasattr(torch.backends, "mudnn"): + torch.backends.mudnn.allow_tf32 = enable else: torch.backends.cuda.matmul.allow_tf32 = enable torch.backends.cudnn.allow_tf32 = enable @@ -555,7 +563,7 @@ def is_grouped_mm_available() -> bool: @lru_cache def is_kenlm_available() -> bool: - return _is_package_available("kenlm") + return _is_package_available("kenlm")[0] @lru_cache @@ -566,17 +574,17 @@ def is_kernels_available(MIN_VERSION: str = KERNELS_MIN_VERSION) -> bool: @lru_cache def is_cv2_available() -> bool: - return _is_package_available("cv2") + return _is_package_available("cv2")[0] @lru_cache def is_yt_dlp_available() -> bool: - return _is_package_available("yt_dlp") + return _is_package_available("yt_dlp")[0] @lru_cache def is_libcst_available() -> bool: - return _is_package_available("libcst") + return _is_package_available("libcst")[0] @lru_cache @@ -593,7 +601,7 @@ def is_triton_available(min_version: str = TRITON_MIN_VERSION) -> bool: @lru_cache def is_hadamard_available() -> bool: - return _is_package_available("fast_hadamard_transform") + return _is_package_available("fast_hadamard_transform")[0] @lru_cache @@ -604,12 +612,12 @@ def is_hqq_available(min_version: str = HQQ_MIN_VERSION) -> bool: @lru_cache def is_pygments_available() -> bool: - return _is_package_available("pygments") + return _is_package_available("pygments")[0] @lru_cache def is_torchvision_available() -> bool: - return _is_package_available("torchvision") + return _is_package_available("torchvision")[0] @lru_cache @@ -619,27 +627,27 @@ def is_torchvision_v2_available() -> bool: @lru_cache def is_galore_torch_available() -> bool: - return _is_package_available("galore_torch") + return _is_package_available("galore_torch")[0] @lru_cache def is_apollo_torch_available() -> bool: - return _is_package_available("apollo_torch") + return _is_package_available("apollo_torch")[0] @lru_cache def is_torch_optimi_available() -> bool: - return _is_package_available("optimi") + return _is_package_available("optimi")[0] @lru_cache def is_lomo_available() -> bool: - return _is_package_available("lomo_optim") + return _is_package_available("lomo_optim")[0] @lru_cache def is_grokadamw_available() -> bool: - return _is_package_available("grokadamw") + return _is_package_available("grokadamw")[0] @lru_cache @@ -650,47 +658,47 @@ def is_schedulefree_available(min_version: str = SCHEDULEFREE_MIN_VERSION) -> bo @lru_cache def is_pyctcdecode_available() -> bool: - return _is_package_available("pyctcdecode") + return _is_package_available("pyctcdecode")[0] @lru_cache def is_librosa_available() -> bool: - return _is_package_available("librosa") + return _is_package_available("librosa")[0] @lru_cache def is_essentia_available() -> bool: - return _is_package_available("essentia") + return _is_package_available("essentia")[0] @lru_cache def is_pydantic_available() -> bool: - return _is_package_available("pydantic") + return _is_package_available("pydantic")[0] @lru_cache def is_fastapi_available() -> bool: - return _is_package_available("fastapi") + return _is_package_available("fastapi")[0] @lru_cache def is_uvicorn_available() -> bool: - return _is_package_available("uvicorn") + return _is_package_available("uvicorn")[0] @lru_cache def is_openai_available() -> bool: - return _is_package_available("openai") + return _is_package_available("openai")[0] @lru_cache def is_pretty_midi_available() -> bool: - return _is_package_available("pretty_midi") + return _is_package_available("pretty_midi")[0] @lru_cache def is_mamba_ssm_available() -> bool: - return is_torch_cuda_available() and _is_package_available("mamba_ssm") + return is_torch_cuda_available() and _is_package_available("mamba_ssm")[0] @lru_cache @@ -707,37 +715,37 @@ def is_flash_linear_attention_available(): @lru_cache def is_causal_conv1d_available() -> bool: - return is_torch_cuda_available() and _is_package_available("causal_conv1d") + return is_torch_cuda_available() and _is_package_available("causal_conv1d")[0] @lru_cache def is_xlstm_available() -> bool: - return is_torch_available() and _is_package_available("xlstm") + return is_torch_available() and _is_package_available("xlstm")[0] @lru_cache def is_mambapy_available() -> bool: - return is_torch_available() and _is_package_available("mambapy") + return is_torch_available() and _is_package_available("mambapy")[0] @lru_cache def is_peft_available() -> bool: - return _is_package_available("peft") + return _is_package_available("peft")[0] @lru_cache def is_bs4_available() -> bool: - return _is_package_available("bs4") + return _is_package_available("bs4")[0] @lru_cache def is_coloredlogs_available() -> bool: - return _is_package_available("coloredlogs") + return _is_package_available("coloredlogs")[0] @lru_cache def is_onnx_available() -> bool: - return _is_package_available("onnx") + return _is_package_available("onnx")[0] @lru_cache @@ -748,22 +756,22 @@ def is_flute_available() -> bool: @lru_cache def is_g2p_en_available() -> bool: - return _is_package_available("g2p_en") + return _is_package_available("g2p_en")[0] @lru_cache def is_torch_neuroncore_available(check_device=True) -> bool: - return is_torch_xla_available() and _is_package_available("torch_neuronx") + return is_torch_xla_available() and _is_package_available("torch_neuronx")[0] @lru_cache def is_torch_tensorrt_fx_available() -> bool: - return _is_package_available("torch_tensorrt") and _is_package_available("torch_tensorrt.fx") + return _is_package_available("torch_tensorrt")[0] and _is_package_available("torch_tensorrt.fx")[0] @lru_cache def is_datasets_available() -> bool: - return _is_package_available("datasets") + return _is_package_available("datasets")[0] @lru_cache @@ -781,32 +789,32 @@ def is_detectron2_available() -> bool: @lru_cache def is_rjieba_available() -> bool: - return _is_package_available("rjieba") + return _is_package_available("rjieba")[0] @lru_cache def is_psutil_available() -> bool: - return _is_package_available("psutil") + return _is_package_available("psutil")[0] @lru_cache def is_py3nvml_available() -> bool: - return _is_package_available("py3nvml") + return _is_package_available("py3nvml")[0] @lru_cache def is_sacremoses_available() -> bool: - return _is_package_available("sacremoses") + return _is_package_available("sacremoses")[0] @lru_cache def is_apex_available() -> bool: - return _is_package_available("apex") + return _is_package_available("apex")[0] @lru_cache def is_aqlm_available() -> bool: - return _is_package_available("aqlm") + return _is_package_available("aqlm")[0] @lru_cache @@ -817,17 +825,17 @@ def is_vptq_available(min_version: str = VPTQ_MIN_VERSION) -> bool: @lru_cache def is_av_available() -> bool: - return _is_package_available("av") + return _is_package_available("av")[0] @lru_cache def is_decord_available() -> bool: - return _is_package_available("decord") + return _is_package_available("decord")[0] @lru_cache def is_torchcodec_available() -> bool: - return _is_package_available("torchcodec") + return _is_package_available("torchcodec")[0] @lru_cache @@ -874,7 +882,7 @@ def is_flash_attn_2_available() -> bool: @lru_cache def is_flash_attn_3_available() -> bool: - return is_torch_cuda_available() and _is_package_available("flash_attn_3") + return is_torch_cuda_available() and _is_package_available("flash_attn_3")[0] @lru_cache @@ -925,32 +933,32 @@ def is_quanto_greater(library_version: str, accept_dev: bool = False) -> bool: @lru_cache def is_torchdistx_available(): - return _is_package_available("torchdistx") + return _is_package_available("torchdistx")[0] @lru_cache def is_faiss_available() -> bool: - return _is_package_available("faiss") + return _is_package_available("faiss")[0] @lru_cache def is_scipy_available() -> bool: - return _is_package_available("scipy") + return _is_package_available("scipy")[0] @lru_cache def is_sklearn_available() -> bool: - return _is_package_available("sklearn") + return _is_package_available("sklearn")[0] @lru_cache def is_sentencepiece_available() -> bool: - return _is_package_available("sentencepiece") + return _is_package_available("sentencepiece")[0] @lru_cache def is_seqio_available() -> bool: - return _is_package_available("seqio") + return _is_package_available("seqio")[0] @lru_cache @@ -961,7 +969,7 @@ def is_gguf_available(min_version: str = GGUF_MIN_VERSION) -> bool: @lru_cache def is_protobuf_available() -> bool: - return _is_package_available("google") and _is_package_available("google.protobuf") + return _is_package_available("google")[0] and _is_package_available("google.protobuf")[0] @lru_cache @@ -971,12 +979,12 @@ def is_fsdp_available(min_version: str = FSDP_MIN_VERSION) -> bool: @lru_cache def is_optimum_available() -> bool: - return _is_package_available("optimum") + return _is_package_available("optimum")[0] @lru_cache def is_llm_awq_available() -> bool: - return _is_package_available("awq") + return _is_package_available("awq")[0] @lru_cache @@ -987,12 +995,12 @@ def is_auto_round_available(min_version: str = AUTOROUND_MIN_VERSION) -> bool: @lru_cache def is_optimum_quanto_available(): - return is_optimum_available() and _is_package_available("optimum.quanto") + return is_optimum_available() and _is_package_available("optimum.quanto")[0] @lru_cache def is_quark_available() -> bool: - return _is_package_available("quark") + return _is_package_available("quark")[0] @lru_cache @@ -1009,92 +1017,92 @@ def is_qutlass_available(): @lru_cache def is_compressed_tensors_available() -> bool: - return _is_package_available("compressed_tensors") + return _is_package_available("compressed_tensors")[0] @lru_cache def is_gptqmodel_available() -> bool: - return _is_package_available("gptqmodel") + return _is_package_available("gptqmodel")[0] @lru_cache def is_fbgemm_gpu_available() -> bool: - return _is_package_available("fbgemm_gpu") + return _is_package_available("fbgemm_gpu")[0] @lru_cache def is_levenshtein_available() -> bool: - return _is_package_available("Levenshtein") + return _is_package_available("Levenshtein")[0] @lru_cache def is_optimum_neuron_available() -> bool: - return is_optimum_available() and _is_package_available("optimum.neuron") + return is_optimum_available() and _is_package_available("optimum.neuron")[0] @lru_cache def is_tokenizers_available() -> bool: - return _is_package_available("tokenizers") + return _is_package_available("tokenizers")[0] @lru_cache def is_vision_available() -> bool: - return _is_package_available("PIL") + return _is_package_available("PIL")[0] @lru_cache def is_pytesseract_available() -> bool: - return _is_package_available("pytesseract") + return _is_package_available("pytesseract")[0] @lru_cache def is_pytest_available() -> bool: - return _is_package_available("pytest") + return _is_package_available("pytest")[0] @lru_cache def is_pytest_order_available() -> bool: - return is_pytest_available() and _is_package_available("pytest_order") + return is_pytest_available() and _is_package_available("pytest_order")[0] @lru_cache def is_spacy_available() -> bool: - return _is_package_available("spacy") + return _is_package_available("spacy")[0] @lru_cache def is_pytorch_quantization_available() -> bool: - return _is_package_available("pytorch_quantization") + return _is_package_available("pytorch_quantization")[0] @lru_cache def is_pandas_available() -> bool: - return _is_package_available("pandas") + return _is_package_available("pandas")[0] @lru_cache def is_soundfile_available() -> bool: - return _is_package_available("soundfile") + return _is_package_available("soundfile")[0] @lru_cache def is_timm_available() -> bool: - return _is_package_available("timm") + return _is_package_available("timm")[0] @lru_cache def is_natten_available() -> bool: - return _is_package_available("natten") + return _is_package_available("natten")[0] @lru_cache def is_nltk_available() -> bool: - return _is_package_available("nltk") + return _is_package_available("nltk")[0] @lru_cache def is_numba_available() -> bool: - is_available = _is_package_available("numba") + is_available = _is_package_available("numba")[0] if not is_available: return False @@ -1104,7 +1112,7 @@ def is_numba_available() -> bool: @lru_cache def is_torchaudio_available() -> bool: - return _is_package_available("torchaudio") + return _is_package_available("torchaudio")[0] @lru_cache @@ -1121,22 +1129,22 @@ def is_speech_available() -> bool: @lru_cache def is_spqr_available() -> bool: - return _is_package_available("spqr_quant") + return _is_package_available("spqr_quant")[0] @lru_cache def is_phonemizer_available() -> bool: - return _is_package_available("phonemizer") + return _is_package_available("phonemizer")[0] @lru_cache def is_uroman_available() -> bool: - return _is_package_available("uroman") + return _is_package_available("uroman")[0] @lru_cache def is_sudachi_available() -> bool: - return _is_package_available("sudachipy") + return _is_package_available("sudachipy")[0] @lru_cache @@ -1147,37 +1155,39 @@ def is_sudachi_projection_available() -> bool: @lru_cache def is_jumanpp_available() -> bool: - return _is_package_available("rhoknp") and shutil.which("jumanpp") is not None + return _is_package_available("rhoknp")[0] and shutil.which("jumanpp") is not None @lru_cache def is_cython_available() -> bool: - return _is_package_available("pyximport") + return _is_package_available("pyximport")[0] @lru_cache def is_jinja_available() -> bool: - return _is_package_available("jinja2") + return _is_package_available("jinja2")[0] @lru_cache def is_jmespath_available() -> bool: - return _is_package_available("jmespath") + return _is_package_available("jmespath")[0] @lru_cache def is_mlx_available() -> bool: - return _is_package_available("mlx") + return _is_package_available("mlx")[0] @lru_cache def is_num2words_available() -> bool: - return _is_package_available("num2words") + return _is_package_available("num2words")[0] @lru_cache -def is_tiktoken_available() -> bool: - return _is_package_available("tiktoken") and _is_package_available("blobfile") +def is_tiktoken_available(with_blobfile: bool = True) -> bool: + if not _is_package_available("tiktoken")[0]: + return False + return with_blobfile and _is_package_available("blobfile")[0] or True @lru_cache @@ -1188,29 +1198,34 @@ def is_liger_kernel_available() -> bool: @lru_cache def is_rich_available() -> bool: - return _is_package_available("rich") + return _is_package_available("rich")[0] @lru_cache def is_matplotlib_available() -> bool: - return _is_package_available("matplotlib") + return _is_package_available("matplotlib")[0] @lru_cache def is_mistral_common_available() -> bool: - return _is_package_available("mistral_common") + return _is_package_available("mistral_common")[0] @lru_cache def is_opentelemetry_available() -> bool: try: - return _is_package_available("opentelemetry") and version.parse( + return _is_package_available("opentelemetry")[0] and version.parse( importlib.metadata.version("opentelemetry-api") ) >= version.parse("1.30.0") except Exception as _: return False +@lru_cache +def is_pynvml_available() -> bool: + return _is_package_available("pynvml")[0] + + def check_torch_load_is_safe() -> None: if not is_torch_greater_or_equal("2.6"): raise ValueError( @@ -1430,7 +1445,7 @@ def is_sagemaker_dp_enabled() -> bool: except json.JSONDecodeError: return False # Lastly, check if the `smdistributed` module is present. - return _is_package_available("smdistributed") + return _is_package_available("smdistributed")[0] def is_sagemaker_mp_enabled() -> bool: @@ -1454,7 +1469,7 @@ def is_sagemaker_mp_enabled() -> bool: except json.JSONDecodeError: return False # Lastly, check if the `smdistributed` module is present. - return _is_package_available("smdistributed") + return _is_package_available("smdistributed")[0] def is_training_run_on_sagemaker() -> bool: @@ -2005,7 +2020,7 @@ def __init__( # Needed for autocompletion in an IDE def __dir__(self): - result = super().__dir__() + result = list(super().__dir__()) # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. for attr in self.__all__: @@ -2259,10 +2274,12 @@ def direct_transformers_import(path: str, file="__init__.py") -> ModuleType: name = "transformers" location = os.path.join(path, file) spec = importlib.util.spec_from_file_location(name, location, submodule_search_locations=[path]) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - module = sys.modules[name] - return module + if spec is not None and spec.loader is not None: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + module = sys.modules[name] + return module + raise ImportError(f"Could not load module {name} from {location}") class VersionComparison(Enum): @@ -2276,13 +2293,13 @@ class VersionComparison(Enum): @staticmethod def from_string(version_string: str) -> "VersionComparison": string_to_operator = { - "=": VersionComparison.EQUAL.value, - "==": VersionComparison.EQUAL.value, - "!=": VersionComparison.NOT_EQUAL.value, - ">": VersionComparison.GREATER_THAN.value, - "<": VersionComparison.LESS_THAN.value, - ">=": VersionComparison.GREATER_THAN_OR_EQUAL.value, - "<=": VersionComparison.LESS_THAN_OR_EQUAL.value, + "=": VersionComparison.EQUAL, + "==": VersionComparison.EQUAL, + "!=": VersionComparison.NOT_EQUAL, + ">": VersionComparison.GREATER_THAN, + "<": VersionComparison.LESS_THAN, + ">=": VersionComparison.GREATER_THAN_OR_EQUAL, + "<=": VersionComparison.LESS_THAN_OR_EQUAL, } return string_to_operator[version_string] diff --git a/src/transformers/utils/loading_report.py b/src/transformers/utils/loading_report.py index c613e3bc1acf..0e5810abd8c1 100644 --- a/src/transformers/utils/loading_report.py +++ b/src/transformers/utils/loading_report.py @@ -50,7 +50,7 @@ def update_key_name(mapping: dict[str, Any]) -> dict[str, Any]: mapping = {k: k for k in mapping} not_mapping = True - bucket: dict[tuple[str, Any], list[set[int]]] = defaultdict(list) + bucket: dict[str, list[set[int] | Any]] = defaultdict(list) for key, val in mapping.items(): digs = _DIGIT_RX.findall(key) patt = _pattern_of(key) diff --git a/src/transformers/utils/logging.py b/src/transformers/utils/logging.py index c9fc19f26dd7..4b38b824ede7 100644 --- a/src/transformers/utils/logging.py +++ b/src/transformers/utils/logging.py @@ -33,6 +33,8 @@ import huggingface_hub.utils as hf_hub_utils from tqdm import auto as tqdm_lib +from ._typing import TransformersLogger + _lock = threading.Lock() _default_handler: logging.Handler | None = None @@ -99,7 +101,8 @@ def _configure_library_root_logger() -> None: formatter = logging.Formatter("[%(levelname)s|%(pathname)s:%(lineno)s] %(asctime)s >> %(message)s") _default_handler.setFormatter(formatter) - is_ci = os.getenv("CI") is not None and os.getenv("CI").upper() in {"1", "ON", "YES", "TRUE"} + ci = os.getenv("CI") + is_ci = ci is not None and ci.upper() in {"1", "ON", "YES", "TRUE"} library_root_logger.propagate = is_ci @@ -143,7 +146,7 @@ def captureWarnings(capture): _captureWarnings(capture) -def get_logger(name: str | None = None) -> logging.Logger: +def get_logger(name: str | None = None) -> TransformersLogger: """ Return a logger with the specified name. @@ -312,7 +315,7 @@ def warning_advice(self, *args, **kwargs): self.warning(*args, **kwargs) -logging.Logger.warning_advice = warning_advice +logging.Logger.warning_advice = warning_advice # type: ignore[unresolved-attribute] @functools.lru_cache(None) @@ -327,7 +330,7 @@ def warning_once(self, *args, **kwargs): self.warning(*args, **kwargs) -logging.Logger.warning_once = warning_once +logging.Logger.warning_once = warning_once # type: ignore[unresolved-attribute] @functools.lru_cache(None) @@ -342,7 +345,7 @@ def info_once(self, *args, **kwargs): self.info(*args, **kwargs) -logging.Logger.info_once = info_once +logging.Logger.info_once = info_once # type: ignore[unresolved-attribute] class EmptyTqdm: diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py index 1660d546ed1e..ecbe8271fe13 100644 --- a/src/transformers/utils/notebook.py +++ b/src/transformers/utils/notebook.py @@ -15,7 +15,7 @@ import os import re import time -from typing import Optional +from typing import Optional, TypeVar import IPython.display as disp @@ -23,6 +23,15 @@ from ..trainer_utils import IntervalStrategy, has_length +_T = TypeVar("_T") + + +def _require(x: _T | None, msg: str) -> _T: + if x is None: + raise RuntimeError(msg) + return x + + def format_time(t): "Format `t` (in seconds) to (h):mm:ss" t = int(t) @@ -307,7 +316,8 @@ def on_train_begin(self, args, state, control, **kwargs): def on_step_end(self, args, state, control, **kwargs): epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}" - self.training_tracker.update( + tt = _require(self.training_tracker, "on_train_begin must be called before on_step_end") + tt.update( state.global_step + 1, comment=f"Epoch {epoch}/{state.num_train_epochs}", force_update=self._force_next_update, @@ -334,47 +344,52 @@ def on_predict(self, args, state, control, **kwargs): def on_log(self, args, state, control, logs=None, **kwargs): # Only for when there is no evaluation if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: + tt = _require(self.training_tracker, "on_train_begin must be called before on_log") values = {"Training Loss": logs["loss"]} # First column is necessarily Step sine we're not in epoch eval strategy values["Step"] = state.global_step - self.training_tracker.write_line(values) + tt.write_line(values) def on_evaluate(self, args, state, control, metrics=None, **kwargs): - if self.training_tracker is not None: - values = {"Training Loss": "No log", "Validation Loss": "No log"} - for log in reversed(state.log_history): - if "loss" in log: - values["Training Loss"] = log["loss"] - break - - if self.first_column == "Epoch": - values["Epoch"] = int(state.epoch) - else: - values["Step"] = state.global_step - metric_key_prefix = "eval" - for k in metrics: - if k.endswith("_loss"): - metric_key_prefix = re.sub(r"\_loss$", "", k) - _ = metrics.pop("total_flos", None) - _ = metrics.pop("epoch", None) - _ = metrics.pop(f"{metric_key_prefix}_runtime", None) - _ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None) - _ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None) - for k, v in metrics.items(): - splits = k.split("_") - name = " ".join([part.capitalize() for part in splits[1:]]) - if name == "Loss": - # Single dataset - name = "Validation Loss" - values[name] = v - self.training_tracker.write_line(values) - self.training_tracker.remove_child() - self.prediction_bar = None - # Evaluation takes a long time so we should force the next update. - self._force_next_update = True + tt = _require(self.training_tracker, "on_train_begin must be called before on_evaluate") + + values = {"Training Loss": "No log", "Validation Loss": "No log"} + for log in reversed(state.log_history): + if "loss" in log: + values["Training Loss"] = log["loss"] + break + + if self.first_column == "Epoch": + values["Epoch"] = int(state.epoch) + else: + values["Step"] = state.global_step + if metrics is None: + metrics = {} + metric_key_prefix = "eval" + for k in metrics: + if k.endswith("_loss"): + metric_key_prefix = re.sub(r"\_loss$", "", k) + _ = metrics.pop("total_flos", None) + _ = metrics.pop("epoch", None) + _ = metrics.pop(f"{metric_key_prefix}_runtime", None) + _ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None) + _ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None) + for k, v in metrics.items(): + splits = k.split("_") + name = " ".join([part.capitalize() for part in splits[1:]]) + if name == "Loss": + # Single dataset + name = "Validation Loss" + values[name] = v + tt.write_line(values) + tt.remove_child() + self.prediction_bar = None + # Evaluation takes a long time so we should force the next update. + self._force_next_update = True def on_train_end(self, args, state, control, **kwargs): - self.training_tracker.update( + tt = _require(self.training_tracker, "on_train_begin must be called before on_train_end") + tt.update( state.global_step, comment=f"Epoch {int(state.epoch)}/{state.num_train_epochs}", force_update=True, diff --git a/src/transformers/utils/peft_utils.py b/src/transformers/utils/peft_utils.py index 99062bf6502f..a1ec093b4e67 100644 --- a/src/transformers/utils/peft_utils.py +++ b/src/transformers/utils/peft_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +import importlib.metadata import os from packaging import version diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 9b6fb9e6a86b..80551e96f771 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -158,6 +158,12 @@ def __iter__(self): def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" + def to_diff_dict(self) -> dict[str, Any]: + """ + Default behavior: no diffing implemented for this config. + """ + return self.to_dict() + def to_json_string(self, use_diff: bool = True) -> str: """ Serializes this instance to a JSON string. @@ -170,10 +176,7 @@ def to_json_string(self, use_diff: bool = True) -> str: Returns: `str`: String containing all the attributes that make up this configuration instance in JSON format. """ - if use_diff is True: - config_dict = self.to_diff_dict() - else: - config_dict = self.to_dict() + config_dict = self.to_diff_dict() if use_diff else self.to_dict() return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def update(self, **kwargs): @@ -1254,7 +1257,8 @@ def is_quantized(self): def is_quantization_compressed(self): from compressed_tensors.quantization import QuantizationStatus - return self.is_quantized and self.quantization_config.quantization_status == QuantizationStatus.COMPRESSED + qc = self.quantization_config + return self.is_quantized and (qc is not None and qc.quantization_status == QuantizationStatus.COMPRESSED) @property def is_sparsification_compressed(self): diff --git a/src/transformers/utils/type_validators.py b/src/transformers/utils/type_validators.py index 8775150ece22..2600998277a6 100644 --- a/src/transformers/utils/type_validators.py +++ b/src/transformers/utils/type_validators.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Union +from typing import Any, Union, cast from ..tokenization_utils_base import PaddingStrategy, TruncationStrategy from ..video_utils import VideoMetadataType @@ -92,7 +92,7 @@ def video_metadata_validator(value: VideoMetadataType | None = None): valid_keys = ["total_num_frames", "fps", "width", "height", "duration", "video_backend", "frames_indices"] - def check_dict_keys(d: dict) -> bool: + def check_dict_keys(d: dict[str, Any]) -> bool: return all(key in valid_keys for key in d.keys()) if isinstance(value, Sequence) and isinstance(value[0], Sequence) and isinstance(value[0][0], dict): @@ -107,7 +107,7 @@ def check_dict_keys(d: dict) -> bool: for item in value: if not check_dict_keys(item): raise ValueError( - f"Invalid keys found in video metadata. Valid keys: {valid_keys} got: {list(item.keys())}" + f"Invalid keys found in video metadata. Valid keys: {valid_keys} got: {list(cast(dict, item).keys())}" ) elif isinstance(value, dict):