From 69018734e0c5cd1f7ae7dae0214b6727526ca087 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue, 7 Oct 2025 09:36:59 +0200 Subject: [PATCH] Fix callable annotations --- tests/testing_utils.py | 3 ++- trl/extras/profiling.py | 6 +++--- trl/trainer/callbacks.py | 3 ++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index cbe677255b5..d012d26881c 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -16,6 +16,7 @@ import random import signal import warnings +from collections.abc import Callable import psutil import pytest @@ -73,7 +74,7 @@ def set_tmp_dir(self, tmp_path): self.tmp_dir = str(tmp_path) -def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> callable: +def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> Callable: """ Decorator to ignore warnings with a specific message and/or category. diff --git a/trl/extras/profiling.py b/trl/extras/profiling.py index 8e4c9e9188c..7fc7b40b5aa 100644 --- a/trl/extras/profiling.py +++ b/trl/extras/profiling.py @@ -15,7 +15,7 @@ import contextlib import functools import time -from collections.abc import Generator +from collections.abc import Callable, Generator from transformers import Trainer from transformers.integrations import is_mlflow_available, is_wandb_available @@ -68,12 +68,12 @@ def some_method(self): mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step) -def profiling_decorator(func: callable) -> callable: +def profiling_decorator(func: Callable) -> Callable: """ Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`]. Args: - func (`callable`): + func (`Callable`): Function to be profiled. Example: diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 68fb6f97b72..2c1240b0fcb 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -14,6 +14,7 @@ import logging import os +from collections.abc import Callable from typing import Optional, Union import pandas as pd @@ -583,7 +584,7 @@ def __init__( self, trainer: Trainer, project_name: Optional[str] = None, - scorers: Optional[dict[str, callable]] = None, + scorers: Optional[dict[str, Callable]] = None, generation_config: Optional[GenerationConfig] = None, num_prompts: Optional[int] = None, dataset_name: str = "eval_dataset",