Skip to content

Commit 5a2cfd8

Browse files
author
Joey Tsai
committed
[SpanFillingCommonAPI]
- Expose Relay Expr WithFields APIs to python side - Change the APIs in _SpanFiller from creating new instance to WithFields. - Change the control of frontend span filler from env var to the passcontext config
1 parent 3a873d9 commit 5a2cfd8

File tree

7 files changed

+228
-47
lines changed

7 files changed

+228
-47
lines changed

python/tvm/relay/expr.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,21 @@ def __init__(self, data, span=None):
180180
self.__init_handle_by_constructor__(_ffi_api.Constant, data, span)
181181

182182

183+
@tvm._ffi.register_func("relay.ConstantWithFields")
184+
def ConstantWithFields(
185+
constant,
186+
data=None,
187+
virtual_device=None,
188+
span=None,
189+
):
190+
"""
191+
Returns constant with the given properties. A None property denotes 'no change'.
192+
Returns constant if all properties are unchanged. Otherwise, returns a copy with the new
193+
fields.
194+
"""
195+
return _ffi_api.ConstantWithFields(constant, data, virtual_device, span)
196+
197+
183198
@tvm._ffi.register_object("relay.Tuple")
184199
class Tuple(ExprWithOp):
185200
"""Tuple expression that groups several fields together.
@@ -208,6 +223,16 @@ def astype(self, _):
208223
raise TypeError("astype cannot be used on tuple")
209224

210225

226+
@tvm._ffi.register_func("relay.TupleWithFields")
227+
def TupleWithFields(tup, fields=None, virtual_device=None, span=None):
228+
"""
229+
Returns tuple with the given properties. A None property denotes 'no change'.
230+
Returns tuple if all properties are unchanged. Otherwise, returns a copy with the new
231+
fields.
232+
"""
233+
return _ffi_api.TupleWithFields(tup, fields, virtual_device, span)
234+
235+
211236
@tvm._ffi.register_object("relay.Var")
212237
class Var(ExprWithOp):
213238
"""A local variable in Relay.
@@ -239,6 +264,16 @@ def name_hint(self):
239264
return name
240265

241266

267+
@tvm._ffi.register_func("relay.VarWithFields")
268+
def VarWithFields(variable, vid=None, type_annotation=None, virtual_device=None, span=None):
269+
"""
270+
Returns var with the given properties. A None property denotes 'no change'.
271+
Returns var if all properties are unchanged. Otherwise, returns a copy with the new
272+
fields.
273+
"""
274+
return _ffi_api.VarWithFields(variable, vid, type_annotation, virtual_device, span)
275+
276+
242277
@tvm._ffi.register_object("relay.Call")
243278
class Call(ExprWithOp):
244279
"""Function call node in Relay.
@@ -271,6 +306,18 @@ def __init__(self, op, args, attrs=None, type_args=None, span=None):
271306
self.__init_handle_by_constructor__(_ffi_api.Call, op, args, attrs, type_args, span)
272307

273308

309+
@tvm._ffi.register_func("relay.CallWithFields")
310+
def CallWithFields(
311+
call, op=None, args=None, attrs=None, type_args=None, virtual_device=None, span=None
312+
):
313+
"""
314+
Returns call with the given properties. A None property denotes 'no change'.
315+
Returns call if all properties are unchanged. Otherwise, returns a copy with the new
316+
fields.
317+
"""
318+
return _ffi_api.CallWithFields(call, op, args, attrs, type_args, virtual_device, span)
319+
320+
274321
@tvm._ffi.register_object("relay.Let")
275322
class Let(ExprWithOp):
276323
"""Let variable binding expression.
@@ -294,6 +341,16 @@ def __init__(self, variable, value, body, span=None):
294341
self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, body, span)
295342

296343

344+
@tvm._ffi.register_func("relay.LetWithFields")
345+
def LetWithFields(let, variable=None, value=None, body=None, virtual_device=None, span=None):
346+
"""
347+
Returns let with the given properties. A None property denotes 'no change'.
348+
Returns let if all properties are unchanged. Otherwise, returns a copy with the new
349+
fields.
350+
"""
351+
return _ffi_api.LetWithFields(let, variable, value, body, virtual_device, span)
352+
353+
297354
@tvm._ffi.register_object("relay.If")
298355
class If(ExprWithOp):
299356
"""A conditional expression in Relay.
@@ -317,6 +374,18 @@ def __init__(self, cond, true_branch, false_branch, span=None):
317374
self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch, span)
318375

319376

377+
@tvm._ffi.register_func("relay.IfWithFields")
378+
def IfWithFields(
379+
if_expr, cond=None, true_branch=None, false_branch=None, virtual_device=None, span=None
380+
):
381+
"""
382+
Returns if with the given properties. A None property denotes 'no change'.
383+
Returns if if all properties are unchanged. Otherwise, returns a copy with the new
384+
fields.
385+
"""
386+
return _ffi_api.IfWithFields(if_expr, cond, true_branch, false_branch, virtual_device, span)
387+
388+
320389
@tvm._ffi.register_object("relay.TupleGetItem")
321390
class TupleGetItem(ExprWithOp):
322391
"""Get index-th item from a tuple.
@@ -337,6 +406,18 @@ def __init__(self, tuple_value, index, span=None):
337406
self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index, span)
338407

339408

409+
@tvm._ffi.register_func("relay.TupleGetItemWithFields")
410+
def TupleGetItemWithFields(
411+
tuple_get_item, tuple_value=None, index=None, virtual_device=None, span=None
412+
):
413+
"""
414+
Returns tuple_get_item with the given properties. A None property denotes 'no change'.
415+
Returns tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new
416+
fields.
417+
"""
418+
return _ffi_api.TupleGetItemWithFields(tuple_get_item, tuple_value, index, virtual_device, span)
419+
420+
340421
@tvm._ffi.register_object("relay.RefCreate")
341422
class RefCreate(ExprWithOp):
342423
"""Create a new reference from initial value.
@@ -353,6 +434,21 @@ def __init__(self, value, span=None):
353434
self.__init_handle_by_constructor__(_ffi_api.RefCreate, value, span)
354435

355436

437+
@tvm._ffi.register_func("relay.RefCreateWithFields")
438+
def RefCreateWithFields(
439+
ref_create,
440+
value=None,
441+
virtual_device=None,
442+
span=None,
443+
):
444+
"""
445+
Returns ref_create with the given properties. A None property denotes 'no change'.
446+
Returns ref_create if all properties are unchanged. Otherwise, returns a copy with the new
447+
fields.
448+
"""
449+
return _ffi_api.RefCreateWithFields(ref_create, value, virtual_device, span)
450+
451+
356452
@tvm._ffi.register_object("relay.RefRead")
357453
class RefRead(ExprWithOp):
358454
"""Get the value inside the reference.
@@ -369,6 +465,21 @@ def __init__(self, ref, span=None):
369465
self.__init_handle_by_constructor__(_ffi_api.RefRead, ref, span)
370466

371467

468+
@tvm._ffi.register_func("relay.RefReadWithFields")
469+
def RefReadWithFields(
470+
ref_read,
471+
ref=None,
472+
virtual_device=None,
473+
span=None,
474+
):
475+
"""
476+
Returns ref_read with the given properties. A None property denotes 'no change'.
477+
Returns ref_read if all properties are unchanged. Otherwise, returns a copy with the new
478+
fields.
479+
"""
480+
return _ffi_api.RefReadWithFields(ref_read, ref, virtual_device, span)
481+
482+
372483
@tvm._ffi.register_object("relay.RefWrite")
373484
class RefWrite(ExprWithOp):
374485
"""
@@ -390,6 +501,22 @@ def __init__(self, ref, value, span=None):
390501
self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value, span)
391502

392503

504+
@tvm._ffi.register_func("relay.RefWriteWithFields")
505+
def RefWriteWithFields(
506+
ref_write,
507+
ref=None,
508+
value=None,
509+
virtual_device=None,
510+
span=None,
511+
):
512+
"""
513+
Returns ref_write with the given properties. A None property denotes 'no change'.
514+
Returns ref_write if all properties are unchanged. Otherwise, returns a copy with the new
515+
fields.
516+
"""
517+
return _ffi_api.RefWriteWithFields(ref_write, ref, value, virtual_device, span)
518+
519+
393520
class TempExpr(ExprWithOp):
394521
"""Baseclass of all TempExpr.
395522

python/tvm/relay/frontend/common.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
"""Common utilities"""
1919
from __future__ import absolute_import as _abs
2020
import logging
21-
import os
2221
import numpy as np
2322

2423
import tvm
@@ -1027,42 +1026,50 @@ def visit(self, expr):
10271026
def visit_function(self, fn):
10281027
new_params = [self.visit(x) for x in fn.params]
10291028
new_body = self.visit(fn.body)
1030-
return _function.Function(
1031-
list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs, self._span
1029+
return _function.FunctionWithFields(
1030+
fn, list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs, None, self._span
10321031
)
10331032

10341033
def visit_let(self, let):
10351034
new_variable = self.visit(let.var)
10361035
new_value = self.visit(let.value)
10371036
new_body = self.visit(let.body)
1038-
return _expr.Let(new_variable, new_value, new_body, self._span)
1037+
return _expr.LetWithFields(let, new_variable, new_value, new_body, None, self._span)
10391038

10401039
def visit_call(self, call):
10411040
new_args = [self.visit(arg) for arg in call.args]
10421041
# call.op might be RelayExpr or Op type
10431042
# ExprMutator will return directly if subject belongs to Op type
10441043
new_op = self.visit(call.op)
1045-
return _expr.Call(new_op, new_args, call.attrs, call.type_args, self._span)
1044+
return _expr.CallWithFields(
1045+
call, new_op, new_args, call.attrs, call.type_args, None, self._span
1046+
)
10461047

10471048
def visit_var(self, var):
1048-
return _expr.Var(var.name_hint, var.type_annotation, self._span)
1049+
return _expr.VarWithFields(var, var.vid, var.type_annotation, None, self._span)
10491050

10501051
def visit_if(self, ite):
1051-
return _expr.If(
1052+
return _expr.IfWithFields(
1053+
ite,
10521054
self.visit(ite.cond),
10531055
self.visit(ite.true_branch),
10541056
self.visit(ite.false_branch),
1057+
None,
10551058
self._span,
10561059
)
10571060

10581061
def visit_tuple(self, tup):
1059-
return _expr.Tuple([self.visit(field) for field in tup.fields], self._span)
1062+
return _expr.TupleWithFields(
1063+
tup, [self.visit(field) for field in tup.fields], None, self._span
1064+
)
10601065

10611066
def visit_tuple_getitem(self, op):
1062-
return _expr.TupleGetItem(self.visit(op.tuple_value), op.index, self._span)
1067+
return _expr.TupleGetItemWithFields(
1068+
op, self.visit(op.tuple_value), op.index, None, self._span
1069+
)
10631070

10641071
def visit_constant(self, const):
1065-
return _expr.Constant(const.data, self._span)
1072+
return _expr.ConstantWithFields(const, const.data, None, self._span)
10661073

10671074
# TODO: Frontend model translation could not use following relay expressions so far,
10681075
# enable them when new models/impls leverage these kinds of relay expressions.
@@ -1115,23 +1122,10 @@ def fill(self, sym):
11151122
raise RuntimeError(f"unsupported type {type(sym)}")
11161123

11171124

1118-
def _should_fill_span():
1119-
should_fill_span = os.environ.get("TVM_SPANFILLING", "1")
1120-
1121-
try:
1122-
should_fill_span = bool(int(should_fill_span))
1123-
except ValueError:
1124-
raise ValueError(
1125-
f"invalid value for TVM_SPANFILLING {should_fill_span}, please set to 0 or 1."
1126-
)
1127-
1128-
return should_fill_span
1129-
1130-
11311125
def set_span(sym, span):
11321126
"""
11331127
Recursively tag the span to the symbol. Stop when it encounters a span-tagged expr. Disabled
1134-
when setting the environment variable "TVM_SPANFILLING" as 0.
1128+
when setting the "relay.frontend.fill_span" as False to the config of PassContext
11351129
11361130
Parameters
11371131
----------
@@ -1163,6 +1157,6 @@ def set_span(sym, span):
11631157
#}
11641158
"""
11651159

1166-
if _should_fill_span():
1160+
if tvm.transform.PassContext.current().config.get("relay.frontend.fill_span", True):
11671161
return _SpanFiller(span).fill(sym)
11681162
return sym

python/tvm/testing/utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,17 +2085,14 @@ def pprint(name, obj):
20852085

20862086
class _control_span_filling:
20872087
def __init__(self, on=True):
2088-
self._old_state = os.environ["TVM_SPANFILLING"] if "TVM_SPANFILLING" in os.environ else None
20892088
self._on = on
2089+
self._pass_ctx = tvm.transform.PassContext(config={"relay.frontend.fill_span": self._on})
20902090

20912091
def __enter__(self):
2092-
os.environ["TVM_SPANFILLING"] = str(int(self._on))
2092+
self._pass_ctx.__enter__()
20932093

20942094
def __exit__(self, exc_type, exc_val, exc_tb):
2095-
if self._old_state:
2096-
os.environ["TVM_SPANFILLING"] = self._old_state
2097-
else:
2098-
del os.environ["TVM_SPANFILLING"]
2095+
self._pass_ctx.__exit__(exc_type, exc_val, exc_tb)
20992096

21002097

21012098
class enable_span_filling(_control_span_filling):

src/ir/span.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@
2020
* \file span.cc
2121
* \brief The span data structure.
2222
*/
23+
#include <tvm/ir/expr.h>
2324
#include <tvm/ir/span.h>
25+
#include <tvm/ir/transform.h>
2426
#include <tvm/runtime/registry.h>
2527

2628
#include <algorithm>
2729

2830
namespace tvm {
2931

32+
TVM_REGISTER_PASS_CONFIG_OPTION("relay.frontend.fill_span", Bool);
33+
3034
ObjectPtr<Object> GetSourceNameNode(const String& name) {
3135
// always return pointer as the reference can change as map re-allocate.
3236
// or use another level of indirection by creating a unique_ptr

0 commit comments

Comments
 (0)