-
Notifications
You must be signed in to change notification settings - Fork 39
activation-level disillation #388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 32 commits
9fa4c46
11708ff
9437310
d3ac964
56fc8db
5d75f01
f1bfca9
8b16752
efa8cf0
4cda56d
9ca2347
41692e9
99c42c0
d3df7a5
f2f097e
8e04aba
3ebda84
6a8732f
f729625
0effa24
f7a0837
6f2d5e3
90da831
5251719
d2858d6
280db13
a46ed18
01b5530
6e42944
4c75e10
035d36c
e3ac422
40ae449
4a621a5
e0e6670
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import functools | ||
| import logging | ||
| import typing | ||
| import warnings | ||
|
|
||
|
|
@@ -8,12 +9,14 @@ | |
| from fast_llm.engine.config_utils.tensor_dim import TensorDim | ||
| from fast_llm.engine.distributed.config import DistributedConfig | ||
| from fast_llm.layers.common.peft.config import PeftConfig | ||
| from fast_llm.utils import Assert | ||
| from fast_llm.utils import Assert, log | ||
|
|
||
| if typing.TYPE_CHECKING: | ||
| from fast_llm.layers.block.block import BlockBase | ||
| from fast_llm.layers.block.sequence import FixedBlockSequence, PatternBlockSequence | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class BlockDimNames: | ||
| # A set of common tensor dim names packed into a namespace. | ||
|
|
@@ -37,6 +40,8 @@ class BlockKwargs: | |
| sequence_lengths = "sequence_lengths" | ||
| # TODO: Belongs elsewhere? | ||
| grad_output = "grad_output" | ||
| activation_distillation_storage = "activation_distillation_storage" | ||
| activation_distillation_targets = "activation_distillation_targets" | ||
| iteration = "iteration" | ||
| device = "device" | ||
| hidden_states = "hidden_states" | ||
|
|
@@ -87,6 +92,9 @@ def get_layer( | |
| peft=peft, | ||
| ) | ||
|
|
||
| def get_distillation_models(self) -> set[str]: | ||
| return set() | ||
|
|
||
|
|
||
| @config_class(registry=True) | ||
| class BlockSequenceConfig(BlockConfig): | ||
|
|
@@ -118,6 +126,9 @@ def layer_class(self) -> "type[FixedBlockSequence]": | |
|
|
||
| return FixedBlockSequence | ||
|
|
||
| def get_distillation_models(self) -> set[str]: | ||
| return self.block.get_distillation_models() | ||
|
|
||
|
|
||
| @config_class(dynamic_type={BlockSequenceConfig: "pattern"}) | ||
| class PatternBlockSequenceConfig(BlockSequenceConfig): | ||
|
|
@@ -164,3 +175,21 @@ def expanded_pattern(self) -> list[str]: | |
| def preprocessing_layers(self) -> dict[str, int]: | ||
| # The index at which each block first appears. These blocks are used for preprocessing. | ||
| return {name: self.expanded_pattern.index(name) for name in set(self.expanded_pattern)} | ||
|
|
||
| def get_distillation_models(self) -> set[str]: | ||
| models = set() | ||
| for block in self.blocks.values(): | ||
| models.update(block.get_distillation_models()) | ||
| return models | ||
|
|
||
| @classmethod | ||
| def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: | ||
| # Patch creeping type parameters from pretrained model | ||
| # TODO: fix this | ||
| if "block" in default: | ||
| removed = default.pop("block") | ||
| log( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| f"Removing 'block' from default dict in PatternBlockSequenceConfig._from_dict: {removed}", | ||
| log_fn=logger.warning, | ||
| ) | ||
| return super()._from_dict(default, strict=strict) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,16 +4,19 @@ | |
|
|
||
| import torch | ||
|
|
||
| from fast_llm.core.distributed import set_generator | ||
| from fast_llm.core.distributed import ReduceOp, all_reduce, set_generator | ||
| from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig | ||
| from fast_llm.engine.config_utils.tensor_dim import TensorDim | ||
| from fast_llm.engine.distributed.config import DistributedConfig | ||
| from fast_llm.engine.distributed.distributed import Distributed | ||
| from fast_llm.layers.block.block import Block | ||
| from fast_llm.layers.block.config import BlockKwargs | ||
| from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss | ||
| from fast_llm.layers.common.peft.config import PeftConfig | ||
| from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig | ||
| from fast_llm.layers.language_model.head import _format_name | ||
| from fast_llm.tensor import TensorMeta | ||
| from fast_llm.utils import Assert | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -134,6 +137,9 @@ def forward( | |
| hidden_states = self.norm_1(input_) | ||
| self._debug(hidden_states, "norm_1", kwargs.get(BlockKwargs.hidden_dims), kwargs) | ||
| hidden_states, bias = self.mixer(hidden_states, kwargs) | ||
|
|
||
| hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses) | ||
|
|
||
| with set_generator(generator): | ||
| input_ = self._bias_dropout_add(hidden_states, bias, input_) | ||
| self._debug(input_, "mixer_residual", kwargs.get(BlockKwargs.hidden_dims), kwargs) | ||
|
|
@@ -148,6 +154,51 @@ def forward( | |
| hidden_states = torch.stack((fw_input, hidden_states), dim=0) | ||
| return hidden_states | ||
|
|
||
| def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): | ||
| """ | ||
| Maybe apply activation distillation loss and setup backward hooks | ||
| """ | ||
| mixer_output = hidden_states if bias is None else hidden_states + bias | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should only be evaluated if needed. |
||
| # Teacher populates mixer activations for distillation. | ||
| activation_storage = kwargs.get(BlockKwargs.activation_distillation_storage) | ||
| if activation_storage is not None: | ||
RaymondLi0 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| activation_storage[self.module_name] = mixer_output.detach() | ||
| # Student gets teacher activations and computes the activation-level loss. | ||
| activation_targets = kwargs.get(BlockKwargs.activation_distillation_targets) | ||
| if ( | ||
| activation_targets is not None | ||
| and self.training | ||
| and (teacher_output := activation_targets.pop(self.module_name, None)) is not None | ||
| ): | ||
| # Compare student mixer output with the teacher's stored activation and accumulate the loss. | ||
| teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) | ||
| Assert.eq(teacher_tensor.shape, mixer_output.shape) | ||
| # TODO: un-scaled loss for reporting? Average loss over layers? | ||
| # L2 loss | ||
| activation_loss_factor = self._config.activation_distillation_factor | ||
| # (batch, sequence, hidden) or (sequence, batch, hidden). Take the norm over hidden dim. | ||
| # TODO: handle possible padding? | ||
| local_loss_sum = torch.sum(torch.norm(mixer_output - teacher_tensor, p=2, dim=(2))) | ||
| # mixer_output.shape is (batch, sequence, hidden) or (sequence, batch, hidden) | ||
| # In either case, dims 0 and 1 are batch and sequence | ||
| total_count = mixer_output.shape[0] * mixer_output.shape[1] | ||
|
|
||
| # All-reduce across tensor-parallel group if sequence-parallel is enabled | ||
| if self._sequence_parallel and self._distributed.tensor_group is not None: | ||
| all_reduce(local_loss_sum, group=self._distributed.tensor_group, op=ReduceOp.SUM) | ||
| # Assume all ranks contribute the same count (not the case if padding) | ||
| total_count *= self._distributed.tensor_group.size() | ||
|
|
||
| activation_loss = activation_loss_factor * (local_loss_sum / total_count) | ||
|
|
||
| # Backward hooks | ||
| hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0) | ||
| bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None | ||
| # Logging | ||
| if losses is not None and self._activation_distillation_loss_name in losses: | ||
| losses[self._activation_distillation_loss_name].append(activation_loss.detach()) | ||
| return hidden_states, bias | ||
|
|
||
| def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: | ||
| # TODO: Add marginal compute? (normalization, bias_dropout_add) | ||
| return sum( | ||
|
|
@@ -161,5 +212,21 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: | |
| self.mixer.preprocess(kwargs) | ||
| self.mlp.preprocess(kwargs) | ||
|
|
||
| # TODO: add layer_index | ||
| _activation_distillation_loss_name = "activation_distillation_loss" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be nice to have a layer index in logging
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree! This involves a bit more changes because
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'd also need to be careful not to drown the logs in these activation losses for each layer |
||
|
|
||
| def get_loss_definitions(self, count: int = 1) -> list[LossDef]: | ||
| return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) | ||
| loss_definitions = [] | ||
| if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None: | ||
| loss_definitions.append( | ||
| LossDef( | ||
| name=self._activation_distillation_loss_name, | ||
| formatted_name=_format_name(self._activation_distillation_loss_name), | ||
| count=count, | ||
| ) | ||
| ) | ||
| return ( | ||
| loss_definitions | ||
| + self.mixer.get_loss_definitions(count=count) | ||
| + self.mlp.get_loss_definitions(count=count) | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.