diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index 3d789be7dc61..557c43fb5813 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -143,7 +143,9 @@ class AdapterModuleMixin(ABC): - `adapter_metadata_cfg_key`: A str representing a key in the model config that is used to preserve the metadata of the adapter config. - **Note**: This module is **not** responsible for maintaining its config. Subclasses must ensure config is updated + .. note:: + + This module is **not** responsible for maintaining its config. Subclasses must ensure config is updated or preserved as needed. It is the responsibility of the subclasses to propagate the most up to date config to lower layers. """ @@ -435,8 +437,6 @@ def forward_enabled_adapters(self, input: 'torch.Tensor'): Utilizes the implicit merge strategy of each adapter when computing the adapter's output, and how that output will be merged back with the original input. - **Note**: - Args: input: The output tensor of the calling module is the input to the first adapter, whose output is then chained to the next adapter until all adapters are consumed. @@ -519,7 +519,9 @@ def forward_single_enabled_adapter_( """ Perform the forward step of a single adapter module on some input data. - **Note**: Subclasses can override this method to accommodate more complicate adapter forward steps. + .. note:: + + Subclasses can override this method to accommodate more complicate adapter forward steps. Args: input: input: The output tensor of the calling module is the input to the first adapter, whose output @@ -756,8 +758,10 @@ def save_adapters(self, filepath: str, name: str = None): Utility method that saves only the adapter module(s), and not the entire model itself. This allows the sharing of adapters which are often just a fraction of the size of the full model, enabling easier deliver. + + .. note:: - Note: The saved file is a pytorch compatible pickle file, containing the state dicts of the adapter(s), + The saved file is a pytorch compatible pickle file, containing the state dicts of the adapter(s), as well as a binary representation of the adapter config. Args: @@ -835,7 +839,9 @@ def load_adapters(self, filepath: str, name: str = None, map_location: str = Non This allows the sharing of adapters which are often just a fraction of the size of the full model, enabling easier deliver. - Note: During restoration, assumes that the model does not currently already have an adapter with + .. note:: + + During restoration, assumes that the model does not currently already have an adapter with the name (if provided), or any adapter that shares a name with the state dict's modules (if name is not provided). This is to ensure that each adapter name is globally unique in a model. @@ -964,7 +970,9 @@ def adapter_module_names(self) -> List[str]: """ List of valid adapter modules that are supported by the model. - **Note**: Subclasses should override this property and return a list of str names, of all the modules + .. note:: + + Subclasses should override this property and return a list of str names, of all the modules that they support, which will enable users to determine where to place the adapter modules. Returns: