Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/triton_kernels/tests/test_mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
)
from triton_kernels.numerics_details.mxfp_details._upcast_from_mxfp import upcast_mxfp4_tile
from triton_kernels.target_info import is_cuda
from triton_kernels.tensor import convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details.layout import StridedLayout
from triton_kernels.testing import assert_close, assert_equal


Expand Down Expand Up @@ -203,6 +205,21 @@ def test_mxfp_quant_dequant(src_dtype, dst_dtype, device):
assert_equal(weight, dequant)


def test_downcast_to_mxfp_accepts_pitched_strided_input(device):
torch.manual_seed(0)
dense = torch.randn((64, 128), device=device, dtype=torch.bfloat16)
pitched = torch.empty_strided(dense.shape, (256, 1), device=device, dtype=dense.dtype)
pitched.copy_(dense)

tensor = convert_layout(wrap_torch_tensor(pitched), StridedLayout(-1))
quant, scale = downcast_to_mxfp(tensor, torch.uint8, axis=-1)
expected_quant, expected_scale = downcast_to_mxfp(dense, torch.uint8, axis=-1)

assert tensor.storage.data.stride() == (256, 1)
assert_equal(expected_quant, quant)
assert_equal(expected_scale, scale)


# fmt: off
@pytest.mark.parametrize(
"shape, axis, quant_dtype, rounding_mode, scale_dtype, microblock_size",
Expand Down
92 changes: 91 additions & 1 deletion python/triton_kernels/tests/test_tensor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
import torch
from triton_kernels.tensor_details.dtype import BIT
from triton_kernels.tensor_details.dtype import BIT, FP4, UINT8
from triton_kernels.tensor import (
convert_layout,
make_ragged_tensor_metadata,
make_ragged_tensor_metadata_torch,
remap_ragged_tensor_metadata,
Expand All @@ -11,6 +12,95 @@
wrap_torch_tensor,
)
from triton_kernels.testing import assert_equal
from triton_kernels.tensor_details.layout import (
BlackwellActMXScaleLayout,
BlackwellMX4ValueShuffledLayout,
BlackwellMXScaleLayout,
BlackwellMXValueLayout,
CDNA4MXScaleLayout,
GFX1250MXScaleLayout,
HopperMXScaleLayout,
HopperMXValueLayout,
StridedLayout,
)


@pytest.mark.parametrize(
("transpose", "layout"),
[
(False, StridedLayout(-1)),
(False, StridedLayout(1)),
(True, StridedLayout(-2)),
(True, StridedLayout(0)),
],
)
def test_convert_layout_noop(transpose, layout):
data = torch.randn((7, 11))
if transpose:
data = data.T
tensor = wrap_torch_tensor(data)

assert convert_layout(tensor, layout) is tensor


def test_convert_layout_noop_preserves_strided_view():
tensor = wrap_torch_tensor(torch.randn((14, 11))[::2])

assert convert_layout(tensor, StridedLayout(-1)) is tensor
assert tensor.storage.data.stride() == (22, 1)


def test_convert_layout_rejects_strided_view_without_contiguous_dimension():
tensor = wrap_torch_tensor(torch.randn((14, 22))[::2, ::2])

with pytest.raises(ValueError):
convert_layout(tensor, tensor.storage.layout)


def test_convert_layout_noop_does_not_ignore_transformation_kwargs():
tensor = wrap_torch_tensor(torch.randn((7, 11)))

with pytest.raises(TypeError):
convert_layout(tensor, tensor.storage.layout, unsupported=True)


@pytest.mark.parametrize(
("storage_shape", "logical_shape", "dtype", "layout", "equivalent_layout"),
[
((10, 254, 60), None, UINT8, BlackwellMXScaleLayout(), BlackwellMXScaleLayout()),
((130, 65), None, UINT8, BlackwellActMXScaleLayout(None), BlackwellActMXScaleLayout(None)),
((256, 64), (256, 128), FP4, BlackwellMXValueLayout(), BlackwellMXValueLayout()),
((128, 256), (128, 512), FP4, BlackwellMX4ValueShuffledLayout(), BlackwellMX4ValueShuffledLayout()),
((70, 65), None, UINT8, HopperMXScaleLayout(-2, 4), HopperMXScaleLayout(-2, 4)),
((64, 64), (64, 128), FP4, HopperMXValueLayout(-2, 3), HopperMXValueLayout(-2, 3)),
((10, 254, 60), None, UINT8, CDNA4MXScaleLayout(), CDNA4MXScaleLayout()),
((10, 254, 60), None, UINT8, GFX1250MXScaleLayout(), GFX1250MXScaleLayout()),
],
)
def test_convert_layout_noop_for_equivalent_layout(storage_shape, logical_shape, dtype, layout, equivalent_layout):
tensor = wrap_torch_tensor(torch.randint(0, 256, storage_shape, dtype=torch.uint8), dtype=dtype,
shape=logical_shape)
converted = convert_layout(tensor, layout)

assert converted is not tensor
assert convert_layout(converted, equivalent_layout) is converted


@pytest.mark.parametrize(
("storage_shape", "logical_shape", "dtype", "layout", "different_layout"),
[
((70, 65), None, UINT8, HopperMXScaleLayout(-2, 4), HopperMXScaleLayout(-2, 8)),
((64, 64), (64, 128), FP4, HopperMXValueLayout(-2, 3), HopperMXValueLayout(-2, 2)),
((128, 256), (128, 512), FP4, BlackwellMX4ValueShuffledLayout(), BlackwellMX4ValueShuffledLayout(block_n=128)),
],
)
def test_convert_layout_converts_different_parameterized_layout(storage_shape, logical_shape, dtype, layout,
different_layout):
tensor = wrap_torch_tensor(torch.randint(0, 256, storage_shape, dtype=torch.uint8), dtype=dtype,
shape=logical_shape)
converted = convert_layout(tensor, layout)

assert convert_layout(converted, different_layout) is not converted


@pytest.mark.parametrize("n_slices", [1, 7, 33, 911, 1025])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import pytest
import torch
from triton_kernels.tensor_details.layout import BlackwellMXScaleLayout, BlackwellActMXScaleLayout, StridedLayout
from triton_kernels.tensor import make_ragged_tensor_metadata, wrap_torch_tensor, convert_layout
from triton_kernels.tensor import make_ragged_tensor_metadata, make_ragged_tensor_metadata_torch, wrap_torch_tensor, convert_layout

# ------------------------------------------------------------
# Torch tests
# ------------------------------------------------------------


def test_act_scale_storage_preservation():
slice_sizes = torch.tensor([2, 3], dtype=torch.int32)
metadata = make_ragged_tensor_metadata_torch(slice_sizes, 5)
equivalent = BlackwellActMXScaleLayout(metadata)
reconstructed = BlackwellActMXScaleLayout(make_ragged_tensor_metadata_torch(slice_sizes, 5))

assert equivalent.can_preserve_storage_as(BlackwellActMXScaleLayout(metadata), 2)
assert not equivalent.can_preserve_storage_as(reconstructed, 2)


@pytest.mark.parametrize(
"shape",
[
Expand Down
10 changes: 9 additions & 1 deletion python/triton_kernels/triton_kernels/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,23 @@ def wrap_torch_tensor(torch_tensor, dtype=None, shape=None, shape_max=None, layo
if shape_max is None:
shape_max = list(shape)
if layout is None:
# For a strided (dense) tensor we only track which dimension has unit stride.
# For a strided tensor we only track which dimension has unit stride.
# This is consistent with how we expand `shape` for packed sub-byte dtypes.
major_dim = torch_tensor.stride().index(1) if 1 in torch_tensor.stride() else -1
layout = StridedLayout(major_dim=major_dim - torch_tensor.ndim)
return Tensor(Storage(torch_tensor, layout), dtype=dtype, shape=shape, shape_max=shape_max)


def convert_layout(tensor: Tensor, layout: Layout, **layout_transformation_kwargs):
"""Convert `tensor` storage encoding to `layout`.

Returns `tensor` unchanged when its existing storage is already valid for
`layout`. This operation does not clone, densify, or canonicalize physical
strides of a tensor that is already in the requested encoding.
"""
shape = list(tensor.shape)
if not layout_transformation_kwargs and tensor.storage.layout.can_preserve_storage_as(layout, len(shape)):
return tensor
# convert `tensor` into canonical form
transformation = tensor.storage.layout.make_transformation(shape, tensor.dtype == FP4)
canonical_data = transformation.unswizzle_data(tensor.storage.data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def unswizzle_data(self, data):
@dataclass(frozen=True)
class Layout(ABC):

def can_preserve_storage_as(self, other: "Layout", rank: int) -> bool:
"""Whether existing storage is already valid for `other`."""
return self == other

@abstractmethod
def make_transformation(self, shape: list[int]) -> LayoutTransformation:
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class BlackwellActMXScaleLayout(Layout):

ragged_metadata: RaggedTensorMetadata | None

def can_preserve_storage_as(self, other: Layout, rank: int) -> bool:
return isinstance(other, BlackwellActMXScaleLayout) and self.ragged_metadata is other.ragged_metadata

@property
def name(self):
return "BLACKWELL_ACT_SCALE"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def make_transformation(self, shape: list[int], is_fp4: bool) -> LayoutTransform
def name(self):
return "STRIDED"

def can_preserve_storage_as(self, other: Layout, rank: int) -> bool:
return isinstance(other, StridedLayout) and self.order(rank) == other.order(rank)

def swizzle_block_shape(self, block_shape):
return block_shape

Expand Down
Loading