Skip to content

Commit

Permalink
[WebAssembly] Autovec support for dot (#123207)
Browse files Browse the repository at this point in the history
Enable the use of partial.reduce.add that we can lower to dot or a tree
of (add (extmul_low_u, extmul_high_u)) for the unsigned case. We support
both v8i16 and v16i8 inputs.
  • Loading branch information
sparker-arm authored Feb 3, 2025
1 parent f7f3dfc commit df2de13
Show file tree
Hide file tree
Showing 7 changed files with 340 additions and 1 deletion.
1 change: 1 addition & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISD.def
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ HANDLE_NODETYPE(Wrapper)
HANDLE_NODETYPE(WrapperREL)
HANDLE_NODETYPE(BR_IF)
HANDLE_NODETYPE(BR_TABLE)
HANDLE_NODETYPE(DOT)
HANDLE_NODETYPE(SHUFFLE)
HANDLE_NODETYPE(SWIZZLE)
HANDLE_NODETYPE(VEC_SHL)
Expand Down
124 changes: 124 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsWebAssembly.h"
#include "llvm/Support/ErrorHandling.h"
Expand Down Expand Up @@ -177,6 +178,10 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(

// SIMD-specific configuration
if (Subtarget->hasSIMD128()) {

// Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);

// Combine vector mask reductions into alltrue/anytrue
setTargetDAGCombine(ISD::SETCC);

Expand Down Expand Up @@ -406,6 +411,35 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
return TargetLowering::getPointerMemTy(DL, AS);
}

bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
const IntrinsicInst *I) const {
if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
return true;

EVT VT = EVT::getEVT(I->getType());
auto Op1 = I->getOperand(1);

if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
if (InstructionOpcodeToISD(InputInst->getOpcode()) != ISD::MUL)
return true;

if (isa<Instruction>(InputInst->getOperand(0)) &&
isa<Instruction>(InputInst->getOperand(1))) {
// dot only supports signed inputs but also support lowering unsigned.
if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
cast<Instruction>(InputInst->getOperand(1))->getOpcode())
return true;

EVT Op1VT = EVT::getEVT(Op1->getType());
if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
((VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()) ||
(VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
return false;
}
}
return true;
}

TargetLowering::AtomicExpansionKind
WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
// We have wasm instructions for these
Expand Down Expand Up @@ -2030,6 +2064,94 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
MachinePointerInfo(SV));
}

// Try to lower partial.reduce.add to a dot or fallback to a sequence with
// extmul and adds.
SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN);
if (N->getConstantOperandVal(0) !=
Intrinsic::experimental_vector_partial_reduce_add)
return SDValue();

assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
SDLoc DL(N);
SDValue Mul = N->getOperand(2);
assert(Mul->getOpcode() == ISD::MUL && "expected mul input");

SDValue ExtendLHS = Mul->getOperand(0);
SDValue ExtendRHS = Mul->getOperand(1);
assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
"expected widening mul");
assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
"expected mul to use the same extend for both operands");

SDValue ExtendInLHS = ExtendLHS->getOperand(0);
SDValue ExtendInRHS = ExtendRHS->getOperand(0);
bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;

if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
if (IsSigned) {
// i32x4.dot_i16x8_s
SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
ExtendInLHS, ExtendInRHS);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
}

unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;

// (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInLHS);
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInRHS);
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInLHS);
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInRHS);

SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v4i32, LowLHS, LowRHS);
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v4i32, HighLHS, HighRHS);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, MulLow, MulHigh);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
} else {
assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
"expected v16i8 input types");
// Lower to a wider tree, using twice the operations compared to above.
if (IsSigned) {
// Use two dots
unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_S;
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_S;
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
SDValue DotLHS =
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
SDValue DotRHS =
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}

unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);

SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);

SDValue LowLow = DAG.getNode(LowOpc, DL, MVT::v4i32, MulLow);
SDValue LowHigh = DAG.getNode(LowOpc, DL, MVT::v4i32, MulHigh);
SDValue HighLow = DAG.getNode(HighOpc, DL, MVT::v4i32, MulLow);
SDValue HighHigh = DAG.getNode(HighOpc, DL, MVT::v4i32, MulHigh);

SDValue AddLow = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowLow, HighLow);
SDValue AddHigh = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowHigh, HighHigh);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}
}

SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
SelectionDAG &DAG) const {
MachineFunction &MF = DAG.getMachineFunction();
Expand Down Expand Up @@ -3126,5 +3248,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return performVectorTruncZeroCombine(N, DCI);
case ISD::TRUNCATE:
return performTruncateCombine(N, DCI);
case ISD::INTRINSIC_WO_CHAIN:
return performLowerPartialReduction(N, DCI.DAG);
}
}
2 changes: 2 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class WebAssemblyTargetLowering final : public TargetLowering {
/// right decision when generating code for different targets.
const WebAssemblySubtarget *Subtarget;

bool
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override;
bool shouldScalarizeBinop(SDValue VecOp) const override;
FastISel *createFastISel(FunctionLoweringInfo &FuncInfo,
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
Original file line number Diff line number Diff line change
Expand Up @@ -1147,11 +1147,15 @@ def : Pat<(wasm_shr_u
}

// Widening dot product: i32x4.dot_i16x8_s
def dot_t : SDTypeProfile<1, 2, [SDTCisVT<0, v4i32>, SDTCisVT<1, v8i16>, SDTCisVT<2, v8i16>]>;
def wasm_dot : SDNode<"WebAssemblyISD::DOT", dot_t>;
let isCommutable = 1 in
defm DOT : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs), (outs), (ins),
[(set V128:$dst, (int_wasm_dot V128:$lhs, V128:$rhs))],
"i32x4.dot_i16x8_s\t$dst, $lhs, $rhs", "i32x4.dot_i16x8_s",
186>;
def : Pat<(wasm_dot V128:$lhs, V128:$rhs),
(DOT $lhs, $rhs)>;

// Extending multiplication: extmul_{low,high}_P, extmul_high
def extend_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
Expand Down
47 changes: 47 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,53 @@ WebAssemblyTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
return Cost;
}

InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp) const {
InstructionCost Invalid = InstructionCost::getInvalid();
if (!VF.isFixed() || !ST->hasSIMD128())
return Invalid;

InstructionCost Cost(TTI::TCC_Basic);

// Possible options:
// - i16x8.extadd_pairwise_i8x16_sx
// - i32x4.extadd_pairwise_i16x8_sx
// - i32x4.dot_i16x8_s
// Only try to support dot, for now.

if (Opcode != Instruction::Add)
return Invalid;

if (!BinOp || *BinOp != Instruction::Mul)
return Invalid;

if (InputTypeA != InputTypeB)
return Invalid;

if (OpAExtend != OpBExtend)
return Invalid;

EVT InputEVT = EVT::getEVT(InputTypeA);
EVT AccumEVT = EVT::getEVT(AccumType);

// TODO: Add i64 accumulator.
if (AccumEVT != MVT::i32)
return Invalid;

// Signed inputs can lower to dot
if (InputEVT == MVT::i16 && VF.getFixedValue() == 8)
return OpAExtend == TTI::PR_SignExtend ? Cost : Cost * 2;

// Double the size of the lowered sequence.
if (InputEVT == MVT::i8 && VF.getFixedValue() == 16)
return OpAExtend == TTI::PR_SignExtend ? Cost * 2 : Cost * 4;

return Invalid;
}

TTI::ReductionShuffle WebAssemblyTTIImpl::getPreferredExpandedReductionShuffle(
const IntrinsicInst *II) const {

Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
TTI::TargetCostKind CostKind,
unsigned Index, Value *Op0, Value *Op1);

InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
Type *AccumType, ElementCount VF,
TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp = std::nullopt) const;
TTI::ReductionShuffle
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const;

Expand Down
Loading

0 comments on commit df2de13

Please sign in to comment.