Skip to content

Commit

Permalink
Fix stack overflow in verifier (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
hoffmang9 authored Mar 7, 2020
1 parent bdd52e7 commit 07951ab
Showing 1 changed file with 41 additions and 31 deletions.
72 changes: 41 additions & 31 deletions lib/chiavdf/fast_vdf/verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,24 @@
#include "proof_common.h"
#include "create_discriminant.h"

const int kMaxBytesProof = 100000;
void VerifyWesolowskiProof(integer &D, form x, form y, form proof, int iters, bool &is_valid)
{
PulmarkReducer reducer;
int int_size = (D.num_bits() + 16) >> 4;
integer L = root(-D, 4);
integer B = GetB(D, x, y);
integer r = FastPow(2, iters, B);
form f1 = FastPowFormNucomp(proof, D, B, L, reducer);
form f2 = FastPowFormNucomp(x, D, r, L, reducer);
if (f1 * f2 == y)
{
is_valid = true;
}
else
{
is_valid = false;
}
}

integer ConvertBytesToInt(uint8_t *bytes, int start_index, int end_index)
{
Expand Down Expand Up @@ -52,31 +69,12 @@ std::vector<form> DeserializeProof(uint8_t *proof_bytes, int proof_len, integer
return proof;
}

void VerifyWesolowskiProof(integer &D, form x, form y, form proof, int iters, bool &is_valid)
{
PulmarkReducer reducer;
int int_size = (D.num_bits() + 16) >> 4;
integer L = root(-D, 4);
integer B = GetB(D, x, y);
integer r = FastPow(2, iters, B);
form f1 = FastPowFormNucomp(proof, D, B, L, reducer);
form f2 = FastPowFormNucomp(x, D, r, L, reducer);
if (f1 * f2 == y)
{
is_valid = true;
}
else
{
is_valid = false;
}
}

bool CheckProofOfTimeNWesolowskiInner(integer &D, form x, uint8_t *proof_blob,
int blob_len, int iters, int int_size,
std::vector<int> iter_list, int recursion)
{
uint8_t result_bytes[kMaxBytesProof];
uint8_t proof_bytes[kMaxBytesProof];
uint8_t* result_bytes = new uint8_t[2 * int_size];
uint8_t* proof_bytes = new uint8_t[blob_len - 2 * int_size];
memcpy(result_bytes, proof_blob, 2 * int_size);
memcpy(proof_bytes, proof_blob + 2 * int_size, blob_len - 2 * int_size);
form y = DeserializeForm(D, result_bytes, int_size);
Expand All @@ -87,32 +85,44 @@ bool CheckProofOfTimeNWesolowskiInner(integer &D, form x, uint8_t *proof_blob,
{
bool is_valid;
VerifyWesolowskiProof(D, x, y, proof[0], iters, is_valid);
delete[] result_bytes;
delete[] proof_bytes;
return is_valid;
}
else
{
if (!(proof.size() % 2 == 1 && proof.size() > 2))
if (!(proof.size() % 2 == 1 && proof.size() > 2)) {
delete[] result_bytes;
delete[] proof_bytes;
return false;
}
int iters1 = iter_list[iter_list.size() - 1];
int iters2 = iters - iters1;
bool ver_outer;
std::thread t(VerifyWesolowskiProof, std::ref(D), x, proof[proof.size() - 2], proof[proof.size() - 1], iters1, std::ref(ver_outer));
uint8_t new_proof_bytes[kMaxBytesProof];
VerifyWesolowskiProof(D, x, proof[proof.size() - 2], proof[proof.size() - 1], iters1, ver_outer);
if (!ver_outer) {
delete[] result_bytes;
delete[] proof_bytes;
return false;
}
uint8_t* new_proof_bytes = new uint8_t[blob_len - 4 * int_size];
for (int i = 0; i < blob_len - 4 * int_size; i++)
new_proof_bytes[i] = proof_blob[i];
iter_list.pop_back();
bool ver_inner = CheckProofOfTimeNWesolowskiInner(D, proof[proof.size() - 2], new_proof_bytes, blob_len - 4 * int_size, iters2, int_size, iter_list, recursion - 1);
t.join();
if (ver_inner && ver_outer)
delete[] result_bytes;
delete[] proof_bytes;
delete[] new_proof_bytes;
if (ver_inner)
return true;
return false;
}
}

bool CheckProofOfTimeNWesolowski(integer &D, form x, uint8_t *proof_blob, int proof_blob_len, int iters, int recursion)
bool CheckProofOfTimeNWesolowski(integer D, form x, uint8_t *proof_blob, int proof_blob_len, int iters, int recursion)
{
int int_size = (D.num_bits() + 16) >> 4;
uint8_t new_proof_blob[kMaxBytesProof];
uint8_t* new_proof_blob = new uint8_t[proof_blob_len];
int new_cnt = 4 * int_size;
memcpy(new_proof_blob, proof_blob, new_cnt);
std::vector<int> iter_list;
Expand All @@ -126,7 +136,8 @@ bool CheckProofOfTimeNWesolowski(integer &D, form x, uint8_t *proof_blob, int pr
new_cnt += 4 * int_size;
}
bool is_valid = CheckProofOfTimeNWesolowskiInner(D, x, new_proof_blob, new_cnt, iters, int_size, iter_list, recursion);
return is_valid;
delete[] new_proof_blob;
return is_valid;
}

std::vector<uint8_t> HexToBytes(char *hex_proof)
Expand Down Expand Up @@ -180,7 +191,6 @@ bool CheckProofOfTimeType(ProofOfTimeType &proof)
form y = form::from_abd(proof.a, proof.b, discriminant);
std::vector<uint8_t> proof_blob = SerializeForm(y, int_size);
proof_blob.insert(proof_blob.end(), proof.witness.begin(), proof.witness.end());

result = CheckProofOfTimeNWesolowski(discriminant, x, proof_blob.data(), proof_blob.size(), proof.iterations_needed, proof.witness_type);
}
catch (std::exception &e)
Expand Down

0 comments on commit 07951ab

Please sign in to comment.