Skip to content

Commit

Permalink
Map written LocationSets to program locations (loc_t) instead of IR::…
Browse files Browse the repository at this point in the history
…Expression*s

Signed-off-by: Kyle Cripps <[email protected]>
  • Loading branch information
kfcripps committed Jul 12, 2024
1 parent f24aacc commit 407aa2d
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 72 deletions.
45 changes: 39 additions & 6 deletions frontends/p4/def_use.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,10 @@ bool ComputeWriteSet::preorder(const IR::SelectExpression *expression) {
visit(expression->select);
visit(&expression->selectCases);
auto l = getWrites(expression->select);
for (auto c : expression->selectCases) {
auto s = getWrites(c->keyset);
const loc_t *selectCasesLoc = getLoc(&expression->selectCases, getChildContext());
for (auto *c : expression->selectCases) {
const loc_t *selectCaseLoc = getLoc(c, selectCasesLoc);
auto s = getWrites(c->keyset, selectCaseLoc);
l = l->join(s);
}
expressionWrites(expression, l);
Expand Down Expand Up @@ -673,7 +675,8 @@ bool ComputeWriteSet::preorder(const IR::MethodCallExpression *expression) {
lhs = save;
auto mi = MethodInstance::resolve(expression, storageMap->refMap, storageMap->typeMap);
if (auto bim = mi->to<BuiltInMethod>()) {
auto base = getWrites(bim->appliedTo);
const loc_t *methodLoc = getLoc(expression->method, getChildContext());
auto base = getWrites(bim->appliedTo, methodLoc);
cstring name = bim->name.name;
if (name == IR::Type_Header::setInvalid || name == IR::Type_Header::setValid) {
// modifies only the valid field.
Expand Down Expand Up @@ -712,7 +715,7 @@ bool ComputeWriteSet::preorder(const IR::MethodCallExpression *expression) {
LOG3("Analyzing callees of " << expression << DBPrint::Brief << callees << DBPrint::Reset
<< indent);
ProgramPoint pt(callingContext, expression);
ComputeWriteSet cw(this, pt, currentDefinitions);
ComputeWriteSet cw(this, pt, currentDefinitions, cached_locs);
cw.setCalledBy(this);
for (auto c : callees) (void)c->getNode()->apply(cw);
currentDefinitions = cw.currentDefinitions;
Expand All @@ -735,7 +738,8 @@ bool ComputeWriteSet::preorder(const IR::MethodCallExpression *expression) {
visit(arg);
lhs = save;
if (p->direction == IR::Direction::Out || p->direction == IR::Direction::InOut) {
auto val = getWrites(arg->expression);
const loc_t *argLoc = getLoc(arg, getChildContext());
auto val = getWrites(arg->expression, argLoc);
result = result->join(val);
}
}
Expand All @@ -759,6 +763,35 @@ void ComputeWriteSet::visitVirtualMethods(const IR::IndexedVector<IR::Declaratio
}
}

// Returns program location of n, given the program location of n's direct parent.
// Use to get loc if n is indirect child (e.g. grandchild) of currently being visited node.
// In this case parentLoc is the loc of n's direct parent.
const P4::ComputeWriteSet::loc_t *ComputeWriteSet::getLoc(const IR::Node *n,
const loc_t *parentLoc) {
loc_t tmp{n, parentLoc};
return &*cached_locs.insert(tmp).first;
}

// Returns program location given the context of the currently being visited node.
// Use to get loc of currently being visited node.
const P4::ComputeWriteSet::loc_t *ComputeWriteSet::getLoc(const Visitor::Context *ctxt) {
if (!ctxt) return nullptr;
loc_t tmp{ctxt->node, getLoc(ctxt->parent)};
return &*cached_locs.insert(tmp).first;
}

// Returns program location of a child node n, given the context of the
// currently being visited node.
// Use to get loc if n is direct child of currently being visited node.
const P4::ComputeWriteSet::loc_t *ComputeWriteSet::getLoc(const IR::Node *n,
const Visitor::Context *ctxt) {
for (auto *p = ctxt; p; p = p->parent)
if (p->node == n) return getLoc(p);
auto rv = getLoc(ctxt);
loc_t tmp{n, rv};
return &*cached_locs.insert(tmp).first;
}

// Symbolic execution of the parser
bool ComputeWriteSet::preorder(const IR::P4Parser *parser) {
LOG3("CWS Visiting " << dbp(parser));
Expand All @@ -784,7 +817,7 @@ bool ComputeWriteSet::preorder(const IR::P4Parser *parser) {
// but we use the same data structures
ProgramPoint pt(state);
currentDefinitions = allDefinitions->getDefinitions(pt);
ComputeWriteSet cws(this, pt, currentDefinitions);
ComputeWriteSet cws(this, pt, currentDefinitions, cached_locs);
cws.setCalledBy(this);
(void)state->apply(cws);

Expand Down
169 changes: 103 additions & 66 deletions frontends/p4/def_use.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "lib/alloc_trace.h"
#include "lib/hash.h"
#include "lib/hvec_map.h"
#include "lib/ordered_map.h"
#include "lib/ordered_set.h"
#include "typeMap.h"

Expand Down Expand Up @@ -476,6 +477,78 @@ class AllDefinitions : public IHasDbPrint {
*/

class ComputeWriteSet : public Inspector, public IHasDbPrint {
public:
// A location in the program. Includes the context from the visitor, which needs to
// be copied out of the Visitor::Context objects, as they are allocated on the stack and
// will become invalid as the IR traversal continues
struct loc_t {
const IR::Node *node;
const loc_t *parent;
bool operator<(const loc_t &a) const {
if (node != a.node) return node->id < a.node->id;
if (!parent || !a.parent) return parent != nullptr;
return *parent < *a.parent;
}
};

explicit ComputeWriteSet(AllDefinitions *allDefinitions)
: allDefinitions(allDefinitions),
currentDefinitions(nullptr),
returnedDefinitions(nullptr),
exitDefinitions(new Definitions()),
storageMap(allDefinitions->storageMap),
lhs(false),
virtualMethod(false),
cached_locs(*new std::set<loc_t>) {
CHECK_NULL(allDefinitions);
visitDagOnce = false;
}

// expressions
bool preorder(const IR::Literal *expression) override;
bool preorder(const IR::Slice *expression) override;
bool preorder(const IR::TypeNameExpression *expression) override;
bool preorder(const IR::PathExpression *expression) override;
bool preorder(const IR::Member *expression) override;
bool preorder(const IR::ArrayIndex *expression) override;
bool preorder(const IR::Operation_Binary *expression) override;
bool preorder(const IR::Mux *expression) override;
bool preorder(const IR::SelectExpression *expression) override;
bool preorder(const IR::ListExpression *expression) override;
bool preorder(const IR::Operation_Unary *expression) override;
bool preorder(const IR::MethodCallExpression *expression) override;
bool preorder(const IR::DefaultExpression *expression) override;
bool preorder(const IR::Expression *expression) override;
bool preorder(const IR::InvalidHeader *expression) override;
bool preorder(const IR::InvalidHeaderUnion *expression) override;
bool preorder(const IR::P4ListExpression *expression) override;
bool preorder(const IR::HeaderStackExpression *expression) override;
bool preorder(const IR::StructExpression *expression) override;
// statements
bool preorder(const IR::P4Parser *parser) override;
bool preorder(const IR::P4Control *control) override;
bool preorder(const IR::P4Action *action) override;
bool preorder(const IR::P4Table *table) override;
bool preorder(const IR::Function *function) override;
bool preorder(const IR::AssignmentStatement *statement) override;
bool preorder(const IR::ReturnStatement *statement) override;
bool preorder(const IR::ExitStatement *statement) override;
bool preorder(const IR::BreakStatement *statement) override;
bool handleJump(const char *tok, Definitions *&defs);
bool preorder(const IR::ContinueStatement *statement) override;
bool preorder(const IR::IfStatement *statement) override;
bool preorder(const IR::ForStatement *statement) override;
bool preorder(const IR::ForInStatement *statement) override;
bool preorder(const IR::BlockStatement *statement) override;
bool preorder(const IR::SwitchStatement *statement) override;
bool preorder(const IR::EmptyStatement *statement) override;
bool preorder(const IR::MethodCallStatement *statement) override;

const LocationSet *writtenLocations(const IR::Expression *expression) {
expression->apply(*this);
return getWrites(expression);
}

protected:
AllDefinitions *allDefinitions; /// Result computed by this pass.
Definitions *currentDefinitions; /// Before statement currently processed.
Expand All @@ -487,16 +560,17 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint {
const StorageMap *storageMap;
/// if true we are processing an expression on the lhs of an assignment
bool lhs;
/// For each expression the location set it writes
hvec_map<const IR::Expression *, const LocationSet *> writes;
/// For each program location the location set it writes
ordered_map<loc_t, const LocationSet *> writes;
bool virtualMethod; /// True if we are analyzing a virtual method
AllocTrace memuse;
alloc_trace_cb_t nested_trace;
static int nest_count;

/// Creates new visitor, but with same underlying data structures.
/// Needed to visit some program fragments repeatedly.
ComputeWriteSet(const ComputeWriteSet *source, ProgramPoint context, Definitions *definitions)
ComputeWriteSet(const ComputeWriteSet *source, ProgramPoint context, Definitions *definitions,
std::set<loc_t> &cached_locs)
: allDefinitions(source->allDefinitions),
currentDefinitions(definitions),
returnedDefinitions(nullptr),
Expand All @@ -506,10 +580,14 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint {
callingContext(context),
storageMap(source->storageMap),
lhs(false),
virtualMethod(false) {
virtualMethod(false),
cached_locs(cached_locs) {
visitDagOnce = false;
}
void visitVirtualMethods(const IR::IndexedVector<IR::Declaration> &locals);
const loc_t *getLoc(const IR::Node *n, const loc_t *parentLoc);
const loc_t *getLoc(const Visitor::Context *ctxt);
const loc_t *getLoc(const IR::Node *n, const Visitor::Context *ctxt);
void enterScope(const IR::ParameterList *parameters,
const IR::IndexedVector<IR::Declaration> *locals, ProgramPoint startPoint,
bool clear = true);
Expand All @@ -518,25 +596,39 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint {
Definitions *getDefinitionsAfter(const IR::ParserState *state);
bool setDefinitions(Definitions *defs, const IR::Node *who = nullptr, bool overwrite = false);
ProgramPoint getProgramPoint(const IR::Node *node = nullptr) const;
const LocationSet *getWrites(const IR::Expression *expression) const {
auto result = ::get(writes, expression);
// Get writes of a node that is a direct child of the currently being visited node.
const LocationSet *getWrites(const IR::Expression *expression) {
const loc_t &exprLoc = *getLoc(expression, getChildContext());
auto result = ::get(writes, exprLoc);
BUG_CHECK(result != nullptr, "No location set known for %1%", expression);
return result;
}
// Get writes of a node that is not a direct child of the currently being visited node.
// In this case, parentLoc is the loc of expression's direct parent node.
const LocationSet *getWrites(const IR::Expression *expression, const loc_t *parentLoc) {
const loc_t &exprLoc = *getLoc(expression, parentLoc);
auto result = ::get(writes, exprLoc);
BUG_CHECK(result != nullptr, "No location set known for %1%", expression);
return result;
}
// Register writes of expression, which is expected to be the currently visited node.
void expressionWrites(const IR::Expression *expression, const LocationSet *loc) {
CHECK_NULL(expression);
CHECK_NULL(loc);
LOG3(expression << dbp(expression) << " writes " << loc);
if (auto it = writes.find(expression); it != writes.end()) {
const Context *ctx = getChildContext();
BUG_CHECK(ctx->node == expression, "Expected ctx->node == expression.");
const loc_t &exprLoc = *getLoc(ctx);
if (auto it = writes.find(exprLoc); it != writes.end()) {
BUG_CHECK(*it->second == *loc || expression->is<IR::Literal>(),
"Expression %1% write set already set", expression);
} else {
writes.emplace(expression, loc);
writes.emplace(exprLoc, loc);
}
}
void dbprint(std::ostream &out) const override {
if (writes.empty()) out << "No writes";
for (auto &it : writes) out << it.first << " writes " << it.second << Log::endl;
for (auto &it : writes) out << it.first.node << " writes " << it.second << Log::endl;
}
profile_t init_apply(const IR::Node *root) override {
auto rv = Inspector::init_apply(root);
Expand All @@ -555,63 +647,8 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint {
}
}

public:
explicit ComputeWriteSet(AllDefinitions *allDefinitions)
: allDefinitions(allDefinitions),
currentDefinitions(nullptr),
returnedDefinitions(nullptr),
exitDefinitions(new Definitions()),
storageMap(allDefinitions->storageMap),
lhs(false),
virtualMethod(false) {
CHECK_NULL(allDefinitions);
visitDagOnce = false;
}

// expressions
bool preorder(const IR::Literal *expression) override;
bool preorder(const IR::Slice *expression) override;
bool preorder(const IR::TypeNameExpression *expression) override;
bool preorder(const IR::PathExpression *expression) override;
bool preorder(const IR::Member *expression) override;
bool preorder(const IR::ArrayIndex *expression) override;
bool preorder(const IR::Operation_Binary *expression) override;
bool preorder(const IR::Mux *expression) override;
bool preorder(const IR::SelectExpression *expression) override;
bool preorder(const IR::ListExpression *expression) override;
bool preorder(const IR::Operation_Unary *expression) override;
bool preorder(const IR::MethodCallExpression *expression) override;
bool preorder(const IR::DefaultExpression *expression) override;
bool preorder(const IR::Expression *expression) override;
bool preorder(const IR::InvalidHeader *expression) override;
bool preorder(const IR::InvalidHeaderUnion *expression) override;
bool preorder(const IR::P4ListExpression *expression) override;
bool preorder(const IR::HeaderStackExpression *expression) override;
bool preorder(const IR::StructExpression *expression) override;
// statements
bool preorder(const IR::P4Parser *parser) override;
bool preorder(const IR::P4Control *control) override;
bool preorder(const IR::P4Action *action) override;
bool preorder(const IR::P4Table *table) override;
bool preorder(const IR::Function *function) override;
bool preorder(const IR::AssignmentStatement *statement) override;
bool preorder(const IR::ReturnStatement *statement) override;
bool preorder(const IR::ExitStatement *statement) override;
bool preorder(const IR::BreakStatement *statement) override;
bool handleJump(const char *tok, Definitions *&defs);
bool preorder(const IR::ContinueStatement *statement) override;
bool preorder(const IR::IfStatement *statement) override;
bool preorder(const IR::ForStatement *statement) override;
bool preorder(const IR::ForInStatement *statement) override;
bool preorder(const IR::BlockStatement *statement) override;
bool preorder(const IR::SwitchStatement *statement) override;
bool preorder(const IR::EmptyStatement *statement) override;
bool preorder(const IR::MethodCallStatement *statement) override;

const LocationSet *writtenLocations(const IR::Expression *expression) {
expression->apply(*this);
return getWrites(expression);
}
private:
std::set<loc_t> &cached_locs;
};

} // namespace P4
Expand Down

0 comments on commit 407aa2d

Please sign in to comment.