Skip to content

Commit

Permalink
Add dist err functions
Browse files Browse the repository at this point in the history
  • Loading branch information
KmolYuan committed May 23, 2024
1 parent 4393227 commit f7f8c5a
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "efd"
description = "1D/2D/3D Elliptical Fourier Descriptor (EFD) implementation in Rust."
version = "10.0.2"
version = "10.1.0"
authors = ["KmolYuan <[email protected]>"]
edition = "2021"
license = "MIT"
Expand Down
16 changes: 5 additions & 11 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,14 @@ use crate::*;
use alloc::{vec, vec::Vec};
#[cfg(test)]
use approx::assert_abs_diff_eq;
#[cfg(test)]
use core::iter::zip;
pub use util::dist_err as curve_diff;

/// Epsilon for curve difference.
pub const EPS: f64 = 2.2e-14;
pub const RES: usize = 1000;

/// Error between two curves, the length of the curves must be the same.
pub fn curve_diff<const D: usize>(a: impl Curve<D>, b: impl Curve<D>) -> f64 {
zip(a.as_curve(), b.as_curve())
.map(|(a, b)| a.l2_err(b))
.sum::<f64>()
/ a.len().min(b.len()) as f64
}

#[test]
fn error() {
let coeff = vec![
Expand Down Expand Up @@ -62,7 +56,7 @@ fn efd2d() {
// Test reconstruction
let target = efd.recon_norm_by(&get_norm_t(CURVE2D, false));
let curve = efd.as_geo().inverse().transform(CURVE2D);
assert_abs_diff_eq!(curve_diff(target, curve), 0., epsilon = 0.01695);
assert_abs_diff_eq!(curve_diff(target, curve), 0., epsilon = 0.0029);
}

#[test]
Expand All @@ -86,7 +80,7 @@ fn efd2d_open() {
// Test reconstruction
let target = efd.recon_norm_by(&get_norm_t(CURVE2D_OPEN, true));
let curve = efd.as_geo().inverse().transform(CURVE2D_OPEN);
assert_abs_diff_eq!(curve_diff(target, curve), 0., epsilon = 0.0143);
assert_abs_diff_eq!(curve_diff(target, curve), 0., epsilon = 0.0411);
}

#[test]
Expand Down Expand Up @@ -122,7 +116,7 @@ fn efd3d() {
// Test reconstruction
let target = efd.recon_norm_by(&get_norm_t(CURVE3D, false));
let curve = efd.as_geo().inverse().transform(CURVE3D);
assert_abs_diff_eq!(curve_diff(target, curve), 0., epsilon = 0.00412);
assert_abs_diff_eq!(curve_diff(target, curve), 0., epsilon = 0.0013);
}

#[test]
Expand Down
80 changes: 79 additions & 1 deletion src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,88 @@ where
}

/// Check if the curve is valid.
///
/// See also [`is_valid_curve()`].
pub fn valid_curve<C, const D: usize>(curve: C) -> Option<C>
where
C: Curve<D>,
{
is_valid_curve(curve.as_curve()).then_some(curve)
}

/// Return true if the curve is valid.
#[inline]
pub fn is_valid_curve<C, const D: usize>(curve: C) -> bool
where
C: Curve<D>,
{
let c = curve.as_curve();
(c.len() > 2 && c.iter().flatten().all(|x| x.is_finite())).then_some(curve)
c.len() > 2 && c.iter().flatten().all(|x| x.is_finite())
}

/// Return the zipped average distance error between two curves.
///
/// Returns 0 if either curve is empty.
///
/// See also [`dist_err()`] for a more general case where the curves are not
/// corresponded. (`curve1[i] !== curve2[i]`)
pub fn dist_err_zipped<const D: usize>(curve1: impl Curve<D>, curve2: impl Curve<D>) -> f64 {
let len = curve1.len().min(curve2.len());
if len == 0 {
0.
} else {
core::iter::zip(curve1.as_curve(), curve2.as_curve())
.map(|(a, b)| a.l2_err(b))
.sum::<f64>()
/ len as f64
}
}

/// Return the average distance error between two curves.
///
/// In this algorithm, a curve is assumed to be longer or equal to another, and
/// the distance error is mapped to the nearest point in the shorter curve.
///
/// Returns 0 if either curve is empty.
///
/// See also [`dist_err_zipped()`] for faster computation if the curve points
/// are corresponded.
pub fn dist_err<const D: usize>(curve1: impl Curve<D>, curve2: impl Curve<D>) -> f64 {
if curve1.is_empty() || curve2.is_empty() {
return 0.;
}
let (mut iter1, iter2, len) = {
let iter1 = curve1.as_curve().iter();
let iter2 = curve2.as_curve().iter();
if curve1.len() >= curve2.len() {
(iter1, iter2, curve2.len())
} else {
(iter2, iter1, curve1.len())
}
};
let mut total = 0.;
let mut prev_err = None;
let mut last_pt1 = None; // a pointer indicating the last point of curve1
for pt2 in iter2 {
loop {
if let Some(pt1) = iter1.next() {
last_pt1 = Some(pt1);
let err = pt1.l2_err(pt2);
match prev_err {
// The previous error is the nearest
Some(prev_err) if err > prev_err => {
total += prev_err;
break;
}
// The previous error is not the nearest or unset
_ => prev_err = Some(err),
}
} else {
// curve1 is exhausted, compare the last point
total += last_pt1.unwrap().l2_err(pt2);
break;
}
}
}
total / len as f64
}

0 comments on commit f7f8c5a

Please sign in to comment.