From b72c9414a548eb6165a1cd163ae01808f0913770 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 27 Aug 2024 21:46:28 +0800 Subject: [PATCH 1/6] update doc on torch version --- test/prototype/test_low_bit_optim.py | 10 +++++----- torchao/prototype/low_bit_optim/README.md | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index afeefa2239..37d284a2a1 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -14,7 +14,7 @@ from torch.testing._internal.common_fsdp import FSDPTest from torchao.prototype import low_bit_optim from torchao.prototype.low_bit_optim.quant_utils import quantize_8bit_with_qmap, quantize_4bit_with_qmap -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 try: import bitsandbytes as bnb @@ -75,7 +75,7 @@ def test_quantize_4bit_with_qmap_compile(self, device): class TestOptim(TestCase): - @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3") + @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") @parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"]) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) @@ -83,7 +83,7 @@ def test_optim_smoke(self, optim_name, dtype, device): if optim_name.endswith("Fp8") and device == "cuda" and torch.cuda.get_device_capability() < (8, 9): pytest.skip("FP8 requires compute capability >= 8.9") if optim_name.endswith("4bit") and not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("4-bit Adam requires PyTorch > 2.4") + pytest.skip("4-bit Adam requires PyTorch >= 2.5") # reset cache to avoid hitting cache_size_limit, since the function will re-compile for each test torch._dynamo.reset_code_caches() @@ -100,7 +100,7 @@ def test_optim_smoke(self, optim_name, dtype, device): @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") @pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") - @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3") + @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) def test_optim_8bit_correctness(self, optim_name): device = "cuda" @@ -128,7 +128,7 @@ def test_optim_8bit_correctness(self, optim_name): @pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle") @pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA") - @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3") + @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_5, reason="requires PyTorch >= 2.5") @parametrize("optim_name", ["Adam4bit", "AdamW4bit"]) def test_optim_4bit_correctness(self, optim_name): device = "cuda" diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index 5968b2a79b..d201105197 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -24,7 +24,7 @@ To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. Y **Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand. NOTE: -- The low-bit optimizers require PyTorch >= 2.3. FP8 optimizers require CUDA compute capability >= 8.9. +- 8-bit optimizers require PyTorch >= 2.3. 4-bit optimizers require PyTorch >= 2.5. FP8 optimizers require CUDA compute capability >= 8.9. - For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. - The first training step is expected to be slow since the optimizer needs to be compiled. From 5fe3cfe905502dbbe8653ec3145f713ad1e43531 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 27 Aug 2024 21:51:47 +0800 Subject: [PATCH 2/6] update doc --- torchao/prototype/low_bit_optim/README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index d201105197..963c2e45eb 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -24,8 +24,9 @@ To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. Y **Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand. NOTE: -- 8-bit optimizers require PyTorch >= 2.3. 4-bit optimizers require PyTorch >= 2.5. FP8 optimizers require CUDA compute capability >= 8.9. -- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. +- 8-bit optimizers: PyTorch >= 2.3 is required. +- 4-bit optimizers: PyTorch >= 2.5 is required. Additionally, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. +- FP8 optimizers: PyTorch >= 2.3 and CUDA compute capability >= 8.9 are required. - The first training step is expected to be slow since the optimizer needs to be compiled. ## Benchmarks From 3fcad81b94c6ee9bf8f573c7fb9754cbf5f12458 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 27 Aug 2024 22:45:16 +0800 Subject: [PATCH 3/6] update --- test/prototype/test_low_bit_optim.py | 15 +++++++++------ torchao/prototype/low_bit_optim/README.md | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 37d284a2a1..d637fa59b8 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -14,7 +14,7 @@ from torch.testing._internal.common_fsdp import FSDPTest from torchao.prototype import low_bit_optim from torchao.prototype.low_bit_optim.quant_utils import quantize_8bit_with_qmap, quantize_4bit_with_qmap -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5 try: import bitsandbytes as bnb @@ -75,13 +75,16 @@ def test_quantize_4bit_with_qmap_compile(self, device): class TestOptim(TestCase): - @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") @parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"]) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) def test_optim_smoke(self, optim_name, dtype, device): - if optim_name.endswith("Fp8") and device == "cuda" and torch.cuda.get_device_capability() < (8, 9): - pytest.skip("FP8 requires compute capability >= 8.9") + if optim_name.endswith("Fp8") and device == "cuda": + if not TORCH_VERSION_AT_LEAST_2_4: + pytest.skip("FP8 CUDA requires PyTorch >= 2.4") + if torch.cuda.get_device_capability() < (8, 9): + pytest.skip("FP8 requires compute capability >= 8.9") if optim_name.endswith("4bit") and not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("4-bit Adam requires PyTorch >= 2.5") @@ -100,7 +103,7 @@ def test_optim_smoke(self, optim_name, dtype, device): @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") @pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") - @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) def test_optim_8bit_correctness(self, optim_name): device = "cuda" @@ -128,7 +131,7 @@ def test_optim_8bit_correctness(self, optim_name): @pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle") @pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA") - @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_5, reason="requires PyTorch >= 2.5") + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires PyTorch >= 2.5") @parametrize("optim_name", ["Adam4bit", "AdamW4bit"]) def test_optim_4bit_correctness(self, optim_name): device = "cuda" diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index 963c2e45eb..74504dfe19 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -26,7 +26,7 @@ To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. Y NOTE: - 8-bit optimizers: PyTorch >= 2.3 is required. - 4-bit optimizers: PyTorch >= 2.5 is required. Additionally, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. -- FP8 optimizers: PyTorch >= 2.3 and CUDA compute capability >= 8.9 are required. +- FP8 optimizers: For CPU, PyTorch >= 2.3 is required. For CUDA, PyTorch >= 2.4 and CUDA compute capability >= 8.9 are required. - The first training step is expected to be slow since the optimizer needs to be compiled. ## Benchmarks From 8298a35613f9ddbdc0c61398f184949e6fb5f99f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 27 Aug 2024 23:05:58 +0800 Subject: [PATCH 4/6] fix 4-bit problem --- test/prototype/test_low_bit_optim.py | 3 --- torchao/prototype/low_bit_optim/adam.py | 26 +++++++++++++++++-------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index d637fa59b8..f27f680107 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -85,8 +85,6 @@ def test_optim_smoke(self, optim_name, dtype, device): pytest.skip("FP8 CUDA requires PyTorch >= 2.4") if torch.cuda.get_device_capability() < (8, 9): pytest.skip("FP8 requires compute capability >= 8.9") - if optim_name.endswith("4bit") and not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("4-bit Adam requires PyTorch >= 2.5") # reset cache to avoid hitting cache_size_limit, since the function will re-compile for each test torch._dynamo.reset_code_caches() @@ -131,7 +129,6 @@ def test_optim_8bit_correctness(self, optim_name): @pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle") @pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA") - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires PyTorch >= 2.5") @parametrize("optim_name", ["Adam4bit", "AdamW4bit"]) def test_optim_4bit_correctness(self, optim_name): device = "cuda" diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 4b0b295343..e609696bd4 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -198,7 +198,7 @@ def __init__( @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState4bit.zeros(p.shape, signed, block_size, p.device) + return OptimState4bit.zeros(p.view(-1).shape, signed, block_size, p.device) @staticmethod def _unwrap_dtensor(p: Tensor): @@ -216,6 +216,11 @@ def step(self, closure=None): # NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim. # thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param. + # NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for + # PyTorch 2.3 and 2.4 + # calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op + # correctly for the tensor subclass. + # unwrap DTensor since DTensor does not work well with dynamic compile # flatten p, grad, and optim state to avoid recompilation for group, lr, (beta1, beta2), weight_decay, eps in param_groups: @@ -227,9 +232,9 @@ def step(self, closure=None): self._unwrap_dtensor(p).view(-1), self._unwrap_dtensor(grad).view(-1), step, - self._unwrap_dtensor(exp_avg).view(-1), - self._unwrap_dtensor(exp_avg_sq).view(-1), - self._unwrap_dtensor(max_exp_avg_sq).view(-1) if max_exp_avg_sq is not None else None, + self._unwrap_dtensor(exp_avg), + self._unwrap_dtensor(exp_avg_sq), + self._unwrap_dtensor(max_exp_avg_sq) if max_exp_avg_sq is not None else None, lr, beta1, beta2, @@ -296,7 +301,7 @@ def __init__( @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState4bit.zeros(p.shape, signed, block_size, p.device) + return OptimState4bit.zeros(p.view(-1).shape, signed, block_size, p.device) @staticmethod def _unwrap_dtensor(p: Tensor): @@ -314,6 +319,11 @@ def step(self, closure=None): # NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim. # thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param. + # NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for + # PyTorch 2.3 and 2.4 + # calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op + # correctly for the tensor subclass. + # unwrap DTensor since DTensor does not work well with dynamic compile # flatten p, grad, and optim state to avoid recompilation for group, lr, (beta1, beta2), weight_decay, eps in param_groups: @@ -325,9 +335,9 @@ def step(self, closure=None): self._unwrap_dtensor(p).view(-1), self._unwrap_dtensor(grad).view(-1), step, - self._unwrap_dtensor(exp_avg).view(-1), - self._unwrap_dtensor(exp_avg_sq).view(-1), - self._unwrap_dtensor(max_exp_avg_sq).view(-1) if max_exp_avg_sq is not None else None, + self._unwrap_dtensor(exp_avg), + self._unwrap_dtensor(exp_avg_sq), + self._unwrap_dtensor(max_exp_avg_sq) if max_exp_avg_sq is not None else None, lr, beta1, beta2, From ed4aa715d00ed1727f5dc54c6d9cb89b5598cccf Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 27 Aug 2024 23:11:58 +0800 Subject: [PATCH 5/6] update doc --- torchao/prototype/low_bit_optim/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index 74504dfe19..b1f955e656 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -24,9 +24,9 @@ To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. Y **Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand. NOTE: -- 8-bit optimizers: PyTorch >= 2.3 is required. -- 4-bit optimizers: PyTorch >= 2.5 is required. Additionally, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. -- FP8 optimizers: For CPU, PyTorch >= 2.3 is required. For CUDA, PyTorch >= 2.4 and CUDA compute capability >= 8.9 are required. +- The low-bit optimizers require PyTorch >= 2.3 +- For FP8 optimizers on CUDA, PyTorch >= 2.4 and CUDA compute capability >= 8.9 are required. +- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. - The first training step is expected to be slow since the optimizer needs to be compiled. ## Benchmarks From 4cea083bc7b020dc48c18d18e0fedca65cc15b2e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 27 Aug 2024 23:16:10 +0800 Subject: [PATCH 6/6] update --- test/prototype/test_low_bit_optim.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index f27f680107..701d90e22c 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -127,8 +127,10 @@ def test_optim_8bit_correctness(self, optim_name): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) + # this will not run in CI because we can't install lpmm @pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle") @pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA") + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") @parametrize("optim_name", ["Adam4bit", "AdamW4bit"]) def test_optim_4bit_correctness(self, optim_name): device = "cuda"