Skip to content

Commit

Permalink
Changes for for loops (#4562)
Browse files Browse the repository at this point in the history
* Language grammar changes to add for loop

* IR classes for For and ForInStatements

* For statement namespace fixes

* Recognize "for" token

* Support for for-loops in toP4

* forloop initial test for parsing

* Add support for break / continue in loops. No sema checks

* Add semantic checks for continue and break statements

* Minimize stderr diffs

- due to ';' now being no longer part of the declaration, one character
  difference in a bunch of stderr traces

* sideEffects support for For statements

- keep a PathExpression of the var in ForInStatements, so that the
  symbol can be moved to the top level (from Anton Korobeynikov
  <[email protected]>)

* loop fixes to get through midend/p4test

- allow ForIn to refer to decl in the enclosing scope
- ensure valid P4 code output for loops in toP4
- fix testcase to no deadcode elim everything

* def_use for loops

* Split loop visit_children to separate source file

* loop support in ControlFlowVisitor

* loop support for LocalCopyprop

* loop support for midend ComputeDefUse

* Fix local_copyprop to not copyprop illegally into loops

* Disable ActionSynthesis for statments in for init/update

- Some hacks here to figure out which child is which for a ForStatement
- Should be part of ActionSynthesis policy somehow?

* loop support in FindUninitialized

* Allow annotations on for statements

* clang-format

* Added testcases

* Minor typos fixes for loops

- 'for' instead of 'foreach' in comments/error messages
- fix constant in loop-visitor

* Redo loop flow analysis -- fix def_use and ControlFlowVisitor

- flow state after loop needs to be union of state after the condition
  check (not bottom of loop) and all break states.
- condition of for..in is before setting index var

* Fix toP4 for break/continue + more tests

* Unsupported error for loops in BMV2

* Comment typos/improvements

* Insert break; when removing return/exit from loop

* initial UnrollLoops pass

- only handles simple for v in k1..k2 loops; general for TBD

* GlobalCopyprop support for loops

* UnrollLoops for ForStatement

* Repeat UnrollLoops + constfold + copyprop to fixed point.

* Fix and generalize ForStatement unrolling

- allow more patterns of tests and increments
- deal properly with updates in the presence of break&continue
- single test to skip rest of loop after break rather than rechecking
  the flag every time.
- remove redundant inits of flags

* UnrollLoops fix for break after continue

* Testcases for expected errors and nested loops with return

* Typecheck/inference into ForIn loop ranges

---------

Co-authored-by: Andy Fingerhut <[email protected]>
Co-authored-by: Anton Korobeynikov <[email protected]>
  • Loading branch information
3 people authored May 17, 2024
1 parent 6343eab commit b51b341
Show file tree
Hide file tree
Showing 135 changed files with 5,563 additions and 329 deletions.
1 change: 1 addition & 0 deletions backends/bmv2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ set (BMV2_BACKEND_COMMON_HDRS
common/action.h
common/annotations.h
common/backend.h
common/check_unsupported.h
common/control.h
common/controlFlowGraph.h
common/deparser.h
Expand Down
40 changes: 40 additions & 0 deletions backends/bmv2/common/check_unsupported.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
Copyright 2013-present Barefoot Networks, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#ifndef BACKENDS_BMV2_COMMON_CHECK_UNSUPPORTED_H_
#define BACKENDS_BMV2_COMMON_CHECK_UNSUPPORTED_H_

#include "frontends/common/options.h"
#include "ir/ir.h"
#include "lower.h"
#include "midend/convertEnums.h"

namespace BMV2 {

class CheckUnsupported : public Inspector {
bool preorder(const IR::ForStatement *fs) override {
error(ErrorType::ERR_UNSUPPORTED, "%sBMV2 does not support loops", fs->srcInfo);
return false;
}
bool preorder(const IR::ForInStatement *fs) override {
error(ErrorType::ERR_UNSUPPORTED, "%sBMV2 does not support loops", fs->srcInfo);
return false;
}
};

} // namespace BMV2

#endif /* BACKENDS_BMV2_COMMON_CHECK_UNSUPPORTED_H_ */
3 changes: 3 additions & 0 deletions backends/bmv2/psa_switch/midend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.

#include "midend.h"

#include "backends/bmv2/common/check_unsupported.h"
#include "backends/bmv2/psa_switch/options.h"
#include "frontends/common/constantFolding.h"
#include "frontends/common/resolveReferences/resolveReferences.h"
Expand Down Expand Up @@ -107,6 +108,7 @@ PsaSwitchMidEnd::PsaSwitchMidEnd(CompilerOptions &options, std::ostream *outStre
if (BMV2::PsaSwitchContext::get().options().loadIRFromJson == false) {
addPasses({
options.ndebug ? new P4::RemoveAssertAssume(&refMap, &typeMap) : nullptr,
new CheckUnsupported(),
new P4::RemoveMiss(&refMap, &typeMap),
new P4::EliminateNewtype(&refMap, &typeMap),
new P4::EliminateInvalidHeaders(&refMap, &typeMap),
Expand Down Expand Up @@ -167,6 +169,7 @@ PsaSwitchMidEnd::PsaSwitchMidEnd(CompilerOptions &options, std::ostream *outStre
addPasses({
new P4::ResolveReferences(&refMap),
new P4::TypeChecking(&refMap, &typeMap),
new CheckUnsupported(),
fillEnumMap,
[this, fillEnumMap]() { enumMap = fillEnumMap->repr; },
evaluator,
Expand Down
2 changes: 2 additions & 0 deletions backends/bmv2/simple_switch/midend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.

#include "midend.h"

#include "backends/bmv2/common/check_unsupported.h"
#include "backends/bmv2/simple_switch/options.h"
#include "frontends/common/constantFolding.h"
#include "frontends/common/resolveReferences/resolveReferences.h"
Expand Down Expand Up @@ -75,6 +76,7 @@ SimpleSwitchMidEnd::SimpleSwitchMidEnd(CompilerOptions &options, std::ostream *o
addPasses(
{options.ndebug ? new P4::RemoveAssertAssume(&refMap, &typeMap) : nullptr,
new P4::CheckTableSize(),
new CheckUnsupported(),
new P4::RemoveMiss(&refMap, &typeMap),
new P4::EliminateNewtype(&refMap, &typeMap),
new P4::EliminateInvalidHeaders(&refMap, &typeMap),
Expand Down
11 changes: 10 additions & 1 deletion backends/p4test/midend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ limitations under the License.
#include "midend/simplifySelectCases.h"
#include "midend/simplifySelectList.h"
#include "midend/tableHit.h"
#include "midend/unrollLoops.h"

namespace P4Test {

Expand All @@ -81,6 +82,7 @@ MidEnd::MidEnd(CompilerOptions &options, std::ostream *outStream) {
setName("MidEnd");

auto v1controls = new std::set<cstring>();
auto defUse = new P4::ComputeDefUse;

addPasses(
{options.ndebug ? new P4::RemoveAssertAssume(&refMap, &typeMap) : nullptr,
Expand Down Expand Up @@ -125,7 +127,14 @@ MidEnd::MidEnd(CompilerOptions &options, std::ostream *outStream) {
new P4::EliminateSwitch(&refMap, &typeMap),
new P4::ResolveReferences(&refMap),
new P4::TypeChecking(&refMap, &typeMap, true), // update types before ComputeDefUse
new P4::ComputeDefUse, // present for testing
new PassRepeated({
defUse,
new P4::UnrollLoops(refMap, defUse),
new P4::LocalCopyPropagation(&refMap, &typeMap),
new P4::ConstantFolding(&refMap, &typeMap),
new P4::StrengthReduction(&refMap, &typeMap),
}),
new P4::MoveDeclarations(), // more may have been introduced
evaluator,
[v1controls, evaluator](const IR::Node *root) -> const IR::Node * {
auto toplevel = evaluator->getToplevelBlock();
Expand Down
113 changes: 98 additions & 15 deletions frontends/p4/def_use.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,15 @@ bool LocationSet::overlaps(const LocationSet *other) const {
return false;
}

bool LocationSet::operator==(const LocationSet &other) const {
auto it = other.begin();
for (auto s : locations) {
if (it == other.end() || *it != s) return false;
++it;
}
return it == other.end();
}

void ProgramPoints::add(const ProgramPoints *from) {
points.insert(from->points.begin(), from->points.end());
}
Expand Down Expand Up @@ -406,7 +415,7 @@ void ComputeWriteSet::enterScope(const IR::ParameterList *parameters,
}
}
}
allDefinitions->setDefinitionsAt(entryPoint, defs, false);
allDefinitions->setDefinitionsAt(entryPoint, defs, true);
currentDefinitions = defs;
if (LOGGING(5))
LOG5("CWS Entered scope " << entryPoint << " definitions are " << Log::endl << defs);
Expand Down Expand Up @@ -460,6 +469,9 @@ bool ComputeWriteSet::setDefinitions(Definitions *defs, const IR::Node *node, bo
// overwriting always in parser states. In this case we actually expect
// that the definitions are monotonically increasing.
if (findContext<IR::ParserState>()) overwrite = true;
// Loop bodies get visited repeatedly until a fixed point, so we likewise
// expect monotonically increasing write sets.
if (continueDefinitions != nullptr) overwrite = true; // in a loop
allDefinitions->setDefinitionsAt(point, currentDefinitions, overwrite);
if (LOGGING(5))
LOG5("CWS Definitions at " << point << " are " << Log::endl << defs);
Expand Down Expand Up @@ -820,34 +832,105 @@ bool ComputeWriteSet::preorder(const IR::IfStatement *statement) {
return setDefinitions(result);
}

bool ComputeWriteSet::preorder(const IR::ForStatement *statement) {
LOG3("CWS Visiting " << dbp(statement));
if (currentDefinitions->isUnreachable()) return setDefinitions(currentDefinitions);
visit(statement->init, "init");

auto saveBreak = breakDefinitions;
auto saveContinue = continueDefinitions;
breakDefinitions = new Definitions();
continueDefinitions = new Definitions();
Definitions *startDefs = nullptr;
Definitions *exitDefs = nullptr;

do {
startDefs = currentDefinitions;
visit(statement->condition, "condition");
auto cond = getWrites(statement->condition);
// exitDefs are the definitions after evaluating the condition
exitDefs = currentDefinitions->writes(getProgramPoint(), cond);
(void)setDefinitions(exitDefs, statement->condition, true);
visit(statement->body, "body");
currentDefinitions = currentDefinitions->joinDefinitions(continueDefinitions);
visit(statement->updates, "updates");
currentDefinitions = currentDefinitions->joinDefinitions(startDefs);
} while (!(*startDefs == *currentDefinitions));

exitDefs = exitDefs->joinDefinitions(breakDefinitions);
breakDefinitions = saveBreak;
continueDefinitions = saveContinue;
return setDefinitions(exitDefs);
}

bool ComputeWriteSet::preorder(const IR::ForInStatement *statement) {
LOG3("CWS Visiting " << dbp(statement));
if (currentDefinitions->isUnreachable()) return setDefinitions(currentDefinitions);
visit(statement->collection, "collection");

auto saveBreak = breakDefinitions;
auto saveContinue = continueDefinitions;
breakDefinitions = new Definitions();
continueDefinitions = new Definitions();
Definitions *startDefs = nullptr;
Definitions *exitDefs = currentDefinitions; // in case collection is empty;

do {
startDefs = currentDefinitions;
lhs = true;
visit(statement->ref, "ref");
lhs = false;
auto cond = getWrites(statement->ref);
auto defs = currentDefinitions->writes(getProgramPoint(), cond);
(void)setDefinitions(defs, statement->ref, true);
visit(statement->body, "body");
currentDefinitions = currentDefinitions->joinDefinitions(continueDefinitions);
currentDefinitions = currentDefinitions->joinDefinitions(startDefs);
} while (!(*startDefs == *currentDefinitions));

exitDefs = exitDefs->joinDefinitions(currentDefinitions);
exitDefs = exitDefs->joinDefinitions(breakDefinitions);
breakDefinitions = saveBreak;
continueDefinitions = saveContinue;
return setDefinitions(exitDefs);
}

bool ComputeWriteSet::preorder(const IR::BlockStatement *statement) {
if (currentDefinitions->isUnreachable()) return setDefinitions(currentDefinitions);
visit(statement->components, "components");
return setDefinitions(currentDefinitions);
}

bool ComputeWriteSet::preorder(const IR::ReturnStatement *statement) {
if (currentDefinitions->isUnreachable()) return setDefinitions(currentDefinitions);
if (statement->expression != nullptr) visit(statement->expression);
returnedDefinitions = returnedDefinitions->joinDefinitions(currentDefinitions);
if (LOGGING(5))
LOG5("Return definitions " << returnedDefinitions);
else
LOG3("Return " << returnedDefinitions->size() << " definitions");
auto defs = currentDefinitions->cloneDefinitions();
defs->setUnreachable();
return setDefinitions(defs);
return handleJump("Return", returnedDefinitions);
}

bool ComputeWriteSet::preorder(const IR::ExitStatement *) {
if (currentDefinitions->isUnreachable()) return setDefinitions(currentDefinitions);
exitDefinitions = exitDefinitions->joinDefinitions(currentDefinitions);
return handleJump("Exit", exitDefinitions);
}

bool ComputeWriteSet::preorder(const IR::BreakStatement *) {
if (currentDefinitions->isUnreachable()) return setDefinitions(currentDefinitions);
return handleJump("Break", breakDefinitions);
}

bool ComputeWriteSet::preorder(const IR::ContinueStatement *) {
if (currentDefinitions->isUnreachable()) return setDefinitions(currentDefinitions);
return handleJump("Continue", continueDefinitions);
}

bool ComputeWriteSet::handleJump(const char *tok, Definitions *&defs) {
defs = defs->joinDefinitions(currentDefinitions);
if (LOGGING(5))
LOG5("Exit definitions " << exitDefinitions);
LOG5(tok << " definitions " << defs);
else
LOG3("Exit with " << exitDefinitions->size() << " definitions");
auto defs = currentDefinitions->cloneDefinitions();
defs->setUnreachable();
return setDefinitions(defs);
LOG3(tok << " with " << defs->size() << " definitions");
auto after = currentDefinitions->cloneDefinitions();
after->setUnreachable();
return setDefinitions(after);
}

bool ComputeWriteSet::preorder(const IR::EmptyStatement *) {
Expand Down
27 changes: 20 additions & 7 deletions frontends/p4/def_use.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ class LocationSet : public IHasDbPrint {
}
// only defined for canonical representations
bool overlaps(const LocationSet *other) const;
bool operator==(const LocationSet &other) const;
bool isEmpty() const { return locations.empty(); }
};

Expand Down Expand Up @@ -476,10 +477,12 @@ class AllDefinitions : public IHasDbPrint {

class ComputeWriteSet : public Inspector, public IHasDbPrint {
protected:
AllDefinitions *allDefinitions; /// Result computed by this pass.
Definitions *currentDefinitions; /// Before statement currently processed.
Definitions *returnedDefinitions; /// Definitions after return statements.
Definitions *exitDefinitions; /// Definitions after exit statements.
AllDefinitions *allDefinitions; /// Result computed by this pass.
Definitions *currentDefinitions; /// Before statement currently processed.
Definitions *returnedDefinitions; /// Definitions after return statements.
Definitions *exitDefinitions; /// Definitions after exit statements.
Definitions *breakDefinitions = nullptr; /// Definitions at break statements.
Definitions *continueDefinitions = nullptr; /// Definitions at continue statements.
ProgramPoint callingContext;
const StorageMap *storageMap;
/// if true we are processing an expression on the lhs of an assignment
Expand All @@ -498,6 +501,8 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint {
currentDefinitions(definitions),
returnedDefinitions(nullptr),
exitDefinitions(source->exitDefinitions),
breakDefinitions(source->breakDefinitions),
continueDefinitions(source->continueDefinitions),
callingContext(context),
storageMap(source->storageMap),
lhs(false),
Expand All @@ -522,9 +527,12 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint {
CHECK_NULL(expression);
CHECK_NULL(loc);
LOG3(expression << dbp(expression) << " writes " << loc);
BUG_CHECK(writes.find(expression) == writes.end() || expression->is<IR::Literal>(),
"Expression %1% write set already set", expression);
writes.emplace(expression, loc);
if (auto it = writes.find(expression); it != writes.end()) {
BUG_CHECK(*it->second == *loc || expression->is<IR::Literal>(),
"Expression %1% write set already set", expression);
} else {
writes.emplace(expression, loc);
}
}
void dbprint(std::ostream &out) const override {
if (writes.empty()) out << "No writes";
Expand Down Expand Up @@ -589,7 +597,12 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint {
bool preorder(const IR::AssignmentStatement *statement) override;
bool preorder(const IR::ReturnStatement *statement) override;
bool preorder(const IR::ExitStatement *statement) override;
bool preorder(const IR::BreakStatement *statement) override;
bool handleJump(const char *tok, Definitions *&defs);
bool preorder(const IR::ContinueStatement *statement) override;
bool preorder(const IR::IfStatement *statement) override;
bool preorder(const IR::ForStatement *statement) override;
bool preorder(const IR::ForInStatement *statement) override;
bool preorder(const IR::BlockStatement *statement) override;
bool preorder(const IR::SwitchStatement *statement) override;
bool preorder(const IR::EmptyStatement *statement) override;
Expand Down
24 changes: 21 additions & 3 deletions frontends/p4/removeReturns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,15 @@ const IR::Node *DoRemoveReturns::preorder(IR::ReturnStatement *statement) {
set(TernaryBool::Yes);
auto vec = new IR::IndexedVector<IR::StatOrDecl>();

auto left = new IR::PathExpression(returnVar);
auto left = new IR::PathExpression(IR::Type::Boolean::get(), returnVar);
vec->push_back(
new IR::AssignmentStatement(statement->srcInfo, left, new IR::BoolLiteral(true)));
if (statement->expression != nullptr) {
left = new IR::PathExpression(returnedValue);
left = new IR::PathExpression(statement->expression->type, returnedValue);
vec->push_back(
new IR::AssignmentStatement(statement->srcInfo, left, statement->expression));
}
if (findContext<IR::LoopStatement>()) vec->push_back(new IR::BreakStatement);
return new IR::BlockStatement(*vec);
}

Expand All @@ -158,7 +159,7 @@ const IR::Node *DoRemoveReturns::preorder(IR::BlockStatement *statement) {
break;
} else if (r == TernaryBool::Maybe) {
auto newBlock = new IR::BlockStatement;
auto path = new IR::PathExpression(returnVar);
auto path = new IR::PathExpression(IR::Type::Boolean::get(), returnVar);
auto condition = new IR::LNot(path);
auto ifstat = new IR::IfStatement(condition, newBlock, nullptr);
block->push_back(ifstat);
Expand Down Expand Up @@ -211,4 +212,21 @@ const IR::Node *DoRemoveReturns::preorder(IR::SwitchStatement *statement) {
return statement;
}

const IR::Node *DoRemoveReturns::postorder(IR::LoopStatement *loop) {
// loop body might not (all) execute, so can't be certain if it returns
if (hasReturned() == TernaryBool::Yes) set(TernaryBool::Maybe);

// only need to add an extra check for nested loops
if (!findContext<IR::LoopStatement>()) return loop;
// only if the inner loop may have returned
if (hasReturned() == TernaryBool::No) return loop;

// break out of the outer loop if the inner loop returned
auto rv = new IR::BlockStatement();
rv->push_back(loop);
rv->push_back(new IR::IfStatement(new IR::PathExpression(IR::Type::Boolean::get(), returnVar),
new IR::BreakStatement(), nullptr));
return rv;
}

} // namespace P4
2 changes: 2 additions & 0 deletions frontends/p4/removeReturns.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class DoRemoveReturns : public Transform {
prune();
return parser;
}

const IR::Node *postorder(IR::LoopStatement *loop) override;
};

class RemoveReturns : public PassManager {
Expand Down
Loading

0 comments on commit b51b341

Please sign in to comment.