Skip to content

Commit

Permalink
[SYCL][NVPTX] Set default fdiv and sqrt for llvm.fpbuiltin (#16714)
Browse files Browse the repository at this point in the history
AltMathLibrary is lacking implementation for llvm.fpbuiltin intrinsics
for NVPTX target. This patch adds type-dependent mapping for
llvm.fpbuiltin.fdiv with max-error > 2.0 and llvm.fpbuiltin.sqrt with
max-error > 1.0 on nvvm intrinsics:
fp32 scalar @llvm.fpbuiltin.fdiv -> @llvm.nvvm.div.approx.f
fp32 scalar @llvm.fpbuiltin.sqrt -> @llvm.nvvm.sqrt.approx.f

vector or non-fp32 scalar llvm.fpbuiltin.fdiv -> fdiv
vector or non-fp32 scalar llvm.fpbuiltin.sqrt -> llvm.sqrt

Additionally it maps max-error=0.5 fpbuiltin.fadd, fpbuiltin.fsub.
fpbuiltin.fmul, fpbuiltin.fdiv, fpbuiltin.frem, fpbuiltin.sqrt and
fpbuiltin.ldexp intrinsic functions of LLVM's math operations or
https://llvm.org/docs/LangRef.html#standard-c-c-library-intrinsics

TODO in future patches:
- add preservation of debug info in FPBuiltinFnSelection;
- moved tests from CodeGen to Transform
- move pass to new pass manager

Signed-off-by: Sidorov, Dmitry <[email protected]>

---------

Signed-off-by: Sidorov, Dmitry <[email protected]>
  • Loading branch information
MrSidims authored Jan 30, 2025
1 parent ddea941 commit 52238e1
Show file tree
Hide file tree
Showing 3 changed files with 372 additions and 4 deletions.
63 changes: 59 additions & 4 deletions llvm/lib/Transforms/Scalar/FPBuiltinFnSelection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/FormatVariadic.h"

Expand Down Expand Up @@ -106,6 +107,51 @@ static bool replaceWithLLVMIR(FPBuiltinIntrinsic &BuiltinCall) {
return true;
}

// This function lowers llvm.fpbuiltin. intrinsic functions with max-error
// attribute to the appropriate nvvm approximate intrinsics if it's possible.
// If it's not possible - fallback to instruction or standard C/C++ library LLVM
// intrinsic.
static bool
replaceWithApproxNVPTXCallsOrFallback(FPBuiltinIntrinsic &BuiltinCall,
std::optional<float> Accuracy) {
IRBuilder<> IRBuilder(&BuiltinCall);
SmallVector<Value *> Args(BuiltinCall.args());
Value *Replacement = nullptr;
auto *Type = BuiltinCall.getType();
// For now only add lowering for fdiv and sqrt. Yet nvvm intrinsics have
// approximate variants for sin, cos, exp2 and log2.
// For vector fpbuiltins for NVPTX target we don't have nvvm intrinsics,
// fallback to instruction or standard C/C++ library LLVM intrinsic. Also
// nvvm fdiv and sqrt intrisics support only float type, so fallback in this
// case as well.
switch (BuiltinCall.getIntrinsicID()) {
case Intrinsic::fpbuiltin_fdiv:
if (Accuracy.value() < 2.0)
return false;
if (Type->isVectorTy() || !Type->getScalarType()->isFloatTy())
return replaceWithLLVMIR(BuiltinCall);
Replacement =
IRBuilder.CreateIntrinsic(Type, Intrinsic::nvvm_div_approx_f, Args);
break;
case Intrinsic::fpbuiltin_sqrt:
if (Accuracy.value() < 1.0)
return false;
if (Type->isVectorTy() || !Type->getScalarType()->isFloatTy())
return replaceWithLLVMIR(BuiltinCall);
Replacement =
IRBuilder.CreateIntrinsic(Type, Intrinsic::nvvm_sqrt_approx_f, Args);
break;
default:
return false;
}
BuiltinCall.replaceAllUsesWith(Replacement);
cast<Instruction>(Replacement)->copyFastMathFlags(&BuiltinCall);
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
<< BuiltinCall.getCalledFunction()->getName()
<< "` with equivalent IR. \n `");
return true;
}

static bool selectFnForFPBuiltinCalls(const TargetLibraryInfo &TLI,
const TargetTransformInfo &TTI,
FPBuiltinIntrinsic &BuiltinCall) {
Expand Down Expand Up @@ -136,10 +182,11 @@ static bool selectFnForFPBuiltinCalls(const TargetLibraryInfo &TLI,
return replaceWithLLVMIR(BuiltinCall);

// Several functions for "sycl" and "cuda" requires "0.5" accuracy levels,
// which means correctly rounded results. For now x86 host AltMathLibrary
// doesn't have such ability. For such accuracy level, the fpbuiltins
// should be replaced by equivalent IR operation or llvmbuiltins.
if (T.isX86() && BuiltinCall.getRequiredAccuracy().value() == 0.5) {
// which means correctly rounded results. For now x86 host and NVPTX
// AltMathLibrary doesn't have such ability. For such accuracy level, the
// fpbuiltins should be replaced by equivalent IR operation or llvmbuiltins.
if ((T.isX86() || T.isNVPTX()) &&
BuiltinCall.getRequiredAccuracy().value() == 0.5) {
switch (BuiltinCall.getIntrinsicID()) {
case Intrinsic::fpbuiltin_fadd:
case Intrinsic::fpbuiltin_fsub:
Expand All @@ -154,6 +201,14 @@ static bool selectFnForFPBuiltinCalls(const TargetLibraryInfo &TLI,
}
}

// AltMathLibrary don't have implementation for CUDA approximate precision
// builtins. Lets map them on NVPTX intrinsics. If no appropriate intrinsics
// are known - skip to emit an error.
if (T.isNVPTX() && BuiltinCall.getRequiredAccuracy().value() > 0.5)
if (replaceWithApproxNVPTXCallsOrFallback(
BuiltinCall, BuiltinCall.getRequiredAccuracy()))
return true;

/// Call TLI to select a function implementation to call
StringRef ImplName = TLI.selectFPBuiltinImplementation(&BuiltinCall);
if (ImplName.empty()) {
Expand Down
94 changes: 94 additions & 0 deletions llvm/test/CodeGen/NVPTX/fp-builtin-intrinsics-nvvm-approx.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
; RUN: opt -fpbuiltin-fn-selection -S < %s | FileCheck %s

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

; CHECK-LABEL: @test_fdiv
; CHECK: %{{.*}} = call float @llvm.nvvm.div.approx.f(float %{{.*}}, float %{{.*}})
; CHECK: %{{.*}} = fdiv <2 x float> %{{.*}}, %{{.*}}
define void @test_fdiv(float %d1, <2 x float> %v2d1,
float %d2, <2 x float> %v2d2) {
entry:
%t0 = call float @llvm.fpbuiltin.fdiv.f32(float %d1, float %d2) #0
%t1 = call <2 x float> @llvm.fpbuiltin.fdiv.v2f32(<2 x float> %v2d1, <2 x float> %v2d2) #0
ret void
}

; CHECK-LABEL: @test_fdiv_fast
; CHECK: %{{.*}} = call fast float @llvm.nvvm.div.approx.f(float %{{.*}}, float %{{.*}})
; CHECK: %{{.*}} = fdiv fast <2 x float> %{{.*}}, %{{.*}}
define void @test_fdiv_fast(float %d1, <2 x float> %v2d1,
float %d2, <2 x float> %v2d2) {
entry:
%t0 = call fast float @llvm.fpbuiltin.fdiv.f32(float %d1, float %d2) #0
%t1 = call fast <2 x float> @llvm.fpbuiltin.fdiv.v2f32(<2 x float> %v2d1, <2 x float> %v2d2) #0
ret void
}

; CHECK-LABEL: @test_fdiv_max_error
; CHECK: %{{.*}} = call float @llvm.nvvm.div.approx.f(float %{{.*}}, float %{{.*}})
; CHECK: %{{.*}} = fdiv <2 x float> %{{.*}}, %{{.*}}
define void @test_fdiv_max_error(float %d1, <2 x float> %v2d1,
float %d2, <2 x float> %v2d2) {
entry:
%t0 = call float @llvm.fpbuiltin.fdiv.f32(float %d1, float %d2) #2
%t1 = call <2 x float> @llvm.fpbuiltin.fdiv.v2f32(<2 x float> %v2d1, <2 x float> %v2d2) #2
ret void
}

declare float @llvm.fpbuiltin.fdiv.f32(float, float)
declare <2 x float> @llvm.fpbuiltin.fdiv.v2f32(<2 x float>, <2 x float>)

; CHECK-LABEL: @test_fdiv_double
; CHECK: %{{.*}} = fdiv double %{{.*}}, %{{.*}}
; CHECK: %{{.*}} = fdiv <2 x double> %{{.*}}, %{{.*}}
define void @test_fdiv_double(double %d1, <2 x double> %v2d1,
double %d2, <2 x double> %v2d2) {
entry:
%t0 = call double @llvm.fpbuiltin.fdiv.f64(double %d1, double %d2) #0
%t1 = call <2 x double> @llvm.fpbuiltin.fdiv.v2f64(<2 x double> %v2d1, <2 x double> %v2d2) #0
ret void
}

declare double @llvm.fpbuiltin.fdiv.f64(double, double)
declare <2 x double> @llvm.fpbuiltin.fdiv.v2f64(<2 x double>, <2 x double>)

; CHECK-LABEL: @test_sqrt
; CHECK: %{{.*}} = call float @llvm.nvvm.sqrt.approx.f(float %{{.*}})
; CHECK: %{{.*}} = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %{{.*}})
define void @test_sqrt(float %d, <2 x float> %v2d, <4 x float> %v4d) {
entry:
%t0 = call float @llvm.fpbuiltin.sqrt.f32(float %d) #1
%t1 = call <2 x float> @llvm.fpbuiltin.sqrt.v2f32(<2 x float> %v2d) #1
ret void
}

; CHECK-LABEL: @test_sqrt_max_error
; CHECK: %{{.*}} = call float @llvm.nvvm.sqrt.approx.f(float %{{.*}})
; CHECK: %{{.*}} = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %{{.*}})
define void @test_sqrt_max_error(float %d, <2 x float> %v2d, <4 x float> %v4d) {
entry:
%t0 = call float @llvm.fpbuiltin.sqrt.f32(float %d) #2
%t1 = call <2 x float> @llvm.fpbuiltin.sqrt.v2f32(<2 x float> %v2d) #2
ret void
}

declare float @llvm.fpbuiltin.sqrt.f32(float)
declare <2 x float> @llvm.fpbuiltin.sqrt.v2f32(<2 x float>)

; CHECK-LABEL: @test_sqrt_double
; CHECK: %{{.*}} = call double @llvm.sqrt.f64(double %{{.*}})
; CHECK: %{{.*}} = call <2 x double> @llvm.sqrt.v2f64(<2 x double> %{{.*}})
define void @test_sqrt_double(double %d, <2 x double> %v2d) {
entry:
%t0 = call double @llvm.fpbuiltin.sqrt.f64(double %d) #1
%t1 = call <2 x double> @llvm.fpbuiltin.sqrt.v2f64(<2 x double> %v2d) #1
ret void
}

declare double @llvm.fpbuiltin.sqrt.f64(double)
declare <2 x double> @llvm.fpbuiltin.sqrt.v2f64(<2 x double>)

attributes #0 = { "fpbuiltin-max-error"="2.5" }
attributes #1 = { "fpbuiltin-max-error"="3.0" }
attributes #2 = { "fpbuiltin-max-error"="10.0" }
Loading

0 comments on commit 52238e1

Please sign in to comment.