diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index 2fabb598e99d..91e9a4bbf888 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -234,6 +234,21 @@ static const std::string S8_to_Bf16 = "prmt.b32 $0, f0, f1, 0x7632; \n" // f32->bf16 + pack "prmt.b32 $1, f2, f3, 0x7632; \n" // "}"; +// Conversions have low throughput, rely on bit tricks instead of cvt +// instruction on Hopper and later GPUs. +static const std::string S8_to_Bf16_sm90 = + "{ \n" + ".reg .b32 l<3>; \n" + ".reg .b32 h<3>; \n" + "prmt.b32 l0, $2, 0x43, 0x4140; \n" // Unpack to shifted bf16. + "prmt.b32 h0, $2, 0x43, 0x4342; \n" + "and.b32 l1, l0, 0xff7fff7f; \n" // Zero the least exp bit. + "and.b32 h1, h0, 0xff7fff7f; \n" + "and.b32 l2, l0, 0xff80ff80; \n" // Zero the mantissa. + "and.b32 h2, h0, 0xff80ff80; \n" + "sub.bf16x2 $0, l1, l2; \n" // Subtract the offset. + "sub.bf16x2 $1, h1, h2; \n" + "}"; typedef std::function(Location, ConversionPatternRewriter &, const SmallVector &)> @@ -646,9 +661,15 @@ struct FSubOpConversion struct SIToFPOpConversion : ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; - using Base::Base; using Adaptor = typename Base::OpAdaptor; + explicit SIToFPOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + int computeCapability, + PatternBenefit benefit = patternBenefitDefault) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + computeCapability(computeCapability) {} + SmallVector createDestOps(arith::SIToFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, @@ -657,7 +678,8 @@ struct SIToFPOpConversion Type outElemTy = getElementType(op.getOut()); if (outElemTy.isBF16() && inElemTy.isInteger(8) && operands.size() >= 4) { auto cvtFunc = makeConverterFromPtx( - S8_to_Bf16, getTypeConverter()->convertType(inElemTy), + computeCapability >= 90 ? S8_to_Bf16_sm90 : S8_to_Bf16, + getTypeConverter()->convertType(inElemTy), getTypeConverter()->convertType(outElemTy)); SmallVector inVals = {operands[0][0], operands[1][0], operands[2][0], operands[3][0]}; @@ -668,6 +690,9 @@ struct SIToFPOpConversion return {rewriter.create(loc, elemTy, operands[0][0])}; } } + +private: + int computeCapability; }; struct FPToSIOpConversion @@ -920,8 +945,9 @@ void mlir::triton::NVIDIA::populateElementwiseOpToLLVMPatterns( patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, + computeCapability, benefit); patterns.add(typeConverter, axisInfoAnalysis, computeCapability, benefit);