diff --git a/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h index d43edaa79d..a71eb45609 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h @@ -1,6 +1,10 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include + namespace FlexFlow { std::optional diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h index 86fa1a59aa..5f17a7c677 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h @@ -33,6 +33,9 @@ std::unordered_set find_paths_to_leaf(PCGBinarySPDecomposition const &, parallel_layer_guid_t const &); +PCGBinarySPDecomposition pcg_binary_sp_decomposition_from_binary_sp_tree( + BinarySPDecompositionTree const &spd_tree); + } // namespace FlexFlow #endif diff --git a/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.cc new file mode 100644 index 0000000000..cfd4d65bb4 --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.cc @@ -0,0 +1,21 @@ +#include "compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.h" +#include + +namespace FlexFlow { + +std::optional + get_pcg_balanced_binary_sp_decomposition( + ParallelComputationGraph const &pcg) { + std::optional spd = + get_pcg_series_parallel_decomposition(pcg); + if (!spd) { + return std::nullopt; + } + return pcg_binary_sp_decomposition_from_binary_sp_tree( + balanced_binary_sp_tree_from_nary(spd.value())); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc index 5eb993c6ef..589cbb4c0d 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -1,7 +1,10 @@ #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h" #include "compiler/series_parallel/pcg/pcg_binary_series_split.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include "utils/overload.h" namespace FlexFlow { @@ -67,10 +70,7 @@ BinarySPDecompositionTree }, [](PCGBinaryParallelSplit const ¶llel) -> BinarySPDecompositionTree { return BinarySPDecompositionTree{ - BinaryParallelSplit{ - binary_sp_tree_from_pcg_sp_tree(parallel.get_left_child()), - binary_sp_tree_from_pcg_sp_tree(parallel.get_right_child()), - }, + binary_parallel_split_from_pcg_parallel_split(parallel), }; }, [](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree { @@ -81,9 +81,35 @@ BinarySPDecompositionTree }); } -std::optional - get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &) { - NOT_IMPLEMENTED(); +PCGBinarySPDecomposition pcg_binary_sp_decomposition_from_binary_sp_tree( + BinarySPDecompositionTree const &spd_tree) { + return spd_tree.visit(overload{ + [](BinarySeriesSplit const &series) -> PCGBinarySPDecomposition { + return PCGBinarySPDecomposition{ + PCGBinarySeriesSplit{ + pcg_binary_sp_decomposition_from_binary_sp_tree( + series.get_left_child()), + pcg_binary_sp_decomposition_from_binary_sp_tree( + series.get_right_child()), + }, + }; + }, + [](BinaryParallelSplit const ¶llel) -> PCGBinarySPDecomposition { + return PCGBinarySPDecomposition{ + PCGBinaryParallelSplit{ + pcg_binary_sp_decomposition_from_binary_sp_tree( + parallel.get_left_child()), + pcg_binary_sp_decomposition_from_binary_sp_tree( + parallel.get_right_child()), + }, + }; + }, + [](Node const &node) -> PCGBinarySPDecomposition { + return PCGBinarySPDecomposition{ + parallel_layer_guid_t{node}, + }; + }, + }); } std::unordered_multiset diff --git a/lib/compiler/test/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc new file mode 100644 index 0000000000..8c1d3221d3 --- /dev/null +++ b/lib/compiler/test/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -0,0 +1,92 @@ +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("pcg_binary_sp_decomposition_from_binary_sp_tree") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + + auto make_binary_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_binary_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_binary_leaf = [](Node const &n) { + return BinarySPDecompositionTree{n}; + }; + + auto make_pcg_series_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinarySeriesSplit{lhs, rhs}}; + }; + + auto make_pcg_parallel_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinaryParallelSplit{lhs, rhs}}; + }; + + auto make_pcg_leaf = [](Node const &n) { + return PCGBinarySPDecomposition{parallel_layer_guid_t{n}}; + }; + + SUBCASE("single node") { + BinarySPDecompositionTree input = make_binary_leaf(n1); + + PCGBinarySPDecomposition result = + pcg_binary_sp_decomposition_from_binary_sp_tree(input); + + PCGBinarySPDecomposition expected = make_pcg_leaf(n1); + + CHECK(result == expected); + } + + SUBCASE("series split") { + BinarySPDecompositionTree input = + make_binary_series_split(make_binary_leaf(n1), make_binary_leaf(n2)); + + PCGBinarySPDecomposition result = + pcg_binary_sp_decomposition_from_binary_sp_tree(input); + + PCGBinarySPDecomposition expected = + make_pcg_series_split(make_pcg_leaf(n1), make_pcg_leaf(n2)); + + CHECK(result == expected); + } + + SUBCASE("parallel split") { + BinarySPDecompositionTree input = make_binary_parallel_split( + make_binary_leaf(n1), make_binary_leaf(n2)); + + PCGBinarySPDecomposition result = + pcg_binary_sp_decomposition_from_binary_sp_tree(input); + + PCGBinarySPDecomposition expected = + make_pcg_parallel_split(make_pcg_leaf(n1), make_pcg_leaf(n2)); + + CHECK(result == expected); + } + + SUBCASE("bijectiveness") { + BinarySPDecompositionTree original = make_binary_parallel_split( + make_binary_series_split(make_binary_leaf(n1), make_binary_leaf(n2)), + make_binary_leaf(n3)); + + PCGBinarySPDecomposition pcg_tree = + pcg_binary_sp_decomposition_from_binary_sp_tree(original); + BinarySPDecompositionTree converted = + binary_sp_tree_from_pcg_sp_tree(pcg_tree); + + CHECK(original == converted); + } + } +} diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.h new file mode 100644 index 0000000000..191e6bb1ef --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BALANCED_BINARY_SP_TREE_FROM_NARY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BALANCED_BINARY_SP_TREE_FROM_NARY_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +BinarySPDecompositionTree + balanced_binary_sp_tree_from_nary(SeriesParallelDecomposition const &nary); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h index de48cd17e9..e138ff9c60 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -23,6 +23,10 @@ std::unordered_multiset get_leaves(BinarySPDecompositionTree const &); SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &); +int get_tree_height(BinarySPDecompositionTree const &); + +std::unordered_multiset get_nodes(BinarySPDecompositionTree const &tree); + } // namespace FlexFlow #endif diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..e65dd7376e --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc @@ -0,0 +1,92 @@ +#include "utils/containers/foldl1.h" +#include "utils/containers/get_only.h" +#include "utils/containers/subvec.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/parallel_split.dtg.h" +#include "utils/graph/series_parallel/series_split.dtg.h" +#include "utils/overload.h" +#include +#include + +namespace FlexFlow { + +BinarySPDecompositionTree + balanced_binary_sp_tree_from_nary(SeriesParallelDecomposition const &nary) { + std::function const &)> + from_series_child; + std::function const &)> + from_parallel_child; + + std::function + from_parallel; + std::function from_series; + + auto from_node = [](Node const &n) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{n}; + }; + + from_parallel = [&](ParallelSplit const &s) -> BinarySPDecompositionTree { + auto children = vector_of(s.get_children()); + if (children.size() == 1) { + return from_parallel_child(get_only(children)); + } else if (children.size() == 2) { + return BinarySPDecompositionTree{BinaryParallelSplit{ + from_parallel_child(children[0]), from_parallel_child(children[1])}}; + } + auto s1 = unordered_multiset_of( + subvec(children, std::nullopt, children.size() / 2)); + auto s2 = unordered_multiset_of( + subvec(children, children.size() / 2, std::nullopt)); + return BinarySPDecompositionTree{BinaryParallelSplit{ + from_parallel(ParallelSplit{s1}), from_parallel(ParallelSplit{s2})}}; + }; + + from_series = [&](SeriesSplit const &s) -> BinarySPDecompositionTree { + auto children = vector_of(s.children); + if (children.size() == 1) { + return from_series_child(get_only(children)); + } else if (children.size() == 2) { + return BinarySPDecompositionTree{BinarySeriesSplit{ + from_series_child(children[0]), from_series_child(children[1])}}; + } + auto s1 = subvec(children, std::nullopt, children.size() / 2); + auto s2 = subvec(children, children.size() / 2, std::nullopt); + return BinarySPDecompositionTree{BinarySeriesSplit{ + from_series(SeriesSplit{s1}), from_series(SeriesSplit{s2})}}; + }; + + from_parallel_child = [&](std::variant const &v) + -> BinarySPDecompositionTree { + return std::visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + }, + v); + }; + + from_series_child = [&](std::variant const &v) + -> BinarySPDecompositionTree { + return std::visit( + overload{ + [&](Node const &n) { return from_node(n); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }, + v); + }; + + return nary.visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 62489ff75f..a59e30afa5 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -1,8 +1,8 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/containers/multiset_union.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" - namespace FlexFlow { GenericBinarySPDecompositionTreeImplementation(overload{ + [](BinarySeriesSplit const &series) -> int { + int left_height = get_tree_height(series.get_left_child()); + int right_height = get_tree_height(series.get_right_child()); + return std::max(left_height, right_height) + 1; + }, + [](BinaryParallelSplit const ¶llel) -> int { + int left_height = get_tree_height(parallel.get_left_child()); + int right_height = get_tree_height(parallel.get_right_child()); + return std::max(left_height, right_height) + 1; + }, + [](Node const &) -> int { return 0; }, + }); +} + +std::unordered_multiset get_nodes(BinarySPDecompositionTree const &tree) { + return tree.visit>(overload{ + [](BinarySeriesSplit const &series) -> std::unordered_multiset { + auto left_nodes = get_nodes(series.get_left_child()); + auto right_nodes = get_nodes(series.get_right_child()); + return multiset_union(left_nodes, right_nodes); + }, + [](BinaryParallelSplit const ¶llel) -> std::unordered_multiset { + auto left_nodes = get_nodes(parallel.get_left_child()); + auto right_nodes = get_nodes(parallel.get_right_child()); + return multiset_union(left_nodes, right_nodes); + }, + [](Node const &node) -> std::unordered_multiset { return {node}; }, + }); +} } // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..b9946830ea --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc @@ -0,0 +1,92 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/rapidcheck.h" +#include "utils/containers/contains.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("balanced_binary_sp_tree_from_nary") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + SUBCASE("only node") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; + + BinarySPDecompositionTree result = + balanced_binary_sp_tree_from_nary(input); + BinarySPDecompositionTree correct = make_leaf(n1); + + CHECK(result == correct); + } + + SUBCASE("only serial") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{{n1, n2, n3, n4}}, + }; + + BinarySPDecompositionTree result = + balanced_binary_sp_tree_from_nary(input); + + BinarySPDecompositionTree correct = + make_series_split(make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4))); + + CHECK(result == correct); + } + + SUBCASE("only parallel") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{{n1, n2, n3, n4}}, + }; + BinarySPDecompositionTree result = + balanced_binary_sp_tree_from_nary(input); + + int result_height = get_tree_height(result); + int expected_height = 2; + CHECK(result_height == expected_height); + + std::unordered_multiset result_nodes = get_nodes(result); + std::unordered_multiset expected_nodes = {n1, n2, n3, n4}; + CHECK(result_nodes == expected_nodes); + } + + SUBCASE("nested") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{{SeriesSplit{{n1, n2, n3, n4}}, n5}}, + }; + + BinarySPDecompositionTree result = + balanced_binary_sp_tree_from_nary(input); + + BinarySPDecompositionTree balanced_series = + make_series_split(make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4))); + + std::unordered_set corrects = { + make_parallel_split(balanced_series, make_leaf(n5)), + make_parallel_split(make_leaf(n5), balanced_series)}; + + CHECK(contains(corrects, result)); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index fee971e5e0..f53578821f 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -67,7 +67,7 @@ TEST_SUITE(FF_TEST_SUITE) { // left-associative binary SP trees CHECK(is_binary_sp_tree_left_associative(result)); - std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset result_nodes = get_nodes(result); std::unordered_multiset correct_nodes = {n1, n2, n3}; CHECK(result_nodes == correct_nodes); @@ -96,7 +96,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(is_binary_sp_tree_left_associative(result)); - std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset result_nodes = get_nodes(result); std::unordered_multiset correct_nodes = { n1, n2, n3, n3, n5, n6, n4, n5};