Skip to content

Commit b91d4e5

Browse files
authored
[TVMScript] Produce empty DictAttrs when R.func_attrs is absent (#16844)
A follow-up to #16745. For Relax functions produced in TVMScript, when `R.func_attrs` was not present, the default was set to `None` instead of an empty dictionary.
1 parent b01de08 commit b91d4e5

File tree

4 files changed

+31
-2
lines changed

4 files changed

+31
-2
lines changed

src/relax/ir/expr.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,10 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
493493

494494
Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info, bool is_pure,
495495
DictAttrs attrs, Span span) {
496+
if (!attrs.defined()) {
497+
attrs = DictAttrs();
498+
}
499+
496500
// Set the function type.
497501
// For function, we take a conservative approach and require the function type
498502
// to be known at construction time.

src/script/ir_builder/relax/frame.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,12 @@ void FunctionFrameNode::ExitWithScope() {
6161
!attrs.count(tvm::attr::kGlobalSymbol)) {
6262
attrs.Set(tvm::attr::kGlobalSymbol, name.value());
6363
}
64-
auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs);
6564
this->block_builder->EndScope();
6665
tvm::relax::Function func(/*params=*/params,
6766
/*body=*/body,
6867
/*ret_struct_info=*/ret_struct_info,
6968
/*is_pure=*/is_pure.value_or(Bool(true))->value,
70-
/*attrs=*/dict_attrs);
69+
/*attrs=*/DictAttrs(attrs));
7170
// Step 2: Update IRModule.
7271
if (builder->frames.empty()) {
7372
// Case 0. No outer frame, return function directly

src/tir/ir/function.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ relax::StructInfo InferStructInfo(const PrimFunc& prim_func) {
7070
// Get the function type of a PrimFunc
7171
PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
7272
Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) {
73+
if (!attrs.defined()) {
74+
attrs = DictAttrs();
75+
}
76+
7377
// Assume void-return type for now
7478
// TODO(tvm-team) consider type deduction from body.
7579
if (!ret_type.defined()) {

tests/python/relax/test_tvmscript_parser.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2271,5 +2271,27 @@ def main(A: R.Tensor, B: R.Tensor):
22712271
tvm.ir.assert_structural_equal(DefinedAllAtOnce, MainDefinedLater)
22722272

22732273

2274+
def test_function_attributes_are_defined():
2275+
"""func.attrs defaults to an empty DictAttrs"""
2276+
2277+
@I.ir_module
2278+
class Module:
2279+
@R.function
2280+
def main(x: R.Tensor, shape: R.Shape(["m", "n"])):
2281+
output = Module.subroutine(x, shape)
2282+
return output
2283+
2284+
@R.function
2285+
def subroutine(x: R.Tensor, _: R.Shape(["m", "n"])) -> R.Tensor(["m", "n"]):
2286+
q = x
2287+
m, n = T.int64(), T.int64()
2288+
z = R.match_cast(q, R.Tensor((m, n)))
2289+
w = z
2290+
return w
2291+
2292+
for gvar, func in Module.functions.items():
2293+
assert func.attrs is not None
2294+
2295+
22742296
if __name__ == "__main__":
22752297
tvm.testing.main()

0 commit comments

Comments
 (0)