diff --git a/.envrc b/.envrc new file mode 100644 index 0000000000..2797f0f929 --- /dev/null +++ b/.envrc @@ -0,0 +1,3 @@ +source_up_if_exists + +use flake diff --git a/.vimrc b/.vimrc new file mode 100644 index 0000000000..4c8a8a8279 --- /dev/null +++ b/.vimrc @@ -0,0 +1,8 @@ +" example search path configuration +set path=lib/runtime/**,lib/** + +" set build target +" let g:target = "pcg" + +" set test target +" let g:test_target = "utils-test" diff --git a/lib/substitutions/include/substitutions/unity_substitution_set.h b/lib/substitutions/include/substitutions/unity_substitution_set.h index 183f76ac8a..7c257dee60 100644 --- a/lib/substitutions/include/substitutions/unity_substitution_set.h +++ b/lib/substitutions/include/substitutions/unity_substitution_set.h @@ -10,18 +10,11 @@ namespace FlexFlow { std::vector get_substitution_set(MachineSpecification const &resources); -Substitution create_combine_inception(nonnegative_int num_convs, - nonnegative_int num_dims, - nonnegative_int degree); -Substitution create_combine_concat(nonnegative_int num_inputs, - nonnegative_int num_dims, - nonnegative_int degree); Substitution create_replicate_linear_combine(nonnegative_int num_dims, nonnegative_int degree, bool use_bias); Substitution create_partition_linear_combine(nonnegative_int num_dims, nonnegative_int degree, - Activation activation, bool use_bias); Substitution create_partition_conv2d_combine(nonnegative_int num_dims, nonnegative_int degree); @@ -33,10 +26,6 @@ Substitution create_partition_add_combine(ff_dim_t parallel_dim, nonnegative_int degree); Substitution create_partition_relu_combine(ff_dim_t parallel_dim, nonnegative_int degree); -Substitution create_partition_concat_combine(nonnegative_int num_inputs, - ff_dim_t concat_dim, - ff_dim_t parallel_dim, - nonnegative_int degree); Substitution create_partition_softmax_combine(ff_dim_t softmax_dim, ff_dim_t partition_dim, nonnegative_int degree); diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index cb733e16ff..a7575ae837 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -83,6 +83,8 @@ std::optional get_attribute(ConcatAttrs const &p, std::optional get_attribute(Conv2DAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OUT_CHANNELS: + return OperatorAttributeValue{p.out_channels}; case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; case OperatorAttributeKey::KERNEL_H: @@ -113,6 +115,12 @@ std::optional get_attribute(ElementBinaryAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; + case OperatorAttributeKey::DATA_TYPE: + return OperatorAttributeValue{p.compute_type}; + case OperatorAttributeKey::SHOULD_BROADCAST_LHS: + return OperatorAttributeValue{p.should_broadcast_lhs}; + case OperatorAttributeKey::SHOULD_BROADCAST_RHS: + return OperatorAttributeValue{p.should_broadcast_rhs}; default: return std::nullopt; } @@ -123,6 +131,8 @@ std::optional get_attribute(ElementUnaryAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; + case OperatorAttributeKey::SCALAR: + return OperatorAttributeValue{p.scalar}; default: return std::nullopt; } @@ -227,10 +237,20 @@ std::optional switch (key) { case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; + case OperatorAttributeKey::EMBED_DIM: + return OperatorAttributeValue{p.embed_dim}; + case OperatorAttributeKey::KDIM: + return OperatorAttributeValue{p.kdim}; + case OperatorAttributeKey::VDIM: + return OperatorAttributeValue{p.vdim}; case OperatorAttributeKey::NUM_HEADS: return OperatorAttributeValue{p.num_heads}; - case OperatorAttributeKey::USE_BIAS: + case OperatorAttributeKey::BIAS: return OperatorAttributeValue{p.bias}; + case OperatorAttributeKey::ADD_BIAS_KV: + return OperatorAttributeValue{p.add_bias_kv}; + case OperatorAttributeKey::ADD_ZERO_ATTN: + return OperatorAttributeValue{p.add_bias_kv}; case OperatorAttributeKey::DROPOUT: return OperatorAttributeValue{p.dropout}; default: diff --git a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc index 4f11b343f8..9d312abefd 100644 --- a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc +++ b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc @@ -63,6 +63,22 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( case OperatorType::INPUT: case OperatorType::WEIGHT: case OperatorType::CONV2D: + return PCGOperatorAttrs{Conv2DAttrs{ + /*out_channels=*/acc.get( + OperatorAttributeKey::OUT_CHANNELS), + /*kernel_h=*/acc.get(OperatorAttributeKey::KERNEL_H), + /*kernel_w=*/acc.get(OperatorAttributeKey::KERNEL_W), + /*stride_h=*/acc.get(OperatorAttributeKey::STRIDE_H), + /*stride_w=*/acc.get(OperatorAttributeKey::STRIDE_W), + /*padding_h=*/ + acc.get(OperatorAttributeKey::PADDING_H), + /*padding_w=*/ + acc.get(OperatorAttributeKey::PADDING_W), + /*groups=*/acc.get(OperatorAttributeKey::GROUPS), + /*activation=*/ + acc.get>(OperatorAttributeKey::ACTIVATION), + /*use_bias=*/acc.get(OperatorAttributeKey::USE_BIAS), + }}; case OperatorType::DROPOUT: case OperatorType::LINEAR: return PCGOperatorAttrs{LinearAttrs{ @@ -76,19 +92,56 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( acc.get>( OperatorAttributeKey::REGULARIZER), }}; + case OperatorType::REPLICATE: + return PCGOperatorAttrs{ReplicateAttrs{ + /*replicate_degree=*/acc.get( + OperatorAttributeKey::PARALLEL_DEGREE), + }}; + case OperatorType::REPARTITION: + return PCGOperatorAttrs{RepartitionAttrs{ + /*repartition_dim=*/acc.get( + OperatorAttributeKey::PARALLEL_DIM), + /*repartition_Degree=*/ + acc.get(OperatorAttributeKey::PARALLEL_DEGREE), + }}; + case OperatorType::COMBINE: + return PCGOperatorAttrs{CombineAttrs{ + /*combine_dim=*/acc.get(OperatorAttributeKey::PARALLEL_DIM), + /*combine_degree=*/ + acc.get(OperatorAttributeKey::PARALLEL_DEGREE), + }}; + + case OperatorType::EW_ADD: + return PCGOperatorAttrs{ElementBinaryAttrs{ + acc.get(OperatorAttributeKey::OP_TYPE), + acc.get(OperatorAttributeKey::DATA_TYPE), + acc.get(OperatorAttributeKey::SHOULD_BROADCAST_LHS), + acc.get(OperatorAttributeKey::SHOULD_BROADCAST_LHS), + }}; + case OperatorType::RELU: + return PCGOperatorAttrs{ElementUnaryAttrs{ + acc.get(OperatorAttributeKey::OP_TYPE), + acc.get>(OperatorAttributeKey::SCALAR), + }}; + case OperatorType::REDUCTION: + return PCGOperatorAttrs{ReductionAttrs{ + acc.get(OperatorAttributeKey::PARALLEL_DEGREE), + }}; + case OperatorType::SOFTMAX: + return PCGOperatorAttrs{SoftmaxAttrs{ + acc.get(OperatorAttributeKey::AXIS), + }}; case OperatorType::BATCHMATMUL: case OperatorType::SCALAR_MULTIPLY: case OperatorType::SCALAR_ADD: case OperatorType::SCALAR_FLOOR_DIV: case OperatorType::SCALAR_TRUE_DIV: case OperatorType::SCALAR_SUB: - case OperatorType::RELU: case OperatorType::IDENTITY: case OperatorType::SIGMOID: case OperatorType::TANH: case OperatorType::ELU: case OperatorType::FLAT: - case OperatorType::SOFTMAX: case OperatorType::BATCHNORM: case OperatorType::CONCAT: case OperatorType::SPLIT: @@ -97,7 +150,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( case OperatorType::RESHAPE: case OperatorType::REVERSE: case OperatorType::TRANSPOSE: - case OperatorType::EW_ADD: case OperatorType::EW_MUL: case OperatorType::MATMUL: case OperatorType::MUL: @@ -144,10 +196,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( case OperatorType::LAYERNORM: case OperatorType::GATHER: case OperatorType::BROADCAST: - case OperatorType::REPARTITION: - case OperatorType::COMBINE: - case OperatorType::REPLICATE: - case OperatorType::REDUCTION: case OperatorType::BATCH: case OperatorType::PIPELINE: case OperatorType::FUSED_PARALLEL: diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 83df74f21b..0c673f0a8a 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -188,34 +188,33 @@ bool sub_pcgs_are_isomorphic(SubParallelComputationGraph const &lhs, } std::string as_dot(SubParallelComputationGraph const &spcg) { - NOT_IMPLEMENTED(); - // std::function get_node_label = - // [](ParallelLayerAttrs const &a) -> std::string { - // RecordFormatter r = as_dot(a.op_attrs); - // - // if (a.name.has_value()) { - // RecordFormatter rr; - // rr << "Name" << a.name.value(); - // r << rr; - // } - // - // std::ostringstream oss; - // oss << r; - // return oss.str(); - // }; - // - // std::function get_input_label = - // [](ParallelTensorAttrs const &a) -> std::string { - // RecordFormatter r; - // - // r << fmt::to_string(a.shape); - // - // std::ostringstream oss; - // oss << r; - // return oss.str(); - // }; - // - // return as_dot(spcg.raw_graph, get_node_label, get_input_label); + std::function get_node_label = + [](ParallelLayerAttrs const &a) -> std::string { + RecordFormatter r = as_dot(a.op_attrs); + + if (a.name.has_value()) { + RecordFormatter rr; + rr << "Name" << a.name.value(); + r << rr; + } + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + std::function get_input_label = + [](ParallelTensorAttrs const &a) -> std::string { + RecordFormatter r; + + r << fmt::to_string(a.shape); + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + return as_dot(spcg.raw_graph, get_node_label, get_input_label); } void debug_print_dot(SubParallelComputationGraph const &spcg) { diff --git a/lib/substitutions/src/substitutions/unity_substitution_set.cc b/lib/substitutions/src/substitutions/unity_substitution_set.cc index 4b00cdd95f..1f4dbae1cf 100644 --- a/lib/substitutions/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/src/substitutions/unity_substitution_set.cc @@ -13,14 +13,42 @@ namespace FlexFlow { std::vector get_substitution_set(MachineSpecification const &resources) { std::vector substitutions; - for (nonnegative_int num_dims : + for (nonnegative_int dim : nonnegative_range(1_n, nonnegative_int{MAX_TENSOR_DIM})) { for (nonnegative_int degree = 1_n; degree <= get_num_gpus(resources); degree *= 2_n) { substitutions.push_back( - create_replicate_linear_combine(num_dims, degree, true)); + create_replicate_linear_combine(dim, degree, true)); substitutions.push_back( - create_replicate_linear_combine(num_dims, degree, false)); + create_replicate_linear_combine(dim, degree, false)); + substitutions.push_back( + create_partition_linear_combine(dim, degree, true)); + substitutions.push_back( + create_partition_linear_combine(dim, degree, false)); + substitutions.push_back( + create_partition_relu_combine(ff_dim_t{dim}, degree)); + substitutions.push_back( + create_partition_add_combine(ff_dim_t{dim}, degree)); + substitutions.push_back(create_partition_attention_combine(dim, degree)); + substitutions.push_back(create_replicate_attention_reduce(dim, degree)); + } + } + for (nonnegative_int degree = 1_n; degree <= get_num_gpus(resources); + degree *= 2_n) { + substitutions.push_back(create_partition_conv2d_combine(4_n, degree)); + } + + for (nonnegative_int partition_dim : + nonnegative_range(1_n, nonnegative_int{MAX_TENSOR_DIM})) { + for (nonnegative_int softmax_dim : + nonnegative_range(1_n, nonnegative_int{MAX_TENSOR_DIM})) { + for (nonnegative_int degree = 1_n; degree <= get_num_gpus(resources); + degree *= 2_n) { + if (partition_dim != softmax_dim) { + substitutions.push_back(create_partition_softmax_combine( + ff_dim_t{partition_dim}, ff_dim_t{softmax_dim}, degree)); + } + } } } substitutions.push_back(create_fuse_linear_activation(Activation::RELU)); @@ -30,18 +58,6 @@ std::vector return substitutions; } -Substitution create_combine_inception(nonnegative_int num_convs, - nonnegative_int num_dims, - nonnegative_int degree) { - NOT_IMPLEMENTED(); -} - -Substitution create_combine_concat(nonnegative_int num_inputs, - nonnegative_int num_dims, - nonnegative_int degree) { - NOT_IMPLEMENTED(); -} - Substitution create_replicate_linear_combine(nonnegative_int num_dims, nonnegative_int degree, bool use_bias) { @@ -63,15 +79,14 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, op_type_equals_constraint(OperatorType::LINEAR), op_attr_key_equals(OperatorAttributeKey::BIAS, OperatorAttributeValue{use_bias}), - op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, - nonnegative_int{degree}), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), }}; - PatternValue p_linear_output = get_only(b.add_pattern_node( - linear_pattern, - p_inputs, - {tensor_attr_pattern_require_num_dims(nonnegative_int{num_dims})}, - "linear")); + PatternValue p_linear_output = get_only( + b.add_pattern_node(linear_pattern, + p_inputs, + {tensor_attr_pattern_require_num_dims(num_dims)}, + "linear")); OutputOperatorAttrsAssignment replicate_input_expr = OutputOperatorAttrsAssignment{ @@ -146,47 +161,545 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, Substitution create_partition_linear_combine(nonnegative_int num_dims, nonnegative_int degree, - Activation activation, bool use_bias) { - NOT_IMPLEMENTED(); + SubstitutionBuilder b; + + auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); + auto [p_weight, o_weight] = b.add_input(tensor_attribute_pattern_match_all()); + std::vector p_inputs = {p_input, p_weight}; + + std::optional o_bias = std::nullopt; + if (use_bias) { + std::pair bias = + b.add_input(tensor_attribute_pattern_match_all()); + p_inputs.push_back(bias.first); + o_bias = bias.second; + } + + OperatorAttributePattern linear_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::LINEAR), + op_attr_key_equals(OperatorAttributeKey::BIAS, + OperatorAttributeValue{use_bias}), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + }}; + + PatternValue p_linear_output = get_only( + b.add_pattern_node(linear_pattern, + p_inputs, + {tensor_attr_pattern_require_num_dims(num_dims)}, + "linear")); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{0_n}}), + }}; + OutputGraphExprValue o_partition_input_output = + get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n)); + + OutputOperatorAttrsAssignment replicate_weights_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPLICATE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }}; + OutputGraphExprValue o_replicate_weights_output = get_only( + b.add_output_graph_node(replicate_weights_expr, {o_weight}, 1_n)); + + std::vector o_linear_inputs = { + o_partition_input_output, o_replicate_weights_output}; + + if (use_bias) { + OutputOperatorAttrsAssignment replicate_bias_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPLICATE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }}; + OutputGraphExprValue o_replicate_bias_output = get_only( + b.add_output_graph_node(replicate_bias_expr, {o_bias.value()}, 1_n)); + o_linear_inputs.push_back(o_replicate_bias_output); + } + + OutputOperatorAttrsAssignment linear_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("linear"), + {}, + }; + OutputGraphExprValue o_linear_output = + get_only(b.add_output_graph_node(linear_expr, o_linear_inputs, 1_n)); + + OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::COMBINE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant( + OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{ + nonnegative_int{num_dims.unwrap_nonnegative() - 1}, + }}), + }, + }; + OutputGraphExprValue o_combine_output = + get_only(b.add_output_graph_node(combine_expr, {o_linear_output}, 1_n)); + + b.equate_outputs(p_linear_output, o_combine_output); + + return b.get_substitution(); } Substitution create_partition_conv2d_combine(nonnegative_int num_dims, nonnegative_int degree) { - NOT_IMPLEMENTED(); + if (num_dims != 4) { + throw mk_runtime_error(fmt::format("num_dims must be 4, not {}", num_dims)); + } + + SubstitutionBuilder b; + + auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); + auto [p_weight, o_weight] = b.add_input(tensor_attribute_pattern_match_all()); + std::vector p_inputs = {p_input, p_weight}; + + OperatorAttributePattern conv2d_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::CONV2D), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + }}; + + PatternValue p_conv2d_output = get_only( + b.add_pattern_node(conv2d_pattern, + p_inputs, + {tensor_attr_pattern_require_num_dims(num_dims)}, + "conv2d")); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{0_n}}), + }}; + + OutputGraphExprValue o_partition_input_output = + get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n)); + + OutputOperatorAttrsAssignment replicate_weights_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPLICATE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }}; + OutputGraphExprValue o_replicate_weights_output = get_only( + b.add_output_graph_node(replicate_weights_expr, {o_weight}, 1_n)); + + std::vector o_conv2d_inputs = { + o_partition_input_output, o_replicate_weights_output}; + + OutputOperatorAttrsAssignment conv2d_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("conv2d"), + {}, + }; + OutputGraphExprValue o_conv2d_output = + get_only(b.add_output_graph_node(conv2d_expr, o_conv2d_inputs, 1_n)); + + OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::COMBINE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant( + OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{ + nonnegative_int{num_dims.unwrap_nonnegative() - 1}, + }}), + }, + }; + OutputGraphExprValue o_combine_output = + get_only(b.add_output_graph_node(combine_expr, {o_conv2d_output}, 1_n)); + + b.equate_outputs(p_conv2d_output, o_combine_output); + + return b.get_substitution(); } Substitution create_partition_attention_combine(nonnegative_int num_heads, nonnegative_int degree) { - NOT_IMPLEMENTED(); + + SubstitutionBuilder b; + + auto [p_query_input, o_query_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_key_input, o_key_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_value_input, o_value_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_weights, o_weights] = + b.add_input(tensor_attribute_pattern_match_all()); + std::vector p_inputs = { + p_query_input, p_key_input, p_value_input, p_weights}; + + OperatorAttributePattern attention_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::MULTIHEAD_ATTENTION), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + op_attr_key_divisible_by(OperatorAttributeKey::NUM_HEADS, num_heads), + }}; + + PatternValue p_attention_output = + get_only(b.add_pattern_node(attention_pattern, + p_inputs, + {tensor_attr_pattern_require_num_dims(3_n)}, + "attention")); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{0_n}}), + }}; + + OutputGraphExprValue o_partition_query_input_output = get_only( + b.add_output_graph_node(partition_input_expr, {o_query_input}, 1_n)); + + OutputGraphExprValue o_partition_key_input_output = get_only( + b.add_output_graph_node(partition_input_expr, {o_key_input}, 1_n)); + + OutputGraphExprValue o_partition_value_input_output = get_only( + b.add_output_graph_node(partition_input_expr, {o_value_input}, 1_n)); + + OutputOperatorAttrsAssignment replicate_weight_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPLICATE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }}; + + OutputGraphExprValue o_replicate_weight_output = get_only( + b.add_output_graph_node(replicate_weight_expr, {o_weights}, 1_n)); + + std::vector o_attention_inputs = { + o_partition_query_input_output, + o_partition_key_input_output, + o_partition_value_input_output, + o_replicate_weight_output}; + + OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("attention"), + {}, + }; + OutputGraphExprValue o_attention_output = get_only( + b.add_output_graph_node(attention_expr, o_attention_inputs, 1_n)); + + OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::COMBINE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{ + 2_n, + }}), + }, + }; + OutputGraphExprValue o_combine_output = get_only( + b.add_output_graph_node(combine_expr, {o_attention_output}, 1_n)); + + b.equate_outputs(p_attention_output, o_combine_output); + + return b.get_substitution(); } Substitution create_replicate_attention_reduce(nonnegative_int num_heads, nonnegative_int degree) { - NOT_IMPLEMENTED(); + + SubstitutionBuilder b; + + auto [p_query_input, o_query_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_key_input, o_key_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_value_input, o_value_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_weights, o_weights] = + b.add_input(tensor_attribute_pattern_match_all()); + std::vector p_inputs = { + p_query_input, p_key_input, p_value_input, p_weights}; + + OperatorAttributePattern attention_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::MULTIHEAD_ATTENTION), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + op_attr_key_divisible_by(OperatorAttributeKey::NUM_HEADS, num_heads), + }}; + + PatternValue p_attention_output = + get_only(b.add_pattern_node(attention_pattern, + p_inputs, + {tensor_attr_pattern_require_num_dims(3_n)}, + "attention")); + + OutputOperatorAttrsAssignment replicate_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPLICATE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }}; + + OutputGraphExprValue o_replicate_query_input_output = get_only( + b.add_output_graph_node(replicate_input_expr, {o_query_input}, 1_n)); + + OutputGraphExprValue o_replicate_key_input_output = get_only( + b.add_output_graph_node(replicate_input_expr, {o_key_input}, 1_n)); + + OutputGraphExprValue o_replicate_value_input_output = get_only( + b.add_output_graph_node(replicate_input_expr, {o_value_input}, 1_n)); + + OutputOperatorAttrsAssignment partition_weight_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{1_n}}), + }}; + + OutputGraphExprValue o_partition_weight_output = get_only( + b.add_output_graph_node(partition_weight_expr, {o_weights}, 1_n)); + + std::vector o_attention_inputs = { + o_replicate_query_input_output, + o_replicate_key_input_output, + o_replicate_value_input_output, + o_partition_weight_output}; + + OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("attention"), + {}, + }; + OutputGraphExprValue o_attention_output = get_only( + b.add_output_graph_node(attention_expr, o_attention_inputs, 1_n)); + + OutputOperatorAttrsAssignment reduce_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REDUCTION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }, + }; + OutputGraphExprValue o_reduce_output = + get_only(b.add_output_graph_node(reduce_expr, {o_attention_output}, 1_n)); + + b.equate_outputs(p_attention_output, o_reduce_output); + + return b.get_substitution(); +} + +Substitution create_partition_softmax_combine(ff_dim_t softmax_dim, + ff_dim_t partition_dim, + nonnegative_int degree) { + if (partition_dim == softmax_dim) { + throw mk_runtime_error( + fmt::format("partition dim {} must not be equal to softmax dim {}", + partition_dim, + softmax_dim)); + } + SubstitutionBuilder b; + + auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); + std::vector p_inputs = {p_input}; + + OperatorAttributePattern softmax_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::SOFTMAX), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + op_attr_key_divisible_by(OperatorAttributeKey::SOFTMAX_DIM, + softmax_dim.value), + }}; + + PatternValue p_softmax_output = + get_only(b.add_pattern_node(softmax_pattern, + p_inputs, + {tensor_attribute_pattern_match_all()}, + "softmax")); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{partition_dim}), + }}; + + OutputGraphExprValue o_partition_input_output = + get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n)); + + std::vector o_softmax_inputs = { + o_partition_input_output}; + + OutputOperatorAttrsAssignment softmax_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("softmax"), + {}, + }; + OutputGraphExprValue o_softmax_output = + get_only(b.add_output_graph_node(softmax_expr, o_softmax_inputs, 1_n)); + + OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::COMBINE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{partition_dim}), + }, + }; + OutputGraphExprValue o_combine_output = + get_only(b.add_output_graph_node(combine_expr, {o_softmax_output}, 1_n)); + + b.equate_outputs(p_softmax_output, o_combine_output); + + return b.get_substitution(); } Substitution create_partition_add_combine(ff_dim_t parallel_dim, nonnegative_int degree) { - NOT_IMPLEMENTED(); + SubstitutionBuilder b; + + auto [p_input1, o_input1] = b.add_input(tensor_attribute_pattern_match_all()); + auto [p_input2, o_input2] = b.add_input(tensor_attribute_pattern_match_all()); + std::vector p_inputs = {p_input1, p_input2}; + + OperatorAttributePattern add_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::EW_ADD), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + }}; + + PatternValue p_add_output = get_only(b.add_pattern_node( + add_pattern, p_inputs, {tensor_attribute_pattern_match_all()}, "add")); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{parallel_dim}), + }}; + + OutputGraphExprValue o_partition_input1_output = + get_only(b.add_output_graph_node(partition_input_expr, {o_input1}, 1_n)); + + OutputGraphExprValue o_partition_input2_output = + get_only(b.add_output_graph_node(partition_input_expr, {o_input2}, 1_n)); + + std::vector o_add_inputs = {o_partition_input1_output, + o_partition_input2_output}; + + OutputOperatorAttrsAssignment add_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("add"), + {}, + }; + OutputGraphExprValue o_add_output = + get_only(b.add_output_graph_node(add_expr, o_add_inputs, 1_n)); + + OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::COMBINE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{parallel_dim}), + }, + }; + OutputGraphExprValue o_combine_output = + get_only(b.add_output_graph_node(combine_expr, {o_add_output}, 1_n)); + + b.equate_outputs(p_add_output, o_combine_output); + + return b.get_substitution(); } Substitution create_partition_relu_combine(ff_dim_t parallel_dim, nonnegative_int degree) { - NOT_IMPLEMENTED(); -} + SubstitutionBuilder b; -Substitution create_partition_concat_combine(nonnegative_int num_inputs, - ff_dim_t concat_dim, - ff_dim_t parallel_dim, - nonnegative_int degree) { - NOT_IMPLEMENTED(); -} + auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); -Substitution create_partition_softmax_combine(ff_dim_t softmax_dim, - ff_dim_t partition_dim, - nonnegative_int degree) { - NOT_IMPLEMENTED(); + OperatorAttributePattern relu_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::RELU), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + }}; + + PatternValue p_relu_output = get_only(b.add_pattern_node( + relu_pattern, {p_input}, {tensor_attribute_pattern_match_all()}, "relu")); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{parallel_dim}), + }}; + + OutputGraphExprValue o_partition_input_output = + get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n)); + + OutputOperatorAttrsAssignment relu_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("relu"), + {}, + }; + OutputGraphExprValue o_relu_output = get_only( + b.add_output_graph_node(relu_expr, {o_partition_input_output}, 1_n)); + + OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::COMBINE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{parallel_dim}), + }, + }; + OutputGraphExprValue o_combine_output = + get_only(b.add_output_graph_node(combine_expr, {o_relu_output}, 1_n)); + + b.equate_outputs(p_relu_output, o_combine_output); + + return b.get_substitution(); } Substitution create_fuse_linear_activation(Activation activation) { diff --git a/lib/substitutions/test/src/substitutions/unity_substitution_set.cc b/lib/substitutions/test/src/substitutions/unity_substitution_set.cc index 804fa99bef..21c628ff0b 100644 --- a/lib/substitutions/test/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/test/src/substitutions/unity_substitution_set.cc @@ -1,8 +1,37 @@ #include "substitutions/unity_substitution_set.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/operator_type.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "substitutions/apply_substitution/apply_substitution.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/pcg_pattern.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/substitution_builder.h" +#include "utils/containers/get_only.h" #include using namespace ::FlexFlow; +template +static ParallelLayerAttrs make_layer_attrs( + T const &op_attrs, + std::optional const &maybe_name = std::nullopt) { + return ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{op_attrs}, + /*name=*/maybe_name, + }; +}; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_substitution_set") { MachineSpecification machine_spec = MachineSpecification{ @@ -15,6 +44,1507 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector result = get_substitution_set(machine_spec); - CHECK(result.size() == 36); + CHECK(result.size() == 184); + } + + TEST_CASE("create_replicate_linear_combine, use_bias = false") { + nonnegative_int num_dims = 1_n; + nonnegative_int degree = 1_n; + std::string linear_match = "linear_match"; + + Substitution sub = create_replicate_linear_combine(num_dims, degree, false); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_n, + 12_n, + }, + }, + DataType::FLOAT, + }; + + LinearAttrs linear_attrs = LinearAttrs{ + /*out_channels=*/12_n, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }; + + ReplicateAttrs replicate_input_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + WeightAttrs projection_weight_attrs = WeightAttrs{ + /*tensor_shape=*/throw_if_unexpected( + get_projection_shape(linear_attrs, input_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + RepartitionAttrs partition_projection_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1_n}, + /*repartition_degree=*/degree, + }; + + CombineAttrs combine_op_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{ + nonnegative_int{num_dims.unwrap_nonnegative() - 1}}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult linear_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs, linear_match), + {t_input}, + {t_projection_weight}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, linear_match); + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(1); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult replicate_input_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_replicated_input = + get_only(replicate_input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult partition_projection_added = + add_parallel_layer(pcg, + make_layer_attrs(partition_projection_attrs), + {t_projection_weight}, + {}); + parallel_tensor_guid_t t_partitioned_projection_weight = + get_only(partition_projection_added.outputs); + + ParallelLayerAddedResult replicate_linear_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs), + {t_replicated_input}, + {t_partitioned_projection_weight}); + parallel_tensor_guid_t t_replicated_linear = + get_only(replicate_linear_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_op_attrs), {t_replicated_linear}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_replicate_linear_combine, use_bias = true") { + nonnegative_int num_dims = 1_n; + nonnegative_int degree = 1_n; + std::string linear_match = "linear_match"; + + Substitution sub = create_replicate_linear_combine(num_dims, degree, true); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_n, + 12_n, + }, + }, + DataType::FLOAT, + }; + + LinearAttrs linear_attrs = LinearAttrs{ + /*out_channels=*/12_n, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }; + + ReplicateAttrs replicate_input_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + WeightAttrs projection_weight_attrs = WeightAttrs{ + /*tensor_shape=*/throw_if_unexpected( + get_projection_shape(linear_attrs, input_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + WeightAttrs bias_attrs = WeightAttrs{ + /*tensor_shape=*/throw_if_unexpected( + get_bias_shape(linear_attrs, input_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + RepartitionAttrs partition_projection_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1_n}, + /*repartition_degree=*/degree, + }; + + CombineAttrs combine_op_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{ + nonnegative_int{num_dims.unwrap_nonnegative() - 1}}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult bias_added = + add_parallel_layer(pcg, make_layer_attrs(bias_attrs), {}, {}); + parallel_tensor_guid_t t_bias = get_only(bias_added.outputs); + + ParallelLayerAddedResult linear_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs, linear_match), + {t_input}, + {t_projection_weight, t_bias}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, linear_match); + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(1); + open_parallel_tensor_guid_t match_layer_input_bias = + get_layer_inputs(original_pcg, match_layer).at(2); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_input_weights, + }, + { + PatternInput{DataflowGraphInput{4}}, + match_layer_input_bias, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult replicate_input_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_replicated_input = + get_only(replicate_input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult partition_projection_added = + add_parallel_layer(pcg, + make_layer_attrs(partition_projection_attrs), + {t_projection_weight}, + {}); + parallel_tensor_guid_t t_partitioned_projection_weight = + get_only(partition_projection_added.outputs); + + ParallelLayerAddedResult bias_added = + add_parallel_layer(pcg, make_layer_attrs(bias_attrs), {}, {}); + parallel_tensor_guid_t t_bias = get_only(bias_added.outputs); + + ParallelLayerAddedResult partition_bias_added = add_parallel_layer( + pcg, make_layer_attrs(partition_projection_attrs), {t_bias}, {}); + parallel_tensor_guid_t t_partitioned_bias = + get_only(partition_bias_added.outputs); + + ParallelLayerAddedResult replicate_linear_added = add_parallel_layer( + pcg, + make_layer_attrs(linear_attrs), + {t_replicated_input}, + {t_partitioned_projection_weight, t_partitioned_bias}); + parallel_tensor_guid_t t_replicated_linear = + get_only(replicate_linear_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_op_attrs), {t_replicated_linear}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_linear_combine, use_bias = false") { + nonnegative_int num_dims = 1_n; + nonnegative_int degree = 2_n; + std::string linear_match = "linear_match"; + + Substitution sub = create_partition_linear_combine(num_dims, degree, false); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_n, + 12_n, + }, + }, + DataType::FLOAT, + }; + + LinearAttrs linear_attrs = LinearAttrs{ + /*out_channels=*/12_n, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{0_n}, + /*repartition_degree=*/degree, + }; + + WeightAttrs projection_weight_attrs = WeightAttrs{ + /*tensor_shape=*/throw_if_unexpected( + get_projection_shape(linear_attrs, input_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + ReplicateAttrs replicate_projection_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + CombineAttrs combine_op_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{ + nonnegative_int{num_dims.unwrap_nonnegative() - 1}}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult linear_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs, linear_match), + {t_input}, + {t_projection_weight}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, linear_match); + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(1); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult partition_input_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_partitioned_input = + get_only(partition_input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult replicate_projection_added = + add_parallel_layer(pcg, + make_layer_attrs(replicate_projection_attrs), + {t_projection_weight}, + {}); + parallel_tensor_guid_t t_replicated_projection_weight = + get_only(replicate_projection_added.outputs); + + ParallelLayerAddedResult partition_linear_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs), + {t_partitioned_input}, + {t_replicated_projection_weight}); + parallel_tensor_guid_t t_partitioned_linear = + get_only(partition_linear_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_op_attrs), {t_partitioned_linear}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_linear_combine, use_bias = true") { + nonnegative_int num_dims = 1_n; + nonnegative_int degree = 2_n; + std::string linear_match = "linear_match"; + + Substitution sub = create_partition_linear_combine(num_dims, degree, true); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_n, + 12_n, + }, + }, + DataType::FLOAT, + }; + + LinearAttrs linear_attrs = LinearAttrs{ + /*out_channels=*/12_n, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{0_n}, + /*repartition_degree=*/degree, + }; + + WeightAttrs projection_weight_attrs = WeightAttrs{ + /*tensor_shape=*/throw_if_unexpected( + get_projection_shape(linear_attrs, input_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + WeightAttrs bias_attrs = WeightAttrs{ + /*tensor_shape=*/throw_if_unexpected( + get_bias_shape(linear_attrs, input_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + ReplicateAttrs replicate_projection_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + CombineAttrs combine_op_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{ + nonnegative_int{num_dims.unwrap_nonnegative() - 1}}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult bias_added = + add_parallel_layer(pcg, make_layer_attrs(bias_attrs), {}, {}); + parallel_tensor_guid_t t_bias = get_only(bias_added.outputs); + + ParallelLayerAddedResult linear_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs, linear_match), + {t_input}, + {t_projection_weight, t_bias}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, linear_match); + std::cout << get_layer_inputs(original_pcg, match_layer) << std::endl; + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(1); + open_parallel_tensor_guid_t match_layer_input_bias = + get_layer_inputs(original_pcg, match_layer).at(2); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_input_weights, + }, + { + PatternInput{DataflowGraphInput{4}}, + match_layer_input_bias, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult partition_input_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_partitioned_input = + get_only(partition_input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult replicate_projection_added = + add_parallel_layer(pcg, + make_layer_attrs(replicate_projection_attrs), + {t_projection_weight}, + {}); + parallel_tensor_guid_t t_replicated_projection_weight = + get_only(replicate_projection_added.outputs); + + ParallelLayerAddedResult bias_added = + add_parallel_layer(pcg, make_layer_attrs(bias_attrs), {}, {}); + parallel_tensor_guid_t t_bias = get_only(bias_added.outputs); + + ParallelLayerAddedResult replicate_bias_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_projection_attrs), {t_bias}, {}); + parallel_tensor_guid_t t_replicated_bias = + get_only(replicate_bias_added.outputs); + + ParallelLayerAddedResult partition_linear_added = add_parallel_layer( + pcg, + make_layer_attrs(linear_attrs), + {t_partitioned_input}, + {t_replicated_projection_weight, t_replicated_bias}); + parallel_tensor_guid_t t_partitioned_linear = + get_only(partition_linear_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_op_attrs), {t_partitioned_linear}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_conv2d_combine") { + nonnegative_int outChannels = 6_n; + nonnegative_int kernelH = 5_n; + nonnegative_int kernelW = 4_n; + nonnegative_int strideH = 3_n; + nonnegative_int strideW = 2_n; + nonnegative_int paddingH = 1_n; + nonnegative_int paddingW = 0_n; + nonnegative_int num_dims = 4_n; + nonnegative_int degree = 1_n; + std::string conv2d_match = "conv2d_match"; + + Substitution sub = create_partition_conv2d_combine(num_dims, degree); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 12_n, + 3_n, + 10_n, + 10_n, + }, + }, + DataType::FLOAT, + }; + + Conv2DAttrs conv2d_attrs = Conv2DAttrs{/*outChannels=*/outChannels, + /*kernelH=*/kernelH, + /*kernelW=*/kernelW, + /*strideH=*/strideH, + /*strideW=*/strideW, + /*paddingH=*/paddingH, + /*paddingW=*/paddingW, + /*groups=*/1_n, + /*activation=*/std::nullopt, + /*use_bias=*/false}; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{0_n}, + /*repartition_degree=*/degree, + }; + + ReplicateAttrs replicate_weight_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + CombineAttrs combine_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{ + nonnegative_int{num_dims.unwrap_nonnegative() - 1}}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + TensorShape casted_input_shape = + get_reduced_shape(get_parallel_tensor_shape(pcg, t_input)); + + WeightAttrs projection_weight_attrs = WeightAttrs{ + /*tensor_shape=*/ + get_weight_shapes(conv2d_attrs, casted_input_shape).at(0), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult conv_2d_added = + add_parallel_layer(pcg, + make_layer_attrs(conv2d_attrs, conv2d_match), + {t_input}, + {t_projection_weight}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, conv2d_match); + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(1); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult partition_input_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_partitioned_input = + get_only(partition_input_added.outputs); + + TensorShape casted_input_shape = + get_reduced_shape(get_parallel_tensor_shape(pcg, t_input)); + + WeightAttrs weight_attrs = WeightAttrs{ + /*tensor_shape=*/ + get_weight_shapes(conv2d_attrs, casted_input_shape).at(0), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + ParallelLayerAddedResult weight_added = + add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); + parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); + + ParallelLayerAddedResult replicate_weight_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_weight_attrs), {t_weight}, {}); + parallel_tensor_guid_t t_replicated_weight = + get_only(replicate_weight_added.outputs); + + ParallelLayerAddedResult partition_conv2d_added = + add_parallel_layer(pcg, + make_layer_attrs(conv2d_attrs), + {t_partitioned_input}, + {t_replicated_weight}); + parallel_tensor_guid_t t_partitioned_conv2d = + get_only(partition_conv2d_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_attrs), {t_partitioned_conv2d}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_attention_combine") { + nonnegative_int embed_dim = 8_n; + nonnegative_int num_heads = 6_n; + nonnegative_int degree = 1_n; + std::string attention_match = "attention_match"; + + Substitution sub = create_partition_attention_combine(num_heads, degree); + + TensorShape query_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 12_n, + 16_n, + 10_n, + }, + }, + DataType::FLOAT, + }; + TensorShape key_shape = query_shape; + TensorShape value_shape = query_shape; + + MultiHeadAttentionAttrs attention_attrs = MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0, + /*bias=*/false, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{0_n}, + /*repartition_degree=*/degree, + }; + + WeightAttrs weight_attrs = WeightAttrs{ + /*tensor_shape=*/ + throw_if_unexpected(get_weights_shape( + attention_attrs, query_shape, key_shape, value_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + ReplicateAttrs replicate_weight_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + CombineAttrs combine_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{2_n}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult query_added = + pcg_add_input_layer(pcg, query_shape); + parallel_tensor_guid_t t_query = get_only(query_added.outputs); + + ParallelLayerAddedResult key_added = pcg_add_input_layer(pcg, key_shape); + parallel_tensor_guid_t t_key = get_only(key_added.outputs); + + ParallelLayerAddedResult value_added = + pcg_add_input_layer(pcg, value_shape); + parallel_tensor_guid_t t_value = get_only(value_added.outputs); + + ParallelLayerAddedResult weight_added = + add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); + parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); + + ParallelLayerAddedResult attention_added = + add_parallel_layer(pcg, + make_layer_attrs(attention_attrs, attention_match), + {t_query, t_key, t_value}, + {t_weight}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, attention_match); + open_parallel_tensor_guid_t match_layer_query = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_key = + get_layer_inputs(original_pcg, match_layer).at(1); + open_parallel_tensor_guid_t match_layer_value = + get_layer_inputs(original_pcg, match_layer).at(2); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(3); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_query, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_key, + }, + { + PatternInput{DataflowGraphInput{4}}, + match_layer_value, + }, + { + PatternInput{DataflowGraphInput{6}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult query_added = + pcg_add_input_layer(pcg, query_shape); + parallel_tensor_guid_t t_query = get_only(query_added.outputs); + + ParallelLayerAddedResult key_added = pcg_add_input_layer(pcg, key_shape); + parallel_tensor_guid_t t_key = get_only(key_added.outputs); + + ParallelLayerAddedResult value_added = + pcg_add_input_layer(pcg, value_shape); + parallel_tensor_guid_t t_value = get_only(value_added.outputs); + + ParallelLayerAddedResult weight_added = + add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); + parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); + + ParallelLayerAddedResult partition_query_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_query}, {}); + parallel_tensor_guid_t t_partitioned_query = + get_only(partition_query_added.outputs); + + ParallelLayerAddedResult partition_key_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_key}, {}); + parallel_tensor_guid_t t_partitioned_key = + get_only(partition_key_added.outputs); + + ParallelLayerAddedResult partition_value_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_value}, {}); + parallel_tensor_guid_t t_partitioned_value = + get_only(partition_value_added.outputs); + + ParallelLayerAddedResult replicate_weight_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_weight_attrs), {t_weight}, {}); + parallel_tensor_guid_t t_replicated_weight = + get_only(replicate_weight_added.outputs); + + ParallelLayerAddedResult partition_attention_added = add_parallel_layer( + pcg, + make_layer_attrs(attention_attrs), + {t_partitioned_query, t_partitioned_key, t_partitioned_value}, + {t_replicated_weight}); + parallel_tensor_guid_t t_partitioned_attention = + get_only(partition_attention_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_attrs), {t_partitioned_attention}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_replicate_attention_reduce") { + nonnegative_int embed_dim = 8_n; + nonnegative_int num_heads = 6_n; + nonnegative_int degree = 1_n; + std::string attention_match = "attention_match"; + + Substitution sub = create_replicate_attention_reduce(num_heads, degree); + + TensorShape query_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 12_n, + 16_n, + 10_n, + }, + }, + DataType::FLOAT, + }; + TensorShape key_shape = query_shape; + TensorShape value_shape = query_shape; + + MultiHeadAttentionAttrs attention_attrs = MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0, + /*bias=*/false, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + ReplicateAttrs replicate_input_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + WeightAttrs weight_attrs = WeightAttrs{ + /*tensor_shape=*/ + throw_if_unexpected(get_weights_shape( + attention_attrs, query_shape, key_shape, value_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + RepartitionAttrs partition_weight_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1_n}, + /*repartition_degree=*/degree, + }; + + ReductionAttrs reduction_attrs = ReductionAttrs{ + /*reduction_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult query_added = + pcg_add_input_layer(pcg, query_shape); + parallel_tensor_guid_t t_query = get_only(query_added.outputs); + + ParallelLayerAddedResult key_added = pcg_add_input_layer(pcg, key_shape); + parallel_tensor_guid_t t_key = get_only(key_added.outputs); + + ParallelLayerAddedResult value_added = + pcg_add_input_layer(pcg, value_shape); + parallel_tensor_guid_t t_value = get_only(value_added.outputs); + + ParallelLayerAddedResult weight_added = + add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); + parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); + + ParallelLayerAddedResult attention_added = + add_parallel_layer(pcg, + make_layer_attrs(attention_attrs, attention_match), + {t_query, t_key, t_value}, + {t_weight}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, attention_match); + open_parallel_tensor_guid_t match_layer_query = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_key = + get_layer_inputs(original_pcg, match_layer).at(1); + open_parallel_tensor_guid_t match_layer_value = + get_layer_inputs(original_pcg, match_layer).at(2); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(3); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_query, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_key, + }, + { + PatternInput{DataflowGraphInput{4}}, + match_layer_value, + }, + { + PatternInput{DataflowGraphInput{6}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult query_added = + pcg_add_input_layer(pcg, query_shape); + parallel_tensor_guid_t t_query = get_only(query_added.outputs); + + ParallelLayerAddedResult key_added = pcg_add_input_layer(pcg, key_shape); + parallel_tensor_guid_t t_key = get_only(key_added.outputs); + + ParallelLayerAddedResult value_added = + pcg_add_input_layer(pcg, value_shape); + parallel_tensor_guid_t t_value = get_only(value_added.outputs); + + ParallelLayerAddedResult weight_added = + add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); + parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); + + ParallelLayerAddedResult replicate_query_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_input_attrs), {t_query}, {}); + parallel_tensor_guid_t t_replicated_query = + get_only(replicate_query_added.outputs); + + ParallelLayerAddedResult replicate_key_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_input_attrs), {t_key}, {}); + parallel_tensor_guid_t t_replicated_key = + get_only(replicate_key_added.outputs); + + ParallelLayerAddedResult replicate_value_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_input_attrs), {t_value}, {}); + parallel_tensor_guid_t t_replicated_value = + get_only(replicate_value_added.outputs); + + ParallelLayerAddedResult partition_weight_added = add_parallel_layer( + pcg, make_layer_attrs(partition_weight_attrs), {t_weight}, {}); + parallel_tensor_guid_t t_partitioned_weight = + get_only(partition_weight_added.outputs); + + ParallelLayerAddedResult replicate_attention_added = add_parallel_layer( + pcg, + make_layer_attrs(attention_attrs), + {t_replicated_query, t_replicated_key, t_replicated_value}, + {t_partitioned_weight}); + parallel_tensor_guid_t t_replicated_attention = + get_only(replicate_attention_added.outputs); + + ParallelLayerAddedResult reduce_added = add_parallel_layer( + pcg, make_layer_attrs(reduction_attrs), {t_replicated_attention}, {}); + parallel_tensor_guid_t t_reduction = get_only(reduce_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_softmax_combine") { + nonnegative_int degree = 1_n; + ff_dim_t softmax_dim = ff_dim_t{1_n}; + ff_dim_t partition_dim = ff_dim_t{0_n}; + std::string softmax_match = "softmax_match"; + + Substitution sub = + create_partition_softmax_combine(softmax_dim, partition_dim, degree); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_n, + 10_n, + }, + }, + DataType::FLOAT, + }; + + SoftmaxAttrs softmax_attrs = SoftmaxAttrs{ + /*softmax_dim=*/softmax_dim, + }; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/partition_dim, + /*repartition_degree=*/degree, + }; + + CombineAttrs combine_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{partition_dim}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult softmax_added = add_parallel_layer( + pcg, make_layer_attrs(softmax_attrs, softmax_match), {t_input}, {}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, softmax_match); + open_parallel_tensor_guid_t match_layer_input = + get_layer_inputs(original_pcg, match_layer).at(0); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{{ + PatternInput{DataflowGraphInput{0}}, + match_layer_input, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult partition_input_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_partitioned_input = + get_only(partition_input_added.outputs); + + ParallelLayerAddedResult partition_softmax_added = add_parallel_layer( + pcg, make_layer_attrs(softmax_attrs), {t_partitioned_input}, {}); + parallel_tensor_guid_t t_partitioned_softmax = + get_only(partition_softmax_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_attrs), {t_partitioned_softmax}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_add_combine") { + nonnegative_int degree = 1_n; + ff_dim_t parallel_dim = ff_dim_t{1_n}; + std::string add_match = "add_match"; + + Substitution sub = create_partition_add_combine(parallel_dim, degree); + + TensorShape lhs_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_n, + 15_n, + }, + }, + DataType::FLOAT, + }; + + TensorShape rhs_shape = lhs_shape; + + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/parallel_dim, + /*repartition_degree=*/degree, + }; + + CombineAttrs combine_attrs = CombineAttrs{ + /*combine_dim=*/parallel_dim, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult lhs_added = pcg_add_input_layer(pcg, lhs_shape); + parallel_tensor_guid_t t_lhs = get_only(lhs_added.outputs); + + ParallelLayerAddedResult rhs_added = pcg_add_input_layer(pcg, rhs_shape); + parallel_tensor_guid_t t_rhs = get_only(rhs_added.outputs); + + ParallelLayerAddedResult output_added = add_parallel_layer( + pcg, make_layer_attrs(add_attrs, add_match), {t_lhs, t_rhs}, {}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, add_match); + open_parallel_tensor_guid_t add_match_layer_lhs = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t add_match_layer_rhs = + get_layer_inputs(original_pcg, match_layer).at(1); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + add_match_layer_lhs, + }, + { + PatternInput{DataflowGraphInput{2}}, + add_match_layer_rhs, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult lhs_added = pcg_add_input_layer(pcg, lhs_shape); + parallel_tensor_guid_t t_lhs = get_only(lhs_added.outputs); + + ParallelLayerAddedResult rhs_added = pcg_add_input_layer(pcg, rhs_shape); + parallel_tensor_guid_t t_rhs = get_only(rhs_added.outputs); + + ParallelLayerAddedResult partition_lhs_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_lhs}, {}); + parallel_tensor_guid_t t_partitioned_lhs = + get_only(partition_lhs_added.outputs); + + ParallelLayerAddedResult partition_rhs_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_rhs}, {}); + parallel_tensor_guid_t t_partitioned_rhs = + get_only(partition_rhs_added.outputs); + + ParallelLayerAddedResult partition_add_added = + add_parallel_layer(pcg, + make_layer_attrs(add_attrs, add_match), + {t_partitioned_lhs, t_partitioned_rhs}, + {}); + parallel_tensor_guid_t t_partitioned_add = + get_only(partition_add_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_attrs), {t_partitioned_add}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_relu_combine") { + nonnegative_int degree = 1_n; + ff_dim_t parallel_dim = ff_dim_t{1_n}; + std::string relu_match = "relu_match"; + + Substitution sub = create_partition_relu_combine(parallel_dim, degree); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_n, + 10_n, + }, + }, + DataType::FLOAT, + }; + + ElementUnaryAttrs relu_attrs = ElementUnaryAttrs{ + OperatorType::RELU, + std::nullopt, + }; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/parallel_dim, + /*repartition_degree=*/degree, + }; + + CombineAttrs combine_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{parallel_dim}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult relu_added = add_parallel_layer( + pcg, make_layer_attrs(relu_attrs, relu_match), {t_input}, {}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, relu_match); + open_parallel_tensor_guid_t match_layer_input = + get_layer_inputs(original_pcg, match_layer).at(0); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{{ + PatternInput{DataflowGraphInput{0}}, + match_layer_input, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult partition_input_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_partitioned_input = + get_only(partition_input_added.outputs); + + ParallelLayerAddedResult partition_relu_added = add_parallel_layer( + pcg, make_layer_attrs(relu_attrs), {t_partitioned_input}, {}); + parallel_tensor_guid_t t_partitioned_relu = + get_only(partition_relu_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_attrs), {t_partitioned_relu}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_fuse_linear_activation") { + Substitution sub = create_fuse_linear_activation(Activation::SIGMOID); + nonnegative_int in_channels = 24_n; + nonnegative_int batch_size = 4_n; + nonnegative_int batch_degree = 2_n; + std::string mm_match = "mm_match"; + std::string relu_match = "relu_match"; + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 4_n, + 10_n, + }, + }, + DataType::FLOAT, + }; + + SubParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + parallel_tensor_guid_t t = b.create_input_tensor(input_shape); + t = b.dense(t, + /*outDim=*/4_n, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/mm_match); + t = b.relu(t, + /*name=*/relu_match); + + return sub_pcg_from_full_pcg(b.pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t mm_match_layer = + get_parallel_layer_by_name(pcg, mm_match); + parallel_layer_guid_t relu_match_layer = + get_parallel_layer_by_name(pcg, relu_match); + open_parallel_tensor_guid_t mm_match_layer_input_activations = + get_layer_inputs(pcg, mm_match_layer).at(0); + open_parallel_tensor_guid_t mm_match_layer_input_weights = + get_layer_inputs(pcg, mm_match_layer).at(1); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, mm_match_layer}, + {PatternNode{Node{1}}, relu_match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + mm_match_layer_input_activations, + }, + { + PatternInput{DataflowGraphInput{2}}, + mm_match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = apply_substitution(pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraphBuilder b; + parallel_tensor_guid_t t = b.create_input_tensor(input_shape); + t = b.dense(t, + /*outDim=*/4_n, + /*activation=*/Activation::SIGMOID, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/std::nullopt); + + return sub_pcg_from_full_pcg(b.pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); } }