Skip to content

Commit 6a2459f

Browse files
authored
[Unity][Doc] Document passes that depend on DataflowBlocks and encourage using ConvertToDataflow (#16514)
* Indicate in doc comments which passes need dataflow blocks * Also encourage users to use ConvertToDataflow * Whitespace
1 parent cb6e4ee commit 6a2459f

File tree

3 files changed

+42
-12
lines changed

3 files changed

+42
-12
lines changed

include/tvm/relax/transform.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,9 @@ TVM_DLL Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map,
214214
Optional<String> func_name = NullOpt);
215215

216216
/*!
217-
* \brief Fold constant expressions.
217+
* \brief Fold constant expressions within dataflow blocks.
218+
*
219+
* \note ConvertToDataflow may need to be called first to provide dataflow blocks.
218220
*
219221
* \return The Pass.
220222
*/
@@ -458,6 +460,8 @@ class PatternCheckContext : public ObjectRef {
458460
* of the return value as the target. If it is not specified, the first return value will be the
459461
* target.
460462
* \return The Pass.
463+
*
464+
* \note ConvertToDataflow may need to be called first to provide dataflow blocks.
461465
*/
462466
TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads = NullOpt,
463467
int target_index = 0);
@@ -477,6 +481,8 @@ TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads = Nul
477481
* This must be True if the created composite functions are intended to be offloaded to
478482
* an external backend without using the MergeCompositeFunctions pass.
479483
* \return The Pass.
484+
*
485+
* \note Only operates within dataflow blocks. ConvertToDataflow may need to be called first.
480486
*/
481487
TVM_DLL Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants = true,
482488
bool annotate_codegen = false);
@@ -548,6 +554,7 @@ TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
548554
* \brief Layout conversion pass.
549555
* \param desired_layouts The desired layouts for some operators.
550556
* \return The Pass.
557+
* \note Operates only on dataflow blocks. ConvertToDataflow may need to be called first.
551558
*/
552559
TVM_DLL Pass ConvertLayout(Map<String, Array<String>> desired_layouts);
553560

@@ -564,10 +571,13 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2);
564571
* \brief Dead code elimination.
565572
* \sa RemoveAllUnused
566573
* Currently it removes:
567-
* 1. Unused local VarBindings in a DataflowBlock.
568-
* 2. Unused DataflowBlocks in a function.
569-
* 3. Unused Relax functions in the module.
574+
* 1. Unused local VarBindings
575+
* (those where the bound var is unused and no impure operation is used).
576+
* 2. Unused Relax functions in the module.
570577
* We detect the call chain from the entry function, and remove all unused functions.
578+
*
579+
* Any binding blocks that are left empty will be removed by the normalizer.
580+
*
571581
* \return The Pass.
572582
*/
573583
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
@@ -578,6 +588,7 @@ TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
578588
* Supported operators will be replaced by calls to `call_tir_inplace` that invoke in-place
579589
* PrimFunc implementations of those operators (which are based on the legalizations of those
580590
* operators).
591+
* \note ConvertToDataflow may need to be called first to provide dataflow blocks.
581592
* \return The pass.
582593
*/
583594
TVM_DLL Pass DataflowUseInplaceCalls();
@@ -589,6 +600,8 @@ TVM_DLL Pass DataflowUseInplaceCalls();
589600
* \param fp16_input_names The names of function parameters whose dtype should become fp16. The
590601
* function signature would change accordingly.
591602
* \return The Pass.
603+
*
604+
* \note Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first.
592605
*/
593606
TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype,
594607
Optional<Array<String>> fp16_input_names = NullOpt);

python/tvm/relax/transform/transform.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def Gradient(
5252
"""Reverse-mode automatic differentiation.
5353
5454
This pass will differentiate one function in the IRModule. Now the input function must have only
55-
one dataflow block.
55+
one dataflow block (ConvertToDataflow may need to be called first).
5656
5757
For a given function specified by `func_name`, it generates a new function with the name
5858
`func_name + "_adjoint"`. The new function computes the gradient of the **differentiation
@@ -260,6 +260,8 @@ def DataflowUseInplaceCalls() -> tvm.ir.transform.Pass:
260260
in-place PrimFunc implementations of those operators (which are based on the legalizations of
261261
those operators).
262262
263+
Note: ConvertToDataflow may need to be called first to provide dataflow blocks.
264+
263265
Returns
264266
-------
265267
ret: tvm.ir.transform.Pass
@@ -282,6 +284,8 @@ def ConvertToDataflow(min_size: int = 2) -> tvm.ir.transform.Pass:
282284
"""A pass that converts consecutive dataflow operations
283285
inside binding blocks into dataflow blocks.
284286
287+
Note: ConvertToDataflow may need to be called first.
288+
285289
Params
286290
------
287291
min_size: int
@@ -395,6 +399,8 @@ def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
395399
operation at runtime, instead of doing real data copy.
396400
Here "reshape-like" includes reshape, expand_dims, flatten, etc.
397401
402+
Note: Operates only in dataflow blocks. ConvertToDataflow may need to be called first.
403+
398404
Returns
399405
-------
400406
ret : tvm.ir.transform.Pass
@@ -584,7 +590,9 @@ def RunCodegen(
584590

585591

586592
def FoldConstant() -> tvm.ir.transform.Pass:
587-
"""Fold constant expressions.
593+
"""Fold constant expressions within dataflow blocks.
594+
595+
Note: ConvertToDataflow may need to be called first to provide dataflow blocks.
588596
589597
Returns
590598
-------
@@ -651,6 +659,8 @@ def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass:
651659
652660
A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function.
653661
662+
Note: ConvertToDataflow may need to be called first to provide dataflow blocks.
663+
654664
Parameters
655665
----------
656666
fuse_opt_level : int
@@ -764,6 +774,8 @@ def FuseOpsByPattern(
764774
765775
The end result is similar to FuseOps, but fusion is driven completely by the provided patterns.
766776
777+
Note: Only operates within dataflow blocks. ConvertToDataflow may need to be called first.
778+
767779
Parameters
768780
----------
769781
patterns : List[Union[FusionPattern, Tuple]]
@@ -1172,11 +1184,12 @@ def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> tvm.ir.t
11721184
"""Remove dead code in the IRModule.
11731185
Currently it removes:
11741186
1175-
1. Unused local VarBindings in a DataflowBlock.
1176-
2. Unused DataflowBlocks in a function.
1177-
3. Unused Relax functions in the module.
1187+
1. Unused local VarBindings
1188+
(those where the bound var is unused and no impure operation is used).
1189+
2. Unused Relax functions in the module.
11781190
We detect the call chain from the entry function, and remove all unused functions.
11791191
1192+
Any binding blocks that are left empty will be removed by the normalizer.
11801193
11811194
Notes
11821195
-----
@@ -1203,6 +1216,8 @@ def ToMixedPrecision(
12031216
"""Automatic mixed precision pass. Currently the pass assumes the input module to be fp32
12041217
only, and will automatically cast fp32 to fp16 for certain ops.
12051218
1219+
Note: Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first.
1220+
12061221
Parameters
12071222
----------
12081223
out_dtype : str

src/relax/transform/dead_code_elimination.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
* \sa tvm/relax/ir/binding_rewrite.cc
2525
*
2626
* Currently it removes:
27-
* 1. Unused local VarBindings in a DataflowBlock.
28-
* 2. Unused DataflowBlocks in a function.
29-
* 3. Unused Relax functions in the module.
27+
* 1. Unused local VarBindings
28+
* (those where the bound var is unused and no impure operation is used).
29+
* 2. Unused Relax functions in the module.
3030
* We detect the call chain from the entry function, and remove all unused functions.
31+
*
32+
* Any binding blocks that are left empty will be removed by the normalizer.
3133
*/
3234

3335
#include <tvm/relax/analysis.h>

0 commit comments

Comments
 (0)