Skip to content

Commit c2ec489

Browse files
committed
bug fix
1 parent 9a4a359 commit c2ec489

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

tilelang/language/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List
5454
tir.Call: A region descriptor for the loaded area
5555
"""
5656
indices = load.indices
57-
print("indices", indices)
58-
print("extents", extents)
57+
5958
if len(indices) > len(extents):
6059
# (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, "
6160
# f"region will be expanded in the last 2 dimensions")

tilelang/utils/language.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,16 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> Optional[tir.Buf
131131
"""
132132
buffer, indices = buffer_load.buffer, buffer_load.indices
133133
regions = []
134+
found_ramp: bool = False
134135
for indice in indices:
135136
if isinstance(indice, tir.Ramp):
136137
regions.append(ir.Range.from_min_extent(indice.base, indice.lanes))
138+
found_ramp = True
137139
elif isinstance(indice, tir.PrimExpr):
138140
regions.append(ir.Range.from_min_extent(indice, 1))
139141
else:
140142
raise ValueError("Unsupported type: ", type(indice))
141-
return tir.BufferRegion(buffer, regions)
143+
if found_ramp:
144+
return tir.BufferRegion(buffer, regions)
145+
else:
146+
return None

0 commit comments

Comments
 (0)