From 4280bde0b35810139b4bb0b30ec99eae4433b3c1 Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Sat, 16 Nov 2024 17:02:05 -0800 Subject: [PATCH] [BACKEND] Fix ProgramPoint passing in AxisInfoAnalysis --- lib/Analysis/AxisInfo.cpp | 6 +++--- test/TritonGPU/coalesce.mlir | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index f0c5ae3167ec..717df8d1bd5a 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1084,9 +1084,9 @@ LogicalResult AxisInfoAnalysis::visitOperation( void AxisInfoAnalysis::visitForOpInductionVar( scf::ForOp op, ArrayRef *> argLattices) { - ProgramPoint programPoint(op); - auto lb = getLatticeElementFor(&programPoint, op.getLowerBound())->getValue(); - auto step = getLatticeElementFor(&programPoint, op.getStep())->getValue(); + ProgramPoint *programPoint = getProgramPointAfter(op); + auto lb = getLatticeElementFor(programPoint, op.getLowerBound())->getValue(); + auto step = getLatticeElementFor(programPoint, op.getStep())->getValue(); AxisInfo::DimVectorT knownContiguity(1, 1); AxisInfo::DimVectorT knownDivisibility(1, 1); diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index cf93c37b840d..5d35f43e9eed 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -131,3 +131,32 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war tt.return } } + +// ----- + +// COM: Reproducer for issue #5122 +// CHECK-LABEL: @test_5122 +module { + tt.func public @test_5122(%arg0: i32) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %0 = arith.cmpi sgt, %arg0, %c1_i32 : i32 + scf.if %0 { + %1 = scf.if %0 -> (i32) { + scf.yield %c1_i32 : i32 + } else { + scf.yield %c1_i32 : i32 + } + %2 = arith.cmpi sgt, %1, %c1_i32 : i32 + %3 = scf.if %2 -> (i32) { + scf.yield %c1_i32 : i32 + } else { + scf.yield %c1_i32 : i32 + } + %4 = scf.for %arg1 = %1 to %1 step %c1_i32 iter_args(%arg2 = %3) -> (i32) : i32 { + %5 = arith.addi %arg2, %c1_i32 : i32 + scf.yield %5 : i32 + } + } + tt.return + } +}