Skip to content

Commit

Permalink
Merge pull request #11 from dantp-ai/solution/task0_4
Browse files Browse the repository at this point in the history
Solution/task0 4
  • Loading branch information
dantp-ai authored Mar 19, 2024
2 parents 35b1edd + bdc90c5 commit b41dc13
Showing 1 changed file with 39 additions and 8 deletions.
47 changes: 39 additions & 8 deletions minitorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,24 @@ def modules(self) -> Sequence[Module]:

def train(self) -> None:
"Set the mode of this module and all descendent modules to `train`."
# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")
self._set_mode(self, train=True)

def _set_mode(self, root: Module, train: bool = False) -> None:
if not root:
return

if train:
root.training = True
else:
root.training = False
modules = getattr(root, "_modules", None)
if modules is not None:
for _, v in modules.items():
self._set_mode(v, train)

def eval(self) -> None:
"Set the mode of this module and all descendent modules to `eval`."
# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")
self._set_mode(self, train=False)

def named_parameters(self) -> Sequence[Tuple[str, Parameter]]:
"""
Expand All @@ -47,13 +58,33 @@ def named_parameters(self) -> Sequence[Tuple[str, Parameter]]:
Returns:
The name and `Parameter` of each ancestor parameter.
"""
# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")
return self._traverse_tree(self, named=True)

def _traverse_tree(
self, root: Module, named: bool = False, path: str = ""
) -> Sequence[Tuple[str, Parameter] | Parameter]:
result = []
if not root:
return result

parameters = getattr(root, "_parameters", None)
if parameters is not None:
for k, p in parameters.items():
if named:
result.append((f"{path}.{k}" if path else f"{k}", p))
else:
result.append(p)
modules = getattr(root, "_modules", None)
if modules is not None:
for name, v in modules.items():
path_sofar = f"{path}.{name}" if path else name
result.extend(self._traverse_tree(v, named=named, path=path_sofar))

return result

def parameters(self) -> Sequence[Parameter]:
"Enumerate over all the parameters of this module and its descendents."
# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")
return self._traverse_tree(self)

def add_parameter(self, k: str, v: Any) -> Parameter:
"""
Expand Down

0 comments on commit b41dc13

Please sign in to comment.