From 9e07e7ccedd4dbb68ef22a64ba51a281ba0df9c6 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 17 Dec 2025 19:58:22 +0100 Subject: [PATCH 1/6] Run CI tests for draft PR --- .github/workflows/tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9d571aec883..20b2579a415 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,7 +25,7 @@ jobs: check_code_quality: name: Check code quality runs-on: ubuntu-latest - if: github.event.pull_request.draft == false +# if: github.event.pull_request.draft == false steps: - uses: actions/checkout@v4 - name: Set up Python 3.12 @@ -40,7 +40,7 @@ jobs: name: Tests strategy: matrix: - python-version: ['3.10', '3.11', '3.12', '3.13'] + python-version: ['3.10'] # , '3.11', '3.12', '3.13'] fail-fast: false runs-on: group: aws-g4dn-2xlarge @@ -50,7 +50,7 @@ jobs: defaults: run: shell: bash - if: github.event.pull_request.draft == false +# if: github.event.pull_request.draft == false steps: - name: Git checkout uses: actions/checkout@v4 From a8c742189bdb830d60024ebc02ed2ccb976796ca Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu, 18 Dec 2025 12:51:06 +0100 Subject: [PATCH 2/6] Create ProfilingContext --- trl/extras/profiling.py | 137 ++++++++++++++++++++++++++++++++++------ 1 file changed, 119 insertions(+), 18 deletions(-) diff --git a/trl/extras/profiling.py b/trl/extras/profiling.py index 7fc7b40b5aa..0d9eef133b8 100644 --- a/trl/extras/profiling.py +++ b/trl/extras/profiling.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import functools import time -from collections.abc import Callable, Generator +from collections.abc import Callable from transformers import Trainer from transformers.integrations import is_mlflow_available, is_wandb_available @@ -28,17 +27,116 @@ import mlflow -@contextlib.contextmanager -def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]: +class ProfilingContext: """ - A context manager function for profiling a block of code. Results are logged to Weights & Biases or MLflow - depending on the trainer's configuration. + Context manager for profiling code blocks with configurable logging. + + This class handles timing of code execution and logging metrics to various backends + (Weights & Biases, MLflow) without being coupled to the Trainer class. + + Args: + name (`str`): + Name of the profiling context. Used in the metric name. + report_to (`list` of `str`): + List of integrations to report metrics to (e.g., ["wandb", "mlflow"]). + is_main_process (`bool`, *optional*, defaults to `True`): + Whether this is the main process in distributed training. Metrics are only + logged from the main process. + step (`int` or `None`, *optional*): + Training step to associate with the logged metrics. + metric_prefix (`str`, *optional*, defaults to `"profiling/Time taken"`): + Prefix for the metric name in logs. + + Example: + ```python + # Direct usage + from trl.extras.profiling import ProfilingContext + + with ProfilingContext( + name="MyClass.expensive_operation", + report_to=["wandb"], + is_main_process=True, + step=100 + ): + # Code to profile + result = expensive_computation() + + # With Trainer (backwards compatible via profiling_context function) + from transformers import Trainer + from trl.extras.profiling import profiling_context + + class MyTrainer(Trainer): + def some_method(self): + with profiling_context(self, "matrix_multiplication"): + result = matrix_multiply() + ``` + """ + + def __init__( + self, + name: str, + report_to: list[str], + is_main_process: bool = True, + step: int | None = None, + metric_prefix: str = "profiling/Time taken", + ): + self.name = name + self.report_to = report_to + self.is_main_process = is_main_process + self.step = step + self.metric_prefix = metric_prefix + self._start_time = None + + def __enter__(self): + """Start timing when entering the context.""" + self._start_time = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop timing and log metrics when exiting the context.""" + if self._start_time is not None: + duration = time.perf_counter() - self._start_time + self._log_metrics(duration) + return False + + def _log_metrics(self, duration: float) -> None: + """ + Log profiling metrics to configured backends. + + Args: + duration (`float`): + Execution time in seconds. + """ + if not self.is_main_process: + return + + metric_name = f"{self.metric_prefix}: {self.name}" + metrics = {metric_name: duration} + + # Log to Weights & Biases if configured + if "wandb" in self.report_to and is_wandb_available() and wandb.run is not None: + wandb.log(metrics, step=self.step) + + # Log to MLflow if configured + if "mlflow" in self.report_to and is_mlflow_available() and mlflow.active_run() is not None: + mlflow.log_metrics(metrics, step=self.step) + + +def profiling_context(trainer: Trainer, name: str) -> ProfilingContext: + """ + Factory function to create a ProfilingContext from a Trainer instance. + + This function maintains backwards compatibility with existing code while using + the decoupled ProfilingContext class internally. Args: trainer (`~transformers.Trainer`): - Trainer object. + Trainer object containing configuration for logging. name (`str`): - Name of the block to be profiled. Used as a key in the logged dictionary. + Name of the block to be profiled. Will be prefixed with the trainer class name. + + Returns: + `ProfilingContext`: A configured profiling context manager. Example: ```python @@ -55,27 +153,30 @@ def some_method(self): result = A @ B # Matrix multiplication ``` """ - start_time = time.perf_counter() - yield - end_time = time.perf_counter() - duration = end_time - start_time - - profiling_metrics = {f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration} - if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process: - wandb.log(profiling_metrics) + context_name = f"{trainer.__class__.__name__}.{name}" + step = trainer.state.global_step if hasattr(trainer, "state") else None - if "mlflow" in trainer.args.report_to and mlflow.run is not None and trainer.accelerator.is_main_process: - mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step) + return ProfilingContext( + name=context_name, + report_to=trainer.args.report_to, + is_main_process=trainer.accelerator.is_main_process, + step=step, + ) def profiling_decorator(func: Callable) -> Callable: """ Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`]. + This decorator works with methods that have access to a trainer instance (typically as `self`). + Args: func (`Callable`): Function to be profiled. + Returns: + `Callable`: Wrapped function that profiles execution time. + Example: ```python from transformers import Trainer From f469ae2f34bbb4ecc50929bbdaf7ebd3b5037fb4 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 19 Dec 2025 08:02:12 +0100 Subject: [PATCH 3/6] Remove step from wandb.log because it was not present before --- trl/extras/profiling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/extras/profiling.py b/trl/extras/profiling.py index 0d9eef133b8..0db8d68a769 100644 --- a/trl/extras/profiling.py +++ b/trl/extras/profiling.py @@ -115,7 +115,7 @@ def _log_metrics(self, duration: float) -> None: # Log to Weights & Biases if configured if "wandb" in self.report_to and is_wandb_available() and wandb.run is not None: - wandb.log(metrics, step=self.step) + wandb.log(metrics) # Log to MLflow if configured if "mlflow" in self.report_to and is_mlflow_available() and mlflow.active_run() is not None: From aba3978449033803b06a03c1a61b97ecc5fb8138 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 19 Dec 2025 08:02:19 +0100 Subject: [PATCH 4/6] Revert "Run CI tests for draft PR" This reverts commit 9e07e7ccedd4dbb68ef22a64ba51a281ba0df9c6. --- .github/workflows/tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 20b2579a415..9d571aec883 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,7 +25,7 @@ jobs: check_code_quality: name: Check code quality runs-on: ubuntu-latest -# if: github.event.pull_request.draft == false + if: github.event.pull_request.draft == false steps: - uses: actions/checkout@v4 - name: Set up Python 3.12 @@ -40,7 +40,7 @@ jobs: name: Tests strategy: matrix: - python-version: ['3.10'] # , '3.11', '3.12', '3.13'] + python-version: ['3.10', '3.11', '3.12', '3.13'] fail-fast: false runs-on: group: aws-g4dn-2xlarge @@ -50,7 +50,7 @@ jobs: defaults: run: shell: bash -# if: github.event.pull_request.draft == false + if: github.event.pull_request.draft == false steps: - name: Git checkout uses: actions/checkout@v4 From 736cbe26afe195733e6910d15b72bd9cd9a0465d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 19 Dec 2025 16:31:10 +0000 Subject: [PATCH 5/6] apply style --- trl/extras/profiling.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/trl/extras/profiling.py b/trl/extras/profiling.py index 0db8d68a769..c92459b6055 100644 --- a/trl/extras/profiling.py +++ b/trl/extras/profiling.py @@ -31,8 +31,8 @@ class ProfilingContext: """ Context manager for profiling code blocks with configurable logging. - This class handles timing of code execution and logging metrics to various backends - (Weights & Biases, MLflow) without being coupled to the Trainer class. + This class handles timing of code execution and logging metrics to various backends (Weights & Biases, MLflow) + without being coupled to the Trainer class. Args: name (`str`): @@ -40,8 +40,7 @@ class ProfilingContext: report_to (`list` of `str`): List of integrations to report metrics to (e.g., ["wandb", "mlflow"]). is_main_process (`bool`, *optional*, defaults to `True`): - Whether this is the main process in distributed training. Metrics are only - logged from the main process. + Whether this is the main process in distributed training. Metrics are only logged from the main process. step (`int` or `None`, *optional*): Training step to associate with the logged metrics. metric_prefix (`str`, *optional*, defaults to `"profiling/Time taken"`): @@ -56,7 +55,7 @@ class ProfilingContext: name="MyClass.expensive_operation", report_to=["wandb"], is_main_process=True, - step=100 + step=100, ): # Code to profile result = expensive_computation() @@ -65,6 +64,7 @@ class ProfilingContext: from transformers import Trainer from trl.extras.profiling import profiling_context + class MyTrainer(Trainer): def some_method(self): with profiling_context(self, "matrix_multiplication"): @@ -126,8 +126,8 @@ def profiling_context(trainer: Trainer, name: str) -> ProfilingContext: """ Factory function to create a ProfilingContext from a Trainer instance. - This function maintains backwards compatibility with existing code while using - the decoupled ProfilingContext class internally. + This function maintains backwards compatibility with existing code while using the decoupled ProfilingContext class + internally. Args: trainer (`~transformers.Trainer`): From 691d458930f6042f0d881f02630deb19a0d6d992 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 19 Dec 2025 17:45:50 +0100 Subject: [PATCH 6/6] Address requested change --- trl/extras/profiling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/extras/profiling.py b/trl/extras/profiling.py index c92459b6055..a4076dc0977 100644 --- a/trl/extras/profiling.py +++ b/trl/extras/profiling.py @@ -154,7 +154,7 @@ def some_method(self): ``` """ context_name = f"{trainer.__class__.__name__}.{name}" - step = trainer.state.global_step if hasattr(trainer, "state") else None + step = trainer.state.global_step return ProfilingContext( name=context_name,