Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H

#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h"
#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h"
#include <optional>

namespace FlexFlow {

std::optional<PCGBinarySPDecomposition>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ std::unordered_set<BinaryTreePath>
find_paths_to_leaf(PCGBinarySPDecomposition const &,
parallel_layer_guid_t const &);

PCGBinarySPDecomposition pcg_binary_sp_decomposition_from_binary_sp_tree(
BinarySPDecompositionTree const &spd_tree);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include "compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h"
#include "compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h"
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.h"
#include <optional>

namespace FlexFlow {

std::optional<PCGBinarySPDecomposition>
get_pcg_balanced_binary_sp_decomposition(
ParallelComputationGraph const &pcg) {
std::optional<SeriesParallelDecomposition> spd =
get_pcg_series_parallel_decomposition(pcg);
if (!spd) {
return std::nullopt;
}
return pcg_binary_sp_decomposition_from_binary_sp_tree(
balanced_binary_sp_tree_from_nary(spd.value()));
}

} // namespace FlexFlow
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h"
#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h"
#include "compiler/series_parallel/pcg/pcg_binary_series_split.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h"
#include "utils/graph/series_parallel/get_series_parallel_decomposition.h"
#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h"
#include "utils/overload.h"

namespace FlexFlow {
Expand Down Expand Up @@ -67,10 +70,7 @@ BinarySPDecompositionTree
},
[](PCGBinaryParallelSplit const &parallel) -> BinarySPDecompositionTree {
return BinarySPDecompositionTree{
BinaryParallelSplit{
binary_sp_tree_from_pcg_sp_tree(parallel.get_left_child()),
binary_sp_tree_from_pcg_sp_tree(parallel.get_right_child()),
},
binary_parallel_split_from_pcg_parallel_split(parallel),
};
},
[](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree {
Expand All @@ -81,9 +81,35 @@ BinarySPDecompositionTree
});
}

std::optional<PCGBinarySPDecomposition>
get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &) {
NOT_IMPLEMENTED();
PCGBinarySPDecomposition pcg_binary_sp_decomposition_from_binary_sp_tree(
BinarySPDecompositionTree const &spd_tree) {
return spd_tree.visit<PCGBinarySPDecomposition>(overload{
[](BinarySeriesSplit const &series) -> PCGBinarySPDecomposition {
return PCGBinarySPDecomposition{
PCGBinarySeriesSplit{
pcg_binary_sp_decomposition_from_binary_sp_tree(
series.get_left_child()),
pcg_binary_sp_decomposition_from_binary_sp_tree(
series.get_right_child()),
},
};
},
[](BinaryParallelSplit const &parallel) -> PCGBinarySPDecomposition {
return PCGBinarySPDecomposition{
PCGBinaryParallelSplit{
pcg_binary_sp_decomposition_from_binary_sp_tree(
parallel.get_left_child()),
pcg_binary_sp_decomposition_from_binary_sp_tree(
parallel.get_right_child()),
},
};
},
[](Node const &node) -> PCGBinarySPDecomposition {
return PCGBinarySPDecomposition{
parallel_layer_guid_t{node},
};
},
});
}

std::unordered_multiset<parallel_layer_guid_t>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h"
#include "test/utils/doctest/fmt/unordered_multiset.h"
#include "test/utils/rapidcheck.h"
#include <doctest/doctest.h>

using namespace ::FlexFlow;

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("pcg_binary_sp_decomposition_from_binary_sp_tree") {
Node n1 = Node{1};
Node n2 = Node{2};
Node n3 = Node{3};

auto make_binary_series_split = [](BinarySPDecompositionTree const &lhs,
BinarySPDecompositionTree const &rhs) {
return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}};
};

auto make_binary_parallel_split = [](BinarySPDecompositionTree const &lhs,
BinarySPDecompositionTree const &rhs) {
return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}};
};

auto make_binary_leaf = [](Node const &n) {
return BinarySPDecompositionTree{n};
};

auto make_pcg_series_split = [](PCGBinarySPDecomposition const &lhs,
PCGBinarySPDecomposition const &rhs) {
return PCGBinarySPDecomposition{PCGBinarySeriesSplit{lhs, rhs}};
};

auto make_pcg_parallel_split = [](PCGBinarySPDecomposition const &lhs,
PCGBinarySPDecomposition const &rhs) {
return PCGBinarySPDecomposition{PCGBinaryParallelSplit{lhs, rhs}};
};

auto make_pcg_leaf = [](Node const &n) {
return PCGBinarySPDecomposition{parallel_layer_guid_t{n}};
};

SUBCASE("single node") {
BinarySPDecompositionTree input = make_binary_leaf(n1);

PCGBinarySPDecomposition result =
pcg_binary_sp_decomposition_from_binary_sp_tree(input);

PCGBinarySPDecomposition expected = make_pcg_leaf(n1);

CHECK(result == expected);
}

SUBCASE("series split") {
BinarySPDecompositionTree input =
make_binary_series_split(make_binary_leaf(n1), make_binary_leaf(n2));

PCGBinarySPDecomposition result =
pcg_binary_sp_decomposition_from_binary_sp_tree(input);

PCGBinarySPDecomposition expected =
make_pcg_series_split(make_pcg_leaf(n1), make_pcg_leaf(n2));

CHECK(result == expected);
}

SUBCASE("parallel split") {
BinarySPDecompositionTree input = make_binary_parallel_split(
make_binary_leaf(n1), make_binary_leaf(n2));

PCGBinarySPDecomposition result =
pcg_binary_sp_decomposition_from_binary_sp_tree(input);

PCGBinarySPDecomposition expected =
make_pcg_parallel_split(make_pcg_leaf(n1), make_pcg_leaf(n2));

CHECK(result == expected);
}

SUBCASE("bijectiveness") {
BinarySPDecompositionTree original = make_binary_parallel_split(
make_binary_series_split(make_binary_leaf(n1), make_binary_leaf(n2)),
make_binary_leaf(n3));

PCGBinarySPDecomposition pcg_tree =
pcg_binary_sp_decomposition_from_binary_sp_tree(original);
BinarySPDecompositionTree converted =
binary_sp_tree_from_pcg_sp_tree(pcg_tree);

CHECK(original == converted);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BALANCED_BINARY_SP_TREE_FROM_NARY_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BALANCED_BINARY_SP_TREE_FROM_NARY_H

#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h"
#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h"

namespace FlexFlow {

BinarySPDecompositionTree
balanced_binary_sp_tree_from_nary(SeriesParallelDecomposition const &nary);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ std::unordered_multiset<Node> get_leaves(BinarySPDecompositionTree const &);

SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &);

int get_tree_height(BinarySPDecompositionTree const &);

std::unordered_multiset<Node> get_nodes(BinarySPDecompositionTree const &tree);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "utils/containers/foldl1.h"
#include "utils/containers/get_only.h"
#include "utils/containers/subvec.h"
#include "utils/containers/transform.h"
#include "utils/containers/unordered_multiset_of.h"
#include "utils/containers/vector_of.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h"
#include "utils/graph/series_parallel/parallel_split.dtg.h"
#include "utils/graph/series_parallel/series_split.dtg.h"
#include "utils/overload.h"
#include <functional>
#include <optional>

namespace FlexFlow {

BinarySPDecompositionTree
balanced_binary_sp_tree_from_nary(SeriesParallelDecomposition const &nary) {
std::function<BinarySPDecompositionTree(
std::variant<ParallelSplit, Node> const &)>
from_series_child;
std::function<BinarySPDecompositionTree(
std::variant<SeriesSplit, Node> const &)>
from_parallel_child;

std::function<BinarySPDecompositionTree(ParallelSplit const &p)>
from_parallel;
std::function<BinarySPDecompositionTree(SeriesSplit const &p)> from_series;

auto from_node = [](Node const &n) -> BinarySPDecompositionTree {
return BinarySPDecompositionTree{n};
};

from_parallel = [&](ParallelSplit const &s) -> BinarySPDecompositionTree {
auto children = vector_of(s.get_children());
if (children.size() == 1) {
return from_parallel_child(get_only(children));
} else if (children.size() == 2) {
return BinarySPDecompositionTree{BinaryParallelSplit{
from_parallel_child(children[0]), from_parallel_child(children[1])}};
}
auto s1 = unordered_multiset_of(
subvec(children, std::nullopt, children.size() / 2));
auto s2 = unordered_multiset_of(
subvec(children, children.size() / 2, std::nullopt));
return BinarySPDecompositionTree{BinaryParallelSplit{
from_parallel(ParallelSplit{s1}), from_parallel(ParallelSplit{s2})}};
};

from_series = [&](SeriesSplit const &s) -> BinarySPDecompositionTree {
auto children = vector_of(s.children);
if (children.size() == 1) {
return from_series_child(get_only(children));
} else if (children.size() == 2) {
return BinarySPDecompositionTree{BinarySeriesSplit{
from_series_child(children[0]), from_series_child(children[1])}};
}
auto s1 = subvec(children, std::nullopt, children.size() / 2);
auto s2 = subvec(children, children.size() / 2, std::nullopt);
return BinarySPDecompositionTree{BinarySeriesSplit{
from_series(SeriesSplit{s1}), from_series(SeriesSplit{s2})}};
};

from_parallel_child = [&](std::variant<SeriesSplit, Node> const &v)
-> BinarySPDecompositionTree {
return std::visit(overload{
[&](Node const &n) { return from_node(n); },
[&](SeriesSplit const &s) { return from_series(s); },
},
v);
};

from_series_child = [&](std::variant<ParallelSplit, Node> const &v)
-> BinarySPDecompositionTree {
return std::visit(
overload{
[&](Node const &n) { return from_node(n); },
[&](ParallelSplit const &p) { return from_parallel(p); },
},
v);
};

return nary.visit<BinarySPDecompositionTree>(overload{
[&](Node const &n) { return from_node(n); },
[&](SeriesSplit const &s) { return from_series(s); },
[&](ParallelSplit const &p) { return from_parallel(p); },
});
}

} // namespace FlexFlow
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h"
#include "utils/containers/multiset_union.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h"

namespace FlexFlow {

GenericBinarySPDecompositionTreeImplementation<BinarySPDecompositionTree,
Expand Down Expand Up @@ -82,4 +82,35 @@ SPDecompositionTreeNodeType
});
}

int get_tree_height(BinarySPDecompositionTree const &tree) {
return tree.visit<int>(overload{
[](BinarySeriesSplit const &series) -> int {
int left_height = get_tree_height(series.get_left_child());
int right_height = get_tree_height(series.get_right_child());
return std::max(left_height, right_height) + 1;
},
[](BinaryParallelSplit const &parallel) -> int {
int left_height = get_tree_height(parallel.get_left_child());
int right_height = get_tree_height(parallel.get_right_child());
return std::max(left_height, right_height) + 1;
},
[](Node const &) -> int { return 0; },
});
}

std::unordered_multiset<Node> get_nodes(BinarySPDecompositionTree const &tree) {
return tree.visit<std::unordered_multiset<Node>>(overload{
[](BinarySeriesSplit const &series) -> std::unordered_multiset<Node> {
auto left_nodes = get_nodes(series.get_left_child());
auto right_nodes = get_nodes(series.get_right_child());
return multiset_union(left_nodes, right_nodes);
},
[](BinaryParallelSplit const &parallel) -> std::unordered_multiset<Node> {
auto left_nodes = get_nodes(parallel.get_left_child());
auto right_nodes = get_nodes(parallel.get_right_child());
return multiset_union(left_nodes, right_nodes);
},
[](Node const &node) -> std::unordered_multiset<Node> { return {node}; },
});
}
} // namespace FlexFlow
Loading
Loading