Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
510c2d9
Start on pcg builder
lockshaw Jun 4, 2024
7b55ed1
Add tests and some implementation for pcg builder
lockshaw Jun 4, 2024
c379efd
Add pcg tests, make dtgen constructors explicit to fix bug
lockshaw Jun 10, 2024
35fa653
Add remainder of PCG tests
lockshaw Jun 10, 2024
865a28e
Merge remote-tracking branch 'origin/repo-refactor' into pcg-builder
lockshaw Jun 10, 2024
f379539
Fix build issues in local-execution
lockshaw Jun 10, 2024
2dbb3b9
Format
lockshaw Jun 10, 2024
4050c99
Address Reyna comments, add topological_order function for PCG
lockshaw Jun 17, 2024
42c1968
Pre multidigraph refactor
lockshaw Jun 19, 2024
3be816f
Removing visitable from sp code
lockshaw Jun 21, 2024
6d68324
Add open dataflow graph, start to replace pcg dataflow graph
lockshaw Jun 23, 2024
64a3403
Start refactoring substitutions
lockshaw Jun 24, 2024
7d4c7be
Add utility functions to support pattern matching
lockshaw Jun 25, 2024
9ab9eb2
Pre-refactor inputs
lockshaw Jun 26, 2024
7ae7c65
Merge remote-tracking branch 'origin/repo-refactor' into dataflow-graph
lockshaw Jun 26, 2024
f9b129e
Fix proj url
lockshaw Jun 26, 2024
cf73f08
Get back to substitutions, now with unordered graph inputs
lockshaw Jul 7, 2024
5fd666d
Get substitutions building
lockshaw Jul 13, 2024
5f0c88a
substitutions-tests now builds
lockshaw Jul 13, 2024
3228f2d
Fix bug in filter, pass some initial substitution tests
lockshaw Jul 14, 2024
5f4cc01
Add tests for fmt::to_string, fix some substitutions bugs
lockshaw Jul 15, 2024
ad60be0
Pass initial unit tests for find_pattern_matches
lockshaw Jul 15, 2024
a972da2
Start on unit tests for pcg pattern
lockshaw Jul 15, 2024
bcf776e
Pass initial test for find_pattern_matches
lockshaw Jul 19, 2024
e28400e
Merge remote-tracking branch 'origin/repo-refactor' into dataflow-graph
lockshaw Jul 19, 2024
fe6d65d
Fix small build issue in tests
lockshaw Jul 19, 2024
e647af7
Format
lockshaw Jul 19, 2024
8b58760
Sync tests in CI with tests in proj
lockshaw Jul 19, 2024
1fafb9d
Fix minor build errors in kernels and local-execution
lockshaw Jul 19, 2024
0804314
Format
lockshaw Jul 19, 2024
dd5465c
Remove outdated code
lockshaw Jul 20, 2024
29ec5b8
More outdated code removal
lockshaw Jul 20, 2024
ff41743
More cleanup, add test for sp decomposition
lockshaw Jul 20, 2024
e71d200
Pull apart containers.h
lockshaw Jul 21, 2024
c06710c
More sp testing and fixes
lockshaw Jul 21, 2024
2f75566
Break up graph algorithms.h
lockshaw Jul 21, 2024
c81d3a4
Pre- full SP algo commit
lockshaw Jul 23, 2024
2a11c7e
Add initial implementation and tests for cbc decomposition and invers…
lockshaw Jul 23, 2024
71a9e0f
Pass test for get_inverse_line_graph
lockshaw Jul 24, 2024
25eb1db
Add new multidigraph
lockshaw Jul 24, 2024
64f1932
Fix get_inverse_line_graph to return a MultiDiGraph instead of a DiGraph
lockshaw Jul 24, 2024
31c8d17
Add tests for parallel and series reduction finding
lockshaw Jul 24, 2024
19e7e28
Add really rough implementation of valdez sp decomposition
lockshaw Jul 24, 2024
3791e86
Fix local-execution build
lockshaw Jul 25, 2024
267b72d
Add implementations and tests for applying series/parallel reductions
lockshaw Jul 25, 2024
bb2769a
Format
lockshaw Jul 26, 2024
39cb7b3
Clean up sp decomposition interface and tests
lockshaw Jul 27, 2024
ce0234d
Format
lockshaw Jul 27, 2024
3dc3ec6
Add comments for top-level substitutions functions, add proj doxygen …
lockshaw Jul 31, 2024
ee518c2
Start sketching out substitutions code
lockshaw Jul 31, 2024
f69b95a
Merge branch 'dataflow-graph' into substitutions-fix
lockshaw Jul 31, 2024
3c06b88
Fix build errors
lockshaw Aug 1, 2024
3d6f681
Add ability to permute node ids
lockshaw Aug 1, 2024
098a9d1
Cleanup and start to test new substitutions code
lockshaw Aug 4, 2024
9bd4f14
Add test case for evaluate_substitution_output
lockshaw Aug 5, 2024
101083b
Add naive isomorphism detection code
lockshaw Aug 5, 2024
9fec50c
Add graph inputs to open dataflow graph isomorphism
lockshaw Aug 6, 2024
7c60736
Add input permutation to evaluate_substitution_output
lockshaw Aug 6, 2024
cb6eab2
Fix permute_node_ids
lockshaw Aug 8, 2024
2f3d67a
Add test for permute_input_ids
lockshaw Aug 8, 2024
03cbd02
Migrate over to mutable implementation of apply_substitution
lockshaw Aug 23, 2024
4a8deae
Add fast isomorphism checking and an initial implementation of full s…
lockshaw Aug 24, 2024
0757e94
Pass initial full substitutions test
lockshaw Aug 24, 2024
ba0a174
Cleanup old isomorphism checking code
lockshaw Aug 24, 2024
4dfa403
Merge remote-tracking branch 'origin/repo-refactor' into substitution…
lockshaw Aug 24, 2024
f156f96
Fix post-merge bugs
lockshaw Aug 24, 2024
5f09298
Fix broken pcg builder test
lockshaw Aug 26, 2024
deff4f8
Format
lockshaw Aug 26, 2024
d71d24f
Reorganize code and remove some outdated code pre-code-review
lockshaw Aug 26, 2024
1a63f90
Format
lockshaw Aug 26, 2024
1d4ab09
Restarting work on this after working on export-model-arch
lockshaw Sep 10, 2024
7928864
Merge remote-tracking branch 'flexflow/repo-refactor' into substituti…
lockshaw Sep 10, 2024
f2c8e7b
Merge branch 'substitution-builder' into master
Jan 9, 2025
030b0e8
Merge branch 'master' into merge-substitution-builder
victorli2002 Jan 15, 2025
f5c49c7
Adding in some a simple function to get the currently available subst…
Jan 17, 2025
3e4c357
Merge branch 'master' into merge-substitution-builder
victorli2002 Jan 17, 2025
30f2b6e
Merge branch 'master' into merge-substitution-builder
lockshaw Jan 20, 2025
47cc58a
nonnegative_int additions, code cleanup, etc.
Jan 24, 2025
f8df37e
Merge remote-tracking branch 'origin/master' into victor-substitution…
lockshaw Jan 25, 2025
3728251
A bunch more moving over to nonnegative_int
lockshaw Jan 28, 2025
f27d31b
Even more nonnegative_int updating
lockshaw Jan 31, 2025
9f8762e
Fix build
lockshaw Jan 31, 2025
5edb6f0
Fix failing tests
lockshaw Feb 1, 2025
97338c7
Merge branch 'master' into merge-substitution-builder
lockshaw Feb 1, 2025
88370c0
Format
lockshaw Feb 1, 2025
9c2007e
Merge remote-tracking branch 'refs/remotes/victorli2002/merge-substit…
lockshaw Feb 1, 2025
600e074
Format
lockshaw Feb 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
17 changes: 9 additions & 8 deletions bin/export-model-arch/src/export_model_arch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "utils/cli/cli_parse.h"
#include "utils/cli/cli_parse_result.h"
#include "utils/cli/cli_spec.h"
#include "utils/graph/open_dataflow_graph/algorithms/as_dot.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h"
#include "utils/graph/series_parallel/get_series_parallel_decomposition.h"

Expand All @@ -21,11 +22,11 @@ using namespace ::FlexFlow;
ComputationGraph get_single_operator_computation_graph() {
ComputationGraphBuilder b;

size_t batch_size = 8;
size_t in_channels = 16;
size_t out_channels = 12;
nonnegative_int batch_size = 8_n;
nonnegative_int in_channels = 16_n;
nonnegative_int out_channels = 12_n;
TensorShape input_shape = TensorShape{
TensorDims{FFOrdered<size_t>{
TensorDims{FFOrdered<nonnegative_int>{
batch_size,
in_channels,
out_channels,
Expand Down Expand Up @@ -69,7 +70,7 @@ tl::expected<ComputationGraph, std::string>
} else if (model_name == "bert") {
return get_bert_computation_graph(get_default_bert_config());
} else if (model_name == "split_test") {
int batch_size = 8;
nonnegative_int batch_size = 8_n;
return get_split_test_computation_graph(batch_size);
} else if (model_name == "single_operator") {
return get_single_operator_computation_graph();
Expand Down Expand Up @@ -100,10 +101,10 @@ tl::expected<JsonSPModelExport, std::string>
result.value();
});

std::pair<V1ComputationGraph, bidict<int, layer_guid_t>> v1_result =
to_v1_including_node_numbering(computation_graph);
std::pair<V1ComputationGraph, bidict<nonnegative_int, layer_guid_t>>
v1_result = to_v1_including_node_numbering(computation_graph);
V1ComputationGraph v1_cg = v1_result.first;
bidict<int, layer_guid_t> layer_numbering = v1_result.second;
bidict<nonnegative_int, layer_guid_t> layer_numbering = v1_result.second;
V1BinarySPDecomposition v1_sp_decomposition =
to_v1(sp_decomposition, layer_numbering);

Expand Down
14 changes: 13 additions & 1 deletion cmake/flexflow-utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ function(define_ff_vars target)
MAX_TENSOR_DIM=${FF_MAX_DIM}
MAX_NUM_TASK_REGIONS=${FF_MAX_NUM_TASK_REGIONS}
MAX_NUM_TASK_ARGUMENTS=${FF_MAX_NUM_TASK_ARGUMENTS}
# _FORTIFY_SOURCE=0
)

if (FF_GPU_BACKEND STREQUAL "cuda")
Expand All @@ -39,7 +40,18 @@ function(ff_set_cxx_properties target)
CXX_EXTENSIONS NO
)
target_compile_options(${target}
PRIVATE $<$<COMPILE_LANGUAGE:CXX>:> "-ffile-prefix-map=${CMAKE_SOURCE_DIR}=." # add C++ compile flags here
PUBLIC
$<$<COMPILE_LANGUAGE:CXX>:>
"-ffile-prefix-map=${CMAKE_SOURCE_DIR}=."
"-fsanitize=undefined"
"-fno-sanitize-recover=all"
# add C++ compile flags here
)
target_link_options(${target}
PUBLIC
$<$<COMPILE_LANGUAGE:CXX>:>
"-fsanitize=undefined"
"-fno-sanitize-recover=all"
)
endfunction()

Expand Down
14 changes: 12 additions & 2 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,15 @@
};
lib = pkgs.lib;

mkShell = pkgs.mkShell.override {
mkShell = attrs: pkgs.mkShell.override {
stdenv = pkgs.cudaPackages.backendStdenv;
};
} (attrs // {
hardeningDisable = ["all"]; # disable nixpkgs default compiler arguments, otherwise ubsan doesn't catch
# signed overflows due to the signedoverflow hardening setting.
# for more details, see the following (long-running) nixpkgs github issues:
# - https://github.com/NixOS/nixpkgs/issues/18995
# - https://github.com/NixOS/nixpkgs/issues/60919
});

proj = proj-repo.packages.${system}.proj;
in
Expand Down Expand Up @@ -121,6 +127,8 @@

gpu-ci = mkShell {
inputsFrom = [ ci ];
hardeningDisable = [ "all" ];

buildInputs = builtins.concatLists [
(with nixGL.packages.${system}; [
nixGLDefault
Expand All @@ -135,6 +143,8 @@
"${proj-repo.packages.${system}.proj-nvim}"
];

hardeningDisable = [ "all" ];

buildInputs = builtins.concatLists [
(with pkgs; [
clang-tools
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ bool is_right_associative(ComputationGraphBinarySPDecomposition const &);
std::unordered_multiset<layer_guid_t>
get_layers(ComputationGraphBinarySPDecomposition const &);

V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &,
bidict<int, layer_guid_t> const &layer_numbering);
V1BinarySPDecomposition
to_v1(ComputationGraphBinarySPDecomposition const &,
bidict<nonnegative_int, layer_guid_t> const &layer_numbering);

} // namespace FlexFlow

Expand Down
41 changes: 25 additions & 16 deletions lib/compiler/src/compiler/allowed_machine_views.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
#include "utils/containers/map_from_keys_and_values.h"
#include "utils/containers/product.h"
#include "utils/containers/range.h"
#include "utils/containers/replicate.h"
#include "utils/containers/repeat_element.h"
#include "utils/containers/sorted.h"
#include "utils/containers/transform.h"
#include "utils/containers/unordered_multiset_of.h"
#include "utils/containers/unordered_set_of.h"
#include "utils/containers/zip.h"
#include "utils/nonnegative_int/ceildiv.h"
#include "utils/nonnegative_int/nonnegative_range.h"
#include "utils/nonnegative_int/num_elements.h"
#include "utils/overload.h"

namespace FlexFlow {
Expand Down Expand Up @@ -47,24 +50,29 @@ static std::unordered_set<MachineView>
OperatorTaskSpace const &task,
DeviceType const &device_type) {

auto get_max_stride_upper_bound = [](std::vector<int> const &tensor_dims,
int total_devices) -> int {
int min_num_devices_with_full_stride_volume = product(transform(
tensor_dims, [](int const &num_devices) { return num_devices - 1; }));
return std::ceil(total_devices / min_num_devices_with_full_stride_volume);
auto get_max_stride_upper_bound =
[](std::vector<nonnegative_int> const &tensor_dims,
nonnegative_int total_devices) -> nonnegative_int {
nonnegative_int min_num_devices_with_full_stride_volume =
product(transform(tensor_dims, [](nonnegative_int num_devices) {
return nonnegative_int{num_devices.unwrap_nonnegative() - 1};
}));
return ceildiv(total_devices, min_num_devices_with_full_stride_volume);
};

auto candidate_strides = [&](std::vector<int> const &tensor_dims,
int total_devices)
auto candidate_strides = [&](std::vector<nonnegative_int> const &tensor_dims,
nonnegative_int total_devices)
-> std::unordered_multiset<MultiDimensionalStride> {
int max_stride_upper_bound =
nonnegative_int max_stride_upper_bound =
get_max_stride_upper_bound(tensor_dims, total_devices);

std::vector<stride_t> single_stride_range =
transform(range(1, max_stride_upper_bound + 1),
[](int stride) { return stride_t{stride}; });
transform(nonnegative_range(1_n, max_stride_upper_bound + 1_n),
[](nonnegative_int stride) { return stride_t{stride}; });
std::unordered_multiset<std::vector<stride_t>> raw_stride_vectors =
cartesian_product(replicate(tensor_dims.size(), single_stride_range));
cartesian_product(
repeat_element(/*num_times=*/num_elements(tensor_dims),
/*element=*/single_stride_range));
std::unordered_multiset<MultiDimensionalStride> strides =
transform(raw_stride_vectors, [](auto const &stride_vec) {
return MultiDimensionalStride{stride_vec};
Expand All @@ -75,8 +83,9 @@ static std::unordered_set<MachineView>
auto candidate_starts = [](MachineSpecification const &ms,
DeviceType const &device_type) {
std::unordered_set<MachineSpaceCoordinate> result;
for (int node_idx : range(ms.num_nodes)) {
for (int device_idx : range(get_num_devices_per_node(ms, device_type))) {
for (nonnegative_int node_idx : nonnegative_range(ms.num_nodes)) {
for (nonnegative_int device_idx :
nonnegative_range(get_num_devices_per_node(ms, device_type))) {
result.insert(
MachineSpaceCoordinate{node_idx, device_idx, device_type});
}
Expand All @@ -91,8 +100,8 @@ static std::unordered_set<MachineView>
return get_all_permutations_with_repetition(options, num_dims(task));
};

std::vector<int> tensor_dims = task.degrees;
int total_devices = get_num_devices(machine_spec, device_type);
std::vector<nonnegative_int> tensor_dims = task.degrees;
nonnegative_int total_devices = get_num_devices(machine_spec, device_type);

std::unordered_set<MachineView> machine_views;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@ std::unordered_set<std::pair<MachineSpecification, MachineSpecification>>
for (int i = 1; i < resource.num_nodes; i *= 2) {
MachineSpecification sub_resource1 = resource;
MachineSpecification sub_resource2 = resource;
sub_resource1.num_nodes = i;
sub_resource2.num_nodes = resource.num_nodes - i;
sub_resource1.num_nodes = nonnegative_int{i};
sub_resource2.num_nodes =
nonnegative_int{resource.num_nodes.unwrap_nonnegative() - i};
result.insert(std::make_pair(sub_resource1, sub_resource2));
result.insert(std::make_pair(sub_resource2, sub_resource1));
}

for (int i = 1; i < resource.num_gpus_per_node; i *= 2) {
MachineSpecification sub_resource1 = resource;
MachineSpecification sub_resource2 = resource;
sub_resource1.num_gpus_per_node = i;
sub_resource2.num_gpus_per_node = resource.num_gpus_per_node - i;
sub_resource1.num_gpus_per_node = nonnegative_int{i};
sub_resource2.num_gpus_per_node =
nonnegative_int{resource.num_gpus_per_node.unwrap_nonnegative() - i};
result.insert(std::make_pair(sub_resource1, sub_resource2));
result.insert(std::make_pair(sub_resource2, sub_resource1));
}
Expand Down
10 changes: 2 additions & 8 deletions lib/compiler/src/compiler/machine_mapping/machine_mapping.cc
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
#include "compiler/machine_mapping/machine_mapping.h"
#include "pcg/machine_specification.h"
#include "pcg/machine_view.h"
#include "pcg/operator_task_space.dtg.h"
#include "pcg/operator_task_space.h"
#include "pcg/parallel_computation_graph/parallel_computation_graph.h"
#include "utils/containers/are_disjoint.h"
#include "utils/containers/get_one_of.h"
#include "utils/containers/keys.h"
#include "utils/containers/map_values.h"
#include "utils/containers/merge_maps.h"

namespace FlexFlow {

MachineMapping combine_disjoint_mappings(MachineMapping const &m1,
MachineMapping const &m2) {
return MachineMapping{merge_maps(m1.machine_views, m2.machine_views)};
return MachineMapping{
merge_disjoint_maps(m1.machine_views, m2.machine_views)};
}

bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ ParallelLayerGuidObliviousMachineMapping binary_combine_mappings(
ParallelLayerGuidObliviousMachineMapping const &lhs,
ParallelLayerGuidObliviousMachineMapping const &rhs) {
return ParallelLayerGuidObliviousMachineMapping{
merge_maps(map_keys(lhs.raw_mapping, nest_inside_left_child),
map_keys(rhs.raw_mapping, nest_inside_right_child)),
merge_disjoint_maps(map_keys(lhs.raw_mapping, nest_inside_left_child),
map_keys(rhs.raw_mapping, nest_inside_right_child)),
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ std::unordered_multiset<layer_guid_t>

V1BinarySPDecomposition
to_v1(ComputationGraphBinarySPDecomposition const &tree,
bidict<int, layer_guid_t> const &layer_numbering) {
bidict<nonnegative_int, layer_guid_t> const &layer_numbering) {
return tree.visit<V1BinarySPDecomposition>(
overload{[&](ComputationGraphBinarySeriesSplit const &series) {
return V1BinarySPDecomposition{
Expand Down
60 changes: 33 additions & 27 deletions lib/compiler/test/src/allowed_machine_views.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,39 @@ TEST_SUITE(FF_TEST_SUITE) {

SUBCASE("1 degree of parallelism") {
MachineSpecification ms = MachineSpecification{
/*num_nodes=*/1,
/*num_cpus_per_node=*/5,
/*num_gpus_per_node=*/5,
/*num_nodes=*/1_n,
/*num_cpus_per_node=*/5_n,
/*num_gpus_per_node=*/5_n,
/*inter_node_bandwidth=*/0,
/*intra_node_bandwidth=*/0,
};

OperatorTaskSpace task = OperatorTaskSpace{{3}};
OperatorTaskSpace task = OperatorTaskSpace{{3_n}};

std::unordered_set<MachineView> correct = {
MachineView{
MachineSpaceCoordinate{
/*node_idx=*/0, /*device_idx=*/0, DeviceType::GPU},
{MachineViewDimension{stride_t{1},
/*node_idx=*/0_n, /*device_idx=*/0_n, DeviceType::GPU},
{MachineViewDimension{stride_t{1_n},
MachineSpecificationDimension::INTRA_NODE}},
},

MachineView{
MachineSpaceCoordinate{
/*node_idx=*/0, /*device_idx=*/1, DeviceType::GPU},
{MachineViewDimension{stride_t{1},
/*node_idx=*/0_n, /*device_idx=*/1_n, DeviceType::GPU},
{MachineViewDimension{stride_t{1_n},
MachineSpecificationDimension::INTRA_NODE}},
},
MachineView{
MachineSpaceCoordinate{
/*node_idx=*/0, /*device_idx=*/2, DeviceType::GPU},
{MachineViewDimension{stride_t{1},
/*node_idx=*/0_n, /*device_idx=*/2_n, DeviceType::GPU},
{MachineViewDimension{stride_t{1_n},
MachineSpecificationDimension::INTRA_NODE}},
},
MachineView{
MachineSpaceCoordinate{
/*node_idx=*/0, /*device_idx=*/0, DeviceType::GPU},
{MachineViewDimension{stride_t{2},
/*node_idx=*/0_n, /*device_idx=*/0_n, DeviceType::GPU},
{MachineViewDimension{stride_t{2_n},
MachineSpecificationDimension::INTRA_NODE}},
},
};
Expand All @@ -61,18 +61,18 @@ TEST_SUITE(FF_TEST_SUITE) {
SUBCASE("2 degrees of parallelism") {

MachineSpecification ms = MachineSpecification{
/*num_nodes=*/3,
/*num_cpus_per_node=*/3,
/*num_gpus_per_node=*/3,
/*num_nodes=*/3_n,
/*num_cpus_per_node=*/3_n,
/*num_gpus_per_node=*/3_n,
/*inter_node_bandwidth=*/0,
/*intra_node_bandwidth=*/0,
};
OperatorTaskSpace task = OperatorTaskSpace{{2, 3}};
OperatorTaskSpace task = OperatorTaskSpace{{2_n, 3_n}};

auto make_2d_view = [&](int start_node_idx,
int start_device_idx,
int stride1,
int stride2,
auto make_2d_view = [&](nonnegative_int start_node_idx,
nonnegative_int start_device_idx,
nonnegative_int stride1,
nonnegative_int stride2,
MachineSpecificationDimension m1,
MachineSpecificationDimension m2) {
return MachineView{
Expand All @@ -86,13 +86,19 @@ TEST_SUITE(FF_TEST_SUITE) {
auto intra = MachineSpecificationDimension::INTRA_NODE;
auto inter = MachineSpecificationDimension::INTER_NODE;
std::unordered_set<MachineView> correct = {
make_2d_view(0, 0, /*stride1=*/1, /*stride2=*/1, inter, intra),
make_2d_view(1, 0, /*stride1=*/1, /*stride2=*/1, inter, intra),
make_2d_view(0, 0, /*stride1=*/2, /*stride2=*/1, inter, intra),

make_2d_view(0, 0, /*stride1=*/1, /*stride2=*/1, intra, inter),
make_2d_view(0, 1, /*stride1=*/1, /*stride2=*/1, intra, inter),
make_2d_view(0, 0, /*stride1=*/2, /*stride2=*/1, intra, inter),
make_2d_view(
0_n, 0_n, /*stride1=*/1_n, /*stride2=*/1_n, inter, intra),
make_2d_view(
1_n, 0_n, /*stride1=*/1_n, /*stride2=*/1_n, inter, intra),
make_2d_view(
0_n, 0_n, /*stride1=*/2_n, /*stride2=*/1_n, inter, intra),

make_2d_view(
0_n, 0_n, /*stride1=*/1_n, /*stride2=*/1_n, intra, inter),
make_2d_view(
0_n, 1_n, /*stride1=*/1_n, /*stride2=*/1_n, intra, inter),
make_2d_view(
0_n, 0_n, /*stride1=*/2_n, /*stride2=*/1_n, intra, inter),
};

std::unordered_set<MachineView> result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ TEST_SUITE(FF_TEST_SUITE) {
ParallelTensorShape input_shape = ParallelTensorShape{
ParallelTensorDims{
FFOrdered<ShardParallelDim>{
ShardParallelDim{10, 2},
ShardParallelDim{12, 1},
ShardParallelDim{10_n, 2_n},
ShardParallelDim{12_n, 1_n},
},
ReplicaParallelDimSet{
SumDegree{1},
DiscardCopyDegree{1},
SumDegree{1_n},
DiscardCopyDegree{1_n},
},
},
DataType::FLOAT,
Expand Down
Loading
Loading