Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[low-bit optim] Add coat for float8 optimizer #1231

Draft
wants to merge 52 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
c62fcd3
added dynamic range expansion
MirMustafaAli Nov 6, 2024
b11b4f6
created optimstate with DRE class
MirMustafaAli Nov 6, 2024
b887e69
implement copy_.default for OptimStateFp8WithDynamicRangeExpansion cl…
MirMustafaAli Nov 6, 2024
ab02605
implements _to_copy
MirMustafaAli Nov 6, 2024
43f5c08
removed implemented classes
MirMustafaAli Nov 6, 2024
7d98b15
dynamic_range_expansion -> apply_dynamic_range_expansion
MirMustafaAli Nov 6, 2024
e03c79c
add DRE flags to class
MirMustafaAli Nov 6, 2024
5faa7de
implementing contraction for dequantize
MirMustafaAli Nov 6, 2024
b43c88f
copy k values as well for copy method
MirMustafaAli Nov 6, 2024
ac41627
added dynamic range expansion
MirMustafaAli Nov 6, 2024
79c9461
created optimstate with DRE class
MirMustafaAli Nov 6, 2024
47a7bb0
implement copy_.default for OptimStateFp8WithDynamicRangeExpansion cl…
MirMustafaAli Nov 6, 2024
a162f94
implements _to_copy
MirMustafaAli Nov 6, 2024
5c1a3f4
removed implemented classes
MirMustafaAli Nov 6, 2024
1458d65
dynamic_range_expansion -> apply_dynamic_range_expansion
MirMustafaAli Nov 6, 2024
42fbb09
add DRE flags to class
MirMustafaAli Nov 6, 2024
9d1c00c
implementing contraction for dequantize
MirMustafaAli Nov 6, 2024
7bc6ea4
copy k values as well for copy method
MirMustafaAli Nov 6, 2024
7be5a6b
Merge branch 'add_coat_optimizer' of https://github.com/MirMustafaAli…
MirMustafaAli Nov 8, 2024
70937c8
combine range_expansion into quantize_fp8 function
MirMustafaAli Nov 8, 2024
3583de7
passing apply_range_expansion to quantize_fp8
MirMustafaAli Nov 8, 2024
c754893
remove apply_dynamic_range_expansion method
MirMustafaAli Nov 8, 2024
c47b987
pass destination's dynamic range expasnsion variable to quantize fp8
MirMustafaAli Nov 8, 2024
7a754ce
change type annotation to optional
MirMustafaAli Nov 8, 2024
4d37d86
k is none when dynamic range expansion is False
MirMustafaAli Nov 9, 2024
2d1834a
referencing paper for calculation of dynamic range expansion
MirMustafaAli Nov 9, 2024
3d0d5d6
replaced condition check using variable k
MirMustafaAli Nov 9, 2024
c413ac4
added parameter dynamic_range_expansion
MirMustafaAli Nov 9, 2024
c3f5d29
pass bool condition for quantizing src tensor
MirMustafaAli Nov 9, 2024
1ec9335
readded the torchversion safe_global exports
MirMustafaAli Nov 9, 2024
122530e
initialize k to none and later assign value if dynamic range expansio…
MirMustafaAli Nov 9, 2024
77e1371
conditional statement by checking if k is None instead of directly ap…
MirMustafaAli Nov 9, 2024
366743c
checking if k is available in dst to copy it
MirMustafaAli Nov 9, 2024
38951ae
matching parameters counts with constructor of optimStateFp8
MirMustafaAli Nov 12, 2024
4b3fb6b
copy to k tensor only if k is not None
MirMustafaAli Nov 12, 2024
7185b00
passing k tensor if values are available
MirMustafaAli Nov 12, 2024
0d7edae
providing dynamic range expansion to the adamfloat8 class
MirMustafaAli Nov 12, 2024
58ff635
change of _subclass_zeros from static method to normal class method
MirMustafaAli Nov 12, 2024
6c536a9
added dynamic range expansion to adamwfp8
MirMustafaAli Nov 12, 2024
767ccab
adding smoke test for additional parameters for float8 optimizers
MirMustafaAli Nov 12, 2024
8fa5e3d
added new line
MirMustafaAli Nov 13, 2024
f34bfdd
remove newline
MirMustafaAli Nov 13, 2024
41598a0
removed optim_addon parameter
MirMustafaAli Nov 13, 2024
c189dc7
rename test_optim_addon to test_optim_fp8_coat_smoke
MirMustafaAli Nov 13, 2024
6bb49ea
code formatting
MirMustafaAli Nov 13, 2024
6707425
Merge branch 'main' into add_coat_optimizer
MirMustafaAli Nov 13, 2024
b1aea26
Moved device compatibility check for FP8 optimizer tests from pytest …
MirMustafaAli Nov 13, 2024
92ca7b2
formatting for `ruff check F,I`
MirMustafaAli Nov 13, 2024
861423d
removing duplicate
MirMustafaAli Nov 13, 2024
7661b61
checking if device is cuda before calling device capability
MirMustafaAli Nov 13, 2024
e1fa683
Updating Readme with dynamic range Expansion and Reference to Paper
MirMustafaAli Nov 13, 2024
62eac8b
Merge branch 'main' into add_coat_optimizer
MirMustafaAli Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def __init__(
*,
block_size=256,
bf16_stochastic_round=False,
dynamic_range_expansion=False,
) -> None:
super().__init__(
params,
Expand All @@ -265,11 +266,28 @@ def __init__(
bf16_stochastic_round=bf16_stochastic_round,
is_adamw=False,
)
self.dynamic_range_expansion = dynamic_range_expansion

@staticmethod
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
return OptimStateFp8.zeros(p.shape, block_size, p.device)
def _subclass_zeros(p: Tensor, signed: bool, block_size: int, dynamic_range_expansion: bool):
return OptimStateFp8.zeros(p.shape, block_size, p.device, dynamic_range_expansion)
MirMustafaAli marked this conversation as resolved.
Show resolved Hide resolved

def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = DTensor.from_local(
local_tensor=self._subclass_zeros(
p.to_local(), signed, self.block_size, self.dynamic_range_expansion
),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
)
else:
out = self._subclass_zeros(p, signed, self.block_size, self.dynamic_range_expansion)
else:
out = torch.zeros_like(p)
return out

class AdamW8bit(_AdamBase):
def __init__(
Expand Down
51 changes: 41 additions & 10 deletions torchao/prototype/low_bit_optim/subclass_fp8.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch import Tensor
from typing import Optional
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor
Expand All @@ -11,13 +12,27 @@
DTYPE = torch.float8_e4m3fn


def quantize_fp8(input: Tensor, block_size: int):
def quantize_fp8(input: Tensor, block_size: int, dynamic_range_expansion: bool):

shape = input.shape
input = input.view(-1, block_size)
k = None

if dynamic_range_expansion:
# NOTE: the calculation is from the paper https://arxiv.org/abs/2410.19313
# The idea is to align optimizer state distributions more closely
# with the FP8 representation range, reducing the quantization error.
k = aten.ones(input.shape[0], dtype=DTYPE, device=input.device)
Rdtype = torch.finfo(DTYPE).max / torch.finfo(DTYPE).min # calculate the range of the dtype
Rx = input.abs().amax(-1).clip(1e-12) / input.abs().amin(-1).clip(1e-12) # range of input max and min
k = torch.log(Rdtype) / torch.log(Rx)# calculating optimal value k dynamically
input = input.sign() * (input.abs() ** k.view(-1, 1))

scale = input.abs().amax(-1).clip(1e-12) / torch.finfo(DTYPE).max
input = input / scale.view(-1, 1)
codes = input.to(DTYPE).view(-1)
return codes.view(shape), scale

return codes.view(shape), scale, k


# NOTE: FP8 sign bit is redundant for unsigned optim state.
Expand All @@ -27,10 +42,10 @@ class OptimStateFp8(TorchAOBaseTensor):
tensor_attrs = ["codes", "scale"]

@staticmethod
def __new__(cls, codes: Tensor, scale: Tensor):
def __new__(cls, codes: Tensor, scale: Tensor, k: Optional[Tensor] =None):
return Tensor._make_wrapper_subclass(cls, codes.shape, device=codes.device)

def __init__(self, codes: Tensor, scale: Tensor):
def __init__(self, codes: Tensor, scale: Tensor, k: Optional[Tensor] =None):
"""Create quantized FP8 optimizer state.

Args
Expand All @@ -45,6 +60,7 @@ def __init__(self, codes: Tensor, scale: Tensor):
assert scale.ndim == 1
self.codes = codes
self.scale = scale
self.k = k
self.block_size = codes.numel() // scale.numel()

def __tensor_flatten__(self):
Expand All @@ -62,15 +78,22 @@ def dequantize(self, output_dtype=None):
float_data = self.codes.float()
float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1)

if self.k is not None:
float_data = float_data.view(-1, self.block_size)
float_data = float_data ** (1 / self.k.view(-1, 1))

if output_dtype is not None:
float_data = float_data.to(output_dtype)

return float_data.view(self.codes.shape)

@classmethod
def zeros(cls, shape, block_size: int = 256, device=None):
def zeros(cls, shape, block_size: int = 256, device=None, dynamic_range_expansion: bool = False):

codes = torch.zeros(shape, dtype=DTYPE, device=device)
scale = torch.zeros(codes.numel() // block_size, device=device)
return cls(codes, scale)
k = torch.ones(codes.numel() // block_size, device=device) if dynamic_range_expansion else None
return cls(codes, scale, k)

def __repr__(self):
return (
Expand All @@ -88,12 +111,18 @@ def _(func, types, args, kwargs):
assert dst.block_size == src.block_size
dst.codes.copy_(src.codes)
dst.scale.copy_(src.scale)
if dst.k is not None:
dst.k.copy_(src.k)

elif isinstance(dst, OptimStateFp8):
codes, scale = quantize_fp8(src, dst.block_size)

codes, scale, k = quantize_fp8(src, dst.block_size, True if dst.k is not None else False)

dst.codes.copy_(codes)
dst.scale.copy_(scale)


if dst.k is not None:
dst.k.copy_(k)
else:
dst.copy_(src.dequantize())

Expand All @@ -107,6 +136,7 @@ def _(func, types, args, kwargs):
out = OptimStateFp8(
args[0].codes.to(device=device),
args[0].scale.to(device=device),
args[0].k.to(device=device) if args[0].k is not None else None
)
return return_and_correct_aliasing(func, args, kwargs, out)

Expand All @@ -121,7 +151,7 @@ def _(func, types, args, kwargs):
@OptimStateFp8.implements(aten.view.default)
def _(func, types, args, kwargs):
x, shape = args
return OptimStateFp8(x.codes.view(shape), x.scale)
return OptimStateFp8(x.codes.view(shape), x.scale, x.k)


# this is needed for DTensor.full_tensor()
Expand All @@ -142,10 +172,11 @@ def _(func, types, args, kwargs):
return OptimStateFp8(
func(x.codes, *args[1:], **kwargs),
func(x.scale, *args[1:], **kwargs),
func(x.k, *args[1:], **kwargs) if x.k else None
)


if TORCH_VERSION_AT_LEAST_2_5:
from torch.serialization import add_safe_globals

add_safe_globals([OptimStateFp8])
add_safe_globals([OptimStateFp8])
gau-nernst marked this conversation as resolved.
Show resolved Hide resolved