Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,14 @@ def bind_assign_value(
"Expected the same dtype for TIR vars "
f"but got {value.dtype} vs {prev_value.dtype}",
)
return prev_value
if not isinstance(value, type(prev_value)):
self.report_error(
node,
f"Expected the same IR type for TIR vars "
f"but existing value {type(value)} is mismatched "
f"to previous {type(prev_value)}",
)
value = prev_value
IRBuilder.name(var_name, value)
return value

Expand Down Expand Up @@ -144,18 +151,47 @@ def is_recursive(node: doc.FunctionDef) -> bool:
return False


def collect_symbolic_var_from_prelude(
self: Parser, node: doc.FunctionDef, symbolic_vars: Dict[str, tir.Var]
) -> Dict[str, tir.Var]:
prelude_vars = {}
for stmt in node.body:
if isinstance(stmt, doc.Assign) and all(
isinstance(target, doc.Name) and target.id in symbolic_vars for target in stmt.targets
):
values = self.eval_expr(stmt.value)

try:
iter(values)
except TypeError:
values = [values]

assert len(stmt.targets) == len(values)
for target, value in zip(stmt.targets, values):
name = target.id
prelude_vars[name] = value

return {**symbolic_vars, **prelude_vars}


def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> None:
# Collect symbolic vars from parameters
symbolic_vars = set()
symbolic_vars = {}
for arg in node.args.args:
if arg.annotation is None:
self.report_error(arg, "Type annotation is required for function parameters.")
param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation)
symbolic_vars.update(param_sinfo_proxy.get_symbolic_vars())

for var_name in param_sinfo_proxy.get_symbolic_vars():
if var_name not in symbolic_vars:
symbolic_vars[var_name] = tir.Var(var_name, "int64")

# Update symbolic vars based on
symbolic_vars = collect_symbolic_var_from_prelude(self, node, symbolic_vars)

# Define symbolic vars to the current var_table frame
for var_name in symbolic_vars:
self.var_table.add(var_name, tir.Var(var_name, "int64"), allow_shadowing=False)
for var_name, var in symbolic_vars.items():
self.var_table.add(var_name, var, allow_shadowing=False)


@dispatch.register(token="relax", type_name="FunctionDef")
Expand Down
3 changes: 2 additions & 1 deletion src/script/printer/relax/tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
#include <tvm/ir/expr.h>

#include "../tir/utils.h"
#include "./utils.h"

namespace tvm {
Expand Down Expand Up @@ -59,7 +60,7 @@ Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) {
}
IdDoc var = d->Define(n, GetRef<Frame>(f), n->name_hint.empty() ? "v" : n->name_hint);
var->source_paths.push_back(n_p);
f->stmts.push_back(AssignDoc(var, TIR(d, DType2Str(n->dtype))->Call({}), NullOpt));
f->stmts.push_back(AssignDoc(var, PrintVarCreation(n, n_p, d), NullOpt));
}
if (Optional<ExprDoc> doc = d->GetVarDoc(n)) {
return doc.value();
Expand Down
28 changes: 28 additions & 0 deletions tests/python/tvmscript/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4088,6 +4088,32 @@ def func(A: R.Object):
yield make_ir_generator(subclass)


def relax_symbolic_size_var():
"""Relax symbolic variables may be SizeVar"""
N = tvm.tir.SizeVar("N", "int64")

@R.function
def func(A: R.Tensor([N], "float16")):
B: R.Tensor([N], "float16") = A
return B

return func


def relax_float_symbolic_var():
"""Relax symbolic variables may hold any dtype"""

@R.function
def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")):
N = T.int64()
threshold = T.float16()

B = A >= R.prim_value(threshold / T.cast(N, "float16"))
return B

return func


ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
Expand Down Expand Up @@ -4174,6 +4200,8 @@ def func(A: R.Object):
return_zero_private_with_attr,
*op_of_literal(),
*relax_match_cast_struct_info_proxy(),
relax_symbolic_size_var,
relax_float_symbolic_var,
)

relax_ir_generator = tvm.testing.parameter(
Expand Down