55import tilelang .language as T
66from tvm import ir
77from tvm .tir import PrimExpr , Buffer , BufferRegion , Var , op
8+ from tvm .tir .expr import BufferLoad , Ramp
89from typing import Optional
910
1011_MEMORY_ORDER_ID_MAP = {
@@ -175,13 +176,41 @@ def get_extent(data):
175176 return data .shape
176177 elif isinstance (data , BufferRegion ):
177178 return [x .extent for x in data .region ]
179+ elif isinstance (data , BufferLoad ):
180+ ret = []
181+ new_indices = []
182+ for x in data .indices :
183+ if isinstance (x , Ramp ):
184+ ret .append (x .lanes )
185+ new_indices .append (x .base )
186+ else :
187+ ret .append (1 )
188+ new_indices .append (x )
189+ return ret
178190 else :
179191 return None
180192
181- src_extent = get_extent (value )
182- dst_extent = get_extent (dst )
193+ src_extent = list (get_extent (value ))
194+ dst_extent = list (get_extent (dst ))
195+ legal = True
196+
197+ if (dst_extent is None and src_extent is None ) or len (dst_extent ) < len (src_extent ):
198+ legal = False
199+ elif (dst_extent and src_extent ):
200+ if len (dst_extent ) > len (src_extent ):
201+ dst_extent_dims = [x for x in dst_extent if x != 1 ]
202+ if dst_extent_dims != src_extent :
203+ legal = False
204+ else :
205+ if dst_extent != src_extent :
206+ legal = False
207+ else :
208+ src_extent = list (src_extent ) if src_extent else [1 ] * len (dst_extent )
209+ dst_extent = list (dst_extent ) if dst_extent else [1 ] * len (src_extent )
210+ extent = max (dst_extent , src_extent )
211+ dst_extent = src_extent = extent
183212
184- if dst_extent is None and src_extent is None :
213+ if not legal :
185214 func_name = "AtomicAddRet" if return_prev else "AtomicAdd"
186215 return_type = dst .dtype if return_prev else "handle"
187216
@@ -194,12 +223,7 @@ def get_extent(data):
194223 if isinstance (dst , Buffer ) and isinstance (value , Buffer ):
195224 ir .assert_structural_equal (dst .shape , value .shape )
196225
197- assert src_extent or dst_extent , "Can't deduce atomicadd extents from args"
198- src_extent = list (src_extent ) if src_extent else [1 ] * len (dst_extent )
199- dst_extent = list (dst_extent ) if dst_extent else [1 ] * len (src_extent )
200- extent = max (src_extent , dst_extent )
201-
202- def _to_region (data , access_type ):
226+ def _to_region (data , extent , access_type ):
203227 from .customize import buffer_to_tile_region , buffer_region_to_tile_region , buffer_load_to_tile_region
204228
205229 if isinstance (data , Var ) and T .has_let_value (data ):
@@ -209,10 +233,17 @@ def _to_region(data, access_type):
209233 elif isinstance (data , BufferRegion ):
210234 return buffer_region_to_tile_region (data , access_type , extent )
211235 else :
236+ new_indices = []
237+ for x in data .indices :
238+ if isinstance (x , Ramp ):
239+ new_indices .append (x .base )
240+ else :
241+ new_indices .append (x )
242+ data = T .BufferLoad (data .buffer , new_indices )
212243 return buffer_load_to_tile_region (data , access_type , extent )
213244
214- value = _to_region (value , "r" )
215- dst = _to_region (dst , "w" )
245+ value = _to_region (value , src_extent , "r" )
246+ dst = _to_region (dst , dst_extent , "w" )
216247
217248 # Note: tile-region-based atomic operations don't support return_prev yet
218249 # This would need to be implemented in the tile runtime
0 commit comments