Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cleanup][2/x] split float8 mm by delayed vs dynamic #1461

Merged
merged 9 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 132 additions & 138 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,77 +29,86 @@
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor


@torch._dynamo.allow_in_graph
class manual_float8_matmul_with_args_in_float8(torch.autograd.Function):
"""
Like torch.matmul, but with the arguments in float8

Note: this function requires all arguments to already be Float8Tensor objects,
which only supports tensorwise scaling granularity. The reason we didn't just make this
function support axiswise scaling granularity is because that would need very
careful testing of delayed scaling, as delayed scaling modifies buffers inplace.

In the future we'll probably have to unify, just postponing that until a future PR.
"""

@staticmethod
def forward(
ctx,
input_fp8,
weight_fp8_t,
):
ctx.save_for_backward(input_fp8, weight_fp8_t)
# the reshapes are needed in order to make the shapes compatible with
# torch.mm
orig_shape = input_fp8.shape
input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1])
res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t)
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
return res_bits

@staticmethod
def backward(ctx, grad_output_fp8):
input_fp8, weight_fp8_t = ctx.saved_tensors

# the reshapes are needed in order to make the shapes compatible with
# torch.mm
grad_output_fp8_orig_shape = grad_output_fp8.shape
grad_output_fp8_reshaped = grad_output_fp8.reshape(
-1, grad_output_fp8_orig_shape[-1]
)

# calculate grad_input
grad_input = torch.mm(
grad_output_fp8_reshaped,
weight_fp8_t.t(),
)
grad_input = grad_input.reshape(
*grad_output_fp8_orig_shape[:-1], grad_input.shape[-1]
def _cast_input_to_float8(
input: torch.Tensor,
scaling_type_input: ScalingType,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

if tensor_already_casted_to_fp8(input):
input_fp8 = input
else:
assert scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
config.cast_config_input.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)

input_fp8_orig_shape = input_fp8.shape
input_fp8_reshaped = input_fp8.reshape(-1, input_fp8_orig_shape[-1])

# calculate grad_weight
# Note: the variant below is slightly faster on LLaMa 3 8B pretraining
# compared to than calculating `grad_weight_t = input_fp8_t @ grad_output_fp8_reshaped`
grad_weight = torch.mm(
grad_output_fp8_reshaped.t(),
input_fp8_reshaped,
)

return grad_input, grad_weight.t()
return input_fp8


def _get_weight_scale(
weight: torch.Tensor,
scaling_type_weight: ScalingType,
config: Float8LinearConfig,
) -> Optional[torch.Tensor]:
if tensor_already_casted_to_fp8(weight):
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured it would have just returned the scale on the fp8 tensor

assert scaling_type_weight is ScalingType.DYNAMIC
return tensor_to_scale(weight, config.cast_config_weight.target_dtype)


def _cast_weight_to_float8_t(
weight: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if tensor_already_casted_to_fp8(weight):
return weight.t()
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
config.cast_config_weight.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8.t()


def _cast_output_to_float8_in_bw(
output: torch.Tensor,
scaling_type_grad_output,
linear_mm_config: LinearMMConfig,
config: Float8LinearConfig,
) -> torch.Tensor:
assert scaling_type_grad_output is ScalingType.DYNAMIC
output = NoopFwToFloat8BwDynamic.apply(
output,
linear_mm_config,
config.cast_config_grad_output.target_dtype,
)
return output


@torch._dynamo.allow_in_graph
class manual_float8_matmul_with_args_in_hp(torch.autograd.Function):
class matmul_with_hp_or_float8_args(torch.autograd.Function):
"""
Like torch.matmul, but with the arguments in high precision and the cast to float8
defined inside of this function.
Like torch.matmul, but with the arguments in either high precision or float8.
* if the arguments are in high precision, they are cast to float8 according
to the specified config
* if the arguments are in float8, we assume the cast honored the config

Note: this function currently only supports dynamic scaling type and
axiswise granularity. We will have to unify this with other scaling types
and other granularities in a separate PR.
Only supports dynamic scaling, does not support delayed/static scaling.
"""

@staticmethod
Expand All @@ -116,7 +125,9 @@ def forward(

c = config

if c.cast_config_input.scaling_type is ScalingType.DISABLED:
if tensor_already_casted_to_fp8(input_hp):
input_maybe_fp8 = input_hp
elif c.cast_config_input.scaling_type is ScalingType.DISABLED:
input_maybe_fp8 = input_hp
else:
input_maybe_fp8 = hp_tensor_to_float8_dynamic(
Expand All @@ -130,7 +141,9 @@ def forward(
),
)

if c.cast_config_weight.scaling_type is ScalingType.DISABLED:
if tensor_already_casted_to_fp8(weight_hp_t):
weight_maybe_fp8_t = weight_hp_t
elif c.cast_config_weight.scaling_type is ScalingType.DISABLED:
weight_maybe_fp8_t = weight_hp_t
else:
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
Expand Down Expand Up @@ -166,7 +179,10 @@ def backward(ctx, grad_output):
# calculate grad_input
#

if c.cast_config_grad_output.scaling_type is ScalingType.DISABLED:
if tensor_already_casted_to_fp8(grad_output_reshaped):
# TODO(future PR): this var name is axiswise-specific, fix it
grad_output_reshaped_maybe_fp8_dim0 = grad_output_reshaped
elif c.cast_config_grad_output.scaling_type is ScalingType.DISABLED:
grad_output_reshaped_maybe_fp8_dim0 = grad_output_reshaped
else:
grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
Expand All @@ -180,7 +196,10 @@ def backward(ctx, grad_output):
),
)

if c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED:
if tensor_already_casted_to_fp8(weight_hp_t):
# TODO(future PR): var name is axiswise specific, fix it
weight_t_maybe_fp8_dim0 = weight_hp_t
elif c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED:
weight_t_maybe_fp8_dim0 = weight_hp_t
else:
# Note: we need https://github.com/pytorch/pytorch/issues/136267
Expand Down Expand Up @@ -213,7 +232,10 @@ def backward(ctx, grad_output):
# calculate grad_weight
#

if (
if tensor_already_casted_to_fp8(grad_output_reshaped):
# TODO(future PR): var name is axiswise specific, fix it
grad_output_reshaped_maybe_fp8_dim1 = grad_output_reshaped
elif (
c.cast_config_grad_output_for_grad_weight.scaling_type
is ScalingType.DISABLED
):
Expand All @@ -230,7 +252,10 @@ def backward(ctx, grad_output):
),
)

if c.cast_config_input_for_grad_weight.scaling_type is ScalingType.DISABLED:
if tensor_already_casted_to_fp8(input_hp_reshaped):
# TODO(future PR): var name is axiswise specific, fix it
input_reshaped_maybe_fp8_dim1 = input_hp_reshaped
elif c.cast_config_input_for_grad_weight.scaling_type is ScalingType.DISABLED:
input_reshaped_maybe_fp8_dim1 = input_hp_reshaped
else:
input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
Expand Down Expand Up @@ -303,58 +328,6 @@ def __init__(self, *args, **kwargs):
),
)

def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

if tensor_already_casted_to_fp8(input):
input_fp8 = input
else:
assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
self.config.cast_config_input.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
return input_fp8

def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
if tensor_already_casted_to_fp8(weight):
return None
assert self.scaling_type_weight is ScalingType.DYNAMIC
return tensor_to_scale(weight, self.config.cast_config_weight.target_dtype)

def cast_weight_to_float8_t(
self,
weight: torch.Tensor,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if tensor_already_casted_to_fp8(weight):
return weight.t()
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
self.config.cast_config_weight.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8.t()

def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
assert self.scaling_type_grad_output is ScalingType.DYNAMIC
output = NoopFwToFloat8BwDynamic.apply(
output,
self.linear_mm_config,
self.config.cast_config_grad_output.target_dtype,
)
return output

def forward(self, input: torch.Tensor) -> torch.Tensor:
has_any_axiswise_scaling = any(
cc.scaling_granularity is ScalingGranularity.AXISWISE
Expand All @@ -368,34 +341,55 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
]
)

input_maybe_fp8 = input
weight_maybe_fp8_t = self.weight.t()

# TODO(future PR): check for axiswise scaling for input, weight,
# grad_output separately instead of together
if not has_any_axiswise_scaling:
input_fp8 = self.cast_input_to_float8(input)
input_fp8 = _cast_input_to_float8(
input,
self.scaling_type_input,
self.config,
self.linear_mm_config,
)
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = self.get_weight_scale(self.weight)
weight_scale = _get_weight_scale(
self.weight, self.scaling_type_weight, self.config
)

if self.config.force_recompute_fp8_weight_in_bwd:
weight_fp8_t = checkpoint.checkpoint(
self.cast_weight_to_float8_t,
_cast_weight_to_float8_t,
self.weight,
self.config,
self.linear_mm_config,
weight_scale,
)
else:
weight_fp8_t = self.cast_weight_to_float8_t(self.weight, weight_scale)
weight_fp8_t = _cast_weight_to_float8_t(
self.weight,
self.config,
self.linear_mm_config,
weight_scale,
)

output = manual_float8_matmul_with_args_in_float8.apply(
input_fp8, weight_fp8_t
)
input_maybe_fp8 = input_fp8
weight_maybe_fp8_t = weight_fp8_t

# Cast grad_output to float8_e5m2 during backward
output = self.cast_output_to_float8_in_bw(output)
output = matmul_with_hp_or_float8_args.apply(
input_maybe_fp8,
weight_maybe_fp8_t,
self.linear_mm_config,
self.config,
)

else:
# for now, axiswise path is separate
# TODO(future PR): unify to support mix and match
output = manual_float8_matmul_with_args_in_hp.apply(
input,
self.weight.t(),
if not has_any_axiswise_scaling:
# Cast grad_output to float8_e5m2 during backward
output = _cast_output_to_float8_in_bw(
output,
self.scaling_type_grad_output,
self.linear_mm_config,
self.config,
)
Expand Down
Loading
Loading