Skip to content

Commit 2dab926

Browse files
committed
update
1 parent 5b719a8 commit 2dab926

File tree

1 file changed

+42
-11
lines changed

1 file changed

+42
-11
lines changed

tilelang/language/atomic.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import tilelang.language as T
66
from tvm import ir
77
from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op
8+
from tvm.tir.expr import BufferLoad, Ramp
89
from typing import Optional
910

1011
_MEMORY_ORDER_ID_MAP = {
@@ -175,13 +176,41 @@ def get_extent(data):
175176
return data.shape
176177
elif isinstance(data, BufferRegion):
177178
return [x.extent for x in data.region]
179+
elif isinstance(data, BufferLoad):
180+
ret = []
181+
new_indices = []
182+
for x in data.indices:
183+
if isinstance(x, Ramp):
184+
ret.append(x.lanes)
185+
new_indices.append(x.base)
186+
else:
187+
ret.append(1)
188+
new_indices.append(x)
189+
return ret
178190
else:
179191
return None
180192

181-
src_extent = get_extent(value)
182-
dst_extent = get_extent(dst)
193+
src_extent = list(get_extent(value))
194+
dst_extent = list(get_extent(dst))
195+
legal = True
196+
197+
if (dst_extent is None and src_extent is None) or len(dst_extent) < len(src_extent):
198+
legal = False
199+
elif (dst_extent and src_extent):
200+
if len(dst_extent) > len(src_extent):
201+
dst_extent_dims = [x for x in dst_extent if x != 1]
202+
if dst_extent_dims != src_extent:
203+
legal = False
204+
else:
205+
if dst_extent != src_extent:
206+
legal = False
207+
else:
208+
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
209+
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
210+
extent = max(dst_extent, src_extent)
211+
dst_extent = src_extent = extent
183212

184-
if dst_extent is None and src_extent is None:
213+
if not legal:
185214
func_name = "AtomicAddRet" if return_prev else "AtomicAdd"
186215
return_type = dst.dtype if return_prev else "handle"
187216

@@ -194,12 +223,7 @@ def get_extent(data):
194223
if isinstance(dst, Buffer) and isinstance(value, Buffer):
195224
ir.assert_structural_equal(dst.shape, value.shape)
196225

197-
assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
198-
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
199-
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
200-
extent = max(src_extent, dst_extent)
201-
202-
def _to_region(data, access_type):
226+
def _to_region(data, extent, access_type):
203227
from .customize import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
204228

205229
if isinstance(data, Var) and T.has_let_value(data):
@@ -209,10 +233,17 @@ def _to_region(data, access_type):
209233
elif isinstance(data, BufferRegion):
210234
return buffer_region_to_tile_region(data, access_type, extent)
211235
else:
236+
new_indices = []
237+
for x in data.indices:
238+
if isinstance(x, Ramp):
239+
new_indices.append(x.base)
240+
else:
241+
new_indices.append(x)
242+
data = T.BufferLoad(data.buffer, new_indices)
212243
return buffer_load_to_tile_region(data, access_type, extent)
213244

214-
value = _to_region(value, "r")
215-
dst = _to_region(dst, "w")
245+
value = _to_region(value, src_extent, "r")
246+
dst = _to_region(dst, dst_extent, "w")
216247

217248
# Note: tile-region-based atomic operations don't support return_prev yet
218249
# This would need to be implemented in the tile runtime

0 commit comments

Comments
 (0)