Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P4Testgen] Make more functions of the P4Tools Z3 API accessible. #4322

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 59 additions & 128 deletions backends/p4tools/common/core/z3_solver.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "backends/p4tools/common/core/z3_solver.h"

#include <z3++.h>
#include <z3_api.h>

#include <algorithm>
Expand Down Expand Up @@ -49,101 +50,6 @@ const char *toString(z3::model m) { return Z3_model_to_string(m.ctx(), m); }
#define Z3_LOG(...)
#endif // NDEBUG

/// Translates P4 expressions into Z3. Any variables encountered are declared to a Z3 instance.
class Z3Translator : public virtual Inspector {
public:
/// Creates a Z3 translator. Any variables encountered during translation will be declared to
/// the Z3 instance encapsulated within the given solver.
explicit Z3Translator(Z3Solver &solver);

/// Handles unexpected nodes.
bool preorder(const IR::Node *node) override;

/// Translates casts.
bool preorder(const IR::Cast *cast) override;

/// Translates constants.
bool preorder(const IR::Constant *constant) override;
bool preorder(const IR::BoolLiteral *boolLiteral) override;
bool preorder(const IR::StringLiteral *stringLiteral) override;

/// Translates variables.
bool preorder(const IR::SymbolicVariable *var) override;

// Translations for unary operations.
bool preorder(const IR::Neg *op) override;
bool preorder(const IR::Cmpl *op) override;
bool preorder(const IR::LNot *op) override;

// Translations for binary operations.
bool preorder(const IR::Equ *op) override;
bool preorder(const IR::Neq *op) override;
bool preorder(const IR::Lss *op) override;
bool preorder(const IR::Leq *op) override;
bool preorder(const IR::Grt *op) override;
bool preorder(const IR::Geq *op) override;
bool preorder(const IR::Mod *op) override;
bool preorder(const IR::Add *op) override;
bool preorder(const IR::Sub *op) override;
bool preorder(const IR::Mul *op) override;
bool preorder(const IR::Div *op) override;
bool preorder(const IR::Shl *op) override;
bool preorder(const IR::Shr *op) override;
bool preorder(const IR::BAnd *op) override;
bool preorder(const IR::BOr *op) override;
bool preorder(const IR::BXor *op) override;
bool preorder(const IR::LAnd *op) override;
bool preorder(const IR::LOr *op) override;
bool preorder(const IR::Concat *op) override;

// Translations for ternary operations.
bool preorder(const IR::Mux *op) override;
bool preorder(const IR::Slice *op) override;

/// @returns the result of the translation.
z3::expr getResult() { return result; }

private:
/// Function type for a unary operator.
using Z3UnaryOp = z3::expr (*)(const z3::expr &);

/// Function type for a binary operator.
using Z3BinaryOp = z3::expr (*)(const z3::expr &, const z3::expr &);

/// Function type for a ternary operator.
using Z3TernaryOp = z3::expr (*)(const z3::expr &, const z3::expr &, const z3::expr &);

/// Handles recursion into unary operations.
///
/// @returns false.
bool recurseUnary(const IR::Operation_Unary *unary, Z3UnaryOp f);

/// Handles recursion into binary operations.
///
/// @returns false.
bool recurseBinary(const IR::Operation_Binary *binary, Z3BinaryOp f);

/// Handles recursion into ternary operations.
///
/// @returns false.
bool recurseTernary(const IR::Operation_Ternary *ternary, Z3TernaryOp f);

/// Rewrites a shift operation so that the type of the shift amount matches that of the number
/// being shifted.
///
/// P4 allows shift operands to have different types: when the number being shifted is a bit
/// vector, the shift amount can be an infinite-precision integer. This rewrites such
/// expressions so that the shift amount is a bit vector.
template <class ShiftType>
const ShiftType *rewriteShift(const ShiftType *shift) const;

/// The output of the translation.
z3::expr result;

/// The Z3 solver instance, to which variables will be declared as they are encountered.
Z3Solver &solver;
};

z3::sort Z3Solver::toSort(const IR::Type *type) {
BUG_CHECK(type, "Z3Solver::toSort with empty pointer");

Expand Down Expand Up @@ -262,6 +168,31 @@ void Z3Solver::timeout(unsigned tm) {
timeout_ = tm;
}

std::optional<bool> Z3Solver::interpretSolverResult(z3::check_result result) {
switch (result) {
case z3::sat:
Z3_LOG("result:%s", "sat");
return true;
case z3::unsat:
Z3_LOG("result:%s", "unsat");
return false;

default: // unknown
Z3_LOG("result:%s", "unknown");
return std::nullopt;
}
}

std::optional<bool> Z3Solver::checkSat() {
Util::ScopedTimer ctCheckSat("checkSat");
return interpretSolverResult(z3solver.check());
}

std::optional<bool> Z3Solver::checkSat(const z3::expr_vector &asserts) {
Util::ScopedTimer ctCheckSat("checkSat");
return interpretSolverResult(z3solver.check(asserts));
}

std::optional<bool> Z3Solver::checkSat(const std::vector<const Constraint *> &asserts) {
Util::ScopedTimer ctZ3("z3");
if (isIncremental) {
Expand All @@ -285,38 +216,27 @@ std::optional<bool> Z3Solver::checkSat(const std::vector<const Constraint *> &as
}
Z3_LOG("checking satisfiability for %d assertions",
isIncremental ? z3solver.assertions().size() : z3Assertions.size());
Util::ScopedTimer ctCheckSat("checkSat");
z3::check_result result = isIncremental ? z3solver.check() : z3solver.check(z3Assertions);
switch (result) {
case z3::sat:
Z3_LOG("result:%s", "sat");
return true;
case z3::unsat:
Z3_LOG("result:%s", "unsat");
return false;

default: // unknown
Z3_LOG("result:%s", "unknown");
return std::nullopt;
}
return isIncremental ? checkSat() : checkSat(z3Assertions);
}

void Z3Solver::asrt(const Constraint *assertion) {
CHECK_NULL(assertion);
try {
Z3Translator z3translator(*this);
assertion->apply(z3translator);
auto expr = z3translator.getResult();
Z3Translator z3translator(*this);
auto expr = z3translator.translate(assertion);
asrt(expr);
p4Assertions.push_back(assertion);
BUG_CHECK(isIncremental || z3Assertions.size() == p4Assertions.size(),
"Number of assertion in P4 and Z3 formats aren't equal");
}

Z3_LOG("add assertion '%s'", toString(expr));
void Z3Solver::asrt(const z3::expr &assertion) {
try {
Z3_LOG("add assertion '%s'", toString(assertion));
if (isIncremental) {
z3solver.add(expr);
z3solver.add(assertion);
} else {
z3Assertions.push_back(expr);
z3Assertions.push_back(assertion);
}
p4Assertions.push_back(assertion);
BUG_CHECK(isIncremental || z3Assertions.size() == p4Assertions.size(),
"Number of assertion in P4 and Z3 formats aren't equal");
} catch (z3::exception &e) {
BUG("Z3Solver: Z3 exception: %1%\nAssertion %2%", e.msg(), assertion);
}
Expand Down Expand Up @@ -471,8 +391,8 @@ bool Z3Translator::preorder(const IR::Cast *cast) {
exprSize = exprType->width_bits();
} else if (castExtrType->is<IR::Type_Boolean>()) {
exprSize = 1;
auto trueVal = solver.ctx().bv_val(1, exprSize);
auto falseVal = solver.ctx().bv_val(0, exprSize);
auto trueVal = solver.get().ctx().bv_val(1, exprSize);
auto falseVal = solver.get().ctx().bv_val(0, exprSize);
castExpr = z3::ite(castExpr, trueVal, falseVal);
} else if (const auto *exprType = castExtrType->to<IR::Extracted_Varbits>()) {
exprSize = exprType->width_bits();
Expand All @@ -496,7 +416,7 @@ bool Z3Translator::preorder(const IR::Cast *cast) {
if (cast->destType->is<IR::Type_Boolean>()) {
if (const auto *exprType = castExtrType->to<IR::Type_Bits>()) {
if (exprType->width_bits() == 1) {
castExpr = z3::operator==(castExpr, solver.ctx().bv_val(1, 1));
castExpr = z3::operator==(castExpr, solver.get().ctx().bv_val(1, 1));
} else {
BUG("Cast expression type %1% is not bit<1> : %2%", exprType, castExpr);
}
Expand All @@ -514,36 +434,36 @@ bool Z3Translator::preorder(const IR::Cast *cast) {
bool Z3Translator::preorder(const IR::Constant *constant) {
// Handle infinite-integer constants.
if (constant->type->is<IR::Type_InfInt>()) {
result = solver.ctx().int_val(constant->value.str().c_str());
result = solver.get().ctx().int_val(constant->value.str().c_str());
return false;
}

// Handle bit<n> constants.
if (const auto *bits = constant->type->to<IR::Type_Bits>()) {
result = solver.ctx().bv_val(constant->value.str().c_str(), bits->size);
result = solver.get().ctx().bv_val(constant->value.str().c_str(), bits->size);
return false;
}

if (const auto *bits = constant->type->to<IR::Extracted_Varbits>()) {
result = solver.ctx().bv_val(constant->value.str().c_str(), bits->width_bits());
result = solver.get().ctx().bv_val(constant->value.str().c_str(), bits->width_bits());
return false;
}

BUG("Z3Translator: unsupported type for constant %1%", constant);
}

bool Z3Translator::preorder(const IR::BoolLiteral *boolLiteral) {
result = solver.ctx().bool_val(boolLiteral->value);
result = solver.get().ctx().bool_val(boolLiteral->value);
return false;
}

bool Z3Translator::preorder(const IR::StringLiteral *stringLiteral) {
result = solver.ctx().string_val(stringLiteral->value);
result = solver.get().ctx().string_val(stringLiteral->value);
return false;
}

bool Z3Translator::preorder(const IR::SymbolicVariable *var) {
result = solver.declareVar(*var);
result = solver.get().declareVar(*var);
return false;
}

Expand Down Expand Up @@ -694,4 +614,15 @@ bool Z3Translator::recurseTernary(const IR::Operation_Ternary *ternary, Z3Ternar
return false;
}

z3::expr Z3Translator::getResult() { return result; }

z3::expr Z3Translator::translate(const IR::Expression *expression) {
try {
expression->apply(*this);
} catch (z3::exception &e) {
BUG("Z3Translator: Z3 exception: %1%\nExpression %2%", e.msg(), expression);
}
return result;
}

} // namespace P4Tools
Loading
Loading