Skip to content

Commit

Permalink
[WIP][AMDGPU] Split isInlinableLiteral16 into three and call the sp…
Browse files Browse the repository at this point in the history
…ecific version if possible

The current implementation of `isInlinableLiteral16` assumes, a 16-bit inlinable
literal is either an i16 or a fp16. This is not always true because of bf16.
However, we can't tell fp16 and bf16 apart by just looking at the value. This
patch tries to split `isInlinableLiteral16` into three versions, i16, fp16, bf16
respectively, and call the corresponding version.

This patch is based on #81282. The current status is, only two uses of original
`isInlinableLiteral16` are still there. We need to add an extra argument to indicate
the type of the operand the immediate corresponds to. This will also require the
change of the function signature of the two callers.
  • Loading branch information
shiltian committed Feb 13, 2024
1 parent e9a5322 commit 7df617d
Show file tree
Hide file tree
Showing 12 changed files with 257 additions and 115 deletions.
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4106,7 +4106,7 @@ InstructionSelector::ComplexRendererFns
AMDGPUInstructionSelector::selectWMMAVISrc(MachineOperand &Root) const {
std::optional<FPValueAndVReg> FPValReg;
if (mi_match(Root.getReg(), *MRI, m_GFCstOrSplat(FPValReg))) {
if (TII.isInlineConstant(FPValReg->Value.bitcastToAPInt())) {
if (TII.isInlineConstant(FPValReg->Value)) {
return {{[=](MachineInstrBuilder &MIB) {
MIB.addImm(FPValReg->Value.bitcastToAPInt().getSExtValue());
}}};
Expand Down
48 changes: 37 additions & 11 deletions llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1927,8 +1927,12 @@ static bool isInlineableLiteralOp16(int64_t Val, MVT VT, bool HasInv2Pi) {
return isInlinableIntLiteral(Val);
}

// f16/v2f16 operands work correctly for all values.
return AMDGPU::isInlinableLiteral16(Val, HasInv2Pi);
if (VT.getScalarType() == MVT::f16)
return AMDGPU::isInlinableLiteralFP16(Val, HasInv2Pi);

assert(VT.getScalarType() == MVT::bf16);

return AMDGPU::isInlinableLiteralBF16(Val, HasInv2Pi);
}

bool AMDGPUOperand::isInlinableImm(MVT type) const {
Expand Down Expand Up @@ -2277,15 +2281,26 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
return;

case AMDGPU::OPERAND_REG_IMM_INT16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
if (isSafeTruncation(Val, 16) &&
AMDGPU::isInlinableIntLiteral(static_cast<int16_t>(Val))) {
Inst.addOperand(MCOperand::createImm(Val));
setImmKindConst();
return;
}

Inst.addOperand(MCOperand::createImm(Val & 0xffff));
setImmKindLiteral();
return;

case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
if (isSafeTruncation(Val, 16) &&
AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm())) {
AMDGPU::isInlinableLiteralFP16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm())) {
Inst.addOperand(MCOperand::createImm(Val));
setImmKindConst();
return;
Expand All @@ -2296,12 +2311,17 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
return;

case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16: {
assert(isSafeTruncation(Val, 16));
assert(AMDGPU::isInlinableIntLiteral(static_cast<int16_t>(Val)));
Inst.addOperand(MCOperand::createImm(Val));
return;
}
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: {
assert(isSafeTruncation(Val, 16));
assert(AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm()));
assert(AMDGPU::isInlinableLiteralFP16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm()));

Inst.addOperand(MCOperand::createImm(Val));
return;
Expand Down Expand Up @@ -3429,7 +3449,13 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst,
OperandType == AMDGPU::OPERAND_REG_IMM_V2FP16)
return AMDGPU::isInlinableLiteralV2F16(Val);

return AMDGPU::isInlinableLiteral16(Val, hasInv2PiInlineImm());
if (OperandType == AMDGPU::OPERAND_REG_IMM_FP16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_C_FP16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_FP16 ||
OperandType == AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED)
return AMDGPU::isInlinableLiteralFP16(Val, hasInv2PiInlineImm());

llvm_unreachable("invalid operand type");
}
default:
llvm_unreachable("invalid operand size");
Expand Down
23 changes: 16 additions & 7 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,8 @@ void AMDGPUInstPrinter::printImmediateInt16(uint32_t Imm,

// This must accept a 32-bit immediate value to correctly handle packed 16-bit
// operations.
static bool printImmediateFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O) {
static bool printImmediateFP16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O) {
if (Imm == 0x3C00)
O << "1.0";
else if (Imm == 0xBC00)
Expand All @@ -488,7 +488,7 @@ static bool printImmediateFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
return true;
}

void AMDGPUInstPrinter::printImmediate16(uint32_t Imm,
void AMDGPUInstPrinter::printImmediate16(uint32_t Imm, uint8_t OpType,
const MCSubtargetInfo &STI,
raw_ostream &O) {
int16_t SImm = static_cast<int16_t>(Imm);
Expand All @@ -498,8 +498,17 @@ void AMDGPUInstPrinter::printImmediate16(uint32_t Imm,
}

uint16_t HImm = static_cast<uint16_t>(Imm);
if (printImmediateFloat16(HImm, STI, O))
return;
switch (OpType) {
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
if (printImmediateFP16(HImm, STI, O))
return;
break;
default:
llvm_unreachable("bad operand type");
}

uint64_t Imm16 = static_cast<uint16_t>(Imm);
O << formatHex(Imm16);
Expand All @@ -525,7 +534,7 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
if (isUInt<16>(Imm) &&
printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
printImmediateFP16(static_cast<uint16_t>(Imm), STI, O))
return;
break;
default:
Expand Down Expand Up @@ -797,7 +806,7 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
printImmediate16(Op.getImm(), STI, O);
printImmediate16(Op.getImm(), OpTy, STI, O);
break;
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ class AMDGPUInstPrinter : public MCInstPrinter {
raw_ostream &O);
void printImmediateInt16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediate16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediate16(uint32_t Imm, uint8_t OpType,
const MCSubtargetInfo &STI, raw_ostream &O);
void printImmediateV216(uint32_t Imm, uint8_t OpType,
const MCSubtargetInfo &STI, raw_ostream &O);
bool printImmediateFloat32(uint32_t Imm, const MCSubtargetInfo &STI,
Expand Down
24 changes: 14 additions & 10 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12965,10 +12965,8 @@ SDValue SITargetLowering::performFPMed3ImmCombine(SelectionDAG &DAG,

const SIInstrInfo *TII = getSubtarget()->getInstrInfo();

if ((!K0->hasOneUse() ||
TII->isInlineConstant(K0->getValueAPF().bitcastToAPInt())) &&
(!K1->hasOneUse() ||
TII->isInlineConstant(K1->getValueAPF().bitcastToAPInt()))) {
if ((!K0->hasOneUse() || TII->isInlineConstant(K0->getValueAPF())) &&
(!K1->hasOneUse() || TII->isInlineConstant(K1->getValueAPF()))) {
return DAG.getNode(AMDGPUISD::FMED3, SL, K0->getValueType(0),
Var, SDValue(K0, 0), SDValue(K1, 0));
}
Expand Down Expand Up @@ -15391,16 +15389,22 @@ bool SITargetLowering::checkAsmConstraintVal(SDValue Op, StringRef Constraint,
llvm_unreachable("Invalid asm constraint");
}

bool SITargetLowering::checkAsmConstraintValA(SDValue Op,
uint64_t Val,
bool SITargetLowering::checkAsmConstraintValA(SDValue Op, uint64_t Val,
unsigned MaxSize) const {
unsigned Size = std::min<unsigned>(Op.getScalarValueSizeInBits(), MaxSize);
bool HasInv2Pi = Subtarget->hasInv2PiInlineImm();
if ((Size == 16 && AMDGPU::isInlinableLiteral16(Val, HasInv2Pi)) ||
(Size == 32 && AMDGPU::isInlinableLiteral32(Val, HasInv2Pi)) ||
(Size == 64 && AMDGPU::isInlinableLiteral64(Val, HasInv2Pi))) {
return true;
if (Size == 16) {
MVT VT = Op.getSimpleValueType();
if (VT == MVT::i16 && AMDGPU::isInlinableLiteralI16(Val, HasInv2Pi))
return true;
if (VT == MVT::f16 && AMDGPU::isInlinableLiteralFP16(Val, HasInv2Pi))
return true;
if (VT == MVT::bf16 && AMDGPU::isInlinableLiteralBF16(Val, HasInv2Pi))
return true;
}
if ((Size == 32 && AMDGPU::isInlinableLiteral32(Val, HasInv2Pi)) ||
(Size == 64 && AMDGPU::isInlinableLiteral64(Val, HasInv2Pi)))
return true;
return false;
}

Expand Down
25 changes: 22 additions & 3 deletions llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4121,8 +4121,27 @@ bool SIInstrInfo::isInlineConstant(const APInt &Imm) const {
ST.hasInv2PiInlineImm());
case 16:
return ST.has16BitInsts() &&
AMDGPU::isInlinableLiteral16(Imm.getSExtValue(),
ST.hasInv2PiInlineImm());
AMDGPU::isInlinableLiteralI16(Imm.getSExtValue(),
ST.hasInv2PiInlineImm());
default:
llvm_unreachable("invalid bitwidth");
}
}

bool SIInstrInfo::isInlineConstant(const APFloat &Imm) const {
APInt IntImm = Imm.bitcastToAPInt();
bool HasInv2Pi = ST.hasInv2PiInlineImm();
switch (IntImm.getBitWidth()) {
case 32:
case 64:
return isInlineConstant(IntImm);
case 16:
if (Imm.isIEEE())
return ST.has16BitInsts() &&
AMDGPU::isInlinableLiteralFP16(IntImm.getSExtValue(), HasInv2Pi);
else
return ST.has16BitInsts() &&
AMDGPU::isInlinableLiteralBF16(IntImm.getSExtValue(), HasInv2Pi);
default:
llvm_unreachable("invalid bitwidth");
}
Expand Down Expand Up @@ -4196,7 +4215,7 @@ bool SIInstrInfo::isInlineConstant(const MachineOperand &MO,
// constants in these cases
int16_t Trunc = static_cast<int16_t>(Imm);
return ST.has16BitInsts() &&
AMDGPU::isInlinableLiteral16(Trunc, ST.hasInv2PiInlineImm());
AMDGPU::isInlinableLiteralFP16(Trunc, ST.hasInv2PiInlineImm());
}

return false;
Expand Down
4 changes: 1 addition & 3 deletions llvm/lib/Target/AMDGPU/SIInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -966,9 +966,7 @@ class SIInstrInfo final : public AMDGPUGenInstrInfo {

bool isInlineConstant(const APInt &Imm) const;

bool isInlineConstant(const APFloat &Imm) const {
return isInlineConstant(Imm.bitcastToAPInt());
}
bool isInlineConstant(const APFloat &Imm) const;

// Returns true if this non-register operand definitely does not need to be
// encoded as a 32-bit literal. Note that this function handles all kinds of
Expand Down
38 changes: 35 additions & 3 deletions llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2652,13 +2652,28 @@ bool isInlinableLiteral32(int32_t Literal, bool HasInv2Pi) {
(Val == 0x3e22f983 && HasInv2Pi);
}

bool isInlinableLiteral16(int16_t Literal, bool HasInv2Pi) {
bool isInlinableLiteralI16(int16_t Literal, bool HasInv2Pi) {
if (!HasInv2Pi)
return false;
if (isInlinableIntLiteral(Literal))
return true;
return (Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(0.0f))) ||
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(1.0f))) ||
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(-1.0f))) ||
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(0.5f))) ||
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(-0.5f))) ||
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(2.0f))) ||
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(-2.0f))) ||
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(4.0f))) ||
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(-4.0f))) ||
(Literal == static_cast<int16_t>(0x3e22f983));
}

bool isInlinableLiteralFP16(int16_t Literal, bool HasInv2Pi) {
if (!HasInv2Pi)
return false;

if (isInlinableIntLiteral(Literal))
return true;

uint16_t Val = static_cast<uint16_t>(Literal);
return Val == 0x3C00 || // 1.0
Val == 0xBC00 || // -1.0
Expand All @@ -2671,6 +2686,23 @@ bool isInlinableLiteral16(int16_t Literal, bool HasInv2Pi) {
Val == 0x3118; // 1/2pi
}

bool isInlinableLiteralBF16(int16_t Literal, bool HasInv2Pi) {
if (!HasInv2Pi)
return false;
if (isInlinableIntLiteral(Literal))
return true;
uint16_t Val = static_cast<uint16_t>(Literal);
return Val == 0x3F00 || // 0.5
Val == 0xBF00 || // -0.5
Val == 0x3F80 || // 1.0
Val == 0xBF80 || // -1.0
Val == 0x4000 || // 2.0
Val == 0xC000 || // -2.0
Val == 0x4080 || // 4.0
Val == 0xC080 || // -4.0
Val == 0x3E22; // 1.0 / (2.0 * pi)
}

std::optional<unsigned> getInlineEncodingV216(bool IsFloat, uint32_t Literal) {
// Unfortunately, the Instruction Set Architecture Reference Guide is
// misleading about how the inline operands work for (packed) 16-bit
Expand Down
8 changes: 7 additions & 1 deletion llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,13 @@ LLVM_READNONE
bool isInlinableLiteral32(int32_t Literal, bool HasInv2Pi);

LLVM_READNONE
bool isInlinableLiteral16(int16_t Literal, bool HasInv2Pi);
bool isInlinableLiteralFP16(int16_t Literal, bool HasInv2Pi);

LLVM_READNONE
bool isInlinableLiteralBF16(int16_t Literal, bool HasInv2Pi);

LLVM_READNONE
bool isInlinableLiteralI16(int16_t Literal, bool HasInv2Pi);

LLVM_READNONE
std::optional<unsigned> getInlineEncodingV2I16(uint32_t Literal);
Expand Down
Loading

0 comments on commit 7df617d

Please sign in to comment.