Skip to content

Commit

Permalink
Introduce a generic way to control checks for specific cases
Browse files Browse the repository at this point in the history
Sometimes we want to be able to xfail specific inputs without changing
the checked ULP for all cases or skipping the tests. There are also some
cases where we need to perform extra checks for only specific functions.

Add a trait that provides a hook for providing extra checks or skipping
existing checks on a per-function or per-input basis.
  • Loading branch information
tgross35 committed Oct 28, 2024
1 parent e74f324 commit 364e13e
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 8 deletions.
23 changes: 23 additions & 0 deletions crates/libm-test/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
pub mod gen;
mod num_traits;
mod special_case;
mod test_traits;

pub use num_traits::{Float, Hex, Int};
pub use special_case::{MaybeOverride, SpecialCase};
pub use test_traits::{CheckBasis, CheckCtx, CheckOutput, GenerateInput, TupleCall};

/// Result type for tests is usually from `anyhow`. Most times there is no success value to
Expand All @@ -11,3 +13,24 @@ pub type TestResult<T = (), E = anyhow::Error> = Result<T, E>;

// List of all files present in libm's source
include!(concat!(env!("OUT_DIR"), "/all_files.rs"));

/// Return the unsuffixed version of a function name; e.g. `abs` and `absf` both return `abs`,
/// `lgamma_r` and `lgammaf_r` both return `lgamma_r`.
pub fn canonical_name(name: &str) -> &str {
let known_mappings = &[
("erff", "erf"),
("erf", "erf"),
("lgammaf_r", "lgamma_r"),
("modff", "modf"),
("modf", "modf"),
];

match known_mappings.iter().find(|known| known.0 == name) {
Some(found) => found.1,
None => name
.strip_suffix("f")
.or_else(|| name.strip_suffix("f16"))
.or_else(|| name.strip_suffix("f128"))
.unwrap_or(name),
}
}
14 changes: 11 additions & 3 deletions crates/libm-test/src/num_traits.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt;

use crate::TestResult;
use crate::{MaybeOverride, SpecialCase, TestResult};

/// Common types and methods for floating point numbers.
pub trait Float: Copy + fmt::Display + fmt::Debug + PartialEq<Self> {
Expand Down Expand Up @@ -137,13 +137,21 @@ macro_rules! impl_int {
}
}

impl<Input: Hex + fmt::Debug> $crate::CheckOutput<Input> for $ty {
impl<Input> $crate::CheckOutput<Input> for $ty
where
Input: Hex + fmt::Debug,
SpecialCase: MaybeOverride<Input>,
{
fn validate<'a>(
self,
expected: Self,
input: Input,
_ctx: &$crate::CheckCtx,
ctx: &$crate::CheckCtx,
) -> TestResult {
if let Some(res) = SpecialCase::check_int(input, self, expected, ctx) {
return res;
}

anyhow::ensure!(
self == expected,
"\
Expand Down
95 changes: 95 additions & 0 deletions crates/libm-test/src/special_case.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//! Configuration for skipping or changing the result for individual test cases (inputs) rather
//! than ignoring entire tests.
use crate::{CheckCtx, Float, Int, TestResult};

/// Type implementing [`IgnoreCase`].
pub struct SpecialCase;

/// Don't run further validation on this test case.
const SKIP: Option<TestResult> = Some(Ok(()));

/// Return this to skip checks on a test that currently fails but shouldn't. Looks
/// the same as skip, but we keep them separate to better indicate purpose.
const XFAIL: Option<TestResult> = Some(Ok(()));

/// Allow overriding the outputs of specific test cases.
///
/// There are some cases where we want to xfail specific cases or handle certain inputs
/// differently than the rest of calls to `validate`. This provides a hook to do that.
///
/// If `None` is returned, checks will proceed as usual. If `Some(result)` is returned, checks
/// are skipped and the provided result is returned instead.
///
/// This gets implemented once per input type, then the functions provide further filtering
/// based on function name and values.
///
/// `ulp` can also be set to adjust the ULP for that specific test, even if `None` is still
/// returned.
pub trait MaybeOverride<Input> {
fn check_float<F: Float>(
_input: Input,
_actual: F,
_expected: F,
_ulp: &mut u32,
_ctx: &CheckCtx,
) -> Option<TestResult> {
None
}

fn check_int<I: Int>(
_input: Input,
_actual: I,
_expected: I,
_ctx: &CheckCtx,
) -> Option<TestResult> {
None
}
}

impl MaybeOverride<(f32,)> for SpecialCase {
fn check_float<F: Float>(
_input: (f32,),
actual: F,
expected: F,
_ulp: &mut u32,
ctx: &CheckCtx,
) -> Option<TestResult> {
maybe_check_nan_bits(actual, expected, ctx)
}
}

impl MaybeOverride<(f64,)> for SpecialCase {
fn check_float<F: Float>(
_input: (f64,),
actual: F,
expected: F,
_ulp: &mut u32,
ctx: &CheckCtx,
) -> Option<TestResult> {
maybe_check_nan_bits(actual, expected, ctx)
}
}

impl MaybeOverride<(f32, f32)> for SpecialCase {}
impl MaybeOverride<(f64, f64)> for SpecialCase {}
impl MaybeOverride<(f32, f32, f32)> for SpecialCase {}
impl MaybeOverride<(f64, f64, f64)> for SpecialCase {}
impl MaybeOverride<(i32, f32)> for SpecialCase {}
impl MaybeOverride<(i32, f64)> for SpecialCase {}
impl MaybeOverride<(f32, i32)> for SpecialCase {}
impl MaybeOverride<(f64, i32)> for SpecialCase {}

/// Check NaN bits if the function requires it
fn maybe_check_nan_bits<F: Float>(actual: F, expected: F, ctx: &CheckCtx) -> Option<TestResult> {
if !(ctx.canonical_name == "abs" || ctx.canonical_name == "copysigh") {
return None;
}

// abs and copysign require signaling NaNs to be propagated, so verify bit equality.
if actual.to_bits() == expected.to_bits() {
return SKIP;
} else {
Some(Err(anyhow::anyhow!("NaNs have different bitpatterns")))
}
}
39 changes: 34 additions & 5 deletions crates/libm-test/src/test_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::fmt;

use anyhow::{Context, bail, ensure};

use crate::{Float, Hex, Int, TestResult};
use crate::{Float, Hex, Int, MaybeOverride, SpecialCase, TestResult};

/// Implement this on types that can generate a sequence of tuples for test input.
pub trait GenerateInput<TupleArgs> {
Expand All @@ -34,10 +34,19 @@ pub struct CheckCtx {
pub ulp: u32,
/// Function name.
pub fname: &'static str,
/// Return the unsuffixed version of the function name.
pub canonical_name: &'static str,
/// Source of truth for tests.
pub basis: CheckBasis,
}

impl CheckCtx {
pub fn new(ulp: u32, fname: &'static str, basis: CheckBasis) -> Self {
let canonical_fname = crate::canonical_name(fname);
Self { ulp, fname, canonical_name: canonical_fname, basis }
}
}

/// Possible items to test against
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum CheckBasis {}
Expand Down Expand Up @@ -135,10 +144,20 @@ where
F: Float + Hex,
Input: Hex + fmt::Debug,
u32: TryFrom<F::SignedInt, Error: fmt::Debug>,
SpecialCase: MaybeOverride<Input>,
{
fn validate<'a>(self, expected: Self, input: Input, ctx: &CheckCtx) -> TestResult {
// Create a wrapper function so we only need to `.with_context` once.
let inner = || -> TestResult {
let mut allowed_ulp = ctx.ulp;

// If the tested function requires a nonstandard test, run it here.
if let Some(res) =
SpecialCase::check_float(input, self, expected, &mut allowed_ulp, ctx)
{
return res;
}

// Check when both are NaNs
if self.is_nan() && expected.is_nan() {
ensure!(self.to_bits() == expected.to_bits(), "NaNs have different bitpatterns");
Expand Down Expand Up @@ -166,7 +185,6 @@ where
let ulp_u32 = u32::try_from(ulp_diff)
.map_err(|e| anyhow::anyhow!("{e:?}: ulp of {ulp_diff} exceeds u32::MAX"))?;

let allowed_ulp = ctx.ulp;
ensure!(ulp_u32 <= allowed_ulp, "ulp {ulp_diff} > {allowed_ulp}",);

Ok(())
Expand All @@ -191,17 +209,28 @@ where
macro_rules! impl_tuples {
($(($a:ty, $b:ty);)*) => {
$(
impl<Input: Hex + fmt::Debug> CheckOutput<Input> for ($a, $b) {
impl<Input> CheckOutput<Input> for ($a, $b)
where
Input: Hex + fmt::Debug,
SpecialCase: MaybeOverride<Input>,
{
fn validate<'a>(
self,
expected: Self,
input: Input,
ctx: &CheckCtx,
) -> TestResult {
self.0.validate(expected.0, input, ctx,)
self.0.validate(expected.0, input, ctx)
.and_then(|()| self.1.validate(expected.1, input, ctx))
.with_context(|| format!(
"full input {input:?} full actual {self:?} expected {expected:?}"
"full context:\
\n input: {input:?} {ibits}\
\n expected: {expected:?} {expbits}\
\n actual: {self:?} {actbits}\
",
actbits = self.hex(),
expbits = expected.hex(),
ibits = input.hex(),
))
}
}
Expand Down

0 comments on commit 364e13e

Please sign in to comment.