11"""The language interface for tl programs."""
22from __future__ import annotations
33from tilelang .primitives .gemm .base import GemmWarpPolicy
4- from tilelang .utils .language import get_buffer_region_from_load
54import tilelang .language as T
65from tvm import tir
7- from tilelang .utils .language import to_buffer_region
6+ from tilelang .utils .language import (
7+ to_buffer_region ,
8+ retrieve_shape ,
9+ retrieve_stride ,
10+ retrieve_offset ,
11+ prim_expr_equal ,
12+ )
13+ from tilelang .language .utils import (
14+ buffer_region_to_tile_region ,)
815
916
1017def gemm_sp (
@@ -144,54 +151,13 @@ def legalize_arguments(arg: tir.Buffer | tir.Var):
144151 B = legalize_arguments (B )
145152 C = legalize_arguments (C )
146153
147- def retrieve_shape (object : tir .Buffer | tir .BufferRegion ) -> list [int ]:
148- if isinstance (object , tir .Buffer ):
149- return object .shape
150- elif isinstance (object , tir .BufferRegion ):
151- region = object .region
152- shape = []
153- for r in region :
154- shape .append (r .extent )
155- return shape
156- elif isinstance (object , tir .BufferLoad ):
157- region = get_buffer_region_from_load (object ).region
158- shape = []
159- for r in region :
160- shape .append (r .extent )
161- return shape
162- else :
163- raise ValueError (
164- f"Unsupported retrieve_shape argument type: { type (object )} for buffer { object } " )
165-
166- def retrieve_stride (object : tir .Buffer | tir .BufferRegion ) -> list [int ]:
167- if isinstance (object , tir .Buffer ):
168- strides = []
169- stride = 1
170- for s in reversed (object .shape ):
171- strides .insert (0 , stride )
172- stride *= s
173- return strides
174- elif isinstance (object , tir .BufferRegion ):
175- buffer , _ = object .buffer , object .region
176- strides = []
177- stride = 1
178- for s in reversed (buffer .shape ):
179- strides .insert (0 , stride )
180- stride *= s
181- return strides
182- elif isinstance (object , tir .BufferLoad ):
183- buffer = object .buffer
184- strides = []
185- stride = 1
186- for s in reversed (buffer .shape ):
187- strides .insert (0 , stride )
188- stride *= s
189- return strides
190- else :
191- raise ValueError (
192- f"Unsupported retrieve_stride argument type: { type (object )} for buffer { object } " )
154+ A_region = to_buffer_region (A_sparse )
155+ E_region = to_buffer_region (E )
156+ B_region = to_buffer_region (B )
157+ C_region = to_buffer_region (C )
193158
194159 A_shape = retrieve_shape (A_sparse )
160+ E_shape = retrieve_shape (E ) # nolint: F841
195161 B_shape = retrieve_shape (B )
196162 C_shape = retrieve_shape (C )
197163
@@ -213,86 +179,30 @@ def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]:
213179 M , N = C_shape
214180 K = 2 * (A_shape [- 2 ] if transpose_A else A_shape [- 1 ])
215181 K_B = B_shape [- 1 ] if transpose_B else B_shape [- 2 ]
216- assert K == K_B , f"T.gemm_sp K shape check failed: K_A (wo sparse) = { K } , K_B = { K_B } "
182+ assert prim_expr_equal (
183+ K , K_B ), f"T.gemm_sp K shape check failed: K_A (wo sparse) = { K } , K_B = { K_B } "
217184
218185 stride_a = A_stride [- 2 ]
219186 stride_b = B_stride [- 2 ]
220187
221- def retrieve_ptr (object : tir .Buffer | tir .BufferRegion , access_type : str = "r" ) -> tir .PrimExpr :
222- if isinstance (object , tir .Buffer ):
223- return object .access_ptr (access_type )
224- elif isinstance (object , tir .BufferRegion ):
225- buffer , region = object .buffer , object .region
226- indices = []
227- for r in region :
228- indices .append (r .min )
229- strides = []
230- stride = 1
231- for s in reversed (buffer .shape ):
232- strides .insert (0 , stride )
233- stride *= s
234- offset = 0
235- # not offset the last two dimension
236- for i in range (len (indices ) - 2 ):
237- offset += indices [i ] * strides [i ]
238- return buffer .access_ptr (access_mask = access_type , offset = offset )
239- elif isinstance (object , tir .BufferLoad ):
240- buffer = object .buffer
241- region = get_buffer_region_from_load (object ).region
242- indices = []
243- for r in region :
244- indices .append (r .min )
245- strides = []
246- stride = 1
247- for s in reversed (buffer .shape ):
248- strides .insert (0 , stride )
249- stride *= s
250- offset = 0
251- for i in range (len (indices ) - 2 ):
252- offset += indices [i ] * strides [i ]
253- return buffer .access_ptr (access_mask = access_type , offset = offset )
254- else :
255- raise ValueError (
256- f"Unsupported retrieve_ptr argument type: { type (object )} for buffer { object } " )
257-
258- def retrieve_offset (object : tir .Buffer | tir .BufferRegion ) -> tir .PrimExpr :
259- """Retrieve the offset of the buffer or buffer region."""
260- if isinstance (object , tir .Buffer ):
261- return [0 ] * len (object .shape )
262- elif isinstance (object , tir .BufferRegion ):
263- _ , region = object .buffer , object .region
264- indices = []
265- for r in region :
266- indices .append (r .min )
267- return indices
268- elif isinstance (object , tir .BufferLoad ):
269- region = get_buffer_region_from_load (object ).region
270- indices = []
271- for r in region :
272- indices .append (r .min )
273- return indices
274- else :
275- raise ValueError (
276- f"Unsupported retrieve_offset argument type: { type (object )} for buffer { object } " )
277-
278188 A_offset = retrieve_offset (A_sparse )
279189 B_offset = retrieve_offset (B )
280190 assert A_offset [- 2 ] == 0 , "The offset of the first dimension of A must be 0"
281191 assert B_offset [- 2 ] == 0 , "The offset of the first dimension of B must be 0"
282192 offset_a = A_offset [- 1 ]
283193 offset_b = B_offset [- 1 ]
284194
285- Aptr = retrieve_ptr ( A_sparse , "r" )
286- Eptr = retrieve_ptr ( E , "r" )
287- Bptr = retrieve_ptr ( B , "r" )
288- Cptr = retrieve_ptr ( C , "rw" )
195+ A_arg = buffer_region_to_tile_region ( A_region , "r" , [ r for r in A_shape ] )
196+ E_arg = buffer_region_to_tile_region ( E_region , "r" , [ r for r in E_shape ] )
197+ B_arg = buffer_region_to_tile_region ( B_region , "r" , [ r for r in B_shape ] )
198+ C_arg = buffer_region_to_tile_region ( C_region , "rw" , [ r for r in C_shape ] )
289199 return tir .call_intrin (
290200 "handle" ,
291201 tir .op .Op .get ("tl.gemm_sp_py" ),
292- Aptr ,
293- Eptr ,
294- Bptr ,
295- Cptr ,
202+ A_arg ,
203+ E_arg ,
204+ B_arg ,
205+ C_arg ,
296206 transpose_A ,
297207 transpose_B ,
298208 transpose_E ,
0 commit comments