Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
da0da51
initial ty integration
tarekziade Jan 21, 2026
09efe27
narrow ty check to utils and add more ignores
tarekziade Jan 21, 2026
dac72f4
circleci does not use the makefile, fix ty changes there
tarekziade Jan 21, 2026
21bbc0a
remove unecessary ignores
tarekziade Jan 22, 2026
3bed4d1
removed a couple of ignores
tarekziade Jan 22, 2026
76f66f7
causal_mask can be a Tensor, BlockMask or None, lets be explicit
tarekziade Jan 22, 2026
5b07638
simplify to_json_string
tarekziade Jan 22, 2026
1b57a49
make it more readable - we know qc cannot be None here but its best b…
tarekziade Jan 22, 2026
7d65390
explicitely assert training_tracker, do not silently ignore
tarekziade Jan 22, 2026
b634db3
make _is_package_available return type unique, and fully private to i…
tarekziade Jan 22, 2026
c226084
explicit contract at the base class level is nicer than duck typing c…
tarekziade Jan 22, 2026
050a7a8
add a comment about monky patching the logger
tarekziade Jan 22, 2026
448f487
added ignore
tarekziade Feb 3, 2026
6b65a46
better one
tarekziade Feb 5, 2026
2389e27
fixed test
tarekziade Feb 5, 2026
064bcec
tweaks
tarekziade Feb 5, 2026
a71340a
added a logger protocol
tarekziade Feb 5, 2026
c998e27
removed some asserts
tarekziade Feb 5, 2026
5be4580
just ignore type check on that one
tarekziade Feb 5, 2026
06ebe02
better descirption
tarekziade Feb 5, 2026
74fd531
not needed anymore
tarekziade Feb 6, 2026
e043bb7
yeah callable dont always have names
tarekziade Feb 6, 2026
3556795
one more func name ficx
tarekziade Feb 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ jobs:
path: ~/transformers/installed.txt
- run: ruff check examples tests src utils scripts benchmark benchmark_v2 setup.py conftest.py
- run: ruff format --check examples tests src utils scripts benchmark benchmark_v2 setup.py conftest.py
- run: ty check src/transformers/utils/*.py --force-exclude --exclude '**/*_pb2*.py'
- run: python utils/custom_init_isort.py --check_only
- run: python utils/sort_auto_mappings.py --check_only

Expand Down
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ export PYTHONPATH = src
check_dirs := examples tests src utils scripts benchmark benchmark_v2
exclude_folders := ""

# Helper to find all Python files in directories (ty doesn't recursively scan directories)
define get_py_files
$(shell find $(1) -name "*.py" -type f 2>/dev/null)
endef


# this runs all linting/formatting scripts, most notably ruff
style:
Expand All @@ -20,6 +25,7 @@ style:
check-repo:
ruff check $(check_dirs) setup.py conftest.py
ruff format --check $(check_dirs) setup.py conftest.py
ty check $(call get_py_files,src/transformers/utils) --force-exclude --exclude '**/*_pb2*.py'
-python utils/custom_init_isort.py --check_only
-python utils/sort_auto_mappings.py --check_only
-python -c "from transformers import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)
Expand Down
2 changes: 1 addition & 1 deletion docker/quality.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ USER root
RUN apt-get update && apt-get install -y time git
ENV UV_PYTHON=/usr/local/bin/python
RUN pip install uv
RUN uv pip install --no-cache-dir -U pip setuptools GitPython "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[ruff]" urllib3
RUN uv pip install --no-cache-dir -U pip setuptools GitPython "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[quality]" urllib3
RUN apt-get install -y jq curl && apt-get clean && rm -rf /var/lib/apt/lists/*
20 changes: 20 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,23 @@ env = [
# Note: 'D:' means default value from laptop or CI won't be overwritten
"D:HF_HUB_DOWNLOAD_TIMEOUT=60",
]

[tool.ty]
# ty type checker configuration
# Using default settings for comprehensive type checking

[tool.ty.rules]
# Disable specific rules that produce false positives or are too strict for this codebase
invalid-method-override = "ignore" # Parameter name differences are acceptable (e.g., x vs input, new_embeddings vs value)
not-subscriptable = "ignore" # False positives on tensor slicing (e.g., self.position_ids[:, :seq_length])
no-matching-overload = "ignore" # False positives on torch.zeros and similar functions accepting Size/tuple
unsupported-operator = "ignore" # False positives on tuple concatenation with += when properly initialized
unresolved-import = "ignore" # Optional dependencies (mlx, torch_npu, habana_frameworks, etc.) checked at runtime
call-non-callable = "ignore" # Mixin pattern issues where classes are used as both types and callables
unresolved-reference = "ignore" # Forward references with noqa: F821 that ty doesn't respect
invalid-argument-type = "ignore" # Complex type narrowing and union type issues
not-iterable = "ignore" # Complex async/Future type patterns
invalid-return-type = "ignore" # Return type mismatches that would require refactoring
deprecated = "ignore" # Deprecation warnings from dependencies
invalid-assignment = "ignore" # Low-level assignments that are runtime-safe
unused-ignore-comment = "ignore" # Ignore comments that became unnecessary after adding broader per-file-ignores
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
"rjieba",
"rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"ruff==0.14.10",
"ty==0.0.12",
# `sacrebleu` not used in `transformers`. However, it is needed in several tests, when a test calls
# `evaluate.load("sacrebleu")`. This metric is used in the examples that we use to test the `Trainer` with, in the
# `Trainer` tests (see references to `run_translation.py`).
Expand Down Expand Up @@ -182,7 +183,7 @@ def deps_list(*pkgs):
extras["audio"] += deps_list("kenlm")
extras["video"] = deps_list("av")
extras["timm"] = deps_list("timm")
extras["quality"] = deps_list("datasets", "ruff", "GitPython", "urllib3", "libcst", "rich")
extras["quality"] = deps_list("datasets", "ruff", "GitPython", "urllib3", "libcst", "rich", "ty")
extras["kernels"] = deps_list("kernels")
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["tiktoken"] = deps_list("tiktoken", "blobfile")
Expand Down
1 change: 1 addition & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"rjieba": "rjieba",
"rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"ruff": "ruff==0.14.10",
"ty": "ty==0.0.12",
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
"sacremoses": "sacremoses",
"safetensors": "safetensors>=0.4.3",
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import numpy as np
import packaging.version

from transformers.utils.import_utils import _is_package_available
from transformers.utils.import_utils import is_pynvml_available


if os.getenv("WANDB_MODE") == "offline":
Expand Down Expand Up @@ -1030,7 +1030,7 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
f"gpu/{device_idx}/allocated_memory": memory_allocated / (1024**3), # GB
f"gpu/{device_idx}/memory_usage": memory_allocated / total_memory, # ratio
}
if _is_package_available("pynvml"):
if is_pynvml_available():
power = torch.cuda.power_draw(device_idx)
gpu_memory_logs[f"gpu/{device_idx}/power"] = power / 1000 # Watts
if dist.is_available() and dist.is_initialized():
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/dia/convert_dia_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
DiaTokenizer,
GenerationConfig,
)
from transformers.utils.import_utils import _is_package_available
from transformers.utils.import_utils import is_tiktoken_available


# Provide just the list of layer keys you want to fix
Expand Down Expand Up @@ -180,7 +180,7 @@ def convert_dia_model_to_hf(checkpoint_path, verbose=False):
model = convert_dia_model_to_hf(args.checkpoint_path, args.verbose)
if args.convert_preprocessor:
try:
if not _is_package_available("tiktoken"):
if not is_tiktoken_available(with_blobfile=False):
raise ModuleNotFoundError(
"""`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer"""
)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/whisper/convert_openai_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
WhisperTokenizerFast,
)
from transformers.models.whisper.tokenization_whisper import LANGUAGES, bytes_to_unicode
from transformers.utils.import_utils import _is_package_available
from transformers.utils.import_utils import is_tiktoken_available


_MODELS = {
Expand Down Expand Up @@ -345,7 +345,7 @@ def convert_tiktoken_to_hf(

if args.convert_preprocessor:
try:
if not _is_package_available("tiktoken"):
if not is_tiktoken_available(with_blobfile=False):
raise ModuleNotFoundError(
"""`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer"""
)
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
import shutil
import threading
import time
from collections.abc import Callable
from collections.abc import Callable, Sized
from functools import partial
from pathlib import Path
from typing import Any, NamedTuple
from typing import Any, NamedTuple, TypeGuard

import numpy as np

Expand Down Expand Up @@ -820,7 +820,7 @@ def stop_and_update_metrics(self, metrics=None):
self.update_metrics(stage, metrics)


def has_length(dataset):
def has_length(dataset: Any) -> TypeGuard[Sized]:
"""
Checks if the dataset implements __len__() and it doesn't raise an error
"""
Expand Down
123 changes: 123 additions & 0 deletions src/transformers/utils/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import logging
from collections.abc import Mapping, MutableMapping
from typing import Any, Protocol, TypeAlias


# A few helpful type aliases
Level: TypeAlias = int
ExcInfo: TypeAlias = (
None
| bool
| BaseException
| tuple[type[BaseException], BaseException, object] # traceback is `types.TracebackType`, but keep generic here
)


class TransformersLogger(Protocol):
# ---- Core Logger identity / configuration ----
name: str
level: int
parent: logging.Logger | None
propagate: bool
disabled: bool
handlers: list[logging.Handler]

# Exists on Logger; default is True. (Not heavily used, but is part of API.)
raiseExceptions: bool # type: ignore[assignment]

# ---- Standard methods ----
def setLevel(self, level: Level) -> None: ...
def isEnabledFor(self, level: Level) -> bool: ...
def getEffectiveLevel(self) -> int: ...

def getChild(self, suffix: str) -> logging.Logger: ...

def addHandler(self, hdlr: logging.Handler) -> None: ...
def removeHandler(self, hdlr: logging.Handler) -> None: ...
def hasHandlers(self) -> bool: ...

# ---- Logging calls ----
def debug(self, msg: object, *args: object, **kwargs: object) -> None: ...
def info(self, msg: object, *args: object, **kwargs: object) -> None: ...
def warning(self, msg: object, *args: object, **kwargs: object) -> None: ...
def warn(self, msg: object, *args: object, **kwargs: object) -> None: ...
def error(self, msg: object, *args: object, **kwargs: object) -> None: ...
def exception(self, msg: object, *args: object, exc_info: ExcInfo = True, **kwargs: object) -> None: ...
def critical(self, msg: object, *args: object, **kwargs: object) -> None: ...
def fatal(self, msg: object, *args: object, **kwargs: object) -> None: ...

# The lowest-level primitive
def log(self, level: Level, msg: object, *args: object, **kwargs: object) -> None: ...

# ---- Record-level / formatting ----
def makeRecord(
self,
name: str,
level: Level,
fn: str,
lno: int,
msg: object,
args: tuple[object, ...] | Mapping[str, object],
exc_info: ExcInfo,
func: str | None = None,
extra: Mapping[str, object] | None = None,
sinfo: str | None = None,
) -> logging.LogRecord: ...

def handle(self, record: logging.LogRecord) -> None: ...
def findCaller(
self,
stack_info: bool = False,
stacklevel: int = 1,
) -> tuple[str, int, str, str | None]: ...

def callHandlers(self, record: logging.LogRecord) -> None: ...
def getMessage(self) -> str: ... # NOTE: actually on LogRecord; included rarely; safe to omit if you want

def _log(
self,
level: Level,
msg: object,
args: tuple[object, ...] | Mapping[str, object],
exc_info: ExcInfo = None,
extra: Mapping[str, object] | None = None,
stack_info: bool = False,
stacklevel: int = 1,
) -> None: ...

# ---- Filters ----
def addFilter(self, filt: logging.Filter) -> None: ...
def removeFilter(self, filt: logging.Filter) -> None: ...
@property
def filters(self) -> list[logging.Filter]: ...

def filter(self, record: logging.LogRecord) -> bool: ...

# ---- Convenience helpers ----
def setFormatter(self, fmt: logging.Formatter) -> None: ... # mostly on handlers; present on adapters sometimes
def debugStack(self, msg: object, *args: object, **kwargs: object) -> None: ... # not std; safe no-op if absent

# ---- stdlib dictConfig-friendly / extra storage ----
# Logger has `manager` and can have arbitrary attributes; Protocol can't express arbitrary attrs,
# but we can at least include `__dict__` to make "extra attributes" less painful.
__dict__: MutableMapping[str, Any]

# ---- Transformers logger specific methods ----
def warning_advice(self, msg: object, *args: object, **kwargs: object) -> None: ...
def warning_once(self, msg: object, *args: object, **kwargs: object) -> None: ...
def info_once(self, msg: object, *args: object, **kwargs: object) -> None: ...
10 changes: 7 additions & 3 deletions src/transformers/utils/attention_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,14 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
past_key_values=None,
)

if causal_mask is not None:
attention_mask = ~causal_mask.bool()
else:
if causal_mask is None:
# attention_mask must be a tensor here
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, seq_length, seq_length)
elif isinstance(causal_mask, torch.Tensor):
attention_mask = ~causal_mask.to(dtype=torch.bool)
else:
attention_mask = ~causal_mask

top_bottom_border = "##" * (
len(f"Attention visualization for {self.config.model_type} | {self.mapped_cls}") + 4
) # Box width adjusted to text length
Expand Down
31 changes: 17 additions & 14 deletions src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,7 @@
from datetime import datetime
from functools import lru_cache
from inspect import isfunction
from typing import (
Any,
Literal,
Union,
get_args,
get_origin,
get_type_hints,
)
from typing import Any, Literal, Union, get_args, get_origin, get_type_hints, no_type_check

from packaging import version

Expand All @@ -41,6 +34,10 @@

if is_jinja_available():
import jinja2
import jinja2.exceptions
import jinja2.ext
import jinja2.nodes
import jinja2.runtime
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
else:
Expand Down Expand Up @@ -181,10 +178,11 @@ def _parse_type_hint(hint: str) -> dict:
def _convert_type_hints_to_json_schema(func: Callable) -> dict:
type_hints = get_type_hints(func)
signature = inspect.signature(func)
func_name = getattr(func, "__name__", "operation")
required = []
for param_name, param in signature.parameters.items():
if param.annotation == inspect.Parameter.empty:
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func_name}")
if param.default == inspect.Parameter.empty:
required.append(param_name)

Expand Down Expand Up @@ -340,10 +338,10 @@ def get_json_schema(func: Callable) -> dict:
}
"""
doc = inspect.getdoc(func)
func_name = getattr(func, "__name__", "operation")

if not doc:
raise DocstringParsingException(
f"Cannot generate JSON schema for {func.__name__} because it has no docstring!"
)
raise DocstringParsingException(f"Cannot generate JSON schema for {func_name} because it has no docstring!")
doc = doc.strip()
main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc)

Expand All @@ -354,7 +352,7 @@ def get_json_schema(func: Callable) -> dict:
for arg, schema in json_schema["properties"].items():
if arg not in param_descriptions:
raise DocstringParsingException(
f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
f"Cannot generate JSON schema for {func_name} because the docstring has no description for the argument '{arg}'"
)
desc = param_descriptions[arg]
enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
Expand All @@ -363,7 +361,7 @@ def get_json_schema(func: Callable) -> dict:
desc = enum_choices.string[: enum_choices.start()].strip()
schema["description"] = desc

output = {"name": func.__name__, "description": main_doc, "parameters": json_schema}
output = {"name": func_name, "description": main_doc, "parameters": json_schema}
if return_dict is not None:
output["return"] = return_dict
return {"type": "function", "function": output}
Expand All @@ -389,6 +387,11 @@ def _render_with_assistant_indices(

@lru_cache
def _compile_jinja_template(chat_template):
return _cached_compile_jinja_template(chat_template)


@no_type_check
def _cached_compile_jinja_template(chat_template):
if not is_jinja_available():
raise ImportError(
"apply_chat_template requires jinja2 to be installed. Please install it using `pip install jinja2`."
Expand Down
Loading