Skip to content

Commit

Permalink
Update Iterative Tasking Infrastructure (#917)
Browse files Browse the repository at this point in the history
* Add lroberts36/merge-sparse-with-jdolence-sparse's changes to iterative tasking

* Update copyrights

* Add undordered map include

* Update changelog

* Fix function name typos in task_list.hpp

---------

Co-authored-by: Patrick Mullen <[email protected]>
Co-authored-by: Patrick Mullen <[email protected]>
  • Loading branch information
3 people authored Aug 11, 2023
1 parent 94caccd commit ab9d167
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 62 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- [[PR 885]](https://github.com/parthenon-hpc-lab/parthenon/pull/885) Expose PackDescriptor and use uids in SparsePacks

### Fixed (not changing behavior/API/variables/...)
- [[PR 917]](https://github.com/parthenon-hpc-lab/parthenon/pull/917) Update Iterative Tasking Infrastructure
- [[PR 890]](https://github.com/parthenon-hpc-lab/parthenon/pull/890) Fix bugs in sparse communication and prolongation

### Infrastructure (changes irrelevant to downstream codes)
Expand Down
4 changes: 2 additions & 2 deletions src/basic_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//========================================================================================
// (C) (or copyright) 2021-2022. Triad National Security, LLC. All rights reserved.
// (C) (or copyright) 2021-2023. Triad National Security, LLC. All rights reserved.
//
// This program was produced under U.S. Government contract 89233218CNA000001 for Los
// Alamos National Laboratory (LANL), which is operated by Triad National Security, LLC
Expand Down Expand Up @@ -36,7 +36,7 @@ using Real = double;
#endif
#endif

enum class TaskStatus { fail, complete, incomplete, iterate, skip };
enum class TaskStatus { fail, complete, incomplete, iterate, skip, waiting };
enum class AmrTag : int { derefine = -1, same = 0, refine = 1 };
enum class RefinementOp_t { Prolongation, Restriction, None };

Expand Down
191 changes: 131 additions & 60 deletions src/tasks/task_list.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//========================================================================================
// (C) (or copyright) 2021. Triad National Security, LLC. All rights reserved.
// (C) (or copyright) 2023. Triad National Security, LLC. All rights reserved.
//
// This program was produced under U.S. Government contract 89233218CNA000001 for Los
// Alamos National Laboratory (LANL), which is operated by Triad National Security, LLC
Expand All @@ -24,6 +24,7 @@
#include <stdexcept>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>

Expand All @@ -44,17 +45,18 @@ TaskID AddTaskHelper(TaskList *, Task);

class IterativeTasks {
public:
IterativeTasks() = default;
IterativeTasks(TaskList *tl, int key) : tl_(tl), key_(key) {
max_iterations_ = std::numeric_limits<int>::max();
}

// overload to add member functions of class T to task list
// NOTE: we must capture the object pointer
template <class T, class U, class... Args>
TaskID AddTask(TaskID const &dep, TaskStatus (T::*func)(Args...), U *obj,
Args &&...args) {
template <class T, class U, class... Args1, class... Args2>
TaskID AddTask(TaskID const &dep, TaskStatus (T::*func)(Args1...), U *obj,
Args2 &&...args) {
return this->AddTask_(TaskType::iterative, 1, dep, [=]() mutable -> TaskStatus {
return (obj->*func)(std::forward<Args>(args)...);
return (obj->*func)(std::forward<Args2>(args)...);
});
}

Expand Down Expand Up @@ -94,6 +96,8 @@ class IterativeTasks {
int GetMaxIterations() const { return max_iterations_; }
int GetIterationCount() const { return count_; }
void IncrementCount() { count_++; }
void ResetCount() { count_ = 0; }
void PrintList() { std::cout << "tl_ = " << tl_ << std::endl; }

private:
template <class F, class... Args>
Expand Down Expand Up @@ -139,7 +143,8 @@ class TaskList {
for (auto &task : task_list_) {
if (task.GetID() == id) {
return (task.GetStatus() != TaskStatus::incomplete &&
task.GetStatus() != TaskStatus::skip);
task.GetStatus() != TaskStatus::skip &&
task.GetStatus() != TaskStatus::waiting);
}
}
return false;
Expand Down Expand Up @@ -184,6 +189,7 @@ class TaskList {
++task;
}
}
iter_tasks[key].ResetCount();
}
void ResetIteration(const int key) {
PARTHENON_REQUIRE_THROWS(key < iter_tasks.size(), "Invalid iteration key");
Expand All @@ -197,6 +203,11 @@ class TaskList {
PARTHENON_WARN("Iteration " + iter_labels[key] +
" reached maximum allowed cycles without convergence.");
}
for (auto &task : task_list_) {
if (task.GetKey() == key && task.GetType() == TaskType::completion_criteria) {
MarkTaskComplete(task.GetID());
}
}
ClearIteration(key);
return;
}
Expand All @@ -219,36 +230,47 @@ class TaskList {
}
}
}
void CompleteIfNeeded(const TaskID &id) {
bool CompleteIfNeeded(const TaskID &id) {
MarkTaskComplete(id);
auto task = task_list_.begin();
while (task != task_list_.end()) {
if (task->GetID() == id) {
if (task->GetType() == TaskType::completion_criteria) {
ClearIteration(task->GetKey());
return true;
} else if (task->GetType() == TaskType::single) {
task_list_.erase(task);
task = task_list_.erase(task);
} else {
task->SetStatus(TaskStatus::waiting);
}
break;
} else {
++task;
}
}
return false;
}
void DoAvailable() {
for (auto &task : task_list_) {
auto task = task_list_.begin();
while (task != task_list_.end()) {
// first skip task if it's complete. Possible for iterative tasks
if (task.GetStatus() != TaskStatus::incomplete) continue;
auto dep = task.GetDependency();
if (task->GetStatus() != TaskStatus::incomplete) {
++task;
continue;
}
auto dep = task->GetDependency();
if (CheckDependencies(dep)) {
task();
if (task.GetStatus() == TaskStatus::complete && !task.IsRegional()) {
MarkTaskComplete(task.GetID());
} else if (task.GetStatus() == TaskStatus::skip &&
task.GetType() == TaskType::completion_criteria) {
ResetIteration(task.GetKey());
(*task)();
if (task->GetStatus() == TaskStatus::complete && !task->IsRegional()) {
MarkTaskComplete(task->GetID());
} else if (task->GetStatus() == TaskStatus::skip &&
task->GetType() == TaskType::completion_criteria) {
ResetIteration(task->GetKey());
} else if (task->GetStatus() == TaskStatus::iterate && !task->IsRegional()) {
ResetIteration(task->GetKey());
}
}
++task;
}
ClearComplete();
}
Expand All @@ -274,21 +296,21 @@ class TaskList {
return valid;
}

TaskID AddTask(Task tsk) {
TaskID AddTask(Task &tsk) {
TaskID id(tasks_added_ + 1);
tsk.SetID(id);
task_list_.push_back(tsk);
task_list_.push_back(std::move(tsk));
tasks_added_++;
return id;
}

// overload to add member functions of class T to task list
// NOTE: we must capture the object pointer
template <class T, class U, class... Args>
TaskID AddTask(TaskID const &dep, TaskStatus (T::*func)(Args...), U *obj,
Args &&...args) {
template <class T, class U, class... Args1, class... Args2>
TaskID AddTask(TaskID const &dep, TaskStatus (T::*func)(Args1...), U *obj,
Args2 &&...args) {
return this->AddTask(dep, [=]() mutable -> TaskStatus {
return (obj->*func)(std::forward<Args>(args)...);
return (obj->*func)(std::forward<Args2>(args)...);
});
}

Expand All @@ -305,18 +327,18 @@ class TaskList {

IterativeTasks &AddIteration(const std::string &label) {
int key = iter_tasks.size();
iter_tasks.push_back(IterativeTasks(this, key));
iter_labels.push_back(label);
return iter_tasks.back();
iter_tasks[key] = IterativeTasks(this, key);
iter_labels[key] = label;
return iter_tasks[key];
}

void Print() {
int i = 0;
std::cout << "TaskList::Print():" << std::endl;
for (auto &t : task_list_) {
std::cout << " " << i << " " << t.GetID().to_string() << " "
<< t.GetDependency().to_string() << " "
<< (t.GetStatus() == TaskStatus::incomplete)
<< t.GetDependency().to_string() << " " << tasks_completed_.to_string()
<< " " << (t.GetStatus() == TaskStatus::incomplete)
<< (t.GetStatus() == TaskStatus::complete)
<< (t.GetStatus() == TaskStatus::skip)
<< (t.GetStatus() == TaskStatus::iterate)
Expand All @@ -327,8 +349,8 @@ class TaskList {
}

protected:
std::vector<IterativeTasks> iter_tasks;
std::vector<std::string> iter_labels;
std::map<int, IterativeTasks> iter_tasks;
std::map<int, std::string> iter_labels;
std::list<Task> task_list_;
int tasks_added_ = 0;
TaskID tasks_completed_;
Expand All @@ -340,41 +362,92 @@ namespace task_list_impl {
inline TaskID AddTaskHelper(TaskList *tl, Task tsk) { return tl->AddTask(tsk); }
} // namespace task_list_impl

class RegionCounter {
public:
explicit RegionCounter(const std::string &base) : base_(base), cnt_(0) {}
std::string ID() { return base_ + std::to_string(cnt_++); }

private:
const std::string base_;
int cnt_;
};

class TaskRegion {
public:
explicit TaskRegion(const int size) : lists(size) {}
void AddRegionalDependencies(const int reg_dep_id, const int list_index, TaskID id) {
auto task_pair = std::make_pair(list_index, id);
id_for_reg[reg_dep_id].push_back(task_pair);
lists[list_index].MarkRegional(id);
all_done[reg_dep_id].val = 0;
void AddRegionalDependencies(const int reg_dep_id, const int list_index,
const TaskID &id) {
AddRegionalDependencies(std::to_string(reg_dep_id), list_index, id);
}
void AddRegionalDependencies(const std::string &reg_dep_id, const int list_index,
const TaskID &id) {
AddDependencies(reg_dep_id, list_index, id);
global[reg_dep_id] = false;
}
void AddGlobalDependencies(const int reg_dep_id, const int list_index,
const TaskID &id) {
AddGlobalDependencies(std::to_string(reg_dep_id), list_index, id);
}
void AddGlobalDependencies(const std::string &reg_dep_id, const int list_index,
const TaskID &id) {
AddDependencies(reg_dep_id, list_index, id);
global[reg_dep_id] = true;
}

TaskList &operator[](int i) { return lists[i]; }

int size() const { return lists.size(); }

bool Execute() {
for (auto i = 0; i < lists.size(); ++i) {
if (!lists[i].IsComplete()) {
lists[i].DoAvailable();
}
}
return CheckAndUpdate();
}

bool CheckAndUpdate() {
for (auto &reg_dep : id_for_reg) {
auto reg_id = reg_dep.first;
auto it = id_for_reg.begin();
while (it != id_for_reg.end()) {
auto &reg_id = it->first;
bool check = false;
if (HasRun(reg_id) && !all_done[reg_id].active) {
all_done[reg_id].val = IsComplete(reg_id);
all_done[reg_id].StartReduce(MPI_MIN);
if (global[reg_id]) {
all_done[reg_id].StartReduce(MPI_MIN);
} else {
check = true;
}
}
if (all_done[reg_id].active) {
if (global[reg_id] && all_done[reg_id].active) {
auto status = all_done[reg_id].CheckReduce();
if (status == TaskStatus::complete) {
if (all_done[reg_id].val) {
for (auto &lst : reg_dep.second) {
lists[lst.first].CompleteIfNeeded(lst.second);
}
check = true;
}
}
if (check) {
if (all_done[reg_id].val) {
bool clear = false;
for (auto &lst : it->second) {
clear = lists[lst.first].CompleteIfNeeded(lst.second);
}
if (clear) {
all_done.erase(reg_id);
global.erase(reg_id);
it = id_for_reg.erase(it);
} else {
for (auto &lst : reg_dep.second) {
lists[lst.first].ResetIfNeeded(lst.second);
}
all_done[reg_id].val = 0;
++it;
}
} else {
for (auto &lst : it->second) {
lists[lst.first].ResetIfNeeded(lst.second);
}
all_done[reg_id].val = 0;
++it;
}
} else {
++it;
}
}
int complete_cnt = 0;
Expand All @@ -393,7 +466,12 @@ class TaskRegion {
}

private:
bool HasRun(const int reg_id) {
void AddDependencies(const std::string &label, const int list_id, const TaskID &tid) {
id_for_reg[label][list_id] = tid;
lists[list_id].MarkRegional(tid);
all_done[label].val = 0;
}
bool HasRun(const std::string &reg_id) {
auto &lvec = id_for_reg[reg_id];
int n_to_run = lvec.size();
int n_ran = 0;
Expand All @@ -406,7 +484,7 @@ class TaskRegion {
}
return n_ran == n_to_run;
}
bool IsComplete(const int reg_id) {
bool IsComplete(const std::string &reg_id) {
auto &lvec = id_for_reg[reg_id];
int n_to_finish = lvec.size();
int n_finished = 0;
Expand All @@ -420,10 +498,10 @@ class TaskRegion {
return n_finished == n_to_finish;
}

// id_for_reg[region_id] = std::pair<>(task_list_index, task_id_of_regional_task)
std::map<int, std::vector<std::pair<int, TaskID>>> id_for_reg;
std::unordered_map<std::string, std::map<int, TaskID>> id_for_reg;
std::vector<TaskList> lists;
std::map<int, AllReduce<int>> all_done;
std::unordered_map<std::string, AllReduce<int>> all_done;
std::unordered_map<std::string, bool> global;
};

class TaskCollection {
Expand All @@ -437,15 +515,8 @@ class TaskCollection {
assert(Validate());
for (auto &region : regions) {
bool complete = false;
auto num_lists = region.size();
while (!complete) {
// TODO(pgrete): need to let Kokkos::PartitionManager handle this
for (auto i = 0; i < num_lists; ++i) {
if (!region[i].IsComplete()) {
region[i].DoAvailable();
}
}
complete = region.CheckAndUpdate();
complete = region.Execute();
}
}
return TaskListStatus::complete;
Expand Down

0 comments on commit ab9d167

Please sign in to comment.