@@ -319,6 +319,33 @@ static int GetVectorizeSizeMax(int compute_capability, DataType dtype) {
319319For VectorizeAtomicAdd (const For &for_node, const Var &thread_var,
320320 const Range &thread_bounds, int compute_capability) {
321321
322+ auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out,
323+ int &stride_out) -> bool {
324+ int mul_count = 0 , legal_mul_count = 0 ;
325+ stride_out = -1 ;
326+ var_out = PrimExpr ();
327+ PostOrderVisit (idx, [&](const ObjectRef &obj) {
328+ if (const MulNode *mul = obj.as <MulNode>()) {
329+ mul_count++;
330+ const VarNode *var = nullptr ;
331+ const IntImmNode *imm = nullptr ;
332+ if ((var = mul->a .as <VarNode>()) && (imm = mul->b .as <IntImmNode>())) {
333+ var_out = mul->a ;
334+ stride_out = imm->value ;
335+ legal_mul_count++;
336+ } else if ((var = mul->b .as <VarNode>()) &&
337+ (imm = mul->a .as <IntImmNode>())) {
338+ var_out = mul->b ;
339+ stride_out = imm->value ;
340+ legal_mul_count++;
341+ }
342+ }
343+ });
344+ if (mul_count == 1 && legal_mul_count == 1 )
345+ return true ;
346+ return false ;
347+ };
348+
322349 int vectorize_size_max = 1 ;
323350 int stride_x = -1 , stride_y = -1 ;
324351 PrimExpr bx_var, by_var;
@@ -327,33 +354,22 @@ For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
327354 if (const auto *call = obj.as <CallNode>()) {
328355 if (call->op == builtin::call_extern () && call->args .size () >= 2 ) {
329356 const auto *func_name = call->args [0 ].as <StringImmNode>();
330- if (func_name->value == " AtomicAdd" ) {
331- DataType dtype = call->args [1 ].as <BufferLoadNode>()->dtype ;
357+ if (func_name && func_name->value == " AtomicAdd" ) {
358+ const auto *bufload = call->args [1 ].as <BufferLoadNode>();
359+ if (!bufload || bufload->indices .size () != 2 )
360+ return ;
361+
362+ DataType dtype = bufload->dtype ;
332363 vectorize_size_max = GetVectorizeSizeMax (compute_capability, dtype);
333- }
334- }
335- }
336- if (const MulNode *mul = obj.as <MulNode>()) {
337- const VarNode *var = nullptr ;
338- const IntImmNode *imm = nullptr ;
339- PrimExpr var_expr;
340- if ((var = mul->a .as <VarNode>()) && (imm = mul->b .as <IntImmNode>())) {
341- var_expr = mul->a ;
342- } else if ((var = mul->b .as <VarNode>()) &&
343- (imm = mul->a .as <IntImmNode>())) {
344- var_expr = mul->b ;
345- }
346- if (var && imm) {
347- if (var->name_hint == " bx" ) {
348- stride_x = imm->value ;
349- bx_var = var_expr;
350- } else if (var->name_hint == " by" ) {
351- stride_y = imm->value ;
352- by_var = var_expr;
364+ if (!ParseIndex (bufload->indices [0 ], by_var, stride_y))
365+ return ;
366+ if (!ParseIndex (bufload->indices [1 ], bx_var, stride_x))
367+ return ;
353368 }
354369 }
355370 }
356371 });
372+
357373 if (vectorize_size_max != 1 ) {
358374 int vectorize_hint = vectorize_size_max;
359375 AtomicAddVectorizePlanResult res = {1 , false , 0 };
0 commit comments