Skip to content

Commit 819b002

Browse files
authored
[Relax] Support nested ModuleList in nn.Module (#16971)
1 parent 28d32b5 commit 819b002

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

python/tvm/relax/frontend/nn/core.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -607,16 +607,19 @@ def wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Sequence[Tensor]]:
607607

608608
def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any], bool]):
609609
"""Find attributes that satisfy the condition recursively"""
610+
if isinstance(root, ModuleList):
611+
for i, subitem in enumerate(root):
612+
yield from _attribute_finder(subitem, prefix + f"{i}.", condition_yield)
613+
return
610614
for name, item in root.__dict__.items():
611615
if condition_yield(item):
612616
yield prefix + name, item
613617
elif isinstance(item, ModuleList):
614-
for i, subitem in enumerate(item):
615-
yield from _attribute_finder(
616-
subitem,
617-
prefix + name + f".{i}.",
618-
condition_yield,
619-
)
618+
yield from _attribute_finder(
619+
item,
620+
prefix + name + ".",
621+
condition_yield,
622+
)
620623
elif isinstance(item, Module):
621624
yield from _attribute_finder(
622625
item,

tests/python/relax/test_frontend_nn_modules.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,5 +700,20 @@ def forward(x: R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dty
700700
assert_structural_equal(tvm_mod["forward"], forward)
701701

702702

703+
def test_module_list():
704+
class Module(nn.Module):
705+
def __init__(self):
706+
self.layers = nn.ModuleList(
707+
[nn.ModuleList([nn.Linear(4, 4, bias=False) for _ in range(2)]) for _ in range(1)]
708+
)
709+
710+
def forward(self, x: nn.Tensor):
711+
return self.layers(x)
712+
713+
mod = Module()
714+
named_params = dict(mod.named_parameters())
715+
assert ["layers.0.0.weight", "layers.0.1.weight"] == sorted(list(named_params.keys()))
716+
717+
703718
if __name__ == "__main__":
704719
tvm.testing.main()

0 commit comments

Comments
 (0)