Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Canary Adapters tutorial (#9670) #9777

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def change_prompt(
prompt_cls = PromptFormatter.resolve(self.prompt_format)
self.prompt = prompt_cls(
tokenizer=self.tokenizer,
defaults=OmegaConf.to_container(pd) if (pd := self.cfg.prompt_defaults) is not None else None,
defaults=OmegaConf.to_container(pd) if (pd := self.cfg.get('prompt_defaults')) is not None else None,
)

# Update config
Expand Down Expand Up @@ -979,7 +979,7 @@ def _transcribe_on_end(self, trcfg: MultiTaskTranscriptionConfig):
"""
super()._transcribe_on_end(trcfg)

self.transf_decoder.unfreeze()
self.transf_decoder.unfreeze(partial=True)

def _may_be_make_dict_and_fix_paths(self, json_items, manifest_path, trcfg: MultiTaskTranscriptionConfig):
"""
Expand Down
14 changes: 0 additions & 14 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,20 +665,6 @@ def test_dataloader(self):

""" Transcription related methods """

def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig):
super()._transcribe_on_begin(audio, trcfg)

# Freeze the encoder and decoder modules
self.encoder.freeze()
self.decoder.freeze()

def _transcribe_on_end(self, trcfg: TranscribeConfig):
super()._transcribe_on_end(trcfg)

# Unfreeze the encoder and decoder modules
self.encoder.unfreeze()
self.decoder.unfreeze()

def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):
logits, logits_len, greedy_predictions = self.forward(input_signal=batch[0], input_signal_length=batch[1])
output = dict(logits=logits, logits_len=logits_len)
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig):
super()._transcribe_on_end(trcfg)

if hasattr(self, 'ctc_decoder'):
self.ctc_decoder.unfreeze()
self.ctc_decoder.unfreeze(partial=True)

def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):
if self.cur_decoder == "rnnt":
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/transformer_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,4 +633,4 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig):
super()._transcribe_on_end(trcfg)

# Unfreeze the encoder and decoder modules
self.transf_decoder.unfreeze()
self.transf_decoder.unfreeze(partial=True)
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/mixins/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,13 +770,13 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig):

# Unfreeze the encoder and decoder modules
if hasattr(self, 'encoder'):
self.encoder.unfreeze()
self.encoder.unfreeze(partial=True)

if hasattr(self, 'decoder'):
self.decoder.unfreeze()
self.decoder.unfreeze(partial=True)

if hasattr(self, 'joint'):
self.joint.unfreeze()
self.joint.unfreeze(partial=True)

@classmethod
def get_transcribe_config(cls) -> TranscribeConfig:
Expand Down
5 changes: 4 additions & 1 deletion nemo/collections/common/prompts/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class BaseModalityType:
def matches(value: Any) -> bool:
raise NotImplementedError

def __repr__(self):
return f"Modality.{self.__class__.__name__}()"


class Text(BaseModalityType):
"""Modality for text values."""
Expand All @@ -42,7 +45,7 @@ def matches(self, value: str) -> bool:
return isinstance(value, str) and value in self.allowed_values

def __repr__(self):
return f"{self.__class__.__name__}({self.allowed_values})"
return f"Modality.{self.__class__.__name__}(allowed_values={self.allowed_values})"


class Modality:
Expand Down
99 changes: 86 additions & 13 deletions nemo/core/classes/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.nn import Module

from nemo.core.classes.common import FileIO, Serialization, Typing
from nemo.utils import logging

__all__ = ['NeuralModule']

Expand Down Expand Up @@ -54,39 +55,111 @@ def input_example(self, max_batch=None, max_dim=None):
def freeze(self) -> None:
r"""
Freeze all params for inference.

This method sets `requires_grad` to False for all parameters of the module.
It also stores the original `requires_grad` state of each parameter in a dictionary,
so that `unfreeze()` can restore the original state if `partial=True` is set in `unfreeze()`.
"""
for param in self.parameters():
grad_map = {}

for pname, param in self.named_parameters():
# Store the original grad state
grad_map[pname] = param.requires_grad
# Freeze the parameter
param.requires_grad = False

# Store the frozen grad map
if not hasattr(self, '_frozen_grad_map'):
self._frozen_grad_map = grad_map
else:
self._frozen_grad_map.update(grad_map)

self.eval()

def unfreeze(self) -> None:
def unfreeze(self, partial: bool = False) -> None:
"""
Unfreeze all parameters for training.

Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`).
The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were
previously unfrozen prior `freeze()`.

Example:
Consider a model that has an encoder and a decoder module. Assume we want the encoder to be frozen always.

```python
model.encoder.freeze() # Freezes all parameters in the encoder explicitly
```

During inference, all parameters of the model should be frozen - we do this by calling the model's freeze method.
This step records that the encoder module parameters were already frozen, and so if partial unfreeze is called,
we should keep the encoder parameters frozen.

```python
model.freeze() # Freezes all parameters in the model; encoder remains frozen
```

Now, during fine-tuning, we want to unfreeze the decoder but keep the encoder frozen. We can do this by calling
`unfreeze(partial=True)`.

```python
model.unfreeze(partial=True) # Unfreezes only the decoder; encoder remains frozen
```

Args:
partial: If True, only unfreeze parameters that were previously frozen. If the parameter was already frozen
when calling `freeze()`, it will remain frozen after calling `unfreeze(partial=True)`.
"""
for param in self.parameters():
param.requires_grad = True
if partial and not hasattr(self, '_frozen_grad_map'):
raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`")

for pname, param in self.named_parameters():
if not partial:
# Unfreeze all parameters
param.requires_grad = True
else:
# Unfreeze only parameters that were previously frozen

# Check if the parameter was frozen
if pname in self._frozen_grad_map:
param.requires_grad = self._frozen_grad_map[pname]
else:
# Log a warning if the parameter was not found in the frozen grad map
logging.warning(
f"Parameter {pname} not found in list of previously frozen parameters. "
f"Unfreezing this parameter."
)
param.requires_grad = True

# Clean up the frozen grad map
if hasattr(self, '_frozen_grad_map'):
delattr(self, '_frozen_grad_map')

self.train()

@contextmanager
def as_frozen(self):
"""
Context manager which temporarily freezes a module, yields control and finally unfreezes the module.
Context manager which temporarily freezes a module, yields control and finally unfreezes the module partially
to return to original state.

Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`).
The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were
previously unfrozen prior `freeze()`.

Example:
with model.as_frozen(): # by default, partial = True
# Do something with the model
pass

# Model's parameters are now back to original state of requires_grad
"""
training_mode = self.training
grad_map = {}
for pname, param in self.named_parameters():
grad_map[pname] = param.requires_grad

self.freeze()
try:
yield
finally:
self.unfreeze()

for pname, param in self.named_parameters():
param.requires_grad = grad_map[pname]
self.unfreeze(partial=True)

if training_mode:
self.train()
Expand Down
89 changes: 89 additions & 0 deletions tests/core/test_neural_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile

import pytest
import torch

from nemo.core.classes.module import NeuralModule


class TempModule(NeuralModule):

def __init__(self):
super().__init__()

self.layer1 = torch.nn.Linear(10, 10, bias=False)
self.layer2 = torch.nn.Linear(10, 10, bias=False)


class TestNeuralModule:

@pytest.mark.unit
def test_num_weights(self):
module = TempModule()
assert module.num_weights == 200

@pytest.mark.unit
def test_freeze(self):
module = TempModule()
module.freeze()
for p in module.parameters():
assert not p.requires_grad

@pytest.mark.unit
def test_unfreeze(self):
module = TempModule()
module.freeze()
module.unfreeze()
for p in module.parameters():
assert p.requires_grad

@pytest.mark.unit
def test_as_frozen(self):
module = TempModule()

for p in module.parameters():
assert p.requires_grad

with module.as_frozen():
for p in module.parameters():
assert not p.requires_grad

for p in module.parameters():
assert p.requires_grad

@pytest.mark.unit
def test_partial_unfreeze(self):
module = TempModule()

for param in module.layer1.parameters():
param.requires_grad = False

module.freeze()

for param in module.layer1.parameters():
assert not param.requires_grad

assert module._frozen_grad_map is not None
assert len(module._frozen_grad_map) == 2
assert module._frozen_grad_map['layer1.weight'] is False

module.unfreeze(partial=True)

# layer1 should still be frozen due to partial unfreeze
assert module.layer1.weight.requires_grad is False
assert not hasattr(module, '_frozen_grad_map')
Loading
Loading