Skip to content

Commit 2ed3dec

Browse files
authored
feat: sqrt implementation and tests (#53)
* Implement core logic of sqrt function * Add sqrt macro and helper to round result * Add sqrt functions for all uint modules * Add sqrt tests * Format and lint files * Fix typo * Format files
1 parent b548509 commit 2ed3dec

File tree

16 files changed

+1154
-0
lines changed

16 files changed

+1154
-0
lines changed

math/core/sources/internal/common.move

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,115 @@ public(package) fun msb(val: u256, bit_width: u16): u8 {
4343
// clz result for non-zero is guaranteed to be less than bit_width, so the subtraction is safe
4444
(bit_width - 1 - clz(val, bit_width)) as u8
4545
}
46+
47+
/// Returns the square root of a number. If the number is not a perfect square, the value is rounded
48+
/// towards zero.
49+
///
50+
/// This method is based on Newton's method for computing square roots. The algorithm is restricted to only
51+
/// using integer operations.
52+
public(package) fun sqrt_floor(a: u256): u256 {
53+
// Take care of easy edge cases: sqrt(0) = 0 and sqrt(1) = 1
54+
if (a <= 1) {
55+
return a
56+
};
57+
let mut aa = a;
58+
let mut xn = 1;
59+
60+
// In this function, we use Newton's method to get a root of `f(x) := x² - a`. It involves building a
61+
// sequence x_n that converges toward sqrt(a). For each iteration x_n, we also define the error between
62+
// the current value as `ε_n = | x_n - sqrt(a) |`.
63+
//
64+
// For our first estimation, we consider `e` the smallest power of 2 which is bigger than the square root
65+
// of the target. (i.e. `2**(e-1) ≤ sqrt(a) < 2**e`). We know that `e ≤ 128` because `(2¹²⁸)² = 2²⁵⁶` is
66+
// bigger than any uint256.
67+
//
68+
// By noticing that
69+
// `2**(e-1) ≤ sqrt(a) < 2**e → (2**(e-1))² ≤ a < (2**e)² → 2**(2*e-2) ≤ a < 2**(2*e)`
70+
// we can deduce that `e - 1` is `log2(a) / 2`. We can thus compute `x_n = 2**(e-1)` using a method similar
71+
// to the msb function.
72+
if (aa >= (1 << 128)) {
73+
aa = aa >> 128;
74+
xn = xn << 64;
75+
};
76+
if (aa >= (1 << 64)) {
77+
aa = aa >> 64;
78+
xn = xn << 32;
79+
};
80+
if (aa >= (1 << 32)) {
81+
aa = aa >> 32;
82+
xn = xn << 16;
83+
};
84+
if (aa >= (1 << 16)) {
85+
aa = aa >> 16;
86+
xn = xn << 8;
87+
};
88+
if (aa >= (1 << 8)) {
89+
aa = aa >> 8;
90+
xn = xn << 4;
91+
};
92+
if (aa >= (1 << 4)) {
93+
aa = aa >> 4;
94+
xn = xn << 2;
95+
};
96+
if (aa >= (1 << 2)) {
97+
xn = xn << 1;
98+
};
99+
100+
// We now have x_n such that `x_n = 2**(e-1) ≤ sqrt(a) < 2**e = 2 * x_n`. This implies ε_n ≤ 2**(e-1).
101+
//
102+
// We can refine our estimation by noticing that the middle of that interval minimizes the error.
103+
// If we move x_n to equal 2**(e-1) + 2**(e-2), then we reduce the error to ε_n ≤ 2**(e-2).
104+
// This is going to be our x_0 (and ε_0).
105+
xn = (3 * xn) >> 1; // ε_0 := | x_0 - sqrt(a) | ≤ 2**(e-2)
106+
107+
// From here, Newton's method give us:
108+
// x_{n+1} = (x_n + a / x_n) / 2
109+
//
110+
// One should note that:
111+
// x_{n+1}² - a = ((x_n + a / x_n) / 2)² - a
112+
// = ((x_n² + a) / (2 * x_n))² - a
113+
// = (x_n⁴ + 2 * a * x_n² + a²) / (4 * x_n²) - a
114+
// = (x_n⁴ + 2 * a * x_n² + a² - 4 * a * x_n²) / (4 * x_n²)
115+
// = (x_n⁴ - 2 * a * x_n² + a²) / (4 * x_n²)
116+
// = (x_n² - a)² / (2 * x_n)²
117+
// = ((x_n² - a) / (2 * x_n))²
118+
// ≥ 0
119+
// Which proves that for all n ≥ 1, sqrt(a) ≤ x_n
120+
//
121+
// This gives us the proof of quadratic convergence of the sequence:
122+
// ε_{n+1} = | x_{n+1} - sqrt(a) |
123+
// = | (x_n + a / x_n) / 2 - sqrt(a) |
124+
// = | (x_n² + a - 2*x_n*sqrt(a)) / (2 * x_n) |
125+
// = | (x_n - sqrt(a))² / (2 * x_n) |
126+
// = | ε_n² / (2 * x_n) |
127+
// = ε_n² / | (2 * x_n) |
128+
//
129+
// For the first iteration, we have a special case where x_0 is known:
130+
// ε_1 = ε_0² / | (2 * x_0) |
131+
// ≤ (2**(e-2))² / (2 * (2**(e-1) + 2**(e-2)))
132+
// ≤ 2**(2*e-4) / (3 * 2**(e-1))
133+
// ≤ 2**(e-3) / 3
134+
// ≤ 2**(e-3-log2(3))
135+
// ≤ 2**(e-4.5)
136+
//
137+
// For the following iterations, we use the fact that, 2**(e-1) ≤ sqrt(a) ≤ x_n:
138+
// ε_{n+1} = ε_n² / | (2 * x_n) |
139+
// ≤ (2**(e-k))² / (2 * 2**(e-1))
140+
// ≤ 2**(2*e-2*k) / 2**e
141+
// ≤ 2**(e-2*k)
142+
xn = (xn + a / xn) >> 1; // ε_1 := | x_1 - sqrt(a) | ≤ 2**(e-4.5) -- special case, see above
143+
xn = (xn + a / xn) >> 1; // ε_2 := | x_2 - sqrt(a) | ≤ 2**(e-9) -- general case with k = 4.5
144+
xn = (xn + a / xn) >> 1; // ε_3 := | x_3 - sqrt(a) | ≤ 2**(e-18) -- general case with k = 9
145+
xn = (xn + a / xn) >> 1; // ε_4 := | x_4 - sqrt(a) | ≤ 2**(e-36) -- general case with k = 18
146+
xn = (xn + a / xn) >> 1; // ε_5 := | x_5 - sqrt(a) | ≤ 2**(e-72) -- general case with k = 36
147+
xn = (xn + a / xn) >> 1; // ε_6 := | x_6 - sqrt(a) | ≤ 2**(e-144) -- general case with k = 72
148+
149+
// Because e ≤ 128 (as discussed during the first estimation phase), we now have reached a precision
150+
// ε_6 ≤ 2**(e-144) < 1. Given we're operating on integers, then we can ensure that xn is now either
151+
// sqrt(a) or sqrt(a) + 1.
152+
if (xn > a / xn) {
153+
xn - 1
154+
} else {
155+
xn
156+
}
157+
}

math/core/sources/internal/macros.move

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,28 @@ public(package) fun mul_shr_u256_wide(
505505
(false, result)
506506
}
507507

508+
/// Compute the square root of an unsigned integer with configurable rounding.
509+
///
510+
/// This macro provides a uniform API for `sqrt` across all unsigned integer widths. It normalises
511+
/// the input to `u256`, calculates the integer square root, and applies the requested rounding mode.
512+
/// The algorithm uses a binary search to find the floor of the square root, then determines whether
513+
/// to round up based on the rounding mode.
514+
///
515+
/// #### Generics
516+
/// - `$Int`: Any unsigned integer type (`u8`, `u16`, `u32`, `u64`, `u128`, or `u256`).
517+
///
518+
/// #### Parameters
519+
/// - `$value`: The unsigned integer to calculate the square root of.
520+
/// - `$rounding_mode`: Rounding strategy drawn from `rounding::RoundingMode`.
521+
///
522+
/// #### Returns
523+
/// The square root of `$value` rounded according to `$rounding_mode`, cast back to `$Int`.
524+
public(package) macro fun sqrt<$Int>($value: $Int, $rounding_mode: RoundingMode): $Int {
525+
let (value, rounding_mode) = ($value as u256, $rounding_mode);
526+
let floor_res = common::sqrt_floor(value);
527+
round_sqrt_result(value, floor_res, rounding_mode) as $Int
528+
}
529+
508530
/// Determine whether rounding up is required after dividing and apply it to `result`.
509531
/// Returns `(overflow, result)` where `overflow` is `true` if the rounded value cannot be represented as `u256`.
510532
public(package) fun round_division_result(
@@ -603,3 +625,54 @@ public(package) fun log256_should_round_up(value: u256, floor_log: u16): bool {
603625
let threshold = 1 << (threshold_exp as u8);
604626
value >= threshold
605627
}
628+
629+
/// Apply rounding mode to the floor result of a square root calculation.
630+
///
631+
/// For nearest rounding, compares the distance from `value` to `floor²` versus the distance
632+
/// to `ceil²` where `ceil = floor + 1`.
633+
///
634+
/// Given:
635+
/// - `distance_to_floor = value - floor²`
636+
/// - `distance_to_ceil = (floor + 1)² - value = floor² + 2·floor + 1 - value`
637+
///
638+
/// We want to round down if `distance_to_floor < distance_to_ceil`, which expands to:
639+
/// ```
640+
/// value - floor² < floor² + 2·floor + 1 - value
641+
/// 2·value < 2·floor² + 2·floor + 1
642+
/// 2·(value - floor²) < 2·floor + 1
643+
/// ```
644+
///
645+
/// Since we're working with integers, dividing both sides by 2 gives us:
646+
/// `value - floor² <= floor`
647+
///
648+
/// Considering that `sqrt(u256::MAX)` < `2^128`, all arithmetic operations in the function
649+
/// are guaranteed to not overflow or underflow.
650+
///
651+
/// #### Parameters
652+
/// - `value`: The original value whose square root was calculated.
653+
/// - `floor_result`: The floor of the square root.
654+
/// - `rounding_mode`: Rounding strategy drawn from `rounding::RoundingMode`.
655+
///
656+
/// #### Returns
657+
/// The square root rounded according to the specified mode.
658+
public(package) fun round_sqrt_result(
659+
value: u256,
660+
floor_result: u256,
661+
rounding_mode: RoundingMode,
662+
): u256 {
663+
if (rounding_mode == rounding::down()) {
664+
return floor_result
665+
};
666+
667+
let floor_squared = floor_result * floor_result;
668+
if (floor_squared == value) {
669+
// Perfect square, no rounding needed
670+
floor_result
671+
} else if (rounding_mode == rounding::up()) {
672+
floor_result + 1
673+
} else if (value - floor_squared <= floor_result) {
674+
floor_result
675+
} else {
676+
floor_result + 1
677+
}
678+
}

math/core/sources/u128.move

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,10 @@ public fun log2(value: u128, rounding_mode: RoundingMode): u8 {
8686
public fun log256(value: u128, rounding_mode: RoundingMode): u8 {
8787
macros::log256!(value, BIT_WIDTH as u16, rounding_mode)
8888
}
89+
90+
/// Compute the square root of a value with configurable rounding.
91+
///
92+
/// Returns 0 if given 0.
93+
public fun sqrt(value: u128, rounding_mode: RoundingMode): u128 {
94+
macros::sqrt!(value, rounding_mode)
95+
}

math/core/sources/u16.move

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,10 @@ public fun log2(value: u16, rounding_mode: RoundingMode): u8 {
8686
public fun log256(value: u16, rounding_mode: RoundingMode): u8 {
8787
macros::log256!(value, BIT_WIDTH as u16, rounding_mode)
8888
}
89+
90+
/// Compute the square root of a value with configurable rounding.
91+
///
92+
/// Returns 0 if given 0.
93+
public fun sqrt(value: u16, rounding_mode: RoundingMode): u16 {
94+
macros::sqrt!(value, rounding_mode)
95+
}

math/core/sources/u256.move

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,10 @@ public fun log2(value: u256, rounding_mode: RoundingMode): u16 {
8080
public fun log256(value: u256, rounding_mode: RoundingMode): u8 {
8181
macros::log256!(value, BIT_WIDTH, rounding_mode)
8282
}
83+
84+
/// Compute the square root of a value with configurable rounding.
85+
///
86+
/// Returns 0 if given 0.
87+
public fun sqrt(value: u256, rounding_mode: RoundingMode): u256 {
88+
macros::sqrt!(value, rounding_mode)
89+
}

math/core/sources/u32.move

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,10 @@ public fun log2(value: u32, rounding_mode: RoundingMode): u8 {
8686
public fun log256(value: u32, rounding_mode: RoundingMode): u8 {
8787
macros::log256!(value, BIT_WIDTH as u16, rounding_mode)
8888
}
89+
90+
/// Compute the square root of a value with configurable rounding.
91+
///
92+
/// Returns 0 if given 0.
93+
public fun sqrt(value: u32, rounding_mode: RoundingMode): u32 {
94+
macros::sqrt!(value, rounding_mode)
95+
}

math/core/sources/u64.move

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,10 @@ public fun log2(value: u64, rounding_mode: RoundingMode): u8 {
8686
public fun log256(value: u64, rounding_mode: RoundingMode): u8 {
8787
macros::log256!(value, BIT_WIDTH as u16, rounding_mode)
8888
}
89+
90+
/// Compute the square root of a value with configurable rounding.
91+
///
92+
/// Returns 0 if given 0.
93+
public fun sqrt(value: u64, rounding_mode: RoundingMode): u64 {
94+
macros::sqrt!(value, rounding_mode)
95+
}

math/core/sources/u8.move

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,10 @@ public fun log2(value: u8, rounding_mode: RoundingMode): u8 {
8787
public fun log256(value: u8, rounding_mode: RoundingMode): u8 {
8888
macros::log256!(value, BIT_WIDTH as u16, rounding_mode)
8989
}
90+
91+
/// Compute the square root of a value with configurable rounding.
92+
///
93+
/// Returns 0 if given 0.
94+
public fun sqrt(value: u8, rounding_mode: RoundingMode): u8 {
95+
macros::sqrt!(value, rounding_mode)
96+
}

math/core/tests/common_tests.move

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,109 @@ fun msb_handles_u128_values() {
152152
let mid_bit: u128 = 1u128 << 40;
153153
assert_eq!(common::msb(mid_bit as u256, 128), 40);
154154
}
155+
156+
// === sqrt ===
157+
158+
#[test]
159+
fun sqrt_returns_zero_for_zero() {
160+
assert_eq!(common::sqrt_floor(0), 0);
161+
}
162+
163+
#[test]
164+
fun sqrt_handles_perfect_squares() {
165+
assert_eq!(common::sqrt_floor(4), 2);
166+
assert_eq!(common::sqrt_floor(9), 3);
167+
assert_eq!(common::sqrt_floor(16), 4);
168+
assert_eq!(common::sqrt_floor(25), 5);
169+
assert_eq!(common::sqrt_floor(100), 10);
170+
assert_eq!(common::sqrt_floor(256), 16);
171+
assert_eq!(common::sqrt_floor(65536), 256);
172+
assert_eq!(common::sqrt_floor(1 << 64), 1 << 32);
173+
assert_eq!(common::sqrt_floor(1 << 128), 1 << 64);
174+
}
175+
176+
#[test]
177+
fun sqrt_floors_non_perfect_squares() {
178+
assert_eq!(common::sqrt_floor(2), 1); // 1.414... → 1
179+
assert_eq!(common::sqrt_floor(3), 1); // 1.732... → 1
180+
assert_eq!(common::sqrt_floor(5), 2); // 2.236... → 2
181+
assert_eq!(common::sqrt_floor(8), 2); // 2.828... → 2
182+
assert_eq!(common::sqrt_floor(15), 3); // 3.873... → 3
183+
assert_eq!(common::sqrt_floor(99), 9); // 9.950... → 9
184+
assert_eq!(common::sqrt_floor(255), 15); // 15.969... → 15
185+
}
186+
187+
#[test]
188+
fun sqrt_handles_u8_values() {
189+
let zero: u8 = 0;
190+
assert_eq!(common::sqrt_floor(zero as u256), 0);
191+
192+
let one: u8 = 1;
193+
assert_eq!(common::sqrt_floor(one as u256), 1);
194+
195+
let four: u8 = 4;
196+
assert_eq!(common::sqrt_floor(four as u256), 2);
197+
198+
let max: u8 = std::u8::max_value!();
199+
assert_eq!(common::sqrt_floor(max as u256), 15);
200+
}
201+
202+
#[test]
203+
fun sqrt_handles_u16_values() {
204+
let perfect: u16 = 256;
205+
assert_eq!(common::sqrt_floor(perfect as u256), 16);
206+
207+
let non_perfect: u16 = 1000;
208+
assert_eq!(common::sqrt_floor(non_perfect as u256), 31);
209+
210+
let max: u16 = std::u16::max_value!();
211+
assert_eq!(common::sqrt_floor(max as u256), 255);
212+
}
213+
214+
#[test]
215+
fun sqrt_handles_u32_values() {
216+
let perfect: u32 = 65536;
217+
assert_eq!(common::sqrt_floor(perfect as u256), 256);
218+
219+
let non_perfect: u32 = 1000000;
220+
assert_eq!(common::sqrt_floor(non_perfect as u256), 1000);
221+
222+
let max: u32 = std::u32::max_value!();
223+
assert_eq!(common::sqrt_floor(max as u256), 65535);
224+
}
225+
226+
#[test]
227+
fun sqrt_handles_u64_values() {
228+
let perfect: u64 = 1 << 32;
229+
assert_eq!(common::sqrt_floor(perfect as u256), 1 << 16);
230+
231+
let non_perfect: u64 = 100000000;
232+
assert_eq!(common::sqrt_floor(non_perfect as u256), 10000);
233+
234+
let max: u64 = std::u64::max_value!();
235+
assert_eq!(common::sqrt_floor(max as u256), 4294967295);
236+
}
237+
238+
#[test]
239+
fun sqrt_handles_u128_values() {
240+
let perfect: u128 = 1 << 64;
241+
assert_eq!(common::sqrt_floor(perfect as u256), 1 << 32);
242+
243+
let large: u128 = 1000000000000000000;
244+
assert_eq!(common::sqrt_floor(large as u256), 1000000000);
245+
246+
let max: u128 = std::u128::max_value!();
247+
assert_eq!(common::sqrt_floor(max as u256), std::u64::max_value!() as u256);
248+
}
249+
250+
#[test]
251+
fun sqrt_handles_u256_values() {
252+
let perfect: u256 = 1 << 128;
253+
assert_eq!(common::sqrt_floor(perfect), 1 << 64);
254+
255+
let large: u256 = 1 << 200;
256+
assert_eq!(common::sqrt_floor(large), 1 << 100);
257+
258+
let max: u256 = std::u256::max_value!();
259+
assert_eq!(common::sqrt_floor(max), std::u128::max_value!() as u256);
260+
}

0 commit comments

Comments
 (0)