|
13 | 13 | #include "../target/utils.h" |
14 | 14 | #include "../transform/atomicadd_vectorize.h" |
15 | 15 | #include "../transform/common/loop_fusion_utils.h" |
| 16 | +#include "../transform/common/loop_parallel_transform_utils.h" |
16 | 17 | #include "../transform/loop_partition.h" |
17 | 18 | #include "builtin.h" |
18 | 19 |
|
@@ -225,34 +226,6 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer, |
225 | 226 | } |
226 | 227 | } |
227 | 228 |
|
228 | | -/** |
229 | | - * @brief Build a SIMT-style loop nest that performs element-wise atomic |
230 | | - * additions from src to dst. |
231 | | - * |
232 | | - * Constructs a nested loop (parallelized per iter var) that loads a value from |
233 | | - * the source buffer, optionally casts it to the destination dtype, and performs |
234 | | - * an extern atomic add into the destination buffer address. For scalar |
235 | | - * (zero-dimensional) operations a trivial serial For with a single BufferStore |
236 | | - * is returned. |
237 | | - * |
238 | | - * The method: |
239 | | - * - Creates iter vars for all non-singleton extents and binds them into the |
240 | | - * provided analyzer. |
241 | | - * - Validates loop variable counts against src/dst ranges (ICHECK on mismatch). |
242 | | - * - Computes indexed accesses and emits optional bound predicates; |
243 | | - * out-of-bounds accesses are masked to zero when predicates are uncertain. |
244 | | - * - Emits an extern `call_extern("AtomicAdd", address_of(dst_value), |
245 | | - * src_value)` call wrapped in an Evaluate statement. |
246 | | - * - Wraps the body with a parallel For at each loop level. If `coalesced_width` |
247 | | - * is defined it is attached as the "coalesced_width" annotation on each loop. |
248 | | - * |
249 | | - * Note: This function mutates the analyzer binding state by binding loop |
250 | | - * variables and may fail via ICHECK if internal assumptions about shapes are |
251 | | - * violated. |
252 | | - * |
253 | | - * @return A nested For loop (parallel loops) implementing the atomic-add |
254 | | - * kernel. For scalar cases a serial For of extent 1 is returned. |
255 | | - */ |
256 | 229 | For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { |
257 | 230 | Array<IterVar> loop_vars = MakeIterVars(); |
258 | 231 | bool is_scalar = loop_vars.empty(); |
@@ -418,6 +391,152 @@ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T, |
418 | 391 | return par_op_->InferLayout(T, level); |
419 | 392 | } |
420 | 393 |
|
| 394 | +Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { |
| 395 | + Target target = T.target; |
| 396 | + auto simt_loop = MakeSIMTLoop(analyzer); |
| 397 | + auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop)); |
| 398 | + auto transformed_loop = |
| 399 | + Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop)); |
| 400 | + LOG(INFO) << transformed_loop; |
| 401 | + |
| 402 | + auto GetArchInt = [&](const Target &tgt) -> int { |
| 403 | + int arch_int = 0; |
| 404 | + if (auto s = tgt->GetAttr<String>("arch")) { |
| 405 | + std::string arch = s.value(); |
| 406 | + if (arch.rfind("sm_", 0) == 0) |
| 407 | + arch_int = std::stoi(arch.substr(3)); |
| 408 | + } |
| 409 | + return arch_int; |
| 410 | + }; |
| 411 | + |
| 412 | + struct AtomicLoopNestCollector : tir::StmtExprVisitor { |
| 413 | + Array<IterVar> loop_vars; |
| 414 | + Map<Buffer, Array<PrimExpr>> indice_map; |
| 415 | + std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> writes; |
| 416 | + arith::Analyzer analyzer; |
| 417 | + |
| 418 | + void Run(const Stmt &s) { StmtExprVisitor::VisitStmt(s); } |
| 419 | + |
| 420 | + void VisitStmt_(const ForNode *op) final { |
| 421 | + if (op->kind == ForKind::kParallel) { |
| 422 | + loop_vars.push_back(IterVar(Range(op->min, op->extent), op->loop_var, |
| 423 | + IterVarType::kDataPar)); |
| 424 | + } |
| 425 | + analyzer.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); |
| 426 | + StmtExprVisitor::VisitStmt_(op); |
| 427 | + } |
| 428 | + void VisitStmt_(const BufferStoreNode *op) final { |
| 429 | + if (op->buffer.scope() == "local.fragment") { |
| 430 | + indice_map.Set(op->buffer, op->indices); |
| 431 | + writes.insert(op->buffer); |
| 432 | + } |
| 433 | + StmtExprVisitor::VisitStmt_(op); |
| 434 | + } |
| 435 | + void VisitExpr_(const BufferLoadNode *op) final { |
| 436 | + if (op->buffer.scope() == "local.fragment") { |
| 437 | + indice_map.Set(op->buffer, op->indices); |
| 438 | + } |
| 439 | + StmtExprVisitor::VisitExpr_(op); |
| 440 | + } |
| 441 | + }; |
| 442 | + |
| 443 | + auto ComputeLoopLayoutFromBuffer = |
| 444 | + [&](const Buffer &buf, const Array<PrimExpr> &indices, |
| 445 | + const LayoutMap &layout_map, const Range &thread_bounds, |
| 446 | + const Array<IterVar> &loop_vars) -> Fragment { |
| 447 | + Fragment src = layout_map[buf].as<Fragment>().value(); |
| 448 | + Var rep; |
| 449 | + auto rep_iter = |
| 450 | + IterVar(Range(0, src->ReplicateExtent()), rep, IterVarType::kDataPar); |
| 451 | + PrimExpr fth = src->ForwardThread(indices, rep); |
| 452 | + fth = analyzer->Simplify(fth); |
| 453 | + Fragment out = Fragment(loop_vars, /*forward_index=*/{}, fth, rep_iter) |
| 454 | + ->BindThreadRange(thread_bounds); |
| 455 | + return out; |
| 456 | + }; |
| 457 | + |
| 458 | + struct AtomicInferResult { |
| 459 | + Fragment loop_layout; |
| 460 | + Optional<PrimExpr> predicate; |
| 461 | + }; |
| 462 | + |
| 463 | + auto AtomicAddInferLayout = |
| 464 | + [&](const For &loop, const LayoutInferArgs &args) -> AtomicInferResult { |
| 465 | + AtomicLoopNestCollector C; |
| 466 | + C.Run(loop); |
| 467 | + Optional<Buffer> read_src; |
| 468 | + int best_rank = -1; |
| 469 | + for (auto kv : C.indice_map) { |
| 470 | + const Buffer &buf = kv.first; |
| 471 | + if (buf.scope() != "local.fragment") |
| 472 | + continue; |
| 473 | + if (!args.layout_map.count(buf)) |
| 474 | + continue; |
| 475 | + int rank = static_cast<int>(kv.second.size()); |
| 476 | + if (rank > best_rank) { |
| 477 | + best_rank = rank; |
| 478 | + read_src = buf; |
| 479 | + } |
| 480 | + } |
| 481 | + AtomicAddVectorizePlanner planner; |
| 482 | + int sm = GetArchInt(target); |
| 483 | + auto plan = planner.Plan(loop, sm); |
| 484 | + int vec = std::max(plan.vector_size, 1); |
| 485 | + if (auto cw = loop->annotations.Get("coalesced_width")) { |
| 486 | + if (const auto *imm = cw->as<IntImmNode>()) { |
| 487 | + int expected = imm->value; |
| 488 | + ICHECK_GT(expected, 0); |
| 489 | + ICHECK(vec % expected == 0) |
| 490 | + << "vector_size " << vec << " not divisible by coalesced_width " |
| 491 | + << expected; |
| 492 | + vec = expected; |
| 493 | + } else { |
| 494 | + LOG(FATAL) << "coalesced_width should be IntImmNode."; |
| 495 | + } |
| 496 | + } |
| 497 | + PrimExpr total = 1; |
| 498 | + for (Stmt s = loop; s.as<For>().has_value(); s = s.as<For>().value()->body) |
| 499 | + total = total * s.as<For>().value()->extent; |
| 500 | + PrimExpr denom = args.thread_bounds->extent * vec; |
| 501 | + while (!analyzer->CanProve(floormod(total, denom) == 0) && vec > 1) { |
| 502 | + vec >>= 1; |
| 503 | + denom = args.thread_bounds->extent * vec; |
| 504 | + } |
| 505 | + if (vec < 1) |
| 506 | + vec = 1; |
| 507 | + Fragment loop_layout; |
| 508 | + if (read_src) { |
| 509 | + loop_layout = ComputeLoopLayoutFromBuffer( |
| 510 | + read_src.value(), C.indice_map[read_src.value()], args.layout_map, |
| 511 | + args.thread_bounds, C.loop_vars); |
| 512 | + } else { |
| 513 | + For remapped = loop; // 简化处理 |
| 514 | + loop_layout = PlanLoopPartition(remapped, vec, args.thread_bounds); |
| 515 | + } |
| 516 | + |
| 517 | + Optional<PrimExpr> pred; |
| 518 | + if (plan.dynamic && plan.condition.defined()) { |
| 519 | + pred = plan.condition; |
| 520 | + } |
| 521 | + DLOG(INFO) << "[AtomicAddInferLayout] vec=" << vec |
| 522 | + << " loop_layout=" << loop_layout->DebugOutput(); |
| 523 | + return {loop_layout, pred}; |
| 524 | + }; |
| 525 | + |
| 526 | + auto ret = AtomicAddInferLayout(transformed_loop, |
| 527 | + {T.target, T.thread_bounds, T.layout_map, |
| 528 | + analyzer, false, T.buffer_remap}); |
| 529 | + Fragment loop_layout = ret.loop_layout; |
| 530 | + LOG(INFO) << loop_layout->DebugOutput(); |
| 531 | + auto thread_loop = |
| 532 | + PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout); |
| 533 | + LOG(INFO) << thread_loop; |
| 534 | + auto vectorized_thread_loop = |
| 535 | + VectorizeAtomicAdd(thread_loop, GetArchInt(target)); |
| 536 | + LOG(INFO) << vectorized_thread_loop; |
| 537 | + return vectorized_thread_loop; |
| 538 | +} |
| 539 | + |
421 | 540 | TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) |
422 | 541 | .set_num_inputs(2) |
423 | 542 | .set_attr<TCallEffectKind>("TCallEffectKind", |
|
0 commit comments