2828#include < tvm/relax/transform.h>
2929#include < tvm/relax/utils.h>
3030
31+ #include < optional>
32+
3133namespace tvm {
3234namespace relax {
3335
@@ -39,85 +41,100 @@ class DataflowBlockExtractor : public ExprMutator {
3941 Array<BindingBlock> new_blocks;
4042 Expr new_body = VisitExpr (seq->body );
4143 bool changed = !new_body.same_as (seq->body );
42- bool dataflow_streak = false ;
43- Array<Binding> dataflow_bindings;
44+
45+ // Accumulated bindings that are not going to be added to a
46+ // DataflowBlock, either because they would be illegal within a
47+ // DataflowBlock, or because there were insufficient bindings to
48+ // make a dataflowblock. Because these bindings occur prior to
49+ // `dataflow_bindings`, this array may only be accumulated into
50+ // when `dataflow_bindings` is empty.
4451 Array<Binding> non_dataflow_bindings;
4552
53+ // Current bindings that may legally be added to a DataflowBlock.
54+ Array<Binding> dataflow_bindings;
55+
56+ // If present, a DataflowBlock whose bindings are currently in
57+ // `dataflow_bindings`. Used to propagate DataflowBlock to the
58+ // output, even if it doesn't meet the minimum size.
59+ Optional<DataflowBlock> input_dataflow_block;
60+
61+ // Handle any bindings currently in `dataflow_bindings`. These
62+ // are either pushed to their own block, or to the end of
63+ // `non_dataflow_bindings`, depending on whether the bindings meet
64+ // the minimum size requirement.
65+ auto push_dataflow_bindings = [&]() {
66+ if (dataflow_bindings.empty ()) {
67+ // No Dataflow bindings, so no action required.
68+ return ;
69+ }
70+ if (dataflow_bindings.size () < min_size_ && !input_dataflow_block) {
71+ // The df block is below the minimum length, and no input
72+ // DataflowBlock needs to be preserved. Combine the blocks
73+ // and reset the dataflow collection.
74+
75+ non_dataflow_bindings.insert (non_dataflow_bindings.end (), dataflow_bindings.begin (),
76+ dataflow_bindings.end ());
77+
78+ } else {
79+ // A new DataflowBlock can be generated, with bindings that
80+ // occur after the non-dataflow bindings.
81+ new_blocks.push_back (BindingBlock (non_dataflow_bindings));
82+ new_blocks.push_back (DataflowBlock (dataflow_bindings));
83+ non_dataflow_bindings = {};
84+
85+ // Making a dataflow block doesn't imply that the function was
86+ // changed. A change requires that this either be a new
87+ // dataflow block, or have additional dataflow bindings in the
88+ // current block.
89+ changed = changed || !input_dataflow_block.defined () ||
90+ input_dataflow_block.value ()->bindings .size () != dataflow_bindings.size ();
91+ }
92+
93+ dataflow_bindings = {};
94+ input_dataflow_block = NullOpt;
95+ };
96+
4697 for (auto block : seq->blocks ) {
4798 BindingBlock new_block = this ->VisitBindingBlock (block);
4899 changed = changed || !new_block.same_as (block);
49100
50101 // For an existing dataflow block, we add to the current streak
51102 // or start a new streak in case there will be more dataflow operations
52103 // coming up
53- if (new_block.as <DataflowBlock>()) {
54- if (!dataflow_streak) {
55- dataflow_streak = true ;
56- }
104+ if (auto dataflow_block = new_block.as <DataflowBlock>()) {
57105 dataflow_bindings.insert (dataflow_bindings.end (), new_block->bindings .begin (),
58106 new_block->bindings .end ());
107+ input_dataflow_block = dataflow_block;
59108 continue ;
60109 }
61110
62111 // for a binding block, attempt to extract dataflow blocks inside
63112 auto binding_block = Downcast<BindingBlock>(new_block);
64- for (size_t i = 0 ; i < binding_block->bindings .size (); i++) {
65- auto binding = binding_block->bindings [i];
113+ for (const auto & binding : binding_block->bindings ) {
66114 Expr value = GetBoundValue (binding);
67115 // dataflow values: not an if node and not an impure call
68116 bool is_dataflow = (!value.as <IfNode>()) &&
69117 (!(value.as <CallNode>() && IsImpureCall (Downcast<Call>(value))));
70- if (!dataflow_streak) {
71- // we can start a dataflow streak
72- if (is_dataflow) {
73- dataflow_streak = true ;
74- dataflow_bindings = {binding};
75- } else {
76- non_dataflow_bindings.push_back (binding);
77- }
118+ if (is_dataflow) {
119+ // extend the streak
120+ dataflow_bindings.push_back (binding);
78121 } else {
79- if (is_dataflow) {
80- // extend the streak
81- dataflow_bindings.push_back (binding);
82- } else {
83- // this is the end of the streak
84- dataflow_streak = false ;
85-
86- // if the df block is below the minimum length, combine the blocks
87- // and reset the dataflow collection
88- if (dataflow_bindings.size () < min_size_) {
89- non_dataflow_bindings.insert (non_dataflow_bindings.end (), dataflow_bindings.begin (),
90- dataflow_bindings.end ());
91- dataflow_bindings = {};
92- } else {
93- // otherwise insert both collections
94- changed = true ;
95- new_blocks.push_back (BindingBlock (non_dataflow_bindings));
96- new_blocks.push_back (DataflowBlock (dataflow_bindings));
97- non_dataflow_bindings = {};
98- dataflow_bindings = {};
99- }
100- non_dataflow_bindings.push_back (binding);
101- }
122+ // End the streak, if one currently exists.
123+ push_dataflow_bindings ();
124+ non_dataflow_bindings.push_back (binding);
102125 }
103126 }
104127 }
105128
106129 // handle any remaining bindings
107- if (dataflow_bindings.size () < min_size_) {
108- non_dataflow_bindings.insert (non_dataflow_bindings.end (), dataflow_bindings.begin (),
109- dataflow_bindings.end ());
110- new_blocks.push_back (BindingBlock (non_dataflow_bindings));
111- } else {
112- changed = true ;
113- new_blocks.push_back (BindingBlock (non_dataflow_bindings));
114- new_blocks.push_back (DataflowBlock (dataflow_bindings));
115- }
130+ push_dataflow_bindings ();
131+ new_blocks.push_back (BindingBlock (non_dataflow_bindings));
116132
117- if (!changed) {
133+ if (changed) {
134+ return SeqExpr (new_blocks, new_body);
135+ } else {
118136 return GetRef<SeqExpr>(seq);
119137 }
120- return SeqExpr (new_blocks, new_body);
121138 }
122139
123140 private:
0 commit comments