Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -534,28 +534,34 @@ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
* \param own_name (Optional.) If we are checking a recursive function body,
* the caller can pass the function's name so recursive calls
* can be ignored in the check (must be a Var or GlobalVar).
*
* \param assumed_purity_when_unknown The purity to assume when not otherwise known.
*
* \return The impure expression, if one exists within the given
* expression. Otherwise, NullOpt.
* \note Relies on StructInfo annotations, so ensure that the module has been normalized first.
* Also, an impure call in a *nested* function does *not* mean that the outer expression contains
* an impure call--it only does if the nested function is *later called*.
*/
TVM_DLL Optional<Expr> FindImpureCall(const Expr& expr,
const Optional<Expr>& own_name = Optional<Expr>(nullptr));
const Optional<Expr>& own_name = Optional<Expr>(nullptr),
bool assumed_purity_when_unknown = false);

/*!
* \brief Check if the given expression (likely a function body) contains any impure calls.
* \param expr The expression to be examined. If expr is a function, we check the body.
* \param own_name (Optional.) If we are checking a recursive function body,
* the caller can pass the function's name so recursive calls
* can be ignored in the check (must be a Var or GlobalVar).
* \param assumed_purity_when_unknown The purity to assume when not otherwise known
* \return A boolean indicating if the expression contains any impure calls.
* \note Relies on StructInfo annotations, so ensure that the module has been normalized first.
* Also, an impure call in a *nested* function does *not* mean that the outer expression contains
* an impure call--it only does if the nested function is *later called*.
*/
TVM_DLL bool ContainsImpureCall(const Expr& expr,
const Optional<Expr>& own_name = Optional<Expr>(nullptr));
const Optional<Expr>& own_name = Optional<Expr>(nullptr),
bool assumed_purity_when_unknown = false);

/*!
* \brief Check if the IRModule is well formed.
Expand Down
45 changes: 41 additions & 4 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,25 @@ class FunctionNode : public BaseFuncNode {
/*! \brief The return type of the function. */
StructInfo ret_struct_info;
/*! \brief Whether the function is annotated as pure or not. */
bool is_pure;
Optional<Bool> is_pure;

/*! \brief Whether the function is known to be pure. */
bool IsPure() const {
if (is_pure.defined()) {
return is_pure.value()->value;
} else {
return false;
}
}

/*! \brief Whether the function is known to be impure. */
bool IsImpure() const {
if (is_pure.defined()) {
return !is_pure.value()->value;
} else {
return false;
}
}

void VisitAttrs(AttrVisitor* v) {
v->Visit("params", &params);
Expand Down Expand Up @@ -1015,17 +1033,36 @@ class Function : public BaseFunc {
*
* \param span The source span of the expression.
*/
TVM_DLL explicit Function(Array<Var> params, Expr body,
Optional<StructInfo> ret_struct_info = NullOpt,
Optional<Bool> is_pure = NullOpt, DictAttrs attrs = DictAttrs(),
Span span = Span());

TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span());
bool is_pure, DictAttrs attrs = DictAttrs(), Span span = Span())
: Function(params, body, ret_struct_info, Optional<Bool>(Bool(is_pure)), attrs, span) {}

/*!
* \brief Mimics the constructor but without body Expr.
* \note ret_struct_info is required, since it can not deduced by the body.
*
* \note ret_struct_info is required, since it cannot be deduced by
* the body.
*/
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info,
bool is_pure = true, DictAttrs attrs = DictAttrs(),
Optional<Bool> is_pure, DictAttrs attrs = DictAttrs(),
Span span = Span());

/*!
* \brief Mimics the constructor but without body Expr.
*
* \note ret_struct_info is required, since it cannot be deduced by
* the body.
*/
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info, bool is_pure,
DictAttrs attrs = DictAttrs(), Span span = Span()) {
return CreateEmpty(params, ret_struct_info, Optional<Bool>(Bool(is_pure)), attrs, span);
}

TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
};
Expand Down
99 changes: 86 additions & 13 deletions include/tvm/relax/struct_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,17 +320,50 @@ class FuncStructInfoNode : public StructInfoNode {
Optional<StructInfoDeriveFunc> derive_func;
/*!
* \brief Whether the function is pure.
* \note This parameter should be set to true only if the function is pure on all inputs.
* If the function _may_ have visible side effects, set it to false.
*
* There are three possible state for this value.
*
* `Bool(true)`: The function is known to be pure.
* `Bool(false)`: The function is known to be impure.
* `NullOpt`: The function's purity is unknown.
*
* In most cases `Bool(false)` and `NullOpt` should be treated
* equivalently, as they both indicate a function that may contain
* side effects. However, inference of purity may occur when the
* purity is `NullOpt`, based on analysis of a function's body. A
* function with purity of `Bool(false)` is known to be impure, and
* further analysis is unnecessary.
*
* \note This parameter should be set to true only if the function
* is pure on all inputs. If the function _may_ have visible side
* effects, set it to false.
*/
bool purity;
Optional<Bool> purity;

/*!
* \return Whether the func struct info is opaque.
* \note We define a function as opaque we have no constraints on params.
*/
bool IsOpaque() const { return !params.defined(); }

/*! \brief Whether the FuncStructInfo is known to be pure */
bool IsPure() const {
if (purity.defined()) {
return purity.value()->value;
} else {
return false;
}
}

/*! \brief Whether the function is known to be impure. */
bool IsImpure() const {
if (purity.defined()) {
return !purity.value()->value;
} else {
return false;
}
}

void VisitAttrs(AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("ret", &ret);
Expand Down Expand Up @@ -365,42 +398,82 @@ class FuncStructInfo : public StructInfo {
* \brief Constructor from parameter struct info and return value struct info.
* \param params The struct info of function parameters.
* \param ret The return value struct info.
* \param purity The purity of the function (true by default).
* \param purity The purity of the function (unknown by default).
* \param span The span of the AST.
*
* \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from
* params. If you are unsure, you can always erase ret to static.
*/
TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, bool purity = true,
TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, Optional<Bool> purity = NullOpt,
Span span = Span());

/*!
* \brief Constructor from parameter struct info and return value struct info.
* \param params The struct info of function parameters.
* \param ret The return value struct info.
* \param purity The purity of the function.
* \param span The span of the AST.
*
* \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from
* params. If you are unsure, you can always erase ret to static.
*/
TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, bool purity, Span span = Span())
: FuncStructInfo(params, ret, Optional<Bool>(Bool(purity)), span) {}

/*!
* \brief Constructing an opaque function struct info using derive_func.
*
* \param derive_func Derivation function.
* \param purity The purity of the function (unknown by default).
* \param span The span of the AST.
*
* \return The FuncStructInfo for opaque packedfunc.
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
*/
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func,
Optional<Bool> purity = NullOpt, Span span = Span());

/*!
* \brief Constructing an opaque function struct info using derive_func.
*
* \param derive_func Derivation function.
* \param purity The purity of the function
* (false by default: most external functions are not pure).
* \param purity The purity of the function.
* \param span The span of the AST.
*
* \return The FuncStructInfo for opaque packedfunc.
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
*/
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity,
Span span = Span()) {
return OpaqueFunc(derive_func, Optional<Bool>(Bool(purity)), span);
}

/*!
* \brief Construct an opaque function using from return struct info.
*
* \param ret The struct info of the return value.
* \param purity The purity of the function (unknown by default).
* \param span The span of the AST.
*
* \return The FuncStructInfo for opaque packedfunc.
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
*/
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity = false,
Span span = Span());
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(),
Optional<Bool> purity = NullOpt, Span span = Span());

/*!
* \brief Construct an opaque function using from return struct info.
*
* \param ret The struct info of the return value.
* \param purity The purity of the function
* (false by default: most external functions are not pure).
* \param purity The purity of the function.
* \param span The span of the AST.
*
* \return The FuncStructInfo for opaque packedfunc.
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
*/
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), bool purity = false,
Span span = Span());
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret, bool purity, Span span = Span()) {
return OpaqueFunc(ret, Optional<Bool>(Bool(purity)), span);
}

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode);
};
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <tvm/relax/expr.h>
#include <tvm/runtime/logging.h>

#include <optional>

namespace tvm {
namespace relax {

Expand Down Expand Up @@ -122,6 +124,19 @@ TVM_DLL bool IsLeafOrTuple(const Expr& expr);
*/
TVM_DLL bool IsImpureCall(const Call& call);

/*!
* \brief Return the purity of the given Call node. If the callee is a
* general expression, this simply requires checking the purity field
* of the FuncStructInfo. If it is an Op, then this checks the
* `fPurity` field.
*
* \param call The input call
*
* \return True if the call is known to be pure. False if the call is
* known to be impure. std::nullopt if the call's purity is unknown.
*/
TVM_DLL std::optional<bool> GetPurity(const Call& call);

/*!
* \brief Copy the given function. All variables that are bound inside the original function
* would be copied to satisfy the restriction in the well-formed check: Variables in
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/script/ir_builder/relax/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace relax {
* \param is_private Whether the function is annotated as private.
* \return The created ir_builder Function frame.
*/
TVM_DLL FunctionFrame Function(const Bool& is_pure, const Bool& is_private);
TVM_DLL FunctionFrame Function(const Optional<Bool>& is_pure, const Bool& is_private);

/*!
* \brief Add a parameter to the last function frame.
Expand Down
15 changes: 11 additions & 4 deletions python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,12 @@ def has_reshape_pattern(func: tir.PrimFunc) -> bool:
return _ffi_api.has_reshape_pattern(func) # type: ignore


def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] = None) -> bool:
"""
Check if the given expression (likely a function body) contains any impure calls.
def contains_impure_call(
expr: Expr,
own_name: Optional[Union[Var, GlobalVar]] = None,
assumed_purity_when_unknown: bool = False,
) -> bool:
"""Check if the given expression (likely a function body) contains any impure calls.

Parameters
----------
Expand All @@ -385,6 +388,10 @@ def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] =
For a recursive function, the analysis can ignore the self-calls
for checking purity.

assumed_purity_when_unknown: bool
The purity to assume when not otherwise known. Defaults to
False, treating unknown purity as impure

Returns
-------
ret : bool
Expand All @@ -397,7 +404,7 @@ def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] =
Also, an impure call in a *nested* function does *not* mean that the outer expression contains
an impure call--it only does if the nested function is *later called*.
"""
return _ffi_api.contains_impure_call(expr, own_name)
return _ffi_api.contains_impure_call(expr, own_name, assumed_purity_when_unknown)


def get_var2val(func: Function) -> Dict[Var, Expr]:
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def function(
name: str,
params: Optional[Union[Var, Tuple, List[Var]]] = None,
attrs: Optional[Dict[str, Object]] = None,
pure: bool = True,
pure: Optional[bool] = None,
private: bool = False,
) -> FunctionScope:
"""Annotate a Relax function.
Expand All @@ -227,7 +227,7 @@ def function(
attrs : Dict[str, Object], optional
The function attrs

pure : bool, optional
pure : Optional[bool]
Whether the function is annotated as pure.

private : bool, optional
Expand All @@ -236,6 +236,7 @@ def function(
If it is not private and not an inner function, then it will have
a global symbol attribute (mapped to the function's name)


Returns
-------
ret: FunctionScope
Expand Down
9 changes: 2 additions & 7 deletions python/tvm/relax/distributed/global_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,8 @@ def __init__(
):
if isinstance(shape, (list, tuple)):
shape = ShapeTuple(shape)
device_range = None
if isinstance(device_ids, Range):
device_range = device_ids
device_ids = []
self.__init_handle_by_constructor__(
ffi.DeviceMesh, shape, device_ids, device_range
) # type: ignore

self.__init_handle_by_constructor__(ffi.DeviceMesh, shape, device_ids) # type: ignore


def device_mesh(shape: ShapeTuple, device_ids: Union[List[int], Range]) -> DeviceMesh:
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ class Function(BaseFunc, Scriptable):
params: List[Var]
body: Expr
ret_struct_info: StructInfo
is_pure: bool
is_pure: Optional[bool]
attrs: tvm.ir.DictAttrs
span: Optional[Span]

Expand All @@ -984,7 +984,7 @@ def __init__(
params: List[Var],
body: Expr,
ret_struct_info: Optional[StructInfo] = None,
is_pure: Optional[bool] = True,
is_pure: Optional[bool] = None,
attrs: Optional[tvm.ir.DictAttrs] = None,
span: Optional[Span] = None,
) -> None:
Expand All @@ -1002,7 +1002,7 @@ def __init__(
def create_empty(
params: List[Var],
ret_struct_info: StructInfo,
is_pure: Optional[bool] = True,
is_pure: Optional[bool] = None,
attrs: Optional[tvm.ir.DictAttrs] = None,
span: Optional[Span] = None,
):
Expand Down
Loading