Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,8 @@ void PatternGrouper::VisitExprs() {
}

void PatternGrouper::CreateGroup(const Expr& expr) {
VLOG(1) << "Creating group for:" << std::endl << PrettyPrint(expr);

int var_number = 0;

auto node_map = matcher_->GetMemo();
Expand Down Expand Up @@ -696,6 +698,7 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
auto body = extractor.Mutate(expr);

group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
VLOG(1) << "Candidate extracted function:" << std::endl << PrettyPrint(group.function);
group.name = extractor.GetName();
// Check to make sure we aren't overlapping with another group or creating an invalid fusion
// The MatchExtractor will create a new graph by replacing nodes that match the inputs of the
Expand All @@ -708,6 +711,10 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
// Similiarly, if interior nodes in a group are used outside of the group fusing to a single
// output would create an invalid graph tranformation, so we block the creation of such groups.
auto memo = extractor.GetMemo();
for (auto kv : memo) {
VLOG(1) << "matched index " << matcher_->expr_to_node(kv.first)->index_;
}

for (auto kv : memo) {
// Check to ensure that this node isn't an input or a global
if (inputs.count(kv.first) == 0 && kv.first.as<OpNode>() == nullptr &&
Expand All @@ -720,16 +727,19 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
// if the node isn't the output of the group
auto node = matcher_->expr_to_node(kv.first);
for (auto* output : node->outputs_) {
// and the node is used by nodes outside of the group
if (memo.count(output->ref()) == 0) {
// TODO(mbs): This condition used to also include the following test, which since
// the dominators relation is used back-to-front was always vacuously true. So the
// code is just rejecting the match if a strictly internal node happened to connect
// to an outside node.
ICHECK(!matcher_->expr_to_node(expr)->Dominates(output));
// Exit because nodes in this pattern's body are used outside the pattern, fusing it
// would be invalid
return;
// A node inside the matched group contributes an output to nodes outside of the matched
// group...
auto root = matcher_->expr_to_node(expr);
if (!root->Dominates(output)) {
// ...and the outside dataflow does not come back to the root of the matched group.
// So reject the match since it would create a cycle.
VLOG(1) << "Rejecting group since would create a cycle with output " << output->index_
<< " for root " << root->index_ << " in graph:" << std::endl
<< matcher_->expr_graph().ToString();
return;
}
// else: We'll allow the output to be included in the matched group.
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/relay/ir/dataflow_matcher_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
const std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual>& memo() const {
return memo_;
}
const IndexedGraph<Expr>& expr_graph() const { return *expr_graph_; }

protected:
bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,4 +941,4 @@ def test_conv2d_bwd():


if __name__ == "__main__":
pytest.main([__file__])
tvm.testing.main()
52 changes: 51 additions & 1 deletion tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,6 @@ def concat(*args):


def test_partition_fuzzy_function_args():

func_pattern = FunctionPattern(None, wildcard() + wildcard())(None) + wildcard()
x = relay.var("x")
y = relay.var("y")
Expand Down Expand Up @@ -1790,5 +1789,56 @@ def callback(self, pre, post, node_map):
assert tvm.ir.structural_equal(out, expected)


def test_matched_outside_but_dominated():
"""In this example the pattern matches the nn.conv2d/add/multiply flow. Even though the
add output is consumed by the sigmoid, the sigmoid itself is dominated by the multiply.
So partitioning can proceed, all be it with a duplication of the add."""
in_mod = tvm.parser.parse(
"""
#[version = "0.0.5"]
def @main(%data: Tensor[(16, 16, 32, 32), float16], %weight: Tensor[(32, 16, 3, 3), float16], %bias: Tensor[(32), float32]) -> Tensor[(16, 32, 32, 32), float32] {
%0 = layout_transform(%data, src_layout="NCHW", dst_layout="NHWC");
%1 = layout_transform(%weight, src_layout="OIHW", dst_layout="OHWI");
%2 = expand_dims(%bias, axis=1, num_newaxis=2);
%3 = expand_dims(%2, axis=0);
%4 = nn.conv2d(%0, %1, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="OHWI", out_dtype="float32");
%5 = layout_transform(%3, src_layout="NCHW", dst_layout="NHWC");
%6 = add(%4, %5);
%7 = sigmoid(%6);
%8 = multiply(%6, %7);
layout_transform(%8, src_layout="NHWC", dst_layout="NCHW")
}
"""
)
expected_mod = tvm.parser.parse(
"""
#[version = "0.0.5"]
def @main(%data: Tensor[(16, 16, 32, 32), float16], %weight: Tensor[(32, 16, 3, 3), float16], %bias: Tensor[(32), float32]) -> Tensor[(16, 32, 32, 32), float32] {
%2 = expand_dims(%bias, axis=1, num_newaxis=2);
%3 = expand_dims(%2, axis=0);
%4 = layout_transform(%data, src_layout="NCHW", dst_layout="NHWC");
%5 = layout_transform(%weight, src_layout="OIHW", dst_layout="OHWI");
%6 = nn.conv2d(%4, %5, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="OHWI", out_dtype="float32");
%7 = layout_transform(%3, src_layout="NCHW", dst_layout="NHWC");
%8 = add(%6, %7);
%9 = sigmoid(%8);
%10 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, %FunctionVar_0_3, PartitionedFromPattern="nn.conv2d_add_multiply_") {
%0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="OHWI", out_dtype="float32");
%1 = add(%0, %FunctionVar_0_2);
multiply(%1, %FunctionVar_0_3)
};
%11 = %10(%4, %5, %7, %9);
layout_transform(%11, src_layout="NHWC", dst_layout="NCHW")
}
"""
)
pattern = is_op("multiply")(
is_op("add")(is_op("nn.conv2d")(wildcard(), wildcard()), wildcard()), wildcard()
)
actual_mod = tvm.IRModule.from_expr(pattern.partition(in_mod["main"]))
actual_mod = relay.transform.InferType()(actual_mod)
tvm.ir.assert_structural_equal(actual_mod, expected_mod)


if __name__ == "__main__":
tvm.testing.main()