diff --git a/src/transform/atomicadd_vectorize.cc b/src/transform/atomicadd_vectorize.cc index a6b12f7e9..cd63c9583 100644 --- a/src/transform/atomicadd_vectorize.cc +++ b/src/transform/atomicadd_vectorize.cc @@ -231,21 +231,25 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator { // Ref: src/tl_templates/cuda/atomic.h::AtomicAdd const IntImm memory_order = node->args.size() >= 3 ? Downcast(node->args[2]) : IntImm(0); - + Array new_args; Call address_of_dst = Call(DataType::Handle(), builtin::address_of(), {dst_node}); Call address_of_value = Call(DataType::Handle(), builtin::address_of(), {value_node}); - Array new_args; if (vector_size_ == 4) { new_args.push_back(StringImm("AtomicAddx4")); + new_args.push_back(address_of_dst); + new_args.push_back(address_of_value); } else if (vector_size_ == 2) { new_args.push_back(StringImm("AtomicAddx2")); + new_args.push_back(address_of_dst); + new_args.push_back(address_of_value); } else { new_args.push_back(StringImm("AtomicAdd")); + new_args.push_back(dst_node); + new_args.push_back(value_node); } - new_args.push_back(address_of_dst); - new_args.push_back(address_of_value); + new_args.push_back(memory_order); Call new_call =