Skip to content

Commit ff22934

Browse files
committed
[chore] refactor binding
1 parent 47fc794 commit ff22934

File tree

4 files changed

+52
-137
lines changed

4 files changed

+52
-137
lines changed

src/op/gemm_sp_py.cc

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66

77
#include "gemm_sp_py.h"
8+
#include "utils.h"
89

910
#include "builtin.h"
1011
#include <tvm/tir/builtin.h>
@@ -48,17 +49,19 @@ using namespace tir;
4849
* fails with an ICHECK (runtime assertion). No other validation is
4950
* performed here.
5051
*/
51-
GemmSPPy::GemmSPPy(Array<PrimExpr> args, BufferMap vmap) {
52+
GemmSPPy::GemmSPPy(Array<PrimExpr> args) {
5253
ObjectPtr<GemmSPPyNode> node = tvm::ffi::make_object<GemmSPPyNode>();
5354

54-
node->Aptr = args[0];
55-
node->Eptr = args[1];
56-
node->Bptr = args[2];
57-
node->Cptr = args[3];
58-
node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
59-
node->E = vmap[GetVarFromAccessPtr(node->Eptr)];
60-
node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
61-
node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
55+
node->aRegion_ = NormalizeToBufferRegion(args[0]);
56+
node->eRegion_ = NormalizeToBufferRegion(args[1]);
57+
node->bRegion_ = NormalizeToBufferRegion(args[2]);
58+
node->cRegion_ = NormalizeToBufferRegion(args[3]);
59+
60+
node->A = node->aRegion_->buffer;
61+
node->E = node->eRegion_->buffer;
62+
node->B = node->bRegion_->buffer;
63+
node->C = node->cRegion_->buffer;
64+
6265
node->trans_A = args[4].as<Bool>().value();
6366
node->trans_B = args[5].as<Bool>().value();
6467
node->trans_E = args[6].as<Bool>().value();

src/op/gemm_sp_py.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class GemmSPPyNode : public TileOperatorNode {
2323
bool CheckWGMMA() const;
2424
tir::Buffer A, E, B, C;
2525
// pointer to the A, E, B, C
26-
PrimExpr Aptr, Eptr, Bptr, Cptr;
26+
BufferRegion aRegion_, eRegion_, bRegion_, cRegion_;
2727
bool trans_A, trans_B, trans_E;
2828
int M, N, K;
2929
int stride_A, stride_B;
@@ -33,6 +33,8 @@ class GemmSPPyNode : public TileOperatorNode {
3333
// only will be enabled under cdna mfma instructions
3434
int kPack = 1;
3535
int wg_wait = 0;
36+
37+
// use GemmWarp Policy here as the atom size are flexible in v2
3638
mutable GemmWarpPolicy policy;
3739

3840
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSPPy", GemmSPPyNode,
@@ -45,10 +47,10 @@ class GemmSPPyNode : public TileOperatorNode {
4547
.def_ro("E", &GemmSPPyNode::E)
4648
.def_ro("B", &GemmSPPyNode::B)
4749
.def_ro("C", &GemmSPPyNode::C)
48-
.def_ro("Aptr", &GemmSPPyNode::Aptr)
49-
.def_ro("Eptr", &GemmSPPyNode::Eptr)
50-
.def_ro("Bptr", &GemmSPPyNode::Bptr)
51-
.def_ro("Cptr", &GemmSPPyNode::Cptr)
50+
.def_ro("aRegion", &GemmSPPyNode::aRegion_)
51+
.def_ro("eRegion", &GemmSPPyNode::eRegion_)
52+
.def_ro("bRegion", &GemmSPPyNode::bRegion_)
53+
.def_ro("cRegion", &GemmSPPyNode::cRegion_)
5254
.def_ro("trans_A", &GemmSPPyNode::trans_A)
5355
.def_ro("trans_B", &GemmSPPyNode::trans_B)
5456
.def_ro("trans_E", &GemmSPPyNode::trans_E)
@@ -82,7 +84,7 @@ class GemmSPPy : public TileOperator {
8284
public:
8385
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPPy, TileOperator,
8486
GemmSPPyNode);
85-
TVM_DLL GemmSPPy(Array<PrimExpr> args, BufferMap vmap);
87+
TVM_DLL GemmSPPy(Array<PrimExpr> args);
8688
static const Op &Get();
8789
};
8890

tilelang/language/experimental/gemm_sp.py

Lines changed: 24 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
"""The language interface for tl programs."""
22
from __future__ import annotations
33
from tilelang.primitives.gemm.base import GemmWarpPolicy
4-
from tilelang.utils.language import get_buffer_region_from_load
54
import tilelang.language as T
65
from tvm import tir
7-
from tilelang.utils.language import to_buffer_region
6+
from tilelang.utils.language import (
7+
to_buffer_region,
8+
retrieve_shape,
9+
retrieve_stride,
10+
retrieve_offset,
11+
prim_expr_equal,
12+
)
13+
from tilelang.language.utils import (
14+
buffer_region_to_tile_region,)
815

916

1017
def gemm_sp(
@@ -144,54 +151,13 @@ def legalize_arguments(arg: tir.Buffer | tir.Var):
144151
B = legalize_arguments(B)
145152
C = legalize_arguments(C)
146153

147-
def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]:
148-
if isinstance(object, tir.Buffer):
149-
return object.shape
150-
elif isinstance(object, tir.BufferRegion):
151-
region = object.region
152-
shape = []
153-
for r in region:
154-
shape.append(r.extent)
155-
return shape
156-
elif isinstance(object, tir.BufferLoad):
157-
region = get_buffer_region_from_load(object).region
158-
shape = []
159-
for r in region:
160-
shape.append(r.extent)
161-
return shape
162-
else:
163-
raise ValueError(
164-
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
165-
166-
def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]:
167-
if isinstance(object, tir.Buffer):
168-
strides = []
169-
stride = 1
170-
for s in reversed(object.shape):
171-
strides.insert(0, stride)
172-
stride *= s
173-
return strides
174-
elif isinstance(object, tir.BufferRegion):
175-
buffer, _ = object.buffer, object.region
176-
strides = []
177-
stride = 1
178-
for s in reversed(buffer.shape):
179-
strides.insert(0, stride)
180-
stride *= s
181-
return strides
182-
elif isinstance(object, tir.BufferLoad):
183-
buffer = object.buffer
184-
strides = []
185-
stride = 1
186-
for s in reversed(buffer.shape):
187-
strides.insert(0, stride)
188-
stride *= s
189-
return strides
190-
else:
191-
raise ValueError(
192-
f"Unsupported retrieve_stride argument type: {type(object)} for buffer {object}")
154+
A_region = to_buffer_region(A_sparse)
155+
E_region = to_buffer_region(E)
156+
B_region = to_buffer_region(B)
157+
C_region = to_buffer_region(C)
193158

194159
A_shape = retrieve_shape(A_sparse)
160+
E_shape = retrieve_shape(E) # nolint: F841
195161
B_shape = retrieve_shape(B)
196162
C_shape = retrieve_shape(C)
197163

@@ -213,86 +179,30 @@ def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]:
213179
M, N = C_shape
214180
K = 2 * (A_shape[-2] if transpose_A else A_shape[-1])
215181
K_B = B_shape[-1] if transpose_B else B_shape[-2]
216-
assert K == K_B, f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}"
182+
assert prim_expr_equal(
183+
K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}"
217184

218185
stride_a = A_stride[-2]
219186
stride_b = B_stride[-2]
220187

221-
def retrieve_ptr(object: tir.Buffer | tir.BufferRegion, access_type: str = "r") -> tir.PrimExpr:
222-
if isinstance(object, tir.Buffer):
223-
return object.access_ptr(access_type)
224-
elif isinstance(object, tir.BufferRegion):
225-
buffer, region = object.buffer, object.region
226-
indices = []
227-
for r in region:
228-
indices.append(r.min)
229-
strides = []
230-
stride = 1
231-
for s in reversed(buffer.shape):
232-
strides.insert(0, stride)
233-
stride *= s
234-
offset = 0
235-
# not offset the last two dimension
236-
for i in range(len(indices) - 2):
237-
offset += indices[i] * strides[i]
238-
return buffer.access_ptr(access_mask=access_type, offset=offset)
239-
elif isinstance(object, tir.BufferLoad):
240-
buffer = object.buffer
241-
region = get_buffer_region_from_load(object).region
242-
indices = []
243-
for r in region:
244-
indices.append(r.min)
245-
strides = []
246-
stride = 1
247-
for s in reversed(buffer.shape):
248-
strides.insert(0, stride)
249-
stride *= s
250-
offset = 0
251-
for i in range(len(indices) - 2):
252-
offset += indices[i] * strides[i]
253-
return buffer.access_ptr(access_mask=access_type, offset=offset)
254-
else:
255-
raise ValueError(
256-
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
257-
258-
def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
259-
"""Retrieve the offset of the buffer or buffer region."""
260-
if isinstance(object, tir.Buffer):
261-
return [0] * len(object.shape)
262-
elif isinstance(object, tir.BufferRegion):
263-
_, region = object.buffer, object.region
264-
indices = []
265-
for r in region:
266-
indices.append(r.min)
267-
return indices
268-
elif isinstance(object, tir.BufferLoad):
269-
region = get_buffer_region_from_load(object).region
270-
indices = []
271-
for r in region:
272-
indices.append(r.min)
273-
return indices
274-
else:
275-
raise ValueError(
276-
f"Unsupported retrieve_offset argument type: {type(object)} for buffer {object}")
277-
278188
A_offset = retrieve_offset(A_sparse)
279189
B_offset = retrieve_offset(B)
280190
assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
281191
assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0"
282192
offset_a = A_offset[-1]
283193
offset_b = B_offset[-1]
284194

285-
Aptr = retrieve_ptr(A_sparse, "r")
286-
Eptr = retrieve_ptr(E, "r")
287-
Bptr = retrieve_ptr(B, "r")
288-
Cptr = retrieve_ptr(C, "rw")
195+
A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape])
196+
E_arg = buffer_region_to_tile_region(E_region, "r", [r for r in E_shape])
197+
B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape])
198+
C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape])
289199
return tir.call_intrin(
290200
"handle",
291201
tir.op.Op.get("tl.gemm_sp_py"),
292-
Aptr,
293-
Eptr,
294-
Bptr,
295-
Cptr,
202+
A_arg,
203+
E_arg,
204+
B_arg,
205+
C_arg,
296206
transpose_A,
297207
transpose_B,
298208
transpose_E,

tilelang/tileop/gemm_sp/gemm_sp_base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,20 @@ def C(self) -> tir.Buffer:
8383
return self.gemm_sp_node.C
8484

8585
@property
86-
def APtr(self) -> tir.PrimExpr:
87-
return self.gemm_sp_node.APtr
86+
def ARegion(self) -> tir.PrimExpr:
87+
return self.gemm_sp_node.ARegion
8888

8989
@property
90-
def EPtr(self) -> tir.PrimExpr:
91-
return self.gemm_sp_node.EPtr
90+
def ERegion(self) -> tir.PrimExpr:
91+
return self.gemm_sp_node.ERegion
9292

9393
@property
94-
def BPtr(self) -> tir.PrimExpr:
95-
return self.gemm_sp_node.BPtr
94+
def BRegion(self) -> tir.PrimExpr:
95+
return self.gemm_sp_node.BRegion
9696

9797
@property
98-
def CPtr(self) -> tir.PrimExpr:
99-
return self.gemm_sp_node.CPtr
98+
def CRegion(self) -> tir.PrimExpr:
99+
return self.gemm_sp_node.CRegion
100100

101101
@property
102102
def stride_A(self) -> int:

0 commit comments

Comments
 (0)