[SelectionDAG] Add expansion for llvm.convert.from.arbitrary.fp#179318
[SelectionDAG] Add expansion for llvm.convert.from.arbitrary.fp#179318
Conversation
The expansion converts arbitrary-precision FP represented as integer following these algorithm: 1. Extract sign, exponent, and mantissa bit fields via masks and shifts. 2. Classify the input (zero, denormal, normal, Inf, NaN) using the exponent and mantissa fields. 3. Normal path: adjusting the exponent bias and left-shifting the mantissa to fit the wider destination format. 4. Denormal path: normalizing by finding the MSB position of the mantissa (via count-leading-zeros), computing the correct exponent from that position, stripping the implicit leading 1, and shifting the fraction into the destination mantissa field. 5. Assemble the destination IEEE bit pattern (sign | exponent | mantissa) and select among the normal, denormal, and special-value results. Currently only conversions from OCP floats are covered, in LLVM terms these are: Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, Float4E2M1FN. OCP spec: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
As sanity check for the patch, I've tested E2E FP8 conversions in runtime on X86 using a toy builtin that serves as a wrapper to the intrinsic (MrSidims@75916aa). Test that compares results between expansion and C implementation can be found here: https://gist.github.com/MrSidims/70a7e645cf14b15b9e8ff064c52319e3 UPD Feb 10th:
fast drawing of the clang builtin (which is not going to be published as it's incomplete and AFAIK not really desired by anybody) and library-like reference implementation in C are done with a help of AI. |
| N->getOperand(1)); | ||
| } | ||
|
|
||
| SDValue DAGTypeLegalizer::ScalarizeVecRes_CONVERT_FROM_ARBITRARY_FP(SDNode *N) { |
There was a problem hiding this comment.
It's basically the same as ScalarizeVecRes_FP_ROUND. Per my taste, despite duplicated code, it's better to leave as is as 2 separate functions simplify search in the repository for exact intrinsic expansion. Merging 2 functions together might worsen this experience.
|
@llvm/pr-subscribers-llvm-support @llvm/pr-subscribers-llvm-adt Author: Dmitry Sidorov (MrSidims) ChangesThe expansion converts arbitrary-precision FP represented as integer following these algorithm:
Currently only conversions from OCP floats are covered, in LLVM terms these are: Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, Float4E2M1FN. OCP spec: Patch is 103.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/179318.diff 14 Files Affected:
diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index 44fa3919962c4..287c7bf51ede6 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -412,6 +412,10 @@ class APFloatBase {
/// format interpretation for llvm.convert.to.arbitrary.fp and
/// llvm.convert.from.arbitrary.fp intrinsics.
LLVM_ABI static bool isValidArbitraryFPFormat(StringRef Format);
+
+ /// Returns the fltSemantics for a given arbitrary FP format string,
+ /// or nullptr if invalid.
+ LLVM_ABI static const fltSemantics *getArbitraryFPSemantics(StringRef Format);
};
namespace detail {
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 2ebd2641944f5..7818b54c7cccb 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1014,6 +1014,12 @@ enum NodeType {
STRICT_BF16_TO_FP,
STRICT_FP_TO_BF16,
+ /// CONVERT_FROM_ARBITRARY_FP - This operator converts from an arbitrary
+ /// floating-point represented as an integer to a native FP type.
+ /// The first operand is the integer containing the source FP bits.
+ /// The second operand is a constant indicating the source FP semantics.
+ CONVERT_FROM_ARBITRARY_FP,
+
/// Perform various unary floating-point operations inspired by libm. For
/// FPOWI, the result is undefined if the integer operand doesn't fit into
/// sizeof(int).
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index d9a2409b35e4c..c708a429bf79d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3495,6 +3495,254 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
Results.push_back(Op);
break;
}
+ case ISD::CONVERT_FROM_ARBITRARY_FP: {
+ // Expand conversion from arbitrary FP format stored in an integer to a
+ // native IEEE float type using integer bit manipulation.
+ //
+ // TODO: currently only conversions from FP4, FP6 and FP8 formats from OCP
+ // specification are expanded. Remaining arbitrary FP types: Float8E4M3,
+ // Float8E3M4, Float8E5M2FNUZ, Float8E4M3FNUZ, Float8E4M3B11FNUZ,
+ // Float8E8M0FNU.
+ EVT DstVT = Node->getValueType(0);
+
+ // For vector types, unroll into scalar operations.
+ if (DstVT.isVector()) {
+ Results.push_back(DAG.UnrollVectorOp(Node));
+ break;
+ }
+
+ SDValue IntVal = Node->getOperand(0);
+ const uint64_t SemEnum = Node->getConstantOperandVal(1);
+ const auto Sem = static_cast<APFloatBase::Semantics>(SemEnum);
+
+ // Supported source formats.
+ switch (Sem) {
+ case APFloatBase::S_Float8E5M2:
+ case APFloatBase::S_Float8E4M3FN:
+ case APFloatBase::S_Float6E3M2FN:
+ case APFloatBase::S_Float6E2M3FN:
+ case APFloatBase::S_Float4E2M1FN:
+ break;
+ default:
+ report_fatal_error("CONVERT_FROM_ARBITRARY_FP: unsupported source "
+ "format (semantics enum " +
+ Twine(SemEnum) + ")");
+ }
+
+ const fltSemantics &SrcSem = APFloatBase::EnumToSemantics(Sem);
+
+ const unsigned SrcBits = APFloat::getSizeInBits(SrcSem);
+ const unsigned SrcPrecision = APFloat::semanticsPrecision(SrcSem);
+ const unsigned SrcMant = SrcPrecision - 1;
+ const unsigned SrcExp = SrcBits - SrcMant - 1;
+ const int SrcBias = 1 - APFloat::semanticsMinExponent(SrcSem);
+
+ const fltNonfiniteBehavior NFBehavior = SrcSem.nonFiniteBehavior;
+ const fltNanEncoding NanEnc = SrcSem.nanEncoding;
+
+ // Destination format parameters.
+ const fltSemantics *DstSem;
+ if (DstVT == MVT::f16)
+ DstSem = &APFloat::IEEEhalf();
+ else if (DstVT == MVT::bf16)
+ DstSem = &APFloat::BFloat();
+ else if (DstVT == MVT::f32)
+ DstSem = &APFloat::IEEEsingle();
+ else if (DstVT == MVT::f64)
+ DstSem = &APFloat::IEEEdouble();
+ else
+ llvm_unreachable("Unsupported destination float type");
+
+ const unsigned DstBits = APFloat::getSizeInBits(*DstSem);
+ const unsigned DstMant = APFloat::semanticsPrecision(*DstSem) - 1;
+ const unsigned DstExpBits = DstBits - DstMant - 1;
+ const int DstMinExp = APFloat::semanticsMinExponent(*DstSem);
+ const int DstBias = 1 - DstMinExp;
+ const uint64_t DstExpAllOnes = (1ULL << DstExpBits) - 1;
+
+ // Work in an integer type matching the destination float width.
+ // Use zero-extend to preserve the raw bit-pattern.
+ EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), DstBits);
+ SDValue Src = DAG.getZExtOrTrunc(IntVal, dl, IntVT);
+
+ EVT SetCCVT =
+ TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), IntVT);
+ SDValue Zero = DAG.getConstant(0, dl, IntVT);
+ SDValue One = DAG.getConstant(1, dl, IntVT);
+
+ // Extract bit fields.
+ const uint64_t MantMask = (SrcMant > 0) ? ((1ULL << SrcMant) - 1) : 0;
+ const uint64_t ExpMask = (1ULL << SrcExp) - 1;
+
+ SDValue MantField = DAG.getNode(ISD::AND, dl, IntVT, Src,
+ DAG.getConstant(MantMask, dl, IntVT));
+
+ SDValue ExpField =
+ DAG.getNode(ISD::AND, dl, IntVT,
+ DAG.getNode(ISD::SRL, dl, IntVT, Src,
+ DAG.getShiftAmountConstant(SrcMant, IntVT, dl)),
+ DAG.getConstant(ExpMask, dl, IntVT));
+
+ SDValue SignBit =
+ DAG.getNode(ISD::SRL, dl, IntVT, Src,
+ DAG.getShiftAmountConstant(SrcBits - 1, IntVT, dl));
+
+ // Precompute sign shifted to MSB of destination.
+ SDValue SignShifted =
+ DAG.getNode(ISD::SHL, dl, IntVT, SignBit,
+ DAG.getShiftAmountConstant(DstBits - 1, IntVT, dl));
+
+ // Classify the input value based on compile-time format properties.
+ SDValue ExpAllOnes = DAG.getConstant(ExpMask, dl, IntVT);
+ SDValue IsExpAllOnes =
+ DAG.getSetCC(dl, SetCCVT, ExpField, ExpAllOnes, ISD::SETEQ);
+ SDValue IsExpZero = DAG.getSetCC(dl, SetCCVT, ExpField, Zero, ISD::SETEQ);
+ SDValue IsMantZero = DAG.getSetCC(dl, SetCCVT, MantField, Zero, ISD::SETEQ);
+ SDValue IsMantNonZero =
+ DAG.getSetCC(dl, SetCCVT, MantField, Zero, ISD::SETNE);
+
+ // NaN detection.
+ SDValue IsNaN;
+ if (NFBehavior == fltNonfiniteBehavior::FiniteOnly) {
+ // FiniteOnly formats (E2M1FN, E3M2FN, E2M3FN) never produce NaN.
+ IsNaN = DAG.getBoolConstant(false, dl, SetCCVT, IntVT);
+ } else if (NFBehavior == fltNonfiniteBehavior::IEEE754) {
+ // E5M2 produces NaN when exp == all-ones AND mantissa != 0.
+ IsNaN = DAG.getNode(ISD::AND, dl, SetCCVT, IsExpAllOnes, IsMantNonZero);
+ } else {
+ // NanOnly + AllOnes (E4M3FN): NaN when all exp and mantissa bits are 1.
+ assert(NanEnc == fltNanEncoding::AllOnes);
+ SDValue MantAllOnes = DAG.getConstant(MantMask, dl, IntVT);
+ SDValue IsMantAllOnes =
+ DAG.getSetCC(dl, SetCCVT, MantField, MantAllOnes, ISD::SETEQ);
+ IsNaN = DAG.getNode(ISD::AND, dl, SetCCVT, IsExpAllOnes, IsMantAllOnes);
+ }
+
+ // Inf detection.
+ SDValue IsInf;
+ if (NFBehavior == fltNonfiniteBehavior::IEEE754) {
+ // E5M2: Inf when exp == all-ones AND mantissa == 0.
+ IsInf = DAG.getNode(ISD::AND, dl, SetCCVT, IsExpAllOnes, IsMantZero);
+ } else {
+ // NanOnly and FiniteOnly formats have no Inf representation.
+ IsInf = DAG.getBoolConstant(false, dl, SetCCVT, IntVT);
+ }
+
+ // Zero detection.
+ SDValue IsZero = DAG.getNode(ISD::AND, dl, SetCCVT, IsExpZero, IsMantZero);
+
+ // Denorm detection: exp == 0 AND mant != 0.
+ SDValue IsDenorm =
+ DAG.getNode(ISD::AND, dl, SetCCVT, IsExpZero, IsMantNonZero);
+
+ // Normal value conversion.
+ // dst_exp = exp_field + (DstBias - SrcBias)
+ // dst_mant = mant << (DstMant - SrcMant)
+ const int BiasAdjust = DstBias - SrcBias;
+ SDValue NormDstExp = DAG.getNode(
+ ISD::ADD, dl, IntVT, ExpField,
+ DAG.getConstant(APInt(DstBits, BiasAdjust, true), dl, IntVT));
+
+ SDValue NormDstMant;
+ if (DstMant > SrcMant)
+ NormDstMant =
+ DAG.getNode(ISD::SHL, dl, IntVT, MantField,
+ DAG.getShiftAmountConstant(DstMant - SrcMant, IntVT, dl));
+ else
+ NormDstMant = MantField;
+
+ // Assemble normal result.
+ SDValue NormExpShifted =
+ DAG.getNode(ISD::SHL, dl, IntVT, NormDstExp,
+ DAG.getShiftAmountConstant(DstMant, IntVT, dl));
+ SDValue NormResult = DAG.getNode(
+ ISD::OR, dl, IntVT,
+ DAG.getNode(ISD::OR, dl, IntVT, SignShifted, NormExpShifted),
+ NormDstMant);
+
+ // Denormal value conversion.
+ // For a denormal source (exp_field == 0, mant != 0), normalize by finding
+ // the MSB position of mant using CTLZ, then compute the correct
+ // exponent and mantissa for the destination format.
+ SDValue DenormResult;
+ {
+ const unsigned IntVTBits = DstBits;
+ SDValue LeadingZeros =
+ DAG.getNode(ISD::CTLZ_ZERO_UNDEF, dl, IntVT, MantField);
+
+ // dst_exp_denorm = (IntVTBits + DstBias - SrcBias - SrcMant) -
+ // LeadingZeros
+ const int DenormExpConst =
+ (int)IntVTBits + DstBias - SrcBias - (int)SrcMant;
+ SDValue DenormDstExp = DAG.getNode(
+ ISD::SUB, dl, IntVT,
+ DAG.getConstant(APInt(DstBits, DenormExpConst, true), dl, IntVT),
+ LeadingZeros);
+
+ // MSB position of the mantissa (0-indexed from LSB).
+ SDValue MantMSB =
+ DAG.getNode(ISD::SUB, dl, IntVT,
+ DAG.getConstant(IntVTBits - 1, dl, IntVT), LeadingZeros);
+
+ // leading_one = 1 << MantMSB
+ SDValue LeadingOne = DAG.getNode(ISD::SHL, dl, IntVT, One, MantMSB);
+
+ // frac = mant XOR leading_one (strip the implicit 1)
+ SDValue Frac = DAG.getNode(ISD::XOR, dl, IntVT, MantField, LeadingOne);
+
+ // shift_amount = DstMant - MantMSB
+ // = DstMant - (IntVTBits - 1 - LeadingZeros)
+ // = LeadingZeros - (IntVTBits - 1 - DstMant)
+ const unsigned ShiftSub = IntVTBits - 1 - DstMant; // always >= 0
+ SDValue ShiftAmount = DAG.getNode(ISD::SUB, dl, IntVT, LeadingZeros,
+ DAG.getConstant(ShiftSub, dl, IntVT));
+
+ SDValue DenormDstMant =
+ DAG.getNode(ISD::SHL, dl, IntVT, Frac, ShiftAmount);
+
+ // Assemble denorm as sign | (denorm_dst_exp << DstMant) | denorm_dst_mant
+ SDValue DenormExpShifted =
+ DAG.getNode(ISD::SHL, dl, IntVT, DenormDstExp,
+ DAG.getShiftAmountConstant(DstMant, IntVT, dl));
+ DenormResult = DAG.getNode(
+ ISD::OR, dl, IntVT,
+ DAG.getNode(ISD::OR, dl, IntVT, SignShifted, DenormExpShifted),
+ DenormDstMant);
+ }
+
+ // Select between normal and denorm paths.
+ SDValue FiniteResult =
+ DAG.getSelect(dl, IntVT, IsDenorm, DenormResult, NormResult);
+
+ // Build special-value results.
+ // NaN -> canonical quiet NaN: sign=0, exp=all-ones, qNaN bit set.
+ // Encoding: (DstExpAllOnes << DstMant) | (1 << (DstMant - 1))
+ const uint64_t QNaNBit = (DstMant > 0) ? (1ULL << (DstMant - 1)) : 0;
+ SDValue NaNResult =
+ DAG.getConstant((DstExpAllOnes << DstMant) | QNaNBit, dl, IntVT);
+
+ // Inf -> destination Inf.
+ // sign | (DstExpAllOnes << DstMant)
+ SDValue InfResult =
+ DAG.getNode(ISD::OR, dl, IntVT, SignShifted,
+ DAG.getConstant(DstExpAllOnes << DstMant, dl, IntVT));
+
+ // Zero -> signed zero.
+ // Sign bit only.
+ SDValue ZeroResult = SignShifted;
+
+ // Final selection goes in order: NaN takes priority, then Inf, then Zero.
+ SDValue Result = FiniteResult;
+ Result = DAG.getSelect(dl, IntVT, IsZero, ZeroResult, Result);
+ Result = DAG.getSelect(dl, IntVT, IsInf, InfResult, Result);
+ Result = DAG.getSelect(dl, IntVT, IsNaN, NaNResult, Result);
+
+ // Bitcast integer result to destination float type.
+ Result = DAG.getNode(ISD::BITCAST, dl, DstVT, Result);
+
+ Results.push_back(Result);
+ break;
+ }
case ISD::FCANONICALIZE: {
// This implements llvm.canonicalize.f* by multiplication with 1.0, as
// suggested in
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 16453f220bb50..0acb510a9550d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2763,6 +2763,9 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
case ISD::STRICT_UINT_TO_FP:
case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP: R = SoftPromoteHalfRes_XINT_TO_FP(N); break;
+ case ISD::CONVERT_FROM_ARBITRARY_FP:
+ R = SoftPromoteHalfRes_CONVERT_FROM_ARBITRARY_FP(N);
+ break;
case ISD::POISON:
case ISD::UNDEF: R = SoftPromoteHalfRes_UNDEF(N); break;
case ISD::ATOMIC_SWAP: R = BitcastToInt_ATOMIC_SWAP(N); break;
@@ -3050,6 +3053,19 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_XINT_TO_FP(SDNode *N) {
return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
}
+SDValue
+DAGTypeLegalizer::SoftPromoteHalfRes_CONVERT_FROM_ARBITRARY_FP(SDNode *N) {
+ EVT OVT = N->getValueType(0);
+ EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
+ SDLoc dl(N);
+
+ SDValue Res = DAG.getNode(ISD::CONVERT_FROM_ARBITRARY_FP, dl, NVT,
+ N->getOperand(0), N->getOperand(1));
+
+ // Round the value to the softened type.
+ return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
+}
+
SDValue DAGTypeLegalizer::SoftPromoteHalfRes_UNDEF(SDNode *N) {
return DAG.getUNDEF(MVT::i16);
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 5b32c5f945a75..e768677675863 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2099,6 +2099,9 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::FP16_TO_FP:
case ISD::VP_UINT_TO_FP:
case ISD::UINT_TO_FP: Res = PromoteIntOp_UINT_TO_FP(N); break;
+ case ISD::CONVERT_FROM_ARBITRARY_FP:
+ Res = PromoteIntOp_CONVERT_FROM_ARBITRARY_FP(N);
+ break;
case ISD::STRICT_FP16_TO_FP:
case ISD::STRICT_UINT_TO_FP: Res = PromoteIntOp_STRICT_UINT_TO_FP(N); break;
case ISD::ZERO_EXTEND: Res = PromoteIntOp_ZERO_EXTEND(N); break;
@@ -2698,6 +2701,12 @@ SDValue DAGTypeLegalizer::PromoteIntOp_UINT_TO_FP(SDNode *N) {
ZExtPromotedInteger(N->getOperand(0))), 0);
}
+SDValue DAGTypeLegalizer::PromoteIntOp_CONVERT_FROM_ARBITRARY_FP(SDNode *N) {
+ return SDValue(DAG.UpdateNodeOperands(N, GetPromotedInteger(N->getOperand(0)),
+ N->getOperand(1)),
+ 0);
+}
+
SDValue DAGTypeLegalizer::PromoteIntOp_STRICT_UINT_TO_FP(SDNode *N) {
return SDValue(DAG.UpdateNodeOperands(N, N->getOperand(0),
ZExtPromotedInteger(N->getOperand(1))), 0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index f10b6dfa902ec..cdfc0a8df1dac 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -417,6 +417,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntOp_TRUNCATE(SDNode *N);
SDValue PromoteIntOp_UINT_TO_FP(SDNode *N);
SDValue PromoteIntOp_STRICT_UINT_TO_FP(SDNode *N);
+ SDValue PromoteIntOp_CONVERT_FROM_ARBITRARY_FP(SDNode *N);
SDValue PromoteIntOp_ZERO_EXTEND(SDNode *N);
SDValue PromoteIntOp_VP_ZERO_EXTEND(SDNode *N);
SDValue PromoteIntOp_MSTORE(MaskedStoreSDNode *N, unsigned OpNo);
@@ -806,6 +807,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue SoftPromoteHalfRes_FNEG(SDNode *N);
SDValue SoftPromoteHalfRes_AssertNoFPClass(SDNode *N);
SDValue SoftPromoteHalfRes_XINT_TO_FP(SDNode *N);
+ SDValue SoftPromoteHalfRes_CONVERT_FROM_ARBITRARY_FP(SDNode *N);
SDValue SoftPromoteHalfRes_UNDEF(SDNode *N);
SDValue SoftPromoteHalfRes_VECREDUCE(SDNode *N);
SDValue SoftPromoteHalfRes_VECREDUCE_SEQ(SDNode *N);
@@ -857,6 +859,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue ScalarizeVecRes_BUILD_VECTOR(SDNode *N);
SDValue ScalarizeVecRes_EXTRACT_SUBVECTOR(SDNode *N);
SDValue ScalarizeVecRes_FP_ROUND(SDNode *N);
+ SDValue ScalarizeVecRes_CONVERT_FROM_ARBITRARY_FP(SDNode *N);
SDValue ScalarizeVecRes_UnaryOpWithExtraInput(SDNode *N);
SDValue ScalarizeVecRes_INSERT_VECTOR_ELT(SDNode *N);
SDValue ScalarizeVecRes_LOAD(LoadSDNode *N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 4db9c5d009ae8..d9affb9e2cbe5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -62,6 +62,9 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
case ISD::BUILD_VECTOR: R = ScalarizeVecRes_BUILD_VECTOR(N); break;
case ISD::EXTRACT_SUBVECTOR: R = ScalarizeVecRes_EXTRACT_SUBVECTOR(N); break;
case ISD::FP_ROUND: R = ScalarizeVecRes_FP_ROUND(N); break;
+ case ISD::CONVERT_FROM_ARBITRARY_FP:
+ R = ScalarizeVecRes_CONVERT_FROM_ARBITRARY_FP(N);
+ break;
case ISD::AssertZext:
case ISD::AssertSext:
case ISD::FPOWI:
@@ -478,6 +481,23 @@ SDValue DAGTypeLegalizer::ScalarizeVecRes_FP_ROUND(SDNode *N) {
N->getOperand(1));
}
+SDValue DAGTypeLegalizer::ScalarizeVecRes_CONVERT_FROM_ARBITRARY_FP(SDNode *N) {
+ SDLoc DL(N);
+ SDValue Op = N->getOperand(0);
+ EVT OpVT = Op.getValueType();
+ // The result needs scalarizing, but it's not a given that the source does.
+ // See similar logic in ScalarizeVecRes_UnaryOp.
+ if (getTypeAction(OpVT) == TargetLowering::TypeScalarizeVector) {
+ Op = GetScalarizedVector(Op);
+ } else {
+ EVT VT = OpVT.getVectorElementType();
+ Op = DAG.getExtractVectorElt(DL, VT, Op, 0);
+ }
+ return DAG.getNode(ISD::CONVERT_FROM_ARBITRARY_FP, DL,
+ N->getValueType(0).getVectorElementType(), Op,
+ N->getOperand(1));
+}
+
SDValue DAGTypeLegalizer::ScalarizeVecRes_UnaryOpWithExtraInput(SDNode *N) {
SDValue Op = GetScalarizedVector(N->getOperand(0));
return DAG.getNode(N->getOpcode(), SDLoc(N), Op.getValueType(), Op,
@@ -818,6 +838,7 @@ bool DAGTypeLegalizer::ScalarizeVectorOperand(SDNode *N, unsigned OpNo) {
break;
case ISD::FP_TO_SINT_SAT:
case ISD::FP_TO_UINT_SAT:
+ case ISD::CONVERT_FROM_ARBITRARY_FP:
Res = ScalarizeVecOp_UnaryOpWithExtraInput(N);
break;
case ISD::STRICT_SINT_TO_FP:
@@ -1367,6 +1388,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::VP_UINT_TO_FP:
case ISD::FCANONICALIZE:
case ISD::AssertNoFPClass:
+ case ISD::CONVERT_FROM_ARBITRARY_FP:
SplitVecRes_UnaryOp(N, Lo, Hi);
break;
case ISD::ADDRSPACECAST:
@@ -2768,7 +2790,8 @@ void DAGTypeLegalizer::SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo,
const SDNodeFlags Flags = N->getFlags();
unsigned Opcode = N->getOpcode();
if (N->getNumOperands() <= 2) {
- if (Opcode == ISD::FP_ROUND || Opcode == ISD::AssertNoFPClass) {
+ if (Opcode == ISD::FP_ROUND || Opcode == ISD::AssertNoFPClass ||
+ Opcode == ISD::CONVERT_FROM_ARBITRARY_FP) {
Lo = DAG.getNode(Opcode, dl, LoVT, Lo, N->getOperand(1), Flags);
Hi = DAG.getNode(Opcode, dl, HiVT, Hi, N->getOperand(1), Flags);
} else {
@@ -3578,7 +3601,10 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
break;
case ISD::STRICT_FP_ROUND:
case ISD::VP_FP_ROUND:
- case ISD::FP_ROUND: Res = SplitVecOp_FP_ROUND(N); break;
+ case ISD::FP_ROUND:
+ case ISD::CONVERT_FROM_ARBITRARY_FP:
+ Res = SplitVecOp_FP_ROUND(N);
+ break;
case ISD::FCOPYSIGN: Res = SplitVecOp_FPOpDifferentTypes(N); break;
case ISD::STORE:
Res = SplitVecOp_STORE(cast<StoreSDNode>(N), OpNo);
@@ -4686,8 +4712,8 @@ SDValue DAGTypeLegalizer::SplitVecOp_FP_ROUND(SDNode *N) {
Lo = DAG.getNode(ISD::VP_FP_ROUND, DL, OutVT, Lo, MaskLo, EVLLo);
Hi = DAG.getNode(ISD::VP_FP...
[truncated]
|
| DAG.getConstant(APInt(DstBits, BiasAdjust, true), dl, IntVT)); | ||
|
|
||
| SDValue NormDstMant; | ||
| if (DstMant > SrcMant) |
There was a problem hiding this comment.
Braces and temp variable to avoid ugly line breaks
| else if (DstVT == MVT::f64) | ||
| DstSem = &APFloat::IEEEdouble(); | ||
| else | ||
| llvm_unreachable("Unsupported destination float type"); |
There was a problem hiding this comment.
It will work, I was just having in mind SPIR-V capabilities (where fp128 is not supported) and this is a wrong way of doing the patch. Implicitly fixed be applying #179318 (comment)
|
@arsenm @RKSimon @kuhar @efriedma-quic @nikic friendly ping in case if you have missed it (as there was a github CI outage the day I pushed it). Also may be there are other people, who might be interested in - if so, please call them in the PR. |
|
Ping. Also adding some folks from MLIR world, who might be interested in the intrinsics and their expansion. |
|
@benvanik @matthias-springer @hanhanW as the three most recent folks I remember being involved in floating-point conversion codes that could be rewritten to this |
| Results.push_back(Op); | ||
| break; | ||
| } | ||
| case ISD::CONVERT_FROM_ARBITRARY_FP: { |
There was a problem hiding this comment.
- Can this be a LLVM-level rewrite (see something like LowerBufferFatPointers) instead? That'd allow for more optimizations of what's already somewhat complex bit manipulation? If there's a target-specific semantic you want to let through, I'd add
TargetTransformInfo::isLegalArbitraryFpConversion - Especially if that gets moved out of DAG->DAG, it should account for fast-math
There was a problem hiding this comment.
I will add target specific lowering as well for AMDGPU and SPIR-V backends in follow up patches. In SPIR-V there is a public multi-vendor extension SPV_EXT_float8 that introduces fp8 <-> ieee float conversions. For AMDGPU I see quite a few entries in VOP3Instructions.td, but it requires me to do some research to see what intructions from there can be mapped to llvm.convert.from.arbitrary.fp and llvm.convert.to.arbitrary.fp (the later will be added right after this PR is merged as it reuses some of the utility functions) intrinsics.
Also I believe there are NVPTX capabilities for this and there is SPV_INTEL_float4 extension, but those will be covered by appropriate folks I guess.
Can this be a LLVM-level rewrite (see something like LowerBufferFatPointers) instead
Depending on where to place such rewriting. If in middle end - it would require a pass to check which target triple the module has to skip those targets which have native support. I personally don't mind such skip, but I know that some folks would object, saying that it's not LLVM way.
There was a problem hiding this comment.
I have mixed views about IR expansions. I simultaneously think they're a hack, and we'd be better off if we did more legalization on IR than in SelectionDAG/GISel. These cases do not require control flow, so they can follow SOP and do the DAG expansion. The downsides of the DAG expansion is that DAG combiner will always be worse than any IR optimizations.
They also add an unstructured ordering property to the "modular IR"
There was a problem hiding this comment.
Re the fp8 operations, I can advise on those - I did a lot of the plumbing up in MLIR
This also means we'll probably have the awkward impedance mismatch where the AMDGPU target really dislikes APIs that have the from <N x i8> but that'll be the best form for this sort of convert.from.arbitrary intrinsic to represent vectors.
(I'm also going to flag the interesting notes that many of these operations come in vector flavors - stuff like "take 16 bits, treat that as 2 x fp8, and expand that to 2 x float" - or, really, take a specified half of 32 bits". That's all stuf that can be done in the optimizer, but it's a long-term hazard.)
| if (DstVT.isVector()) { | ||
| Results.push_back(DAG.UnrollVectorOp(Node)); | ||
| break; | ||
| } |
There was a problem hiding this comment.
Can you not do this? I don't see anything in the code here that shouldn't just naturally work out for vectors.
There was a problem hiding this comment.
Well, it would require a bit of patch rewriting, let me check, what I can do.
Unrolling is required in the current version due to the fact, that the expansion creates SETCC + vselect nodes for NaN/Inf/etc classification. These produce vector i1 types that are (AFAIU) not legal on AMDGPU.
There was a problem hiding this comment.
Okay, with CONVERT_FROM_ARBITRARY_FP registered in LegalizeVectorOps.cpp UnrollVectorOp is not needed.
| StringRef FormatStr = cast<MDString>(MD)->getString(); | ||
| const fltSemantics *SrcSem = | ||
| APFloatBase::getArbitraryFPSemantics(FormatStr); | ||
| if (!SrcSem) { |
There was a problem hiding this comment.
I'd expect the unhandled cases to be caught by the IR verifier
There was a problem hiding this comment.
These cases should and will be handled. I'm intentionally skipping FNUZ for now as they have different denormal semantics to simplify both implementation rounds and review rounds. IMHO this mustn't be moved to verified, instead this very error on line 7139 will be removed when everything is implemented.
| // TODO: extend to remaining arbitrary FP types: Float8E4M3, Float8E3M4, | ||
| // Float8E5M2FNUZ, Float8E4M3FNUZ, Float8E4M3B11FNUZ, Float8E8M0FNU. | ||
| return StringSwitch<const fltSemantics *>(Format) | ||
| .Case("Float8E5M2", &semFloat8E5M2) |
There was a problem hiding this comment.
I'd expect these to be all lowercase (although maybe that mistake was in the original intrinsic patch)
There was a problem hiding this comment.
I don't mind to change this, though I interpret result of #164252 (comment) as - lets use the same names as https://llvm.org/doxygen/classllvm_1_1APFloatBase.html#ae28d826c1042631ac188d8295949ff52 (may be without S_ prefix?). There was even a test added to check the case: https://github.com/llvm/llvm-project/pull/164252/changes#diff-d2cab71d92a8fec9056c6892be1c84291deb22ddab5b0e6113a7d49df2405067R10204
The expansion converts arbitrary-precision FP represented as integer following these algorithm:
Currently only conversions from OCP floats are covered, in LLVM terms these are: Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, Float4E2M1FN.
OCP spec:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
AI has assisted in X86 E2E testing.