Skip to content

Commit 1698f50

Browse files
committed
[Relax] Default purity as unknown, instead of pure
Prior to this commit, the constructors for `relax.Function` and `relax.FuncStructInfo` provided a default annotation as a pure function. While this could be overridden, it This commit updates the IR to store purity as `Optional<Bool>` instead of `bool`, and changes the default to `NullOpt`. A function's purity can have three possible values: - `Bool(true)`: The function is known to be pure. - `Bool(false)`: The function is known to be impure. - `NullOpt`: The function's purity is unknown. In most cases `Bool(false)` and `NullOpt` should be treated equivalently, as they both indicate a function that may contain side effects. However, in the future, inference of purity may occur when the purity is `NullOpt`, based on analysis of a function's body. A function with purity of `Bool(false)` is known to be impure, and further analysis is unnecessary.
1 parent b2204ae commit 1698f50

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+563
-237
lines changed

include/tvm/relax/analysis.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,28 +521,34 @@ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
521521
* \param own_name (Optional.) If we are checking a recursive function body,
522522
* the caller can pass the function's name so recursive calls
523523
* can be ignored in the check (must be a Var or GlobalVar).
524+
*
525+
* \param assumed_purity_when_unknown The purity to assume when not otherwise known.
526+
*
524527
* \return The impure expression, if one exists within the given
525528
* expression. Otherwise, NullOpt.
526529
* \note Relies on StructInfo annotations, so ensure that the module has been normalized first.
527530
* Also, an impure call in a *nested* function does *not* mean that the outer expression contains
528531
* an impure call--it only does if the nested function is *later called*.
529532
*/
530533
TVM_DLL Optional<Expr> FindImpureCall(const Expr& expr,
531-
const Optional<Expr>& own_name = Optional<Expr>(nullptr));
534+
const Optional<Expr>& own_name = Optional<Expr>(nullptr),
535+
bool assumed_purity_when_unknown = false);
532536

533537
/*!
534538
* \brief Check if the given expression (likely a function body) contains any impure calls.
535539
* \param expr The expression to be examined. If expr is a function, we check the body.
536540
* \param own_name (Optional.) If we are checking a recursive function body,
537541
* the caller can pass the function's name so recursive calls
538542
* can be ignored in the check (must be a Var or GlobalVar).
543+
* \param assumed_purity_when_unknown The purity to assume when not otherwise known
539544
* \return A boolean indicating if the expression contains any impure calls.
540545
* \note Relies on StructInfo annotations, so ensure that the module has been normalized first.
541546
* Also, an impure call in a *nested* function does *not* mean that the outer expression contains
542547
* an impure call--it only does if the nested function is *later called*.
543548
*/
544549
TVM_DLL bool ContainsImpureCall(const Expr& expr,
545-
const Optional<Expr>& own_name = Optional<Expr>(nullptr));
550+
const Optional<Expr>& own_name = Optional<Expr>(nullptr),
551+
bool assumed_purity_when_unknown = false);
546552

547553
/*!
548554
* \brief Check if the IRModule is well formed.

include/tvm/relax/expr.h

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,25 @@ class FunctionNode : public BaseFuncNode {
944944
/*! \brief The return type of the function. */
945945
StructInfo ret_struct_info;
946946
/*! \brief Whether the function is annotated as pure or not. */
947-
bool is_pure;
947+
Optional<Bool> is_pure;
948+
949+
/*! \brief Whether the function is known to be pure. */
950+
bool IsPure() const {
951+
if (is_pure.defined()) {
952+
return is_pure.value()->value;
953+
} else {
954+
return false;
955+
}
956+
}
957+
958+
/*! \brief Whether the function is known to be impure. */
959+
bool IsImpure() const {
960+
if (is_pure.defined()) {
961+
return !is_pure.value()->value;
962+
} else {
963+
return false;
964+
}
965+
}
948966

949967
void VisitAttrs(AttrVisitor* v) {
950968
v->Visit("params", &params);
@@ -982,17 +1000,36 @@ class FunctionNode : public BaseFuncNode {
9821000

9831001
class Function : public BaseFunc {
9841002
public:
1003+
TVM_DLL explicit Function(Array<Var> params, Expr body,
1004+
Optional<StructInfo> ret_struct_info = NullOpt,
1005+
Optional<Bool> is_pure = NullOpt, DictAttrs attrs = DictAttrs(),
1006+
Span span = Span());
1007+
9851008
TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
986-
bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span());
1009+
bool is_pure, DictAttrs attrs = DictAttrs(), Span span = Span())
1010+
: Function(params, body, ret_struct_info, Optional<Bool>(Bool(is_pure)), attrs, span) {}
9871011

9881012
/*!
9891013
* \brief Mimics the constructor but without body Expr.
990-
* \note ret_struct_info is required, since it can not deduced by the body.
1014+
*
1015+
* \note ret_struct_info is required, since it cannot be deduced by
1016+
* the body.
9911017
*/
9921018
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info,
993-
bool is_pure = true, DictAttrs attrs = DictAttrs(),
1019+
Optional<Bool> is_pure, DictAttrs attrs = DictAttrs(),
9941020
Span span = Span());
9951021

1022+
/*!
1023+
* \brief Mimics the constructor but without body Expr.
1024+
*
1025+
* \note ret_struct_info is required, since it cannot be deduced by
1026+
* the body.
1027+
*/
1028+
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info, bool is_pure,
1029+
DictAttrs attrs = DictAttrs(), Span span = Span()) {
1030+
return CreateEmpty(params, ret_struct_info, Optional<Bool>(Bool(is_pure)), attrs, span);
1031+
}
1032+
9961033
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
9971034
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
9981035
};

include/tvm/relax/struct_info.h

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -320,17 +320,50 @@ class FuncStructInfoNode : public StructInfoNode {
320320
Optional<StructInfoDeriveFunc> derive_func;
321321
/*!
322322
* \brief Whether the function is pure.
323-
* \note This parameter should be set to true only if the function is pure on all inputs.
324-
* If the function _may_ have visible side effects, set it to false.
323+
*
324+
* There are three possible state for this value.
325+
*
326+
* `Bool(true)`: The function is known to be pure.
327+
* `Bool(false)`: The function is known to be impure.
328+
* `NullOpt`: The function's purity is unknown.
329+
*
330+
* In most cases `Bool(false)` and `NullOpt` should be treated
331+
* equivalently, as they both indicate a function that may contain
332+
* side effects. However, inference of purity may occur when the
333+
* purity is `NullOpt`, based on analysis of a function's body. A
334+
* function with purity of `Bool(false)` is known to be impure, and
335+
* further analysis is unnecessary.
336+
*
337+
* \note This parameter should be set to true only if the function
338+
* is pure on all inputs. If the function _may_ have visible side
339+
* effects, set it to false.
325340
*/
326-
bool purity;
341+
Optional<Bool> purity;
327342

328343
/*!
329344
* \return Whether the func struct info is opaque.
330345
* \note We define a function as opaque we have no constraints on params.
331346
*/
332347
bool IsOpaque() const { return !params.defined(); }
333348

349+
/*! \brief Whether the FuncStructInfo is known to be pure */
350+
bool IsPure() const {
351+
if (purity.defined()) {
352+
return purity.value()->value;
353+
} else {
354+
return false;
355+
}
356+
}
357+
358+
/*! \brief Whether the function is known to be impure. */
359+
bool IsImpure() const {
360+
if (purity.defined()) {
361+
return !purity.value()->value;
362+
} else {
363+
return false;
364+
}
365+
}
366+
334367
void VisitAttrs(AttrVisitor* v) {
335368
v->Visit("params", &params);
336369
v->Visit("ret", &ret);
@@ -365,42 +398,82 @@ class FuncStructInfo : public StructInfo {
365398
* \brief Constructor from parameter struct info and return value struct info.
366399
* \param params The struct info of function parameters.
367400
* \param ret The return value struct info.
368-
* \param purity The purity of the function (true by default).
401+
* \param purity The purity of the function (unknown by default).
369402
* \param span The span of the AST.
370403
*
371404
* \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from
372405
* params. If you are unsure, you can always erase ret to static.
373406
*/
374-
TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, bool purity = true,
407+
TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, Optional<Bool> purity = NullOpt,
375408
Span span = Span());
376409

410+
/*!
411+
* \brief Constructor from parameter struct info and return value struct info.
412+
* \param params The struct info of function parameters.
413+
* \param ret The return value struct info.
414+
* \param purity The purity of the function.
415+
* \param span The span of the AST.
416+
*
417+
* \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from
418+
* params. If you are unsure, you can always erase ret to static.
419+
*/
420+
TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, bool purity, Span span = Span())
421+
: FuncStructInfo(params, ret, Optional<Bool>(Bool(purity)), span) {}
422+
423+
/*!
424+
* \brief Constructing an opaque function struct info using derive_func.
425+
*
426+
* \param derive_func Derivation function.
427+
* \param purity The purity of the function (unknown by default).
428+
* \param span The span of the AST.
429+
*
430+
* \return The FuncStructInfo for opaque packedfunc.
431+
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
432+
*/
433+
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func,
434+
Optional<Bool> purity = NullOpt, Span span = Span());
435+
377436
/*!
378437
* \brief Constructing an opaque function struct info using derive_func.
379438
*
380439
* \param derive_func Derivation function.
381-
* \param purity The purity of the function
382-
* (false by default: most external functions are not pure).
440+
* \param purity The purity of the function.
441+
* \param span The span of the AST.
442+
*
443+
* \return The FuncStructInfo for opaque packedfunc.
444+
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
445+
*/
446+
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity,
447+
Span span = Span()) {
448+
return OpaqueFunc(derive_func, Optional<Bool>(Bool(purity)), span);
449+
}
450+
451+
/*!
452+
* \brief Construct an opaque function using from return struct info.
453+
*
454+
* \param ret The struct info of the return value.
455+
* \param purity The purity of the function (unknown by default).
383456
* \param span The span of the AST.
384457
*
385458
* \return The FuncStructInfo for opaque packedfunc.
386459
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
387460
*/
388-
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity = false,
389-
Span span = Span());
461+
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(),
462+
Optional<Bool> purity = NullOpt, Span span = Span());
390463

391464
/*!
392465
* \brief Construct an opaque function using from return struct info.
393466
*
394467
* \param ret The struct info of the return value.
395-
* \param purity The purity of the function
396-
* (false by default: most external functions are not pure).
468+
* \param purity The purity of the function.
397469
* \param span The span of the AST.
398470
*
399471
* \return The FuncStructInfo for opaque packedfunc.
400472
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
401473
*/
402-
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), bool purity = false,
403-
Span span = Span());
474+
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret, bool purity, Span span = Span()) {
475+
return OpaqueFunc(ret, Optional<Bool>(Bool(purity)), span);
476+
}
404477

405478
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode);
406479
};

include/tvm/relax/utils.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include <tvm/relax/expr.h>
3030
#include <tvm/runtime/logging.h>
3131

32+
#include <optional>
33+
3234
namespace tvm {
3335
namespace relax {
3436

@@ -115,6 +117,19 @@ TVM_DLL bool IsLeafOrTuple(const Expr& expr);
115117
*/
116118
TVM_DLL bool IsImpureCall(const Call& call);
117119

120+
/*!
121+
* \brief Return the purity of the given Call node. If the callee is a
122+
* general expression, this simply requires checking the purity field
123+
* of the FuncStructInfo. If it is an Op, then this checks the
124+
* `fPurity` field.
125+
*
126+
* \param call The input call
127+
*
128+
* \return True if the call is known to be pure. False if the call is
129+
* known to be impure. std::nullopt if the call's purity is unknown.
130+
*/
131+
TVM_DLL std::optional<bool> GetPurity(const Call& call);
132+
118133
/*!
119134
* \brief Copy the given function. All variables that are bound inside the original function
120135
* would be copied to satisfy the restriction in the well-formed check: Variables in

include/tvm/script/ir_builder/relax/ir.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ namespace relax {
3737
* \param is_private Whether the function is annotated as private.
3838
* \return The created ir_builder Function frame.
3939
*/
40-
TVM_DLL FunctionFrame Function(const Bool& is_pure, const Bool& is_private);
40+
TVM_DLL FunctionFrame Function(const Optional<Bool>& is_pure, const Bool& is_private);
4141

4242
/*!
4343
* \brief Add a parameter to the last function frame.

python/tvm/relax/analysis/analysis.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,12 @@ def has_reshape_pattern(func: tir.PrimFunc) -> bool:
345345
return _ffi_api.has_reshape_pattern(func) # type: ignore
346346

347347

348-
def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] = None) -> bool:
349-
"""
350-
Check if the given expression (likely a function body) contains any impure calls.
348+
def contains_impure_call(
349+
expr: Expr,
350+
own_name: Optional[Union[Var, GlobalVar]] = None,
351+
assumed_purity_when_unknown: bool = False,
352+
) -> bool:
353+
"""Check if the given expression (likely a function body) contains any impure calls.
351354
352355
Parameters
353356
----------
@@ -358,6 +361,10 @@ def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] =
358361
For a recursive function, the analysis can ignore the self-calls
359362
for checking purity.
360363
364+
assumed_purity_when_unknown: bool
365+
The purity to assume when not otherwise known. Defaults to
366+
False, treating unknown purity as impure
367+
361368
Returns
362369
-------
363370
ret : bool
@@ -370,7 +377,7 @@ def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] =
370377
Also, an impure call in a *nested* function does *not* mean that the outer expression contains
371378
an impure call--it only does if the nested function is *later called*.
372379
"""
373-
return _ffi_api.contains_impure_call(expr, own_name)
380+
return _ffi_api.contains_impure_call(expr, own_name, assumed_purity_when_unknown)
374381

375382

376383
def get_var2val(func: Function) -> Dict[Var, Expr]:

python/tvm/relax/block_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def function(
209209
name: str,
210210
params: Optional[Union[Var, Tuple, List[Var]]] = None,
211211
attrs: Optional[Dict[str, Object]] = None,
212-
pure: bool = True,
212+
pure: Optional[bool] = None,
213213
private: bool = False,
214214
) -> FunctionScope:
215215
"""Annotate a Relax function.
@@ -227,7 +227,7 @@ def function(
227227
attrs : Dict[str, Object], optional
228228
The function attrs
229229
230-
pure : bool, optional
230+
pure : Optional[bool]
231231
Whether the function is annotated as pure.
232232
233233
private : bool, optional
@@ -236,6 +236,7 @@ def function(
236236
If it is not private and not an inner function, then it will have
237237
a global symbol attribute (mapped to the function's name)
238238
239+
239240
Returns
240241
-------
241242
ret: FunctionScope

python/tvm/relax/distributed/global_info.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,8 @@ def __init__(
4444
):
4545
if isinstance(shape, (list, tuple)):
4646
shape = ShapeTuple(shape)
47-
device_range = None
48-
if isinstance(device_ids, Range):
49-
device_range = device_ids
50-
device_ids = []
51-
self.__init_handle_by_constructor__(
52-
ffi.DeviceMesh, shape, device_ids, device_range
53-
) # type: ignore
47+
48+
self.__init_handle_by_constructor__(ffi.DeviceMesh, shape, device_ids) # type: ignore
5449

5550

5651
def device_mesh(shape: ShapeTuple, device_ids: Union[List[int], Range]) -> DeviceMesh:

python/tvm/relax/expr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ class Function(BaseFunc, Scriptable):
878878
params: List[Var]
879879
body: Expr
880880
ret_struct_info: StructInfo
881-
is_pure: bool
881+
is_pure: Optional[bool]
882882
attrs: tvm.ir.DictAttrs
883883
span: Optional[Span]
884884

@@ -887,7 +887,7 @@ def __init__(
887887
params: List[Var],
888888
body: Expr,
889889
ret_struct_info: Optional[StructInfo] = None,
890-
is_pure: Optional[bool] = True,
890+
is_pure: Optional[bool] = None,
891891
attrs: Optional[tvm.ir.DictAttrs] = None,
892892
span: Optional[Span] = None,
893893
) -> None:
@@ -905,7 +905,7 @@ def __init__(
905905
def create_empty(
906906
params: List[Var],
907907
ret_struct_info: StructInfo,
908-
is_pure: Optional[bool] = True,
908+
is_pure: Optional[bool] = None,
909909
attrs: Optional[tvm.ir.DictAttrs] = None,
910910
span: Optional[Span] = None,
911911
):

0 commit comments

Comments
 (0)