Skip to content

Commit eb011c7

Browse files
authored
[Bugfix][Relax] Preserve existing DataflowBlock in ConvertToDataflow (#17148)
The `relax.transform.ConvertToDataflow` identifies portions of a Relax function that satisfy the requirements of a `relax::DataflowBlock`, and converts those portions to a new `DataflowBlock`, provided they are at least some minimum number of operations. Prior to this commit, if a function contained a region that would be converted to a `DataflowBlock`, but also contains existing `DataflowBlock`s that were smaller than the size required for creating a `DataflowBlock`, those existing blocks would be erroneously converted to non-dataflow. This commit updates the `ConvertToDataflow` pass to preserve all existing `DataflowBlock` present in the input.
1 parent 3755571 commit eb011c7

File tree

2 files changed

+173
-50
lines changed

2 files changed

+173
-50
lines changed

src/relax/transform/convert_dataflow.cc

Lines changed: 67 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include <tvm/relax/transform.h>
2929
#include <tvm/relax/utils.h>
3030

31+
#include <optional>
32+
3133
namespace tvm {
3234
namespace relax {
3335

@@ -39,85 +41,100 @@ class DataflowBlockExtractor : public ExprMutator {
3941
Array<BindingBlock> new_blocks;
4042
Expr new_body = VisitExpr(seq->body);
4143
bool changed = !new_body.same_as(seq->body);
42-
bool dataflow_streak = false;
43-
Array<Binding> dataflow_bindings;
44+
45+
// Accumulated bindings that are not going to be added to a
46+
// DataflowBlock, either because they would be illegal within a
47+
// DataflowBlock, or because there were insufficient bindings to
48+
// make a dataflowblock. Because these bindings occur prior to
49+
// `dataflow_bindings`, this array may only be accumulated into
50+
// when `dataflow_bindings` is empty.
4451
Array<Binding> non_dataflow_bindings;
4552

53+
// Current bindings that may legally be added to a DataflowBlock.
54+
Array<Binding> dataflow_bindings;
55+
56+
// If present, a DataflowBlock whose bindings are currently in
57+
// `dataflow_bindings`. Used to propagate DataflowBlock to the
58+
// output, even if it doesn't meet the minimum size.
59+
Optional<DataflowBlock> input_dataflow_block;
60+
61+
// Handle any bindings currently in `dataflow_bindings`. These
62+
// are either pushed to their own block, or to the end of
63+
// `non_dataflow_bindings`, depending on whether the bindings meet
64+
// the minimum size requirement.
65+
auto push_dataflow_bindings = [&]() {
66+
if (dataflow_bindings.empty()) {
67+
// No Dataflow bindings, so no action required.
68+
return;
69+
}
70+
if (dataflow_bindings.size() < min_size_ && !input_dataflow_block) {
71+
// The df block is below the minimum length, and no input
72+
// DataflowBlock needs to be preserved. Combine the blocks
73+
// and reset the dataflow collection.
74+
75+
non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(),
76+
dataflow_bindings.end());
77+
78+
} else {
79+
// A new DataflowBlock can be generated, with bindings that
80+
// occur after the non-dataflow bindings.
81+
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
82+
new_blocks.push_back(DataflowBlock(dataflow_bindings));
83+
non_dataflow_bindings = {};
84+
85+
// Making a dataflow block doesn't imply that the function was
86+
// changed. A change requires that this either be a new
87+
// dataflow block, or have additional dataflow bindings in the
88+
// current block.
89+
changed = changed || !input_dataflow_block.defined() ||
90+
input_dataflow_block.value()->bindings.size() != dataflow_bindings.size();
91+
}
92+
93+
dataflow_bindings = {};
94+
input_dataflow_block = NullOpt;
95+
};
96+
4697
for (auto block : seq->blocks) {
4798
BindingBlock new_block = this->VisitBindingBlock(block);
4899
changed = changed || !new_block.same_as(block);
49100

50101
// For an existing dataflow block, we add to the current streak
51102
// or start a new streak in case there will be more dataflow operations
52103
// coming up
53-
if (new_block.as<DataflowBlock>()) {
54-
if (!dataflow_streak) {
55-
dataflow_streak = true;
56-
}
104+
if (auto dataflow_block = new_block.as<DataflowBlock>()) {
57105
dataflow_bindings.insert(dataflow_bindings.end(), new_block->bindings.begin(),
58106
new_block->bindings.end());
107+
input_dataflow_block = dataflow_block;
59108
continue;
60109
}
61110

62111
// for a binding block, attempt to extract dataflow blocks inside
63112
auto binding_block = Downcast<BindingBlock>(new_block);
64-
for (size_t i = 0; i < binding_block->bindings.size(); i++) {
65-
auto binding = binding_block->bindings[i];
113+
for (const auto& binding : binding_block->bindings) {
66114
Expr value = GetBoundValue(binding);
67115
// dataflow values: not an if node and not an impure call
68116
bool is_dataflow = (!value.as<IfNode>()) &&
69117
(!(value.as<CallNode>() && IsImpureCall(Downcast<Call>(value))));
70-
if (!dataflow_streak) {
71-
// we can start a dataflow streak
72-
if (is_dataflow) {
73-
dataflow_streak = true;
74-
dataflow_bindings = {binding};
75-
} else {
76-
non_dataflow_bindings.push_back(binding);
77-
}
118+
if (is_dataflow) {
119+
// extend the streak
120+
dataflow_bindings.push_back(binding);
78121
} else {
79-
if (is_dataflow) {
80-
// extend the streak
81-
dataflow_bindings.push_back(binding);
82-
} else {
83-
// this is the end of the streak
84-
dataflow_streak = false;
85-
86-
// if the df block is below the minimum length, combine the blocks
87-
// and reset the dataflow collection
88-
if (dataflow_bindings.size() < min_size_) {
89-
non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(),
90-
dataflow_bindings.end());
91-
dataflow_bindings = {};
92-
} else {
93-
// otherwise insert both collections
94-
changed = true;
95-
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
96-
new_blocks.push_back(DataflowBlock(dataflow_bindings));
97-
non_dataflow_bindings = {};
98-
dataflow_bindings = {};
99-
}
100-
non_dataflow_bindings.push_back(binding);
101-
}
122+
// End the streak, if one currently exists.
123+
push_dataflow_bindings();
124+
non_dataflow_bindings.push_back(binding);
102125
}
103126
}
104127
}
105128

106129
// handle any remaining bindings
107-
if (dataflow_bindings.size() < min_size_) {
108-
non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(),
109-
dataflow_bindings.end());
110-
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
111-
} else {
112-
changed = true;
113-
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
114-
new_blocks.push_back(DataflowBlock(dataflow_bindings));
115-
}
130+
push_dataflow_bindings();
131+
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
116132

117-
if (!changed) {
133+
if (changed) {
134+
return SeqExpr(new_blocks, new_body);
135+
} else {
118136
return GetRef<SeqExpr>(seq);
119137
}
120-
return SeqExpr(new_blocks, new_body);
121138
}
122139

123140
private:

tests/python/relax/test_transform_convert_dataflow.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,5 +489,111 @@ def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
489489
return v
490490

491491

492+
class TestPreserveExistingDataflowBlocksAtBeginning(ExtractCompare):
493+
"""Preserve existing DataflowBlocks
494+
495+
This is a regression test. In previous implementations, a
496+
DataflowBlock in the input, without enough bindings to become a
497+
new dataflow block, could be accidentally ommitted.
498+
499+
This test is identical to
500+
`TestPreserveExistingDataflowBlocksAtEnd`, except that the
501+
existing dataflow block is at the beginning of the function.
502+
503+
"""
504+
505+
@I.ir_module
506+
class Before:
507+
@R.function(pure=False)
508+
def main(A0: R.Tensor, B0: R.Tensor):
509+
# This DataflowBlock is below the minimum size for a new
510+
# block, but already exists in the input IRModule.
511+
with R.dataflow():
512+
A1 = R.add(A0, A0)
513+
R.output(A1)
514+
515+
R.print(format="impure_function")
516+
517+
# This sequence is large enough that it may be converted
518+
# to a DataflowBlock.
519+
B1 = R.add(B0, B0)
520+
B2 = R.add(B1, B1)
521+
B3 = R.add(B2, B2)
522+
523+
return (A1, B3)
524+
525+
@I.ir_module
526+
class Expected:
527+
@R.function(pure=False)
528+
def main(A0: R.Tensor, B0: R.Tensor):
529+
# This dataflow block should be preserved in the output.
530+
with R.dataflow():
531+
A1 = R.add(A0, A0)
532+
R.output(A1)
533+
534+
R.print(format="impure_function")
535+
536+
with R.dataflow():
537+
B1 = R.add(B0, B0)
538+
B2 = R.add(B1, B1)
539+
B3 = R.add(B2, B2)
540+
R.output(B3)
541+
542+
return (A1, B3)
543+
544+
545+
class TestPreserveExistingDataflowBlocksAtEnd(ExtractCompare):
546+
"""Preserve existing DataflowBlocks
547+
548+
This is a regression test. In previous implementations, a
549+
DataflowBlock in the input, without enough bindings to become a
550+
new dataflow block, could be accidentally ommitted.
551+
552+
This test is identical to
553+
`TestPreserveExistingDataflowBlocksAtBeginning`, except that the
554+
existing dataflow block is at the end of the function.
555+
556+
"""
557+
558+
@I.ir_module
559+
class Before:
560+
@R.function(pure=False)
561+
def main(A0: R.Tensor, B0: R.Tensor):
562+
# This sequence is large enough that it may be converted
563+
# to a DataflowBlock.
564+
B1 = R.add(B0, B0)
565+
B2 = R.add(B1, B1)
566+
B3 = R.add(B2, B2)
567+
568+
R.print(format="impure_function")
569+
570+
# This DataflowBlock is below the minimum size for a new
571+
# block, but already exists in the input IRModule.
572+
with R.dataflow():
573+
A1 = R.add(A0, A0)
574+
R.output(A1)
575+
576+
return (A1, B3)
577+
578+
@I.ir_module
579+
class Expected:
580+
@R.function(pure=False)
581+
def main(A0: R.Tensor, B0: R.Tensor):
582+
with R.dataflow():
583+
B1 = R.add(B0, B0)
584+
B2 = R.add(B1, B1)
585+
B3 = R.add(B2, B2)
586+
R.output(B3)
587+
588+
R.print(format="impure_function")
589+
590+
# This dataflow block should be preserved in the output.
591+
with R.dataflow():
592+
A1 = R.add(A0, A0)
593+
R.output(A1)
594+
595+
return (A1, B3)
596+
597+
492598
if __name__ == "__main__":
493599
tvm.testing.main()

0 commit comments

Comments
 (0)