From 0d8a978e8a55b08778ec6ee861c2c5ed6703eb6c Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Fri, 5 Jan 2024 21:04:41 -0700 Subject: [PATCH] intrinsics.fmuladdf{16,32,64,128}: expose llvm.fmuladd.* semantics Add intrinsics `fmuladd{f16,f32,f64,f128}`. This computes `(a * b) + c`, to be fused if the code generator determines that (i) the target instruction set has support for a fused operation, and (ii) that the fused operation is more efficient than the equivalent, separate pair of `mul` and `add` instructions. https://llvm.org/docs/LangRef.html#llvm-fmuladd-intrinsic MIRI support is included for f32 and f64. The codegen_cranelift uses the `fma` function from libc, which is a correct implementation, but without the desired performance semantic. I think this requires an update to cranelift to expose a suitable instruction in its IR. I have not tested with codegen_gcc, but it should behave the same way (using `fma` from libc). --- .../src/intrinsics/mod.rs | 5 +- .../rustc_codegen_gcc/src/intrinsic/mod.rs | 3 ++ compiler/rustc_codegen_llvm/src/context.rs | 5 ++ compiler/rustc_codegen_llvm/src/intrinsic.rs | 5 ++ .../rustc_hir_analysis/src/check/intrinsic.rs | 13 +++++ compiler/rustc_span/src/symbol.rs | 4 ++ library/core/src/intrinsics.rs | 53 +++++++++++++++++++ src/tools/miri/src/intrinsics/mod.rs | 31 +++++++++++ src/tools/miri/tests/pass/float.rs | 18 +++++++ .../intrinsics/fmuladd_nondeterministic.rs | 44 +++++++++++++++ tests/ui/intrinsics/intrinsic-fmuladd.rs | 42 +++++++++++++++ 11 files changed, 222 insertions(+), 1 deletion(-) create mode 100644 src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs create mode 100644 tests/ui/intrinsics/intrinsic-fmuladd.rs diff --git a/compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs b/compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs index 19e5adc253854..35f0ccff3f99e 100644 --- a/compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs +++ b/compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs @@ -328,6 +328,9 @@ fn codegen_float_intrinsic_call<'tcx>( sym::fabsf64 => ("fabs", 1, fx.tcx.types.f64, types::F64), sym::fmaf32 => ("fmaf", 3, fx.tcx.types.f32, types::F32), sym::fmaf64 => ("fma", 3, fx.tcx.types.f64, types::F64), + // FIXME: calling `fma` from libc without FMA target feature uses expensive sofware emulation + sym::fmuladdf32 => ("fmaf", 3, fx.tcx.types.f32, types::F32), // TODO: use cranelift intrinsic analogous to llvm.fmuladd.f32 + sym::fmuladdf64 => ("fma", 3, fx.tcx.types.f64, types::F64), // TODO: use cranelift intrinsic analogous to llvm.fmuladd.f64 sym::copysignf32 => ("copysignf", 2, fx.tcx.types.f32, types::F32), sym::copysignf64 => ("copysign", 2, fx.tcx.types.f64, types::F64), sym::floorf32 => ("floorf", 1, fx.tcx.types.f32, types::F32), @@ -381,7 +384,7 @@ fn codegen_float_intrinsic_call<'tcx>( let layout = fx.layout_of(ty); let res = match intrinsic { - sym::fmaf32 | sym::fmaf64 => { + sym::fmaf32 | sym::fmaf64 | sym::fmuladdf32 | sym::fmuladdf64 => { CValue::by_val(fx.bcx.ins().fma(args[0], args[1], args[2]), layout) } sym::copysignf32 | sym::copysignf64 => { diff --git a/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs b/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs index 945eedf555667..972d66321403d 100644 --- a/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs +++ b/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs @@ -66,6 +66,9 @@ fn get_simple_intrinsic<'gcc, 'tcx>( sym::log2f64 => "log2", sym::fmaf32 => "fmaf", sym::fmaf64 => "fma", + // FIXME: calling `fma` from libc without FMA target feature uses expensive sofware emulation + sym::fmuladdf32 => "fmaf", // TODO: use gcc intrinsic analogous to llvm.fmuladd.f32 + sym::fmuladdf64 => "fma", // TODO: use gcc intrinsic analogous to llvm.fmuladd.f64 sym::fabsf32 => "fabsf", sym::fabsf64 => "fabs", sym::minnumf32 => "fminf", diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 0a116971e0783..c836dd5473f3c 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -884,6 +884,11 @@ impl<'ll> CodegenCx<'ll, '_> { ifn!("llvm.fma.f64", fn(t_f64, t_f64, t_f64) -> t_f64); ifn!("llvm.fma.f128", fn(t_f128, t_f128, t_f128) -> t_f128); + ifn!("llvm.fmuladd.f16", fn(t_f16, t_f16, t_f16) -> t_f16); + ifn!("llvm.fmuladd.f32", fn(t_f32, t_f32, t_f32) -> t_f32); + ifn!("llvm.fmuladd.f64", fn(t_f64, t_f64, t_f64) -> t_f64); + ifn!("llvm.fmuladd.f128", fn(t_f128, t_f128, t_f128) -> t_f128); + ifn!("llvm.fabs.f16", fn(t_f16) -> t_f16); ifn!("llvm.fabs.f32", fn(t_f32) -> t_f32); ifn!("llvm.fabs.f64", fn(t_f64) -> t_f64); diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 30c6f08e894b7..bfe623e7fc354 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -86,6 +86,11 @@ fn get_simple_intrinsic<'ll>( sym::fmaf64 => "llvm.fma.f64", sym::fmaf128 => "llvm.fma.f128", + sym::fmuladdf16 => "llvm.fmuladd.f16", + sym::fmuladdf32 => "llvm.fmuladd.f32", + sym::fmuladdf64 => "llvm.fmuladd.f64", + sym::fmuladdf128 => "llvm.fmuladd.f128", + sym::fabsf16 => "llvm.fabs.f16", sym::fabsf32 => "llvm.fabs.f32", sym::fabsf64 => "llvm.fabs.f64", diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 25e219ef3f23a..06317a3b3049c 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -357,6 +357,19 @@ pub fn check_intrinsic_type( (0, 0, vec![tcx.types.f128, tcx.types.f128, tcx.types.f128], tcx.types.f128) } + sym::fmuladdf16 => { + (0, 0, vec![tcx.types.f16, tcx.types.f16, tcx.types.f16], tcx.types.f16) + } + sym::fmuladdf32 => { + (0, 0, vec![tcx.types.f32, tcx.types.f32, tcx.types.f32], tcx.types.f32) + } + sym::fmuladdf64 => { + (0, 0, vec![tcx.types.f64, tcx.types.f64, tcx.types.f64], tcx.types.f64) + } + sym::fmuladdf128 => { + (0, 0, vec![tcx.types.f128, tcx.types.f128, tcx.types.f128], tcx.types.f128) + } + sym::fabsf16 => (0, 0, vec![tcx.types.f16], tcx.types.f16), sym::fabsf32 => (0, 0, vec![tcx.types.f32], tcx.types.f32), sym::fabsf64 => (0, 0, vec![tcx.types.f64], tcx.types.f64), diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 8e0009695db69..cc3bda99a117b 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -914,6 +914,10 @@ symbols! { fmt_debug, fmul_algebraic, fmul_fast, + fmuladdf128, + fmuladdf16, + fmuladdf32, + fmuladdf64, fn_align, fn_delegation, fn_must_use, diff --git a/library/core/src/intrinsics.rs b/library/core/src/intrinsics.rs index d7a2f1909cabe..061fba9a1f785 100644 --- a/library/core/src/intrinsics.rs +++ b/library/core/src/intrinsics.rs @@ -1795,6 +1795,59 @@ extern "rust-intrinsic" { #[rustc_nounwind] pub fn fmaf128(a: f128, b: f128, c: f128) -> f128; + /// Returns `a * b + c` for `f16` values, non-deterministically executing + /// either a fused multiply-add or two operations with rounding of the + /// intermediate result. + /// + /// The operation is fused if the code generator determines that target + /// instruction set has support for a fused operation, and that the fused + /// operation is more efficient than the equivalent, separate pair of mul + /// and add instructions. It is unspecified whether or not a fused operation + /// is selected, and that may depend on optimization level and context, for + /// example. + #[rustc_nounwind] + #[cfg(not(bootstrap))] + pub fn fmuladdf16(a: f16, b: f16, c: f16) -> f16; + /// Returns `a * b + c` for `f32` values, non-deterministically executing + /// either a fused multiply-add or two operations with rounding of the + /// intermediate result. + /// + /// The operation is fused if the code generator determines that target + /// instruction set has support for a fused operation, and that the fused + /// operation is more efficient than the equivalent, separate pair of mul + /// and add instructions. It is unspecified whether or not a fused operation + /// is selected, and that may depend on optimization level and context, for + /// example. + #[rustc_nounwind] + #[cfg(not(bootstrap))] + pub fn fmuladdf32(a: f32, b: f32, c: f32) -> f32; + /// Returns `a * b + c` for `f64` values, non-deterministically executing + /// either a fused multiply-add or two operations with rounding of the + /// intermediate result. + /// + /// The operation is fused if the code generator determines that target + /// instruction set has support for a fused operation, and that the fused + /// operation is more efficient than the equivalent, separate pair of mul + /// and add instructions. It is unspecified whether or not a fused operation + /// is selected, and that may depend on optimization level and context, for + /// example. + #[rustc_nounwind] + #[cfg(not(bootstrap))] + pub fn fmuladdf64(a: f64, b: f64, c: f64) -> f64; + /// Returns `a * b + c` for `f128` values, non-deterministically executing + /// either a fused multiply-add or two operations with rounding of the + /// intermediate result. + /// + /// The operation is fused if the code generator determines that target + /// instruction set has support for a fused operation, and that the fused + /// operation is more efficient than the equivalent, separate pair of mul + /// and add instructions. It is unspecified whether or not a fused operation + /// is selected, and that may depend on optimization level and context, for + /// example. + #[rustc_nounwind] + #[cfg(not(bootstrap))] + pub fn fmuladdf128(a: f128, b: f128, c: f128) -> f128; + /// Returns the absolute value of an `f16`. /// /// The stabilized version of this intrinsic is diff --git a/src/tools/miri/src/intrinsics/mod.rs b/src/tools/miri/src/intrinsics/mod.rs index 665dd7c441a55..9f772cfa982d8 100644 --- a/src/tools/miri/src/intrinsics/mod.rs +++ b/src/tools/miri/src/intrinsics/mod.rs @@ -295,6 +295,37 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { this.write_scalar(res, dest)?; } + "fmuladdf32" => { + let [a, b, c] = check_arg_count(args)?; + let a = this.read_scalar(a)?.to_f32()?; + let b = this.read_scalar(b)?.to_f32()?; + let c = this.read_scalar(c)?.to_f32()?; + let fuse: bool = this.machine.rng.get_mut().gen(); + let res = if fuse { + // FIXME: Using host floats, to work around https://github.com/rust-lang/rustc_apfloat/issues/11 + a.to_host().mul_add(b.to_host(), c.to_host()).to_soft() + } else { + ((a * b).value + c).value + }; + let res = this.adjust_nan(res, &[a, b, c]); + this.write_scalar(res, dest)?; + } + "fmuladdf64" => { + let [a, b, c] = check_arg_count(args)?; + let a = this.read_scalar(a)?.to_f64()?; + let b = this.read_scalar(b)?.to_f64()?; + let c = this.read_scalar(c)?.to_f64()?; + let fuse: bool = this.machine.rng.get_mut().gen(); + let res = if fuse { + // FIXME: Using host floats, to work around https://github.com/rust-lang/rustc_apfloat/issues/11 + a.to_host().mul_add(b.to_host(), c.to_host()).to_soft() + } else { + ((a * b).value + c).value + }; + let res = this.adjust_nan(res, &[a, b, c]); + this.write_scalar(res, dest)?; + } + "powf32" => { let [f1, f2] = check_arg_count(args)?; let f1 = this.read_scalar(f1)?.to_f32()?; diff --git a/src/tools/miri/tests/pass/float.rs b/src/tools/miri/tests/pass/float.rs index 6ab18a5345ea2..853d3e80517da 100644 --- a/src/tools/miri/tests/pass/float.rs +++ b/src/tools/miri/tests/pass/float.rs @@ -30,6 +30,7 @@ fn main() { libm(); test_fast(); test_algebraic(); + test_fmuladd(); } trait Float: Copy + PartialEq + Debug { @@ -1041,3 +1042,20 @@ fn test_algebraic() { test_operations_f32(11., 2.); test_operations_f32(10., 15.); } + +fn test_fmuladd() { + use std::intrinsics::{fmuladdf32, fmuladdf64}; + + #[inline(never)] + pub fn test_operations_f32(a: f32, b: f32, c: f32) { + assert_approx_eq!(unsafe { fmuladdf32(a, b, c) }, a * b + c); + } + + #[inline(never)] + pub fn test_operations_f64(a: f64, b: f64, c: f64) { + assert_approx_eq!(unsafe { fmuladdf64(a, b, c) }, a * b + c); + } + + test_operations_f32(0.1, 0.2, 0.3); + test_operations_f64(1.1, 1.2, 1.3); +} diff --git a/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs b/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs new file mode 100644 index 0000000000000..b46cf1ddf65df --- /dev/null +++ b/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs @@ -0,0 +1,44 @@ +#![feature(core_intrinsics)] +use std::intrinsics::{fmuladdf32, fmuladdf64}; + +fn main() { + let mut saw_zero = false; + let mut saw_nonzero = false; + for _ in 0..50 { + let a = std::hint::black_box(0.1_f64); + let b = std::hint::black_box(0.2); + let c = std::hint::black_box(-a * b); + // It is unspecified whether the following operation is fused or not. The + // following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused. + let x = unsafe { fmuladdf64(a, b, c) }; + if x == 0.0 { + saw_zero = true; + } else { + saw_nonzero = true; + } + } + assert!( + saw_zero && saw_nonzero, + "`fmuladdf64` failed to be evaluated as both fused and unfused" + ); + + let mut saw_zero = false; + let mut saw_nonzero = false; + for _ in 0..50 { + let a = std::hint::black_box(0.1_f32); + let b = std::hint::black_box(0.2); + let c = std::hint::black_box(-a * b); + // It is unspecified whether the following operation is fused or not. The + // following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused. + let x = unsafe { fmuladdf32(a, b, c) }; + if x == 0.0 { + saw_zero = true; + } else { + saw_nonzero = true; + } + } + assert!( + saw_zero && saw_nonzero, + "`fmuladdf32` failed to be evaluated as both fused and unfused" + ); +} diff --git a/tests/ui/intrinsics/intrinsic-fmuladd.rs b/tests/ui/intrinsics/intrinsic-fmuladd.rs new file mode 100644 index 0000000000000..d03297884f791 --- /dev/null +++ b/tests/ui/intrinsics/intrinsic-fmuladd.rs @@ -0,0 +1,42 @@ +//@ run-pass +#![feature(core_intrinsics)] + +use std::intrinsics::*; + +macro_rules! assert_approx_eq { + ($a:expr, $b:expr) => {{ + let (a, b) = (&$a, &$b); + assert!((*a - *b).abs() < 1.0e-6, "{} is not approximately equal to {}", *a, *b); + }}; +} + +fn main() { + unsafe { + let nan: f32 = f32::NAN; + let inf: f32 = f32::INFINITY; + let neg_inf: f32 = f32::NEG_INFINITY; + assert_approx_eq!(fmuladdf32(1.23, 4.5, 0.67), 6.205); + assert_approx_eq!(fmuladdf32(-1.23, -4.5, -0.67), 4.865); + assert_approx_eq!(fmuladdf32(0.0, 8.9, 1.2), 1.2); + assert_approx_eq!(fmuladdf32(3.4, -0.0, 5.6), 5.6); + assert!(fmuladdf32(nan, 7.8, 9.0).is_nan()); + assert_eq!(fmuladdf32(inf, 7.8, 9.0), inf); + assert_eq!(fmuladdf32(neg_inf, 7.8, 9.0), neg_inf); + assert_eq!(fmuladdf32(8.9, inf, 3.2), inf); + assert_eq!(fmuladdf32(-3.2, 2.4, neg_inf), neg_inf); + } + unsafe { + let nan: f64 = f64::NAN; + let inf: f64 = f64::INFINITY; + let neg_inf: f64 = f64::NEG_INFINITY; + assert_approx_eq!(fmuladdf64(1.23, 4.5, 0.67), 6.205); + assert_approx_eq!(fmuladdf64(-1.23, -4.5, -0.67), 4.865); + assert_approx_eq!(fmuladdf64(0.0, 8.9, 1.2), 1.2); + assert_approx_eq!(fmuladdf64(3.4, -0.0, 5.6), 5.6); + assert!(fmuladdf64(nan, 7.8, 9.0).is_nan()); + assert_eq!(fmuladdf64(inf, 7.8, 9.0), inf); + assert_eq!(fmuladdf64(neg_inf, 7.8, 9.0), neg_inf); + assert_eq!(fmuladdf64(8.9, inf, 3.2), inf); + assert_eq!(fmuladdf64(-3.2, 2.4, neg_inf), neg_inf); + } +}