diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index f4f2813a3..e545ea466 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -6,6 +6,7 @@ import copy import io +import functools import os import random from contextlib import nullcontext, redirect_stdout @@ -22,6 +23,11 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.checkpoint import ( + checkpoint, + create_selective_checkpoint_contexts, + CheckpointPolicy, +) from torchao.float8.config import ( CastConfig, Float8LinearConfig, @@ -254,6 +260,22 @@ def profile_function( return prof +# set up AC for max(abs(tensor)) +# context: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts +ops_to_save = [ + torch.ops.aten.abs.default, + torch.ops.aten.max.default, +] + +def policy_fn(ctx, op, *args, **kwargs): + if op in ops_to_save: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + +context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + + def main( profile_path_prefix: pathlib.Path, compile: bool = True, @@ -265,6 +287,7 @@ def main( dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, enable_sync_amax_history: bool = True, + enable_activation_checkpointing: bool = False, ): assert model_type in ("linear", "ln_linear", "norm_ffn_norm", "norm_ffn_norm_small"), "unsupported" assert dtype_filter in ("both", "float8", "bfloat16") @@ -294,6 +317,7 @@ def main( print(f"Compile is set to | {compile}") print(f"model_type is set to | {model_type}") print(f"scaling_repr is set to | {scaling_repr}") + print(f"enable_activation_checkpointing is set to {enable_activation_checkpointing}") device = "cuda" ref_dtype = torch.bfloat16 @@ -338,11 +362,17 @@ def main( convert_to_float8_training(m_float8, config=config) def ref_forw_backward(x): - out = m_ref(x) + if enable_activation_checkpointing: + out = checkpoint(m_ref, x, use_reentrant=False, context_fn=context_fn) + else: + out = m_ref(x) out.sum().backward() def float8_forw(x): - out = m_float8(x) + if enable_activation_checkpointing: + out = checkpoint(m_float8, x, use_reentrant=False, context_fn=context_fn) + else: + out = m_float8(x) return out sync_amax_history = sync_float8_amax_and_scale_history