diff --git a/src/math/Interpolation.cpp b/src/math/Interpolation.cpp index 290f55f415..a6cfff0a04 100644 --- a/src/math/Interpolation.cpp +++ b/src/math/Interpolation.cpp @@ -98,6 +98,8 @@ double interpolateRecursive( double v2 = interpolateRecursive(queryPoint, points, nextLowerBounds, nextUpperBounds, dimension + 1); + // TODO: Handle cases where we clamp one and not the other + // Distance should be based on the clamped value, not the true query point return v1 * (1 - t) + v2 * t; } @@ -115,8 +117,12 @@ double interpolate(const std::vector &queryPoint, points.numDimensions); for (size_t i = 0; i < points.numDimensions; ++i) { - clampedQueryPoint[i] = clamp(queryPoint[i], points.uniqueValues[i].front(), - points.uniqueValues[i].back()); + double clampMax = + std::max(points.uniqueValues[i].front(), points.uniqueValues[i].back()); + double clampMin = + std::min(points.uniqueValues[i].front(), points.uniqueValues[i].back()); + + clampedQueryPoint[i] = clamp(queryPoint[i], clampMin, clampMax); lowerBounds[i] = std::lower_bound(points.uniqueValues[i].begin(), @@ -128,6 +134,7 @@ double interpolate(const std::vector &queryPoint, upperBounds[i] = std::next(lowerBounds[i]); if (upperBounds[i] == points.uniqueValues[i].end()) { + upperBounds[i] = lowerBounds[i]; } } diff --git a/tests/unit_tests/InterpolationTest.h b/tests/unit_tests/InterpolationTest.h index 2b4579a1c1..6acdf3ed3e 100644 --- a/tests/unit_tests/InterpolationTest.h +++ b/tests/unit_tests/InterpolationTest.h @@ -309,4 +309,52 @@ class InterpolationTest : public CxxTest::TestSuite { std::cout << "Finished testOutOfBoundsInterpolation" << std::endl; std::cout << "#########################################\n" << std::endl; } + + void testClampingIssue() { + std::cout << "\n#########################################" << std::endl; + std::cout << "Starting testClampingIssue\n" << std::endl; + + // Create a 4D grid with more points per dimension + PointCloud points; + points.numDimensions = 4; + points.uniqueValues = { + {0.0, 0.5, 1.0}, // x-values + {0.0, 0.5, 1.0}, // y-values + {0.0, 0.5, 1.0}, // z-values + {-1.0, -0.5, 0.0} // zz-values + }; + + // Define a non-linear function: f(x, y, z) = x * y + y * z + z * x + auto nonLinearFunction = [](const std::vector &x) { + return x[0] * x[1] + x[1] * x[2] + x[2] * x[0] + x[3]; + }; + + // Populate the point map with function values + for (double x : points.uniqueValues[0]) { + for (double y : points.uniqueValues[1]) { + for (double z : points.uniqueValues[2]) { + for (double zz : points.uniqueValues[3]) { + std::vector point = {x, y, z, zz}; + points.pointMap[point] = nonLinearFunction(point); + } + } + } + } + + // Choose a query point that requires interpolation + std::vector queryPoint = {0.25, 0.75, 0.5, -0.25}; + + // Expected value calculated directly + double expectedValue = nonLinearFunction(queryPoint); + + // Perform interpolation + double interpolatedValue = interpolate(queryPoint, points); + + // Use a reasonable tolerance + double tolerance = 1e-4; + TS_ASSERT_DELTA(interpolatedValue, expectedValue, tolerance); + + std::cout << "\nFinished testClampingIssue" << std::endl; + std::cout << "#########################################\n" << std::endl; + } };