diff --git a/flax/nnx/module.py b/flax/nnx/module.py index 795bb9a08..05e14988a 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -268,8 +268,8 @@ def set_attributes( raise_if_not_found: bool = True, **attributes: tp.Any, ) -> None: - """Sets the attributes of nested Modules including the current Module. - If the attribute is not found in the Module, it is ignored. + """Sets the attributes of nested :class:`flax.nnx.Module`'s including the current + ``nnx.Module``. If the attribute is not found in the ``nnx.Module``, it is ignored. Example:: @@ -288,7 +288,8 @@ def set_attributes( >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True) - ``Filter``'s can be used to set the attributes of specific Modules:: + ``Filter``'s (``flax.nnx.filterlib``) can be used to set the attributes of specific + ``nnx.Module``'s:: >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.set_attributes(nnx.Dropout, deterministic=True) @@ -297,8 +298,8 @@ def set_attributes( (True, False) Args: - *filters: Filters to select the Modules to set the attributes of. - raise_if_not_found: If True (default), raises a ValueError if at least one attribute + *filters: NNX ``Filter``'s to select the :class:`flax.nnx.Module`'s whose attributes will be to be set. + raise_if_not_found: If ``True`` (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules. **attributes: The attributes to set. """