Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions math/core/sources/internal/common.move
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,115 @@ public(package) fun clz(val: u256, bit_width: u16): u16 {

count
}

/// Returns the square root of a number. If the number is not a perfect square, the value is rounded
/// towards zero.
///
/// This method is based on Newton's method for computing square roots. The algorithm is restricted to only
/// using integer operations.
public(package) fun sqrt_floor(a: u256): u256 {
// Take care of easy edge cases: sqrt(0) = 0 and sqrt(1) = 1
if (a <= 1) {
return a
};
let mut aa = a;
let mut xn = 1;

// In this function, we use Newton's method to get a root of `f(x) := x² - a`. It involves building a
// sequence x_n that converges toward sqrt(a). For each iteration x_n, we also define the error between
// the current value as `ε_n = | x_n - sqrt(a) |`.
//
// For our first estimation, we consider `e` the smallest power of 2 which is bigger than the square root
// of the target. (i.e. `2**(e-1) ≤ sqrt(a) < 2**e`). We know that `e ≤ 128` because `(2¹²⁸)² = 2²⁵⁶` is
// bigger than any uint256.
//
// By noticing that
// `2**(e-1) ≤ sqrt(a) < 2**e → (2**(e-1))² ≤ a < (2**e)² → 2**(2*e-2) ≤ a < 2**(2*e)`
// we can deduce that `e - 1` is `log2(a) / 2`. We can thus compute `x_n = 2**(e-1)` using a method similar
// to the msb function.
if (aa >= (1 << 128)) {
aa = aa >> 128;
xn = xn << 64;
};
if (aa >= (1 << 64)) {
aa = aa >> 64;
xn = xn << 32;
};
if (aa >= (1 << 32)) {
aa = aa >> 32;
xn = xn << 16;
};
if (aa >= (1 << 16)) {
aa = aa >> 16;
xn = xn << 8;
};
if (aa >= (1 << 8)) {
aa = aa >> 8;
xn = xn << 4;
};
if (aa >= (1 << 4)) {
aa = aa >> 4;
xn = xn << 2;
};
if (aa >= (1 << 2)) {
xn = xn << 1;
};

// We now have x_n such that `x_n = 2**(e-1) ≤ sqrt(a) < 2**e = 2 * x_n`. This implies ε_n ≤ 2**(e-1).
//
// We can refine our estimation by noticing that the middle of that interval minimizes the error.
// If we move x_n to equal 2**(e-1) + 2**(e-2), then we reduce the error to ε_n ≤ 2**(e-2).
// This is going to be our x_0 (and ε_0).
xn = (3 * xn) >> 1; // ε_0 := | x_0 - sqrt(a) | ≤ 2**(e-2)

// From here, Newton's method give us:
// x_{n+1} = (x_n + a / x_n) / 2
//
// One should note that:
// x_{n+1}² - a = ((x_n + a / x_n) / 2)² - a
// = ((x_n² + a) / (2 * x_n))² - a
// = (x_n⁴ + 2 * a * x_n² + a²) / (4 * x_n²) - a
// = (x_n⁴ + 2 * a * x_n² + a² - 4 * a * x_n²) / (4 * x_n²)
// = (x_n⁴ - 2 * a * x_n² + a²) / (4 * x_n²)
// = (x_n² - a)² / (2 * x_n)²
// = ((x_n² - a) / (2 * x_n))²
// ≥ 0
// Which proves that for all n ≥ 1, sqrt(a) ≤ x_n
//
// This gives us the proof of quadratic convergence of the sequence:
// ε_{n+1} = | x_{n+1} - sqrt(a) |
// = | (x_n + a / x_n) / 2 - sqrt(a) |
// = | (x_n² + a - 2*x_n*sqrt(a)) / (2 * x_n) |
// = | (x_n - sqrt(a))² / (2 * x_n) |
// = | ε_n² / (2 * x_n) |
// = ε_n² / | (2 * x_n) |
//
// For the first iteration, we have a special case where x_0 is known:
// ε_1 = ε_0² / | (2 * x_0) |
// ≤ (2**(e-2))² / (2 * (2**(e-1) + 2**(e-2)))
// ≤ 2**(2*e-4) / (3 * 2**(e-1))
// ≤ 2**(e-3) / 3
// ≤ 2**(e-3-log2(3))
// ≤ 2**(e-4.5)
//
// For the following iterations, we use the fact that, 2**(e-1) ≤ sqrt(a) ≤ x_n:
// ε_{n+1} = ε_n² / | (2 * x_n) |
// ≤ (2**(e-k))² / (2 * 2**(e-1))
// ≤ 2**(2*e-2*k) / 2**e
// ≤ 2**(e-2*k)
xn = (xn + a / xn) >> 1; // ε_1 := | x_1 - sqrt(a) | ≤ 2**(e-4.5) -- special case, see above
xn = (xn + a / xn) >> 1; // ε_2 := | x_2 - sqrt(a) | ≤ 2**(e-9) -- general case with k = 4.5
xn = (xn + a / xn) >> 1; // ε_3 := | x_3 - sqrt(a) | ≤ 2**(e-18) -- general case with k = 9
xn = (xn + a / xn) >> 1; // ε_4 := | x_4 - sqrt(a) | ≤ 2**(e-36) -- general case with k = 18
xn = (xn + a / xn) >> 1; // ε_5 := | x_5 - sqrt(a) | ≤ 2**(e-72) -- general case with k = 36
xn = (xn + a / xn) >> 1; // ε_6 := | x_6 - sqrt(a) | ≤ 2**(e-144) -- general case with k = 72

// Because e ≤ 128 (as discussed during the first estimation phase), we know have reached a precision
// ε_6 ≤ 2**(e-144) < 1. Given we're operating on integers, then we can ensure that xn is now either
// sqrt(a) or sqrt(a) + 1.
if (xn > a / xn) {
xn - 1
} else {
xn
}
}
73 changes: 73 additions & 0 deletions math/core/sources/internal/macros.move
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,28 @@ public(package) fun mul_shr_u256_wide(
(false, result)
}

/// Compute the square root of an unsigned integer with configurable rounding.
///
/// This macro provides a uniform API for `sqrt` across all unsigned integer widths. It normalises
/// the input to `u256`, calculates the integer square root, and applies the requested rounding mode.
/// The algorithm uses a binary search to find the floor of the square root, then determines whether
/// to round up based on the rounding mode.
///
/// #### Generics
/// - `$Int`: Any unsigned integer type (`u8`, `u16`, `u32`, `u64`, `u128`, or `u256`).
///
/// #### Parameters
/// - `$value`: The unsigned integer to calculate the square root of.
/// - `$rounding_mode`: Rounding strategy drawn from `rounding::RoundingMode`.
///
/// #### Returns
/// The square root of `$value` rounded according to `$rounding_mode`, cast back to `$Int`.
public(package) macro fun sqrt<$Int>($value: $Int, $rounding_mode: RoundingMode): $Int {
let (value, rounding_mode) = ($value as u256, $rounding_mode);
let floor_res = common::sqrt_floor(value);
round_sqrt_result(value, floor_res, rounding_mode) as $Int
}

/// Determine whether rounding up is required after dividing and apply it to `result`.
/// Returns `(overflow, result)` where `overflow` is `true` if the rounded value cannot be represented as `u256`.
public(package) fun round_division_result(
Expand Down Expand Up @@ -586,3 +608,54 @@ public(package) fun log256_should_round_up(value: u256, floor_log: u16): bool {
let threshold = 1 << (threshold_exp as u8);
value >= threshold
}

/// Apply rounding mode to the floor result of a square root calculation.
///
/// For nearest rounding, compares the distance from `value` to `floor²` versus the distance
/// to `ceil²` where `ceil = floor + 1`.
///
/// Given:
/// - `distance_to_floor = value - floor²`
/// - `distance_to_ceil = (floor + 1)² - value = floor² + 2·floor + 1 - value`
///
/// We want to round down if `distance_to_floor < distance_to_ceil`, which expands to:
/// ```
/// value - floor² < floor² + 2·floor + 1 - value
/// 2·value < 2·floor² + 2·floor + 1
/// 2·(value - floor²) < 2·floor + 1
/// ```
///
/// Since we're working with integers, dividing both sides by 2 gives us:
/// `value - floor² <= floor`
///
/// Considering that `sqrt(u256::MAX)` < `2^128`, all arithmetic operations in the function
/// are guaranteed to not overflow or underflow.
///
/// #### Parameters
/// - `value`: The original value whose square root was calculated.
/// - `floor_result`: The floor of the square root.
/// - `rounding_mode`: Rounding strategy drawn from `rounding::RoundingMode`.
///
/// #### Returns
/// The square root rounded according to the specified mode.
public(package) fun round_sqrt_result(
value: u256,
floor_result: u256,
rounding_mode: RoundingMode,
): u256 {
if (rounding_mode == rounding::down()) {
return floor_result
};

let floor_squared = floor_result * floor_result;
if (floor_squared == value) {
// Perfect square, no rounding needed
floor_result
} else if (rounding_mode == rounding::up()) {
floor_result + 1
} else if (value - floor_squared <= floor_result) {
floor_result
} else {
floor_result + 1
}
}
7 changes: 7 additions & 0 deletions math/core/sources/u128.move
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,10 @@ public fun log2(value: u128, rounding_mode: RoundingMode): u8 {
public fun log256(value: u128, rounding_mode: RoundingMode): u8 {
macros::log256!(value, BIT_WIDTH as u16, rounding_mode)
}

/// Compute the square root of a value with configurable rounding.
///
/// Returns 0 if given 0.
public fun sqrt(value: u128, rounding_mode: RoundingMode): u128 {
macros::sqrt!(value, rounding_mode)
}
7 changes: 7 additions & 0 deletions math/core/sources/u16.move
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,10 @@ public fun log2(value: u16, rounding_mode: RoundingMode): u8 {
public fun log256(value: u16, rounding_mode: RoundingMode): u8 {
macros::log256!(value, BIT_WIDTH as u16, rounding_mode)
}

/// Compute the square root of a value with configurable rounding.
///
/// Returns 0 if given 0.
public fun sqrt(value: u16, rounding_mode: RoundingMode): u16 {
macros::sqrt!(value, rounding_mode)
}
7 changes: 7 additions & 0 deletions math/core/sources/u256.move
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,10 @@ public fun log2(value: u256, rounding_mode: RoundingMode): u16 {
public fun log256(value: u256, rounding_mode: RoundingMode): u8 {
macros::log256!(value, BIT_WIDTH, rounding_mode)
}

/// Compute the square root of a value with configurable rounding.
///
/// Returns 0 if given 0.
public fun sqrt(value: u256, rounding_mode: RoundingMode): u256 {
macros::sqrt!(value, rounding_mode)
}
7 changes: 7 additions & 0 deletions math/core/sources/u32.move
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,10 @@ public fun log2(value: u32, rounding_mode: RoundingMode): u8 {
public fun log256(value: u32, rounding_mode: RoundingMode): u8 {
macros::log256!(value, BIT_WIDTH as u16, rounding_mode)
}

/// Compute the square root of a value with configurable rounding.
///
/// Returns 0 if given 0.
public fun sqrt(value: u32, rounding_mode: RoundingMode): u32 {
macros::sqrt!(value, rounding_mode)
}
7 changes: 7 additions & 0 deletions math/core/sources/u64.move
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,10 @@ public fun log2(value: u64, rounding_mode: RoundingMode): u8 {
public fun log256(value: u64, rounding_mode: RoundingMode): u8 {
macros::log256!(value, BIT_WIDTH as u16, rounding_mode)
}

/// Compute the square root of a value with configurable rounding.
///
/// Returns 0 if given 0.
public fun sqrt(value: u64, rounding_mode: RoundingMode): u64 {
macros::sqrt!(value, rounding_mode)
}
7 changes: 7 additions & 0 deletions math/core/sources/u8.move
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,10 @@ public fun log2(value: u8, rounding_mode: RoundingMode): u8 {
public fun log256(value: u8, rounding_mode: RoundingMode): u8 {
macros::log256!(value, BIT_WIDTH as u16, rounding_mode)
}

/// Compute the square root of a value with configurable rounding.
///
/// Returns 0 if given 0.
public fun sqrt(value: u8, rounding_mode: RoundingMode): u8 {
macros::sqrt!(value, rounding_mode)
}
106 changes: 106 additions & 0 deletions math/core/tests/common_tests.move
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,109 @@ fun clz_handles_u128_values() {
let mid_bit: u128 = 1u128 << 40;
assert_eq!(common::clz(mid_bit as u256, 128), 87);
}

// === sqrt ===

#[test]
fun sqrt_returns_zero_for_zero() {
assert_eq!(common::sqrt_floor(0), 0);
}

#[test]
fun sqrt_handles_perfect_squares() {
assert_eq!(common::sqrt_floor(4), 2);
assert_eq!(common::sqrt_floor(9), 3);
assert_eq!(common::sqrt_floor(16), 4);
assert_eq!(common::sqrt_floor(25), 5);
assert_eq!(common::sqrt_floor(100), 10);
assert_eq!(common::sqrt_floor(256), 16);
assert_eq!(common::sqrt_floor(65536), 256);
assert_eq!(common::sqrt_floor(1 << 64), 1 << 32);
assert_eq!(common::sqrt_floor(1 << 128), 1 << 64);
}

#[test]
fun sqrt_floors_non_perfect_squares() {
assert_eq!(common::sqrt_floor(2), 1); // 1.414... → 1
assert_eq!(common::sqrt_floor(3), 1); // 1.732... → 1
assert_eq!(common::sqrt_floor(5), 2); // 2.236... → 2
assert_eq!(common::sqrt_floor(8), 2); // 2.828... → 2
assert_eq!(common::sqrt_floor(15), 3); // 3.873... → 3
assert_eq!(common::sqrt_floor(99), 9); // 9.950... → 9
assert_eq!(common::sqrt_floor(255), 15); // 15.969... → 15
}

#[test]
fun sqrt_handles_u8_values() {
let zero: u8 = 0;
assert_eq!(common::sqrt_floor(zero as u256), 0);

let one: u8 = 1;
assert_eq!(common::sqrt_floor(one as u256), 1);

let four: u8 = 4;
assert_eq!(common::sqrt_floor(four as u256), 2);

let max: u8 = std::u8::max_value!();
assert_eq!(common::sqrt_floor(max as u256), 15);
}

#[test]
fun sqrt_handles_u16_values() {
let perfect: u16 = 256;
assert_eq!(common::sqrt_floor(perfect as u256), 16);

let non_perfect: u16 = 1000;
assert_eq!(common::sqrt_floor(non_perfect as u256), 31);

let max: u16 = std::u16::max_value!();
assert_eq!(common::sqrt_floor(max as u256), 255);
}

#[test]
fun sqrt_handles_u32_values() {
let perfect: u32 = 65536;
assert_eq!(common::sqrt_floor(perfect as u256), 256);

let non_perfect: u32 = 1000000;
assert_eq!(common::sqrt_floor(non_perfect as u256), 1000);

let max: u32 = std::u32::max_value!();
assert_eq!(common::sqrt_floor(max as u256), 65535);
}

#[test]
fun sqrt_handles_u64_values() {
let perfect: u64 = 1 << 32;
assert_eq!(common::sqrt_floor(perfect as u256), 1 << 16);

let non_perfect: u64 = 100000000;
assert_eq!(common::sqrt_floor(non_perfect as u256), 10000);

let max: u64 = std::u64::max_value!();
assert_eq!(common::sqrt_floor(max as u256), 4294967295);
}

#[test]
fun sqrt_handles_u128_values() {
let perfect: u128 = 1 << 64;
assert_eq!(common::sqrt_floor(perfect as u256), 1 << 32);

let large: u128 = 1000000000000000000;
assert_eq!(common::sqrt_floor(large as u256), 1000000000);

let max: u128 = std::u128::max_value!();
assert_eq!(common::sqrt_floor(max as u256), std::u64::max_value!() as u256);
}

#[test]
fun sqrt_handles_u256_values() {
let perfect: u256 = 1 << 128;
assert_eq!(common::sqrt_floor(perfect), 1 << 64);

let large: u256 = 1 << 200;
assert_eq!(common::sqrt_floor(large), 1 << 100);

let max: u256 = std::u256::max_value!();
assert_eq!(common::sqrt_floor(max), std::u128::max_value!() as u256);
}
Loading
Loading