Skip to content

Commit

Permalink
Open up the Z3 api.
Browse files Browse the repository at this point in the history
  • Loading branch information
fruffy committed Mar 1, 2024
1 parent 1fba085 commit 55cd213
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 129 deletions.
187 changes: 59 additions & 128 deletions backends/p4tools/common/core/z3_solver.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "backends/p4tools/common/core/z3_solver.h"

#include <z3++.h>
#include <z3_api.h>

#include <algorithm>
Expand Down Expand Up @@ -49,101 +50,6 @@ const char *toString(z3::model m) { return Z3_model_to_string(m.ctx(), m); }
#define Z3_LOG(...)
#endif // NDEBUG

/// Translates P4 expressions into Z3. Any variables encountered are declared to a Z3 instance.
class Z3Translator : public virtual Inspector {
public:
/// Creates a Z3 translator. Any variables encountered during translation will be declared to
/// the Z3 instance encapsulated within the given solver.
explicit Z3Translator(Z3Solver &solver);

/// Handles unexpected nodes.
bool preorder(const IR::Node *node) override;

/// Translates casts.
bool preorder(const IR::Cast *cast) override;

/// Translates constants.
bool preorder(const IR::Constant *constant) override;
bool preorder(const IR::BoolLiteral *boolLiteral) override;
bool preorder(const IR::StringLiteral *stringLiteral) override;

/// Translates variables.
bool preorder(const IR::SymbolicVariable *var) override;

// Translations for unary operations.
bool preorder(const IR::Neg *op) override;
bool preorder(const IR::Cmpl *op) override;
bool preorder(const IR::LNot *op) override;

// Translations for binary operations.
bool preorder(const IR::Equ *op) override;
bool preorder(const IR::Neq *op) override;
bool preorder(const IR::Lss *op) override;
bool preorder(const IR::Leq *op) override;
bool preorder(const IR::Grt *op) override;
bool preorder(const IR::Geq *op) override;
bool preorder(const IR::Mod *op) override;
bool preorder(const IR::Add *op) override;
bool preorder(const IR::Sub *op) override;
bool preorder(const IR::Mul *op) override;
bool preorder(const IR::Div *op) override;
bool preorder(const IR::Shl *op) override;
bool preorder(const IR::Shr *op) override;
bool preorder(const IR::BAnd *op) override;
bool preorder(const IR::BOr *op) override;
bool preorder(const IR::BXor *op) override;
bool preorder(const IR::LAnd *op) override;
bool preorder(const IR::LOr *op) override;
bool preorder(const IR::Concat *op) override;

// Translations for ternary operations.
bool preorder(const IR::Mux *op) override;
bool preorder(const IR::Slice *op) override;

/// @returns the result of the translation.
z3::expr getResult() { return result; }

private:
/// Function type for a unary operator.
using Z3UnaryOp = z3::expr (*)(const z3::expr &);

/// Function type for a binary operator.
using Z3BinaryOp = z3::expr (*)(const z3::expr &, const z3::expr &);

/// Function type for a ternary operator.
using Z3TernaryOp = z3::expr (*)(const z3::expr &, const z3::expr &, const z3::expr &);

/// Handles recursion into unary operations.
///
/// @returns false.
bool recurseUnary(const IR::Operation_Unary *unary, Z3UnaryOp f);

/// Handles recursion into binary operations.
///
/// @returns false.
bool recurseBinary(const IR::Operation_Binary *binary, Z3BinaryOp f);

/// Handles recursion into ternary operations.
///
/// @returns false.
bool recurseTernary(const IR::Operation_Ternary *ternary, Z3TernaryOp f);

/// Rewrites a shift operation so that the type of the shift amount matches that of the number
/// being shifted.
///
/// P4 allows shift operands to have different types: when the number being shifted is a bit
/// vector, the shift amount can be an infinite-precision integer. This rewrites such
/// expressions so that the shift amount is a bit vector.
template <class ShiftType>
const ShiftType *rewriteShift(const ShiftType *shift) const;

/// The output of the translation.
z3::expr result;

/// The Z3 solver instance, to which variables will be declared as they are encountered.
Z3Solver &solver;
};

z3::sort Z3Solver::toSort(const IR::Type *type) {
BUG_CHECK(type, "Z3Solver::toSort with empty pointer");

Expand Down Expand Up @@ -262,6 +168,31 @@ void Z3Solver::timeout(unsigned tm) {
timeout_ = tm;
}

std::optional<bool> Z3Solver::interpretSolverResult(z3::check_result result) {
switch (result) {
case z3::sat:
Z3_LOG("result:%s", "sat");
return true;
case z3::unsat:
Z3_LOG("result:%s", "unsat");
return false;

default: // unknown
Z3_LOG("result:%s", "unknown");
return std::nullopt;
}
}

std::optional<bool> Z3Solver::checkSat() {
Util::ScopedTimer ctCheckSat("checkSat");
return interpretSolverResult(z3solver.check());
}

std::optional<bool> Z3Solver::checkSat(const z3::expr_vector &asserts) {
Util::ScopedTimer ctCheckSat("checkSat");
return interpretSolverResult(z3solver.check(asserts));
}

std::optional<bool> Z3Solver::checkSat(const std::vector<const Constraint *> &asserts) {
Util::ScopedTimer ctZ3("z3");
if (isIncremental) {
Expand All @@ -285,38 +216,27 @@ std::optional<bool> Z3Solver::checkSat(const std::vector<const Constraint *> &as
}
Z3_LOG("checking satisfiability for %d assertions",
isIncremental ? z3solver.assertions().size() : z3Assertions.size());
Util::ScopedTimer ctCheckSat("checkSat");
z3::check_result result = isIncremental ? z3solver.check() : z3solver.check(z3Assertions);
switch (result) {
case z3::sat:
Z3_LOG("result:%s", "sat");
return true;
case z3::unsat:
Z3_LOG("result:%s", "unsat");
return false;

default: // unknown
Z3_LOG("result:%s", "unknown");
return std::nullopt;
}
return isIncremental ? checkSat() : checkSat(z3Assertions);
}

void Z3Solver::asrt(const Constraint *assertion) {
CHECK_NULL(assertion);
try {
Z3Translator z3translator(*this);
assertion->apply(z3translator);
auto expr = z3translator.getResult();
Z3Translator z3translator(*this);
auto expr = z3translator.translate(assertion);
asrt(expr);
p4Assertions.push_back(assertion);
BUG_CHECK(isIncremental || z3Assertions.size() == p4Assertions.size(),
"Number of assertion in P4 and Z3 formats aren't equal");
}

Z3_LOG("add assertion '%s'", toString(expr));
void Z3Solver::asrt(const z3::expr &assertion) {
try {
Z3_LOG("add assertion '%s'", toString(assertion));
if (isIncremental) {
z3solver.add(expr);
z3solver.add(assertion);
} else {
z3Assertions.push_back(expr);
z3Assertions.push_back(assertion);
}
p4Assertions.push_back(assertion);
BUG_CHECK(isIncremental || z3Assertions.size() == p4Assertions.size(),
"Number of assertion in P4 and Z3 formats aren't equal");
} catch (z3::exception &e) {
BUG("Z3Solver: Z3 exception: %1%\nAssertion %2%", e.msg(), assertion);
}
Expand Down Expand Up @@ -471,8 +391,8 @@ bool Z3Translator::preorder(const IR::Cast *cast) {
exprSize = exprType->width_bits();
} else if (castExtrType->is<IR::Type_Boolean>()) {
exprSize = 1;
auto trueVal = solver.ctx().bv_val(1, exprSize);
auto falseVal = solver.ctx().bv_val(0, exprSize);
auto trueVal = solver.get().ctx().bv_val(1, exprSize);
auto falseVal = solver.get().ctx().bv_val(0, exprSize);
castExpr = z3::ite(castExpr, trueVal, falseVal);
} else if (const auto *exprType = castExtrType->to<IR::Extracted_Varbits>()) {
exprSize = exprType->width_bits();
Expand All @@ -496,7 +416,7 @@ bool Z3Translator::preorder(const IR::Cast *cast) {
if (cast->destType->is<IR::Type_Boolean>()) {
if (const auto *exprType = castExtrType->to<IR::Type_Bits>()) {
if (exprType->width_bits() == 1) {
castExpr = z3::operator==(castExpr, solver.ctx().bv_val(1, 1));
castExpr = z3::operator==(castExpr, solver.get().ctx().bv_val(1, 1));
} else {
BUG("Cast expression type %1% is not bit<1> : %2%", exprType, castExpr);
}
Expand All @@ -514,36 +434,36 @@ bool Z3Translator::preorder(const IR::Cast *cast) {
bool Z3Translator::preorder(const IR::Constant *constant) {
// Handle infinite-integer constants.
if (constant->type->is<IR::Type_InfInt>()) {
result = solver.ctx().int_val(constant->value.str().c_str());
result = solver.get().ctx().int_val(constant->value.str().c_str());
return false;
}

// Handle bit<n> constants.
if (const auto *bits = constant->type->to<IR::Type_Bits>()) {
result = solver.ctx().bv_val(constant->value.str().c_str(), bits->size);
result = solver.get().ctx().bv_val(constant->value.str().c_str(), bits->size);
return false;
}

if (const auto *bits = constant->type->to<IR::Extracted_Varbits>()) {
result = solver.ctx().bv_val(constant->value.str().c_str(), bits->width_bits());
result = solver.get().ctx().bv_val(constant->value.str().c_str(), bits->width_bits());
return false;
}

BUG("Z3Translator: unsupported type for constant %1%", constant);
}

bool Z3Translator::preorder(const IR::BoolLiteral *boolLiteral) {
result = solver.ctx().bool_val(boolLiteral->value);
result = solver.get().ctx().bool_val(boolLiteral->value);
return false;
}

bool Z3Translator::preorder(const IR::StringLiteral *stringLiteral) {
result = solver.ctx().string_val(stringLiteral->value);
result = solver.get().ctx().string_val(stringLiteral->value);
return false;
}

bool Z3Translator::preorder(const IR::SymbolicVariable *var) {
result = solver.declareVar(*var);
result = solver.get().declareVar(*var);
return false;
}

Expand Down Expand Up @@ -694,4 +614,15 @@ bool Z3Translator::recurseTernary(const IR::Operation_Ternary *ternary, Z3Ternar
return false;
}

z3::expr Z3Translator::getResult() { return result; }

z3::expr Z3Translator::translate(const IR::Expression *expression) {
try {
expression->apply(*this);
} catch (z3::exception &e) {
BUG("Z3Translator: Z3 exception: %1%\nExpression %2%", e.msg(), expression);
}
return result;
}

} // namespace P4Tools
Loading

0 comments on commit 55cd213

Please sign in to comment.