[DAGCombiner] Fold select into partial.reduce.add operands.#167857
[DAGCombiner] Fold select into partial.reduce.add operands.#167857sdesmalen-arm merged 5 commits intomainfrom
Conversation
|
@llvm/pr-subscribers-backend-risc-v @llvm/pr-subscribers-llvm-selectiondag Author: Sander de Smalen (sdesmalen-arm) ChangesThis generates more optimal codegen when using partial reductions with predication. Full diff: https://github.com/llvm/llvm-project/pull/167857.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index df353c4d91b1a..c21b79225323e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13018,22 +13018,37 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
return SDValue();
}
-// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
+// partial_reduce_*mla(acc, mul(*ext(a), *ext(b)), splat(1))
// -> partial_reduce_*mla(acc, a, b)
//
-// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-// -> partial_reduce_*mla(acc, x, C)
+// partial_reduce_*mla(acc, mul(*ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, splat(C))
//
-// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
-// -> partial_reduce_fmla(acc, a, b)
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)
+//
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), splat(C)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), splat(C))
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDLoc DL(N);
auto *Context = DAG.getContext();
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
-
unsigned Opc = Op1->getOpcode();
+
+ // Handle predication by moving the SELECT into the operand of the MUL.
+ SDValue Pred;
+ if (Opc == ISD::VSELECT) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+ !C.isZero())
+ return SDValue();
+ Pred = Op1->getOperand(0);
+ Op1 = Op1->getOperand(1);
+ Opc = Op1->getOpcode();
+ }
+
if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
return SDValue();
@@ -13068,6 +13083,18 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue LHSExtOp = LHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
+ // Return 'select(P, Op, splat(0))' if P is nonzero,
+ // or 'P' otherwise.
+ auto tryPredicate = [&](SDValue P, SDValue Op) {
+ if (!P)
+ return Op;
+ EVT OpVT = Op.getValueType();
+ SDValue Zero = OpVT.isFloatingPoint()
+ ? DAG.getConstantFP(0.0, DL, OpVT)
+ : DAG.getConstant(0, DL, OpVT);
+ return DAG.getSelect(DL, OpVT, P, Op, Zero);
+ };
+
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
APInt C;
@@ -13090,8 +13117,10 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
+ SDValue Constant =
+ tryPredicate(Pred, DAG.getConstant(CTrunc, DL, LHSExtOpVT));
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
- DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+ Constant);
}
unsigned RHSOpcode = RHS->getOpcode();
@@ -13132,17 +13161,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
+ RHSExtOp = tryPredicate(Pred, RHSExtOp);
return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
}
-// partial.reduce.umla(acc, zext(op), splat(1))
-// -> partial.reduce.umla(acc, op, splat(trunc(1)))
-// partial.reduce.smla(acc, sext(op), splat(1))
-// -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.*mla(acc, *ext(op), splat(1))
+// -> partial.reduce.*mla(acc, op, splat(trunc(1)))
// partial.reduce.sumla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
-// partial.reduce.fmla(acc, fpext(op), splat(1.0))
-// -> partial.reduce.fmla(acc, op, splat(1.0))
+//
+// partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
+// -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
@@ -13152,7 +13181,18 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
return SDValue();
+ SDValue Pred;
unsigned Op1Opcode = Op1.getOpcode();
+ if (Op1Opcode == ISD::VSELECT) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+ !C.isZero())
+ return SDValue();
+ Pred = Op1->getOperand(0);
+ Op1 = Op1->getOperand(1);
+ Op1Opcode = Op1->getOpcode();
+ }
+
if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
return SDValue();
@@ -13181,6 +13221,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
? DAG.getConstantFP(1, DL, UnextOp1VT)
: DAG.getConstant(1, DL, UnextOp1VT);
+ if (Pred) {
+ SDValue Zero = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+ ? DAG.getConstantFP(0, DL, UnextOp1VT)
+ : DAG.getConstant(0, DL, UnextOp1VT);
+ Constant = DAG.getSelect(DL, UnextOp1VT, Pred, Constant, Zero);
+ }
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
Constant);
}
diff --git a/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
new file mode 100644
index 0000000000000..24cdd0a852222
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
@@ -0,0 +1,159 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s | FileCheck %s
+
+target triple = "aarch64"
+
+define <4 x i32> @predicate_dot_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a, <16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v1.16b, v1.16b, #7
+; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %ext.2 = sext <16 x i8> %b to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, %ext.2
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <4 x i32> @predicate_dot_by_C_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v1.16b, v1.16b, #7
+; CHECK-NEXT: movi v3.16b, #127
+; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, splat(i32 127)
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEXT: sel z2.b, p0, z2.b, z3.b
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %ext.2 = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, %ext.2
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_by_C_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, p0/z, #127 // =0x7f
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, splat(i32 127)
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x i32> @predicate_ext_mul_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.16b, #1
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext = sext <16 x i8> %a to <16 x i32>
+ %sel = select <16 x i1> %p, <16 x i32> %ext, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_ext_mul_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, p0/z, #1 // =0x1
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %ext, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x float> @predicated_fdot_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a, <8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: shl v1.8h, v1.8h, #15
+; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: fdot z0.s, z2.h, z1.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+ %ext.1 = fpext <8 x half> %a to <8 x float>
+ %ext.2 = fpext <8 x half> %b to <8 x float>
+ %mul = fmul <8 x float> %ext.1, %ext.2
+ %sel = select <8 x i1> %p, <8 x float> %mul, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fdot_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEXT: sel z2.h, p0, z2.h, z3.h
+; CHECK-NEXT: fdot z0.s, z1.h, z2.h
+; CHECK-NEXT: ret
+ %ext.1 = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %ext.2 = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+ %mul = fmul <vscale x 8 x float> %ext.1, %ext.2
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %mul, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+define <4 x float> @predicated_fpext_fmul_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: movi v3.8h, #60, lsl #8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: shl v1.8h, v1.8h, #15
+; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: fdot z0.s, z2.h, z1.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+ %ext = fpext <8 x half> %a to <8 x float>
+ %sel = select <8 x i1> %p, <8 x float> %ext, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fpext_fmul_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v2.2d, #0000000000000000
+; CHECK-NEXT: fmov z2.h, p0/m, #1.00000000
+; CHECK-NEXT: fdot z0.s, z1.h, z2.h
+; CHECK-NEXT: ret
+ %ext = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %ext, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+attributes #0 = { nounwind "target-features"="+sve,+dotprod" }
+attributes #1 = { nounwind "target-features"="+sve2p1,+dotprod" }
|
|
@llvm/pr-subscribers-backend-aarch64 Author: Sander de Smalen (sdesmalen-arm) ChangesThis generates more optimal codegen when using partial reductions with predication. Full diff: https://github.com/llvm/llvm-project/pull/167857.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index df353c4d91b1a..c21b79225323e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13018,22 +13018,37 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
return SDValue();
}
-// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
+// partial_reduce_*mla(acc, mul(*ext(a), *ext(b)), splat(1))
// -> partial_reduce_*mla(acc, a, b)
//
-// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-// -> partial_reduce_*mla(acc, x, C)
+// partial_reduce_*mla(acc, mul(*ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, splat(C))
//
-// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
-// -> partial_reduce_fmla(acc, a, b)
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)
+//
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), splat(C)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), splat(C))
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDLoc DL(N);
auto *Context = DAG.getContext();
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
-
unsigned Opc = Op1->getOpcode();
+
+ // Handle predication by moving the SELECT into the operand of the MUL.
+ SDValue Pred;
+ if (Opc == ISD::VSELECT) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+ !C.isZero())
+ return SDValue();
+ Pred = Op1->getOperand(0);
+ Op1 = Op1->getOperand(1);
+ Opc = Op1->getOpcode();
+ }
+
if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
return SDValue();
@@ -13068,6 +13083,18 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue LHSExtOp = LHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
+ // Return 'select(P, Op, splat(0))' if P is nonzero,
+ // or 'P' otherwise.
+ auto tryPredicate = [&](SDValue P, SDValue Op) {
+ if (!P)
+ return Op;
+ EVT OpVT = Op.getValueType();
+ SDValue Zero = OpVT.isFloatingPoint()
+ ? DAG.getConstantFP(0.0, DL, OpVT)
+ : DAG.getConstant(0, DL, OpVT);
+ return DAG.getSelect(DL, OpVT, P, Op, Zero);
+ };
+
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
APInt C;
@@ -13090,8 +13117,10 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
+ SDValue Constant =
+ tryPredicate(Pred, DAG.getConstant(CTrunc, DL, LHSExtOpVT));
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
- DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+ Constant);
}
unsigned RHSOpcode = RHS->getOpcode();
@@ -13132,17 +13161,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
+ RHSExtOp = tryPredicate(Pred, RHSExtOp);
return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
}
-// partial.reduce.umla(acc, zext(op), splat(1))
-// -> partial.reduce.umla(acc, op, splat(trunc(1)))
-// partial.reduce.smla(acc, sext(op), splat(1))
-// -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.*mla(acc, *ext(op), splat(1))
+// -> partial.reduce.*mla(acc, op, splat(trunc(1)))
// partial.reduce.sumla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
-// partial.reduce.fmla(acc, fpext(op), splat(1.0))
-// -> partial.reduce.fmla(acc, op, splat(1.0))
+//
+// partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
+// -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
@@ -13152,7 +13181,18 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
return SDValue();
+ SDValue Pred;
unsigned Op1Opcode = Op1.getOpcode();
+ if (Op1Opcode == ISD::VSELECT) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+ !C.isZero())
+ return SDValue();
+ Pred = Op1->getOperand(0);
+ Op1 = Op1->getOperand(1);
+ Op1Opcode = Op1->getOpcode();
+ }
+
if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
return SDValue();
@@ -13181,6 +13221,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
? DAG.getConstantFP(1, DL, UnextOp1VT)
: DAG.getConstant(1, DL, UnextOp1VT);
+ if (Pred) {
+ SDValue Zero = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+ ? DAG.getConstantFP(0, DL, UnextOp1VT)
+ : DAG.getConstant(0, DL, UnextOp1VT);
+ Constant = DAG.getSelect(DL, UnextOp1VT, Pred, Constant, Zero);
+ }
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
Constant);
}
diff --git a/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
new file mode 100644
index 0000000000000..24cdd0a852222
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
@@ -0,0 +1,159 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s | FileCheck %s
+
+target triple = "aarch64"
+
+define <4 x i32> @predicate_dot_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a, <16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v1.16b, v1.16b, #7
+; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %ext.2 = sext <16 x i8> %b to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, %ext.2
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <4 x i32> @predicate_dot_by_C_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v1.16b, v1.16b, #7
+; CHECK-NEXT: movi v3.16b, #127
+; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, splat(i32 127)
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEXT: sel z2.b, p0, z2.b, z3.b
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %ext.2 = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, %ext.2
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_by_C_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, p0/z, #127 // =0x7f
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, splat(i32 127)
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x i32> @predicate_ext_mul_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.16b, #1
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext = sext <16 x i8> %a to <16 x i32>
+ %sel = select <16 x i1> %p, <16 x i32> %ext, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_ext_mul_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, p0/z, #1 // =0x1
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %ext, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x float> @predicated_fdot_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a, <8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: shl v1.8h, v1.8h, #15
+; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: fdot z0.s, z2.h, z1.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+ %ext.1 = fpext <8 x half> %a to <8 x float>
+ %ext.2 = fpext <8 x half> %b to <8 x float>
+ %mul = fmul <8 x float> %ext.1, %ext.2
+ %sel = select <8 x i1> %p, <8 x float> %mul, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fdot_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEXT: sel z2.h, p0, z2.h, z3.h
+; CHECK-NEXT: fdot z0.s, z1.h, z2.h
+; CHECK-NEXT: ret
+ %ext.1 = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %ext.2 = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+ %mul = fmul <vscale x 8 x float> %ext.1, %ext.2
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %mul, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+define <4 x float> @predicated_fpext_fmul_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: movi v3.8h, #60, lsl #8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: shl v1.8h, v1.8h, #15
+; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: fdot z0.s, z2.h, z1.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+ %ext = fpext <8 x half> %a to <8 x float>
+ %sel = select <8 x i1> %p, <8 x float> %ext, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fpext_fmul_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v2.2d, #0000000000000000
+; CHECK-NEXT: fmov z2.h, p0/m, #1.00000000
+; CHECK-NEXT: fdot z0.s, z1.h, z2.h
+; CHECK-NEXT: ret
+ %ext = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %ext, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+attributes #0 = { nounwind "target-features"="+sve,+dotprod" }
+attributes #1 = { nounwind "target-features"="+sve2p1,+dotprod" }
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
14e935c to
a126fe7
Compare
RISC-V test coverage for #167857
|
I've precommitted a RISC-V test case here 2a53949, can you rebase this PR on top of it? |
This generates more optimal codegen when using partial reductions with predication. partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1)) -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b) partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1)) -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
a126fe7 to
7506ebc
Compare
| if (Opc == ISD::VSELECT) { | ||
| APInt C; | ||
| if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) || | ||
| !C.isZero()) | ||
| return SDValue(); |
There was a problem hiding this comment.
| if (Opc == ISD::VSELECT) { | |
| APInt C; | |
| if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) || | |
| !C.isZero()) | |
| return SDValue(); | |
| if (Opc == ISD::VSELECT && isNullOrNullSplat(Op1->getOperand(2))) { |
| if (Op1Opcode == ISD::VSELECT) { | ||
| APInt C; | ||
| if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) || | ||
| !C.isZero()) | ||
| return SDValue(); |
There was a problem hiding this comment.
| if (Op1Opcode == ISD::VSELECT) { | |
| APInt C; | |
| if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) || | |
| !C.isZero()) | |
| return SDValue(); | |
| if (Op1Opcode == ISD::VSELECT && isNullOrNullSplat(Op1->getOperand(2))) { |
There was a problem hiding this comment.
It seems isNullOrNullSplat does not recognise FP constants, so I've added a similar FP variant (similar to what was done for isOneOrOneSplatFP)
…-partial-reduce-add' into predicated-partial-reduce-add
) This generates more optimal codegen when using partial reductions with predication. ``` partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1)) -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b) partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1)) -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1))) ```
This generates more optimal codegen when using partial reductions with predication.