diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 0979dd5cf2..ccf925a3fd 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -75,7 +75,7 @@ def test_quantize_4bit_with_qmap_compile(self, device): class TestOptim(TestCase): - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.3") + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") @parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"]) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) @@ -84,10 +84,7 @@ def test_optim_smoke(self, optim_name, dtype, device): if not TORCH_VERSION_AT_LEAST_2_4: pytest.skip("FP8 CUDA requires PyTorch >= 2.4") if torch.cuda.get_device_capability() < (8, 9): - pytest.skip("FP8 requires compute capability >= 8.9") - - # reset cache to avoid hitting cache_size_limit, since the function will re-compile for each test - torch._dynamo.reset_code_caches() + pytest.skip("FP8 CUDA requires compute capability >= 8.9") model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32)) model.to(device=device, dtype=dtype) @@ -232,12 +229,11 @@ def world_size(self) -> int: return 2 @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="OptimState8bit dispatch: attempting to run unimplemented operator/function: aten.as_strided.default") - @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="https://github.com/pytorch/ao/issues/652") @skip_if_lt_x_gpu(2) def test_fsdp2(self): - optim_classes = [low_bit_optim.Adam8bit, low_bit_optim.Adam4bit] + optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit] if torch.cuda.get_device_capability() >= (8, 9): - optim_classes.append(low_bit_optim.AdamFp8) + optim_classes.append(low_bit_optim.AdamWFp8) self.run_subtests( {"optim_cls": optim_classes}, @@ -252,9 +248,6 @@ def _test_fsdp2(self, optim_cls): TransformerBlock, ) - # seems like cache_size_limit is shared between FSDP processes? - torch._dynamo.config.cache_size_limit = 8 * self.world_size - batch_size = 3 vocab_size = 1024 seq_len = 64 diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index 9a423d424c..7781386bdd 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -27,45 +27,35 @@ NOTE: - The low-bit optimizers require PyTorch >= 2.3 - For FP8 optimizers on CUDA, PyTorch >= 2.4 and CUDA compute capability >= 8.9 are required. - For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. -- The first training step is expected to be slow since the optimizer needs to be compiled. ## Benchmarks -Fine-tune [timm](https://github.com/huggingface/pytorch-image-models)'s ViT-H (630M params) on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset. BF16 AMP, 1 epoch, batch size 8, cosine LR scheduler, 4070Ti SUPER, fixed random seed. Benchmark script is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py). +Fine-tune [timm](https://github.com/huggingface/pytorch-image-models)'s [ViT-H](https://huggingface.co/timm/vit_huge_patch14_224.orig_in21k) (630M params) on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset. PyTorch 2.4, BF16 AMP, compiled model, 1 epoch, batch size 8, cosine LR scheduler, 4070Ti SUPER, fixed random seed. Benchmark script is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py). -AdamW impl | Max memory (GB) | imgs/s | accuracy -----------------|-----------------|--------|---------- -PyTorch (fused) | 12.23 | 41.8 | 94.38 -bnb 8-bit | 8.32 | 43.6 | 94.18 -ao 8-bit | 8.33 | 42.6 | 94.25 -ao FP8 E4M3 | 9.27 | 44.1 | 94.40 -lpmm 4-bit | 7.72 | 46.0 | 94.29 -ao 4-bit | 7.72 | 40.0 | 94.03 -lpmm 4-bit (*) | 7.74 | 26.6 | 94.25 +AdamW impl | Peak memory allocated (GB) | imgs/s | accuracy +----------------|----------------------------|--------|---------- +PyTorch (fused) | 12.23 | 41.9 | 94.52 +bnb 8-bit | 8.32 | 43.6 | 94.54 +ao 8-bit | 8.33 | 42.5 | 94.30 +ao FP8 E4M3 | 8.33 | 43.2 | 94.13 +lpmm 4-bit | 7.72 | 46.1 | 94.40 +ao 4-bit | 7.72 | 42.4 | 94.13 +lpmm 4-bit (*) | 7.74 | 26.7 | 94.10 (*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details. -Fine-tune [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b) on [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset. Full BF16, 1 epoch, A100, fixed random seed. Benchmark is done with [torchtune](https://github.com/pytorch/torchtune). See [#746](https://github.com/pytorch/ao/pull/746) for more details. +Fine-tune [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b) on [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset. PyTorch 2.4, full BF16, 1 epoch, A100, fixed random seed. Benchmark is done with [torchtune 52d1b838](https://github.com/pytorch/torchtune/tree/52d1b838c1c35b5e75fddf8776be400adc36dff5). See [#812](https://github.com/pytorch/ao/pull/812) for more details. -AdamW impl | Max memory (GB) | toks/s | `truthfulqa_mc2` acc | Compile time ------------------|-----------------|--------|----------------------|------------- -Not fine-tuned | - | - | 38.95 | - -PyTorch (fused) | 52 | ~4500 | 42.12 | ~4 min -bnb 8-bit | 39 | ~4000 | 41.98 | ~4 min -ao 8-bit | 39 | ~4000 | 42.41 | ~12 min -ao 4-bit | 33 | ~3600 | 42.34 | ~4 min +AdamW impl | Peak memory allocated (GB) | toks/s | `truthfulqa_mc2` acc +-----------------|----------------------------|--------|---------------------- +Not fine-tuned | - | - | 38.95 +PyTorch (fused) | 51.6 | 3200 | 42.61 +bnb 8-bit | 39.3 | 3000 | 42.75 +ao 8-bit | 39.1 | 2900 | 41.50 +ao 4-bit | 33.2 | 2900 | 42.27 NOTE: lpmm's 4-bit AdamW does not support BF16 weights. -### Note on compile times - -There are 2 approaches to compile optimizer step in low-bit optim: - -1. Compile optim step for single param i.e. `torch.compile(single_param_adam)` -2. Compile optim step for all params i.e. `torch.compile(param_groups_adam)` - -Currently Adam8bit and AdamFp8 use approach (2) (with static shape) since it is faster (but compile much slower), while Adam4bit uses approach (1) (with dynamic shape) since there are excessive memory usage for "Adam4bit + approach (2)". Approach (1) requires dynamic shape to avoid hitting recompiles limit. - ## Optimizer CPU offload This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. For multi-GPU training, you can use FSDP's built-in CPU offload. diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index e609696bd4..7f0d47854b 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -52,52 +52,6 @@ def _new_buffer(self, p: Tensor, signed: bool): out = torch.zeros_like(p) return out - def _prepare_param_groups(self): - param_groups = [] - - for group in self.param_groups: - _group = [] - - for p in group["params"]: - if p.grad is None: - continue - - grad = p.grad - if grad.is_sparse: - raise RuntimeError("Sparse gradient is not supported") - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = torch.tensor(0.0) - state["exp_avg"] = self._new_buffer(p, True) - state["exp_avg_sq"] = self._new_buffer(p, False) - if group["amsgrad"]: - state["max_exp_avg_sq"] = self._new_buffer(p, False) - - state["step"] += 1 - - if not isinstance(group["lr"], Tensor): - raise RuntimeError( - "lr was changed to a non-Tensor object. If you want to update lr, please use " - "optim.param_groups[0]['lr'].fill_(new_lr)" - ) - - p_grad_state = ( - p, - grad, - state["step"], - state["exp_avg"], - state["exp_avg_sq"], - state.get("max_exp_avg_sq", None), - ) - _group.append(p_grad_state) - - param_groups.append((_group, group["lr"], group["betas"], group["weight_decay"], group["eps"])) - - return param_groups - @torch.no_grad() def step(self, closure=None): loss = None @@ -105,22 +59,56 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() - param_groups = self._prepare_param_groups() + # for a given model, the number of different argument combinations to single_param_adam() is fixed. + # thus, it is safe to disable cache limit without the risk of always re-compiling. + with torch._dynamo.utils.disable_cache_limit(): + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Sparse gradient is not supported") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = torch.tensor(0.0) + state["exp_avg"] = self._new_buffer(p, True) + state["exp_avg_sq"] = self._new_buffer(p, False) + if group["amsgrad"]: + state["max_exp_avg_sq"] = self._new_buffer(p, False) + + state["step"] += 1 + + if not isinstance(group["lr"], Tensor): + raise RuntimeError( + "lr was changed to a non-Tensor object. If you want to update lr, please use " + "optim.param_groups[0]['lr'].fill_(new_lr)" + ) + + torch.compile(single_param_adam, fullgraph=True, dynamic=False)( + p, + grad, + state["step"], + state["exp_avg"], + state["exp_avg_sq"], + state.get("max_exp_avg_sq", None), + group["lr"], + group["betas"][0], + group["betas"][1], + group["weight_decay"], + group["eps"], + self.is_adamw, + ) - # static compile optim step for all params in a single graph - torch.compile(param_groups_adam, fullgraph=True)(param_groups, self.is_adamw) return loss -def param_groups_adam(param_groups, is_adamw): - for group, lr, (beta1, beta2), weight_decay, eps in param_groups: - for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group: - single_param_adam( - p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, weight_decay, eps, is_adamw - ) - - # this will work with any optim state tensor subclass that implements aten.lerp.Scalar and aten.copy_.default +# and param tensor subclass that implements aten.add_.Tensor, and aten.addcdiv_.default def single_param_adam( p: Tensor, grad: Tensor, @@ -198,53 +186,7 @@ def __init__( @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState4bit.zeros(p.view(-1).shape, signed, block_size, p.device) - - @staticmethod - def _unwrap_dtensor(p: Tensor): - return p._local_tensor if isinstance(p, DTensor) else p - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - param_groups = self._prepare_param_groups() - - # NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim. - # thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param. - - # NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for - # PyTorch 2.3 and 2.4 - # calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op - # correctly for the tensor subclass. - - # unwrap DTensor since DTensor does not work well with dynamic compile - # flatten p, grad, and optim state to avoid recompilation - for group, lr, (beta1, beta2), weight_decay, eps in param_groups: - for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group: - # DTensor._local_tensor has .requires_grad = False - # to avoid recompilation, set p.requires_grad = False and restore it after optim step - p.requires_grad_(False) - torch.compile(single_param_adam, fullgraph=True, dynamic=True)( - self._unwrap_dtensor(p).view(-1), - self._unwrap_dtensor(grad).view(-1), - step, - self._unwrap_dtensor(exp_avg), - self._unwrap_dtensor(exp_avg_sq), - self._unwrap_dtensor(max_exp_avg_sq) if max_exp_avg_sq is not None else None, - lr, - beta1, - beta2, - weight_decay, - eps, - self.is_adamw, - ) - p.requires_grad_(True) - - return loss + return OptimState4bit.zeros(p.shape, signed, block_size, p.device) class AdamFp8(_AdamBase): @@ -301,53 +243,7 @@ def __init__( @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState4bit.zeros(p.view(-1).shape, signed, block_size, p.device) - - @staticmethod - def _unwrap_dtensor(p: Tensor): - return p._local_tensor if isinstance(p, DTensor) else p - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - param_groups = self._prepare_param_groups() - - # NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim. - # thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param. - - # NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for - # PyTorch 2.3 and 2.4 - # calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op - # correctly for the tensor subclass. - - # unwrap DTensor since DTensor does not work well with dynamic compile - # flatten p, grad, and optim state to avoid recompilation - for group, lr, (beta1, beta2), weight_decay, eps in param_groups: - for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group: - # DTensor._local_tensor has .requires_grad = False - # to avoid recompilation, set p.requires_grad = False and restore it after optim step - p.requires_grad_(False) - torch.compile(single_param_adam, fullgraph=True, dynamic=True)( - self._unwrap_dtensor(p).view(-1), - self._unwrap_dtensor(grad).view(-1), - step, - self._unwrap_dtensor(exp_avg), - self._unwrap_dtensor(exp_avg_sq), - self._unwrap_dtensor(max_exp_avg_sq) if max_exp_avg_sq is not None else None, - lr, - beta1, - beta2, - weight_decay, - eps, - self.is_adamw, - ) - p.requires_grad_(True) - - return loss + return OptimState4bit.zeros(p.shape, signed, block_size, p.device) class AdamWFp8(_AdamBase):