Skip to content

Commit

Permalink
Make sure P4 expression optimization does not strip away types (#4300)
Browse files Browse the repository at this point in the history
and check typing sanity in ExecutionState::set
  • Loading branch information
vlstill authored Jan 8, 2024
1 parent b082745 commit d7dfee1
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 46 deletions.
4 changes: 2 additions & 2 deletions backends/p4tools/common/compiler/convert_varbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

namespace P4Tools {

/// Converts all existing Type_Varbit types in the program into a custom Size_Type_Varbit type.
/// Sized_Type_Varbit also contains information about the width that was assigned to the type by
/// Converts all existing Type_Varbit types in the program into a custom Extracted_Varbit type.
/// Extracted_Varbit also contains information about the width that was assigned to the type by
/// the extract call.
class ConvertVarbits : public Transform {
public:
Expand Down
5 changes: 3 additions & 2 deletions backends/p4tools/common/lib/symbolic_env.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <boost/container/vector.hpp>

#include "backends/p4tools/common/lib/model.h"
#include "frontends/p4/optimizeExpressions.h"
#include "ir/indexed_vector.h"
#include "ir/vector.h"
#include "ir/visitor.h"
Expand All @@ -28,7 +27,9 @@ const IR::Expression *SymbolicEnv::get(const IR::StateVariable &var) const {
bool SymbolicEnv::exists(const IR::StateVariable &var) const { return map.find(var) != map.end(); }

void SymbolicEnv::set(const IR::StateVariable &var, const IR::Expression *value) {
map[var] = P4::optimizeExpression(value);
BUG_CHECK(value->type && !value->type->is<IR::Type_Unknown>(),
"Cannot set value with unspecified type: %1%", value);
map[var] = value;
}

const IR::Expression *SymbolicEnv::subst(const IR::Expression *expr) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ std::vector<std::pair<IR::StateVariable, const IR::Expression *>> ExprStepper::s
ExecutionState &nextState, const std::vector<IR::StateVariable> &flatFields,
int varBitFieldSize) {
std::vector<std::pair<IR::StateVariable, const IR::Expression *>> fields;
for (const auto &fieldRef : flatFields) {
// Make a copy of the StateVariable so it can be modified in the varbit case (and it is just a
// pointer wrapper anyway).
for (IR::StateVariable fieldRef : flatFields) {
const auto *fieldType = fieldRef->type;
// If the header had a varbit, the header needs to be updated.
// We assign @param varbitFeldSize to the varbit field.
Expand All @@ -62,15 +64,12 @@ std::vector<std::pair<IR::StateVariable, const IR::Expression *>> ExprStepper::s
// We need to cast the generated variable to the appropriate type.
if (fieldType->is<IR::Extracted_Varbits>()) {
pktVar = new IR::Cast(fieldType, pktVar);
// Update the field and add the field to the return list.
// TODO: Better way to handle varbits here?
auto *newRef = fieldRef->clone();
newRef->type = fieldType;
nextState.set(fieldRef, pktVar);
fields.emplace_back(fieldRef, pktVar);
continue;
}
if (const auto *bits = fieldType->to<IR::Type_Bits>()) {
// Rewrite the type of the field so it matches the extracted varbit type.
// TODO: is there a better way to do this?
auto *newRefExpr = fieldRef->clone();
newRefExpr->type = fieldType;
fieldRef.ref = newRefExpr;
} else if (const auto *bits = fieldType->to<IR::Type_Bits>()) {
if (bits->isSigned) {
pktVar = new IR::Cast(fieldType, pktVar);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ const IR::Expression *TableStepper::computeHit(TableMatchMap *matches) {
const IR::StringLiteral *TableStepper::getTableActionString(
const IR::MethodCallExpression *actionCall) {
cstring actionName = actionCall->method->toString();
return new IR::StringLiteral(actionName);
return new IR::StringLiteral(IR::Type_String::get(), actionName);
}

const IR::Expression *TableStepper::evalTableConstEntries() {
Expand Down
30 changes: 29 additions & 1 deletion backends/p4tools/modules/testgen/lib/execution_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "backends/p4tools/common/lib/taint.h"
#include "backends/p4tools/common/lib/trace_event.h"
#include "backends/p4tools/common/lib/variables.h"
#include "frontends/p4/optimizeExpressions.h"
#include "ir/id.h"
#include "ir/indexed_vector.h"
#include "ir/irutils.h"
Expand Down Expand Up @@ -178,10 +179,37 @@ void ExecutionState::markVisited(const IR::Node *node) {

const P4::Coverage::CoverageSet &ExecutionState::getVisited() const { return visitedNodes; }

/// Compare types, considering Extracted_Varbit and bits equal if the (real/extracted) sizes are
/// equal. This is because the packet expression can be something like 0 ++
/// (Extracted_Varbit<N>)pkt_var. This expression is typed as bit<N>, but the optimizer removes the
/// 0 ++ and makes it into Extracted_Varbit type.
/// TODO: Maybe there is a better way to handle varbit that could allow us to avoid this.
static bool typeEquivSansVarbit(const IR::Type *a, const IR::Type *b) {
if (a->equiv(*b)) {
return true;
}
const auto *abit = a->to<IR::Type_Bits>();
const auto *avar = a->to<IR::Extracted_Varbits>();
const auto *bbit = b->to<IR::Type_Bits>();
const auto *bvar = b->to<IR::Extracted_Varbits>();
return (abit && bvar && abit->width_bits() == bvar->width_bits()) ||
(avar && bbit && avar->width_bits() == bbit->width_bits());
}

void ExecutionState::set(const IR::StateVariable &var, const IR::Expression *value) {
const auto *type = value->type;
BUG_CHECK(type && !type->is<IR::Type_Unknown>(), "Cannot set value with unspecified type: %1%",
value);
if (getProperty<bool>("inUndefinedState")) {
// If we are in an undefined state, the variable we set is tainted.
value = ToolsVariables::getTaintExpression(value->type);
value = ToolsVariables::getTaintExpression(type);
} else {
value = P4::optimizeExpression(value);
BUG_CHECK(value->type && !value->type->is<IR::Type_Unknown>(),
"The P4 expression optimizer stripped a type of %1% (was %2%)", value, type);
BUG_CHECK(typeEquivSansVarbit(type, value->type),
"The P4 expression optimizer had changed type of %1% (%2% -> %3%)", value, type,
value->type);
}
env.set(var, value);
}
Expand Down
7 changes: 4 additions & 3 deletions frontends/common/constantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,9 @@ const IR::Node *DoConstantFolding::compare(const IR::Operation_Binary *e) {
auto ri = rlist->components.at(i);
const IR::Operation_Binary *tmp;
if (eqTest)
tmp = new IR::Equ(li, ri);
tmp = new IR::Equ(IR::Type_Boolean::get(), li, ri);
else
tmp = new IR::Neq(li, ri);
tmp = new IR::Neq(IR::Type_Boolean::get(), li, ri);
auto cmp = compare(tmp);
auto boolLit = cmp->to<IR::BoolLiteral>();
if (boolLit == nullptr) return e;
Expand Down Expand Up @@ -960,7 +960,8 @@ const IR::Node *DoConstantFolding::postorder(IR::SelectExpression *expression) {
finished = true;
if (someUnknown) {
if (!c->keyset->is<IR::DefaultExpression>()) changes = true;
auto newc = new IR::SelectCase(c->srcInfo, new IR::DefaultExpression(), c->state);
auto newc = new IR::SelectCase(
c->srcInfo, new IR::DefaultExpression(expression->select->type), c->state);
cases.push_back(newc);
} else {
// This is the result.
Expand Down
68 changes: 41 additions & 27 deletions frontends/p4/strengthReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ const IR::Node *DoStrengthReduction::postorder(IR::BAnd *expr) {
auto l = expr->left->to<IR::Cmpl>();
auto r = expr->right->to<IR::Cmpl>();
if (l && r)
return new IR::Cmpl(expr->type, new IR::BOr(expr->srcInfo, expr->type, l->expr, r->expr));
return new IR::Cmpl(expr->srcInfo, expr->type,
new IR::BOr(expr->srcInfo, expr->type, l->expr, r->expr));

if (hasSideEffects(expr)) return expr;
if (isZero(expr->left)) return expr->left;
Expand All @@ -95,7 +96,9 @@ const IR::Node *DoStrengthReduction::postorder(IR::BOr *expr) {
if (isZero(expr->right)) return expr->left;
auto l = expr->left->to<IR::Cmpl>();
auto r = expr->right->to<IR::Cmpl>();
if (l && r) return new IR::Cmpl(new IR::BAnd(expr->srcInfo, expr->type, l->expr, r->expr));
if (l && r)
return new IR::Cmpl(expr->srcInfo, expr->type,
new IR::BAnd(expr->srcInfo, expr->type, l->expr, r->expr));
if (hasSideEffects(expr)) return expr;
if (expr->left->equiv(*expr->right)) return expr->left;
return expr;
Expand Down Expand Up @@ -143,15 +146,15 @@ const IR::Node *DoStrengthReduction::postorder(IR::Equ *expr) {
if (isTrue(expr->left)) return expr->right;
if (isTrue(expr->right)) return expr->left;
// a == false is the same as !a
if (isFalse(expr->left)) return new IR::LNot(expr->right);
if (isFalse(expr->right)) return new IR::LNot(expr->left);
if (isFalse(expr->left)) return new IR::LNot(expr->srcInfo, expr->type, expr->right);
if (isFalse(expr->right)) return new IR::LNot(expr->srcInfo, expr->type, expr->left);
return expr;
}

const IR::Node *DoStrengthReduction::postorder(IR::Neq *expr) {
// a != true is the same as !a
if (isTrue(expr->left)) return new IR::LNot(expr->right);
if (isTrue(expr->right)) return new IR::LNot(expr->left);
if (isTrue(expr->left)) return new IR::LNot(expr->srcInfo, expr->type, expr->right);
if (isTrue(expr->right)) return new IR::LNot(expr->srcInfo, expr->type, expr->left);
// a != false is the same as a
if (isFalse(expr->left)) return expr->right;
if (isFalse(expr->right)) return expr->left;
Expand All @@ -160,12 +163,18 @@ const IR::Node *DoStrengthReduction::postorder(IR::Neq *expr) {

const IR::Node *DoStrengthReduction::postorder(IR::LNot *expr) {
if (auto e = expr->expr->to<IR::LNot>()) return e->expr;
if (auto e = expr->expr->to<IR::Equ>()) return new IR::Neq(e->left, e->right);
if (auto e = expr->expr->to<IR::Neq>()) return new IR::Equ(e->left, e->right);
if (auto e = expr->expr->to<IR::Leq>()) return new IR::Grt(e->left, e->right);
if (auto e = expr->expr->to<IR::Geq>()) return new IR::Lss(e->left, e->right);
if (auto e = expr->expr->to<IR::Lss>()) return new IR::Geq(e->left, e->right);
if (auto e = expr->expr->to<IR::Grt>()) return new IR::Leq(e->left, e->right);
if (auto e = expr->expr->to<IR::Equ>())
return new IR::Neq(expr->srcInfo, expr->type, e->left, e->right);
if (auto e = expr->expr->to<IR::Neq>())
return new IR::Equ(expr->srcInfo, expr->type, e->left, e->right);
if (auto e = expr->expr->to<IR::Leq>())
return new IR::Grt(expr->srcInfo, expr->type, e->left, e->right);
if (auto e = expr->expr->to<IR::Geq>())
return new IR::Lss(expr->srcInfo, expr->type, e->left, e->right);
if (auto e = expr->expr->to<IR::Lss>())
return new IR::Geq(expr->srcInfo, expr->type, e->left, e->right);
if (auto e = expr->expr->to<IR::Grt>())
return new IR::Leq(expr->srcInfo, expr->type, e->left, e->right);
return expr;
}

Expand All @@ -176,13 +185,13 @@ const IR::Node *DoStrengthReduction::postorder(IR::Sub *expr) {
if (expr->right->is<IR::Constant>()) {
auto cst = expr->right->to<IR::Constant>();
auto neg = new IR::Constant(cst->srcInfo, cst->type, -cst->value, cst->base, true);
auto result = new IR::Add(expr->srcInfo, expr->left, neg);
auto result = new IR::Add(expr->srcInfo, expr->type, expr->left, neg);
return result;
}
if (hasSideEffects(expr)) return expr;
if (expr->left->equiv(*expr->right) && expr->left->type &&
!expr->left->type->is<IR::Type_Unknown>())
return new IR::Constant(expr->left->type, 0);
return new IR::Constant(expr->srcInfo, expr->left->type, 0);
return expr;
}

Expand Down Expand Up @@ -230,14 +239,14 @@ const IR::Node *DoStrengthReduction::postorder(IR::Mul *expr) {
if (isOne(expr->right)) return expr->left;
auto exp = isPowerOf2(expr->left);
if (exp >= 0) {
auto amt = new IR::Constant(exp);
auto sh = new IR::Shl(expr->srcInfo, expr->right, amt);
auto amt = new IR::Constant(expr->left->srcInfo, exp);
auto sh = new IR::Shl(expr->srcInfo, expr->type, expr->right, amt);
return sh;
}
exp = isPowerOf2(expr->right);
if (exp >= 0) {
auto amt = new IR::Constant(exp);
auto sh = new IR::Shl(expr->srcInfo, expr->left, amt);
auto amt = new IR::Constant(expr->right->srcInfo, exp);
auto sh = new IR::Shl(expr->srcInfo, expr->type, expr->left, amt);
return sh;
}
if (hasSideEffects(expr)) return expr;
Expand All @@ -254,8 +263,8 @@ const IR::Node *DoStrengthReduction::postorder(IR::Div *expr) {
if (isOne(expr->right)) return expr->left;
auto exp = isPowerOf2(expr->right);
if (exp >= 0) {
auto amt = new IR::Constant(exp);
auto sh = new IR::Shr(expr->srcInfo, expr->left, amt);
auto amt = new IR::Constant(expr->right->srcInfo, exp);
auto sh = new IR::Shr(expr->srcInfo, expr->type, expr->left, amt);
return sh;
}
if (isZero(expr->left) && !hasSideEffects(expr->right)) return expr->left;
Expand All @@ -272,8 +281,9 @@ const IR::Node *DoStrengthReduction::postorder(IR::Mod *expr) {
if (exp >= 0) {
big_int mask = 1;
mask = (mask << exp) - 1;
auto amt = new IR::Constant(expr->right->to<IR::Constant>()->type, mask);
auto sh = new IR::BAnd(expr->srcInfo, expr->left, amt);
auto amt =
new IR::Constant(expr->right->srcInfo, expr->right->to<IR::Constant>()->type, mask);
auto sh = new IR::BAnd(expr->srcInfo, expr->type, expr->left, amt);
return sh;
}
return expr;
Expand Down Expand Up @@ -301,7 +311,7 @@ const IR::Node *DoStrengthReduction::postorder(IR::Mux *expr) {
if (isTrue(expr->e1) && isFalse(expr->e2))
return expr->e0;
else if (isFalse(expr->e1) && isTrue(expr->e2))
return new IR::LNot(expr->e0);
return new IR::LNot(expr->srcInfo, expr->type, expr->e0);
else if (const auto *lnot = expr->e0->to<IR::LNot>()) {
expr->e0 = lnot->expr;
const auto *tmp = expr->e1;
Expand Down Expand Up @@ -372,8 +382,9 @@ const IR::Node *DoStrengthReduction::postorder(IR::Slice *expr) {
expr->e0 = shift_of;
expr->e1 = new IR::Constant(hi + shift_amt);
expr->e2 = new IR::Constant(0);
return new IR::Concat(expr->type, expr,
new IR::Constant(IR::Type_Bits::get(-(lo + shift_amt)), 0));
return new IR::Concat(
expr->srcInfo, expr->type, expr,
new IR::Constant(expr->srcInfo, IR::Type_Bits::get(-(lo + shift_amt)), 0));
}
}

Expand All @@ -393,8 +404,11 @@ const IR::Node *DoStrengthReduction::postorder(IR::Slice *expr) {
else
break;
} else {
return new IR::Concat(expr->type, new IR::Slice(cat->left, expr->getH() - rwidth, 0),
new IR::Slice(cat->right, rwidth - 1, expr->getL()));
return new IR::Concat(
expr->srcInfo, expr->type,
// type of slice is calculated by its constructor
new IR::Slice(cat->left->srcInfo, cat->left, expr->getH() - rwidth, 0),
new IR::Slice(cat->right->srcInfo, cat->right, rwidth - 1, expr->getL()));
}
}

Expand Down

0 comments on commit d7dfee1

Please sign in to comment.