Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,21 @@ static const std::string S8_to_Bf16 =
"prmt.b32 $0, f0, f1, 0x7632; \n" // f32->bf16 + pack
"prmt.b32 $1, f2, f3, 0x7632; \n" //
"}";
// Conversions have low throughput, rely on bit tricks instead of cvt
// instruction on Hopper and later GPUs.
static const std::string S8_to_Bf16_sm90 =
"{ \n"
".reg .b32 l<3>; \n"
".reg .b32 h<3>; \n"
"prmt.b32 l0, $2, 0x43, 0x4140; \n" // Unpack to shifted bf16.
"prmt.b32 h0, $2, 0x43, 0x4342; \n"
"and.b32 l1, l0, 0xff7fff7f; \n" // Zero the least exp bit.
"and.b32 h1, h0, 0xff7fff7f; \n"
"and.b32 l2, l0, 0xff80ff80; \n" // Zero the mantissa.
"and.b32 h2, h0, 0xff80ff80; \n"
"sub.bf16x2 $0, l1, l2; \n" // Subtract the offset.
"sub.bf16x2 $1, h1, h2; \n"
"}";

typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
const SmallVector<Value> &)>
Expand Down Expand Up @@ -646,9 +661,15 @@ struct FSubOpConversion
struct SIToFPOpConversion
: ElementwiseOpConversionBase<arith::SIToFPOp, SIToFPOpConversion> {
using Base = ElementwiseOpConversionBase<arith::SIToFPOp, SIToFPOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

explicit SIToFPOpConversion(LLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisAnalysisPass,
int computeCapability,
PatternBenefit benefit = patternBenefitDefault)
: ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit),
computeCapability(computeCapability) {}

SmallVector<Value> createDestOps(arith::SIToFPOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Expand All @@ -657,7 +678,8 @@ struct SIToFPOpConversion
Type outElemTy = getElementType(op.getOut());
if (outElemTy.isBF16() && inElemTy.isInteger(8) && operands.size() >= 4) {
auto cvtFunc = makeConverterFromPtx(
S8_to_Bf16, getTypeConverter()->convertType(inElemTy),
computeCapability >= 90 ? S8_to_Bf16_sm90 : S8_to_Bf16,
getTypeConverter()->convertType(inElemTy),
getTypeConverter()->convertType(outElemTy));
SmallVector<Value> inVals = {operands[0][0], operands[1][0],
operands[2][0], operands[3][0]};
Expand All @@ -668,6 +690,9 @@ struct SIToFPOpConversion
return {rewriter.create<LLVM::SIToFPOp>(loc, elemTy, operands[0][0])};
}
}

private:
int computeCapability;
};

struct FPToSIOpConversion
Expand Down Expand Up @@ -920,8 +945,9 @@ void mlir::triton::NVIDIA::populateElementwiseOpToLLVMPatterns(
patterns.add<ExtFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<TruncFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FPToSIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<SIToFPOpConversion>(typeConverter, axisInfoAnalysis, benefit);

patterns.add<SIToFPOpConversion>(typeConverter, axisInfoAnalysis,
computeCapability, benefit);
patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis,
computeCapability, benefit);

Expand Down