diff --git a/barretenberg/cpp/src/barretenberg/acir_formal_proofs/formal_proofs.cpp b/barretenberg/cpp/src/barretenberg/acir_formal_proofs/formal_proofs.cpp index 4c8bc12ee6b9..342b4f9ccf08 100644 --- a/barretenberg/cpp/src/barretenberg/acir_formal_proofs/formal_proofs.cpp +++ b/barretenberg/cpp/src/barretenberg/acir_formal_proofs/formal_proofs.cpp @@ -269,4 +269,19 @@ bool verify_gt(smt_solver::Solver* solver, smt_circuit::UltraCircuit circuit) debug_solution(solver, terms); } return res; +} + +bool verify_idiv(smt_solver::Solver* solver, smt_circuit::UltraCircuit circuit, uint32_t bit_size) +{ + auto a = circuit["a"]; + auto b = circuit["b"]; + auto c = circuit["c"]; + auto cr = idiv(a, b, bit_size, solver); + c != cr; + bool res = solver->check(); + if (res) { + std::unordered_map terms({ { "a", a }, { "b", b }, { "c", c }, { "cr", cr } }); + debug_solution(solver, terms); + } + return res; } \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.cpp b/barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.cpp index a3ab02b5f584..90bf02991f96 100644 --- a/barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.cpp +++ b/barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.cpp @@ -47,17 +47,37 @@ smt_circuit::STerm shr(smt_circuit::STerm v0, smt_circuit::STerm v1, smt_solver: smt_circuit::STerm shl64(smt_circuit::STerm v0, smt_circuit::STerm v1, smt_solver::Solver* solver) { auto shifted = shl(v0, v1, solver); - // 2^64 - 1 - auto mask = smt_terms::BVConst("18446744073709551615", solver, 10); - auto res = shifted & mask; + auto res = shifted.truncate(63); return res; } smt_circuit::STerm shl32(smt_circuit::STerm v0, smt_circuit::STerm v1, smt_solver::Solver* solver) { auto shifted = shl(v0, v1, solver); - // 2^32 - 1 - auto mask = smt_terms::BVConst("4294967295", solver, 10); - auto res = shifted & mask; + auto res = shifted.truncate(31); return res; } + +smt_circuit::STerm idiv(smt_circuit::STerm v0, smt_circuit::STerm v1, uint32_t bit_size, smt_solver::Solver* solver) +{ + // highest bit of v0 and v1 is sign bit + smt_circuit::STerm exponent = smt_terms::BVConst(std::to_string(bit_size), solver, 10); + auto sign_bit_v0 = v0.extract_bit(bit_size - 1); + auto sign_bit_v1 = v1.extract_bit(bit_size - 1); + auto res_sign_bit = sign_bit_v0 ^ sign_bit_v1; + res_sign_bit <<= bit_size - 1; + auto abs_value_v0 = v0.truncate(bit_size - 2); + auto abs_value_v1 = v1.truncate(bit_size - 2); + auto abs_res = abs_value_v0 / abs_value_v1; + + // if abs_value_v0 == 0 then res = 0 + // in our context we use idiv only once, so static name for the division result okay. + auto res = smt_terms::BVVar("res_signed_division", solver); + auto condition = smt_terms::Bool(abs_value_v0, solver) == smt_terms::Bool(smt_terms::BVConst("0", solver, 10)); + auto eq1 = condition & (smt_terms::Bool(res, solver) == smt_terms::Bool(smt_terms::BVConst("0", solver, 10))); + auto eq2 = !condition & (smt_terms::Bool(res, solver) == smt_terms::Bool(res_sign_bit | abs_res, solver)); + + (eq1 | eq2).assert_term(); + + return res; +} \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.hpp b/barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.hpp index d1e6fe4ab45b..a1d3c99a7cb3 100644 --- a/barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.hpp +++ b/barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.hpp @@ -44,4 +44,14 @@ smt_circuit::STerm shr(smt_circuit::STerm v0, smt_circuit::STerm v1, smt_solver: * @param solver SMT solver instance * @return Result of (v0 << v1) without truncation */ -smt_circuit::STerm shl(smt_circuit::STerm v0, smt_circuit::STerm v1, smt_solver::Solver* solver); \ No newline at end of file +smt_circuit::STerm shl(smt_circuit::STerm v0, smt_circuit::STerm v1, smt_solver::Solver* solver); + +/** + * @brief Signed division in noir-style + * @param v0 Numerator + * @param v1 Denominator + * @param bit_size bit sizes of numerator and denominator + * @param solver SMT solver instance + * @return Result of (v0 / v1) + */ +smt_circuit::STerm idiv(smt_circuit::STerm v0, smt_circuit::STerm v1, uint32_t bit_size, smt_solver::Solver* solver); \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.test.cpp b/barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.test.cpp index 9918dcb6947e..ddfc4da15072 100644 --- a/barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.test.cpp +++ b/barretenberg/cpp/src/barretenberg/acir_formal_proofs/helpers.test.cpp @@ -12,9 +12,16 @@ using uint_ct = stdlib::uint32; using namespace smt_terms; +/** + * @brief Test left shift operation + * Tests that 5 << 1 = 10 using SMT solver + */ TEST(helpers, shl) { - Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config, 16, 32); + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + /*base=*/16, + /*bvsize=*/32); STerm x = BVVar("x", &s); STerm y = BVVar("y", &s); @@ -32,9 +39,16 @@ TEST(helpers, shl) EXPECT_TRUE(vals["z"] == "00000000000000000000000000001010"); } +/** + * @brief Test right shift operation + * Tests that 5 >> 1 = 2 using SMT solver + */ TEST(helpers, shr) { - Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config, 16, 32); + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + /*base=*/16, + /*bvsize=*/32); STerm x = BVVar("x", &s); STerm y = BVVar("y", &s); @@ -52,11 +66,18 @@ TEST(helpers, shr) EXPECT_TRUE(vals["z"] == "00000000000000000000000000000010"); } +/** + * @brief Test edge case for right shift operation + * Tests that 1879048194 >> 16 = 28672 using SMT solver + */ TEST(helpers, buggy_shr) { // using smt solver i found that 1879048194 >> 16 == 0 // its strange... - Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config, 16, 32); + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + /*base=*/16, + /*bvsize=*/32); STerm x = BVVar("x", &s); STerm y = BVVar("y", &s); @@ -74,9 +95,16 @@ TEST(helpers, buggy_shr) EXPECT_TRUE(vals["z"] == "00000000000000000111000000000000"); } +/** + * @brief Test power of 2 calculation + * Tests that 2^11 = 2048 using SMT solver + */ TEST(helpers, pow2) { - Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config, 16, 32); + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + /*base=*/16, + /*bvsize=*/32); STerm x = BVVar("x", &s); STerm z = pow2_8(x, &s); @@ -89,4 +117,111 @@ TEST(helpers, pow2) info("z = ", vals["z"]); // z == 2048 in binary EXPECT_TRUE(vals["z"] == "00000000000000000000100000000000"); +} + +/** + * @brief Test signed division with zero dividend + * Tests that 0 / -1 = 0 using SMT solver + */ +TEST(helpers, signed_div) +{ + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + /*base=*/16, + /*bvsize=*/32); + + STerm x = BVVar("x", &s); + STerm y = BVVar("y", &s); + STerm z = idiv(x, y, 2, &s); + // 00 == 0 + x == 0; + // 11 == -1 + y == 3; + s.check(); + std::unordered_map terms({ { "x", x }, { "y", y }, { "z", z } }); + std::unordered_map vals = s.model(terms); + info("x = ", vals["x"]); + info("y = ", vals["y"]); + info("z = ", vals["z"]); + EXPECT_TRUE(vals["z"] == "00000000000000000000000000000000"); +} + +/** + * @brief Test signed division with positive dividend and negative divisor + * Tests that 1 / -1 = -1 using SMT solver + */ +TEST(helpers, signed_div_1) +{ + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + /*base=*/16, + /*bvsize=*/32); + + STerm x = BVVar("x", &s); + STerm y = BVVar("y", &s); + STerm z = idiv(x, y, 2, &s); + // 01 == 1 + x == 1; + // 11 == -1 + y == 3; + s.check(); + std::unordered_map terms({ { "x", x }, { "y", y }, { "z", z } }); + std::unordered_map vals = s.model(terms); + info("x = ", vals["x"]); + info("y = ", vals["y"]); + info("z = ", vals["z"]); + EXPECT_TRUE(vals["z"] == "00000000000000000000000000000011"); +} + +/** + * @brief Test signed division with positive numbers + * Tests that 7 / 2 = 3 using SMT solver + */ +TEST(helpers, signed_div_2) +{ + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + /*base=*/16, + /*bvsize=*/32); + + STerm x = BVVar("x", &s); + STerm y = BVVar("y", &s); + STerm z = idiv(x, y, 4, &s); + // 0111 == 7 + x == 7; + // 0010 == 2 + y == 2; + s.check(); + std::unordered_map terms({ { "x", x }, { "y", y }, { "z", z } }); + std::unordered_map vals = s.model(terms); + info("x = ", vals["x"]); + info("y = ", vals["y"]); + info("z = ", vals["z"]); + EXPECT_TRUE(vals["z"] == "00000000000000000000000000000011"); +} + +/** + * @brief Test left shift overflow behavior + * Tests that 1 << 50 = 0 (due to overflow) using SMT solver + */ +TEST(helpers, shl_overflow) +{ + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + /*base=*/16, + /*bvsize=*/32); + + STerm x = BVVar("x", &s); + STerm y = BVVar("y", &s); + STerm z = shl32(x, y, &s); + x == 1; + y == 50; + s.check(); + std::unordered_map terms({ { "x", x }, { "y", y }, { "z", z } }); + std::unordered_map vals = s.model(terms); + info("x = ", vals["x"]); + info("y = ", vals["y"]); + info("z = ", vals["z"]); + // z == 1010 in binary + EXPECT_TRUE(vals["z"] == "00000000000000000000000000000000"); } \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/README.md b/barretenberg/cpp/src/barretenberg/smt_verification/README.md index c62f07516fef..5246142d6929 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/README.md +++ b/barretenberg/cpp/src/barretenberg/smt_verification/README.md @@ -24,7 +24,7 @@ To store it on the disk just do the following ```c++ msgpack::sbuffer buffer = circuit.export_circuit(); - + std::fstream myfile; myfile.open("fname.pack", std::ios::out | std::ios::trunc | std::ios::binary); @@ -44,7 +44,7 @@ To store it on the disk just do the following 2. Initialize the Solver: There's an `smt_solver::SolverConfiguration` structure: - + ```cpp struct SolverConfiguration { bool produce_models; @@ -90,7 +90,7 @@ To store it on the disk just do the following All the tables are exoported directly from circuit, but if you want to create your own table, there're two methods for this: - - `Solver::create_table(vector& table)` - creates a set of values. + - `Solver::create_table(vector& table)` - creates a set of values. - `Solver::create_lookup_table(vector>& table)` - creates a table with three columns. ```c++ @@ -101,13 +101,13 @@ To store it on the disk just do the following There is more on `FFConst` in the following sections. -3. Initialize the Circuit +3. Initialize the Circuit - From now on we will use `smt_terms::STerm` and `smt_terms::Bool` types to operate inside the solver. + From now on we will use `smt_terms::STerm` and `smt_terms::Bool` types to operate inside the solver. You can choose the behaviour of symbolic variables by providing the specific type to `STerm` or `Circuit` constructor: - - `smt_terms::TermType::FFTerm` - symbolic variables that simulate finite field arithmetic. + - `smt_terms::TermType::FFTerm` - symbolic variables that simulate finite field arithmetic. - `smt_terms::TermType::FFITerm` - symbolic variables that simulate integer elements which behave like finite field ones. Useful, when you want to create range constraints. Bad, when you try multiplication. - `smt_terms::TermType::ITerm` - symbolic variables that simulate ordinary integer elements. Useful, when you want to create range constraints and operate with signed values that are not shrinked modulo smth. - `smt_terms::TermType::BVTerm` - symbolic variables that simulate $\pmod{2^n}$ arithmetic. Useful, when you test uint circuits. Supports range constraints and bitwise operations. Doesn't behave like finite field element. @@ -117,13 +117,13 @@ To store it on the disk just do the following `Bool` - simulates the boolean values and mostly will be useful to simulate complex `if` statements if needed. Now we can create symbolic circuit - + - ```smt_circuit::StandardCircuit circuit(CircuitSchema c_info, Solver* s, TermType type, str tag="", bool optimizations=true)``` - ```smt_circuit::UltraCircuit circuit(CircuitSchema c_info, Solver* s, TermType type, str tag="", bool optimizations=true)``` - + It will generate all the symbolic values of the circuit wires, add all the gate constrains, create a map `term_name->STerm` and the inverse of it. Where `term_name` is the name you provided earlier. - In case you want to create two similar circuits with the same `solver` and `schema`, then you should specify the `tag`(name) of a circuit. + In case you want to create two similar circuits with the same `solver` and `schema`, then you should specify the `tag`(name) of a circuit. **Advanced** If you don't want the circuit optimizations to be applied then you should set `optimizations` to `false`. Optimizations interchange the complex circuits like bitwise XOR with simple XOR operation. More on optimizations can be found [standard_circuit.cpp](circuit/standard_circuit.cpp) @@ -145,15 +145,15 @@ To store it on the disk just do the following You can add, subtract and multiply these variables(including `+=`, `-=`, etc); Also there are two functions: - `batch_add(std::vector& terms)` - - `batch_mul(std::vector& terms)` + - `batch_mul(std::vector& terms)` to create an addition/multiplication Term in one call `FFITerm` also can be used to create range constraints. e.g. `x <= bb::fr(2).pow(10) - 1;` - `BVTerm` can be used to create bitwise constraints. e.g. `STerm y = x^z` or `STerm y = x.rotr(10)`. And range constraints too. + `BVTerm` can be used to create bitwise constraints. e.g. `STerm y = x^z` or `STerm y = x.rotr(10)`. And range constraints too. Also there are `truncate` and `extract_bit` methods, e.g. `x.truncate(9)` will truncate to last 10 bits (starting from 0th bit), `x.extract_bit(10)` will extract 10th bit. - You can create a constraint `==` or `!=` that will be included directly into solver. e.g. `x == y;` + You can create a constraint `==` or `!=` that will be included directly into solver. e.g. `x == y;` **!Note: In this case these are not comparison operators** @@ -169,7 +169,7 @@ To store it on the disk just do the following You can `|, &, ==, !=, !` these variables and also `batch_or`, `batch_and` them. To create a constraint you should call `Bool::assert_term()` method. - + The way I see the use of Bool types is to create terms like `(a == b && c == 1) || (a != b && c == 0)`, `(a!=1)||(b!=2)|(c!=3)` and of course more sophisticated ones. **!Note that constraint like `(Bool(STerm a) == Bool(STerm b)).assert_term()`, where a has `FFTerm` type and b has `FFITerm` type, won't work, since their types differ.** @@ -181,8 +181,8 @@ After generating all the constrains you should call `bool res = solver.check()` In case you expected `false` but `true` was returned you can then check what went wrong. You should generate an unordered map with `str->term` values and ask the solver to obtain `unoredered_map res = solver.model(unordered_map terms)`. Or you can provide a vector of terms that you want to check and the return map will contain their symbolic names that are given during initialization. Specifically either it's the name that you set or `var_{i}`. - -Now you have the values of the specified terms, which resulted into `true` result. + +Now you have the values of the specified terms, which resulted into `true` result. **!Note that the return values are decimal strings/binary strings**, so if you want to use them later you should use `FFConst` with base 10, etc. Also, there is a header file "barretenberg/smt_verification/utl/smt_util.hpp" that contains two useful functions: @@ -191,22 +191,22 @@ Also, there is a header file "barretenberg/smt_verification/utl/smt_util.hpp" th These functions will write witness variables in c-like array format into file named `fname`. The vector of `special_names` is the values that you want ot see in stdout. -`pack` argument tells this function to save an `msgpack` buffer of the witness on disk. Name of the file will be `fname`.pack +`pack` argument tells this function to save an `msgpack` buffer of the witness on disk. Name of the file will be `fname`.pack You can then import the saved witness using one of the following functions: - `vec> import_witness(str fname)` - `vec import_witness_single(str fname)` - + ## 4. Automated verification of a unique witness -There's a static member of `StandardCircuit` and `UltraCircuit` +There's a static member of `StandardCircuit` and `UltraCircuit` - `pair StandardCircuit::unique_wintes(CircuitSchema circuit_info, Solver*, TermType type, vector equal, bool optimizations)` - `pair UltraCircuit::unique_wintes(CircuitSchema circuit_info, Solver*, TermType type, vector equal, bool optimizations)` They will create two separate circuits, constrain variables with names from `equal` to be equal acrosss the circuits, and set all the other variables to be not equal at the same time. -Another one is +Another one is - `pair StandardCircuit::unique_witness_ext(CircuitSchema circuit_info, Solver* s, TermType type, vector equal_variables, vector nequal_variables, vector at_least_one_equal_variable, vector at_least_one_nequal_variable)` that does the same but provides you with more flexible settings. - Same in `UltraCircuit` @@ -372,9 +372,9 @@ void model_variables(SymCircuit& c, Solver* s, FFTerm& evaluation) ``` -More examples can be found in +More examples can be found in - [terms/ffterm.test.cpp](terms/ffterm.test.cpp), [terms/ffiterm.test.cpp](terms/ffiterm.test.cpp), [terms/bvterm.test.cpp](terms/bvterm.test.cpp), [terms/iterm.test.cpp](terms/iterm.test.cpp) -- [circuit/standard_circuit.test.cpp](circuit/standard_circuit.test.cpp), [circuit/ultra_circuit](circuit/ultra_circuit.test.cpp) +- [circuit/standard_circuit.test.cpp](circuit/standard_circuit.test.cpp), [circuit/ultra_circuit](circuit/ultra_circuit.test.cpp) - [smt_polynomials.test.cpp](smt_polynomials.test.cpp), [smt_examples.test.cpp](smt_examples.test.cpp) - [bb_tests](bb_tests) diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.cpp index 155a17c8c63b..a186b6ff2db8 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.cpp @@ -30,18 +30,18 @@ UltraCircuit::UltraCircuit( // add gate in its normal state to solver size_t arith_cursor = 0; - while (arith_cursor < this->selectors[1].size()) { - arith_cursor = this->handle_arithmetic_relation(arith_cursor, 1); + while (arith_cursor < this->selectors[2].size()) { + arith_cursor = this->handle_arithmetic_relation(arith_cursor, 2); } size_t lookup_cursor = 0; - while (lookup_cursor < this->selectors[5].size()) { - lookup_cursor = this->handle_lookup_relation(lookup_cursor, 5); + while (lookup_cursor < this->selectors[1].size()) { + lookup_cursor = this->handle_lookup_relation(lookup_cursor, 1); } size_t elliptic_cursor = 0; - while (elliptic_cursor < this->selectors[3].size()) { - elliptic_cursor = this->handle_elliptic_relation(elliptic_cursor, 3); + while (elliptic_cursor < this->selectors[4].size()) { + elliptic_cursor = this->handle_elliptic_relation(elliptic_cursor, 4); } // size_t delta_range_cursor = 0; @@ -88,7 +88,7 @@ size_t UltraCircuit::handle_arithmetic_relation(size_t cursor, size_t idx) std::vector boolean_gate = { 1, -1, 0, 0, 0, 0, 1, 0, 0, 0, 0 }; bool boolean_gate_flag = - (boolean_gate == selectors[1][cursor]) && (w_l_idx == w_r_idx) && (w_o_idx == 0) && (w_4_idx == 0); + (boolean_gate == selectors[idx][cursor]) && (w_l_idx == w_r_idx) && (w_o_idx == 0) && (w_4_idx == 0); if (boolean_gate_flag) { (Bool(w_l) == Bool(STerm(0, this->solver, this->type)) | Bool(w_l) == Bool(STerm(1, this->solver, this->type))) .assert_term(); @@ -292,7 +292,7 @@ size_t UltraCircuit::handle_elliptic_relation(size_t cursor, size_t idx) y_add_identity == 0; // scaling_factor = 1 } - bb::fr curve_b = this->selectors[3][cursor][11]; + bb::fr curve_b = this->selectors[idx][cursor][11]; auto x_pow_4 = (y1_sqr - curve_b) * x_1; auto y1_sqr_mul_4 = y1_sqr + y1_sqr; y1_sqr_mul_4 += y1_sqr_mul_4; diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.hpp index 5fd3be1fd4d8..f8f504337584 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.hpp @@ -15,11 +15,11 @@ class UltraCircuit : public CircuitBase { public: // TODO(alex): check that there's no actual pub_inputs block std::vector>> selectors; // all selectors from the circuit - // 1st entry are arithmetic selectors - // 2nd entry are delta_range selectors - // 3rd entry are elliptic selectors - // 4th entry are aux selectors - // 5th entry are lookup selectors + // 1st entry are lookup selectors + // 2nd entry are arithmetic selectors + // 3rd entry are delta_range selectors + // 4th entry are elliptic selectors + // 5th entry are aux selectors std::vector>> wires_idxs; // values of the gates' wires idxs std::vector>> lookup_tables; diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/bvterm.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/bvterm.test.cpp index bb7de54a063e..14642ffe5cc0 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/bvterm.test.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/terms/bvterm.test.cpp @@ -1,3 +1,4 @@ +#include #include #include "barretenberg/stdlib/primitives/uint/uint.hpp" @@ -317,6 +318,60 @@ TEST(BVTerm, shl) ASSERT_TRUE(s.check()); + std::string xvals = s.getValue(y.term).getBitVectorValue(); + STerm bval = STerm(b.get_value(), &s, TermType::BVTerm); + std::string bvals = s.getValue(bval.term).getBitVectorValue(); + ASSERT_EQ(bvals, xvals); +} + +TEST(BVTerm, truncate) +{ + StandardCircuitBuilder builder; + uint_ct a = witness_ct(&builder, engine.get_random_uint32()); + unsigned int mask = (1 << 10) - 1; + uint_ct b = a & mask; + + uint32_t modulus_base = 16; + uint32_t bitvector_size = 32; + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + modulus_base, + bitvector_size); + + STerm x = BVVar("x", &s); + STerm y = x.truncate(9); + + x == a.get_value(); + + ASSERT_TRUE(s.check()); + + std::string xvals = s.getValue(y.term).getBitVectorValue(); + STerm bval = STerm(b.get_value(), &s, TermType::BVTerm); + std::string bvals = s.getValue(bval.term).getBitVectorValue(); + ASSERT_EQ(bvals, xvals); +} + +TEST(BVTerm, extract_bit) +{ + StandardCircuitBuilder builder; + uint_ct a = witness_ct(&builder, engine.get_random_uint32()); + unsigned int mask = (1 << 10); + uint_ct b = a & mask; + + uint32_t modulus_base = 16; + uint32_t bitvector_size = 32; + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + modulus_base, + bitvector_size); + + STerm x = BVVar("x", &s); + STerm y = x.extract_bit(10); + + x == a.get_value(); + + ASSERT_TRUE(s.check()); + std::string xvals = s.getValue(y.term).getBitVectorValue(); STerm bval = STerm(b.get_value(), &s, TermType::BVTerm); std::string bvals = s.getValue(bval.term).getBitVectorValue(); diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.cpp index cb55e0eed65c..995eaae47fe7 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.cpp @@ -429,6 +429,35 @@ STerm STerm::rotl(const uint32_t& n) const return { res, this->solver, this->type }; } +STerm STerm::truncate(const uint32_t& to_size) +{ + if (!this->operations.contains(OpType::EXTRACT) || !this->operations.contains(OpType::BITVEC_PAD)) { + info("EXTRACT is not compatible with ", this->type); + return *this; + } + cvc5::Op extraction = solver->term_manager.mkOp(this->operations.at(OpType::EXTRACT), { to_size, 0 }); + cvc5::Term temp = solver->term_manager.mkTerm(extraction, { this->term }); + cvc5::Op padding = + solver->term_manager.mkOp(this->operations.at(OpType::BITVEC_PAD), + { this->solver->bv_sort.getBitVectorSize() - temp.getSort().getBitVectorSize() }); + cvc5::Term res = solver->term_manager.mkTerm(padding, { temp }); + return { res, this->solver, this->type }; +} + +STerm STerm::extract_bit(const uint32_t& bit_index) +{ + if (!this->operations.contains(OpType::EXTRACT) || !this->operations.contains(OpType::BITVEC_PAD)) { + info("EXTRACT is not compatible with ", this->type); + return *this; + } + cvc5::Op extraction = solver->term_manager.mkOp(this->operations.at(OpType::EXTRACT), { bit_index, bit_index }); + cvc5::Term temp = solver->term_manager.mkTerm(extraction, { this->term }); + cvc5::Op padding = + solver->term_manager.mkOp(this->operations.at(OpType::BITVEC_PAD), + { this->solver->bv_sort.getBitVectorSize() - temp.getSort().getBitVectorSize() }); + cvc5::Term res = solver->term_manager.mkTerm(padding, { temp }); + return { res, this->solver, this->type }; +} /** * @brief Create an inclusion constraint * diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.hpp index a9ac3b56a474..d5a7bd5d0003 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.hpp @@ -15,7 +15,28 @@ using namespace smt_solver; enum class TermType { FFTerm, FFITerm, BVTerm, ITerm }; std::ostream& operator<<(std::ostream& os, TermType type); -enum class OpType : int32_t { ADD, SUB, MUL, DIV, NEG, XOR, AND, OR, GT, GE, LT, LE, MOD, RSH, LSH, ROTR, ROTL, NOT }; +enum class OpType : int32_t { + ADD, + SUB, + MUL, + DIV, + NEG, + XOR, + AND, + OR, + GT, + GE, + LT, + LE, + MOD, + RSH, + LSH, + ROTR, + ROTL, + NOT, + EXTRACT, + BITVEC_PAD +}; /** * @brief precomputed map that contains allowed @@ -75,6 +96,8 @@ const std::unordered_map> typed { OpType::MOD, cvc5::Kind::BITVECTOR_UREM }, { OpType::DIV, cvc5::Kind::BITVECTOR_UDIV }, { OpType::NOT, cvc5::Kind::BITVECTOR_NOT }, + { OpType::EXTRACT, cvc5::Kind::BITVECTOR_EXTRACT }, + { OpType::BITVEC_PAD, cvc5::Kind::BITVECTOR_ZERO_EXTEND }, } } }; @@ -174,6 +197,17 @@ class STerm { STerm rotr(const uint32_t& n) const; STerm rotl(const uint32_t& n) const; + /** + * @brief Returns last `to_size` bits of variable + * @param to_size number of bits to be extracted + */ + STerm truncate(const uint32_t& to_size); + /** + * @brief Returns ith bit of variable + * @param bit_index index of bit to be extracted + */ + STerm extract_bit(const uint32_t& bit_index); + void in(const cvc5::Term& table) const; operator std::string() const { return this->solver->stringify_term(term); }; diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp index cb4bcb2b5e16..50219769ed66 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp @@ -33,6 +33,12 @@ bb::fr string_to_fr(const std::string& number, int base, size_t step) res += std::strtoull(slice.data(), &ptr, base); } res = number[0] == '-' ? -res : res; + + if (base == 2 && number[0] == '1') { + auto max = bb::fr(uint256_t(1) << number.length()); + res -= max; + } + return res; }