Skip to content

Commit

Permalink
[RFC][WIP][AMDGPU] Use bf16 instead of i16 for bfloat
Browse files Browse the repository at this point in the history
Currently it looks like we generally use `i16` to represent `bf16` in those tablegen
files. I'm not sure of the reason behind it. My wild guess is the type `bf16` was
not available when we enabled the support. This patch is trying to use `bf16`
directly in those tablegen files, aiming at fixing #79369. Of course for #79369
a workaround can be to treat all `INT16` variants as `BFloat` in `getOpFltSemantics`,
but it doesn't look good IMHO.

Since I'm fairly new to AMDGPU backend, I'd appreciate it if you can point out
where I don't understand correctly.
  • Loading branch information
shiltian committed Feb 16, 2024
1 parent c098f2d commit d95e99e
Show file tree
Hide file tree
Showing 16 changed files with 356 additions and 34 deletions.
5 changes: 4 additions & 1 deletion clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-gfx11.cl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ typedef unsigned short __attribute__((ext_vector_type(2))) ushort2;
// CHECK: call float @llvm.amdgcn.fdot2(<2 x half> %v2hA, <2 x half> %v2hB, float %fC, i1 false)
// CHECK: call float @llvm.amdgcn.fdot2(<2 x half> %v2hA, <2 x half> %v2hB, float %fC, i1 true)
// CHECK: call half @llvm.amdgcn.fdot2.f16.f16(<2 x half> %v2hA, <2 x half> %v2hB, half %hC)
// CHECK: call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, i16 %sC)
// CHECK: [[s1:%[0-9]+]] = bitcast <2 x i16> %v2ssA to <2 x bfloat>
// CHECK-NEXT: [[s2:%[0-9]+]] = bitcast <2 x i16> %v2ssB to <2 x bfloat>
// CHECK-NEXT: [[s3:%[0-9]+]] = bitcast i16 %sC to bfloat
// CHECK-NEXT: [[d:%[0-9]+]] = tail call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> [[s1]], <2 x bfloat> [[s2]], bfloat [[s3]])
// CHECK: call float @llvm.amdgcn.fdot2.f32.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, float %fC, i1 false)
// CHECK: call float @llvm.amdgcn.fdot2.f32.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, float %fC, i1 true)
// CHECK: call i32 @llvm.amdgcn.udot4(i32 %uiA, i32 %uiB, i32 %uiC, i1 false)
Expand Down
8 changes: 4 additions & 4 deletions llvm/include/llvm/IR/IntrinsicsAMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -2819,11 +2819,11 @@ def int_amdgcn_fdot2_f16_f16 :
def int_amdgcn_fdot2_bf16_bf16 :
ClangBuiltin<"__builtin_amdgcn_fdot2_bf16_bf16">,
DefaultAttrsIntrinsic<
[llvm_i16_ty], // %r
[llvm_bfloat_ty], // %r
[
llvm_v2i16_ty, // %a
llvm_v2i16_ty, // %b
llvm_i16_ty // %c
llvm_v2bf16_ty, // %a
llvm_v2bf16_ty, // %b
llvm_bfloat_ty // %c
],
[IntrNoMem, IntrSpeculatable]
>;
Expand Down
92 changes: 92 additions & 0 deletions llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {

bool isSSrcF64() const { return isSCSrc_b64() || isLiteralImm(MVT::f64); }

bool isSSrc_bf16() const { return isSCSrcB16() || isLiteralImm(MVT::bf16); }

bool isSSrc_f16() const { return isSCSrcB16() || isLiteralImm(MVT::f16); }

bool isSSrcV2F16() const {
Expand Down Expand Up @@ -541,22 +543,40 @@ class AMDGPUOperand : public MCParsedAsmOperand {
return isRegOrInlineNoMods(AMDGPU::VS_64RegClassID, MVT::f64);
}

bool isVCSrcTBF16() const {
return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::bf16);
}

bool isVCSrcTF16() const {
return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::f16);
}

bool isVCSrcTBF16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::bf16);
}

bool isVCSrcTF16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::f16);
}

bool isVCSrcFake16BF16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::bf16);
}

bool isVCSrcFake16F16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::f16);
}

bool isVCSrc_bf16() const {
return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::bf16);
}

bool isVCSrc_f16() const {
return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::f16);
}

bool isVCSrc_v2bf16() const { return isVCSrc_bf16(); }

bool isVCSrc_v2f16() const { return isVCSrc_f16(); }

bool isVSrc_b32() const {
Expand Down Expand Up @@ -597,18 +617,34 @@ class AMDGPUOperand : public MCParsedAsmOperand {

bool isVSrc_f64() const { return isVCSrcF64() || isLiteralImm(MVT::f64); }

bool isVSrcT_bf16() const { return isVCSrcTBF16() || isLiteralImm(MVT::bf16); }

bool isVSrcT_f16() const { return isVCSrcTF16() || isLiteralImm(MVT::f16); }

bool isVSrcT_bf16_Lo128() const {
return isVCSrcTBF16_Lo128() || isLiteralImm(MVT::bf16);
}

bool isVSrcT_f16_Lo128() const {
return isVCSrcTF16_Lo128() || isLiteralImm(MVT::f16);
}

bool isVSrcFake16_bf16_Lo128() const {
return isVCSrcFake16BF16_Lo128() || isLiteralImm(MVT::bf16);
}

bool isVSrcFake16_f16_Lo128() const {
return isVCSrcFake16F16_Lo128() || isLiteralImm(MVT::f16);
}

bool isVSrc_bf16() const { return isVCSrc_bf16() || isLiteralImm(MVT::bf16); }

bool isVSrc_f16() const { return isVCSrc_f16() || isLiteralImm(MVT::f16); }

bool isVSrc_v2bf16() const {
return isVSrc_bf16() || isLiteralImm(MVT::v2bf16);
}

bool isVSrc_v2f16() const { return isVSrc_f16() || isLiteralImm(MVT::v2f16); }

bool isVISrcB32() const {
Expand All @@ -635,6 +671,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
return isVISrcF16() || isVISrcB32();
}

bool isVISrc_64_bf16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::bf16);
}

bool isVISrc_64_f16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::f16);
}
Expand Down Expand Up @@ -803,6 +843,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
return isAISrc_128F16() || isAISrc_128_b32();
}

bool isVISrc_128_bf16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::bf16);
}

bool isVISrc_128_f16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::f16);
}
Expand Down Expand Up @@ -1890,6 +1934,14 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_KIMM16:
return &APFloat::IEEEhalf();
case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2BF16:
return &APFloat::BFloat();
default:
llvm_unreachable("unsupported fp type");
}
Expand Down Expand Up @@ -2186,17 +2238,24 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
case AMDGPU::OPERAND_REG_IMM_INT16:
case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
case AMDGPU::OPERAND_REG_IMM_V2FP32:
Expand Down Expand Up @@ -2240,6 +2299,7 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_IMM_V2FP32:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
Expand Down Expand Up @@ -2295,6 +2355,22 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
setImmKindLiteral();
return;

case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
if (isSafeTruncation(Val, 16) &&
AMDGPU::isInlinableLiteralBF16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm())) {
Inst.addOperand(MCOperand::createImm(Val));
setImmKindConst();
return;
}

Inst.addOperand(MCOperand::createImm(Val & 0xffff));
setImmKindLiteral();
return;

case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
Expand All @@ -2306,6 +2382,17 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
Inst.addOperand(MCOperand::createImm(Val));
return;
}

case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16: {
assert(isSafeTruncation(Val, 16));
assert(AMDGPU::isInlinableLiteralBF16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm()));

Inst.addOperand(MCOperand::createImm(Val));
return;
}

case AMDGPU::OPERAND_KIMM32:
Inst.addOperand(MCOperand::createImm(Literal.getLoBits(32).getZExtValue()));
setImmKindMandatoryLiteral();
Expand Down Expand Up @@ -3429,6 +3516,11 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst,
OperandType == AMDGPU::OPERAND_REG_IMM_V2FP16)
return AMDGPU::isInlinableLiteralV2F16(Val);

if (OperandType == AMDGPU::OPERAND_REG_INLINE_C_V2BF16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_V2BF16 ||
OperandType == AMDGPU::OPERAND_REG_IMM_V2BF16)
return AMDGPU::isInlinableLiteralV2BF16(Val);

return AMDGPU::isInlinableLiteral16(Val, hasInv2PiInlineImm());
}
default:
Expand Down
57 changes: 57 additions & 0 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,47 @@ static bool printImmediateFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
return true;
}

static bool printImmediateBFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O) {
if (Imm == 0x3F80)
O << "1.0";
else if (Imm == 0xBF80)
O << "-1.0";
else if (Imm == 0x3F00)
O << "0.5";
else if (Imm == 0xBF00)
O << "-0.5";
else if (Imm == 0x4000)
O << "2.0";
else if (Imm == 0xC000)
O << "-2.0";
else if (Imm == 0x4080)
O << "4.0";
else if (Imm == 0xC080)
O << "-4.0";
else if (Imm == 0x3E22 && STI.hasFeature(AMDGPU::FeatureInv2PiInlineImm))
O << "0.15915494";
else
return false;

return true;
}

void AMDGPUInstPrinter::printImmediateBF16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
int16_t SImm = static_cast<int16_t>(Imm);
if (isInlinableIntLiteral(SImm)) {
O << SImm;
return;
}

if (printImmediateBFloat16(static_cast<uint16_t>(Imm), STI, O))
return;

O << formatHex(static_cast<uint64_t>(Imm));
}

void AMDGPUInstPrinter::printImmediate16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
Expand Down Expand Up @@ -528,6 +569,13 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
return;
break;
case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
if (isUInt<16>(Imm) &&
printImmediateBFloat16(static_cast<uint16_t>(Imm), STI, O))
return;
break;
default:
llvm_unreachable("bad operand type");
}
Expand Down Expand Up @@ -799,11 +847,20 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
printImmediate16(Op.getImm(), STI, O);
break;
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
printImmediateBF16(Op.getImm(), STI, O);
break;
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
printImmediateV216(Op.getImm(), OpTy, STI, O);
break;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class AMDGPUInstPrinter : public MCInstPrinter {
raw_ostream &O);
void printImmediate16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateBF16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateV216(uint32_t Imm, uint8_t OpType,
const MCSubtargetInfo &STI, raw_ostream &O);
bool printImmediateFloat32(uint32_t Imm, const MCSubtargetInfo &STI,
Expand Down
39 changes: 39 additions & 0 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,27 @@ static uint32_t getLit16Encoding(uint16_t Val, const MCSubtargetInfo &STI) {
return 255;
}

static uint32_t getLitBF16Encoding(uint16_t Val) {
uint16_t IntImm = getIntInlineImmEncoding(static_cast<int16_t>(Val));
if (IntImm != 0)
return IntImm;

// clang-format off
switch (Val) {
case 0x3F00: return 240; // 0.5
case 0xBF00: return 241; // -0.5
case 0x3F80: return 242; // 1.0
case 0xBF80: return 243; // -1.0
case 0x4000: return 244; // 2.0
case 0xC000: return 245; // -2.0
case 0x4080: return 246; // 4.0
case 0xC080: return 247; // -4.0
case 0x3E22: return 248; // 1.0 / (2.0 * pi)
default: return 255;
}
// clang-format on
}

static uint32_t getLit32Encoding(uint32_t Val, const MCSubtargetInfo &STI) {
uint32_t IntImm = getIntInlineImmEncoding(static_cast<int32_t>(Val));
if (IntImm != 0)
Expand Down Expand Up @@ -276,23 +297,41 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
return getLit16IntEncoding(static_cast<uint16_t>(Imm), STI);

case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
// FIXME Is this correct? What do inline immediates do on SI for f16 src
// which does not have f16 support?
return getLit16Encoding(static_cast<uint16_t>(Imm), STI);

case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
// We don't actually need to check Inv2Pi here because BF16 instructions can
// only be emitted for targets that already support the feature.
return getLitBF16Encoding(static_cast<uint16_t>(Imm));

case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
return AMDGPU::getInlineEncodingV2I16(static_cast<uint32_t>(Imm))
.value_or(255);

case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
return AMDGPU::getInlineEncodingV2F16(static_cast<uint32_t>(Imm))
.value_or(255);

case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
return AMDGPU::getInlineEncodingV2BF16(static_cast<uint32_t>(Imm))
.value_or(255);

case AMDGPU::OPERAND_KIMM32:
case AMDGPU::OPERAND_KIMM16:
return MO.getImm();
Expand Down
Loading

0 comments on commit d95e99e

Please sign in to comment.