Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P4Testgen] Fix problems with the reachability pass. #4789

Merged
merged 2 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 130 additions & 86 deletions backends/p4tools/common/compiler/reachability.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include "backends/p4tools/common/compiler/reachability.h"

#include <cstddef>
#include <functional>
#include <iostream>
#include <list>
#include <string>
#include <utility>
#include <vector>

#include "backends/p4tools/common/lib/table_utils.h"
#include "ir/declaration.h"
#include "ir/indexed_vector.h"
#include "ir/vector.h"
Expand Down Expand Up @@ -54,57 +54,78 @@ bool P4ProgramDCGCreator::preorder(const IR::ConstructorCallExpression *callExpr
return true;
}

bool P4ProgramDCGCreator::preorder(const IR::MethodCallExpression *method) {
// Check for application of a table.
CHECK_NULL(method->method);
if (const auto *path = method->method->to<IR::PathExpression>()) {
const auto *currentControl = findOrigCtxt<IR::P4Control>();
if (currentControl != nullptr) {
const auto *decl = currentControl->getDeclByName(path->path->name.name);
if (decl != nullptr) {
if (const auto *action = decl->to<IR::P4Action>()) {
for (const auto *arg : *method->arguments) {
visit(arg);
}
addEdge(method);
visit(action);
return false;
}
}
}
}
if (!method->method->is<IR::Member>()) {
return true;
}
const auto *member = method->method->to<IR::Member>();
for (const auto *arg : *method->arguments) {
bool P4ProgramDCGCreator::preorder(const IR::MethodCallExpression *call) {
CHECK_NULL(call->method);
for (const auto *arg : *call->arguments) {
visit(arg);
}
// Do not anylyse tables apply and value set index methods.
if (member->member != IR::IApply::applyMethodName && member->member.originalName != "index") {
return true;
if (call->method->type->is<IR::Type_Action>()) {
const auto *path = call->method->checkedTo<IR::PathExpression>();
const auto *action = getDeclaration(path->path, true)->checkedTo<IR::P4Action>();
addEdge(call);
visit(action);
return false;
}
if (const auto *pathExpr = member->expr->to<IR::PathExpression>()) {
const auto *currentControl = findOrigCtxt<IR::P4Control>();
if (currentControl != nullptr) {
const auto *decl = currentControl->getDeclByName(pathExpr->path->name.name);
visit(decl->checkedTo<IR::P4Table>());
if (call->method->type->is<IR::Type_Method>()) {
if (const auto *path = call->method->to<IR::PathExpression>()) {
const auto *method = getDeclaration(path->path, true)->checkedTo<IR::Method>();
visit(method);
return false;
}
const auto *parser = findOrigCtxt<IR::P4Parser>();
if (parser != nullptr) {
const auto *decl = parser->getDeclByName(pathExpr->path->name.name);
visit(decl->checkedTo<IR::Declaration_Instance>());
return true;
if (const auto *method = call->method->to<IR::Member>()) {
// Case where call->method is a Member expression. For table invocations, the
// qualifier of the member determines the table being invoked. For extern calls,
// the qualifier determines the extern object containing the method being invoked.
BUG_CHECK(method->expr, "Method call has unexpected format: %1%", call);

// Handle table calls.
if (method->expr->type->is<IR::Type_Table>()) {
const auto *tableDecl =
getDeclaration(method->expr->checkedTo<IR::PathExpression>()->path, true)
->checkedTo<IR::P4Table>();
visit(tableDecl);
return false;
}

// Handle extern calls. They may also be of Type_SpecializedCanonical.
if (method->expr->type->is<IR::Type_Extern>() ||
method->expr->type->is<IR::Type_SpecializedCanonical>()) {
// TODO: This is the wrong place to analyze parser value sets.
// They should be handled in the select expression.
if (method->member.originalName == "index") {
const auto *decl =
getDeclaration(method->expr->checkedTo<IR::PathExpression>()->path, true)
->checkedTo<IR::Declaration_Instance>();
visit(decl);
}
return false;
}

// Handle calls to header methods.
if (method->expr->type->is<IR::Type_Header>() ||
method->expr->type->is<IR::Type_HeaderUnion>()) {
if (method->member == "isValid" || method->member == "setInvalid" ||
method->member == "setValid") {
vlstill marked this conversation as resolved.
Show resolved Hide resolved
return false;
}
BUG("Unknown method call on header instance: %1%", call);
}

if (method->expr->type->is<IR::Type_Stack>()) {
if (method->member == "push_front" || method->member == "pop_front") {
vlstill marked this conversation as resolved.
Show resolved Hide resolved
return false;
}
BUG("Unknown method call on stack instance: %1%", call);
}

BUG("Unknown method member expression: %1% of type %2%", method->expr,
method->expr->type);
}
return true;
}
const auto *type = member->expr->type;
if (const auto *tableType = type->to<IR::Type_Table>()) {
visit(tableType->table);
return false;

BUG("Unknown method call: %1% of type %2%", call->method, call->method->node_type_name());
}
return true;

BUG("Unsupported method call type for %1%: %2%", call, call->method->type);
}

bool P4ProgramDCGCreator::preorder(const IR::MethodCallStatement *method) {
Expand Down Expand Up @@ -133,7 +154,6 @@ bool P4ProgramDCGCreator::preorder(const IR::P4Parser *parser) {

bool P4ProgramDCGCreator::preorder(const IR::P4Table *table) {
addEdge(table, table->name);
DCGVertexTypeSet prevSet;
if (table->annotations != nullptr) {
for (const auto *annotation : table->annotations->annotations) {
visit(annotation);
Expand All @@ -148,20 +168,40 @@ bool P4ProgramDCGCreator::preorder(const IR::P4Table *table) {
}
}
}
TableUtils::TableProperties properties;
TableUtils::checkTableImmutability(*table, properties);

auto storedSet = prev;
const auto *entryList = table->getEntries();
if (entryList != nullptr) {
for (const auto *entry : entryList->entries) {
DCGVertexTypeSet prevSet;
if (properties.tableIsImmutable) {
// We can only match on entries when there are keys present.
if (table->getKey() != nullptr) {
const auto *entryList = table->getEntries();
if (entryList != nullptr) {
for (const auto *entry : entryList->entries) {
prev = storedSet;
visit(entry);
prevSet.insert(prev.begin(), prev.end());
}
} else if (wasImplementations) {
prevSet.insert(prev.begin(), prev.end());
}
}
// If the default action is immutable, we can only match on the default action.
if (properties.defaultIsImmutable) {
prev = storedSet;
visit(entry);
visit(table->getDefaultAction());
prevSet.insert(prev.begin(), prev.end());
prev = prevSet;
return false;
}
} else if (wasImplementations) {
}
for (const auto *action : table->getActionList()->actionList) {
prev = storedSet;
visit(action);
prevSet.insert(prev.begin(), prev.end());
}
prev = storedSet;
visit(table->getDefaultAction());
prevSet.insert(prev.begin(), prev.end());

prev = prevSet;
return false;
}
Expand All @@ -179,16 +219,18 @@ bool P4ProgramDCGCreator::preorder(const IR::ParserState *parserState) {
}
if (parserState->selectExpression != nullptr) {
if (const auto *pathExpr = parserState->selectExpression->to<IR::PathExpression>()) {
if (pathExpr->path->name.name == IR::ParserState::accept ||
pathExpr->path->name.name == IR::ParserState::reject) {
addEdge(parserState->selectExpression);
return true;
const auto *declaration = getDeclaration(pathExpr->path)->checkedTo<IR::ParserState>();

BUG_CHECK(declaration != nullptr, "Parser state not found: %1%",
pathExpr->path->name.name);
if (visited.count(declaration) != 0U) {
addEdge(declaration, declaration->name);
return false;
}
visit(declaration);
} else {
visit(parserState->selectExpression);
}
if (visited.count(parserState->selectExpression) != 0U) {
return false;
}
visit(parserState->selectExpression);
}
return false;
}
Expand Down Expand Up @@ -220,8 +262,9 @@ bool P4ProgramDCGCreator::preorder(const IR::P4Program *program) {
// declaration instance.
auto filter = [pathExpr](const IR::IDeclaration *d) {
CHECK_NULL(d);
if (const auto *decl = d->to<IR::Declaration_Instance>())
if (const auto *decl = d->to<IR::Declaration_Instance>()) {
return pathExpr->path->name == decl->name;
}
return false;
};
// Convert the declaration instance into a constructor-call expression.
Expand All @@ -234,7 +277,7 @@ bool P4ProgramDCGCreator::preorder(const IR::P4Program *program) {
}
this->prev = {program};
for (const auto *arg : v) {
// Apply to the arguments.,
// Visit the blocks in order of the constructor arguments.
visit(arg);
}
return false;
Expand All @@ -248,7 +291,7 @@ bool P4ProgramDCGCreator::preorder(const IR::P4ValueSet *valueSet) {
bool P4ProgramDCGCreator::preorder(const IR::SelectExpression *selectExpression) {
visit(selectExpression->select);
DCGVertexTypeSet prevSet;
const auto *currentParser = findOrigCtxt<IR::P4Parser>();
const auto *currentParser = findContext<IR::P4Parser>();
BUG_CHECK(currentParser != nullptr, "Null parser pointer");
auto storedSet = prev;
for (const auto *selectCase : selectExpression->selectCases) {
Expand Down Expand Up @@ -278,6 +321,8 @@ bool P4ProgramDCGCreator::preorder(const IR::IfStatement *ifStatement) {
prev = storedSet;
visit(ifStatement->ifFalse);
next.insert(prev.begin(), prev.end());
} else {
next.insert(storedSet.begin(), storedSet.end());
}
prev = next;
return false;
Expand Down Expand Up @@ -317,7 +362,7 @@ bool P4ProgramDCGCreator::preorder(const IR::StatOrDecl *statOrDecl) {
return true;
}

void P4ProgramDCGCreator::addEdge(const DCGVertexType *vertex, IR::ID vertexName) {
void P4ProgramDCGCreator::addEdge(DCGVertexType vertex, const IR::ID &vertexName) {
LOG1("Add edge : " << prev.size() << "(" << *prev.begin() << ") : " << vertex);
for (const auto *p : prev) {
dcg->calls(p, vertex);
Expand All @@ -341,13 +386,13 @@ ReachabilityEngineState *ReachabilityEngineState::copy() {
return newState;
}

std::list<const DCGVertexType *> ReachabilityEngineState::getState() { return state; }
std::list<DCGVertexType> ReachabilityEngineState::getState() { return state; }

void ReachabilityEngineState::setState(std::list<const DCGVertexType *> ls) { state = ls; }
void ReachabilityEngineState::setState(const std::list<DCGVertexType> &ls) { state = ls; }

const DCGVertexType *ReachabilityEngineState::getPrevNode() { return prevNode; }
DCGVertexType ReachabilityEngineState::getPrevNode() { return prevNode; }

void ReachabilityEngineState::setPrevNode(const DCGVertexType *n) { prevNode = n; }
void ReachabilityEngineState::setPrevNode(DCGVertexType n) { prevNode = n; }

bool ReachabilityEngineState::isEmpty() { return state.empty(); }

Expand All @@ -357,15 +402,15 @@ ReachabilityEngine::ReachabilityEngine(const NodesCallGraph &dcg,
std::string reachabilityExpression,
bool eliminateAnnotations)
: dcg(dcg), hash(dcg.getHash()) {
std::list<const DCGVertexType *> start;
std::list<DCGVertexType> start;
start.push_back(nullptr);
size_t i = 0;
size_t j = 0;
reachabilityExpression += ";";
while ((i = reachabilityExpression.find(';')) != std::string::npos) {
auto addSubExpr = reachabilityExpression.substr(0, i);
addSubExpr += "+";
std::list<const DCGVertexType *> newStart;
std::list<DCGVertexType> newStart;
while ((j = addSubExpr.find('+')) != std::string::npos) {
auto dotSubExpr = addSubExpr.substr(0, j);
while (dotSubExpr[0] == ' ') {
Expand All @@ -377,7 +422,7 @@ ReachabilityEngine::ReachabilityEngine(const NodesCallGraph &dcg,
}
auto currentNames = getName(dotSubExpr);
if (eliminateAnnotations) {
std::unordered_set<const DCGVertexType *> result;
std::unordered_set<DCGVertexType> result;
for (auto i : currentNames) {
if (!i->is<IR::Annotation>()) {
result.insert(i);
Expand All @@ -403,9 +448,9 @@ ReachabilityEngine::ReachabilityEngine(const NodesCallGraph &dcg,
}
}

void ReachabilityEngine::annotationToStatements(const DCGVertexType *node,
std::unordered_set<const DCGVertexType *> &s) {
std::list<const DCGVertexType *> l = {node};
void ReachabilityEngine::annotationToStatements(DCGVertexType node,
std::unordered_set<DCGVertexType> &s) {
std::list<DCGVertexType> l = {node};
while (!l.empty()) {
const auto *nd = l.front();
l.pop_front();
Expand All @@ -427,12 +472,12 @@ void ReachabilityEngine::annotationToStatements(const DCGVertexType *node,
}
}

void ReachabilityEngine::addTransition(const DCGVertexType *left,
const std::unordered_set<const DCGVertexType *> &rightSet) {
void ReachabilityEngine::addTransition(DCGVertexType left,
const std::unordered_set<DCGVertexType> &rightSet) {
for (const auto *right : rightSet) {
auto i = userTransitions.find(left);
if (i == userTransitions.end()) {
std::list<const DCGVertexType *> l;
std::list<DCGVertexType> l;
l.push_back(right);
userTransitions.emplace(left, l);
} else {
Expand All @@ -442,15 +487,15 @@ void ReachabilityEngine::addTransition(const DCGVertexType *left,
}

const IR::Expression *ReachabilityEngine::addCondition(const IR::Expression *prev,
const DCGVertexType *currentState) {
DCGVertexType currentState) {
const auto *newCond = getCondition(currentState);
if (newCond == nullptr) {
return prev;
}
return new IR::BOr(IR::Type_Boolean::get(), newCond, prev);
}

std::unordered_set<const DCGVertexType *> ReachabilityEngine::getName(std::string name) {
std::unordered_set<DCGVertexType> ReachabilityEngine::getName(std::string name) {
std::string params;
if ((name.length() != 0U) && name[name.length() - 1] == ')') {
size_t n = name.find('(');
Expand Down Expand Up @@ -483,8 +528,7 @@ std::unordered_set<const DCGVertexType *> ReachabilityEngine::getName(std::strin
return i->second;
}

ReachabilityResult ReachabilityEngine::next(ReachabilityEngineState *state,
const DCGVertexType *next) {
ReachabilityResult ReachabilityEngine::next(ReachabilityEngineState *state, DCGVertexType next) {
CHECK_NULL(state);
CHECK_NULL(next);
if (forbiddenVertexes.count(next)) {
Expand All @@ -502,8 +546,8 @@ ReachabilityResult ReachabilityEngine::next(ReachabilityEngineState *state,
return std::make_pair(false, nullptr);
}
const IR::Expression *expr = nullptr;
std::list<const DCGVertexType *> newState;
std::list<const DCGVertexType *> currentState = state->getState();
std::list<DCGVertexType> newState;
std::list<DCGVertexType> currentState = state->getState();
for (const auto *i : currentState) {
if (i == nullptr) {
// Start from intial.
Expand Down Expand Up @@ -553,7 +597,7 @@ ReachabilityResult ReachabilityEngine::next(ReachabilityEngineState *state,

const NodesCallGraph &ReachabilityEngine::getDCG() { return dcg; }

const IR::Expression *ReachabilityEngine::getCondition(const DCGVertexType *n) {
const IR::Expression *ReachabilityEngine::getCondition(DCGVertexType n) {
auto i = conditions.find(n);
if (i != conditions.end()) {
return i->second;
Expand Down
Loading
Loading