-
Notifications
You must be signed in to change notification settings - Fork 540
[PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP with checkpoint flag #2311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jaimec00
wants to merge
39
commits into
NVIDIA:main
Choose a base branch
from
jaimec00:features/SLNMLP
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 36 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
f335cc7
custom tests for selective activation checkpointing for layernorm mlp
jaimec00 e349f46
add selective layernorm mlp to te.pytorch
jaimec00 aa18e74
update test and fix SLNMLP bug
jaimec00 8f50f4a
implement slnmlp
jaimec00 f6f034b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 00841c2
fix tests pointed out by greptile app bot, still pass
jaimec00 955f068
minor formatting change in tests/pytorch/selective_layernorm_mlp/dist…
jaimec00 5e47706
remove duplicate import in test/pytorch/selective_layernorm_mlp/test_…
jaimec00 9a69a6c
clean up tests, remove unused imports
jaimec00 ea8270d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f896579
remove unused paths in test_deffered_init
jaimec00 9ee2df8
fix issue with zero_centered_gamma in test_numerics reference impleme…
jaimec00 05d3908
clean up tests
jaimec00 435fe9c
make comparison.py more extensive, cleaner output
jaimec00 903f37e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0a31a70
fix small typo in tests/pytorch/selective_layernorm_mlp/compare.py
jaimec00 418dce6
fix typo by grepbot in compare.py
jaimec00 31cdd9d
make selectiuve activation checkpointing optional in slnmlp via check…
jaimec00 fae6052
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a6a927e
add comments to clarify logic
jaimec00 16b816b
add checkpoint param to pytests, change compare.py to compare checkpp…
jaimec00 f623124
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ff6f58f
refactor tests to call modified LayerNormMLP
jaimec00 8cbdb91
refactor to implement selective activation checkpointing directly int…
jaimec00 c46ad4c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b068c5f
fix skip explanation for cuda_graphs.py
jaimec00 f0670ed
make _recompute deal with lists instead of tuples
jaimec00 5a34186
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e12fa7c
fix MOST cuda graph failures by initializing identical quantizers dur…
jaimec00 9b29e49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] cc52db5
fix cuda graphs issue, all tests pass now
jaimec00 e94ef33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ebd2329
fix small logic bugs, clean up
jaimec00 212fadb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 402e5f9
integrate tests into main testing scripts
jaimec00 483bbf6
incorporate rng state tracking in checkpointing
jaimec00 643a3c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0d0255f
clean up tests
jaimec00 d86bc00
fix return type mismatches
jaimec00 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
156 changes: 156 additions & 0 deletions
156
tests/pytorch/layernorm_mlp/test_selective_activation_checkpoint.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| import torch | ||
| from transformer_engine.pytorch import LayerNormMLP | ||
| import pytest | ||
|
|
||
| torch.manual_seed(1234) | ||
| device = torch.device("cuda") | ||
|
|
||
|
|
||
| class _Sequential(torch.nn.Sequential): | ||
| """Sequential model that forwards keyword arguments to modules""" | ||
|
|
||
| def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: | ||
| x = input_ | ||
| for module in self: | ||
| x = module(x, **kwargs) | ||
| return x | ||
|
|
||
|
|
||
| class ModelConfig: | ||
| def __init__( | ||
| self, | ||
| hidden_size: int = 128, | ||
| ffn_hidden_size: int = 512, | ||
| layers: int = 1, | ||
| ): | ||
| self._hidden_size = hidden_size | ||
| self._ffn_hidden_size = ffn_hidden_size | ||
| self._layers = layers | ||
|
|
||
| def build(self): | ||
|
|
||
| ln_list, sln_list = [], [] | ||
| for _ in range(self._layers): | ||
| ln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=False).to(device) | ||
| sln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=True).to(device) | ||
| with torch.no_grad(): | ||
| sln.layer_norm_weight = torch.nn.Parameter(ln.layer_norm_weight.clone()) | ||
| sln.layer_norm_bias = torch.nn.Parameter(ln.layer_norm_bias.clone()) | ||
| sln.fc1_weight = torch.nn.Parameter(ln.fc1_weight.clone()) | ||
| sln.fc2_weight = torch.nn.Parameter(ln.fc2_weight.clone()) | ||
| sln.fc1_bias = torch.nn.Parameter(ln.fc1_bias.clone()) | ||
| sln.fc2_bias = torch.nn.Parameter(ln.fc2_bias.clone()) | ||
| ln_list.append(ln) | ||
| sln_list.append(sln) | ||
|
|
||
| ln_model = _Sequential(*ln_list) | ||
| sln_model = _Sequential(*sln_list) | ||
|
|
||
| return ln_model, sln_model | ||
|
|
||
| config = { | ||
| "small": ModelConfig(128, 512, 12), | ||
| "medium": ModelConfig(512, 2048, 12), | ||
| "large": ModelConfig(1024, 4096, 12), | ||
| "huge": ModelConfig(2048, 8192, 12), | ||
| } | ||
|
|
||
| seq_sizes = [2**7, 2**10, 2**14, 2**16] | ||
|
|
||
| def _warmup(model, tensor): | ||
| for _ in range(10): | ||
| model(tensor).sum().backward() | ||
|
|
||
| def _run_fwd(model, tensor): | ||
|
|
||
| torch.cuda.reset_peak_memory_stats(device) | ||
| start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event( | ||
| enable_timing=True | ||
| ) | ||
|
|
||
| torch.cuda.synchronize() | ||
| start_mem = torch.cuda.memory_allocated(device) | ||
| start_time.record() | ||
| out = model(tensor) | ||
| end_time.record() | ||
| end_time.synchronize() | ||
| elapsed = start_time.elapsed_time(end_time) | ||
| peak_mem = torch.cuda.max_memory_allocated(device) | ||
| mem = float(peak_mem - start_mem) | ||
|
|
||
| return out, elapsed, mem | ||
|
|
||
| def _run_bwd(model, out): | ||
|
|
||
| model.zero_grad(set_to_none=False) | ||
| loss = out.sum() | ||
|
|
||
| torch.cuda.reset_peak_memory_stats(device) | ||
| start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event( | ||
| enable_timing=True | ||
| ) | ||
|
|
||
| torch.cuda.synchronize() | ||
| start_mem = torch.cuda.memory_allocated(device) | ||
| start_time.record() | ||
| loss.backward() | ||
| end_time.record() | ||
| end_time.synchronize() | ||
| elapsed = start_time.elapsed_time(end_time) | ||
| peak_mem = torch.cuda.max_memory_allocated(device) | ||
| mem = float(peak_mem - start_mem) | ||
|
|
||
| param_grads = _collect_param_grads(model) | ||
| return param_grads, elapsed, mem | ||
|
|
||
| def _max_diff(ref, other): | ||
| """Return max absolute difference between two tensors or collections.""" | ||
| if ref is None or other is None: | ||
| return 0.0 | ||
| if isinstance(ref, (list, tuple)): | ||
| diffs = [_max_diff(r, o) for r, o in zip(ref, other)] | ||
| return max(diffs) if diffs else 0.0 | ||
| return torch.max(torch.abs(ref.detach() - other.detach())).item() | ||
|
|
||
| def _collect_param_grads(model): | ||
| grads = {} | ||
| for name, param in model.named_parameters(): | ||
| if param.grad is None: | ||
| continue | ||
| key = _param_key(name) | ||
| if key is not None: | ||
| grads[key] = param.grad.detach().clone() | ||
| return grads | ||
|
|
||
| def _param_key(name): | ||
| return name.split(".")[-1] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("size", config.keys()) | ||
| @pytest.mark.parametrize("seq_size", seq_sizes) | ||
| def test_selective_activation_checkpoint(size, seq_size): | ||
|
|
||
| ln_model, sln_model = config[size].build() | ||
| data = torch.randn((seq_size, config[size]._hidden_size), device=device) | ||
|
|
||
| _warmup(ln_model, data.clone()) | ||
| ln_fwd_out, ln_fwd_time, ln_fwd_mem = _run_fwd(ln_model, data.clone()) | ||
| ln_grads, ln_bwd_time, ln_bwd_mem = _run_bwd(ln_model, ln_fwd_out) | ||
|
|
||
| _warmup(sln_model, data.clone()) | ||
| sln_fwd_out, sln_fwd_time, sln_fwd_mem = _run_fwd(sln_model, data.clone()) | ||
| sln_grads, sln_bwd_time, sln_bwd_mem = _run_bwd(sln_model, sln_fwd_out) | ||
|
|
||
| assert ln_fwd_mem > 6*sln_fwd_mem, "" | ||
| assert ln_bwd_time < sln_bwd_time, "" | ||
| assert _max_diff(ln_fwd_out, sln_fwd_out)==0.0, "outputs are not equal!" | ||
| for key in [ | ||
| "layer_norm_weight", | ||
| "layer_norm_bias", | ||
| "fc1_weight", | ||
| "fc1_bias", | ||
| "fc2_weight", | ||
| "fc2_bias", | ||
| ]: | ||
| assert _max_diff(ln_grads[key], sln_grads[key])==0.0, f"gradients for {key} are not equal!" | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| import torch | ||
| from transformer_engine.pytorch import LayerNormMLP | ||
| import pytest | ||
|
|
||
| torch.manual_seed(1234) | ||
| device = torch.device("cuda") | ||
|
|
||
|
|
||
| class _Sequential(torch.nn.Sequential): | ||
| """Sequential model that forwards keyword arguments to modules""" | ||
|
|
||
| def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: | ||
| x = input_ | ||
| for module in self: | ||
| x = module(x, **kwargs) | ||
| return x | ||
|
|
||
|
|
||
| class ModelConfig: | ||
| def __init__( | ||
| self, | ||
| hidden_size: int = 128, | ||
| ffn_hidden_size: int = 512, | ||
| layers: int = 1, | ||
| ): | ||
| self._hidden_size = hidden_size | ||
| self._ffn_hidden_size = ffn_hidden_size | ||
| self._layers = layers | ||
|
|
||
| def build(self): | ||
|
|
||
| ln_list, sln_list = [], [] | ||
| for _ in range(self._layers): | ||
| ln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=False).to(device) | ||
| sln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=True).to(device) | ||
| with torch.no_grad(): | ||
| sln.layer_norm_weight = torch.nn.Parameter(ln.layer_norm_weight.clone()) | ||
| sln.layer_norm_bias = torch.nn.Parameter(ln.layer_norm_bias.clone()) | ||
| sln.fc1_weight = torch.nn.Parameter(ln.fc1_weight.clone()) | ||
| sln.fc2_weight = torch.nn.Parameter(ln.fc2_weight.clone()) | ||
| sln.fc1_bias = torch.nn.Parameter(ln.fc1_bias.clone()) | ||
| sln.fc2_bias = torch.nn.Parameter(ln.fc2_bias.clone()) | ||
ptrendx marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ln_list.append(ln) | ||
| sln_list.append(sln) | ||
|
|
||
| ln_model = _Sequential(*ln_list) | ||
| sln_model = _Sequential(*sln_list) | ||
|
|
||
| return ln_model, sln_model | ||
|
|
||
| config = { | ||
| "small": ModelConfig(128, 512, 12), | ||
| "medium": ModelConfig(512, 2048, 12), | ||
| "large": ModelConfig(1024, 4096, 12), | ||
| "huge": ModelConfig(2048, 8192, 12), | ||
| } | ||
|
|
||
| seq_sizes = [2**7, 2**10, 2**14, 2**16] | ||
|
|
||
| def _warmup(model, tensor): | ||
| for _ in range(3): | ||
| model(tensor).sum().backward() | ||
|
|
||
| def _run_fwd(model, tensor): | ||
|
|
||
| torch.cuda.reset_peak_memory_stats(device) | ||
| start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event( | ||
| enable_timing=True | ||
| ) | ||
|
|
||
| torch.cuda.synchronize() | ||
| start_mem = torch.cuda.memory_allocated(device) | ||
| start_time.record() | ||
| out = model(tensor) | ||
| end_time.record() | ||
| end_time.synchronize() | ||
| elapsed = start_time.elapsed_time(end_time) | ||
| peak_mem = torch.cuda.max_memory_allocated(device) | ||
| mem = float(peak_mem - start_mem) | ||
|
|
||
| return out, elapsed, mem | ||
|
|
||
| def _run_bwd(model, out): | ||
|
|
||
| model.zero_grad(set_to_none=False) | ||
| loss = out.sum() | ||
|
|
||
| torch.cuda.reset_peak_memory_stats(device) | ||
| start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event( | ||
| enable_timing=True | ||
| ) | ||
|
|
||
| torch.cuda.synchronize() | ||
| start_mem = torch.cuda.memory_allocated(device) | ||
| start_time.record() | ||
| loss.backward() | ||
| end_time.record() | ||
| end_time.synchronize() | ||
| elapsed = start_time.elapsed_time(end_time) | ||
| peak_mem = torch.cuda.max_memory_allocated(device) | ||
| mem = float(peak_mem - start_mem) | ||
|
|
||
| param_grads = _collect_param_grads(model) | ||
| return param_grads, elapsed, mem | ||
|
|
||
| def _max_diff(ref, other): | ||
| """Return max absolute difference between two tensors or collections.""" | ||
| if ref is None or other is None: | ||
| return 0.0 | ||
| if isinstance(ref, (list, tuple)): | ||
| diffs = [_max_diff(r, o) for r, o in zip(ref, other)] | ||
| return max(diffs) if diffs else 0.0 | ||
| return torch.max(torch.abs(ref.detach() - other.detach())).item() | ||
|
|
||
| def _collect_param_grads(model): | ||
| grads = {} | ||
| for name, param in model.named_parameters(): | ||
| if param.grad is None: | ||
| continue | ||
| key = _param_key(name) | ||
| if key is not None: | ||
| grads[key] = param.grad.detach().clone() | ||
| return grads | ||
|
|
||
| def _param_key(name): | ||
| return name.split(".")[-1] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("size", config.keys()) | ||
| @pytest.mark.parametrize("seq_size", seq_sizes) | ||
| def test_selective_activation_checkpoint(size, seq_size): | ||
|
|
||
| ln_model, sln_model = config[size].build() | ||
| data = torch.randn((seq_size, config[size]._hidden_size), device=device) | ||
|
|
||
| _warmup(ln_model, data.clone()) | ||
| ln_fwd_out, ln_fwd_time, ln_fwd_mem = _run_fwd(ln_model, data.clone()) | ||
| ln_grads, ln_bwd_time, ln_bwd_mem = _run_bwd(ln_model, ln_fwd_out) | ||
|
|
||
| _warmup(sln_model, data.clone()) | ||
| sln_fwd_out, sln_fwd_time, sln_fwd_mem = _run_fwd(sln_model, data.clone()) | ||
| sln_grads, sln_bwd_time, sln_bwd_mem = _run_bwd(sln_model, sln_fwd_out) | ||
|
|
||
| assert ln_fwd_mem > 6*sln_fwd_mem, f"selective activation checkpointing does not reduce forward memory by 6X, only by {ln_fwd_mem/sln_fwd_mem}!" | ||
| assert ln_bwd_time < sln_bwd_time, "selective activation activation checkpointing backward pass is slower than native!" | ||
| assert _max_diff(ln_fwd_out, sln_fwd_out)==0.0, "outputs are not equal!" | ||
| for key in [ | ||
| "layer_norm_weight", | ||
| "layer_norm_bias", | ||
| "fc1_weight", | ||
| "fc1_bias", | ||
| "fc2_weight", | ||
| "fc2_bias", | ||
| ]: | ||
| assert _max_diff(ln_grads[key], sln_grads[key])==0.0, f"gradients for {key} are not equal!" | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A general comment about this file - it is really nice, but it is not a test - it doesn't actually test anything, it just measures. We could introduce some test functionality here by e.g. ensuring that the error between the checkpointed LayerNormMLP is zero (since this shouldn't affect numerics) or that the memory used is lower (ideally we would quantify the expected memory usage and test against that, but for now even just making sure that the memory usage goes down would be good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I converted it into a test for checking that memory goes down at least 6X in the forward pass. I also asserted that checkpointing is slower than not checkpointing in the backward pass (not sure if this is helpful, but let me know), and that the differences are 0. I put this test in tests/pytorch/layernorm_mlp/test_selective_activation_checkpointing.py because I wasn't sure where it fit in the rest of the testing scripts, but let me know if this test would be better elsewhere!