Skip to content

Commit

Permalink
[P4Testgen] Fix problems with the reachability pass. (#4789)
Browse files Browse the repository at this point in the history
* Fix problems with the reachability pass.

Signed-off-by: fruffy <[email protected]>

* Review comments.

Signed-off-by: fruffy <[email protected]>

---------

Signed-off-by: fruffy <[email protected]>
  • Loading branch information
fruffy authored Jul 31, 2024
1 parent d2c1827 commit 9d4e3c1
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 128 deletions.
217 changes: 131 additions & 86 deletions backends/p4tools/common/compiler/reachability.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include "backends/p4tools/common/compiler/reachability.h"

#include <cstddef>
#include <functional>
#include <iostream>
#include <list>
#include <string>
#include <utility>
#include <vector>

#include "backends/p4tools/common/lib/table_utils.h"
#include "ir/declaration.h"
#include "ir/indexed_vector.h"
#include "ir/vector.h"
Expand Down Expand Up @@ -54,57 +54,79 @@ bool P4ProgramDCGCreator::preorder(const IR::ConstructorCallExpression *callExpr
return true;
}

bool P4ProgramDCGCreator::preorder(const IR::MethodCallExpression *method) {
// Check for application of a table.
CHECK_NULL(method->method);
if (const auto *path = method->method->to<IR::PathExpression>()) {
const auto *currentControl = findOrigCtxt<IR::P4Control>();
if (currentControl != nullptr) {
const auto *decl = currentControl->getDeclByName(path->path->name.name);
if (decl != nullptr) {
if (const auto *action = decl->to<IR::P4Action>()) {
for (const auto *arg : *method->arguments) {
visit(arg);
}
addEdge(method);
visit(action);
return false;
}
}
}
}
if (!method->method->is<IR::Member>()) {
return true;
}
const auto *member = method->method->to<IR::Member>();
for (const auto *arg : *method->arguments) {
bool P4ProgramDCGCreator::preorder(const IR::MethodCallExpression *call) {
CHECK_NULL(call->method);
for (const auto *arg : *call->arguments) {
visit(arg);
}
// Do not anylyse tables apply and value set index methods.
if (member->member != IR::IApply::applyMethodName && member->member.originalName != "index") {
return true;
if (call->method->type->is<IR::Type_Action>()) {
const auto *path = call->method->checkedTo<IR::PathExpression>();
const auto *action = getDeclaration(path->path, true)->checkedTo<IR::P4Action>();
addEdge(call);
visit(action);
return false;
}
if (const auto *pathExpr = member->expr->to<IR::PathExpression>()) {
const auto *currentControl = findOrigCtxt<IR::P4Control>();
if (currentControl != nullptr) {
const auto *decl = currentControl->getDeclByName(pathExpr->path->name.name);
visit(decl->checkedTo<IR::P4Table>());
if (call->method->type->is<IR::Type_Method>()) {
if (const auto *path = call->method->to<IR::PathExpression>()) {
const auto *method = getDeclaration(path->path, true)->checkedTo<IR::Method>();
visit(method);
return false;
}
const auto *parser = findOrigCtxt<IR::P4Parser>();
if (parser != nullptr) {
const auto *decl = parser->getDeclByName(pathExpr->path->name.name);
visit(decl->checkedTo<IR::Declaration_Instance>());
return true;
if (const auto *method = call->method->to<IR::Member>()) {
// Case where call->method is a Member expression. For table invocations, the
// qualifier of the member determines the table being invoked. For extern calls,
// the qualifier determines the extern object containing the method being invoked.
BUG_CHECK(method->expr, "Method call has unexpected format: %1%", call);

// Handle table calls.
if (method->expr->type->is<IR::Type_Table>()) {
const auto *tableDecl =
getDeclaration(method->expr->checkedTo<IR::PathExpression>()->path, true)
->checkedTo<IR::P4Table>();
visit(tableDecl);
return false;
}

// Handle extern calls. They may also be of Type_SpecializedCanonical.
if (method->expr->type->is<IR::Type_Extern>() ||
method->expr->type->is<IR::Type_SpecializedCanonical>()) {
// TODO: This is the wrong place to analyze parser value sets.
// They should be handled in the select expression.
if (method->member.originalName == "index") {
const auto *decl =
getDeclaration(method->expr->checkedTo<IR::PathExpression>()->path, true)
->checkedTo<IR::Declaration_Instance>();
visit(decl);
}
return false;
}

// Handle calls to header methods.
if (method->expr->type->is<IR::Type_Header>() ||
method->expr->type->is<IR::Type_HeaderUnion>()) {
if (method->member == IR::Type_Header::isValid || IR::Type_Header::setInvalid ||
method->member == IR::Type_Header::setValid) {
return false;
}
BUG("Unknown method call on header instance: %1%", call);
}

if (method->expr->type->is<IR::Type_Stack>()) {
if (method->member == IR::Type_Stack::push_front ||
method->member == IR::Type_Stack::pop_front) {
return false;
}
BUG("Unknown method call on stack instance: %1%", call);
}

BUG("Unknown method member expression: %1% of type %2%", method->expr,
method->expr->type);
}
return true;
}
const auto *type = member->expr->type;
if (const auto *tableType = type->to<IR::Type_Table>()) {
visit(tableType->table);
return false;

BUG("Unknown method call: %1% of type %2%", call->method, call->method->node_type_name());
}
return true;

BUG("Unsupported method call type for %1%: %2%", call, call->method->type);
}

bool P4ProgramDCGCreator::preorder(const IR::MethodCallStatement *method) {
Expand Down Expand Up @@ -133,7 +155,6 @@ bool P4ProgramDCGCreator::preorder(const IR::P4Parser *parser) {

bool P4ProgramDCGCreator::preorder(const IR::P4Table *table) {
addEdge(table, table->name);
DCGVertexTypeSet prevSet;
if (table->annotations != nullptr) {
for (const auto *annotation : table->annotations->annotations) {
visit(annotation);
Expand All @@ -148,20 +169,40 @@ bool P4ProgramDCGCreator::preorder(const IR::P4Table *table) {
}
}
}
TableUtils::TableProperties properties;
TableUtils::checkTableImmutability(*table, properties);

auto storedSet = prev;
const auto *entryList = table->getEntries();
if (entryList != nullptr) {
for (const auto *entry : entryList->entries) {
DCGVertexTypeSet prevSet;
if (properties.tableIsImmutable) {
// We can only match on entries when there are keys present.
if (table->getKey() != nullptr) {
const auto *entryList = table->getEntries();
if (entryList != nullptr) {
for (const auto *entry : entryList->entries) {
prev = storedSet;
visit(entry);
prevSet.insert(prev.begin(), prev.end());
}
} else if (wasImplementations) {
prevSet.insert(prev.begin(), prev.end());
}
}
// If the default action is immutable, we can only match on the default action.
if (properties.defaultIsImmutable) {
prev = storedSet;
visit(entry);
visit(table->getDefaultAction());
prevSet.insert(prev.begin(), prev.end());
prev = prevSet;
return false;
}
} else if (wasImplementations) {
}
for (const auto *action : table->getActionList()->actionList) {
prev = storedSet;
visit(action);
prevSet.insert(prev.begin(), prev.end());
}
prev = storedSet;
visit(table->getDefaultAction());
prevSet.insert(prev.begin(), prev.end());

prev = prevSet;
return false;
}
Expand All @@ -179,16 +220,18 @@ bool P4ProgramDCGCreator::preorder(const IR::ParserState *parserState) {
}
if (parserState->selectExpression != nullptr) {
if (const auto *pathExpr = parserState->selectExpression->to<IR::PathExpression>()) {
if (pathExpr->path->name.name == IR::ParserState::accept ||
pathExpr->path->name.name == IR::ParserState::reject) {
addEdge(parserState->selectExpression);
return true;
const auto *declaration = getDeclaration(pathExpr->path)->checkedTo<IR::ParserState>();

BUG_CHECK(declaration != nullptr, "Parser state not found: %1%",
pathExpr->path->name.name);
if (visited.count(declaration) != 0U) {
addEdge(declaration, declaration->name);
return false;
}
visit(declaration);
} else {
visit(parserState->selectExpression);
}
if (visited.count(parserState->selectExpression) != 0U) {
return false;
}
visit(parserState->selectExpression);
}
return false;
}
Expand Down Expand Up @@ -220,8 +263,9 @@ bool P4ProgramDCGCreator::preorder(const IR::P4Program *program) {
// declaration instance.
auto filter = [pathExpr](const IR::IDeclaration *d) {
CHECK_NULL(d);
if (const auto *decl = d->to<IR::Declaration_Instance>())
if (const auto *decl = d->to<IR::Declaration_Instance>()) {
return pathExpr->path->name == decl->name;
}
return false;
};
// Convert the declaration instance into a constructor-call expression.
Expand All @@ -234,7 +278,7 @@ bool P4ProgramDCGCreator::preorder(const IR::P4Program *program) {
}
this->prev = {program};
for (const auto *arg : v) {
// Apply to the arguments.,
// Visit the blocks in order of the constructor arguments.
visit(arg);
}
return false;
Expand All @@ -248,7 +292,7 @@ bool P4ProgramDCGCreator::preorder(const IR::P4ValueSet *valueSet) {
bool P4ProgramDCGCreator::preorder(const IR::SelectExpression *selectExpression) {
visit(selectExpression->select);
DCGVertexTypeSet prevSet;
const auto *currentParser = findOrigCtxt<IR::P4Parser>();
const auto *currentParser = findContext<IR::P4Parser>();
BUG_CHECK(currentParser != nullptr, "Null parser pointer");
auto storedSet = prev;
for (const auto *selectCase : selectExpression->selectCases) {
Expand Down Expand Up @@ -278,6 +322,8 @@ bool P4ProgramDCGCreator::preorder(const IR::IfStatement *ifStatement) {
prev = storedSet;
visit(ifStatement->ifFalse);
next.insert(prev.begin(), prev.end());
} else {
next.insert(storedSet.begin(), storedSet.end());
}
prev = next;
return false;
Expand Down Expand Up @@ -317,7 +363,7 @@ bool P4ProgramDCGCreator::preorder(const IR::StatOrDecl *statOrDecl) {
return true;
}

void P4ProgramDCGCreator::addEdge(const DCGVertexType *vertex, IR::ID vertexName) {
void P4ProgramDCGCreator::addEdge(DCGVertexType vertex, const IR::ID &vertexName) {
LOG1("Add edge : " << prev.size() << "(" << *prev.begin() << ") : " << vertex);
for (const auto *p : prev) {
dcg->calls(p, vertex);
Expand All @@ -341,13 +387,13 @@ ReachabilityEngineState *ReachabilityEngineState::copy() {
return newState;
}

std::list<const DCGVertexType *> ReachabilityEngineState::getState() { return state; }
std::list<DCGVertexType> ReachabilityEngineState::getState() { return state; }

void ReachabilityEngineState::setState(std::list<const DCGVertexType *> ls) { state = ls; }
void ReachabilityEngineState::setState(const std::list<DCGVertexType> &ls) { state = ls; }

const DCGVertexType *ReachabilityEngineState::getPrevNode() { return prevNode; }
DCGVertexType ReachabilityEngineState::getPrevNode() { return prevNode; }

void ReachabilityEngineState::setPrevNode(const DCGVertexType *n) { prevNode = n; }
void ReachabilityEngineState::setPrevNode(DCGVertexType n) { prevNode = n; }

bool ReachabilityEngineState::isEmpty() { return state.empty(); }

Expand All @@ -357,15 +403,15 @@ ReachabilityEngine::ReachabilityEngine(const NodesCallGraph &dcg,
std::string reachabilityExpression,
bool eliminateAnnotations)
: dcg(dcg), hash(dcg.getHash()) {
std::list<const DCGVertexType *> start;
std::list<DCGVertexType> start;
start.push_back(nullptr);
size_t i = 0;
size_t j = 0;
reachabilityExpression += ";";
while ((i = reachabilityExpression.find(';')) != std::string::npos) {
auto addSubExpr = reachabilityExpression.substr(0, i);
addSubExpr += "+";
std::list<const DCGVertexType *> newStart;
std::list<DCGVertexType> newStart;
while ((j = addSubExpr.find('+')) != std::string::npos) {
auto dotSubExpr = addSubExpr.substr(0, j);
while (dotSubExpr[0] == ' ') {
Expand All @@ -377,7 +423,7 @@ ReachabilityEngine::ReachabilityEngine(const NodesCallGraph &dcg,
}
auto currentNames = getName(dotSubExpr);
if (eliminateAnnotations) {
std::unordered_set<const DCGVertexType *> result;
std::unordered_set<DCGVertexType> result;
for (auto i : currentNames) {
if (!i->is<IR::Annotation>()) {
result.insert(i);
Expand All @@ -403,9 +449,9 @@ ReachabilityEngine::ReachabilityEngine(const NodesCallGraph &dcg,
}
}

void ReachabilityEngine::annotationToStatements(const DCGVertexType *node,
std::unordered_set<const DCGVertexType *> &s) {
std::list<const DCGVertexType *> l = {node};
void ReachabilityEngine::annotationToStatements(DCGVertexType node,
std::unordered_set<DCGVertexType> &s) {
std::list<DCGVertexType> l = {node};
while (!l.empty()) {
const auto *nd = l.front();
l.pop_front();
Expand All @@ -427,12 +473,12 @@ void ReachabilityEngine::annotationToStatements(const DCGVertexType *node,
}
}

void ReachabilityEngine::addTransition(const DCGVertexType *left,
const std::unordered_set<const DCGVertexType *> &rightSet) {
void ReachabilityEngine::addTransition(DCGVertexType left,
const std::unordered_set<DCGVertexType> &rightSet) {
for (const auto *right : rightSet) {
auto i = userTransitions.find(left);
if (i == userTransitions.end()) {
std::list<const DCGVertexType *> l;
std::list<DCGVertexType> l;
l.push_back(right);
userTransitions.emplace(left, l);
} else {
Expand All @@ -442,15 +488,15 @@ void ReachabilityEngine::addTransition(const DCGVertexType *left,
}

const IR::Expression *ReachabilityEngine::addCondition(const IR::Expression *prev,
const DCGVertexType *currentState) {
DCGVertexType currentState) {
const auto *newCond = getCondition(currentState);
if (newCond == nullptr) {
return prev;
}
return new IR::BOr(IR::Type_Boolean::get(), newCond, prev);
}

std::unordered_set<const DCGVertexType *> ReachabilityEngine::getName(std::string name) {
std::unordered_set<DCGVertexType> ReachabilityEngine::getName(std::string name) {
std::string params;
if ((name.length() != 0U) && name[name.length() - 1] == ')') {
size_t n = name.find('(');
Expand Down Expand Up @@ -483,8 +529,7 @@ std::unordered_set<const DCGVertexType *> ReachabilityEngine::getName(std::strin
return i->second;
}

ReachabilityResult ReachabilityEngine::next(ReachabilityEngineState *state,
const DCGVertexType *next) {
ReachabilityResult ReachabilityEngine::next(ReachabilityEngineState *state, DCGVertexType next) {
CHECK_NULL(state);
CHECK_NULL(next);
if (forbiddenVertexes.count(next)) {
Expand All @@ -502,8 +547,8 @@ ReachabilityResult ReachabilityEngine::next(ReachabilityEngineState *state,
return std::make_pair(false, nullptr);
}
const IR::Expression *expr = nullptr;
std::list<const DCGVertexType *> newState;
std::list<const DCGVertexType *> currentState = state->getState();
std::list<DCGVertexType> newState;
std::list<DCGVertexType> currentState = state->getState();
for (const auto *i : currentState) {
if (i == nullptr) {
// Start from intial.
Expand Down Expand Up @@ -553,7 +598,7 @@ ReachabilityResult ReachabilityEngine::next(ReachabilityEngineState *state,

const NodesCallGraph &ReachabilityEngine::getDCG() { return dcg; }

const IR::Expression *ReachabilityEngine::getCondition(const DCGVertexType *n) {
const IR::Expression *ReachabilityEngine::getCondition(DCGVertexType n) {
auto i = conditions.find(n);
if (i != conditions.end()) {
return i->second;
Expand Down
Loading

0 comments on commit 9d4e3c1

Please sign in to comment.