Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions tilelang/intrinsics/mma_sm70_macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import tilelang.language as T
from typing import Literal, Callable
from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion
from tvm import tir
from tvm.ir import Range
from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad
from tilelang import tvm as tvm
from tvm.runtime import convert
from tilelang.utils import is_fragment
from tilelang.utils import is_fragment, get_buffer_region_from_load
from tilelang.intrinsics.mma_sm70_layout import (
shared_16x4_to_mma_a_32x4_layout,
shared_4x16_to_mma_b_32x4_layout,
Expand Down Expand Up @@ -493,3 +495,30 @@ def forward_index(i: int, j: int) -> int:
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)

@staticmethod
def _legalize_to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion:
"""
Convert Buffer/BufferRegion/BufferLoad to a BufferRegion.

- Buffer -> full-region BufferRegion covering entire shape
- BufferRegion -> returned as-is
- BufferLoad -> best-effort convert via get_buffer_region_from_load;
if scalar, fall back to 1-sized ranges at given indices
"""
if isinstance(obj, BufferRegion):
return obj
if isinstance(obj, Buffer):
mins = [tir.IntImm("int32", 0) for _ in obj.shape]
ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)]
return BufferRegion(obj, ranges)
if isinstance(obj, BufferLoad):
region = get_buffer_region_from_load(obj)
if region is not None:
return region
# Fallback: scalar load -> 1-sized ranges at indices
mins = [idx for idx in obj.indices]
ones = [tir.IntImm("int32", 1) for _ in obj.indices]
ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, ones)]
return BufferRegion(obj.buffer, ranges)
raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}")
Loading