Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,22 +591,7 @@ def wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Sequence[Tensor]]:
The computed result.
"""
if not isinstance(expr, rx.DataflowVar):
block_builder = BlockBuilder.current()
if block_builder is None:
# Normalize to make sure we have valid StructInfo, but
# wait until we are actually building the function to
# flatten nested expressions.
#
# TODO(Lunderberg): Make this easier to call. Infering
# struct info for a nested expression should be doable in
# a free function, without requiring an active
# BlockBuilder and an active FunctionFrame.
builder = BlockBuilder()
with builder.function("dummy_scope", params=[]):
expr = builder.normalize(expr)
builder.emit_func_output([])
else:
expr = BlockBuilder.current().emit(expr, name)
expr = BlockBuilder.current().emit(expr, name)
if isinstance(expr.struct_info_, TensorStructInfo):
return Tensor(_expr=expr)
if isinstance(expr.struct_info_, TupleStructInfo):
Expand Down
40 changes: 21 additions & 19 deletions python/tvm/relax/frontend/nn/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
return result

# pylint: enable=protected-access

params = _params()
params = None
effects = _effects()
ext_mods = self.extern_mods
with self:
Expand All @@ -122,6 +121,7 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
outputs = _emit_effect_init(self.builder, effects)
self.builder.emit_func_output(outputs, params=[])
for method_name, method_spec in zip(spec.method_names, spec.method_specs):
params = _params() # Re-initialize so symbolic shapes not shared across methods
len_args = len(method_spec.arg_specs)
len_effects = {
"packed": 1,
Expand All @@ -135,18 +135,9 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
with self.builder.dataflow():
outputs, inputs = _emit_method(self.builder, method_spec, params, effects)
self.builder.emit_func_output(outputs, inputs)

# TODO(Lunderberg): Make a `ir.transform.ConvertSSA`,
# similar to the existing `tir.transform.ConvertSSA`,
# that converts an entire module to SSA, including TIR
# variable definitions used in either TIR or Relax.
mod = self.builder.get()
mod[method_name] = rx.utils.copy_with_new_vars(mod[method_name])

mod = self.builder.finalize()
assert rx.analysis.well_formed(mod)

mod = rx.transform.CanonicalizeBindings()(mod)
return mod, params, ext_mods


Expand All @@ -170,6 +161,8 @@ def _emit_method( # pylint: disable=too-many-locals,too-many-branches,too-many-
effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]],
):
# pylint: disable=protected-access
# symbolic shape's name mapping to its tir.Var for reuse
str2var_params: typing.Dict[str, tir.Var] = {}

def _unwrap_ret(expr: typing.Any) -> typing.Any:
if isinstance(expr, (core.Tensor, core.Object)):
Expand All @@ -183,26 +176,35 @@ def _unwrap_ret(expr: typing.Any) -> typing.Any:
def _convert_input(arg):
if isinstance(arg, tir.Var):
return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg]))
elif isinstance(arg, (core.Tensor, core.Object)):
if isinstance(arg, (core.Tensor, core.Object)):
return arg._expr # pylint: disable=protected-access
elif isinstance(arg, _spec.Tuple):
if isinstance(arg, _spec.Tuple):
return rx.Var(
arg.name,
struct_info=TupleStructInfo(
[_convert_input(arg_i).struct_info for arg_i in arg.elements]
),
)
elif isinstance(arg, rx.Expr):
return arg
else:
raise TypeError(f"Unsupported input type: {type(arg)}")
raise TypeError(f"Unsupported input type: {type(arg)}")

def _params(mode: str) -> typing.List[rx.Var]:
inputs: typing.List[rx.Var] = []

for name, param in params:
inputs.append(param._expr)
def _get_var(shape_var: tir.Var) -> tir.Var:
name = shape_var.name
if name in str2var_params:
return str2var_params[name]
var = tir.Var(name, "int64")
str2var_params[name] = var
return var

for name, param in params:
# Make sure the a symbolic shape is not re-registered (same as _method_spec_to_inputs)
# e.g. we do not see `vocab_size` for `lm_head` and `vocab_size_1` for `embed_tokens`
new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in param.shape]
var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr
inputs.append(var)
param._expr = var
if mode == "none":
return []
if mode == "plain":
Expand Down
Loading