Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
f335cc7
custom tests for selective activation checkpointing for layernorm mlp
jaimec00 Oct 27, 2025
e349f46
add selective layernorm mlp to te.pytorch
jaimec00 Oct 27, 2025
aa18e74
update test and fix SLNMLP bug
jaimec00 Oct 27, 2025
8f50f4a
implement slnmlp
jaimec00 Oct 28, 2025
f6f034b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
00841c2
fix tests pointed out by greptile app bot, still pass
jaimec00 Oct 28, 2025
955f068
minor formatting change in tests/pytorch/selective_layernorm_mlp/dist…
jaimec00 Oct 28, 2025
5e47706
remove duplicate import in test/pytorch/selective_layernorm_mlp/test_…
jaimec00 Oct 28, 2025
9a69a6c
clean up tests, remove unused imports
jaimec00 Oct 28, 2025
ea8270d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
f896579
remove unused paths in test_deffered_init
jaimec00 Oct 28, 2025
9ee2df8
fix issue with zero_centered_gamma in test_numerics reference impleme…
jaimec00 Oct 28, 2025
05d3908
clean up tests
jaimec00 Oct 28, 2025
435fe9c
make comparison.py more extensive, cleaner output
jaimec00 Oct 28, 2025
903f37e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
0a31a70
fix small typo in tests/pytorch/selective_layernorm_mlp/compare.py
jaimec00 Oct 28, 2025
418dce6
fix typo by grepbot in compare.py
jaimec00 Oct 28, 2025
31cdd9d
make selectiuve activation checkpointing optional in slnmlp via check…
jaimec00 Oct 28, 2025
fae6052
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
a6a927e
add comments to clarify logic
jaimec00 Oct 29, 2025
16b816b
add checkpoint param to pytests, change compare.py to compare checkpp…
jaimec00 Oct 29, 2025
f623124
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
ff6f58f
refactor tests to call modified LayerNormMLP
jaimec00 Oct 29, 2025
8cbdb91
refactor to implement selective activation checkpointing directly int…
jaimec00 Oct 29, 2025
c46ad4c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
b068c5f
fix skip explanation for cuda_graphs.py
jaimec00 Oct 29, 2025
f0670ed
make _recompute deal with lists instead of tuples
jaimec00 Oct 29, 2025
5a34186
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
e12fa7c
fix MOST cuda graph failures by initializing identical quantizers dur…
jaimec00 Oct 30, 2025
9b29e49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2025
cc52db5
fix cuda graphs issue, all tests pass now
jaimec00 Oct 31, 2025
e94ef33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2025
ebd2329
fix small logic bugs, clean up
jaimec00 Nov 1, 2025
212fadb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2025
402e5f9
integrate tests into main testing scripts
jaimec00 Nov 5, 2025
483bbf6
incorporate rng state tracking in checkpointing
jaimec00 Nov 5, 2025
643a3c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2025
0d0255f
clean up tests
jaimec00 Nov 5, 2025
d86bc00
fix return type mismatches
jaimec00 Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/pytorch/distributed/run_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,7 @@ def test_layernorm_mlp():
{"return_bias": True},
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
{"checkpoint": True},
]

for kwargs in kwargs_list:
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/distributed/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""
Distributed numerics tests

These tests test the numerical corectness of the TransformerEngine layers.
These tests test the numerical correctness of the TransformerEngine layers.
Tests are parametrized by the layer and fp8 precision.
One test consists of running multiple configurations from file run_numerics.py
Such design is due to the fact the initialization of one test is long
Expand Down
169 changes: 169 additions & 0 deletions tests/pytorch/layernorm_mlp/test_selective_activation_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import torch
from transformer_engine.pytorch import LayerNormMLP
import pytest

torch.manual_seed(1234)
device = torch.device("cuda")


class _Sequential(torch.nn.Sequential):
"""Sequential model that forwards keyword arguments to modules"""

def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
x = input_
for module in self:
x = module(x, **kwargs)
return x


class ModelConfig:
def __init__(
self,
hidden_size: int = 128,
ffn_hidden_size: int = 512,
layers: int = 1,
):
self._hidden_size = hidden_size
self._ffn_hidden_size = ffn_hidden_size
self._layers = layers

def build(self):

ln_list, sln_list = [], []
for _ in range(self._layers):
ln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=False).to(device)
sln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=True).to(device)
with torch.no_grad():
sln.layer_norm_weight = torch.nn.Parameter(ln.layer_norm_weight.clone())
sln.layer_norm_bias = torch.nn.Parameter(ln.layer_norm_bias.clone())
sln.fc1_weight = torch.nn.Parameter(ln.fc1_weight.clone())
sln.fc2_weight = torch.nn.Parameter(ln.fc2_weight.clone())
sln.fc1_bias = torch.nn.Parameter(ln.fc1_bias.clone())
sln.fc2_bias = torch.nn.Parameter(ln.fc2_bias.clone())
ln_list.append(ln)
sln_list.append(sln)

ln_model = _Sequential(*ln_list)
sln_model = _Sequential(*sln_list)

return ln_model, sln_model


config = {
"small": ModelConfig(128, 512, 12),
"medium": ModelConfig(512, 2048, 12),
"large": ModelConfig(1024, 4096, 12),
"huge": ModelConfig(2048, 8192, 12),
}

seq_sizes = [2**7, 2**10, 2**14, 2**16]


def _warmup(model, tensor):
for _ in range(3):
model(tensor).sum().backward()


def _run_fwd(model, tensor):

torch.cuda.reset_peak_memory_stats(device)
start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
enable_timing=True
)

torch.cuda.synchronize()
start_mem = torch.cuda.memory_allocated(device)
start_time.record()
out = model(tensor)
end_time.record()
end_time.synchronize()
elapsed = start_time.elapsed_time(end_time)
peak_mem = torch.cuda.max_memory_allocated(device)
mem = float(peak_mem - start_mem)

return out, elapsed, mem


def _run_bwd(model, out):

model.zero_grad(set_to_none=False)
loss = out.sum()

torch.cuda.reset_peak_memory_stats(device)
start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
enable_timing=True
)

torch.cuda.synchronize()
start_mem = torch.cuda.memory_allocated(device)
start_time.record()
loss.backward()
end_time.record()
end_time.synchronize()
elapsed = start_time.elapsed_time(end_time)
peak_mem = torch.cuda.max_memory_allocated(device)
mem = float(peak_mem - start_mem)

param_grads = _collect_param_grads(model)
return param_grads, elapsed, mem


def _max_diff(ref, other):
"""Return max absolute difference between two tensors or collections."""
if ref is None or other is None:
return 0.0
if isinstance(ref, (list, tuple)):
diffs = [_max_diff(r, o) for r, o in zip(ref, other)]
return max(diffs) if diffs else 0.0
return torch.max(torch.abs(ref.detach() - other.detach())).item()


def _collect_param_grads(model):
grads = {}
for name, param in model.named_parameters():
if param.grad is None:
continue
key = _param_key(name)
if key is not None:
grads[key] = param.grad.detach().clone()
return grads


def _param_key(name):
return name.split(".")[-1]


@pytest.mark.parametrize("size", config.keys())
@pytest.mark.parametrize("seq_size", seq_sizes)
def test_selective_activation_checkpoint(size, seq_size):

ln_model, sln_model = config[size].build()
data = torch.randn((seq_size, config[size]._hidden_size), device=device)

_warmup(ln_model, data)
ln_fwd_out, ln_fwd_time, ln_fwd_mem = _run_fwd(ln_model, data)
ln_grads, ln_bwd_time, ln_bwd_mem = _run_bwd(ln_model, ln_fwd_out)

_warmup(sln_model, data)
sln_fwd_out, sln_fwd_time, sln_fwd_mem = _run_fwd(sln_model, data)
sln_grads, sln_bwd_time, sln_bwd_mem = _run_bwd(sln_model, sln_fwd_out)

assert ln_fwd_mem > 6 * sln_fwd_mem, (
"selective activation checkpointing does not reduce forward memory by 6X, only by"
f" {ln_fwd_mem/sln_fwd_mem}!"
)
assert (
ln_bwd_time < sln_bwd_time
), "selective activation activation checkpointing backward pass is slower than native!"
assert _max_diff(ln_fwd_out, sln_fwd_out) == 0.0, "outputs are not equal!"
for key in [
"layer_norm_weight",
"layer_norm_bias",
"fc1_weight",
"fc1_bias",
"fc2_weight",
"fc2_bias",
]:
assert (
_max_diff(ln_grads[key], sln_grads[key]) == 0.0
), f"gradients for {key} are not equal!"
19 changes: 16 additions & 3 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
# creating TMA descriptor for MXFP8 quantization.
"linear",
"transformer",
"layernorm_mlp",
"layernorm_mlp_nocheckpoint",
"layernorm_mlp_checkpoint",
"layernorm_linear",
"mha",
"linear_op",
Expand Down Expand Up @@ -218,12 +219,23 @@ def _test_cuda_graphs(
)
for _ in range(num_layers)
]
elif module == "layernorm_mlp":
elif module == "layernorm_mlp_nocheckpoint":
modules = [
LayerNormMLP(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
checkpoint=False,
)
for _ in range(num_layers)
]
elif module == "layernorm_mlp_checkpoint":
modules = [
LayerNormMLP(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
checkpoint=True,
)
for _ in range(num_layers)
]
Expand Down Expand Up @@ -383,7 +395,8 @@ def test_make_graphed_callables(

_test_make_graphed_callables_with_fp8_weight_caching_modules = [
"transformer",
"layernorm_mlp",
"layernorm_mlp_nocheckpoint",
"layernorm_mlp_checkpoint",
"layernorm_linear",
"linear",
"mha",
Expand Down
14 changes: 10 additions & 4 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
raise ValueError(f"Unsuppored dtype ({dtype})")
raise ValueError(f"Unsupported dtype ({dtype})")


def assert_allclose(
Expand Down Expand Up @@ -1364,7 +1364,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe)
te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe)

# Shoule be bit-wise match
# Should be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)

Expand Down Expand Up @@ -1622,7 +1622,10 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute(
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, return_bias, bias):
@pytest.mark.parametrize("checkpoint", all_boolean)
def test_layernorm_mlp_accuracy(
dtype, bs, model, activation, normalization, return_bias, bias, checkpoint
):
config = model_configs[model]

te_ln_mlp = TestReturnBiasModule(
Expand All @@ -1635,6 +1638,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
return_bias=return_bias,
bias=bias,
device="cuda",
checkpoint=checkpoint,
)

torch_ln_mlp = (
Expand Down Expand Up @@ -1696,8 +1700,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("checkpoint", all_boolean)
def test_layernorm_mlp_accuracy_delay_wgrad_compute(
dtype, bs, model, bias, fuse_wgrad_accumulation
dtype, bs, model, bias, fuse_wgrad_accumulation, checkpoint
):
config = model_configs[model]

Expand All @@ -1708,6 +1713,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
bias=bias,
params_dtype=dtype,
device="cuda",
checkpoint=checkpoint,
delay_wgrad_compute=True,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
Expand Down
10 changes: 7 additions & 3 deletions tests/pytorch/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
)
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
import transformer_engine_torch as tex

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
Expand Down Expand Up @@ -481,7 +480,8 @@ def test_dynamic_recipe_update(
[
Linear,
LayerNormLinear,
LayerNormMLP,
(LayerNormMLP, False), # (module, checkpoint=False)
(LayerNormMLP, True), # (module, checkpoint=True)
GroupedLinear,
],
)
Expand All @@ -495,7 +495,11 @@ def test_quantizer_update(self, module_class):
if module_class == GroupedLinear:
module = module_class(1, in_features, out_features).cuda()
else:
module = module_class(in_features, out_features).cuda()
if isinstance(module_class, tuple) and module_class[0] == LayerNormMLP:
module_class, checkpoint = module_class
module = module_class(in_features, out_features, checkpoint=checkpoint).cuda()
else:
module = module_class(in_features, out_features).cuda()

x = torch.randn(batch_size, in_features, device="cuda")
recipe = DelayedScaling(amax_history_len=1)
Expand Down
3 changes: 3 additions & 0 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ def test_sanity_grouped_linear(
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
@pytest.mark.parametrize("checkpoint", all_boolean)
def test_sanity_layernorm_mlp(
dtype,
fp8_recipe,
Expand All @@ -535,6 +536,7 @@ def test_sanity_layernorm_mlp(
activation,
normalization,
microbatching,
checkpoint,
):
config = model_configs[model]

Expand All @@ -558,6 +560,7 @@ def test_sanity_layernorm_mlp(
normalization=normalization,
params_dtype=dtype,
device="cuda",
checkpoint=checkpoint,
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)

Expand Down
Loading