From 7683f61aa3157f9f30f38e0aa9310ef10c0928bd Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 14 Oct 2024 20:41:08 +0800 Subject: [PATCH 1/2] fix dtype --- torchao/prototype/low_bit_optim/adam.py | 29 +++++++++++++------------ 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index b087fc8888..980b60f9a1 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -109,9 +109,6 @@ def step(self, closure=None): # this will work with any optim state tensor subclass that implements aten.lerp.Scalar and aten.copy_.default # and param tensor subclass that implements aten.add_.Tensor, and aten.addcdiv_.default -# NOTE: right now all of our optimizer state subclasses will dequant to FP32, thus adam computation -# will be done in FP32 (not purposely). we should explicitly cast all inputs to FP32 to ensure FP32 -# computation. will need to benchmark to ensure no slowdown. def single_param_adam( p: Tensor, grad: Tensor, @@ -126,32 +123,36 @@ def single_param_adam( eps: float, is_adamw: bool, ): + # compute in FP32 for accurate calculations + p_f32 = p.float() + grad_f32 = grad.float() + if not is_adamw: - grad = grad.add(p, alpha=weight_decay) + grad_f32 = grad_f32.add(p_f32, alpha=weight_decay) bias_correction1 = 1 - beta1**step bias_correction2 = 1 - beta2**step # keep high precision copy for param update - new_exp_avg = exp_avg.lerp(grad, 1 - beta1) - new_exp_avg_sq = exp_avg_sq.lerp(grad.square(), 1 - beta2) + exp_avg_f32 = exp_avg.float().lerp(grad_f32, 1 - beta1) + exp_avg_sq_f32 = exp_avg_sq.float().lerp(grad_f32.square(), 1 - beta2) - exp_avg.copy_(new_exp_avg) - exp_avg_sq.copy_(new_exp_avg_sq) + exp_avg.copy_(exp_avg_f32) + exp_avg_sq.copy_(exp_avg_sq_f32) if max_exp_avg_sq is not None: - new_max_exp_avg_sq = torch.maximum(max_exp_avg_sq, new_exp_avg_sq) - max_exp_avg_sq.copy_(new_max_exp_avg_sq) - denom = (new_max_exp_avg_sq.sqrt() / bias_correction2.sqrt()).add_(eps) + max_exp_avg_sq_f32 = torch.maximum(max_exp_avg_sq.float(), exp_avg_sq_f32) + max_exp_avg_sq.copy_(max_exp_avg_sq_f32) + denom = (max_exp_avg_sq_f32.sqrt() / bias_correction2.sqrt()).add_(eps) else: - denom = (new_exp_avg_sq.sqrt() / bias_correction2.sqrt()).add_(eps) + denom = (exp_avg_sq_f32.sqrt() / bias_correction2.sqrt()).add_(eps) step_size = lr / bias_correction1 if is_adamw: # merge weight decay and param update in a single .add_() to make this work with quantized param - p.add_(-lr * weight_decay * p - step_size * new_exp_avg / denom) + p.add_(-lr * weight_decay * p_f32 - step_size * exp_avg_f32 / denom) else: - p.addcdiv_(new_exp_avg, denom, value=-step_size) + p.addcdiv_(exp_avg_f32, denom, value=-step_size) class Adam8bit(_AdamBase): From 72ca83409af6ad7b1489902f9ade40b4ba69913e Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 14 Oct 2024 11:50:13 -0700 Subject: [PATCH 2/2] Update regression_test.yml --- .github/workflows/regression_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 876a957ede..4b23adfab6 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -67,7 +67,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: - timeout: 60 + timeout: 120 runner: ${{ matrix.runs-on }} gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }}