diff --git a/minitorch/module.py b/minitorch/module.py index 8f17cfb..78e391d 100644 --- a/minitorch/module.py +++ b/minitorch/module.py @@ -31,13 +31,24 @@ def modules(self) -> Sequence[Module]: def train(self) -> None: "Set the mode of this module and all descendent modules to `train`." - # TODO: Implement for Task 0.4. - raise NotImplementedError("Need to implement for Task 0.4") + self._set_mode(self, train=True) + + def _set_mode(self, root: Module, train: bool = False) -> None: + if not root: + return + + if train: + root.training = True + else: + root.training = False + modules = getattr(root, "_modules", None) + if modules is not None: + for _, v in modules.items(): + self._set_mode(v, train) def eval(self) -> None: "Set the mode of this module and all descendent modules to `eval`." - # TODO: Implement for Task 0.4. - raise NotImplementedError("Need to implement for Task 0.4") + self._set_mode(self, train=False) def named_parameters(self) -> Sequence[Tuple[str, Parameter]]: """ @@ -47,13 +58,33 @@ def named_parameters(self) -> Sequence[Tuple[str, Parameter]]: Returns: The name and `Parameter` of each ancestor parameter. """ - # TODO: Implement for Task 0.4. - raise NotImplementedError("Need to implement for Task 0.4") + return self._traverse_tree(self, named=True) + + def _traverse_tree( + self, root: Module, named: bool = False, path: str = "" + ) -> Sequence[Tuple[str, Parameter] | Parameter]: + result = [] + if not root: + return result + + parameters = getattr(root, "_parameters", None) + if parameters is not None: + for k, p in parameters.items(): + if named: + result.append((f"{path}.{k}" if path else f"{k}", p)) + else: + result.append(p) + modules = getattr(root, "_modules", None) + if modules is not None: + for name, v in modules.items(): + path_sofar = f"{path}.{name}" if path else name + result.extend(self._traverse_tree(v, named=named, path=path_sofar)) + + return result def parameters(self) -> Sequence[Parameter]: "Enumerate over all the parameters of this module and its descendents." - # TODO: Implement for Task 0.4. - raise NotImplementedError("Need to implement for Task 0.4") + return self._traverse_tree(self) def add_parameter(self, k: str, v: Any) -> Parameter: """