Skip to content

Commit

Permalink
Enhancing autotuning search with nested contexts.
Browse files Browse the repository at this point in the history
When doing autotuning with Kokkos kernels, it is possible to have nested
search contexts. When that happens, we want to make sure that we explore
all branches of all possible scenarios. For example, the "idk_jmm" test
in the https://github.com/khuck/apex-kokkos-tuning repository has one
context to choose between a team policy and a mdrange policy, and each
of those has tunable parameters. Because the outer context has only two
choices, it can converge very quickly unless we prevent it from
converging until the search has completed for both team and mdrange
policy implementations. That's what this change does - if we have nested
search contexts, the outer context(s) won't converge until the inner
context searches converge. I also reduced the max_iterations limit for
random, genetic_search and simulated_annealing to 500 from 1000.
  • Loading branch information
khuck committed Apr 24, 2024
1 parent 997ef7f commit 1447ff0
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 8 deletions.
61 changes: 60 additions & 1 deletion src/apex/apex_kokkos_tuning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,46 @@ class Variable {
}
};

/* this class helps us with nested contexts. We don't want to
* stop searching a higher level branch until all of the subtrees
* have also converged - the global optimum might be at the bottom
* of a branch with many options, but we won't get there if we
* have found local minimum in a simpler branch. */
class TreeNode {
static std::map<std::string,TreeNode*> allContexts;
public:
TreeNode(std::string _name) :
name(_name), _hasConverged(false) {}
static TreeNode* find(std::string name, TreeNode* parent) {
auto node = allContexts.find(name);
if (node == allContexts.end()) {
TreeNode *tmp = new TreeNode(name);
if (parent != nullptr) {
parent->children.insert(tmp);
}
allContexts.insert(std::pair<std::string, TreeNode*>(name,tmp));
return tmp;
}
return node->second;
}
std::string name;
std::set<TreeNode*> children;
bool _hasConverged;
bool haveChildren(void) {
return (children.size() > 0);
}
bool childrenConverged(void) {
// yes children? have they all converged?
for (auto child : children) {
if (!child->childrenConverged()) { return false; }
if (!child->_hasConverged) { return false; }
}
return true;
}
};

std::map<std::string,TreeNode*> TreeNode::allContexts;

class KokkosSession {
private:
// EXHAUSTIVE, RANDOM, NELDER_MEAD, PARALLEL_RANK_ORDER
Expand Down Expand Up @@ -258,6 +298,7 @@ class KokkosSession {
std::unordered_map<size_t, std::string> active_requests;
std::set<size_t> used_history;
std::unordered_map<size_t, uint64_t> context_starts;
std::stack<TreeNode*> contextStack;
void writeCache();
bool checkForCache();
void readCache();
Expand Down Expand Up @@ -989,6 +1030,9 @@ bool handle_start(const std::string & name, const size_t vars,
void handle_stop(const std::string & name) {
KokkosSession& session = KokkosSession::getSession();
auto search = session.requests.find(name);
/* We want to check if this search context has child contexts, and if so,
* have they converged? If not converged, we want to pass in false. */
bool childrenConverged = TreeNode::find(name, nullptr)->childrenConverged();
if(search == session.requests.end()) {
std::cerr << "ERROR: No data for " << name << std::endl;
} else {
Expand All @@ -998,8 +1042,11 @@ void handle_stop(const std::string & name) {
profile->calls >= session.window)) {
//std::cout << "Num calls: " << profile->calls << std::endl;
std::shared_ptr<apex_tuning_request> request = search->second;
/* If we are in a nested context, and this is the outermost
* context, we want to not allow it to converge until all of
* the inner contexts have also converged! */
// Evaluate the results
apex::custom_event(request->get_trigger(), NULL);
apex::custom_event(request->get_trigger(), &childrenConverged);
// Reset counter so each measurement is fresh.
apex::reset(name);
}
Expand Down Expand Up @@ -1104,6 +1151,12 @@ void kokkosp_request_values(
std::cout << std::string(getDepth(), ' ');
printContext(numContextVariables, name);
}
// push our context on the stack
if(session.contextStack.size() > 0) {
session.contextStack.push(TreeNode::find(name, session.contextStack.top()));
} else {
session.contextStack.push(TreeNode::find(name, nullptr));
}
// check if we have a cached result
bool success{false};
if (session.use_history) {
Expand All @@ -1118,6 +1171,9 @@ void kokkosp_request_values(
// throw away the time spent setting up tuning
//session.context_starts[contextId] = session.context_starts[contextId] + delta;
}
if(converged) {
TreeNode::find(name, nullptr)->_hasConverged = true;
}
//if (!converged) {
// add this name to our map of active contexts
session.active_requests.insert(
Expand Down Expand Up @@ -1178,6 +1234,9 @@ void kokkosp_end_context(const size_t contextId) {
session.active_requests.erase(contextId);
}
session.context_starts.erase(contextId);
if (session.contextStack.size() > 0) {
session.contextStack.pop();
}
}

} // extern "C"
Expand Down
36 changes: 32 additions & 4 deletions src/apex/apex_policies.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,14 @@ int apex_sa_policy(shared_ptr<apex_tuning_session> tuning_session,
APEX_UNUSED(context);
if (apex_final) return APEX_NOERROR; // we terminated
std::unique_lock<std::mutex> l{shutdown_mutex};
if (tuning_session->sa_session.converged()) {
/* If we are doing nested search contexts, allow us to keep searching
* on outer contexts until all inner contexts have converged! */
bool force{true};
if (context.data != nullptr) {
// the context data is a pointer to a boolean value
force = *((bool*)(context.data));
}
if (tuning_session->sa_session.converged() && force) {
if (!tuning_session->converged_message) {
tuning_session->converged_message = true;
cout << "APEX: Tuning has converged for session " << tuning_session->id
Expand Down Expand Up @@ -912,7 +919,14 @@ int apex_genetic_policy(shared_ptr<apex_tuning_session> tuning_session,
APEX_UNUSED(context);
if (apex_final) return APEX_NOERROR; // we terminated
std::unique_lock<std::mutex> l{shutdown_mutex};
if (tuning_session->genetic_session.converged()) {
/* If we are doing nested search contexts, allow us to keep searching
* on outer contexts until all inner contexts have converged! */
bool force{true};
if (context.data != nullptr) {
// the context data is a pointer to a boolean value
force = *((bool*)(context.data));
}
if (tuning_session->genetic_session.converged() && force) {
if (!tuning_session->converged_message) {
tuning_session->converged_message = true;
cout << "APEX: Tuning has converged for session " << tuning_session->id
Expand Down Expand Up @@ -941,7 +955,14 @@ int apex_exhaustive_policy(shared_ptr<apex_tuning_session> tuning_session,
APEX_UNUSED(context);
if (apex_final) return APEX_NOERROR; // we terminated
std::unique_lock<std::mutex> l{shutdown_mutex};
if (tuning_session->exhaustive_session.converged()) {
/* If we are doing nested search contexts, allow us to keep searching
* on outer contexts until all inner contexts have converged! */
bool force{true};
if (context.data != nullptr) {
// the context data is a pointer to a boolean value
force = *((bool*)(context.data));
}
if (tuning_session->exhaustive_session.converged() && force) {
if (!tuning_session->converged_message) {
tuning_session->converged_message = true;
cout << "APEX: Tuning has converged for session " << tuning_session->id
Expand Down Expand Up @@ -970,7 +991,14 @@ int apex_random_policy(shared_ptr<apex_tuning_session> tuning_session,
APEX_UNUSED(context);
if (apex_final) return APEX_NOERROR; // we terminated
std::unique_lock<std::mutex> l{shutdown_mutex};
if (tuning_session->random_session.converged()) {
/* If we are doing nested search contexts, allow us to keep searching
* on outer contexts until all inner contexts have converged! */
bool force{true};
if (context.data != nullptr) {
// the context data is a pointer to a boolean value
force = *((bool*)(context.data));
}
if (tuning_session->random_session.converged() && force) {
if (!tuning_session->converged_message) {
tuning_session->converged_message = true;
cout << "APEX: Tuning has converged for session " << tuning_session->id
Expand Down
1 change: 1 addition & 0 deletions src/apex/exhaustive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ size_t Exhaustive::get_max_iterations() {
}
}
// want to see multiple values of each one
//std::cout << "Max iterations: " << max_iter << std::endl;
return max_iter;
}

Expand Down
2 changes: 1 addition & 1 deletion src/apex/genetic_search.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class GeneticSearch {
size_t kmax;
size_t k;
std::map<std::string, Variable> vars;
const size_t max_iterations{1000};
const size_t max_iterations{500};
const size_t min_iterations{16};
const size_t population_size{32};
const size_t crossover{16}; // half population
Expand Down
2 changes: 1 addition & 1 deletion src/apex/random.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class Random {
size_t kmax;
size_t k;
std::map<std::string, Variable> vars;
const size_t max_iterations{1000};
const size_t max_iterations{500};
const size_t min_iterations{100};
public:
void evaluate(double new_cost);
Expand Down
2 changes: 1 addition & 1 deletion src/apex/simulated_annealing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class SimulatedAnnealing {
size_t kmax;
size_t k;
std::map<std::string, Variable> vars;
const size_t max_iterations{1000};
const size_t max_iterations{500};
const size_t min_iterations{100};
public:
void evaluate(double new_cost);
Expand Down

0 comments on commit 1447ff0

Please sign in to comment.