Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Iterative Tasking Infrastructure #917

Merged
merged 5 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- [[PR 885]](https://github.com/parthenon-hpc-lab/parthenon/pull/885) Expose PackDescriptor and use uids in SparsePacks

### Fixed (not changing behavior/API/variables/...)
- [[PR 917]](https://github.com/parthenon-hpc-lab/parthenon/pull/917) Update Iterative Tasking Infrastructure
- [[PR 890]](https://github.com/parthenon-hpc-lab/parthenon/pull/890) Fix bugs in sparse communication and prolongation

### Infrastructure (changes irrelevant to downstream codes)
Expand Down
4 changes: 2 additions & 2 deletions src/basic_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//========================================================================================
// (C) (or copyright) 2021-2022. Triad National Security, LLC. All rights reserved.
// (C) (or copyright) 2021-2023. Triad National Security, LLC. All rights reserved.
//
// This program was produced under U.S. Government contract 89233218CNA000001 for Los
// Alamos National Laboratory (LANL), which is operated by Triad National Security, LLC
Expand Down Expand Up @@ -36,7 +36,7 @@ using Real = double;
#endif
#endif

enum class TaskStatus { fail, complete, incomplete, iterate, skip };
enum class TaskStatus { fail, complete, incomplete, iterate, skip, waiting };
enum class AmrTag : int { derefine = -1, same = 0, refine = 1 };
enum class RefinementOp_t { Prolongation, Restriction, None };

Expand Down
191 changes: 131 additions & 60 deletions src/tasks/task_list.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//========================================================================================
// (C) (or copyright) 2021. Triad National Security, LLC. All rights reserved.
// (C) (or copyright) 2023. Triad National Security, LLC. All rights reserved.
//
// This program was produced under U.S. Government contract 89233218CNA000001 for Los
// Alamos National Laboratory (LANL), which is operated by Triad National Security, LLC
Expand All @@ -24,6 +24,7 @@
#include <stdexcept>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>

Expand All @@ -44,17 +45,18 @@ TaskID AddTaskHelper(TaskList *, Task);

class IterativeTasks {
public:
IterativeTasks() = default;
IterativeTasks(TaskList *tl, int key) : tl_(tl), key_(key) {
max_iterations_ = std::numeric_limits<int>::max();
}

// overload to add member functions of class T to task list
// NOTE: we must capture the object pointer
template <class T, class U, class... Args>
TaskID AddTask(TaskID const &dep, TaskStatus (T::*func)(Args...), U *obj,
Args &&...args) {
template <class T, class U, class... Args1, class... Args2>
TaskID AddTask(TaskID const &dep, TaskStatus (T::*func)(Args1...), U *obj,
Args2 &&...args) {
return this->AddTask_(TaskType::iterative, 1, dep, [=]() mutable -> TaskStatus {
return (obj->*func)(std::forward<Args>(args)...);
return (obj->*func)(std::forward<Args2>(args)...);
Comment on lines +55 to +59
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't immediately see why we need two separate variadic template parameters here. I am mostly just curious what issue the single list was causing. Maybe it was something to do with default arguments?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I can't remember exactly what situation led to this change, but I can give you the basic idea. What happened was that I wrote a task with some set of reasonable arguments (I thought) but the compiler yelled at me about type mismatches when I actually tried to add it to a task list. Don't hold me to this, but I think it was a const vs non-const issue. I convinced myself at the time that there was no reason we should be preventing people from writing tasks the way I had written the problematic one, and this transformation generalized the interface to allow it.

});
}

Expand Down Expand Up @@ -94,6 +96,8 @@ class IterativeTasks {
int GetMaxIterations() const { return max_iterations_; }
int GetIterationCount() const { return count_; }
void IncrementCount() { count_++; }
void ResetCount() { count_ = 0; }
void PrintList() { std::cout << "tl_ = " << tl_ << std::endl; }

private:
template <class F, class... Args>
Expand Down Expand Up @@ -139,7 +143,8 @@ class TaskList {
for (auto &task : task_list_) {
if (task.GetID() == id) {
return (task.GetStatus() != TaskStatus::incomplete &&
task.GetStatus() != TaskStatus::skip);
task.GetStatus() != TaskStatus::skip &&
task.GetStatus() != TaskStatus::waiting);
}
}
return false;
Expand Down Expand Up @@ -184,6 +189,7 @@ class TaskList {
++task;
}
}
iter_tasks[key].ResetCount();
}
void ResetIteration(const int key) {
PARTHENON_REQUIRE_THROWS(key < iter_tasks.size(), "Invalid iteration key");
Expand All @@ -197,6 +203,11 @@ class TaskList {
PARTHENON_WARN("Iteration " + iter_labels[key] +
" reached maximum allowed cycles without convergence.");
}
for (auto &task : task_list_) {
if (task.GetKey() == key && task.GetType() == TaskType::completion_criteria) {
MarkTaskComplete(task.GetID());
}
}
ClearIteration(key);
return;
}
Expand All @@ -219,36 +230,47 @@ class TaskList {
}
}
}
void CompleteIfNeeded(const TaskID &id) {
bool CompleteIfNeeded(const TaskID &id) {
MarkTaskComplete(id);
auto task = task_list_.begin();
while (task != task_list_.end()) {
if (task->GetID() == id) {
if (task->GetType() == TaskType::completion_criteria) {
ClearIteration(task->GetKey());
return true;
} else if (task->GetType() == TaskType::single) {
task_list_.erase(task);
task = task_list_.erase(task);
} else {
task->SetStatus(TaskStatus::waiting);
}
break;
} else {
++task;
}
}
return false;
}
void DoAvailable() {
for (auto &task : task_list_) {
auto task = task_list_.begin();
while (task != task_list_.end()) {
// first skip task if it's complete. Possible for iterative tasks
if (task.GetStatus() != TaskStatus::incomplete) continue;
auto dep = task.GetDependency();
if (task->GetStatus() != TaskStatus::incomplete) {
++task;
continue;
}
auto dep = task->GetDependency();
if (CheckDependencies(dep)) {
task();
if (task.GetStatus() == TaskStatus::complete && !task.IsRegional()) {
MarkTaskComplete(task.GetID());
} else if (task.GetStatus() == TaskStatus::skip &&
task.GetType() == TaskType::completion_criteria) {
ResetIteration(task.GetKey());
(*task)();
if (task->GetStatus() == TaskStatus::complete && !task->IsRegional()) {
MarkTaskComplete(task->GetID());
} else if (task->GetStatus() == TaskStatus::skip &&
task->GetType() == TaskType::completion_criteria) {
ResetIteration(task->GetKey());
} else if (task->GetStatus() == TaskStatus::iterate && !task->IsRegional()) {
ResetIteration(task->GetKey());
}
}
++task;
}
ClearComplete();
}
Expand All @@ -274,21 +296,21 @@ class TaskList {
return valid;
}

TaskID AddTask(Task tsk) {
TaskID AddTask(Task &tsk) {
TaskID id(tasks_added_ + 1);
tsk.SetID(id);
task_list_.push_back(tsk);
task_list_.push_back(std::move(tsk));
tasks_added_++;
return id;
}

// overload to add member functions of class T to task list
// NOTE: we must capture the object pointer
template <class T, class U, class... Args>
TaskID AddTask(TaskID const &dep, TaskStatus (T::*func)(Args...), U *obj,
Args &&...args) {
template <class T, class U, class... Args1, class... Args2>
TaskID AddTask(TaskID const &dep, TaskStatus (T::*func)(Args1...), U *obj,
Args2 &&...args) {
return this->AddTask(dep, [=]() mutable -> TaskStatus {
return (obj->*func)(std::forward<Args>(args)...);
return (obj->*func)(std::forward<Args2>(args)...);
});
}

Expand All @@ -305,18 +327,18 @@ class TaskList {

IterativeTasks &AddIteration(const std::string &label) {
int key = iter_tasks.size();
iter_tasks.push_back(IterativeTasks(this, key));
iter_labels.push_back(label);
return iter_tasks.back();
iter_tasks[key] = IterativeTasks(this, key);
iter_labels[key] = label;
return iter_tasks[key];
}

void Print() {
int i = 0;
std::cout << "TaskList::Print():" << std::endl;
for (auto &t : task_list_) {
std::cout << " " << i << " " << t.GetID().to_string() << " "
<< t.GetDependency().to_string() << " "
<< (t.GetStatus() == TaskStatus::incomplete)
<< t.GetDependency().to_string() << " " << tasks_completed_.to_string()
<< " " << (t.GetStatus() == TaskStatus::incomplete)
<< (t.GetStatus() == TaskStatus::complete)
<< (t.GetStatus() == TaskStatus::skip)
<< (t.GetStatus() == TaskStatus::iterate)
Expand All @@ -327,8 +349,8 @@ class TaskList {
}

protected:
std::vector<IterativeTasks> iter_tasks;
std::vector<std::string> iter_labels;
std::map<int, IterativeTasks> iter_tasks;
std::map<int, std::string> iter_labels;
std::list<Task> task_list_;
int tasks_added_ = 0;
TaskID tasks_completed_;
Expand All @@ -340,41 +362,92 @@ namespace task_list_impl {
inline TaskID AddTaskHelper(TaskList *tl, Task tsk) { return tl->AddTask(tsk); }
} // namespace task_list_impl

class RegionCounter {
public:
explicit RegionCounter(const std::string &base) : base_(base), cnt_(0) {}
std::string ID() { return base_ + std::to_string(cnt_++); }

private:
const std::string base_;
int cnt_;
};

class TaskRegion {
public:
explicit TaskRegion(const int size) : lists(size) {}
void AddRegionalDependencies(const int reg_dep_id, const int list_index, TaskID id) {
auto task_pair = std::make_pair(list_index, id);
id_for_reg[reg_dep_id].push_back(task_pair);
lists[list_index].MarkRegional(id);
all_done[reg_dep_id].val = 0;
void AddRegionalDependencies(const int reg_dep_id, const int list_index,
const TaskID &id) {
AddRegionalDependencies(std::to_string(reg_dep_id), list_index, id);
}
void AddRegionalDependencies(const std::string &reg_dep_id, const int list_index,
const TaskID &id) {
AddDependencies(reg_dep_id, list_index, id);
global[reg_dep_id] = false;
}
void AddGlobalDependencies(const int reg_dep_id, const int list_index,
const TaskID &id) {
AddGlobalDependencies(std::to_string(reg_dep_id), list_index, id);
}
void AddGlobalDependencies(const std::string &reg_dep_id, const int list_index,
const TaskID &id) {
AddDependencies(reg_dep_id, list_index, id);
global[reg_dep_id] = true;
}

TaskList &operator[](int i) { return lists[i]; }

int size() const { return lists.size(); }

bool Execute() {
for (auto i = 0; i < lists.size(); ++i) {
if (!lists[i].IsComplete()) {
lists[i].DoAvailable();
}
}
return CheckAndUpdate();
}

bool CheckAndUpdate() {
for (auto &reg_dep : id_for_reg) {
auto reg_id = reg_dep.first;
auto it = id_for_reg.begin();
while (it != id_for_reg.end()) {
auto &reg_id = it->first;
bool check = false;
if (HasRun(reg_id) && !all_done[reg_id].active) {
all_done[reg_id].val = IsComplete(reg_id);
all_done[reg_id].StartReduce(MPI_MIN);
if (global[reg_id]) {
all_done[reg_id].StartReduce(MPI_MIN);
} else {
check = true;
}
}
if (all_done[reg_id].active) {
if (global[reg_id] && all_done[reg_id].active) {
auto status = all_done[reg_id].CheckReduce();
if (status == TaskStatus::complete) {
if (all_done[reg_id].val) {
for (auto &lst : reg_dep.second) {
lists[lst.first].CompleteIfNeeded(lst.second);
}
check = true;
}
}
if (check) {
if (all_done[reg_id].val) {
bool clear = false;
for (auto &lst : it->second) {
clear = lists[lst.first].CompleteIfNeeded(lst.second);
}
if (clear) {
all_done.erase(reg_id);
global.erase(reg_id);
it = id_for_reg.erase(it);
} else {
for (auto &lst : reg_dep.second) {
lists[lst.first].ResetIfNeeded(lst.second);
}
all_done[reg_id].val = 0;
++it;
}
} else {
for (auto &lst : it->second) {
lists[lst.first].ResetIfNeeded(lst.second);
}
all_done[reg_id].val = 0;
++it;
}
} else {
++it;
}
}
int complete_cnt = 0;
Expand All @@ -393,7 +466,12 @@ class TaskRegion {
}

private:
bool HasRun(const int reg_id) {
void AddDependencies(const std::string &label, const int list_id, const TaskID &tid) {
id_for_reg[label][list_id] = tid;
lists[list_id].MarkRegional(tid);
all_done[label].val = 0;
}
bool HasRun(const std::string &reg_id) {
auto &lvec = id_for_reg[reg_id];
int n_to_run = lvec.size();
int n_ran = 0;
Expand All @@ -406,7 +484,7 @@ class TaskRegion {
}
return n_ran == n_to_run;
}
bool IsComplete(const int reg_id) {
bool IsComplete(const std::string &reg_id) {
auto &lvec = id_for_reg[reg_id];
int n_to_finish = lvec.size();
int n_finished = 0;
Expand All @@ -420,10 +498,10 @@ class TaskRegion {
return n_finished == n_to_finish;
}

// id_for_reg[region_id] = std::pair<>(task_list_index, task_id_of_regional_task)
std::map<int, std::vector<std::pair<int, TaskID>>> id_for_reg;
std::unordered_map<std::string, std::map<int, TaskID>> id_for_reg;
std::vector<TaskList> lists;
std::map<int, AllReduce<int>> all_done;
std::unordered_map<std::string, AllReduce<int>> all_done;
std::unordered_map<std::string, bool> global;
};

class TaskCollection {
Expand All @@ -437,15 +515,8 @@ class TaskCollection {
assert(Validate());
for (auto &region : regions) {
bool complete = false;
auto num_lists = region.size();
while (!complete) {
// TODO(pgrete): need to let Kokkos::PartitionManager handle this
for (auto i = 0; i < num_lists; ++i) {
if (!region[i].IsComplete()) {
region[i].DoAvailable();
}
}
complete = region.CheckAndUpdate();
complete = region.Execute();
}
}
return TaskListStatus::complete;
Expand Down