Skip to content

Commit 8a73bc6

Browse files
committed
[Relax] Default purity as unknown, instead of pure
Prior to this commit, the constructors for `relax.Function` and `relax.FuncStructInfo` provided a default annotation as a pure function. While this could be overridden, it This commit updates the IR to store purity as `Optional<Bool>` instead of `bool`, and changes the default to `NullOpt`. A function's purity can have three possible values: - `Bool(true)`: The function is known to be pure. - `Bool(false)`: The function is known to be impure. - `NullOpt`: The function's purity is unknown. In most cases `Bool(false)` and `NullOpt` should be treated equivalently, as they both indicate a function that may contain side effects. However, in the future, inference of purity may occur when the purity is `NullOpt`, based on analysis of a function's body. A function with purity of `Bool(false)` is known to be impure, and further analysis is unnecessary.
1 parent 72b75fe commit 8a73bc6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+573
-243
lines changed

include/tvm/relax/analysis.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,28 +534,34 @@ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
534534
* \param own_name (Optional.) If we are checking a recursive function body,
535535
* the caller can pass the function's name so recursive calls
536536
* can be ignored in the check (must be a Var or GlobalVar).
537+
*
538+
* \param assumed_purity_when_unknown The purity to assume when not otherwise known.
539+
*
537540
* \return The impure expression, if one exists within the given
538541
* expression. Otherwise, NullOpt.
539542
* \note Relies on StructInfo annotations, so ensure that the module has been normalized first.
540543
* Also, an impure call in a *nested* function does *not* mean that the outer expression contains
541544
* an impure call--it only does if the nested function is *later called*.
542545
*/
543546
TVM_DLL Optional<Expr> FindImpureCall(const Expr& expr,
544-
const Optional<Expr>& own_name = Optional<Expr>(nullptr));
547+
const Optional<Expr>& own_name = Optional<Expr>(nullptr),
548+
bool assumed_purity_when_unknown = false);
545549

546550
/*!
547551
* \brief Check if the given expression (likely a function body) contains any impure calls.
548552
* \param expr The expression to be examined. If expr is a function, we check the body.
549553
* \param own_name (Optional.) If we are checking a recursive function body,
550554
* the caller can pass the function's name so recursive calls
551555
* can be ignored in the check (must be a Var or GlobalVar).
556+
* \param assumed_purity_when_unknown The purity to assume when not otherwise known
552557
* \return A boolean indicating if the expression contains any impure calls.
553558
* \note Relies on StructInfo annotations, so ensure that the module has been normalized first.
554559
* Also, an impure call in a *nested* function does *not* mean that the outer expression contains
555560
* an impure call--it only does if the nested function is *later called*.
556561
*/
557562
TVM_DLL bool ContainsImpureCall(const Expr& expr,
558-
const Optional<Expr>& own_name = Optional<Expr>(nullptr));
563+
const Optional<Expr>& own_name = Optional<Expr>(nullptr),
564+
bool assumed_purity_when_unknown = false);
559565

560566
/*!
561567
* \brief Check if the IRModule is well formed.

include/tvm/relax/expr.h

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,25 @@ class FunctionNode : public BaseFuncNode {
956956
/*! \brief The return type of the function. */
957957
StructInfo ret_struct_info;
958958
/*! \brief Whether the function is annotated as pure or not. */
959-
bool is_pure;
959+
Optional<Bool> is_pure;
960+
961+
/*! \brief Whether the function is known to be pure. */
962+
bool IsPure() const {
963+
if (is_pure.defined()) {
964+
return is_pure.value()->value;
965+
} else {
966+
return false;
967+
}
968+
}
969+
970+
/*! \brief Whether the function is known to be impure. */
971+
bool IsImpure() const {
972+
if (is_pure.defined()) {
973+
return !is_pure.value()->value;
974+
} else {
975+
return false;
976+
}
977+
}
960978

961979
void VisitAttrs(AttrVisitor* v) {
962980
v->Visit("params", &params);
@@ -1015,17 +1033,36 @@ class Function : public BaseFunc {
10151033
*
10161034
* \param span The source span of the expression.
10171035
*/
1036+
TVM_DLL explicit Function(Array<Var> params, Expr body,
1037+
Optional<StructInfo> ret_struct_info = NullOpt,
1038+
Optional<Bool> is_pure = NullOpt, DictAttrs attrs = DictAttrs(),
1039+
Span span = Span());
1040+
10181041
TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
1019-
bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span());
1042+
bool is_pure, DictAttrs attrs = DictAttrs(), Span span = Span())
1043+
: Function(params, body, ret_struct_info, Optional<Bool>(Bool(is_pure)), attrs, span) {}
10201044

10211045
/*!
10221046
* \brief Mimics the constructor but without body Expr.
1023-
* \note ret_struct_info is required, since it can not deduced by the body.
1047+
*
1048+
* \note ret_struct_info is required, since it cannot be deduced by
1049+
* the body.
10241050
*/
10251051
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info,
1026-
bool is_pure = true, DictAttrs attrs = DictAttrs(),
1052+
Optional<Bool> is_pure, DictAttrs attrs = DictAttrs(),
10271053
Span span = Span());
10281054

1055+
/*!
1056+
* \brief Mimics the constructor but without body Expr.
1057+
*
1058+
* \note ret_struct_info is required, since it cannot be deduced by
1059+
* the body.
1060+
*/
1061+
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info, bool is_pure,
1062+
DictAttrs attrs = DictAttrs(), Span span = Span()) {
1063+
return CreateEmpty(params, ret_struct_info, Optional<Bool>(Bool(is_pure)), attrs, span);
1064+
}
1065+
10291066
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
10301067
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
10311068
};

include/tvm/relax/struct_info.h

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -320,17 +320,50 @@ class FuncStructInfoNode : public StructInfoNode {
320320
Optional<StructInfoDeriveFunc> derive_func;
321321
/*!
322322
* \brief Whether the function is pure.
323-
* \note This parameter should be set to true only if the function is pure on all inputs.
324-
* If the function _may_ have visible side effects, set it to false.
323+
*
324+
* There are three possible state for this value.
325+
*
326+
* `Bool(true)`: The function is known to be pure.
327+
* `Bool(false)`: The function is known to be impure.
328+
* `NullOpt`: The function's purity is unknown.
329+
*
330+
* In most cases `Bool(false)` and `NullOpt` should be treated
331+
* equivalently, as they both indicate a function that may contain
332+
* side effects. However, inference of purity may occur when the
333+
* purity is `NullOpt`, based on analysis of a function's body. A
334+
* function with purity of `Bool(false)` is known to be impure, and
335+
* further analysis is unnecessary.
336+
*
337+
* \note This parameter should be set to true only if the function
338+
* is pure on all inputs. If the function _may_ have visible side
339+
* effects, set it to false.
325340
*/
326-
bool purity;
341+
Optional<Bool> purity;
327342

328343
/*!
329344
* \return Whether the func struct info is opaque.
330345
* \note We define a function as opaque we have no constraints on params.
331346
*/
332347
bool IsOpaque() const { return !params.defined(); }
333348

349+
/*! \brief Whether the FuncStructInfo is known to be pure */
350+
bool IsPure() const {
351+
if (purity.defined()) {
352+
return purity.value()->value;
353+
} else {
354+
return false;
355+
}
356+
}
357+
358+
/*! \brief Whether the function is known to be impure. */
359+
bool IsImpure() const {
360+
if (purity.defined()) {
361+
return !purity.value()->value;
362+
} else {
363+
return false;
364+
}
365+
}
366+
334367
void VisitAttrs(AttrVisitor* v) {
335368
v->Visit("params", &params);
336369
v->Visit("ret", &ret);
@@ -365,42 +398,82 @@ class FuncStructInfo : public StructInfo {
365398
* \brief Constructor from parameter struct info and return value struct info.
366399
* \param params The struct info of function parameters.
367400
* \param ret The return value struct info.
368-
* \param purity The purity of the function (true by default).
401+
* \param purity The purity of the function (unknown by default).
369402
* \param span The span of the AST.
370403
*
371404
* \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from
372405
* params. If you are unsure, you can always erase ret to static.
373406
*/
374-
TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, bool purity = true,
407+
TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, Optional<Bool> purity = NullOpt,
375408
Span span = Span());
376409

410+
/*!
411+
* \brief Constructor from parameter struct info and return value struct info.
412+
* \param params The struct info of function parameters.
413+
* \param ret The return value struct info.
414+
* \param purity The purity of the function.
415+
* \param span The span of the AST.
416+
*
417+
* \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from
418+
* params. If you are unsure, you can always erase ret to static.
419+
*/
420+
TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, bool purity, Span span = Span())
421+
: FuncStructInfo(params, ret, Optional<Bool>(Bool(purity)), span) {}
422+
423+
/*!
424+
* \brief Constructing an opaque function struct info using derive_func.
425+
*
426+
* \param derive_func Derivation function.
427+
* \param purity The purity of the function (unknown by default).
428+
* \param span The span of the AST.
429+
*
430+
* \return The FuncStructInfo for opaque packedfunc.
431+
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
432+
*/
433+
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func,
434+
Optional<Bool> purity = NullOpt, Span span = Span());
435+
377436
/*!
378437
* \brief Constructing an opaque function struct info using derive_func.
379438
*
380439
* \param derive_func Derivation function.
381-
* \param purity The purity of the function
382-
* (false by default: most external functions are not pure).
440+
* \param purity The purity of the function.
441+
* \param span The span of the AST.
442+
*
443+
* \return The FuncStructInfo for opaque packedfunc.
444+
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
445+
*/
446+
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity,
447+
Span span = Span()) {
448+
return OpaqueFunc(derive_func, Optional<Bool>(Bool(purity)), span);
449+
}
450+
451+
/*!
452+
* \brief Construct an opaque function using from return struct info.
453+
*
454+
* \param ret The struct info of the return value.
455+
* \param purity The purity of the function (unknown by default).
383456
* \param span The span of the AST.
384457
*
385458
* \return The FuncStructInfo for opaque packedfunc.
386459
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
387460
*/
388-
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity = false,
389-
Span span = Span());
461+
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(),
462+
Optional<Bool> purity = NullOpt, Span span = Span());
390463

391464
/*!
392465
* \brief Construct an opaque function using from return struct info.
393466
*
394467
* \param ret The struct info of the return value.
395-
* \param purity The purity of the function
396-
* (false by default: most external functions are not pure).
468+
* \param purity The purity of the function.
397469
* \param span The span of the AST.
398470
*
399471
* \return The FuncStructInfo for opaque packedfunc.
400472
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
401473
*/
402-
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), bool purity = false,
403-
Span span = Span());
474+
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret, bool purity, Span span = Span()) {
475+
return OpaqueFunc(ret, Optional<Bool>(Bool(purity)), span);
476+
}
404477

405478
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode);
406479
};

include/tvm/relax/utils.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include <tvm/relax/expr.h>
3030
#include <tvm/runtime/logging.h>
3131

32+
#include <optional>
33+
3234
namespace tvm {
3335
namespace relax {
3436

@@ -122,6 +124,19 @@ TVM_DLL bool IsLeafOrTuple(const Expr& expr);
122124
*/
123125
TVM_DLL bool IsImpureCall(const Call& call);
124126

127+
/*!
128+
* \brief Return the purity of the given Call node. If the callee is a
129+
* general expression, this simply requires checking the purity field
130+
* of the FuncStructInfo. If it is an Op, then this checks the
131+
* `fPurity` field.
132+
*
133+
* \param call The input call
134+
*
135+
* \return True if the call is known to be pure. False if the call is
136+
* known to be impure. std::nullopt if the call's purity is unknown.
137+
*/
138+
TVM_DLL std::optional<bool> GetPurity(const Call& call);
139+
125140
/*!
126141
* \brief Copy the given function. All variables that are bound inside the original function
127142
* would be copied to satisfy the restriction in the well-formed check: Variables in

include/tvm/script/ir_builder/relax/ir.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ namespace relax {
3737
* \param is_private Whether the function is annotated as private.
3838
* \return The created ir_builder Function frame.
3939
*/
40-
TVM_DLL FunctionFrame Function(const Bool& is_pure, const Bool& is_private);
40+
TVM_DLL FunctionFrame Function(const Optional<Bool>& is_pure, const Bool& is_private);
4141

4242
/*!
4343
* \brief Add a parameter to the last function frame.

python/tvm/relax/analysis/analysis.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,12 @@ def has_reshape_pattern(func: tir.PrimFunc) -> bool:
372372
return _ffi_api.has_reshape_pattern(func) # type: ignore
373373

374374

375-
def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] = None) -> bool:
376-
"""
377-
Check if the given expression (likely a function body) contains any impure calls.
375+
def contains_impure_call(
376+
expr: Expr,
377+
own_name: Optional[Union[Var, GlobalVar]] = None,
378+
assumed_purity_when_unknown: bool = False,
379+
) -> bool:
380+
"""Check if the given expression (likely a function body) contains any impure calls.
378381
379382
Parameters
380383
----------
@@ -385,6 +388,10 @@ def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] =
385388
For a recursive function, the analysis can ignore the self-calls
386389
for checking purity.
387390
391+
assumed_purity_when_unknown: bool
392+
The purity to assume when not otherwise known. Defaults to
393+
False, treating unknown purity as impure
394+
388395
Returns
389396
-------
390397
ret : bool
@@ -397,7 +404,7 @@ def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] =
397404
Also, an impure call in a *nested* function does *not* mean that the outer expression contains
398405
an impure call--it only does if the nested function is *later called*.
399406
"""
400-
return _ffi_api.contains_impure_call(expr, own_name)
407+
return _ffi_api.contains_impure_call(expr, own_name, assumed_purity_when_unknown)
401408

402409

403410
def get_var2val(func: Function) -> Dict[Var, Expr]:

python/tvm/relax/block_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def function(
209209
name: str,
210210
params: Optional[Union[Var, Tuple, List[Var]]] = None,
211211
attrs: Optional[Dict[str, Object]] = None,
212-
pure: bool = True,
212+
pure: Optional[bool] = None,
213213
private: bool = False,
214214
) -> FunctionScope:
215215
"""Annotate a Relax function.
@@ -227,7 +227,7 @@ def function(
227227
attrs : Dict[str, Object], optional
228228
The function attrs
229229
230-
pure : bool, optional
230+
pure : Optional[bool]
231231
Whether the function is annotated as pure.
232232
233233
private : bool, optional
@@ -236,6 +236,7 @@ def function(
236236
If it is not private and not an inner function, then it will have
237237
a global symbol attribute (mapped to the function's name)
238238
239+
239240
Returns
240241
-------
241242
ret: FunctionScope

python/tvm/relax/distributed/global_info.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,8 @@ def __init__(
4444
):
4545
if isinstance(shape, (list, tuple)):
4646
shape = ShapeTuple(shape)
47-
device_range = None
48-
if isinstance(device_ids, Range):
49-
device_range = device_ids
50-
device_ids = []
51-
self.__init_handle_by_constructor__(
52-
ffi.DeviceMesh, shape, device_ids, device_range
53-
) # type: ignore
47+
48+
self.__init_handle_by_constructor__(ffi.DeviceMesh, shape, device_ids) # type: ignore
5449

5550

5651
def device_mesh(shape: ShapeTuple, device_ids: Union[List[int], Range]) -> DeviceMesh:

python/tvm/relax/expr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ class Function(BaseFunc, Scriptable):
975975
params: List[Var]
976976
body: Expr
977977
ret_struct_info: StructInfo
978-
is_pure: bool
978+
is_pure: Optional[bool]
979979
attrs: tvm.ir.DictAttrs
980980
span: Optional[Span]
981981

@@ -984,7 +984,7 @@ def __init__(
984984
params: List[Var],
985985
body: Expr,
986986
ret_struct_info: Optional[StructInfo] = None,
987-
is_pure: Optional[bool] = True,
987+
is_pure: Optional[bool] = None,
988988
attrs: Optional[tvm.ir.DictAttrs] = None,
989989
span: Optional[Span] = None,
990990
) -> None:
@@ -1002,7 +1002,7 @@ def __init__(
10021002
def create_empty(
10031003
params: List[Var],
10041004
ret_struct_info: StructInfo,
1005-
is_pure: Optional[bool] = True,
1005+
is_pure: Optional[bool] = None,
10061006
attrs: Optional[tvm.ir.DictAttrs] = None,
10071007
span: Optional[Span] = None,
10081008
):

0 commit comments

Comments
 (0)