Skip to content
137 changes: 119 additions & 18 deletions trl/extras/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import time
from collections.abc import Callable, Generator
from collections.abc import Callable

from transformers import Trainer
from transformers.integrations import is_mlflow_available, is_wandb_available
Expand All @@ -28,17 +27,116 @@
import mlflow


@contextlib.contextmanager
def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]:
class ProfilingContext:
"""
A context manager function for profiling a block of code. Results are logged to Weights & Biases or MLflow
depending on the trainer's configuration.
Context manager for profiling code blocks with configurable logging.

This class handles timing of code execution and logging metrics to various backends (Weights & Biases, MLflow)
without being coupled to the Trainer class.

Args:
name (`str`):
Name of the profiling context. Used in the metric name.
report_to (`list` of `str`):
List of integrations to report metrics to (e.g., ["wandb", "mlflow"]).
is_main_process (`bool`, *optional*, defaults to `True`):
Whether this is the main process in distributed training. Metrics are only logged from the main process.
step (`int` or `None`, *optional*):
Training step to associate with the logged metrics.
metric_prefix (`str`, *optional*, defaults to `"profiling/Time taken"`):
Prefix for the metric name in logs.

Example:
```python
# Direct usage
from trl.extras.profiling import ProfilingContext

with ProfilingContext(
name="MyClass.expensive_operation",
report_to=["wandb"],
is_main_process=True,
step=100,
):
# Code to profile
result = expensive_computation()

# With Trainer (backwards compatible via profiling_context function)
from transformers import Trainer
from trl.extras.profiling import profiling_context


class MyTrainer(Trainer):
def some_method(self):
with profiling_context(self, "matrix_multiplication"):
result = matrix_multiply()
```
"""

def __init__(
self,
name: str,
report_to: list[str],
is_main_process: bool = True,
step: int | None = None,
metric_prefix: str = "profiling/Time taken",
):
self.name = name
self.report_to = report_to
self.is_main_process = is_main_process
self.step = step
self.metric_prefix = metric_prefix
self._start_time = None

def __enter__(self):
"""Start timing when entering the context."""
self._start_time = time.perf_counter()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop timing and log metrics when exiting the context."""
if self._start_time is not None:
duration = time.perf_counter() - self._start_time
self._log_metrics(duration)
return False

def _log_metrics(self, duration: float) -> None:
"""
Log profiling metrics to configured backends.

Args:
duration (`float`):
Execution time in seconds.
"""
if not self.is_main_process:
return

metric_name = f"{self.metric_prefix}: {self.name}"
metrics = {metric_name: duration}

# Log to Weights & Biases if configured
if "wandb" in self.report_to and is_wandb_available() and wandb.run is not None:
wandb.log(metrics)

# Log to MLflow if configured
if "mlflow" in self.report_to and is_mlflow_available() and mlflow.active_run() is not None:
mlflow.log_metrics(metrics, step=self.step)


def profiling_context(trainer: Trainer, name: str) -> ProfilingContext:
"""
Factory function to create a ProfilingContext from a Trainer instance.

This function maintains backwards compatibility with existing code while using the decoupled ProfilingContext class
internally.

Args:
trainer (`~transformers.Trainer`):
Trainer object.
Trainer object containing configuration for logging.
name (`str`):
Name of the block to be profiled. Used as a key in the logged dictionary.
Name of the block to be profiled. Will be prefixed with the trainer class name.

Returns:
`ProfilingContext`: A configured profiling context manager.

Example:
```python
Expand All @@ -55,27 +153,30 @@ def some_method(self):
result = A @ B # Matrix multiplication
```
"""
start_time = time.perf_counter()
yield
end_time = time.perf_counter()
duration = end_time - start_time

profiling_metrics = {f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration}
if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process:
wandb.log(profiling_metrics)
context_name = f"{trainer.__class__.__name__}.{name}"
step = trainer.state.global_step

if "mlflow" in trainer.args.report_to and mlflow.run is not None and trainer.accelerator.is_main_process:
mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)
return ProfilingContext(
name=context_name,
report_to=trainer.args.report_to,
is_main_process=trainer.accelerator.is_main_process,
step=step,
)


def profiling_decorator(func: Callable) -> Callable:
"""
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].

This decorator works with methods that have access to a trainer instance (typically as `self`).

Args:
func (`Callable`):
Function to be profiled.

Returns:
`Callable`: Wrapped function that profiles execution time.

Example:
```python
from transformers import Trainer
Expand Down
Loading