Skip to content

Commit

Permalink
[RFC][WIP][AMDGPU] Use bf16 instead of i16 for bfloat
Browse files Browse the repository at this point in the history
Currently it looks like we generally use `i16` to represent `bf16` in those tablegen
files. I'm not sure of the reason behind it. My wild guess is the type `bf16` was
not available when we enabled the support. This patch is trying to use `bf16`
directly in those tablegen files, aiming at fixing #79369. Of course for #79369
a workaround can be to treat all `INT16` variants as `BFloat` in `getOpFltSemantics`,
but it doesn't look good IMHO.

Since I'm fairly new to AMDGPU backend, I'd appreciate it if you can point out
where I don't understand correctly.
  • Loading branch information
shiltian committed Feb 14, 2024
1 parent 09e9895 commit 7a517ee
Show file tree
Hide file tree
Showing 17 changed files with 379 additions and 68 deletions.
4 changes: 0 additions & 4 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5912,8 +5912,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
}
}

assert(ArgValue->getType()->canLosslesslyBitCastTo(PTy) &&
"Must be able to losslessly bit cast to param");
// Cast vector type (e.g., v256i32) to x86_amx, this only happen
// in amx intrinsics.
if (PTy->isX86_AMXTy())
Expand Down Expand Up @@ -5943,8 +5941,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
}
}

assert(V->getType()->canLosslesslyBitCastTo(RetTy) &&
"Must be able to losslessly bit cast result type");
// Cast x86_amx to vector type (e.g., v256i32), this only happen
// in amx intrinsics.
if (V->getType()->isX86_AMXTy())
Expand Down
5 changes: 4 additions & 1 deletion clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-gfx11.cl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ typedef unsigned short __attribute__((ext_vector_type(2))) ushort2;
// CHECK: call float @llvm.amdgcn.fdot2(<2 x half> %v2hA, <2 x half> %v2hB, float %fC, i1 false)
// CHECK: call float @llvm.amdgcn.fdot2(<2 x half> %v2hA, <2 x half> %v2hB, float %fC, i1 true)
// CHECK: call half @llvm.amdgcn.fdot2.f16.f16(<2 x half> %v2hA, <2 x half> %v2hB, half %hC)
// CHECK: call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, i16 %sC)
// CHECK: [[s1:%[0-9]+]] = bitcast <2 x i16> %v2ssA to <2 x bfloat>
// CHECK-NEXT: [[s2:%[0-9]+]] = bitcast <2 x i16> %v2ssB to <2 x bfloat>
// CHECK-NEXT: [[s3:%[0-9]+]] = bitcast i16 %sC to bfloat
// CHECK-NEXT: [[d:%[0-9]+]] = tail call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> [[s1]], <2 x bfloat> [[s2]], bfloat [[s3]])
// CHECK: call float @llvm.amdgcn.fdot2.f32.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, float %fC, i1 false)
// CHECK: call float @llvm.amdgcn.fdot2.f32.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, float %fC, i1 true)
// CHECK: call i32 @llvm.amdgcn.udot4(i32 %uiA, i32 %uiB, i32 %uiC, i1 false)
Expand Down
8 changes: 4 additions & 4 deletions llvm/include/llvm/IR/IntrinsicsAMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -2819,11 +2819,11 @@ def int_amdgcn_fdot2_f16_f16 :
def int_amdgcn_fdot2_bf16_bf16 :
ClangBuiltin<"__builtin_amdgcn_fdot2_bf16_bf16">,
DefaultAttrsIntrinsic<
[llvm_i16_ty], // %r
[llvm_bfloat_ty], // %r
[
llvm_v2i16_ty, // %a
llvm_v2i16_ty, // %b
llvm_i16_ty // %c
llvm_v2bf16_ty, // %a
llvm_v2bf16_ty, // %b
llvm_bfloat_ty // %c
],
[IntrNoMem, IntrSpeculatable]
>;
Expand Down
92 changes: 92 additions & 0 deletions llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {

bool isSSrcF64() const { return isSCSrc_b64() || isLiteralImm(MVT::f64); }

bool isSSrc_bf16() const { return isSCSrcB16() || isLiteralImm(MVT::bf16); }

bool isSSrc_f16() const { return isSCSrcB16() || isLiteralImm(MVT::f16); }

bool isSSrcV2F16() const {
Expand Down Expand Up @@ -541,22 +543,40 @@ class AMDGPUOperand : public MCParsedAsmOperand {
return isRegOrInlineNoMods(AMDGPU::VS_64RegClassID, MVT::f64);
}

bool isVCSrcTBF16() const {
return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::bf16);
}

bool isVCSrcTF16() const {
return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::f16);
}

bool isVCSrcTBF16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::bf16);
}

bool isVCSrcTF16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::f16);
}

bool isVCSrcFake16BF16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::bf16);
}

bool isVCSrcFake16F16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::f16);
}

bool isVCSrc_bf16() const {
return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::bf16);
}

bool isVCSrc_f16() const {
return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::f16);
}

bool isVCSrc_v2bf16() const { return isVCSrc_bf16(); }

bool isVCSrc_v2f16() const { return isVCSrc_f16(); }

bool isVSrc_b32() const {
Expand Down Expand Up @@ -597,18 +617,34 @@ class AMDGPUOperand : public MCParsedAsmOperand {

bool isVSrc_f64() const { return isVCSrcF64() || isLiteralImm(MVT::f64); }

bool isVSrcT_bf16() const { return isVCSrcTBF16() || isLiteralImm(MVT::bf16); }

bool isVSrcT_f16() const { return isVCSrcTF16() || isLiteralImm(MVT::f16); }

bool isVSrcT_bf16_Lo128() const {
return isVCSrcTBF16_Lo128() || isLiteralImm(MVT::bf16);
}

bool isVSrcT_f16_Lo128() const {
return isVCSrcTF16_Lo128() || isLiteralImm(MVT::f16);
}

bool isVSrcFake16_bf16_Lo128() const {
return isVCSrcFake16BF16_Lo128() || isLiteralImm(MVT::bf16);
}

bool isVSrcFake16_f16_Lo128() const {
return isVCSrcFake16F16_Lo128() || isLiteralImm(MVT::f16);
}

bool isVSrc_bf16() const { return isVCSrc_bf16() || isLiteralImm(MVT::bf16); }

bool isVSrc_f16() const { return isVCSrc_f16() || isLiteralImm(MVT::f16); }

bool isVSrc_v2bf16() const {
return isVSrc_bf16() || isLiteralImm(MVT::v2bf16);
}

bool isVSrc_v2f16() const { return isVSrc_f16() || isLiteralImm(MVT::v2f16); }

bool isVISrcB32() const {
Expand All @@ -635,6 +671,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
return isVISrcF16() || isVISrcB32();
}

bool isVISrc_64_bf16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::bf16);
}

bool isVISrc_64_f16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::f16);
}
Expand Down Expand Up @@ -803,6 +843,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
return isAISrc_128F16() || isAISrc_128_b32();
}

bool isVISrc_128_bf16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::bf16);
}

bool isVISrc_128_f16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::f16);
}
Expand Down Expand Up @@ -1890,6 +1934,14 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_KIMM16:
return &APFloat::IEEEhalf();
case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2BF16:
return &APFloat::BFloat();
default:
llvm_unreachable("unsupported fp type");
}
Expand Down Expand Up @@ -2186,17 +2238,24 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
case AMDGPU::OPERAND_REG_IMM_INT16:
case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
case AMDGPU::OPERAND_REG_IMM_V2FP32:
Expand Down Expand Up @@ -2240,6 +2299,7 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_IMM_V2FP32:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
Expand Down Expand Up @@ -2295,6 +2355,22 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
setImmKindLiteral();
return;

case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
if (isSafeTruncation(Val, 16) &&
AMDGPU::isInlinableLiteralBF16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm())) {
Inst.addOperand(MCOperand::createImm(Val));
setImmKindConst();
return;
}

Inst.addOperand(MCOperand::createImm(Val & 0xffff));
setImmKindLiteral();
return;

case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
Expand All @@ -2306,6 +2382,17 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
Inst.addOperand(MCOperand::createImm(Val));
return;
}

case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16: {
assert(isSafeTruncation(Val, 16));
assert(AMDGPU::isInlinableLiteralBF16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm()));

Inst.addOperand(MCOperand::createImm(Val));
return;
}

case AMDGPU::OPERAND_KIMM32:
Inst.addOperand(MCOperand::createImm(Literal.getLoBits(32).getZExtValue()));
setImmKindMandatoryLiteral();
Expand Down Expand Up @@ -3429,6 +3516,11 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst,
OperandType == AMDGPU::OPERAND_REG_IMM_V2FP16)
return AMDGPU::isInlinableLiteralV2F16(Val);

if (OperandType == AMDGPU::OPERAND_REG_INLINE_C_V2BF16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_V2BF16 ||
OperandType == AMDGPU::OPERAND_REG_IMM_V2BF16)
return AMDGPU::isInlinableLiteralV2BF16(Val);

return AMDGPU::isInlinableLiteral16(Val, hasInv2PiInlineImm());
}
default:
Expand Down
57 changes: 57 additions & 0 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,47 @@ static bool printImmediateFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
return true;
}

static bool printImmediateBFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O) {
if (Imm == 0x3F80)
O << "1.0";
else if (Imm == 0xBF80)
O << "-1.0";
else if (Imm == 0x3F00)
O << "0.5";
else if (Imm == 0xBF00)
O << "-0.5";
else if (Imm == 0x4000)
O << "2.0";
else if (Imm == 0xC000)
O << "-2.0";
else if (Imm == 0x4080)
O << "4.0";
else if (Imm == 0xC080)
O << "-4.0";
else if (Imm == 0x3E22 && STI.hasFeature(AMDGPU::FeatureInv2PiInlineImm))
O << "0.15915494";
else
return false;

return true;
}

void AMDGPUInstPrinter::printImmediateBF16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
int16_t SImm = static_cast<int16_t>(Imm);
if (isInlinableIntLiteral(SImm)) {
O << SImm;
return;
}

if (printImmediateBFloat16(static_cast<uint16_t>(Imm), STI, O))
return;

O << formatHex(static_cast<uint64_t>(Imm));
}

void AMDGPUInstPrinter::printImmediate16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
Expand Down Expand Up @@ -528,6 +569,13 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
return;
break;
case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
if (isUInt<16>(Imm) &&
printImmediateBFloat16(static_cast<uint16_t>(Imm), STI, O))
return;
break;
default:
llvm_unreachable("bad operand type");
}
Expand Down Expand Up @@ -799,11 +847,20 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
printImmediate16(Op.getImm(), STI, O);
break;
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
printImmediateBF16(Op.getImm(), STI, O);
break;
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
printImmediateV216(Op.getImm(), OpTy, STI, O);
break;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class AMDGPUInstPrinter : public MCInstPrinter {
raw_ostream &O);
void printImmediate16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateBF16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateV216(uint32_t Imm, uint8_t OpType,
const MCSubtargetInfo &STI, raw_ostream &O);
bool printImmediateFloat32(uint32_t Imm, const MCSubtargetInfo &STI,
Expand Down
Loading

0 comments on commit 7a517ee

Please sign in to comment.