Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 35 additions & 24 deletions torchao/dtypes/floatx/cutlass_semi_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,46 +191,57 @@ def _linear_fp8_act_fp8_weight_sparse_cutlass_check(input_tensor, weight_tensor,
from torchao.dtypes.floatx import Float8Layout

base_check = (
isinstance(input_tensor, AffineQuantizedTensor) and
isinstance(input_tensor._layout, Float8Layout) and
input_tensor.dtype in (torch.float16, torch.bfloat16) and
len(input_tensor.shape) >= 2 and
input_tensor.tensor_impl.scale.dtype == torch.float32 and
isinstance(weight_tensor, AffineQuantizedTensor) and
isinstance(weight_tensor._layout, CutlassSemiSparseLayout) and
weight_tensor.dtype == input_tensor.dtype and
len(weight_tensor.shape) == 2 and
weight_tensor.tensor_impl.scale.dtype == torch.float32 and
(bias is None or bias.dtype == input_tensor.dtype) and
(bias is None or len(bias.shape) == 1)
isinstance(input_tensor, AffineQuantizedTensor)
and isinstance(input_tensor._layout, Float8Layout)
and input_tensor.dtype in (torch.float16, torch.bfloat16)
and len(input_tensor.shape) >= 2
and input_tensor.tensor_impl.scale.dtype == torch.float32
and isinstance(weight_tensor, AffineQuantizedTensor)
and isinstance(weight_tensor._layout, CutlassSemiSparseLayout)
and weight_tensor.dtype == input_tensor.dtype
and len(weight_tensor.shape) == 2
and weight_tensor.tensor_impl.scale.dtype == torch.float32
and (bias is None or bias.dtype == input_tensor.dtype)
and (bias is None or len(bias.shape) == 1)
)

if base_check:

# do extra check and reshape if needed
input_tensor_squeezed = False
if len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) and \
len(input_tensor.tensor_impl.scale.shape) > 1 and \
input_tensor.tensor_impl.scale.shape[-1] == 1:
input_tensor.tensor_impl.scale = torch.squeeze(input_tensor.tensor_impl.scale, dim=-1)
if (
len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape)
and len(input_tensor.tensor_impl.scale.shape) > 1
and input_tensor.tensor_impl.scale.shape[-1] == 1
):
input_tensor.tensor_impl.scale = torch.squeeze(
input_tensor.tensor_impl.scale, dim=-1
)
input_tensor_squeezed = True

weight_tensor_squeezed = False
if len(weight_tensor.tensor_impl.scale.shape) == 2 and \
weight_tensor.tensor_impl.scale.shape[-1] == 1:
weight_tensor.tensor_impl.scale = torch.squeeze(weight_tensor.tensor_impl.scale, dim=-1)
if (
len(weight_tensor.tensor_impl.scale.shape) == 2
and weight_tensor.tensor_impl.scale.shape[-1] == 1
):
weight_tensor.tensor_impl.scale = torch.squeeze(
weight_tensor.tensor_impl.scale, dim=-1
)
weight_tensor_squeezed = True

extra_check = (
len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1
and len(weight_tensor.tensor_impl.scale.shape) == 1
)

if not extra_check: # revert if extra check failed
if not extra_check: # revert if extra check failed
if input_tensor_squeezed:
input_tensor.tensor_impl.scale = torch.unsqueeze(input_tensor.tensor_impl.scale, dim=-1)
input_tensor.tensor_impl.scale = torch.unsqueeze(
input_tensor.tensor_impl.scale, dim=-1
)
if weight_tensor_squeezed:
weight_tensor.tensor_impl.scale = torch.unsqueeze(weight_tensor.tensor_impl.scale, dim=-1)
weight_tensor.tensor_impl.scale = torch.unsqueeze(
weight_tensor.tensor_impl.scale, dim=-1
)

return extra_check

Expand Down
Loading