diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 9a8f835e819b..1ad681388912 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -99,6 +99,8 @@ class FunctionFrameNode : public SeqExprFrameNode { Optional ret_struct_info; /*! \brief Whether the function is annotated as pure */ Optional is_pure; + /*! \brief Whether the function is annotated as private */ + Optional is_private; /*! \brief The function attributes. */ Map attrs; /*! \brief The block builder to create Relax function. */ diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index 1cf30b491957..d160ad090e48 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -34,9 +34,10 @@ namespace relax { /*! * \brief Start a function frame. * \param is_pure Whether the function is annotated as pure. + * \param is_private Whether the function is annotated as private. * \return The created ir_builder Function frame. */ -TVM_DLL FunctionFrame Function(const Bool& is_pure); +TVM_DLL FunctionFrame Function(const Bool& is_pure, const Bool& is_private); /*! * \brief Add a parameter to the last function frame. diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 80edb31efbf0..502073edf207 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -199,6 +199,7 @@ def function( name: str, params: Optional[Union[Var, Tuple, List[Var]]] = None, attrs: Optional[Dict[str, Object]] = None, + private: bool = False, ) -> FunctionScope: """Annotate a Relax function. @@ -215,6 +216,12 @@ def function( attrs : Dict[str, Object], optional The function attrs + private : bool, optional + Whether the function is annotated as private. + If the function is private, it will not have a global symbol attribute. + If it is not private and not an inner function, then it will have + a global symbol attribute (mapped to the function's name) + Returns ------- ret: FunctionScope @@ -233,6 +240,11 @@ def function( ) if attrs is None: attrs = {} + # The block builder does not permit nesting functions, per above comment, + # so no further check should be needed + if not private: + attrs["global_symbol"] = name + return FunctionScope(self, name, params, attrs) def testing_scope(self, def_vars: List[tir.Var]) -> TestingScope: diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py index 3015f77428fb..abdf7b8862fe 100644 --- a/python/tvm/relax/frontend/torch/dynamo.py +++ b/python/tvm/relax/frontend/torch/dynamo.py @@ -154,7 +154,8 @@ def _capture(graph_module: fx.GraphModule, example_inputs): keep_params_as_input=keep_params_as_input, unwrap_unit_return_tuple=True, ) - mod[f"subgraph_{len(mod.get_global_vars())}"] = mod_["main"] + new_name = f"subgraph_{len(mod.get_global_vars())}" + mod[new_name] = mod_["main"].with_attr("global_symbol", new_name) return graph_module.forward dynamo.reset() diff --git a/python/tvm/relax/training/setup_trainer.py b/python/tvm/relax/training/setup_trainer.py index 81ecaf4ea5d7..2e2057086904 100644 --- a/python/tvm/relax/training/setup_trainer.py +++ b/python/tvm/relax/training/setup_trainer.py @@ -198,7 +198,10 @@ def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRM # Add optimizer function. self._optimizer.init(params) - mod[self.OPTIMIZER_FUNC] = self._optimizer.get_function() + # Need the global symbol to match the function's name + mod[self.OPTIMIZER_FUNC] = self._optimizer.get_function().with_attr( + "global_symbol", self.OPTIMIZER_FUNC + ) # Module attrs mod = mod.with_attrs( diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index e54e4aa07b3b..b06d9547acdb 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -165,19 +165,24 @@ ############################### Function ################################ -def function(is_pure: bool = True) -> frame.FunctionFrame: +def function(is_pure: bool = True, is_private: bool = False) -> frame.FunctionFrame: """Start a function frame. Parameters ---------- is_pure: bool Whether the function is annotated as pure. + is_private : bool + Whether the function is annotated as private. + Returns ------- frame: FunctionFrame The constructed function frame. """ - return _ffi_api.Function(is_pure) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Function( # type: ignore[attr-defined] # pylint: disable=no-member + is_pure, is_private + ) def arg(name: py_str, struct_info: StructInfo) -> Var: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 9275924466d5..69e262b1d327 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -240,6 +240,7 @@ class Parser(doc.NodeVisitor): dispatch_tokens: List[str] function_annotations: Optional[Dict[str, Dict[str, Any]]] var_table: VarTable + inside_function: bool # whether we are within a function def __init__( self, @@ -250,6 +251,7 @@ def __init__( self.dispatch_tokens = ["default"] self.function_annotations = function_annotations self.var_table = VarTable() + self.inside_function = False def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: """The main parse method for parser. diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 2711e855dddf..ff237a5600e7 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -45,9 +45,11 @@ # this formulation allows us to support having @R.function # appear as a decorator by itself or to have optional arguments # like @R.function(pure=False) -def function(f: Optional[FType] = None, pure: bool = True) -> Union[Function, FType]: +def function( + f: Optional[FType] = None, pure: bool = True, private: bool = False +) -> Union[Function, FType]: # pylint: disable=unused-argument - # (pure isn't used here, but is used later in parsing) + # (pure and private aren't used here, but are used later in parsing) # need to inspect the stack first because is_defined_in_class expects the outer class # to be in a particular position in the stack diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 427c56bcc8e3..863c249975a7 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -160,6 +160,9 @@ def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> Non @dispatch.register(token="relax", type_name="FunctionDef") def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + is_inner_function = self.inside_function + self.inside_function = True + # reserve a var for local function func_val = self.var_table.get().get(node.name) if not func_val and is_recursive(node): @@ -178,11 +181,14 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: local_func_var = relax.Var(node.name, relax.FuncStructInfo(params_sinfo, ret_sinfo)) self.var_table.add(node.name, local_func_var) - purity = find_purity_annotation(node) + purity = find_decorator_annotation(node, "pure") + # treat the function as private if we are inside another function + # or if it has a privacy annotation + privacy = is_inner_function or find_decorator_annotation(node, "private", default=False) with self.var_table.with_frame(): with self.with_dispatch_token("relax"): - with R.function(is_pure=purity): + with R.function(is_pure=purity, is_private=privacy): R.func_name(node.name) collect_symbolic_var_from_params(self, node) @@ -202,22 +208,22 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: self.report_error(stmt, "inline prim_func is disallowed in Relax IR") self.visit_body(node.body) + self.inside_function = is_inner_function -def find_purity_annotation(node: doc.FunctionDef) -> bool: +def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: bool = True) -> bool: """ - Check the value of `pure` in the function decorator. - Returns the annotated purity if present, otherwise defaulting to True. - This allows for specifying the purity in the function signature. + Check the value of given annotation (argument name) in the function decorator. + Returns the value of the annotation if present, otherwise giving the default value. """ - # look for the pure argument in the function decorator + # look for the named argument in the function decorator for dec in node.decorator_list: if not isinstance(dec, doc.Call) or dec.func.attr != "function": continue for keyword in dec.keywords: - if keyword.arg == "pure": + if keyword.arg == annotation: return keyword.value.value - return True + return default @dispatch.register(token="relax", type_name="tvm_declare_function") @@ -238,7 +244,8 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) params.append(relax.Var(arg.arg, param_sinfo)) - is_pure = find_purity_annotation(node) + is_pure = find_decorator_annotation(node, "pure") + func_signature = relax.Function.create_empty(params, ret_sinfo, is_pure=is_pure) return I.decl_function(node.name, func_signature) diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index 37582e301550..19faaad58b87 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -48,7 +48,9 @@ class AppendLossMutator : private ExprMutator { Function new_loss_func = CopyWithNewVars(loss_function); AppendLossMutator mutator(mod, new_loss_func, num_backbone_outputs); - auto new_func_transformed = Downcast(mutator.VisitExpr(new_func)); + auto new_func_transformed = + WithAttr(Downcast(mutator.VisitExpr(new_func)), tvm::attr::kGlobalSymbol, + new_func_name.value_or(func_name + "_loss")); auto new_module = GetRef(mod.CopyOnWrite()); auto new_var = GlobalVar(new_func_name.value_or(func_name + "_loss")); diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 7645ae8cb6c6..2cda7a972d3c 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -333,10 +333,14 @@ class GradientMutator : private ExprMutator { } GradientMutator mutator(mod, require_grads_value, target_index); - Function new_func_transformed = Downcast(mutator.VisitExpr(new_func)); + + // make the adjoint public + auto new_name = func_name + "_adjoint"; + Function new_func_transformed = WithAttr(Downcast(mutator.VisitExpr(new_func)), + tvm::attr::kGlobalSymbol, new_name); IRModule new_module = GetRef(mod.CopyOnWrite()); - new_module->Add(GlobalVar(func_name + "_adjoint"), new_func_transformed); + new_module->Add(GlobalVar(new_name), new_func_transformed); return new_module; } diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index f7c9a4189dbb..fb1f2927769d 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -251,7 +251,10 @@ class TransformParamsLifter : public ExprMutator { lift_plan_ = planner.Plan(func, num_input); // Step 2: Add the lifted function to the module - builder_->AddFunction(lift_plan_.f_transform_params, new_func_name); + // (The lifted function should be public so we add a global symbol to it) + auto lift_func = + WithAttr(lift_plan_.f_transform_params, tvm::attr::kGlobalSymbol, new_func_name); + builder_->AddFunction(lift_func, new_func_name); // Step 3: Update the current function. diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 00bbd2a551a6..966af809c9b4 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -56,6 +56,11 @@ void FunctionFrameNode::ExitWithScope() { "`return` to return an Expr"; this->block_builder->BeginScope(params); Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); + // if the function is not private, add a global symbol to its attributes + if (!is_private.value_or(Bool(false))->value && name.defined() && + !attrs.count(tvm::attr::kGlobalSymbol)) { + attrs.Set(tvm::attr::kGlobalSymbol, name.value()); + } auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); this->block_builder->EndScope(); tvm::relax::Function func(/*params=*/params, diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 52d9f0cfe10e..d66e8d059813 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -52,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) /////////////////////////////// Function //////////////////////////////// -FunctionFrame Function(const Bool& is_pure) { +FunctionFrame Function(const Bool& is_pure, const Bool& is_private) { ObjectPtr n = make_object(); const IRBuilder& ir_builder = IRBuilder::Current(); Optional mod = NullOpt; @@ -61,6 +61,7 @@ FunctionFrame Function(const Bool& is_pure) { } n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/mod); n->is_pure = is_pure; + n->is_private = is_private; return FunctionFrame(n); } diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index bd5d969563a6..bc5f12309f47 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -22,13 +22,36 @@ namespace tvm { namespace script { namespace printer { +bool AtTopLevelFunction(const IRDocsifier& d) { + // fewer than 2 frames: not in a function at all + if (d->frames.size() < 2) { + return false; + } + // if the first frame is a RelaxFrame, then this is not inside a module. + // 2 frames => we are at a function (more than 2 => nested function) + if (d->frames[0]->IsInstance()) { + return d->frames.size() == 2; + } + // otherwise the first two frames pertain to an IR module, + // so 3 frames => we are at a top-level function (more than 3 => nested function) + return d->frames.size() == 3; +} + TVM_REGISTER_NODE_TYPE(RelaxFrameNode); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](relax::Function n, ObjectPath n_p, IRDocsifier d) -> Doc { std::unordered_set func_vars; With f(d); - IdDoc func_name = d->Define(n, f(), FindFunctionName(d, n).value_or("main")); + + IdDoc func_name(""); + // if we are binding a local definition, then calling d->Define + // will result in a repeated definition and an incorrect displayed name + if (Optional name = GetBindingName(d)) { + func_name = std::move(IdDoc(name.value())); + } else { + func_name = std::move(d->Define(n, f(), FindFunctionName(d, n).value_or("main"))); + } (*f)->AddDispatchToken(d, "relax"); (*f)->is_func = true; (*f)->func_vars = &func_vars; @@ -52,17 +75,49 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) (*f)->func_vars = nullptr; // Step 4. Print attributes if (n->attrs.defined() && !n->attrs->dict.empty()) { - (*f)->stmts.push_back( - ExprStmtDoc(Relax(d, "func_attr") // - ->Call({d->AsDoc(n->attrs, n_p->Attr("attrs"))}))); + // If the function is a global function and has a global symbol, + // then don't print the global symbol (it will be implicit from not being private). + // For a function without an IR module whose global symbol + // doesn't match the function name, we should still print the global symbol attribute. + if (AtTopLevelFunction(d) && n->attrs->dict.count(tvm::attr::kGlobalSymbol) && + Downcast(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) { + Map new_attrs; + for (auto kv : n->attrs->dict) { + if (kv.first != tvm::attr::kGlobalSymbol) { + new_attrs.Set(kv.first, kv.second); + } + } + if (!new_attrs.empty()) { + (*f)->stmts.push_back(ExprStmtDoc( + Relax(d, "func_attr") // + ->Call({d->AsDoc(DictAttrs(new_attrs), n_p->Attr("attrs"))}))); + } + } else { + (*f)->stmts.push_back( + ExprStmtDoc(Relax(d, "func_attr") // + ->Call({d->AsDoc(n->attrs, n_p->Attr("attrs"))}))); + } } // Step 5. Prepare the decorator (include purity if it's impure) ExprDoc decorator = Relax(d, "function"); + Array pos_args = {}; + Array dec_keys; + Array dec_values; if (!n->is_pure) { - Array pos_args = {}; - decorator = std::move(decorator->Call( - pos_args, {"pure"}, {LiteralDoc::Boolean(false, Optional())})); + dec_keys.push_back("pure"); + dec_values.push_back(LiteralDoc::Boolean(false, Optional())); } + // if the function is global or is not in a module and does not have a global symbol, + // indicate that it's private + if (AtTopLevelFunction(d) && + (!n->attrs.defined() || !n->attrs->dict.count(tvm::attr::kGlobalSymbol))) { + dec_keys.push_back("private"); + dec_values.push_back(LiteralDoc::Boolean(true, Optional())); + } + if (dec_keys.size()) { + decorator = std::move(decorator->Call(pos_args, dec_keys, dec_values)); + } + // Step 6. Print body Array body = PrintSeqExpr(Downcast(n->body), n_p->Attr("body"), d, /*use_ret=*/true); diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index cbb19c674335..ea83807bf8cd 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -875,7 +875,7 @@ def rewriter(_, matchings): return R.multiply(matchings[x], R.const(2, "float32")) rewritten = rewrite_call(pattern, rewriter, main) - tvm.ir.assert_structural_equal(rewritten, expected1) + tvm.ir.assert_structural_equal(rewritten, expected1.with_attr("global_symbol", "main")) add1 = is_op("relax.add")(x, x) pattern = is_op("relax.add")(add1, add1) @@ -884,7 +884,7 @@ def rewriter(_, matchings): return R.multiply(matchings[x], R.const(4, "float32")) rewritten = rewrite_call(pattern, rewriter, main) - tvm.ir.assert_structural_equal(rewritten, expected2) + tvm.ir.assert_structural_equal(rewritten, expected2.with_attr("global_symbol", "main")) # No rewriting, return the original call node as is def rewriter(orig, _): @@ -959,7 +959,7 @@ def rewriter(_, matchings): return R.nn.attention(matchings[Q], matchings[K], matchings[V]) rewritten = rewrite_call(pattern, rewriter, main) - tvm.ir.assert_structural_equal(rewritten, expected) + tvm.ir.assert_structural_equal(rewritten, expected.with_attr("global_symbol", "main")) def test_attention_qkv(): @@ -1115,7 +1115,7 @@ def expected( inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 ) rewritten = rewrite_bindings(ctx, rewriter, qkv_x2) - tvm.ir.assert_structural_equal(rewritten, expected) + tvm.ir.assert_structural_equal(rewritten, expected.with_attr("global_symbol", "qkv_x2")) def test_combine_matmul_emit_order(): @@ -1173,7 +1173,7 @@ def expected( inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 ) rewritten = rewrite_bindings(ctx, rewriter, main) - tvm.ir.assert_structural_equal(rewritten, expected) + tvm.ir.assert_structural_equal(rewritten, expected.with_attr("global_symbol", "main")) # make sure it builds mod = tvm.IRModule() @@ -1272,7 +1272,7 @@ def rewriter(matchings, _): rewritten = rewrite_bindings(ctx, rewriter, main) print(rewritten.script()) - tvm.ir.assert_structural_equal(rewritten, expected) + tvm.ir.assert_structural_equal(rewritten, expected.with_attr("global_symbol", "main")) # make sure it builds mod = tvm.IRModule() diff --git a/tests/python/relax/test_training_loss.py b/tests/python/relax/test_training_loss.py index 0a2418aad756..0d456ceb3873 100644 --- a/tests/python/relax/test_training_loss.py +++ b/tests/python/relax/test_training_loss.py @@ -46,6 +46,7 @@ def test_l1_loss(): def expected( predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "float32") ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "l1_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets) lv1: R.Tensor((3, 5), "float32") = R.abs(lv) @@ -70,6 +71,7 @@ def expected( b: R.Tensor((2, 4), "float32"), targets: R.Tensor((2, 4), "float32"), ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "forward_loss"}) with R.dataflow(): lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") out: R.Tensor((2, 4), "float32") = R.add(lv, b) @@ -93,6 +95,7 @@ def test_mse_loss(): def expected( predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "float32") ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "mse_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets) lv1: R.Tensor((3, 5), "float32") = R.multiply(lv, lv) @@ -117,6 +120,7 @@ def expected( b: R.Tensor((2, 4), "float32"), targets: R.Tensor((2, 4), "float32"), ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "forward_loss"}) with R.dataflow(): lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") out: R.Tensor((2, 4), "float32") = R.add(lv, b) @@ -143,6 +147,7 @@ def expected( targets: R.Tensor((3,), "int64"), weights: R.Tensor((5,), "float32"), ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "cross_entropy_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) gv: R.Tensor((), "float32") = R.nn.nll_loss( @@ -165,6 +170,7 @@ def test_cross_entropy_loss_without_weights(): def expected( predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3,), "int64") ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "cross_entropy_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) gv: R.Tensor((), "float32") = R.nn.nll_loss( @@ -195,6 +201,7 @@ def expected( targets: R.Tensor((2,), "int64"), weights: R.Tensor((4,), "float32"), ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "forward_loss"}) with R.dataflow(): lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") out: R.Tensor((2, 4), "float32") = R.add(lv, b) @@ -224,6 +231,7 @@ def expected( targets: R.Tensor((3, 5), "int64"), weights: R.Tensor((5,), "float32"), ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "categorical_cross_entropy_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) lv: R.Tensor((), "float32") = -lv * targets.astype("float32") @@ -245,6 +253,7 @@ def test_categorical_cross_entropy_loss_without_weights(): def expected( predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "int64") ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "categorical_cross_entropy_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) gv: R.Tensor((), "float32") = R.mean(-lv * targets.astype("float32")) @@ -270,6 +279,7 @@ def expected( targets: R.Tensor((3, 5), "int64"), weights: R.Tensor((5,), "float32"), ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "categorical_cross_entropy_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) targets = relax.op.reshape( diff --git a/tests/python/relax/test_training_optimizer.py b/tests/python/relax/test_training_optimizer.py index b2246087c6d3..514422da8d01 100644 --- a/tests/python/relax/test_training_optimizer.py +++ b/tests/python/relax/test_training_optimizer.py @@ -67,6 +67,7 @@ def sgd_expected( R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), R.Tuple(R.Tensor((), "int64")), ): + R.func_attr({"global_symbol": "SGD"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] @@ -104,6 +105,7 @@ def sgd_expected( R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), R.Tuple(R.Tensor((), "int64")), ): + R.func_attr({"global_symbol": "SGD"}) with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] num_steps_new: R.Tensor((), "int64") = R.add(num_steps, R.const(1, "int64")) @@ -146,6 +148,7 @@ def msgd_expected( R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), R.Tuple(R.Tensor((), "int64"), R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), ): + R.func_attr({"global_symbol": "MomentumSGD"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] @@ -195,6 +198,7 @@ def msgd_expected( R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), R.Tuple(R.Tensor((), "int64"), R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), ): + R.func_attr({"global_symbol": "MomentumSGD"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] @@ -250,6 +254,7 @@ def msgd_expected( R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), R.Tuple(R.Tensor((), "int64"), R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), ): + R.func_attr({"global_symbol": "MomentumSGD"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] @@ -321,6 +326,7 @@ def adam_expected( R.Tensor((3,), "float32"), ), ): + R.func_attr({"global_symbol": "Adam"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] @@ -418,6 +424,7 @@ def adam_expected( R.Tensor((3,), "float32"), ), ): + R.func_attr({"global_symbol": "Adam"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] @@ -519,6 +526,7 @@ def adam_expected( R.Tensor((3,), "float64"), ), ): + R.func_attr({"global_symbol": "Adam"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] diff --git a/tests/python/relax/test_training_optimizer_numeric.py b/tests/python/relax/test_training_optimizer_numeric.py index 3b300e826120..23db8987f12d 100644 --- a/tests/python/relax/test_training_optimizer_numeric.py +++ b/tests/python/relax/test_training_optimizer_numeric.py @@ -69,7 +69,7 @@ def _test_optimizer(target, dev, np_func, opt_type, *args, **kwargs): x = relax.Var("x", R.Tensor((3, 3), "float32")) y = relax.Var("y", R.Tensor((3,), "float32")) opt = opt_type(*args, **kwargs).init([x, y]) - mod = IRModule.from_expr(opt.get_function()) + mod = IRModule.from_expr(opt.get_function().with_attr("global_symbol", "main")) tvm_func = _legalize_and_build(mod, target, dev)["main"] param_arr = [np.random.rand(3, 3).astype(np.float32), np.random.rand(3).astype(np.float32)] diff --git a/tests/python/relax/test_transform_attach_global_symbol.py b/tests/python/relax/test_transform_attach_global_symbol.py index 035e21609daa..680df969474a 100644 --- a/tests/python/relax/test_transform_attach_global_symbol.py +++ b/tests/python/relax/test_transform_attach_global_symbol.py @@ -43,7 +43,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - @R.function + @R.function(private=True) def main( x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") ) -> R.Tensor: @@ -74,7 +74,6 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: def main( x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") ) -> R.Tensor: - R.func_attr({"global_symbol": "main"}) m, n, k = T.int64(), T.int64(), T.int64() gv0 = R.call_tir(Expected.tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) return gv0 @@ -94,7 +93,7 @@ class Before: def tir_zeros(x: T.Buffer((2), "float32")) -> None: x[0] = T.float32(0) - @R.function + @R.function(private=True) def main() -> R.Tensor: gv0 = R.call_tir(Before.tir_zeros, (), R.Tensor((2,), dtype="float32")) return gv0 @@ -110,7 +109,6 @@ def tir_zeros(x: T.Buffer((2), "float32")) -> None: @R.function def main() -> R.Tensor: - R.func_attr({"global_symbol": "main"}) gv0 = R.call_tir(Expected.tir_zeros, (), R.Tensor((2,), dtype="float32")) return gv0 diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py b/tests/python/relax/test_transform_combine_parallel_matmul.py index 719daaf4496d..97211f0dd0ff 100644 --- a/tests/python/relax/test_transform_combine_parallel_matmul.py +++ b/tests/python/relax/test_transform_combine_parallel_matmul.py @@ -97,7 +97,7 @@ def expected1( R.output(lv3) return lv3 - tvm.ir.assert_structural_equal(mod["main"], expected1) + tvm.ir.assert_structural_equal(mod["main"], expected1.with_attr("global_symbol", "main")) # Test a batched LHS case, slicing is done on the axis 2 mod = get_parallel_matmul(3, lhs_shape=(2, 1024, 640)) @@ -121,7 +121,7 @@ def expected2( R.output(lv3) return lv3 - tvm.ir.assert_structural_equal(mod["main"], expected2) + tvm.ir.assert_structural_equal(mod["main"], expected2.with_attr("global_symbol", "main")) def test_bias(): @@ -151,7 +151,7 @@ def expected1( R.output(lv6) return lv6 - tvm.ir.assert_structural_equal(mod["main"], expected1) + tvm.ir.assert_structural_equal(mod["main"], expected1.with_attr("global_symbol", "main")) mod = get_parallel_matmul(3, with_bias=[True, False, True]) mod = CombineParallelMatmul()(mod) @@ -178,7 +178,7 @@ def expected2( R.output(lv5) return lv5 - tvm.ir.assert_structural_equal(mod["main"], expected2) + tvm.ir.assert_structural_equal(mod["main"], expected2.with_attr("global_symbol", "main")) def test_activation(): @@ -204,7 +204,7 @@ def expected1( R.output(lv6) return lv6 - tvm.ir.assert_structural_equal(mod["main"], expected1) + tvm.ir.assert_structural_equal(mod["main"], expected1.with_attr("global_symbol", "main")) mod = get_parallel_matmul(3, activation=["gelu", "relu", "relu"]) mod = CombineParallelMatmul()(mod) @@ -230,7 +230,7 @@ def expected2( R.output(lv6) return lv6 - tvm.ir.assert_structural_equal(mod["main"], expected2) + tvm.ir.assert_structural_equal(mod["main"], expected2.with_attr("global_symbol", "main")) mod = get_parallel_matmul(3, activation=["relu", None, None]) mod = CombineParallelMatmul()(mod) @@ -255,7 +255,7 @@ def expected3( R.output(lv4) return lv4 - tvm.ir.assert_structural_equal(mod["main"], expected3) + tvm.ir.assert_structural_equal(mod["main"], expected3.with_attr("global_symbol", "main")) def test_bias_activation(): @@ -286,7 +286,7 @@ def expected1( R.output(lv9) return lv9 - tvm.ir.assert_structural_equal(mod["main"], expected1) + tvm.ir.assert_structural_equal(mod["main"], expected1.with_attr("global_symbol", "main")) mod = get_parallel_matmul(3, with_bias=[True, True, True], activation=["relu", None, "relu"]) mod = CombineParallelMatmul()(mod) @@ -316,7 +316,7 @@ def expected2( R.output(lv8) return lv8 - tvm.ir.assert_structural_equal(mod["main"], expected2) + tvm.ir.assert_structural_equal(mod["main"], expected2.with_attr("global_symbol", "main")) mod = get_parallel_matmul(3, with_bias=[True, False, True], activation=["relu", None, "relu"]) mod = CombineParallelMatmul()(mod) @@ -345,7 +345,7 @@ def expected3( R.output(lv7) return lv7 - tvm.ir.assert_structural_equal(mod["main"], expected3) + tvm.ir.assert_structural_equal(mod["main"], expected3.with_attr("global_symbol", "main")) def test_rhs_batched(): @@ -378,6 +378,7 @@ def expected1( w2: R.Tensor((2, 640, 640), dtype="float32"), w3: R.Tensor((3, 4, 640, 640), dtype="float32"), ) -> R.Tensor: + R.func_attr({"global_symbol": "main"}) with R.dataflow(): lv = R.concat((w0, w2), axis=2) lv1 = R.matmul(x, lv, out_dtype="float32") @@ -458,6 +459,7 @@ def expected1( b0: R.Tensor((640,), dtype="float32"), b1: R.Tensor((640,), dtype="float32"), ) -> R.Tensor: + R.func_attr({"global_symbol": "main"}) with R.dataflow(): lv = R.concat((w0, w1, w2), axis=1) lv1 = R.matmul(x1, lv, out_dtype="float32") @@ -515,6 +517,7 @@ def expected( w3: R.Tensor((640, 640), dtype="float32"), w4: R.Tensor((640, 640), dtype="float32"), ) -> R.Tensor: + R.func_attr({"global_symbol": "main"}) with R.dataflow(): lv = R.concat((w0, w1, w2), axis=1) lv1 = R.matmul(x1, lv, out_dtype="float32") diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 9c6e0e0567fe..12a3de6acb30 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -168,7 +168,7 @@ def tir_add( vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = x[vi, vj] + y[vi, vj] - @R.function + @R.function(private=True) def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): gv0 = R.add(x, w) return gv0 @@ -202,7 +202,7 @@ def tir_add( vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = x[vi, vj] + y[vi, vj] - @R.function + @R.function(private=True) def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): gv0 = R.add(x, w) return gv0 @@ -239,7 +239,7 @@ def tir_add( vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = x[vi, vj] + y[vi, vj] - @R.function + @R.function(private=True) def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): gv0 = R.add(x, w) return gv0 @@ -310,7 +310,7 @@ def unused_func1( vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = x[vi, vj] + y[vi, vj] - @R.function + @R.function(private=True) def unused_func2(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): gv0 = R.add(x, w) return gv0 diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 169539b07243..b04677b6f51e 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -48,7 +48,7 @@ def expected(): x = relax.Var("x", R.Tensor([10, 20], "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) lv1 = bb.emit_te(topi.exp, lv0) @@ -100,7 +100,9 @@ def expected(dtype): x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype)) p0 = relax.Var("p0", R.Tensor((), dtype)) - with bb.function("fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": 1}): + with bb.function( + "fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": 1}, private=True + ): with bb.dataflow(): lv0 = bb.emit_te( topi.nn.conv2d, @@ -119,7 +121,7 @@ def expected(dtype): x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype)) y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype)) - with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": 1}): + with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te( topi.nn.conv2d, @@ -196,7 +198,9 @@ def expected(): x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) w = relax.Var("w", R.Tensor((1, 16, 32, 32), "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_upsampling_concatenate_add", [w, x, p0], attrs={"Primitive": 1}): + with bb.function( + "fused_upsampling_concatenate_add", [w, x, p0], attrs={"Primitive": 1}, private=True + ): with bb.dataflow(): lv0 = bb.emit_te(topi.nn.upsampling, w, scale_h=2.0, scale_w=2.0) lv1 = bb.emit_te(topi.concatenate, (lv0, x), axis=1) @@ -287,7 +291,10 @@ def expected(dim: int): # Grouped function dense = relax.Var("dense", R.Tensor((1, 3 * dim), "float32")) with bb.function( - "fused_split_sigmoid_tanh_exp_multiply_add", [dense], attrs={"Primitive": 1} + "fused_split_sigmoid_tanh_exp_multiply_add", + [dense], + attrs={"Primitive": 1}, + private=True, ): with bb.dataflow(): lv0 = bb.emit_te(topi.split, dense, indices_or_sections=3, axis=1) @@ -340,7 +347,7 @@ def expected(dim: int): # Grouped function x = relax.Var("x", R.Tensor((1, 3 * dim), "float32")) - with bb.function("fused_split", [x], attrs={"Primitive": 1}): + with bb.function("fused_split", [x], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1) gv = bb.emit_output(relax.TupleGetItem(lv0, 0)) @@ -398,6 +405,7 @@ def expected(): "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1", [x, p0, p1, p2, p3, p4], attrs={"Primitive": 1}, + private=True, ): with bb.dataflow(): lv0 = bb.emit_te(topi.squeeze, x) @@ -500,6 +508,7 @@ def expected(): "fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1", [x, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11], attrs={"Primitive": 1}, + private=True, ): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) @@ -523,7 +532,7 @@ def expected(): # Grouped function 2 concat = relax.Var("concat", R.Tensor((1, 144, 64, 64), "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_pool2d_add2", [concat, p0], attrs={"Primitive": 1}): + with bb.function("fused_pool2d_add2", [concat, p0], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te( topi.nn.pool2d, @@ -609,7 +618,7 @@ def expected(): # Grouped function 1 x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) w = relax.Var("w", R.Tensor((16, 16, 3, 3), "float32")) - with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1}): + with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te( topi.nn.conv2d, @@ -626,7 +635,7 @@ def expected(): # Grouped function 2 x = relax.Var("x", R.Tensor((1, 32, 64, 64), "float32")) w = relax.Var("w", R.Tensor((16, 32, 3, 3), "float32")) - with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1}): + with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te( topi.nn.conv2d, @@ -689,7 +698,10 @@ def expected(): x = relax.Var("x", R.Tensor((10, 20), "int32")) p0 = relax.Var("p0", R.Tensor((), "int32")) with bb.function( - "fused_add_squeeze_transpose_transpose1_left_shift", [x, p0], attrs={"Primitive": 1} + "fused_add_squeeze_transpose_transpose1_left_shift", + [x, p0], + attrs={"Primitive": 1}, + private=True, ): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) @@ -734,7 +746,7 @@ def expected(): # Grouped function x = relax.Var("x", R.Tensor((16, 16), "float32")) - with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1}): + with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.nn.softmax, x) gv = bb.emit_output(bb.call_te(topi.cast, lv0, dtype="float16")) @@ -781,7 +793,7 @@ def expected(): x = relax.Var("x", R.Tensor([10, 20], "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) lv1 = bb.emit_te(topi.exp, lv0) @@ -791,7 +803,7 @@ def expected(): x = relax.Var("x", R.Tensor([20, 10], "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_add1_exp1_squeeze1", [x, p0], attrs={"Primitive": 1}): + with bb.function("fused_add1_exp1_squeeze1", [x, p0], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) lv1 = bb.emit_te(topi.exp, lv0) @@ -938,7 +950,7 @@ def relu(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "floa T.writes(B[v_i0, v_i1, v_i2, v_i3]) B[v_i0, v_i1, v_i2, v_i3] = T.max(A[v_i0, v_i1, v_i2, v_i3], T.float32(0)) - @R.function + @R.function(private=True) def fused_layer_norm_relu(x: R.Tensor((1, 512, 64, 64), dtype="float32"), mean: R.Tensor((64, 64), dtype="float32"), var: R.Tensor((64, 64), dtype="float32")) -> R.Tensor((1, 512, 64, 64), dtype="float32"): R.func_attr({"Primitive": 1}) cls = Expected @@ -1080,7 +1092,7 @@ def transpose(rxplaceholder: T.Buffer((T.int64(320), T.int64(1280)), "float32"), T.writes(T_transpose[v_ax0, v_ax1]) T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0] - @R.function + @R.function(private=True) def fused_conv2d_add_add2(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), w1: R.Tensor((320, 320, 3, 3), dtype="float32"), lv28: R.Tensor((1, 320, 1, 1), dtype="float32"), lv35: R.Tensor((2, 320, 1, 1), dtype="float32")) -> R.Tensor((2, 320, 64, 64), dtype="float32"): R.func_attr({"Primitive": 1}) cls = Expected @@ -1091,7 +1103,7 @@ def fused_conv2d_add_add2(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), w1 R.output(gv) return gv - @R.function + @R.function(private=True) def fused_matmul_add1(inp_1: R.Tensor((2, 1280), dtype="float32"), lv31: R.Tensor((1280, 320), dtype="float32"), b2: R.Tensor((320,), dtype="float32")) -> R.Tensor((2, 320), dtype="float32"): cls = Expected R.func_attr({"Primitive": 1}) @@ -1226,7 +1238,7 @@ def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)), "float32"), T.writes(T_transpose[v_ax0, v_ax1]) T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0] - @R.function + @R.function(private=True) def fused_matmul1_add1(inp_1: R.Tensor((1, 128), dtype="float32"), lv4: R.Tensor((128, 10), dtype="float32"), linear2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"): R.func_attr({"Primitive": 1}) cls = Expected @@ -1268,7 +1280,7 @@ def main(x: R.Tensor(["n", "m"], "float32")): @I.ir_module class Expected: - @R.function + @R.function(private=True) def fused_add_exp_squeeze( x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32") ) -> R.Tensor(["n", "m"], dtype="float32"): @@ -1306,7 +1318,7 @@ def main(s: R.Shape(["n"])): @I.ir_module class Expected: - @R.function + @R.function(private=True) def fused_full_trilu_broadcast_to( s: R.Shape(["n"]), ) -> R.Tensor([1, 1, "n", "n"], "float32"): @@ -1354,7 +1366,7 @@ def main(s: R.Shape(["n"]), kv_cache: R.Object): @I.ir_module class Expected: - @R.function + @R.function(private=True) def fused_full_trilu_broadcast_to( s: R.Shape(["n"]), ) -> R.Tensor([1, 1, "n", "n"], "float32"): diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 5fb2b3332c23..592132516bee 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -60,7 +60,7 @@ def main( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu_dnnl( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -124,7 +124,7 @@ def main( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -138,7 +138,7 @@ def fused_relax_nn_conv2d_relax_nn_relu( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu1( conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -174,7 +174,7 @@ def main( R.output(conv2d) return conv2d - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -187,7 +187,7 @@ def fused_relax_nn_conv2d( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_conv2d1( conv11: R.Tensor((1, 64, 56, 56), dtype="float32"), weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -236,7 +236,7 @@ def main( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu( conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -250,7 +250,7 @@ def fused_relax_nn_conv2d_relax_nn_relu( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -303,7 +303,7 @@ def main( R.output(out) return out - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -377,7 +377,7 @@ def main( @tvm.script.ir_module class Conv2dx2_partitioned: - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_cutlass( data: R.Tensor((16, 32, 32, 16), dtype="float16"), weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), @@ -565,7 +565,7 @@ def relu( T.writes(out[i, j, k, l]) out[i, j, k, l] = T.max(data[i, j, k, l], T.float32(0)) - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -617,7 +617,7 @@ def main( @I.ir_module class Conv2dReLU_partitioned: - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -685,7 +685,7 @@ def main( @I.ir_module class Conv2dWithConstantWeight_partitioned: - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data: R.Tensor((1, 64, 56, 56), dtype="float32"), param_0: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -721,6 +721,7 @@ def main( def test_split(): @R.function def func(inp: R.Tensor((16, 32), "float32")): + R.func_attr({"global_symbol": "main"}) with R.dataflow(): tup = R.split(inp, [16], axis=1) out = R.add(tup[0], tup[1]) @@ -729,7 +730,7 @@ def func(inp: R.Tensor((16, 32), "float32")): @tvm.script.ir_module class Expected1: - @R.function + @R.function(private=True) def fused_relax_split( inp: R.Tensor((16, 32), dtype="float32") ) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), dtype="float32")): @@ -756,7 +757,7 @@ def main(inp: R.Tensor((16, 32), dtype="float32")) -> R.Tensor((16, 16), dtype=" @I.ir_module class Expected2: - @R.function + @R.function(private=True) def fused_relax_split_relax_add( inp: R.Tensor((16, 32), dtype="float32") ) -> R.Tensor((16, 16), dtype="float32"): diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 00dc7146541b..f59e3f2e9e6f 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -32,7 +32,7 @@ def before(): x = relax.Var("x", R.Tensor([10, 20], "float32")) p0 = relax.Var("p0", R.Tensor([], "float32")) - with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": True}): + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": True}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) lv1 = bb.emit_te(topi.exp, lv0) @@ -565,7 +565,7 @@ def before(): x = relax.Var("x", R.Tensor([10, 20], "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) lv1 = bb.emit_te(topi.exp, lv0) @@ -575,7 +575,7 @@ def before(): x = relax.Var("x", R.Tensor([20, 10], "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_add1_exp1_squeeze1", [x, p0], attrs={"Primitive": 1}): + with bb.function("fused_add1_exp1_squeeze1", [x, p0], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) lv1 = bb.emit_te(topi.exp, lv0) diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index ddc274fee272..d67248417173 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -42,7 +42,7 @@ def test_basic(): # the target IRModule @tvm.script.ir_module class Expected: - @R.function + @R.function(private=True) def lifted_func_0( x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): @@ -97,12 +97,12 @@ def main( ) return res - @R.function + @R.function(private=True) def lifted_func_1(x1: R.Tensor((2, 3), "float32"), c1: R.Tensor((2, 3), "float32")): r_1: R.Tensor((2, 3), "float32") = R.add(x1, c1) return r_1 - @R.function + @R.function(private=True) def lifted_func_0(y: R.Tensor((2, 3), "float32")) -> R.Object: inner_func = R.make_closure(Expected.lifted_func_1, (y,)) return inner_func @@ -140,7 +140,7 @@ def test_recursive(): # the expected IRModule @tvm.script.ir_module class Expected: - @R.function + @R.function(private=True) def lifted_func_0( i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x: R.Tensor((2, 3), "float32") ) -> R.Tensor((2, 3), "float32"): @@ -224,14 +224,14 @@ def glob_func_2( gv11: R.Tensor((10, 5), "float32") = inner(x11, y11) return gv11 - @R.function + @R.function(private=True) def lifted_func_0( x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): s: R.Tensor((10, 5), "float32") = R.add(x2, y2) return s - @R.function + @R.function(private=True) def lifted_func_1( x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): @@ -308,7 +308,7 @@ def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor(dtype="float32", ndim= def test_impure_function(): @tvm.script.ir_module class Expected: - @R.function(pure=False) + @R.function(pure=False, private=True) def lifted_func_0() -> R.Tuple: y = R.print(format="Wow!") return y diff --git a/tests/python/relax/test_transform_merge_composite_functions.py b/tests/python/relax/test_transform_merge_composite_functions.py index 61df388c7888..d55226613137 100644 --- a/tests/python/relax/test_transform_merge_composite_functions.py +++ b/tests/python/relax/test_transform_merge_composite_functions.py @@ -41,7 +41,7 @@ def main( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -57,7 +57,7 @@ def fused_relax_nn_conv2d_relax_nn_relu( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu1( conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -164,7 +164,7 @@ def main( R.output(gv2) return gv2 - @R.function + @R.function(private=True) def fused_relax_nn_gelu( lv: R.Tensor((1, 64, 54, 54), dtype="float32") ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): @@ -174,7 +174,7 @@ def fused_relax_nn_gelu( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_relu( lv1: R.Tensor((1, 64, 54, 54), dtype="float32") ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): @@ -184,7 +184,7 @@ def fused_relax_nn_relu( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_add( lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), @@ -195,7 +195,7 @@ def fused_relax_add( R.output(gv3) return gv3 - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -332,7 +332,7 @@ def main( R.output(gv2) return gv2 - @R.function + @R.function(private=True) def fused_relax_nn_gelu( lv: R.Tensor((1, 64, 54, 54), dtype="float32") ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): @@ -342,7 +342,7 @@ def fused_relax_nn_gelu( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_relu( lv1: R.Tensor((1, 64, 54, 54), dtype="float32") ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): @@ -352,7 +352,7 @@ def fused_relax_nn_relu( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_add( lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), @@ -363,7 +363,7 @@ def fused_relax_add( R.output(gv3) return gv3 - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -509,7 +509,7 @@ def main( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_nn_relu( x11: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -519,7 +519,7 @@ def fused_relax_nn_relu( R.output(gv2) return gv2 - @R.function + @R.function(private=True) def fused_relax_nn_gelu( x21: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -529,7 +529,7 @@ def fused_relax_nn_gelu( R.output(gv3) return gv3 - @R.function + @R.function(private=True) def fused_relax_add( lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -627,7 +627,7 @@ def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32 R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_nn_relu( x11: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -637,7 +637,7 @@ def fused_relax_nn_relu( R.output(gv2) return gv2 - @R.function + @R.function(private=True) def fused_relax_nn_gelu( x21: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -647,7 +647,7 @@ def fused_relax_nn_gelu( R.output(gv3) return gv3 - @R.function + @R.function(private=True) def fused_relax_add( lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -759,7 +759,7 @@ def main( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_nn_relu( add2: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -769,7 +769,7 @@ def fused_relax_nn_relu( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_add( x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -779,7 +779,7 @@ def fused_relax_add( R.output(gv2) return gv2 - @R.function + @R.function(private=True) def fused_relax_nn_gelu( x31: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -924,7 +924,7 @@ def main( R.output(conv) return conv - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -1071,7 +1071,7 @@ def test_reshape(): # Verify that the non-CallNode input (shape in reshape) can be handled properly. @I.ir_module class Module: - @R.function + @R.function(private=True) def fused_relax_matmul( lv: R.Tensor((1, 784), dtype="float32"), lv1: R.Tensor((784, 512), dtype="float32") ) -> R.Tensor((1, 512), dtype="float32"): @@ -1081,7 +1081,7 @@ def fused_relax_matmul( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_reshape( inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"), param_0: R.Shape([1, 784]) ) -> R.Tensor((1, 784), dtype="float32"): diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py index 874e83c7f955..a6feb0b8abca 100644 --- a/tests/python/relax/test_transform_normalize.py +++ b/tests/python/relax/test_transform_normalize.py @@ -44,7 +44,7 @@ def test_normalize_function(): after_mod = relax.transform.Normalize()(before_mod) - @R.function + @R.function(private=True) def expected(x: R.Tensor(("m", "n"), "float16")) -> R.Tensor(dtype="float16", ndim=2): gv = R.add(x, x) gv1 = R.add(x, x) @@ -86,7 +86,7 @@ def test_normalize_if(): before_mod = tvm.IRModule.from_expr(f) after_mod = relax.transform.Normalize()(before_mod) - @R.function + @R.function(private=True) def expected( cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") ) -> R.Tensor(dtype="float32", ndim=1): @@ -151,7 +151,7 @@ def test_normalize_seq_body(): before_mod = tvm.IRModule.from_expr(f) after_mod = relax.transform.Normalize()(before_mod) - @R.function + @R.function(private=True) def expected( x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32") ) -> R.Tensor(ndim=0, dtype="int32"): @@ -175,7 +175,7 @@ def test_normalize_func_body(): before_mod = tvm.IRModule.from_expr(f) after_mod = relax.transform.Normalize()(before_mod) - @R.function + @R.function(private=True) def expected( x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32") ) -> R.Tensor(ndim=0, dtype="int32"): @@ -207,7 +207,7 @@ def test_normalize_if_branches(): before_mod = tvm.IRModule.from_expr(f) after_mod = relax.transform.Normalize()(before_mod) - @R.function + @R.function(private=True) def expected( cond: R.Tensor((), dtype="bool"), x: R.Tensor((), dtype="int32"), @@ -257,7 +257,7 @@ def test_normalize_if_condition(): before_mod = tvm.IRModule.from_expr(f) after_mod = relax.transform.Normalize()(before_mod) - @R.function + @R.function(private=True) def expected( cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") ) -> R.Tensor(dtype="float32", ndim=1): @@ -341,7 +341,7 @@ def test_normalize_combine_nearby_blocks(): after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) - @R.function + @R.function(private=True) def expected(x: R.Tensor((), "int32")): with R.dataflow(): v0 = x @@ -383,7 +383,7 @@ def test_normalize_nested_seq(): ) after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) - @R.function + @R.function(private=True) def expected(): x = relax.const(1) z = relax.const(2) @@ -434,7 +434,7 @@ def test_normalize_nested_seq_dataflow(): ) after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) - @R.function + @R.function(private=True) def expected(): x = relax.const(1) q = relax.const(2) @@ -507,7 +507,7 @@ def test_normalize_deeply_nested_seq(): ) after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) - @R.function + @R.function(private=True) def expected(): x = relax.const(1) u = relax.const(2) diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py index 931d206afbb1..66abb2f027ae 100644 --- a/tests/python/relax/test_transform_rewrite_cuda_graph.py +++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py @@ -82,7 +82,7 @@ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T T.writes(compute[i0, i1]) compute[i0, i1] = T.exp(rxplaceholder[i0, i1], dtype="float32") - @R.function + @R.function(private=True) def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object): R.func_attr({"relax.force_pure": True}) storage: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) @@ -91,7 +91,7 @@ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object): gv: R.Tuple(R.Object, R.Object, R.Object) = (storage, storage1, storage2) return gv - @R.function + @R.function(private=True) def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): R.func_attr({"relax.force_pure": True}) cls = Expected @@ -193,7 +193,7 @@ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T T.writes(compute[i0, i1]) compute[i0, i1] = T.exp(rxplaceholder[i0, i1]) - @R.function + @R.function(private=True) def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): R.func_attr({"relax.force_pure": True}) storage: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) @@ -201,7 +201,7 @@ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): gv: R.Tuple(R.Object, R.Object) = (storage, storage1) return gv - @R.function + @R.function(private=True) def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): R.func_attr({"relax.force_pure": True}) cls = Expected diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 7a8bcdee26ec..9305cdbcb129 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -203,7 +203,7 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): x = relax.Var("x", R.Tensor((128, 128), "float32")) bb = relax.BlockBuilder() - with bb.function("foo", (x,)): + with bb.function("foo", (x,), {"global_symbol": "foo"}): out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") bb.emit_func_output(out) @@ -232,7 +232,7 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): x = relax.Var("x", R.Tensor((128, 128), "float32")) bb = relax.BlockBuilder() - with bb.function("foo", (x,)): + with bb.function("foo", (x,), {"global_symbol": "foo"}): out = bb.emit_te( lambda x: x + 1, x, @@ -254,7 +254,7 @@ def main(x: R.Tensor((10, 20), "float32")) -> R.Tensor((10, 20), dtype="float32" bb = relax.BlockBuilder() x = relax.Var("x", relax.TensorStructInfo([10, 20], "float32")) - with bb.function("main", [x]): + with bb.function("main", [x], {"global_symbol": "main"}): lv1 = bb.emit_te(topi.add, x, x) out = bb.emit_te(topi.multiply, lv1, lv1) bb.emit_func_output(out) @@ -294,7 +294,7 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): x = relax.Var("x", R.Tensor((128, 128), "float32")) bb = relax.BlockBuilder() - with bb.function("foo", (x,)): + with bb.function("foo", (x,), {"global_symbol": "foo"}): out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") bb.emit_func_output(out) mod = bb.get() @@ -834,7 +834,7 @@ def foo(x: R.Tensor((), "float32")): def test_call_tir_empty_tuple_arg(): bb = relax.BlockBuilder() dummy_param = relax.Var("dummy_param", R.Tensor(())) - with bb.function("foo", [dummy_param]): + with bb.function("foo", [dummy_param], {"global_symbol": "foo"}): output = bb.emit_te(topi.full, shape=(16, 32), dtype="float32", fill_value=1.0) bb.emit_func_output(output) @@ -1493,5 +1493,22 @@ def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: _check(foo, bb.get()["foo"]) +def test_private_function(): + @I.ir_module + class Addition: + @R.function(private=True) + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.add(x, x) + return y + + x = relax.Var("x", R.Tensor((), "int32")) + bb = relax.BlockBuilder() + with bb.function("main", (x), private=True): + y = bb.emit(R.add(x, x)) + bb.emit_func_output(y) + + _check(Addition, bb.get()) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_datatype.py b/tests/python/relax/test_tvmscript_parser_op_datatype.py index ec71e868d45b..85c5faa8667b 100644 --- a/tests/python/relax/test_tvmscript_parser_op_datatype.py +++ b/tests/python/relax/test_tvmscript_parser_op_datatype.py @@ -47,7 +47,7 @@ def expected(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float16" gv = bb.emit(relax.op.astype(x, "float16")) bb.emit_func_output(gv) - _check(expected, bb.get()["main"]) + _check(expected.with_attr("global_symbol", "main"), bb.get()["main"]) if __name__ == "__main__": diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 7525c63be440..c37694317361 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -27,12 +27,15 @@ def _assert_print(obj, expected): if not isinstance(obj, str): obj = obj.script(verbose_expr=True) obj = obj.strip() - assert obj == expected.strip(), "\n" + obj + # compare line by line in case there is trailing whitespace in the _middle_ + for obj_line, expected_line in zip(obj.splitlines(), expected.strip().splitlines()): + assert obj_line.strip() == expected_line.strip(), "\n" + obj def test_function(): @R.function def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore + R.func_attr({"some_attr": 1}) return a _assert_print( @@ -41,19 +44,39 @@ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore # from tvm.script import relax as R @R.function +def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): + R.func_attr({"some_attr": 1}) + return a""", + ) + + +def test_lone_private_function(): + @R.function(private=True) + def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore + R.func_attr({"some_attr": 1}) + return a + + # name prints as main because without a global symbol, the printer cannot assume a name + _assert_print( + func, + """ +# from tvm.script import relax as R + +@R.function(private=True) def main(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): + R.func_attr({"some_attr": 1}) return a""", ) def test_extern_func(): @R.function - def relax_func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore + def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore return a obj = IRModule( { - "func": relax_func, + "func": func, "my_ext": relax.ExternFunc("my_ext"), } ) @@ -73,6 +96,40 @@ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): ) +def test_nested_function(): + @I.ir_module + class NestedFunction: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + @R.function + def nested(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + return y + + z = nested(x) + return z + + _assert_print( + NestedFunction, + """ +# from tvm.script import ir as I +# from tvm.script import relax as R + +@I.ir_module +class Module: + @R.function + def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + # from tvm.script import relax as R + + @R.function + def nested(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + return y + + z: R.Tensor((), dtype="int32") = nested(x) + return z +""", + ) + + def test_object_struct_info(): obj = relax.ObjectStructInfo() _assert_print( @@ -576,5 +633,101 @@ def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): ) +def test_private_function(): + @I.ir_module + class AddMod: + @R.function(private=True) + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y: R.Tensor((), dtype="int32") = R.add(x, x) + return y + + _assert_print( + AddMod, + """ +# from tvm.script import ir as I +# from tvm.script import relax as R + +@I.ir_module +class Module: + @R.function(private=True) + def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + y: R.Tensor((), dtype="int32") = R.add(x, x) + return y +""", + ) + + +def test_directly_construct_private_funcs(): + # public + @R.function + def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y: R.Tensor((), dtype="int32") = R.add(x, x) + return y + + # private + @R.function(private=True) + def bar(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y: R.Tensor((), dtype="int32") = R.multiply(x, x) + return y + + # public but there's another attribute + @R.function + def baz(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.func_attr({"relax.force_pure": True}) + y: R.Tuple = R.print(format="Hi there!") + z: R.Tensor((), dtype="int32") = R.add(x, x) + return z + + # private with an attribute + @R.function(private=True) + def quux(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.func_attr({"relax.force_pure": True}) + y: R.Tuple = R.print(format="Lol") + z: R.Tensor((), dtype="int32") = R.multiply(x, x) + return z + + obj = IRModule( + { + "foo": foo, + "bar": bar, + "baz": baz, + "quux": quux, + } + ) + _assert_print( + obj, + """ +# from tvm.script import ir as I +# from tvm.script import relax as R + +@I.ir_module +class Module: + @R.function(private=True) + def bar(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + y: R.Tensor((), dtype="int32") = R.multiply(x, x) + return y + + @R.function + def baz(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + R.func_attr({"relax.force_pure": 1}) + y: R.Tuple = R.print(format=R.str("Hi there!")) + z: R.Tensor((), dtype="int32") = R.add(x, x) + return z + + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + y: R.Tensor((), dtype="int32") = R.add(x, x) + return y + + @R.function(private=True) + def quux(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + R.func_attr({"relax.force_pure": 1}) + y: R.Tuple = R.print(format=R.str("Lol")) + z: R.Tensor((), dtype="int32") = R.multiply(x, x) + return z +""", + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index c55876a3ba2d..f0c4ae0bd2a3 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -71,7 +71,9 @@ def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): gv = R.add(x, y) return gv - Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]) + Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]).with_attr( + "global_symbol", "func_copied" + ) # Assertion will fail if the f_copied contains the same VarNode that's used in # the original function, due to var mapping during structural equal. @@ -113,7 +115,9 @@ def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): gv = R.add(x, y) return gv - Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]) + Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]).with_attr( + "global_symbol", "func_copied" + ) assert_structural_equal(Actual, Expected)