Skip to content

Commit

Permalink
[wip] store inv_scale on Float8Tensor
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 6ab4322ffb9456f244a1c3987842e10f3ecc2a83
ghstack-comment-id: 2273795212
Pull Request resolved: #628
  • Loading branch information
vkuzo committed Aug 7, 2024
1 parent 4f8bee4 commit 0e634b9
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 31 deletions.
6 changes: 5 additions & 1 deletion benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def main(
scaling_type_grad_output: str = "dynamic",
model_type: str = "linear",
dtype_filter: str = "both",
skip_amax_sync: bool = False,
):
assert model_type in ("linear", "ln_linear", "norm_ffn_norm", "norm_ffn_norm_small"), "unsupported"
assert dtype_filter in ("both", "float8", "bfloat16")
Expand All @@ -220,6 +221,9 @@ def main(
cast_config_input=CastConfig(scaling_type=scaling_type_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
# for now we don't care about amax init for performance profiling
enable_amax_init=False,
enable_pre_and_post_forward = not skip_amax_sync,
)
scaling_repr = "_".join(
[
Expand Down Expand Up @@ -290,7 +294,7 @@ def float8_forw_backward_wrapper(x):
# inspection of the fw+bw torch.compile without the scale
# syncing code
# TODO(future): make this better
if linear_requires_sync(config):
if linear_requires_sync(config) and not skip_amax_sync:
with record_function("scale_amax_and_scales"):
sync_amax_history(m_float8)
out = float8_forw(x)
Expand Down
7 changes: 4 additions & 3 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def test_copy_(self):
fp8_b = Float8Tensor(
torch.empty(16, dtype=torch.float8_e4m3fn),
scale_a,
scale_a.reciprocal(),
torch.bfloat16,
fp8_a._linear_mm_config,
)
Expand Down Expand Up @@ -417,14 +418,14 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):

out_scaled_mm = addmm_float8_unwrapped(
a_fp8._data,
a_fp8._scale,
a_fp8._inv_scale,
b_fp8._data,
b_fp8._scale,
b_fp8._inv_scale,
output_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
out_emulated = torch.ops.aten.mm_float8_emulated(
a_fp8._data, a_fp8._scale, b_fp8._data, b_fp8._scale, output_dtype
a_fp8._data, a_fp8._inv_scale, b_fp8._data, b_fp8._inv_scale, output_dtype
)

if output_dtype != base_dtype:
Expand Down
12 changes: 6 additions & 6 deletions torchao/float8/float8_aten_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

def mm_float8_emulated(
m1, # input 1 data
s1, # input 1 scale
inv_s1, # input 1 inverse scale
m2, # input 2 data
s2, # input 2 scale
inv_s2, # input 2 inverse scale
dtype3, # output dtype
):
# naive implementation: dq -> op -> q
m1_fp32 = m1.float() / s1
m2_fp32 = m2.float() / s2
m1_fp32 = m1.float() * inv_s1
m2_fp32 = m2.float() * inv_s2
m3_fp32 = torch.mm(m1_fp32, m2_fp32)

return m3_fp32.to(dtype3)
Expand All @@ -37,13 +37,13 @@ def mm_float8_emulated(
lib = Library("aten", "FRAGMENT")

lib.define(
"mm_float8_emulated(Tensor m1, Tensor s1, Tensor m2, Tensor s2, ScalarType dtype3) -> Tensor"
"mm_float8_emulated(Tensor m1, Tensor inv_s1, Tensor m2, Tensor inv_s2, ScalarType dtype3) -> Tensor"
)
lib.impl("mm_float8_emulated", mm_float8_emulated, "CPU")
lib.impl("mm_float8_emulated", mm_float8_emulated, "CUDA")


@torch.library.impl(lib, "mm_float8_emulated", "Meta")
def _mm_float8_emulated_meta(m1, s1, m2, s2, dtype3):
def _mm_float8_emulated_meta(m1, inv_s1, m2, inv_s2, dtype3):
out = torch.mm(m1.float(), m2.float()).to(dtype3)
return out
31 changes: 20 additions & 11 deletions torchao/float8/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def float8_desugar_op(aten_op, args, kwargs=None):
return Float8Tensor(
new_data,
args[0]._scale,
args[0]._inv_scale,
args[0]._orig_dtype,
args[0]._linear_mm_config,
args[0]._gemm_input_role,
Expand All @@ -62,6 +63,7 @@ def make_float8(data):
return Float8Tensor(
data,
args[0]._scale,
args[0]._inv_scale,
args[0]._orig_dtype,
args[0]._linear_mm_config,
args[0]._gemm_input_role,
Expand All @@ -78,6 +80,7 @@ def float8_cat(aten_op, args, kwargs=None):

orig_dtype = chunked_tensors[0]._orig_dtype
scale = chunked_tensors[0]._scale
inv_scale = chunked_tensors[0]._inv_scale
mm_config = chunked_tensors[0]._linear_mm_config
fp8_dtype = chunked_tensors[0]._data.dtype
gemm_input_role = chunked_tensors[0]._gemm_input_role
Expand Down Expand Up @@ -105,7 +108,7 @@ def float8_cat(aten_op, args, kwargs=None):

new_data = aten_op(chunk_data, *args[1:], **kwargs)
new_data = new_data.view(fp8_dtype)
return Float8Tensor(new_data, scale, orig_dtype, mm_config, gemm_input_role)
return Float8Tensor(new_data, scale, inv_scale, orig_dtype, mm_config, gemm_input_role)


@implements([aten.sum.dim_IntList])
Expand All @@ -130,7 +133,7 @@ def unwrap(x):

def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
a_data = a._data
a_scale = a._scale
a_inv_scale = a._inv_scale
b_data = b._data

scaled_mm_config = choose_scaled_mm_config(
Expand All @@ -151,8 +154,8 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
a_data = a_data.contiguous()
if is_row_major(b_data.stride()):
b_data = b_data.t().contiguous().t()
b_scale = b._scale
return a_data, a_scale, b_data, b_scale
b_inv_scale = b._inv_scale
return a_data, a_inv_scale, b_data, b_inv_scale


@implements([aten.mm.default, aten.matmul.default])
Expand All @@ -165,7 +168,7 @@ def float8_mm(aten_op, args, kwargs=None):
), "Expecting both Float8Tensor for mm inputs but found {} and {}".format(
type(a), type(b)
)
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
a_data, a_inv_scale, b_data, b_inv_scale = preprocess_addmm(a, b)
output_dtype = a._orig_dtype
scaled_mm_config = choose_scaled_mm_config(
a._gemm_input_role,
Expand All @@ -175,13 +178,13 @@ def float8_mm(aten_op, args, kwargs=None):
)
if scaled_mm_config.emulate:
return torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
a._data, a._inv_scale, b._data, b._inv_scale, output_dtype
)
tensor_out = addmm_float8_unwrapped(
a_data,
a_scale,
a_inv_scale,
b_data,
b_scale,
b_inv_scale,
output_dtype,
output_scale=None,
bias=None,
Expand All @@ -200,7 +203,7 @@ def float8_addmm(aten_op, args, kwargs=None):
bias = args[0]
a = args[1]
b = args[2]
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
a_data, a_inv_scale, b_data, b_inv_scale = preprocess_addmm(a, b)
output_dtype = a._orig_dtype
assert bias.dtype == output_dtype, "bias dtype must match output dtype"
scaled_mm_config = choose_scaled_mm_config(
Expand All @@ -210,15 +213,16 @@ def float8_addmm(aten_op, args, kwargs=None):
b._linear_mm_config,
)
if scaled_mm_config.emulate:
# TODO inv scale here
out = torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
)
return out + bias
tensor_out = addmm_float8_unwrapped(
a_data,
a_scale,
a_inv_scale,
b_data,
b_scale,
b_inv_scale,
output_dtype,
output_scale=None,
bias=bias,
Expand Down Expand Up @@ -249,6 +253,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
return Float8Tensor(
args[0]._data,
args[0]._scale,
args[0]._inv_scale,
kwargs["dtype"],
args[0]._linear_mm_config,
args[0]._gemm_input_role,
Expand Down Expand Up @@ -276,6 +281,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
return Float8Tensor(
fp8_out,
fp8_input._scale,
fp8_input._inv_scale,
fp8_input._orig_dtype,
fp8_input._linear_mm_config,
fp8_input._gemm_input_role,
Expand All @@ -292,6 +298,7 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
return Float8Tensor(
fp8_out,
fp8_input._scale,
fp8_input._inv_scale,
fp8_input._orig_dtype,
fp8_input._linear_mm_config,
fp8_input._gemm_input_role,
Expand All @@ -314,6 +321,7 @@ def index_put_fp8(aten_op, args, kwargs=None):
return Float8Tensor(
fp8_out,
fp8_self._scale,
fp8_self._inv_scale,
fp8_self._orig_dtype,
fp8_self._linear_mm_config,
fp8_self._gemm_input_role,
Expand Down Expand Up @@ -355,6 +363,7 @@ def copy_fp8(aten_op, args, kwargs=None):
return Float8Tensor(
fp8_out,
self._scale,
self._inv_scale,
self._orig_dtype,
self._linear_mm_config,
self._gemm_input_role,
Expand Down
10 changes: 6 additions & 4 deletions torchao/float8/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
# For output going from fp32 -> fp8 we multiply by the scale
def addmm_float8_unwrapped(
a_data: torch.Tensor,
a_scale: torch.Tensor,
a_inv_scale: torch.Tensor,
b_data: torch.Tensor,
b_scale: torch.tensor,
b_inv_scale: torch.tensor,
output_dtype: torch.dtype,
output_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
Expand All @@ -36,8 +36,10 @@ def addmm_float8_unwrapped(
as inputs. This is used to standardize the logic between subclassed and non subclassed
versions of the linear module.
"""
a_inverse_scale = a_scale.reciprocal()
b_inverse_scale = b_scale.reciprocal()
# a_inverse_scale = a_scale.reciprocal()
# b_inverse_scale = b_scale.reciprocal()
a_inverse_scale = a_inv_scale
b_inverse_scale = b_inv_scale
if output_dtype == torch.float32 and bias is not None:
# Bias is not supported by _scaled_mm when output is fp32
output = torch._scaled_mm(
Expand Down
16 changes: 14 additions & 2 deletions torchao/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def forward(
"""
tensor_scaled = tensor * scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
inv_scale = scale.reciprocal()

if isinstance(bits_fp8, DTensor):
assert isinstance(
Expand All @@ -166,9 +167,11 @@ def forward(
bits_placements = bits_fp8.placements
local_bits = bits_fp8.to_local()
local_scale = scale.to_local()
local_inv_scale = inv_scale.to_local()
inner_float8_tensor = Float8Tensor(
local_bits,
local_scale,
local_inv_scale,
tensor.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
Expand All @@ -185,6 +188,7 @@ def forward(
return Float8Tensor(
bits_fp8,
scale,
inv_scale,
tensor.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
Expand Down Expand Up @@ -251,6 +255,11 @@ class Float8Tensor(torch.Tensor):
* `_scale`: the scale used to scale the original fp32 tensor. We multiply
by scale to go from fp32 range to fp8 range, and divide by scale to go
from fp8 range to fp32 range.
* `_inv_scale`: the inverse of `_scale`. We need this because the
`torch._scaled_mm` function requires inverse scales, and torch.compile
does not reliably fuse this into preceding ops, which can lead to extra
GPU kernel launches. If we calculate the inverse scale colocated with
creating the `Float8Tensor` instance, we don't see the extra GPU kernels.
* `_orig_dtype`: the original dtype of the tensor used to create this
tensor.
* `_emulate`: if true using fp32 emulation for the matmuls, helpful
Expand All @@ -275,6 +284,7 @@ def __new__(
cls,
data: torch.Tensor,
scale: torch.Tensor,
inv_scale: torch.Tensor,
orig_dtype: torch.dtype,
linear_mm_config: Optional[LinearMMConfig],
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
Expand All @@ -297,6 +307,7 @@ def __new__(
)
self._data = data
self._scale = scale
self._inv_scale = inv_scale
self._orig_dtype = orig_dtype
self._linear_mm_config = (
linear_mm_config if linear_mm_config is not None else LinearMMConfig()
Expand All @@ -314,14 +325,15 @@ def __tensor_flatten__(self):
"_linear_mm_config": self._linear_mm_config,
"_gemm_input_role": self._gemm_input_role,
}
return ["_data", "_scale"], ctx
return ["_data", "_scale", "_inv_scale"], ctx

@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride):
assert len(inner_tensors) == 2
assert len(inner_tensors) == 3
return Float8Tensor(
inner_tensors["_data"],
inner_tensors["_scale"],
inner_tensors["_inv_scale"],
metadata["_orig_dtype"],
metadata["_linear_mm_config"],
metadata["_gemm_input_role"],
Expand Down
11 changes: 7 additions & 4 deletions torchao/float8/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def fsdp_pre_all_gather(self, mesh):
reduce_amax=True,
gemm_input_role=GemmInputRole.WEIGHT,
)
return (float8_tensor._data,), (float8_tensor._scale,)
return (float8_tensor._data,), (float8_tensor._scale, float8_tensor._inv_scale)

def fsdp_post_all_gather(
self,
Expand All @@ -228,7 +228,7 @@ def fsdp_post_all_gather(
out: Optional[torch.Tensor] = None,
):
(data,) = all_gather_outputs
(scale,) = metadata
(scale, inv_scale) = metadata
if out is not None:
from torch.distributed._tensor import DTensor
if isinstance(out, Float8Tensor):
Expand All @@ -245,6 +245,7 @@ def fsdp_post_all_gather(
return Float8Tensor(
data,
scale,
inv_scale,
param_dtype,
self._linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
Expand Down Expand Up @@ -407,7 +408,7 @@ def fsdp_pre_all_gather(self, mesh):
self._linear_mm_config,
GemmInputRole.WEIGHT,
)
return (float8_tensor._data,), (float8_tensor._scale,)
return (float8_tensor._data,), (float8_tensor._scale, float8_tensor._inv_scale)

def fsdp_post_all_gather(
self,
Expand All @@ -418,14 +419,16 @@ def fsdp_post_all_gather(
out: Optional[torch.Tensor] = None,
):
(data,) = all_gather_outputs
(scale,) = metadata
(scale, inv_scale) = metadata
if out is not None:
assert isinstance(out, Float8Tensor), f"{type(out)}"
out._scale = scale
out._inv_scale = inv_scale
return
return Float8Tensor(
data,
scale,
inv_scale,
param_dtype,
self._linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
Expand Down

0 comments on commit 0e634b9

Please sign in to comment.