From b521c6f4e4c800211c215355a9669d32b9ce042b Mon Sep 17 00:00:00 2001 From: Roman Novak Date: Wed, 27 May 2026 21:26:54 -0700 Subject: [PATCH 1/6] Make tensor layout conversion idempotent --- python/triton_kernels/tests/test_tensor.py | 12 ++++++++++++ python/triton_kernels/triton_kernels/tensor.py | 2 ++ 2 files changed, 14 insertions(+) diff --git a/python/triton_kernels/tests/test_tensor.py b/python/triton_kernels/tests/test_tensor.py index 495d40e40f33..ee25a40385aa 100644 --- a/python/triton_kernels/tests/test_tensor.py +++ b/python/triton_kernels/tests/test_tensor.py @@ -2,6 +2,7 @@ import torch from triton_kernels.tensor_details.dtype import BIT from triton_kernels.tensor import ( + convert_layout, make_ragged_tensor_metadata, make_ragged_tensor_metadata_torch, remap_ragged_tensor_metadata, @@ -11,6 +12,17 @@ wrap_torch_tensor, ) from triton_kernels.testing import assert_equal +from triton_kernels.tensor_details.layout import StridedLayout + + +@pytest.mark.parametrize(("transpose", "layout"), [(False, StridedLayout(-1)), (True, StridedLayout(-2))]) +def test_convert_layout_noop(transpose, layout): + data = torch.randn((7, 11)) + if transpose: + data = data.T + tensor = wrap_torch_tensor(data) + + assert convert_layout(tensor, layout) is tensor @pytest.mark.parametrize("n_slices", [1, 7, 33, 911, 1025]) diff --git a/python/triton_kernels/triton_kernels/tensor.py b/python/triton_kernels/triton_kernels/tensor.py index 14964289af94..366cd5768a5d 100644 --- a/python/triton_kernels/triton_kernels/tensor.py +++ b/python/triton_kernels/triton_kernels/tensor.py @@ -233,6 +233,8 @@ def wrap_torch_tensor(torch_tensor, dtype=None, shape=None, shape_max=None, layo def convert_layout(tensor: Tensor, layout: Layout, **layout_transformation_kwargs): + if tensor.storage.layout == layout: + return tensor shape = list(tensor.shape) # convert `tensor` into canonical form transformation = tensor.storage.layout.make_transformation(shape, tensor.dtype == FP4) From c80e52647cb844d834721b2dfcf4140c3a965957 Mon Sep 17 00:00:00 2001 From: Roman Novak Date: Wed, 27 May 2026 21:30:09 -0700 Subject: [PATCH 2/6] Preserve parameterized layout conversion behavior --- python/triton_kernels/tests/test_tensor.py | 7 +++++++ python/triton_kernels/triton_kernels/tensor.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/triton_kernels/tests/test_tensor.py b/python/triton_kernels/tests/test_tensor.py index ee25a40385aa..bf29a55ff8b2 100644 --- a/python/triton_kernels/tests/test_tensor.py +++ b/python/triton_kernels/tests/test_tensor.py @@ -25,6 +25,13 @@ def test_convert_layout_noop(transpose, layout): assert convert_layout(tensor, layout) is tensor +def test_convert_layout_noop_does_not_ignore_transformation_kwargs(): + tensor = wrap_torch_tensor(torch.randn((7, 11))) + + with pytest.raises(TypeError): + convert_layout(tensor, tensor.storage.layout, unsupported=True) + + @pytest.mark.parametrize("n_slices", [1, 7, 33, 911, 1025]) def test_make_ragged_tensor_metadata(n_slices): torch.manual_seed(0) diff --git a/python/triton_kernels/triton_kernels/tensor.py b/python/triton_kernels/triton_kernels/tensor.py index 366cd5768a5d..4d3c7f804eef 100644 --- a/python/triton_kernels/triton_kernels/tensor.py +++ b/python/triton_kernels/triton_kernels/tensor.py @@ -233,7 +233,7 @@ def wrap_torch_tensor(torch_tensor, dtype=None, shape=None, shape_max=None, layo def convert_layout(tensor: Tensor, layout: Layout, **layout_transformation_kwargs): - if tensor.storage.layout == layout: + if tensor.storage.layout == layout and not layout_transformation_kwargs: return tensor shape = list(tensor.shape) # convert `tensor` into canonical form From c064f23eb87ff3ed6eeea009bd50590e3ebb28f1 Mon Sep 17 00:00:00 2001 From: Roman Novak Date: Wed, 27 May 2026 22:22:59 -0700 Subject: [PATCH 3/6] Define layout conversion equivalence --- python/triton_kernels/tests/test_tensor.py | 7 +++++++ .../test_tensor_details/test_layout_blackwell.py | 12 +++++++++++- python/triton_kernels/triton_kernels/tensor.py | 6 +++--- .../tensor_details/layout_details/base.py | 4 ++++ .../tensor_details/layout_details/blackwell_scale.py | 3 +++ .../tensor_details/layout_details/strided.py | 3 +++ 6 files changed, 31 insertions(+), 4 deletions(-) diff --git a/python/triton_kernels/tests/test_tensor.py b/python/triton_kernels/tests/test_tensor.py index bf29a55ff8b2..25088ca026c7 100644 --- a/python/triton_kernels/tests/test_tensor.py +++ b/python/triton_kernels/tests/test_tensor.py @@ -25,6 +25,13 @@ def test_convert_layout_noop(transpose, layout): assert convert_layout(tensor, layout) is tensor +def test_convert_layout_noop_preserves_strided_view(): + tensor = wrap_torch_tensor(torch.randn((14, 11))[::2]) + + assert convert_layout(tensor, StridedLayout(-1)) is tensor + assert tensor.storage.data.stride() == (22, 1) + + def test_convert_layout_noop_does_not_ignore_transformation_kwargs(): tensor = wrap_torch_tensor(torch.randn((7, 11))) diff --git a/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py b/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py index 59d15ee96bdc..87714a57c480 100644 --- a/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py +++ b/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py @@ -1,13 +1,23 @@ import pytest import torch from triton_kernels.tensor_details.layout import BlackwellMXScaleLayout, BlackwellActMXScaleLayout, StridedLayout -from triton_kernels.tensor import make_ragged_tensor_metadata, wrap_torch_tensor, convert_layout +from triton_kernels.tensor import make_ragged_tensor_metadata, make_ragged_tensor_metadata_torch, wrap_torch_tensor, convert_layout # ------------------------------------------------------------ # Torch tests # ------------------------------------------------------------ +def test_act_scale_layout_equivalence(): + slice_sizes = torch.tensor([2, 3], dtype=torch.int32) + metadata = make_ragged_tensor_metadata_torch(slice_sizes, 5) + equivalent = BlackwellActMXScaleLayout(metadata) + reconstructed = BlackwellActMXScaleLayout(make_ragged_tensor_metadata_torch(slice_sizes, 5)) + + assert equivalent.is_equivalent_to(BlackwellActMXScaleLayout(metadata), [5, 4]) + assert not equivalent.is_equivalent_to(reconstructed, [5, 4]) + + @pytest.mark.parametrize( "shape", [ diff --git a/python/triton_kernels/triton_kernels/tensor.py b/python/triton_kernels/triton_kernels/tensor.py index 4d3c7f804eef..a82643e10510 100644 --- a/python/triton_kernels/triton_kernels/tensor.py +++ b/python/triton_kernels/triton_kernels/tensor.py @@ -225,7 +225,7 @@ def wrap_torch_tensor(torch_tensor, dtype=None, shape=None, shape_max=None, layo if shape_max is None: shape_max = list(shape) if layout is None: - # For a strided (dense) tensor we only track which dimension has unit stride. + # For a strided tensor we only track which dimension has unit stride. # This is consistent with how we expand `shape` for packed sub-byte dtypes. major_dim = torch_tensor.stride().index(1) if 1 in torch_tensor.stride() else -1 layout = StridedLayout(major_dim=major_dim - torch_tensor.ndim) @@ -233,9 +233,9 @@ def wrap_torch_tensor(torch_tensor, dtype=None, shape=None, shape_max=None, layo def convert_layout(tensor: Tensor, layout: Layout, **layout_transformation_kwargs): - if tensor.storage.layout == layout and not layout_transformation_kwargs: - return tensor shape = list(tensor.shape) + if not layout_transformation_kwargs and tensor.storage.layout.is_equivalent_to(layout, shape): + return tensor # convert `tensor` into canonical form transformation = tensor.storage.layout.make_transformation(shape, tensor.dtype == FP4) canonical_data = transformation.unswizzle_data(tensor.storage.data) diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py index 4335dc624466..03fe7723d62a 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py @@ -20,6 +20,10 @@ def unswizzle_data(self, data): @dataclass(frozen=True) class Layout(ABC): + def is_equivalent_to(self, other: "Layout", shape: list[int]) -> bool: + """Whether conversion to `other` can preserve the current storage.""" + return self == other + @abstractmethod def make_transformation(self, shape: list[int]) -> LayoutTransformation: pass diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py index 81467f955165..bebdab50a818 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py @@ -34,6 +34,9 @@ class BlackwellActMXScaleLayout(Layout): ragged_metadata: RaggedTensorMetadata | None + def is_equivalent_to(self, other: Layout, shape: list[int]) -> bool: + return isinstance(other, BlackwellActMXScaleLayout) and self.ragged_metadata is other.ragged_metadata + @property def name(self): return "BLACKWELL_ACT_SCALE" diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py index 294ea01e8908..8f9d68e81d98 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py @@ -32,6 +32,9 @@ def make_transformation(self, shape: list[int], is_fp4: bool) -> LayoutTransform def name(self): return "STRIDED" + def is_equivalent_to(self, other: Layout, shape: list[int]) -> bool: + return isinstance(other, StridedLayout) and self.order(len(shape)) == other.order(len(shape)) + def swizzle_block_shape(self, block_shape): return block_shape From 12e065a2a7eab94b3ee20b8fda89465af4dd9b43 Mon Sep 17 00:00:00 2001 From: Roman Novak Date: Wed, 27 May 2026 23:44:02 -0700 Subject: [PATCH 4/6] Test idempotence across tensor layouts --- python/triton_kernels/tests/test_tensor.py | 53 +++++++++++++++++++++- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/python/triton_kernels/tests/test_tensor.py b/python/triton_kernels/tests/test_tensor.py index 25088ca026c7..a792a11d9ad0 100644 --- a/python/triton_kernels/tests/test_tensor.py +++ b/python/triton_kernels/tests/test_tensor.py @@ -1,6 +1,6 @@ import pytest import torch -from triton_kernels.tensor_details.dtype import BIT +from triton_kernels.tensor_details.dtype import BIT, FP4, UINT8 from triton_kernels.tensor import ( convert_layout, make_ragged_tensor_metadata, @@ -12,7 +12,17 @@ wrap_torch_tensor, ) from triton_kernels.testing import assert_equal -from triton_kernels.tensor_details.layout import StridedLayout +from triton_kernels.tensor_details.layout import ( + BlackwellActMXScaleLayout, + BlackwellMX4ValueShuffledLayout, + BlackwellMXScaleLayout, + BlackwellMXValueLayout, + CDNA4MXScaleLayout, + GFX1250MXScaleLayout, + HopperMXScaleLayout, + HopperMXValueLayout, + StridedLayout, +) @pytest.mark.parametrize(("transpose", "layout"), [(False, StridedLayout(-1)), (True, StridedLayout(-2))]) @@ -39,6 +49,45 @@ def test_convert_layout_noop_does_not_ignore_transformation_kwargs(): convert_layout(tensor, tensor.storage.layout, unsupported=True) +@pytest.mark.parametrize( + ("storage_shape", "logical_shape", "dtype", "layout", "equivalent_layout"), + [ + ((10, 254, 60), None, UINT8, BlackwellMXScaleLayout(), BlackwellMXScaleLayout()), + ((130, 65), None, UINT8, BlackwellActMXScaleLayout(None), BlackwellActMXScaleLayout(None)), + ((256, 64), (256, 128), FP4, BlackwellMXValueLayout(), BlackwellMXValueLayout()), + ((128, 256), (128, 512), FP4, BlackwellMX4ValueShuffledLayout(), BlackwellMX4ValueShuffledLayout()), + ((70, 65), None, UINT8, HopperMXScaleLayout(-2, 4), HopperMXScaleLayout(-2, 4)), + ((64, 64), (64, 128), FP4, HopperMXValueLayout(-2, 3), HopperMXValueLayout(-2, 3)), + ((10, 254, 60), None, UINT8, CDNA4MXScaleLayout(), CDNA4MXScaleLayout()), + ((10, 254, 60), None, UINT8, GFX1250MXScaleLayout(), GFX1250MXScaleLayout()), + ], +) +def test_convert_layout_noop_for_equivalent_layout(storage_shape, logical_shape, dtype, layout, equivalent_layout): + tensor = wrap_torch_tensor(torch.randint(0, 256, storage_shape, dtype=torch.uint8), dtype=dtype, + shape=logical_shape) + converted = convert_layout(tensor, layout) + + assert converted is not tensor + assert convert_layout(converted, equivalent_layout) is converted + + +@pytest.mark.parametrize( + ("storage_shape", "logical_shape", "dtype", "layout", "different_layout"), + [ + ((70, 65), None, UINT8, HopperMXScaleLayout(-2, 4), HopperMXScaleLayout(-2, 8)), + ((64, 64), (64, 128), FP4, HopperMXValueLayout(-2, 3), HopperMXValueLayout(-2, 2)), + ((128, 256), (128, 512), FP4, BlackwellMX4ValueShuffledLayout(), BlackwellMX4ValueShuffledLayout(block_n=128)), + ], +) +def test_convert_layout_converts_different_parameterized_layout(storage_shape, logical_shape, dtype, layout, + different_layout): + tensor = wrap_torch_tensor(torch.randint(0, 256, storage_shape, dtype=torch.uint8), dtype=dtype, + shape=logical_shape) + converted = convert_layout(tensor, layout) + + assert convert_layout(converted, different_layout) is not converted + + @pytest.mark.parametrize("n_slices", [1, 7, 33, 911, 1025]) def test_make_ragged_tensor_metadata(n_slices): torch.manual_seed(0) From d5a68ac4b574ec46c01e741b1e046ecbb2702509 Mon Sep 17 00:00:00 2001 From: Roman Novak Date: Wed, 27 May 2026 23:56:42 -0700 Subject: [PATCH 5/6] Clarify layout storage preservation contract --- python/triton_kernels/tests/test_mxfp.py | 17 +++++++++++++++++ python/triton_kernels/tests/test_tensor.py | 17 ++++++++++++++++- .../test_layout_blackwell.py | 6 +++--- python/triton_kernels/triton_kernels/tensor.py | 8 +++++++- .../tensor_details/layout_details/base.py | 4 ++-- .../layout_details/blackwell_scale.py | 2 +- .../tensor_details/layout_details/strided.py | 2 +- 7 files changed, 47 insertions(+), 9 deletions(-) diff --git a/python/triton_kernels/tests/test_mxfp.py b/python/triton_kernels/tests/test_mxfp.py index 0f1087e71fa6..40c7c016da12 100644 --- a/python/triton_kernels/tests/test_mxfp.py +++ b/python/triton_kernels/tests/test_mxfp.py @@ -17,6 +17,8 @@ ) from triton_kernels.numerics_details.mxfp_details._upcast_from_mxfp import upcast_mxfp4_tile from triton_kernels.target_info import is_cuda +from triton_kernels.tensor import convert_layout, wrap_torch_tensor +from triton_kernels.tensor_details.layout import StridedLayout from triton_kernels.testing import assert_close, assert_equal @@ -203,6 +205,21 @@ def test_mxfp_quant_dequant(src_dtype, dst_dtype, device): assert_equal(weight, dequant) +def test_downcast_to_mxfp_accepts_pitched_strided_input(device): + torch.manual_seed(0) + dense = torch.randn((64, 128), device=device, dtype=torch.bfloat16) + pitched = torch.empty_strided(dense.shape, (256, 1), device=device, dtype=dense.dtype) + pitched.copy_(dense) + + tensor = convert_layout(wrap_torch_tensor(pitched), StridedLayout(-1)) + quant, scale = downcast_to_mxfp(tensor, torch.uint8, axis=-1) + expected_quant, expected_scale = downcast_to_mxfp(dense, torch.uint8, axis=-1) + + assert tensor.storage.data.stride() == (256, 1) + assert_equal(expected_quant, quant) + assert_equal(expected_scale, scale) + + # fmt: off @pytest.mark.parametrize( "shape, axis, quant_dtype, rounding_mode, scale_dtype, microblock_size", diff --git a/python/triton_kernels/tests/test_tensor.py b/python/triton_kernels/tests/test_tensor.py index a792a11d9ad0..42a587853bfc 100644 --- a/python/triton_kernels/tests/test_tensor.py +++ b/python/triton_kernels/tests/test_tensor.py @@ -25,7 +25,15 @@ ) -@pytest.mark.parametrize(("transpose", "layout"), [(False, StridedLayout(-1)), (True, StridedLayout(-2))]) +@pytest.mark.parametrize( + ("transpose", "layout"), + [ + (False, StridedLayout(-1)), + (False, StridedLayout(1)), + (True, StridedLayout(-2)), + (True, StridedLayout(0)), + ], +) def test_convert_layout_noop(transpose, layout): data = torch.randn((7, 11)) if transpose: @@ -42,6 +50,13 @@ def test_convert_layout_noop_preserves_strided_view(): assert tensor.storage.data.stride() == (22, 1) +def test_convert_layout_rejects_strided_view_without_contiguous_dimension(): + tensor = wrap_torch_tensor(torch.randn((14, 22))[::2, ::2]) + + with pytest.raises(ValueError): + convert_layout(tensor, tensor.storage.layout) + + def test_convert_layout_noop_does_not_ignore_transformation_kwargs(): tensor = wrap_torch_tensor(torch.randn((7, 11))) diff --git a/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py b/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py index 87714a57c480..769f7a85ef65 100644 --- a/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py +++ b/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py @@ -8,14 +8,14 @@ # ------------------------------------------------------------ -def test_act_scale_layout_equivalence(): +def test_act_scale_storage_preservation(): slice_sizes = torch.tensor([2, 3], dtype=torch.int32) metadata = make_ragged_tensor_metadata_torch(slice_sizes, 5) equivalent = BlackwellActMXScaleLayout(metadata) reconstructed = BlackwellActMXScaleLayout(make_ragged_tensor_metadata_torch(slice_sizes, 5)) - assert equivalent.is_equivalent_to(BlackwellActMXScaleLayout(metadata), [5, 4]) - assert not equivalent.is_equivalent_to(reconstructed, [5, 4]) + assert equivalent.can_preserve_storage_as(BlackwellActMXScaleLayout(metadata), [5, 4]) + assert not equivalent.can_preserve_storage_as(reconstructed, [5, 4]) @pytest.mark.parametrize( diff --git a/python/triton_kernels/triton_kernels/tensor.py b/python/triton_kernels/triton_kernels/tensor.py index a82643e10510..aefc8abb4827 100644 --- a/python/triton_kernels/triton_kernels/tensor.py +++ b/python/triton_kernels/triton_kernels/tensor.py @@ -233,8 +233,14 @@ def wrap_torch_tensor(torch_tensor, dtype=None, shape=None, shape_max=None, layo def convert_layout(tensor: Tensor, layout: Layout, **layout_transformation_kwargs): + """Convert `tensor` storage encoding to `layout`. + + Returns `tensor` unchanged when its existing storage is already valid for + `layout`. This operation does not clone, densify, or canonicalize physical + strides of a tensor that is already in the requested encoding. + """ shape = list(tensor.shape) - if not layout_transformation_kwargs and tensor.storage.layout.is_equivalent_to(layout, shape): + if not layout_transformation_kwargs and tensor.storage.layout.can_preserve_storage_as(layout, shape): return tensor # convert `tensor` into canonical form transformation = tensor.storage.layout.make_transformation(shape, tensor.dtype == FP4) diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py index 03fe7723d62a..d760db28a77d 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py @@ -20,8 +20,8 @@ def unswizzle_data(self, data): @dataclass(frozen=True) class Layout(ABC): - def is_equivalent_to(self, other: "Layout", shape: list[int]) -> bool: - """Whether conversion to `other` can preserve the current storage.""" + def can_preserve_storage_as(self, other: "Layout", shape: list[int]) -> bool: + """Whether existing storage is already valid for `other`.""" return self == other @abstractmethod diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py index bebdab50a818..1ce332dc3165 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py @@ -34,7 +34,7 @@ class BlackwellActMXScaleLayout(Layout): ragged_metadata: RaggedTensorMetadata | None - def is_equivalent_to(self, other: Layout, shape: list[int]) -> bool: + def can_preserve_storage_as(self, other: Layout, shape: list[int]) -> bool: return isinstance(other, BlackwellActMXScaleLayout) and self.ragged_metadata is other.ragged_metadata @property diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py index 8f9d68e81d98..929f4ced2602 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py @@ -32,7 +32,7 @@ def make_transformation(self, shape: list[int], is_fp4: bool) -> LayoutTransform def name(self): return "STRIDED" - def is_equivalent_to(self, other: Layout, shape: list[int]) -> bool: + def can_preserve_storage_as(self, other: Layout, shape: list[int]) -> bool: return isinstance(other, StridedLayout) and self.order(len(shape)) == other.order(len(shape)) def swizzle_block_shape(self, block_shape): From 9e8939c16562aa435ea5fb8a08904283bd195458 Mon Sep 17 00:00:00 2001 From: Roman Novak Date: Thu, 28 May 2026 00:03:28 -0700 Subject: [PATCH 6/6] Pass tensor rank to layout preservation check --- .../tests/test_tensor_details/test_layout_blackwell.py | 4 ++-- python/triton_kernels/triton_kernels/tensor.py | 2 +- .../triton_kernels/tensor_details/layout_details/base.py | 2 +- .../tensor_details/layout_details/blackwell_scale.py | 2 +- .../triton_kernels/tensor_details/layout_details/strided.py | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py b/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py index 769f7a85ef65..9ae43d3cb118 100644 --- a/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py +++ b/python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py @@ -14,8 +14,8 @@ def test_act_scale_storage_preservation(): equivalent = BlackwellActMXScaleLayout(metadata) reconstructed = BlackwellActMXScaleLayout(make_ragged_tensor_metadata_torch(slice_sizes, 5)) - assert equivalent.can_preserve_storage_as(BlackwellActMXScaleLayout(metadata), [5, 4]) - assert not equivalent.can_preserve_storage_as(reconstructed, [5, 4]) + assert equivalent.can_preserve_storage_as(BlackwellActMXScaleLayout(metadata), 2) + assert not equivalent.can_preserve_storage_as(reconstructed, 2) @pytest.mark.parametrize( diff --git a/python/triton_kernels/triton_kernels/tensor.py b/python/triton_kernels/triton_kernels/tensor.py index aefc8abb4827..7f59a5d3d7d1 100644 --- a/python/triton_kernels/triton_kernels/tensor.py +++ b/python/triton_kernels/triton_kernels/tensor.py @@ -240,7 +240,7 @@ def convert_layout(tensor: Tensor, layout: Layout, **layout_transformation_kwarg strides of a tensor that is already in the requested encoding. """ shape = list(tensor.shape) - if not layout_transformation_kwargs and tensor.storage.layout.can_preserve_storage_as(layout, shape): + if not layout_transformation_kwargs and tensor.storage.layout.can_preserve_storage_as(layout, len(shape)): return tensor # convert `tensor` into canonical form transformation = tensor.storage.layout.make_transformation(shape, tensor.dtype == FP4) diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py index d760db28a77d..b62644acf698 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py @@ -20,7 +20,7 @@ def unswizzle_data(self, data): @dataclass(frozen=True) class Layout(ABC): - def can_preserve_storage_as(self, other: "Layout", shape: list[int]) -> bool: + def can_preserve_storage_as(self, other: "Layout", rank: int) -> bool: """Whether existing storage is already valid for `other`.""" return self == other diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py index 1ce332dc3165..92bd5833c158 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py @@ -34,7 +34,7 @@ class BlackwellActMXScaleLayout(Layout): ragged_metadata: RaggedTensorMetadata | None - def can_preserve_storage_as(self, other: Layout, shape: list[int]) -> bool: + def can_preserve_storage_as(self, other: Layout, rank: int) -> bool: return isinstance(other, BlackwellActMXScaleLayout) and self.ragged_metadata is other.ragged_metadata @property diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py index 929f4ced2602..36496cd2ab06 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py @@ -32,8 +32,8 @@ def make_transformation(self, shape: list[int], is_fp4: bool) -> LayoutTransform def name(self): return "STRIDED" - def can_preserve_storage_as(self, other: Layout, shape: list[int]) -> bool: - return isinstance(other, StridedLayout) and self.order(len(shape)) == other.order(len(shape)) + def can_preserve_storage_as(self, other: Layout, rank: int) -> bool: + return isinstance(other, StridedLayout) and self.order(rank) == other.order(rank) def swizzle_block_shape(self, block_shape): return block_shape