diff --git a/src/op/reduce.cc b/src/op/reduce.cc index ccb8a2ead..7148cc076 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -172,7 +172,7 @@ std::string ReduceOpNode::MakeCodegenReducer() const { * - Detects parallel thread splitting from the normalized iterator sum and * emits a call to a templated `tl::AllReduce<...>::run` (or `run_hopper`) * via `builtin::call_extern`. For sufficiently large reducing thread counts - * (>= 32) a workspace is allocated via T.AddWorkspace and passed to the + * (> 32) a workspace is allocated via T.AddWorkspace and passed to the * AllReduce call. * - The final body is wrapped in parallel loops over the destination spatial * dimensions and partitioned by the lowering thread variable. If a temporary @@ -322,7 +322,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } Array thread_reduce_args = { StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)}; - if (reducing_threads >= 32) { + if (reducing_threads > 32) { PrimExpr workspace = T.AddWorkspace( *as_const_int(T.thread_bounds->extent), clear_buffer->dtype); thread_reduce_args.push_back(workspace);