Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 180 additions & 22 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,28 @@ class Constant(ExprWithOp):
----------
data : tvm.nd.NDArray
The data content of the constant expression.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, data):
self.__init_handle_by_constructor__(_ffi_api.Constant, data)
def __init__(self, data, span=None):
self.__init_handle_by_constructor__(_ffi_api.Constant, data, span)


@tvm._ffi.register_func("relay.ConstantWithFields")
def ConstantWithFields(
constant,
data=None,
virtual_device=None,
span=None,
):
"""
Returns constant with the given properties. A None property denotes 'no change'.
Returns constant if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.ConstantWithFields(constant, data, virtual_device, span)


@tvm._ffi.register_object("relay.Tuple")
Expand All @@ -187,7 +205,7 @@ class Tuple(ExprWithOp):
The fields in the tuple.

span: Optional[tvm.relay.Span]
Span that points to original source code
Span that points to original source code.
"""

def __init__(self, fields, span=None):
Expand All @@ -205,6 +223,16 @@ def astype(self, _):
raise TypeError("astype cannot be used on tuple")


@tvm._ffi.register_func("relay.TupleWithFields")
def TupleWithFields(tup, fields=None, virtual_device=None, span=None):
"""
Returns tuple with the given properties. A None property denotes 'no change'.
Returns tuple if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.TupleWithFields(tup, fields, virtual_device, span)


@tvm._ffi.register_object("relay.Var")
class Var(ExprWithOp):
"""A local variable in Relay.
Expand All @@ -221,10 +249,13 @@ class Var(ExprWithOp):

type_annotation: tvm.relay.Type, optional
The type annotation on the variable.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, name_hint, type_annotation=None):
self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, type_annotation)
def __init__(self, name_hint, type_annotation=None, span=None):
self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, type_annotation, span)

@property
def name_hint(self):
Expand All @@ -233,6 +264,16 @@ def name_hint(self):
return name


@tvm._ffi.register_func("relay.VarWithFields")
def VarWithFields(variable, vid=None, type_annotation=None, virtual_device=None, span=None):
"""
Returns var with the given properties. A None property denotes 'no change'.
Returns var if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.VarWithFields(variable, vid, type_annotation, virtual_device, span)


@tvm._ffi.register_object("relay.Call")
class Call(ExprWithOp):
"""Function call node in Relay.
Expand All @@ -256,7 +297,7 @@ class Call(ExprWithOp):
used in advanced usecase of template functions.

span: Optional[tvm.relay.Span]
Span that points to original source code
Span that points to original source code.
"""

def __init__(self, op, args, attrs=None, type_args=None, span=None):
Expand All @@ -265,6 +306,18 @@ def __init__(self, op, args, attrs=None, type_args=None, span=None):
self.__init_handle_by_constructor__(_ffi_api.Call, op, args, attrs, type_args, span)


@tvm._ffi.register_func("relay.CallWithFields")
def CallWithFields(
call, op=None, args=None, attrs=None, type_args=None, virtual_device=None, span=None
):
"""
Returns call with the given properties. A None property denotes 'no change'.
Returns call if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.CallWithFields(call, op, args, attrs, type_args, virtual_device, span)


@tvm._ffi.register_object("relay.Let")
class Let(ExprWithOp):
"""Let variable binding expression.
Expand All @@ -279,10 +332,23 @@ class Let(ExprWithOp):

body: tvm.relay.Expr
The body of the let binding.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, variable, value, body):
self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, body)
def __init__(self, variable, value, body, span=None):
self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, body, span)


@tvm._ffi.register_func("relay.LetWithFields")
def LetWithFields(let, variable=None, value=None, body=None, virtual_device=None, span=None):
"""
Returns let with the given properties. A None property denotes 'no change'.
Returns let if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.LetWithFields(let, variable, value, body, virtual_device, span)


@tvm._ffi.register_object("relay.If")
Expand All @@ -299,10 +365,25 @@ class If(ExprWithOp):

false_branch: tvm.relay.Expr
The expression evaluated when condition is false.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, cond, true_branch, false_branch):
self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch)
def __init__(self, cond, true_branch, false_branch, span=None):
self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch, span)


@tvm._ffi.register_func("relay.IfWithFields")
def IfWithFields(
if_expr, cond=None, true_branch=None, false_branch=None, virtual_device=None, span=None
):
"""
Returns if with the given properties. A None property denotes 'no change'.
Returns if if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.IfWithFields(if_expr, cond, true_branch, false_branch, virtual_device, span)


@tvm._ffi.register_object("relay.TupleGetItem")
Expand All @@ -316,10 +397,25 @@ class TupleGetItem(ExprWithOp):

index: int
The index.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, tuple_value, index):
self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index)
def __init__(self, tuple_value, index, span=None):
self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index, span)


@tvm._ffi.register_func("relay.TupleGetItemWithFields")
def TupleGetItemWithFields(
tuple_get_item, tuple_value=None, index=None, virtual_device=None, span=None
):
"""
Returns tuple_get_item with the given properties. A None property denotes 'no change'.
Returns tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.TupleGetItemWithFields(tuple_get_item, tuple_value, index, virtual_device, span)


@tvm._ffi.register_object("relay.RefCreate")
Expand All @@ -329,10 +425,28 @@ class RefCreate(ExprWithOp):
----------
value: tvm.relay.Expr
The initial value.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, value):
self.__init_handle_by_constructor__(_ffi_api.RefCreate, value)
def __init__(self, value, span=None):
self.__init_handle_by_constructor__(_ffi_api.RefCreate, value, span)


@tvm._ffi.register_func("relay.RefCreateWithFields")
def RefCreateWithFields(
ref_create,
value=None,
virtual_device=None,
span=None,
):
"""
Returns ref_create with the given properties. A None property denotes 'no change'.
Returns ref_create if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.RefCreateWithFields(ref_create, value, virtual_device, span)


@tvm._ffi.register_object("relay.RefRead")
Expand All @@ -342,10 +456,28 @@ class RefRead(ExprWithOp):
----------
ref: tvm.relay.Expr
The reference.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, ref):
self.__init_handle_by_constructor__(_ffi_api.RefRead, ref)
def __init__(self, ref, span=None):
self.__init_handle_by_constructor__(_ffi_api.RefRead, ref, span)


@tvm._ffi.register_func("relay.RefReadWithFields")
def RefReadWithFields(
ref_read,
ref=None,
virtual_device=None,
span=None,
):
"""
Returns ref_read with the given properties. A None property denotes 'no change'.
Returns ref_read if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.RefReadWithFields(ref_read, ref, virtual_device, span)


@tvm._ffi.register_object("relay.RefWrite")
Expand All @@ -357,12 +489,32 @@ class RefWrite(ExprWithOp):
----------
ref: tvm.relay.Expr
The reference.

value: tvm.relay.Expr
The new value.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, ref, value):
self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value)
def __init__(self, ref, value, span=None):
self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value, span)


@tvm._ffi.register_func("relay.RefWriteWithFields")
def RefWriteWithFields(
ref_write,
ref=None,
value=None,
virtual_device=None,
span=None,
):
"""
Returns ref_write with the given properties. A None property denotes 'no change'.
Returns ref_write if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.RefWriteWithFields(ref_write, ref, value, virtual_device, span)


class TempExpr(ExprWithOp):
Expand Down Expand Up @@ -433,7 +585,7 @@ def astype(self, _):
raise TypeError("astype cannot be used on tuple")


def var(name_hint, type_annotation=None, shape=None, dtype="float32"):
def var(name_hint, type_annotation=None, shape=None, dtype="float32", span=None):
"""Create a new tvm.relay.Var.

This is a simple wrapper function that allows specify
Expand All @@ -456,6 +608,9 @@ def var(name_hint, type_annotation=None, shape=None, dtype="float32"):
dtype: str, optional
The data type of the tensor.

span: Optional[tvm.relay.Span]
Span that points to original source code.

Examples
--------
.. code-block:: python
Expand All @@ -476,10 +631,10 @@ def var(name_hint, type_annotation=None, shape=None, dtype="float32"):
type_annotation = _ty.TensorType(shape, dtype)
elif isinstance(type_annotation, str):
type_annotation = _ty.TensorType((), type_annotation)
return Var(name_hint, type_annotation)
return Var(name_hint, type_annotation, span)


def const(value, dtype=None):
def const(value, dtype=None, span=None):
"""Create a constant value.

Parameters
Expand All @@ -490,6 +645,9 @@ def const(value, dtype=None):
dtype: str, optional
The data type of the resulting constant.

span: Optional[tvm.relay.Span]
Span that points to original source code.

Note
----
When dtype is None, we use the following rule:
Expand All @@ -516,7 +674,7 @@ def const(value, dtype=None):
if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")

return Constant(value)
return Constant(value, span)


def bind(expr, binds):
Expand Down
Loading