Skip to content

Commit

Permalink
PGIS: Use getOrCreateMD instead of plain get to work around current i…
Browse files Browse the repository at this point in the history
…ssue.
  • Loading branch information
pearzt authored and jplehr committed Jan 17, 2024
1 parent 2cf688a commit 35559d0
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 45 deletions.
8 changes: 4 additions & 4 deletions pgis/lib/src/CgHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,29 +76,29 @@ bool isOnCycle(metacg::CgNode *node, const metacg::Callgraph *const graph) {
Statements visitNodeForInclusiveStatements(metacg::CgNode *node, CgNodeRawPtrUSet *visitedNodes,
const metacg::Callgraph *const graph) {
if (visitedNodes->find(node) != visitedNodes->end()) {
return node->get<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements();
return node->getOrCreateMD<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements();
}
visitedNodes->insert(node);

CgNodeRawPtrUSet visistedChilds;

Statements inclusiveStatements = node->get<PiraOneData>()->getNumberOfStatements();
Statements inclusiveStatements = node->getOrCreateMD<PiraOneData>()->getNumberOfStatements();
for (auto childNode : graph->getCallees(node)) {
// prevent double processing
if (visistedChilds.find(childNode) != visistedChilds.end())
continue;
visistedChilds.insert(childNode);

// approximate statements of a abstract function with maximum of its children (potential call targets)
if (node->get<LoadImbalance::LIMetaData>()->isVirtual() && node->get<PiraOneData>()->getNumberOfStatements() == 0) {
if (node->getOrCreateMD<LoadImbalance::LIMetaData>()->isVirtual() && node->getOrCreateMD<PiraOneData>()->getNumberOfStatements() == 0) {
inclusiveStatements =
std::max(inclusiveStatements, visitNodeForInclusiveStatements(childNode, visitedNodes, graph));
} else {
inclusiveStatements += visitNodeForInclusiveStatements(childNode, visitedNodes, graph);
}
}

node->get<LoadImbalance::LIMetaData>()->setNumberOfInclusiveStatements(inclusiveStatements);
node->getOrCreateMD<LoadImbalance::LIMetaData>()->setNumberOfInclusiveStatements(inclusiveStatements);

metacg::MCGLogger::instance().getConsole()->trace("Visiting node " + node->getFunctionName() +
". Result = " + std::to_string(inclusiveStatements));
Expand Down
4 changes: 2 additions & 2 deletions pgis/lib/src/ExtrapEstimatorPhase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void ExtrapLocalEstimatorPhaseBase::modifyGraph(metacg::CgNode *mainNode) {
if (shouldInstr) {
auto useCSInstr =
pgis::config::GlobalConfig::get().getAs<bool>(pgis::options::useCallSiteInstrumentation.cliName);
if (useCSInstr && !n->get<PiraOneData>()->getHasBody()) {
if (useCSInstr && !n->getOrCreateMD<PiraOneData>()->getHasBody()) {
// If no definition, use call-site instrumentation
pgis::instrumentPathNode(n);
// n->setState(CgNodeState::INSTRUMENT_PATH);
Expand Down Expand Up @@ -142,7 +142,7 @@ void ExtrapLocalEstimatorPhaseSingleValueExpander::modifyGraph(metacg::CgNode *m
console->trace("Running ExtrapLocalEstimatorPhaseExpander::modifyGraph on {}", n->getFunctionName());
auto [shouldInstr, funcRtVal] = shouldInstrument(n);
if (shouldInstr) {
if (!n->get<PiraOneData>()->getHasBody() && n->get<BaseProfileData>()->getRuntimeInSeconds() == .0) {
if (!n->getOrCreateMD<PiraOneData>()->getHasBody() && n->get<BaseProfileData>()->getRuntimeInSeconds() == .0) {
// If no definition, use call-site instrumentation
// n->setState(CgNodeState::INSTRUMENT_PATH);
pgis::instrumentPathNode(n);
Expand Down
30 changes: 15 additions & 15 deletions pgis/lib/src/IPCGEstimatorPhase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void StatementCountEstimatorPhase::estimateStatementCount(metacg::CgNode *startN

while (!workQueue.empty()) {
auto node = workQueue.front();
const auto nodePOD = node->get<PiraOneData>();
const auto nodePOD = node->getOrCreateMD<PiraOneData>();
workQueue.pop();

if (const auto [it, inserted] = visitedNodes.insert(node); inserted) {
Expand All @@ -107,7 +107,7 @@ void StatementCountEstimatorPhase::estimateStatementCount(metacg::CgNode *startN
inclStmtCounts[startNode] = inclStmtCount;
} else {
// EXCLUSIVE
const auto snPOD = startNode->get<PiraOneData>();
const auto snPOD = startNode->getOrCreateMD<PiraOneData>();
inclStmtCount = snPOD->getNumberOfStatements();
}

Expand Down Expand Up @@ -183,7 +183,7 @@ void RuntimeEstimatorPhase::modifyGraph(metacg::CgNode *mainMethod) {
spdlog::get("console")->debug("The runtime is threshold is computed as: {}", runTimeThreshold);

// The main method is always in the dominant runtime path
mainMethod->get<PiraOneData>()->setDominantRuntime();
mainMethod->getOrCreateMD<PiraOneData>()->setDominantRuntime();

std::queue<metacg::CgNode *> workQueue;
workQueue.push(mainMethod);
Expand Down Expand Up @@ -216,7 +216,7 @@ void RuntimeEstimatorPhase::estimateRuntime(metacg::CgNode *startNode) {
// Skip not leave nodes
const auto &childs = graph->getCallees(startNode);
if (calls > 0 && std::none_of(childs.begin(), childs.end(), [](const auto cnode) {
return cnode->template get<PiraOneData>()->comesFromCube();
return cnode->template getOrCreateMD<PiraOneData>()->comesFromCube();
})) {
exclusiveCalls.emplace(calls, startNode);
}
Expand Down Expand Up @@ -257,7 +257,7 @@ void RuntimeEstimatorPhase::doInstrumentation(metacg::CgNode *startNode, metacg:
maxStmts = childStmts[childNode];
}

if (childNode->get<PiraOneData>()->comesFromCube()) {
if (childNode->getOrCreateMD<PiraOneData>()->comesFromCube()) {
if (childNode->get<BaseProfileData>()->getInclusiveRuntimeInSeconds() < runTimeThreshold) {
continue;
}
Expand Down Expand Up @@ -322,10 +322,10 @@ void RuntimeEstimatorPhase::doInstrumentation(metacg::CgNode *startNode, metacg:
spdlog::get("console")->debug("This is the dominant runtime path");
// maxRtChild->setState(CgNodeState::INSTRUMENT_WITNESS);
pgis::instrumentNode(maxRtChild);
maxRtChild->get<PiraOneData>()->setDominantRuntime();
maxRtChild->getOrCreateMD<PiraOneData>()->setDominantRuntime();
} else {
spdlog::get("console")->debug("This is the non-dominant runtime path");
if (startNode->get<PiraOneData>()->isDominantRuntime()) {
if (startNode->getOrCreateMD<PiraOneData>()->isDominantRuntime()) {
spdlog::get("console")->debug("\tPrincipal: {}", startNode->getFunctionName());
for (auto child : graph->getCallees(startNode)) {
spdlog::get("console")->trace("\tEvaluating {} with {} [stmt threshold: {}]", child->getFunctionName(),
Expand Down Expand Up @@ -353,7 +353,7 @@ InstumentationInfo RuntimeEstimatorPhase::getEstimatedInfoForInstrumentedNode(Cg
unsigned long exclusiveStmtCount = 0;
while (!workQueue.empty()) {
auto wnode = workQueue.front();
const auto nodePOD = wnode->get<PiraOneData>();
const auto nodePOD = wnode->getOrCreateMD<PiraOneData>();
workQueue.pop();
if (visitedNodes.find(wnode) == visitedNodes.end()) {
visitedNodes.insert(wnode);
Expand Down Expand Up @@ -404,7 +404,7 @@ void RuntimeEstimatorPhase::modifyGraphOverhead(metacg::CgNode *mainMethod) {
}
for (const auto &elem : graph->getNodes()) {
const auto &node = elem.second.get();
if (node->get<PiraOneData>()->comesFromCube()) {
if (node->getOrCreateMD<PiraOneData>()->comesFromCube()) {
if (node->getHasBody() || isMPIFunction(node) || (!onlyEligibleNodes && !useCSInstrumentation)) {
pgis::instrumentNode(node);
} else if (useCSInstrumentation) {
Expand Down Expand Up @@ -474,7 +474,7 @@ void RuntimeEstimatorPhase::modifyGraphOverhead(metacg::CgNode *mainMethod) {
// Find direct childs nodes to potentially instrument
for (const auto &elem : graph->getNodes()) {
const auto &node = elem.second.get();
if (node->get<PiraOneData>()->comesFromCube()) {
if (node->getOrCreateMD<PiraOneData>()->comesFromCube()) {
// TODO: Do we need this check
if (!node->get<TemporaryInstrumentationDecisionMetadata>()->isKicked) {
for (const auto &C : graph->getCallees(node)) {
Expand Down Expand Up @@ -906,11 +906,11 @@ void StatisticsEstimatorPhase::modifyGraph(metacg::CgNode *mainMethod) {
}

numReachableFunctions++;
auto numStmts = node->get<PiraOneData>()->getNumberOfStatements();
auto numStmts = node->getOrCreateMD<PiraOneData>()->getNumberOfStatements();
if (pgis::isInstrumented(node)) {
stmtsCoveredWithInstr += numStmts;
}
if (node->get<PiraOneData>()->comesFromCube()) {
if (node->getOrCreateMD<PiraOneData>()->comesFromCube()) {
stmtsActuallyCovered += numStmts;
}
auto &histElem = stmtHist[numStmts];
Expand Down Expand Up @@ -1299,7 +1299,7 @@ void AttachInstrumentationResultsEstimatorPhase::modifyGraph(CgNode *mainMethod)
// We have pretty exact information about a function if we instrument it and all of its childs
// This information should not change between iterations, so we do not need to overwrite/recalculate it
const auto instResult = node->checkAndGet<InstrumentationResultMetaData>();
if (instResult.first && (instResult.second->shouldBeInstrumented && !node->get<PiraOneData>()->comesFromCube())) {
if (instResult.first && (instResult.second->shouldBeInstrumented && !node->getOrCreateMD<PiraOneData>()->comesFromCube())) {
// A node should have been instrumented, but was not in the cube file. This means it has zero calls and zero
// runtime
instResult.second->isExclusiveRuntime = true;
Expand All @@ -1315,7 +1315,7 @@ void AttachInstrumentationResultsEstimatorPhase::modifyGraph(CgNode *mainMethod)
continue;
}

if (node->get<PiraOneData>()->comesFromCube()) {
if (node->getOrCreateMD<PiraOneData>()->comesFromCube()) {
// We already calculated an exclusive result before. Do not change it, as there could be small measurement
// differences that could cause flickering between different instrumentation states
bool hasPrevExclusive = instResult.first && instResult.second->isExclusiveRuntime;
Expand All @@ -1326,7 +1326,7 @@ void AttachInstrumentationResultsEstimatorPhase::modifyGraph(CgNode *mainMethod)
const auto callsFromParents = node->get<BaseProfileData>()->getCallsFromParents();
const auto &childs = graph->getCallees(node);
const bool isExclusive = std::none_of(childs.begin(), childs.end(), [](const auto &child) {
return !child->template get<PiraOneData>()->comesFromCube();
return !child->template getOrCreateMD<PiraOneData>()->comesFromCube();
});
// Calculate the summing inclusive runtime
std::queue<CgNode *> workQueue;
Expand Down
6 changes: 3 additions & 3 deletions pgis/lib/src/LegacyMCGReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ void VersionOneMetaCGReader::read(metacg::graph::MCGManager &cgManager) {
if (opt_f.has_value()) {
metacg::CgNode *node = opt_f.value();
node->getOrCreateMD<LoadImbalance::LIMetaData>();
node->get<LoadImbalance::LIMetaData>()->setVirtual(pfi.second.isVirtual);
node->getOrCreateMD<LoadImbalance::LIMetaData>()->setVirtual(pfi.second.isVirtual);

if (pfi.second.visited) {
node->get<LoadImbalance::LIMetaData>()->flag(LoadImbalance::FlagType::Visited);
node->getOrCreateMD<LoadImbalance::LIMetaData>()->flag(LoadImbalance::FlagType::Visited);
}

if (pfi.second.irrelevant) {
node->get<LoadImbalance::LIMetaData>()->flag(LoadImbalance::FlagType::Irrelevant);
node->getOrCreateMD<LoadImbalance::LIMetaData>()->flag(LoadImbalance::FlagType::Irrelevant);
}
}
}
Expand Down
42 changes: 21 additions & 21 deletions pgis/lib/src/loadImbalance/LIEstimatorPhase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ LIEstimatorPhase::LIEstimatorPhase(std::unique_ptr<LIConfig> &&config, metacg::C
LIEstimatorPhase::~LIEstimatorPhase() { delete this->metric; }

void LIEstimatorPhase::modifyGraph(metacg::CgNode *mainMethod) {
double totalRuntime = mainMethod->get<pira::BaseProfileData>()->getInclusiveRuntimeInSeconds();
double totalRuntime = mainMethod->getOrCreateMD<pira::BaseProfileData>()->getInclusiveRuntimeInSeconds();

// make sure no node is marked for instrumentation yet
for (const auto &elem : graph->getNodes()) {
Expand All @@ -59,14 +59,14 @@ void LIEstimatorPhase::modifyGraph(metacg::CgNode *mainMethod) {
spdlog::get("console")->debug("LIEstimatorPhase: Processing node " + n->getFunctionName());

// flag node as visited
n->get<LIMetaData>()->flag(FlagType::Visited);
n->getOrCreateMD<LIMetaData>()->flag(FlagType::Visited);

double runtime = n->get<pira::BaseProfileData>()->getInclusiveRuntimeInSeconds();
double runtime = n->getOrCreateMD<pira::BaseProfileData>()->getInclusiveRuntimeInSeconds();

std::ostringstream debugString;

debugString << "Visiting node " << n->getFunctionName() << " ("
<< n->get<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() << "): ";
<< n->getOrCreateMD<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() << "): ";

// check whether node is sufficiently important
if (runtime / totalRuntime >= c->relevanceThreshold) {
Expand All @@ -79,12 +79,12 @@ void LIEstimatorPhase::modifyGraph(metacg::CgNode *mainMethod) {
statementThreshold = c->childConstantThreshold;
} else if (c->childRelevanceStrategy == ChildRelevanceStrategy::RelativeToParent) {
statementThreshold =
std::max((pira::Statements)(n->get<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() *
std::max((pira::Statements)(n->getOrCreateMD<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() *
c->childFraction),
c->childConstantThreshold);
} else if (c->childRelevanceStrategy == ChildRelevanceStrategy::RelativeToMain) {
statementThreshold = std::max(
(pira::Statements)(mainMethod->get<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() *
(pira::Statements)(mainMethod->getOrCreateMD<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() *
c->childFraction),
c->childConstantThreshold);
}
Expand All @@ -99,21 +99,21 @@ void LIEstimatorPhase::modifyGraph(metacg::CgNode *mainMethod) {
debugString << " -> " << m;
if (m >= c->imbalanceThreshold) {
debugString << " => imbalanced";
n->get<LoadImbalance::LIMetaData>()->setAssessment(m);
n->get<LoadImbalance::LIMetaData>()->flag(FlagType::Imbalanced);
n->getOrCreateMD<LoadImbalance::LIMetaData>()->setAssessment(m);
n->getOrCreateMD<LoadImbalance::LIMetaData>()->flag(FlagType::Imbalanced);
imbalancedNodeSet.push_back(n);

instrument(n); // make sure imbalanced functions stays instrumented

} else {
debugString << " => balanced";
// mark as irrelevant
n->get<LIMetaData>()->flag(FlagType::Irrelevant);
n->getOrCreateMD<LIMetaData>()->flag(FlagType::Irrelevant);
}
} else {
debugString << "ignored (" << runtime << " / " << totalRuntime << " = " << runtime / totalRuntime << ")";
// mark as irrelevant
n->get<LIMetaData>()->flag(FlagType::Irrelevant);
n->getOrCreateMD<LIMetaData>()->flag(FlagType::Irrelevant);
}
spdlog::get("console")->debug(debugString.str());
}
Expand All @@ -133,8 +133,8 @@ void LIEstimatorPhase::modifyGraph(metacg::CgNode *mainMethod) {
std::ostringstream imbalancedNames;
for (const auto &i : imbalancedNodeSet) {
imbalancedNames << i->getFunctionName();
imbalancedNames << " load imbalance assessment: " << i->get<LoadImbalance::LIMetaData>()->getAssessment().value();
imbalancedNames << " incl. runtime: " << i->get<pira::BaseProfileData>()->getInclusiveRuntimeInSeconds() << " sec.";
imbalancedNames << " load imbalance assessment: " << i->getOrCreateMD<LoadImbalance::LIMetaData>()->getAssessment().value();
imbalancedNames << " incl. runtime: " << i->getOrCreateMD<pira::BaseProfileData>()->getInclusiveRuntimeInSeconds() << " sec.";
imbalancedNames << "\n";
}
spdlog::get("console")->info("Load imbalance summary: " + imbalancedNames.str());
Expand All @@ -160,24 +160,24 @@ void LIEstimatorPhase::instrumentRelevantChildren(metacg::CgNode *node, pira::St
visitedSet.insert(child);

// process grandchilds (as possible implementations of virtual functions are children of those)
if (child->get<LoadImbalance::LIMetaData>()->isVirtual()) {
if (child->getOrCreateMD<LoadImbalance::LIMetaData>()->isVirtual()) {
for (metacg::CgNode *gc : graph->getCallees(child)) {
workQueue.push(gc);
}
}

if (child->get<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() >= statementThreshold) {
if (!child->get<LIMetaData>()->isFlagged(FlagType::Irrelevant)) {
if (child->getOrCreateMD<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() >= statementThreshold) {
if (!child->getOrCreateMD<LIMetaData>()->isFlagged(FlagType::Irrelevant)) {
instrument(child);
debugString << child->getFunctionName() << " ("
<< child->get<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() << ") ";
<< child->getOrCreateMD<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() << ") ";
} else {
debugString << "-" << child->getFunctionName() << "- ("
<< child->get<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() << ") ";
<< child->getOrCreateMD<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() << ") ";
}
} else {
debugString << "/" << child->getFunctionName() << "\\ ("
<< child->get<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() << ") ";
<< child->getOrCreateMD<LoadImbalance::LIMetaData>()->getNumberOfInclusiveStatements() << ") ";
}
}
}
Expand Down Expand Up @@ -205,7 +205,7 @@ void LoadImbalance::LIEstimatorPhase::contextHandling(metacg::CgNode *n, metacg:
if (c->contextStrategy == ContextStrategy::MajorPathsToMain ||
c->contextStrategy == ContextStrategy::MajorParentSteps) {
for (metacg::CgNode *x : nodesOnPathToMain) {
if (!x->get<LIMetaData>()->isFlagged(FlagType::Visited)) {
if (!x->getOrCreateMD<LIMetaData>()->isFlagged(FlagType::Visited)) {
relevantPaths.erase(x);
}
}
Expand Down Expand Up @@ -273,8 +273,8 @@ void LIEstimatorPhase::findSyncPoints(CgNode *node) {

// process all parents which are balanced + visisted
for (CgNode *parent : graph->getCallers(node)) {
if (!parent->get<LoadImbalance::LIMetaData>()->isFlagged(FlagType::Imbalanced) &&
parent->get<LoadImbalance::LIMetaData>()->isFlagged(FlagType::Visited)) {
if (!parent->getOrCreateMD<LoadImbalance::LIMetaData>()->isFlagged(FlagType::Imbalanced) &&
parent->getOrCreateMD<LoadImbalance::LIMetaData>()->isFlagged(FlagType::Visited)) {
// instrument all descendant synchronization routines
instrumentByPattern(
parent, [](CgNode *nodeInQuestion) { return nodeInQuestion->getFunctionName().rfind("MPI_", 0) == 0; },
Expand Down

0 comments on commit 35559d0

Please sign in to comment.