Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 53 additions & 22 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,13 @@ class Constant(ExprWithOp):
----------
data : tvm.nd.NDArray
The data content of the constant expression.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, data):
self.__init_handle_by_constructor__(_ffi_api.Constant, data)
def __init__(self, data, span=None):
self.__init_handle_by_constructor__(_ffi_api.Constant, data, span)


@tvm._ffi.register_object("relay.Tuple")
Expand All @@ -187,7 +190,7 @@ class Tuple(ExprWithOp):
The fields in the tuple.

span: Optional[tvm.relay.Span]
Span that points to original source code
Span that points to original source code.
"""

def __init__(self, fields, span=None):
Expand Down Expand Up @@ -221,10 +224,13 @@ class Var(ExprWithOp):

type_annotation: tvm.relay.Type, optional
The type annotation on the variable.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, name_hint, type_annotation=None):
self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, type_annotation)
def __init__(self, name_hint, type_annotation=None, span=None):
self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, type_annotation, span)

@property
def name_hint(self):
Expand Down Expand Up @@ -256,7 +262,7 @@ class Call(ExprWithOp):
used in advanced usecase of template functions.

span: Optional[tvm.relay.Span]
Span that points to original source code
Span that points to original source code.
"""

def __init__(self, op, args, attrs=None, type_args=None, span=None):
Expand All @@ -279,10 +285,13 @@ class Let(ExprWithOp):

body: tvm.relay.Expr
The body of the let binding.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, variable, value, body):
self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, body)
def __init__(self, variable, value, body, span=None):
self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, body, span)


@tvm._ffi.register_object("relay.If")
Expand All @@ -299,10 +308,13 @@ class If(ExprWithOp):

false_branch: tvm.relay.Expr
The expression evaluated when condition is false.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, cond, true_branch, false_branch):
self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch)
def __init__(self, cond, true_branch, false_branch, span=None):
self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch, span)


@tvm._ffi.register_object("relay.TupleGetItem")
Expand All @@ -316,10 +328,13 @@ class TupleGetItem(ExprWithOp):

index: int
The index.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, tuple_value, index):
self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index)
def __init__(self, tuple_value, index, span=None):
self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index, span)


@tvm._ffi.register_object("relay.RefCreate")
Expand All @@ -329,10 +344,13 @@ class RefCreate(ExprWithOp):
----------
value: tvm.relay.Expr
The initial value.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, value):
self.__init_handle_by_constructor__(_ffi_api.RefCreate, value)
def __init__(self, value, span=None):
self.__init_handle_by_constructor__(_ffi_api.RefCreate, value, span)


@tvm._ffi.register_object("relay.RefRead")
Expand All @@ -342,10 +360,13 @@ class RefRead(ExprWithOp):
----------
ref: tvm.relay.Expr
The reference.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, ref):
self.__init_handle_by_constructor__(_ffi_api.RefRead, ref)
def __init__(self, ref, span=None):
self.__init_handle_by_constructor__(_ffi_api.RefRead, ref, span)


@tvm._ffi.register_object("relay.RefWrite")
Expand All @@ -357,12 +378,16 @@ class RefWrite(ExprWithOp):
----------
ref: tvm.relay.Expr
The reference.

value: tvm.relay.Expr
The new value.

span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, ref, value):
self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value)
def __init__(self, ref, value, span=None):
self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value, span)


class TempExpr(ExprWithOp):
Expand Down Expand Up @@ -433,7 +458,7 @@ def astype(self, _):
raise TypeError("astype cannot be used on tuple")


def var(name_hint, type_annotation=None, shape=None, dtype="float32"):
def var(name_hint, type_annotation=None, shape=None, dtype="float32", span=None):
"""Create a new tvm.relay.Var.

This is a simple wrapper function that allows specify
Expand All @@ -456,6 +481,9 @@ def var(name_hint, type_annotation=None, shape=None, dtype="float32"):
dtype: str, optional
The data type of the tensor.

span: Optional[tvm.relay.Span]
Span that points to original source code.

Examples
--------
.. code-block:: python
Expand All @@ -476,10 +504,10 @@ def var(name_hint, type_annotation=None, shape=None, dtype="float32"):
type_annotation = _ty.TensorType(shape, dtype)
elif isinstance(type_annotation, str):
type_annotation = _ty.TensorType((), type_annotation)
return Var(name_hint, type_annotation)
return Var(name_hint, type_annotation, span)


def const(value, dtype=None):
def const(value, dtype=None, span=None):
"""Create a constant value.

Parameters
Expand All @@ -490,6 +518,9 @@ def const(value, dtype=None):
dtype: str, optional
The data type of the resulting constant.

span: Optional[tvm.relay.Span]
Span that points to original source code.

Note
----
When dtype is None, we use the following rule:
Expand All @@ -516,7 +547,7 @@ def const(value, dtype=None):
if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")

return Constant(value)
return Constant(value, span)


def bind(expr, binds):
Expand Down
Loading