diff --git a/python/triton_kernels/tests/test_mxfp.py b/python/triton_kernels/tests/test_mxfp.py index 0f1087e71fa6..40c7c016da12 100644 --- a/python/triton_kernels/tests/test_mxfp.py +++ b/python/triton_kernels/tests/test_mxfp.py @@ -17,6 +17,8 @@ ) from triton_kernels.numerics_details.mxfp_details._upcast_from_mxfp import upcast_mxfp4_tile from triton_kernels.target_info import is_cuda +from triton_kernels.tensor import convert_layout, wrap_torch_tensor +from triton_kernels.tensor_details.layout import StridedLayout from triton_kernels.testing import assert_close, assert_equal @@ -203,6 +205,21 @@ def test_mxfp_quant_dequant(src_dtype, dst_dtype, device): assert_equal(weight, dequant) +def test_downcast_to_mxfp_accepts_pitched_strided_input(device): + torch.manual_seed(0) + dense = torch.randn((64, 128), device=device, dtype=torch.bfloat16) + pitched = torch.empty_strided(dense.shape, (256, 1), device=device, dtype=dense.dtype) + pitched.copy_(dense) + + tensor = convert_layout(wrap_torch_tensor(pitched), StridedLayout(-1)) + quant, scale = downcast_to_mxfp(tensor, torch.uint8, axis=-1) + expected_quant, expected_scale = downcast_to_mxfp(dense, torch.uint8, axis=-1) + + assert tensor.storage.data.stride() == (256, 1) + assert_equal(expected_quant, quant) + assert_equal(expected_scale, scale) + + # fmt: off @pytest.mark.parametrize( "shape, axis, quant_dtype, rounding_mode, scale_dtype, microblock_size", diff --git a/python/triton_kernels/tests/test_tensor.py b/python/triton_kernels/tests/test_tensor.py index 495d40e40f33..42a587853bfc 100644 --- a/python/triton_kernels/tests/test_tensor.py +++ b/python/triton_kernels/tests/test_tensor.py @@ -1,7 +1,8 @@ import pytest import torch -from triton_kernels.tensor_details.dtype import BIT +from triton_kernels.tensor_details.dtype import BIT, FP4, UINT8 from triton_kernels.tensor import ( + convert_layout, make_ragged_tensor_metadata, make_ragged_tensor_metadata_torch, remap_ragged_tensor_metadata, @@ -11,6 +12,95 @@ wrap_torch_tensor, ) from triton_kernels.testing import assert_equal +from triton_kernels.tensor_details.layout import ( + BlackwellActMXScaleLayout, + BlackwellMX4ValueShuffledLayout, + BlackwellMXScaleLayout, + BlackwellMXValueLayout, + CDNA4MXScaleLayout, + GFX1250MXScaleLayout, + HopperMXScaleLayout, + HopperMXValueLayout, + StridedLayout, +) + + +@pytest.mark.parametrize( + ("transpose", "layout"), + [ + (False, StridedLayout(-1)), + (False, StridedLayout(1)), + (True, StridedLayout(-2)), + (True, StridedLayout(0)), + ], +) +def test_convert_layout_noop(transpose, layout): + data = torch.randn((7, 11)) + if transpose: + data = data.T + tensor = wrap_torch_tensor(data) + + assert convert_layout(tensor, layout) is tensor + + +def test_convert_layout_noop_preserves_strided_view(): + tensor = wrap_torch_tensor(torch.randn((14, 11))[::2]) + + assert convert_layout(tensor, StridedLayout(-1)) is tensor + assert tensor.storage.data.stride() == (22, 1) + + +def test_convert_layout_rejects_strided_view_without_contiguous_dimension(): + tensor = wrap_torch_tensor(torch.randn((14, 22))[::2, ::2]) + + with pytest.raises(ValueError): + convert_layout(tensor, tensor.storage.layout) + + +def test_convert_layout_noop_does_not_ignore_transformation_kwargs(): + tensor = wrap_torch_tensor(torch.randn((7, 11))) + + with pytest.raises(TypeError): + convert_layout(tensor, tensor.storage.layout, unsupported=True) + + +@pytest.mark.parametrize( + ("storage_shape", "logical_shape", "dtype", "layout", "equivalent_layout"), + [ + ((10, 254, 60), None, UINT8, BlackwellMXScaleLayout(), BlackwellMXScaleLayout()), + ((130, 65), None, UINT8, BlackwellActMXScaleLayout(None), BlackwellActMXScaleLayout(None)), + ((256, 64), (256, 128), FP4, BlackwellMXValueLayout(), BlackwellMXValueLayout()), + ((128, 256), (128, 512), FP4, BlackwellMX4ValueShuffledLayout(), BlackwellMX4ValueShuffledLayout()), + ((70, 65), None, UINT8, HopperMXScaleLayout(-2, 4), HopperMXScaleLayout(-2, 4)), + ((64, 64), (64, 128), FP4, HopperMXValueLayout(-2, 3), HopperMXValueLayout(-2, 3)), + ((10, 254, 60), None, UINT8, CDNA4MXScaleLayout(), CDNA4MXScaleLayout()), + ((10, 254, 60), None, UINT8, GFX1250MXScaleLayout(), GFX1250MXScaleLayout()), + ], +) +def test_convert_layout_noop_for_equivalent_layout(storage_shape, logical_shape, dtype, layout, equivalent_layout): + tensor = wrap_torch_tensor(torch.randint(0, 256, storage_shape, dtype=torch.uint8), dtype=dtype, + shape=logical_shape) + converted = convert_layout(tensor, layout) + + assert converted is not tensor + assert convert_layout(converted, equivalent_layout) is converted + + +@pytest.mark.parametrize( + ("storage_shape", "logical_shape", "dtype", "layout", "different_layout"), + [ + ((70, 65), None, UINT8, HopperMXScaleLayout(-2, 4), HopperMXScaleLayout(-2, 8)), + ((64, 64), (64, 128), FP4, HopperMXValueLayout(-2, 3), HopperMXValueLayout(-2, 2)), + ((128, 256), (128, 512), FP4, BlackwellMX4ValueShuffledLayout(), BlackwellMX4ValueShuffledLayout(block_n=128)), + ], +) +def test_convert_layout_converts_different_parameterized_layout(storage_shape, logical_shape, dtype, layout, + different_layout): + tensor = wrap_torch_tensor(torch.randint(0, 256, storage_shape, dtype=torch.uint8), dtype=dtype, + shape=logical_shape) + converted = convert_layout(tensor, layout) + + assert convert_layout(converted, different_layout) is not converted @pytest.mark.parametrize("n_slices", [1, 7, 33, 911, 1025]) diff --git a/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py b/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py index 59d15ee96bdc..9ae43d3cb118 100644 --- a/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py +++ b/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py @@ -1,13 +1,23 @@ import pytest import torch from triton_kernels.tensor_details.layout import BlackwellMXScaleLayout, BlackwellActMXScaleLayout, StridedLayout -from triton_kernels.tensor import make_ragged_tensor_metadata, wrap_torch_tensor, convert_layout +from triton_kernels.tensor import make_ragged_tensor_metadata, make_ragged_tensor_metadata_torch, wrap_torch_tensor, convert_layout # ------------------------------------------------------------ # Torch tests # ------------------------------------------------------------ +def test_act_scale_storage_preservation(): + slice_sizes = torch.tensor([2, 3], dtype=torch.int32) + metadata = make_ragged_tensor_metadata_torch(slice_sizes, 5) + equivalent = BlackwellActMXScaleLayout(metadata) + reconstructed = BlackwellActMXScaleLayout(make_ragged_tensor_metadata_torch(slice_sizes, 5)) + + assert equivalent.can_preserve_storage_as(BlackwellActMXScaleLayout(metadata), 2) + assert not equivalent.can_preserve_storage_as(reconstructed, 2) + + @pytest.mark.parametrize( "shape", [ diff --git a/python/triton_kernels/triton_kernels/tensor.py b/python/triton_kernels/triton_kernels/tensor.py index 14964289af94..7f59a5d3d7d1 100644 --- a/python/triton_kernels/triton_kernels/tensor.py +++ b/python/triton_kernels/triton_kernels/tensor.py @@ -225,7 +225,7 @@ def wrap_torch_tensor(torch_tensor, dtype=None, shape=None, shape_max=None, layo if shape_max is None: shape_max = list(shape) if layout is None: - # For a strided (dense) tensor we only track which dimension has unit stride. + # For a strided tensor we only track which dimension has unit stride. # This is consistent with how we expand `shape` for packed sub-byte dtypes. major_dim = torch_tensor.stride().index(1) if 1 in torch_tensor.stride() else -1 layout = StridedLayout(major_dim=major_dim - torch_tensor.ndim) @@ -233,7 +233,15 @@ def wrap_torch_tensor(torch_tensor, dtype=None, shape=None, shape_max=None, layo def convert_layout(tensor: Tensor, layout: Layout, **layout_transformation_kwargs): + """Convert `tensor` storage encoding to `layout`. + + Returns `tensor` unchanged when its existing storage is already valid for + `layout`. This operation does not clone, densify, or canonicalize physical + strides of a tensor that is already in the requested encoding. + """ shape = list(tensor.shape) + if not layout_transformation_kwargs and tensor.storage.layout.can_preserve_storage_as(layout, len(shape)): + return tensor # convert `tensor` into canonical form transformation = tensor.storage.layout.make_transformation(shape, tensor.dtype == FP4) canonical_data = transformation.unswizzle_data(tensor.storage.data) diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py index 4335dc624466..b62644acf698 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py @@ -20,6 +20,10 @@ def unswizzle_data(self, data): @dataclass(frozen=True) class Layout(ABC): + def can_preserve_storage_as(self, other: "Layout", rank: int) -> bool: + """Whether existing storage is already valid for `other`.""" + return self == other + @abstractmethod def make_transformation(self, shape: list[int]) -> LayoutTransformation: pass diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py index 81467f955165..92bd5833c158 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py @@ -34,6 +34,9 @@ class BlackwellActMXScaleLayout(Layout): ragged_metadata: RaggedTensorMetadata | None + def can_preserve_storage_as(self, other: Layout, rank: int) -> bool: + return isinstance(other, BlackwellActMXScaleLayout) and self.ragged_metadata is other.ragged_metadata + @property def name(self): return "BLACKWELL_ACT_SCALE" diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py index 294ea01e8908..36496cd2ab06 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py @@ -32,6 +32,9 @@ def make_transformation(self, shape: list[int], is_fp4: bool) -> LayoutTransform def name(self): return "STRIDED" + def can_preserve_storage_as(self, other: Layout, rank: int) -> bool: + return isinstance(other, StridedLayout) and self.order(rank) == other.order(rank) + def swizzle_block_shape(self, block_shape): return block_shape