Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved integer square root. #4403

Merged
merged 26 commits into from
Feb 16, 2024
Merged
Changes from 5 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 36 additions & 24 deletions contracts/utils/math/Math.sol
Original file line number Diff line number Diff line change
Expand Up @@ -216,39 +216,51 @@ library Math {
/**
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
* towards zero.
*
* Inspired by Henry S. Warren, Jr.'s "Hacker's Delight" (Chapter 11).
*/
function sqrt(uint256 a) internal pure returns (uint256) {
if (a == 0) {
return 0;
}

// For our first guess, we get the biggest power of 2 which is smaller than the square root of the target.
//
// We know that the "msb" (most significant bit) of our target number `a` is a power of 2 such that we have
// `msb(a) <= a < 2*msb(a)`. This value can be written `msb(a)=2**k` with `k=log2(a)`.
//
// This can be rewritten `2**log2(a) <= a < 2**(log2(a) + 1)`
// → `sqrt(2**k) <= sqrt(a) < sqrt(2**(k+1))`
// → `2**(k/2) <= sqrt(a) < 2**((k+1)/2) <= 2**(k/2 + 1)`
//
// Consequently, `2**(log2(a) / 2)` is a good first approximation of `sqrt(a)` with at least 1 correct bit.
uint256 result = 1 << (log2(a) >> 1);

// At this point `result` is an estimation with one bit of precision. We know the true value is a uint128,
// since it is the square root of a uint256. Newton's method converges quadratically (precision doubles at
// every iteration). We thus need at most 7 iteration to turn our partial result with one bit of precision
// into the expected uint128 result.
unchecked {
// Take care of easy edge cases
if (a <= 1) { return a; }
// This check ensures no overflow
if (a >= ((1 << 128) - 1)**2) { return (1 << 128) - 1; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a risk of overflow ? Can you document which part is succeptible ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using this specific method and return logic, it is possible to overflow when computing result**2 during the check result**2 <= a, because it may be that result == 2**128; thus, result**2 == 0 because this is unchecked.

Copy link
Collaborator

@Amxx Amxx Jul 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interresting. I guess this is not necessary if you do min(result, a/result) so the current version is good in that regard.

I'd rewrite that as:

Suggested change
if (a >= ((1 << 128) - 1)**2) { return (1 << 128) - 1; }
if (a >= uint256(type(uint128).max)**2) { return type(uint128).max; }

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the cost of that test outweight the cost of the min at the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interresting. I guess this is not necessary if you do min(result, a/result) so the current version is good in that regard.

The current version has no proof of correctness. This proposed change does. See Appendix B.4.3. Have you looked at that report?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code was updated with suggestion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, the overflow check (if (a >= uint256(type(uint128).max)**2) { return type(uint128).max; }) costs 23 gas.


// If we have
//
// 2^{e-1} <= sqrt(x) < 2^{e},
//
// then at the end of initialization, we will have
//
// result == 2^{e-1} + 2^{e-2}.
//
// This ensures that
//
// abs(sqrt(x) - result) <= 2^{e-2}.
uint256 aAux = a;
uint256 result = 1;
if (aAux >= (1 << 128)) { aAux >>= 128; result <<= 64; }
if (aAux >= (1 << 64 )) { aAux >>= 64; result <<= 32; }
if (aAux >= (1 << 32 )) { aAux >>= 32; result <<= 16; }
if (aAux >= (1 << 16 )) { aAux >>= 16; result <<= 8; }
if (aAux >= (1 << 8 )) { aAux >>= 8; result <<= 4; }
if (aAux >= (1 << 4 )) { aAux >>= 4; result <<= 2; }
if (aAux >= (1 << 2 )) { result <<= 1; }
result += (result >> 1);

// Perform the 6 required Newton iteration
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
return min(result, a / result);

// We either have
//
// Isqrt(x) == result or Isqrt(x) == result-1.
if (result * result <= a) {
return result;
}
return result-1;
Amxx marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down