From 228702332a11d4600cd5b22c58f71050ed5e1d75 Mon Sep 17 00:00:00 2001 From: color <1044187534@qq.com> Date: Wed, 4 Feb 2026 12:53:03 +0800 Subject: [PATCH] fix(intrinsics): add missing _legalize_to_buffer_region in SM70 emitter --- .../intrinsics/mma_sm70_macro_generator.py | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/tilelang/intrinsics/mma_sm70_macro_generator.py b/tilelang/intrinsics/mma_sm70_macro_generator.py index 6acc40a4c..52679b169 100644 --- a/tilelang/intrinsics/mma_sm70_macro_generator.py +++ b/tilelang/intrinsics/mma_sm70_macro_generator.py @@ -2,10 +2,12 @@ import tilelang.language as T from typing import Literal, Callable from tvm import DataType -from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion +from tvm import tir +from tvm.ir import Range +from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad from tilelang import tvm as tvm from tvm.runtime import convert -from tilelang.utils import is_fragment +from tilelang.utils import is_fragment, get_buffer_region_from_load from tilelang.intrinsics.mma_sm70_layout import ( shared_16x4_to_mma_a_32x4_layout, shared_4x16_to_mma_b_32x4_layout, @@ -493,3 +495,30 @@ def forward_index(i: int, j: int) -> int: forward_thread_fn=forward_thread, forward_index_fn=forward_index, ) + + @staticmethod + def _legalize_to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion: + """ + Convert Buffer/BufferRegion/BufferLoad to a BufferRegion. + + - Buffer -> full-region BufferRegion covering entire shape + - BufferRegion -> returned as-is + - BufferLoad -> best-effort convert via get_buffer_region_from_load; + if scalar, fall back to 1-sized ranges at given indices + """ + if isinstance(obj, BufferRegion): + return obj + if isinstance(obj, Buffer): + mins = [tir.IntImm("int32", 0) for _ in obj.shape] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)] + return BufferRegion(obj, ranges) + if isinstance(obj, BufferLoad): + region = get_buffer_region_from_load(obj) + if region is not None: + return region + # Fallback: scalar load -> 1-sized ranges at indices + mins = [idx for idx in obj.indices] + ones = [tir.IntImm("int32", 1) for _ in obj.indices] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, ones)] + return BufferRegion(obj.buffer, ranges) + raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}")