Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved integer square root. #4403

Merged
merged 26 commits into from
Feb 16, 2024
Merged
Changes from 15 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 60 additions & 29 deletions contracts/utils/math/Math.sol
Original file line number Diff line number Diff line change
Expand Up @@ -339,38 +339,69 @@ library Math {
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
* towards zero.
*
* Inspired by Henry S. Warren, Jr.'s "Hacker's Delight" (Chapter 11).
* This method is based on Newton's method for computing square roots; the algorithm is restricted to only
* using integer operations.
*/
function sqrt(uint256 a) internal pure returns (uint256) {
if (a == 0) {
return 0;
}

// For our first guess, we get the biggest power of 2 which is smaller than the square root of the target.
//
// We know that the "msb" (most significant bit) of our target number `a` is a power of 2 such that we have
// `msb(a) <= a < 2*msb(a)`. This value can be written `msb(a)=2**k` with `k=log2(a)`.
//
// This can be rewritten `2**log2(a) <= a < 2**(log2(a) + 1)`
// → `sqrt(2**k) <= sqrt(a) < sqrt(2**(k+1))`
// → `2**(k/2) <= sqrt(a) < 2**((k+1)/2) <= 2**(k/2 + 1)`
//
// Consequently, `2**(log2(a) / 2)` is a good first approximation of `sqrt(a)` with at least 1 correct bit.
uint256 result = 1 << (log2(a) >> 1);

// At this point `result` is an estimation with one bit of precision. We know the true value is a uint128,
// since it is the square root of a uint256. Newton's method converges quadratically (precision doubles at
// every iteration). We thus need at most 7 iteration to turn our partial result with one bit of precision
// into the expected uint128 result.
unchecked {
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
return min(result, a / result);
// Take care of easy edge cases when a == 0 or a == 1
if (a <= 1) {
return a;
}
chgorman marked this conversation as resolved.
Show resolved Hide resolved

uint256 aAux = a;
uint256 result = 1;

// For the first guess of `result` (e), we get the biggest power of 2 which is smaller than sqrt(a)
ernestognw marked this conversation as resolved.
Show resolved Hide resolved
// (i.e. 2^e <= sqrt(a)). We know that e is at most 127 given (2^128)^2 overflows an uint256.
ernestognw marked this conversation as resolved.
Show resolved Hide resolved
// Thus, we approximate e by iterating 2^{i/2} where i starts at 128, and applying the exponent
// to e if the result is still smaller than a (up to e == 127).
ernestognw marked this conversation as resolved.
Show resolved Hide resolved
if (aAux >= (1 << 128)) {
aAux >>= 128;
result <<= 64;
}
if (aAux >= (1 << 64)) {
aAux >>= 64;
result <<= 32;
}
if (aAux >= (1 << 32)) {
aAux >>= 32;
result <<= 16;
}
if (aAux >= (1 << 16)) {
aAux >>= 16;
result <<= 8;
}
if (aAux >= (1 << 8)) {
aAux >>= 8;
result <<= 4;
}
if (aAux >= (1 << 4)) {
aAux >>= 4;
result <<= 2;
}
if (aAux >= (1 << 2)) {
result <<= 1;
}
Copy link
Collaborator

@Amxx Amxx Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very simmilar to the code we have in log2. In fact, its actually 1 << (log2(a) >> 1). In log2, we changed the logic to branchless, and that saved some gas. Surprisingly the same approach doesn't appear to save much here.


// We can use the fact that 2^e <= sqrt(a) to improve the estimation
// by computing the arithmetic mean between the current estimation and
// the next one (result * 2), ensuring that result - sqrt(a) <= 2^{e-2}.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I don't think that is clear.

Copy link
Collaborator

@Amxx Amxx Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that we know result <= sqrt(a) < 2 * result ... so right now we have this distance:

║ result - sqrt(a) ║ < result

However, if we take the middle of that interval, then the distance becomes

║ 3*result/2 - sqrt(a) ║ < result / 2

But that still only gives us result/2 = 2**(e-1) ... where does 2**(e-2) come from?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that still only gives us result/2 = 2**(e-1) ... where does 2**(e-2) come from?

Right. But consider we're looking for e such that 2**e + c == a and the worst case for c is 2**e - 1. If we're using e, we would be above the square root. So you should note this as result/2 = 2**(e-2).

Copy link
Collaborator

@Amxx Amxx Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But consider we're looking for e such that 2 ** e + c == a and the worst case for c is 2 ** e - 1

We have e such that "2 ** e <= sqrt(a) < 2 ** (e+1)"

Looks like there is a confusion between a and sqrt(a) here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the confusion is in the definition of e. Comments in the code and PDF define it differently

result = (3 * result) >> 1;

// Each Newton iteration will have f(x) = (x + a / x) / 2.
// Given the error (ε) is defined by x - sqrt(a), then we know that
// ε + 1 == ε^2 / 2x <= ε^2 / 2 * sqrt(a).
Copy link
Collaborator

@Amxx Amxx Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments should be readable alone, without the PDF document.
Here, e+1 reads like e plus the number one, which is not how it should be interpreted.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the intention is to express e_{+1}. Would you be ok with that syntax?

result = (result + a / result) >> 1; // err := result - sqrt(a) <= 2^{e-4.5}
result = (result + a / result) >> 1; // err := result - sqrt(a) <= 2^{e-9}
result = (result + a / result) >> 1; // err := result - sqrt(a) <= 2^{e-18}
result = (result + a / result) >> 1; // err := result - sqrt(a) <= 2^{e-36}
result = (result + a / result) >> 1; // err := result - sqrt(a) <= 2^{e-72}
result = (result + a / result) >> 1; // err := result - sqrt(a) <= 2^{e-144}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that correct?

Suggested change
result = (result + a / result) >> 1; // err := result - sqrt(a) <= 2^{e-4.5}
result = (result + a / result) >> 1; // err := result - sqrt(a) <= 2^{e-9}
result = (result + a / result) >> 1; // err := result - sqrt(a) <= 2^{e-18}
result = (result + a / result) >> 1; // err := result - sqrt(a) <= 2^{e-36}
result = (result + a / result) >> 1; // err := result - sqrt(a) <= 2^{e-72}
result = (result + a / result) >> 1; // err := result - sqrt(a) <= 2^{e-144}
result = (result + a / result) >> 1; // ε = ║ result - sqrt(a) ║ < 2**(e - 4.5)
result = (result + a / result) >> 1; // ε = ║ result - sqrt(a) ║ < 2**(e - 9)
result = (result + a / result) >> 1; // ε = ║ result - sqrt(a) ║ < 2**(e - 18)
result = (result + a / result) >> 1; // ε = ║ result - sqrt(a) ║ < 2**(e - 36)
result = (result + a / result) >> 1; // ε = ║ result - sqrt(a) ║ < 2**(e - 72)
result = (result + a / result) >> 1; // ε = ║ result - sqrt(a) ║ < 2**(e - 144)

Where does the initial 4.5 come from ? Its not documented anywhere.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's correct. Feel free to update it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does the initial 4.5 come from ? Its not documented anywhere.

Added a reference, but comes from using the error formula from Walter Rudin. Principles of Mathematical Analysis. 3rd ed. McGraw-Hill New York, 1976. Exercise 3.16 (b)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than adding a reference (to something that I couldn't easily find ... I mean I found a 352 pages long scanned document, with 3.16 being an unrelated definition ... but that doesn't go in the right direction) I'd rather add a proof.

AFAIK, its faster to reasd the proof than find the reference :P

Copy link
Member

@ernestognw ernestognw Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note 3.16 is not the exercise but a section. I think you're looking for this:

Captura de pantalla 2024-02-15 a la(s) 12 45 15 p m

And the solution:
Captura de pantalla 2024-02-15 a la(s) 12 46 06 p m


// After 6 iterations, no more precision can be obtained since the max result is 127.
chgorman marked this conversation as resolved.
Show resolved Hide resolved
// result is either sqrt(a) or sqrt(a) + 1.
return result - SafeCast.toUint(result > a / result);
}
}

Expand Down
Loading