Skip to content

Commit 5a02ea8

Browse files
committed
[Bugfix]:Fix atomicadd auto vectorize identify var error
1 parent ec24561 commit 5a02ea8

File tree

2 files changed

+40
-26
lines changed

2 files changed

+40
-26
lines changed

src/transform/atomicadd_vectorize.cc

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,33 @@ static int GetVectorizeSizeMax(int compute_capability, DataType dtype) {
319319
For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
320320
const Range &thread_bounds, int compute_capability) {
321321

322+
auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out,
323+
int &stride_out) -> bool {
324+
int mul_count = 0, legal_mul_count = 0;
325+
stride_out = -1;
326+
var_out = PrimExpr();
327+
PostOrderVisit(idx, [&](const ObjectRef &obj) {
328+
if (const MulNode *mul = obj.as<MulNode>()) {
329+
mul_count++;
330+
const VarNode *var = nullptr;
331+
const IntImmNode *imm = nullptr;
332+
if ((var = mul->a.as<VarNode>()) && (imm = mul->b.as<IntImmNode>())) {
333+
var_out = mul->a;
334+
stride_out = imm->value;
335+
legal_mul_count++;
336+
} else if ((var = mul->b.as<VarNode>()) &&
337+
(imm = mul->a.as<IntImmNode>())) {
338+
var_out = mul->b;
339+
stride_out = imm->value;
340+
legal_mul_count++;
341+
}
342+
}
343+
});
344+
if (mul_count == 1 && legal_mul_count == 1)
345+
return true;
346+
return false;
347+
};
348+
322349
int vectorize_size_max = 1;
323350
int stride_x = -1, stride_y = -1;
324351
PrimExpr bx_var, by_var;
@@ -327,33 +354,22 @@ For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
327354
if (const auto *call = obj.as<CallNode>()) {
328355
if (call->op == builtin::call_extern() && call->args.size() >= 2) {
329356
const auto *func_name = call->args[0].as<StringImmNode>();
330-
if (func_name->value == "AtomicAdd") {
331-
DataType dtype = call->args[1].as<BufferLoadNode>()->dtype;
357+
if (func_name && func_name->value == "AtomicAdd") {
358+
const auto *bufload = call->args[1].as<BufferLoadNode>();
359+
if (!bufload || bufload->indices.size() != 2)
360+
return;
361+
362+
DataType dtype = bufload->dtype;
332363
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
333-
}
334-
}
335-
}
336-
if (const MulNode *mul = obj.as<MulNode>()) {
337-
const VarNode *var = nullptr;
338-
const IntImmNode *imm = nullptr;
339-
PrimExpr var_expr;
340-
if ((var = mul->a.as<VarNode>()) && (imm = mul->b.as<IntImmNode>())) {
341-
var_expr = mul->a;
342-
} else if ((var = mul->b.as<VarNode>()) &&
343-
(imm = mul->a.as<IntImmNode>())) {
344-
var_expr = mul->b;
345-
}
346-
if (var && imm) {
347-
if (var->name_hint == "bx") {
348-
stride_x = imm->value;
349-
bx_var = var_expr;
350-
} else if (var->name_hint == "by") {
351-
stride_y = imm->value;
352-
by_var = var_expr;
364+
if (!ParseIndex(bufload->indices[0], by_var, stride_y))
365+
return;
366+
if (!ParseIndex(bufload->indices[1], bx_var, stride_x))
367+
return;
353368
}
354369
}
355370
}
356371
});
372+
357373
if (vectorize_size_max != 1) {
358374
int vectorize_hint = vectorize_size_max;
359375
AtomicAddVectorizePlanResult res = {1, false, 0};

testing/python/language/test_tilelang_language_atomic_add.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,8 @@ def test_atomic_return_prev():
372372
run_atomic_return_prev(32, 32, 8, 8)
373373

374374

375-
# TODO(lei): test failed and this is experimental
376-
# CC @dyq
377-
# def test_tile_atomic_add():
378-
# run_tile_atomic_add(8, 128, 128, 32, 32)
375+
def test_tile_atomic_add():
376+
run_tile_atomic_add(8, 128, 128, 32, 32)
379377

380378
if __name__ == "__main__":
381379
tilelang.testing.main()

0 commit comments

Comments
 (0)