Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,19 @@ bool verify_gt(smt_solver::Solver* solver, smt_circuit::UltraCircuit circuit)
debug_solution(solver, terms);
}
return res;
}

bool verify_idiv(smt_solver::Solver* solver, smt_circuit::UltraCircuit circuit, uint32_t bit_size)
{
auto a = circuit["a"];
auto b = circuit["b"];
auto c = circuit["c"];
auto cr = idiv(a, b, bit_size, solver);
c != cr;
bool res = solver->check();
if (res) {
std::unordered_map<std::string, cvc5::Term> terms({ { "a", a }, { "b", b }, { "c", c }, { "cr", cr } });
debug_solution(solver, terms);
}
return res;
}
32 changes: 26 additions & 6 deletions barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the pow2_8, you should use bit_extraction rather then direct & 1 value

Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,37 @@ smt_circuit::STerm shr(smt_circuit::STerm v0, smt_circuit::STerm v1, smt_solver:
smt_circuit::STerm shl64(smt_circuit::STerm v0, smt_circuit::STerm v1, smt_solver::Solver* solver)
{
auto shifted = shl(v0, v1, solver);
// 2^64 - 1
auto mask = smt_terms::BVConst("18446744073709551615", solver, 10);
auto res = shifted & mask;
auto res = shifted.truncate(63);
return res;
}

smt_circuit::STerm shl32(smt_circuit::STerm v0, smt_circuit::STerm v1, smt_solver::Solver* solver)
{
auto shifted = shl(v0, v1, solver);
// 2^32 - 1
auto mask = smt_terms::BVConst("4294967295", solver, 10);
auto res = shifted & mask;
auto res = shifted.truncate(31);
return res;
}

smt_circuit::STerm idiv(smt_circuit::STerm v0, smt_circuit::STerm v1, uint32_t bit_size, smt_solver::Solver* solver)
{
// highest bit of v0 and v1 is sign bit
smt_circuit::STerm exponent = smt_terms::BVConst(std::to_string(bit_size), solver, 10);
auto sign_bit_v0 = v0.extract_bit(bit_size - 1);
auto sign_bit_v1 = v1.extract_bit(bit_size - 1);
auto res_sign_bit = sign_bit_v0 ^ sign_bit_v1;
res_sign_bit <<= bit_size - 1;
auto abs_value_v0 = v0.truncate(bit_size - 2);
auto abs_value_v1 = v1.truncate(bit_size - 2);
auto abs_res = abs_value_v0 / abs_value_v1;

// if abs_value_v0 == 0 then res = 0
// in our context we use idiv only once, so static name for the division result okay.
auto res = smt_terms::BVVar("res_signed_division", solver);
auto condition = smt_terms::Bool(abs_value_v0, solver) == smt_terms::Bool(smt_terms::BVConst("0", solver, 10));
auto eq1 = condition & (smt_terms::Bool(res, solver) == smt_terms::Bool(smt_terms::BVConst("0", solver, 10)));
auto eq2 = !condition & (smt_terms::Bool(res, solver) == smt_terms::Bool(res_sign_bit | abs_res, solver));

(eq1 | eq2).assert_term();

return res;
}
12 changes: 11 additions & 1 deletion barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,14 @@ smt_circuit::STerm shr(smt_circuit::STerm v0, smt_circuit::STerm v1, smt_solver:
* @param solver SMT solver instance
* @return Result of (v0 << v1) without truncation
*/
smt_circuit::STerm shl(smt_circuit::STerm v0, smt_circuit::STerm v1, smt_solver::Solver* solver);
smt_circuit::STerm shl(smt_circuit::STerm v0, smt_circuit::STerm v1, smt_solver::Solver* solver);

/**
* @brief Signed division in noir-style
* @param v0 Numerator
* @param v1 Denominator
* @param bit_size bit sizes of numerator and denominator
* @param solver SMT solver instance
* @return Result of (v0 / v1)
*/
smt_circuit::STerm idiv(smt_circuit::STerm v0, smt_circuit::STerm v1, uint32_t bit_size, smt_solver::Solver* solver);
143 changes: 139 additions & 4 deletions barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,16 @@ using uint_ct = stdlib::uint32<StandardCircuitBuilder>;

using namespace smt_terms;

/**
* @brief Test left shift operation
* Tests that 5 << 1 = 10 using SMT solver
*/
TEST(helpers, shl)
{
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config, 16, 32);
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
/*base=*/16,
/*bvsize=*/32);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
Expand All @@ -32,9 +39,16 @@ TEST(helpers, shl)
EXPECT_TRUE(vals["z"] == "00000000000000000000000000001010");
}

/**
* @brief Test right shift operation
* Tests that 5 >> 1 = 2 using SMT solver
*/
TEST(helpers, shr)
{
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config, 16, 32);
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
/*base=*/16,
/*bvsize=*/32);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
Expand All @@ -52,11 +66,18 @@ TEST(helpers, shr)
EXPECT_TRUE(vals["z"] == "00000000000000000000000000000010");
}

/**
* @brief Test edge case for right shift operation
* Tests that 1879048194 >> 16 = 28672 using SMT solver
*/
TEST(helpers, buggy_shr)
{
// using smt solver i found that 1879048194 >> 16 == 0
// its strange...
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config, 16, 32);
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
/*base=*/16,
/*bvsize=*/32);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
Expand All @@ -74,9 +95,16 @@ TEST(helpers, buggy_shr)
EXPECT_TRUE(vals["z"] == "00000000000000000111000000000000");
}

/**
* @brief Test power of 2 calculation
* Tests that 2^11 = 2048 using SMT solver
*/
TEST(helpers, pow2)
{
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config, 16, 32);
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
/*base=*/16,
/*bvsize=*/32);

STerm x = BVVar("x", &s);
STerm z = pow2_8(x, &s);
Expand All @@ -89,4 +117,111 @@ TEST(helpers, pow2)
info("z = ", vals["z"]);
// z == 2048 in binary
EXPECT_TRUE(vals["z"] == "00000000000000000000100000000000");
}

/**
* @brief Test signed division with zero dividend
* Tests that 0 / -1 = 0 using SMT solver
*/
TEST(helpers, signed_div)
{
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
/*base=*/16,
/*bvsize=*/32);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
STerm z = idiv(x, y, 2, &s);
// 00 == 0
x == 0;
// 11 == -1
y == 3;
s.check();
std::unordered_map<std::string, cvc5::Term> terms({ { "x", x }, { "y", y }, { "z", z } });
std::unordered_map<std::string, std::string> vals = s.model(terms);
info("x = ", vals["x"]);
info("y = ", vals["y"]);
info("z = ", vals["z"]);
EXPECT_TRUE(vals["z"] == "00000000000000000000000000000000");
}

/**
* @brief Test signed division with positive dividend and negative divisor
* Tests that 1 / -1 = -1 using SMT solver
*/
TEST(helpers, signed_div_1)
{
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
/*base=*/16,
/*bvsize=*/32);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
STerm z = idiv(x, y, 2, &s);
// 01 == 1
x == 1;
// 11 == -1
y == 3;
s.check();
std::unordered_map<std::string, cvc5::Term> terms({ { "x", x }, { "y", y }, { "z", z } });
std::unordered_map<std::string, std::string> vals = s.model(terms);
info("x = ", vals["x"]);
info("y = ", vals["y"]);
info("z = ", vals["z"]);
EXPECT_TRUE(vals["z"] == "00000000000000000000000000000011");
}

/**
* @brief Test signed division with positive numbers
* Tests that 7 / 2 = 3 using SMT solver
*/
TEST(helpers, signed_div_2)
{
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
/*base=*/16,
/*bvsize=*/32);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
STerm z = idiv(x, y, 4, &s);
// 0111 == 7
x == 7;
// 0010 == 2
y == 2;
s.check();
std::unordered_map<std::string, cvc5::Term> terms({ { "x", x }, { "y", y }, { "z", z } });
std::unordered_map<std::string, std::string> vals = s.model(terms);
info("x = ", vals["x"]);
info("y = ", vals["y"]);
info("z = ", vals["z"]);
EXPECT_TRUE(vals["z"] == "00000000000000000000000000000011");
}

/**
* @brief Test left shift overflow behavior
* Tests that 1 << 50 = 0 (due to overflow) using SMT solver
*/
TEST(helpers, shl_overflow)
{
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
/*base=*/16,
/*bvsize=*/32);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
STerm z = shl32(x, y, &s);
x == 1;
y == 50;
s.check();
std::unordered_map<std::string, cvc5::Term> terms({ { "x", x }, { "y", y }, { "z", z } });
std::unordered_map<std::string, std::string> vals = s.model(terms);
info("x = ", vals["x"]);
info("y = ", vals["y"]);
info("z = ", vals["z"]);
// z == 1010 in binary
EXPECT_TRUE(vals["z"] == "00000000000000000000000000000000");
}
Loading