diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 531bc206..867cca98 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -361,7 +361,6 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index d56dce98..c22319c1 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -151,7 +151,7 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.AVG, group=group) return loss, grad @@ -277,7 +277,7 @@ def _torch_reverse_kl_forward_backward( loss = (loss_per_sample * loss_mask).mean() if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.AVG, group=group) if grad_output is not None: loss.backward(torch.full_like(loss, grad_output)) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index d9a27c45..261d5402 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,4 +1,5 @@ import functools +import logging import typing import warnings @@ -8,12 +9,14 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.utils import Assert +from fast_llm.utils import Assert, log if typing.TYPE_CHECKING: from fast_llm.layers.block.block import BlockBase from fast_llm.layers.block.sequence import FixedBlockSequence, PatternBlockSequence +logger = logging.getLogger(__name__) + class BlockDimNames: # A set of common tensor dim names packed into a namespace. @@ -37,6 +40,7 @@ class BlockKwargs: sequence_lengths = "sequence_lengths" # TODO: Belongs elsewhere? grad_output = "grad_output" + activation_distillation_targets = "activation_distillation_targets" iteration = "iteration" device = "device" hidden_states = "hidden_states" @@ -87,6 +91,9 @@ def get_layer( peft=peft, ) + def get_distillation_models(self) -> set[str]: + return set() + @config_class(registry=True) class BlockSequenceConfig(BlockConfig): @@ -118,6 +125,9 @@ def layer_class(self) -> "type[FixedBlockSequence]": return FixedBlockSequence + def get_distillation_models(self) -> set[str]: + return self.block.get_distillation_models() + @config_class(dynamic_type={BlockSequenceConfig: "pattern"}) class PatternBlockSequenceConfig(BlockSequenceConfig): @@ -164,3 +174,21 @@ def expanded_pattern(self) -> list[str]: def preprocessing_layers(self) -> dict[str, int]: # The index at which each block first appears. These blocks are used for preprocessing. return {name: self.expanded_pattern.index(name) for name in set(self.expanded_pattern)} + + def get_distillation_models(self) -> set[str]: + models = set() + for block in self.blocks.values(): + models.update(block.get_distillation_models()) + return models + + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + # Patch creeping type parameters from pretrained model + # TODO: fix this + if "block" in default: + removed = default.pop("block") + log( + f"Removing 'block' from default dict in PatternBlockSequenceConfig._from_dict: {removed}", + log_fn=logger.warning, + ) + return super()._from_dict(default, strict=strict) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index a915b16d..148dabd5 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -4,16 +4,19 @@ import torch -from fast_llm.core.distributed import set_generator +from fast_llm.core.distributed import ReduceOp, all_reduce, set_generator from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig +from fast_llm.layers.language_model.head import _format_name from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -134,6 +137,9 @@ def forward( hidden_states = self.norm_1(input_) self._debug(hidden_states, "norm_1", kwargs.get(BlockKwargs.hidden_dims), kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) + + hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses) + with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) self._debug(input_, "mixer_residual", kwargs.get(BlockKwargs.hidden_dims), kwargs) @@ -148,6 +154,52 @@ def forward( hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states + def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): + """ + Maybe apply activation distillation loss and setup backward hooks. + """ + mixer_output = hidden_states if bias is None else hidden_states + bias + + # Teacher: output mixer activations via _debug interface + self._debug(mixer_output.detach(), "mixer_output", kwargs.get(BlockKwargs.hidden_dims), kwargs) + + # Student gets teacher activations and computes the activation-level loss. + activation_targets = kwargs.get(BlockKwargs.activation_distillation_targets) + key = f"{self.module_name}.mixer_output" + if ( + activation_targets is not None + and self.training + and (teacher_output := activation_targets.pop(key, None)) is not None + ): + # Compare student mixer output with the teacher's stored activation and accumulate the loss. + teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) + Assert.eq(teacher_tensor.shape, mixer_output.shape) + # TODO: un-scaled loss for reporting? Average loss over layers? + # L2 loss + activation_loss_factor = self._config.activation_distillation_factor + # (batch, sequence, hidden) or (sequence, batch, hidden). Take the norm over hidden dim. + # TODO: handle possible padding? + local_loss_sum = torch.sum(torch.norm(mixer_output - teacher_tensor, p=2, dim=(2))) + # mixer_output.shape is (batch, sequence, hidden) or (sequence, batch, hidden) + # In either case, dims 0 and 1 are batch and sequence + total_count = mixer_output.shape[0] * mixer_output.shape[1] + + # All-reduce across tensor-parallel group if sequence-parallel is enabled + if self._sequence_parallel and self._distributed.tensor_group is not None: + all_reduce(local_loss_sum, group=self._distributed.tensor_group, op=ReduceOp.SUM) + # Assume all ranks contribute the same count (not the case if padding) + total_count *= self._distributed.tensor_group.size() + + activation_loss = activation_loss_factor * (local_loss_sum / total_count) + + # Backward hooks + hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0) + bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None + # Logging + if losses is not None and self._activation_distillation_loss_name in losses: + losses[self._activation_distillation_loss_name].append(activation_loss.detach()) + return hidden_states, bias + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (normalization, bias_dropout_add) return sum( @@ -161,5 +213,21 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self.mixer.preprocess(kwargs) self.mlp.preprocess(kwargs) + # TODO: add layer_index + _activation_distillation_loss_name = "activation_distillation_loss" + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) + loss_definitions = [] + if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None: + loss_definitions.append( + LossDef( + name=self._activation_distillation_loss_name, + formatted_name=_format_name(self._activation_distillation_loss_name), + count=count, + ) + ) + return ( + loss_definitions + + self.mixer.get_loss_definitions(count=count) + + self.mlp.get_loss_definitions(count=count) + ) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 4b2bec1c..83087570 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -200,6 +200,22 @@ class DecoderBlockConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + distillation_model: str | None = Field( + default=None, + desc="Name of the reference model to use for activation-level distillation.", + hint=FieldHint.feature, + ) + activation_distillation_factor: float = Field( + default=0.0, + desc="Factor to scale the activation-level distillation loss by.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + + def _validate(self) -> None: + super()._validate() + if self.activation_distillation_factor > 0.0 and self.distillation_model is None: + raise ValueError("Activation distillation requires a distillation_model.") @property def layer_class(self) -> "type[DecoderBlock]": @@ -223,3 +239,8 @@ def get_layer( peft=peft, return_input=return_input, ) + + def get_distillation_models(self) -> set[str]: + if self.distillation_model is not None and self.activation_distillation_factor > 0.0: + return {self.distillation_model} + return set() diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index f9334816..f5af7d8b 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -167,6 +167,7 @@ def _validate(self) -> None: prediction_heads = 1 expected_names = {name for name in (head.distillation_model, head.dpo_reference_model) if name is not None} + expected_names.update(self.model.base_model.decoder.get_distillation_models()) Assert.eq(self.reference_models.keys(), expected_names) for reference_model in self.reference_models.values(): diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index b42f68b2..acb1a71d 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -16,7 +16,7 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex -from fast_llm.layers.block.config import FixedBlockSequenceConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MLPConfig @@ -419,8 +419,19 @@ def import_config(cls, config: dict) -> dict: } @classmethod - def export_config(cls, config: FixedBlockSequenceConfig) -> dict: - # TODO: Support PatternBlockSequenceConfig with compatible configs. + def export_config(cls, config: FixedBlockSequenceConfig | PatternBlockSequenceConfig) -> dict: + if isinstance(config, PatternBlockSequenceConfig): + # All exported block configs must be equal + exported_block_configs = [ + safe_merge_dicts( + cls.block_converter_class.export_config(block_config), + {"num_hidden_layers": config.num_blocks}, + ) + for block_config in config.blocks.values() + ] + for other in exported_block_configs[1:]: + Assert.eq(exported_block_configs[0], other) + return exported_block_configs[0] Assert.custom(isinstance, config, FixedBlockSequenceConfig) return safe_merge_dicts( cls.block_converter_class.export_config(config.block), @@ -430,15 +441,19 @@ def export_config(cls, config: FixedBlockSequenceConfig) -> dict: @classmethod def get_converters( cls, - config: FixedBlockSequenceConfig, + config: FixedBlockSequenceConfig | PatternBlockSequenceConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: + # In the case of PatternBlockSequenceConfig, compatibility was already checked in export_config + block_config = ( + config.block if isinstance(config, FixedBlockSequenceConfig) else next(iter(config.blocks.values())) + ) converters = [] for block_index in range(config.num_blocks): converters += cls.block_converter_class.get_converters( - config.block, + block_config, f"{fast_llm_prefix}.{block_index}", f"{hf_prefix}.{block_index}", drop_on_export, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 0f26d14f..a0c38143 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -1,4 +1,5 @@ import logging +import re import typing import torch @@ -166,14 +167,28 @@ def preprocess_batch( if preprocessed_meta is None: preprocessed_meta = self.preprocess_meta(batch, phase) + distillation_models = self._config.decoder.get_distillation_models() + # TODO: Support multiple distillation models? + assert len(distillation_models) <= 1 reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): reference_preprocessed_meta = [ (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta ] + # Set output_hidden_states in reference metadata before preprocessing if needed for distillation + if name in distillation_models: + reference_output_hidden_states = [r"decoder\.\d+\.mixer_output$"] + for _, ref_kwargs_meta in reference_preprocessed_meta: + ref_kwargs_meta[BlockKwargs.output_hidden_states] = [ + re.compile(pattern) for pattern in reference_output_hidden_states + ] + reference_batch = reference_model.fast_llm_model.base_model.preprocess_batch( - batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration + batch, + reference_preprocessed_meta, + phase=PhaseType.inference, + iteration=iteration, ) # TODO: Do things work with >1? @@ -181,6 +196,14 @@ def preprocess_batch( for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch): reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] + if BlockKwargs.hidden_states in reference_kwargs and reference_kwargs[BlockKwargs.hidden_states]: + # Extract activations from hidden_states dict (stored by _debug method) + # Format: {layer_name: (meta, tensor), ...} + activations = { + layer_name: tensor + for layer_name, (meta, tensor) in reference_kwargs[BlockKwargs.hidden_states].items() + } + reference_logits[i][f"{name}_activations"] = activations preprocessed = [] presents = None @@ -205,6 +228,13 @@ def preprocess_batch( **reference_logits[i], } + # Add activation-distillation targets + assert len(distillation_models) <= 1 + for distillation_model in distillation_models: + teacher_key = f"{distillation_model}_activations" + if teacher_key in reference_logits[i]: + kwargs[BlockKwargs.activation_distillation_targets] = reference_logits[i].pop(teacher_key) + if phase != PhaseType.inference: labels_begin = tokens_begin + 1 labels_end = tokens_end + self._config.head.max_prediction_distance diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 752e3a8c..16277720 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -188,6 +188,37 @@ def _update_and_add_testing_config( init_1 = {"initialization": {"type": "normal", "std": 2**-5.5}} # Needed to match Megatron (init_1 / (2 * num_layers) ** 0.5) init_2 = {"initialization": {"type": "normal", "std": 2**-6.5}} +base_model = { + "embeddings": { + "word_embeddings": init_1, + "position_embeddings": {"enabled": True, **init_1}, + "num_position_embeddings": 512, + "vocab_size": MODEL_TEST_VOCAB_SIZE, + }, + "decoder": { + "block": { + "mixer": { + "query_layer": {"weight": init_1}, + "key_layer": {"weight": init_1}, + "value_layer": {"weight": init_1}, + "dense_layer": {"weight": init_2}, + "heads": 8, + "head_groups": 8, + "head_size": 32, + # "cross_document_attention":False, + }, + "mlp": { + "layer_1": {"weight": init_1}, + "layer_2": {"weight": init_2}, + "intermediate_size": 1024, + }, + }, + "num_blocks": 2, + }, + "head": {"output_weight": init_1}, + "hidden_size": 256, + "tied_embedding_weight": True, +} MODEL_CONFIGS["gpt_2"] = ModelTestingConfig( # Tests gpt2 features (absolute embeddings, layer norm, relu activation, tied embeddings, MHA, linear biases). @@ -207,37 +238,7 @@ def _update_and_add_testing_config( "timeout": 30, }, "model": { - "base_model": { - "embeddings": { - "word_embeddings": init_1, - "position_embeddings": {"enabled": True, **init_1}, - "num_position_embeddings": 512, - "vocab_size": MODEL_TEST_VOCAB_SIZE, - }, - "decoder": { - "block": { - "mixer": { - "query_layer": {"weight": init_1}, - "key_layer": {"weight": init_1}, - "value_layer": {"weight": init_1}, - "dense_layer": {"weight": init_2}, - "heads": 8, - "head_groups": 8, - "head_size": 32, - # "cross_document_attention":False, - }, - "mlp": { - "layer_1": {"weight": init_1}, - "layer_2": {"weight": init_2}, - "intermediate_size": 1024, - }, - }, - "num_blocks": 2, - }, - "head": {"output_weight": init_1}, - "hidden_size": 256, - "tied_embedding_weight": True, - }, + "base_model": base_model, "multi_stage": { "debug_param_init": _LOG_LEVEL, "debug_layer_outputs": _LOG_LEVEL, @@ -538,6 +539,84 @@ def _update_and_add_testing_config( }, ) +_update_and_add_testing_config( + # Tests logit distillation. + "mistral", + "mistral_distill_logits", + updates={ + ("model", "base_model", "head", "distillation_model"): "teacher", + ("reference_models"): { + "teacher": { + "model": {"base_model": base_model}, + }, + }, + }, + megatron_args=None, + checkpoint_format=MistralCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, + ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, + ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: tp2, stp2, stp2_ce4 + }, + compare_factor=1.5, + # modes not supported with reference models + skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), +) + +_update_and_add_testing_config( + "mistral_distill_logits", + "mistral_reverse_kl", + updates={ + ("model", "base_model", "head", "distillation_loss_implementation"): "reverse_kl", + }, + megatron_args=None, + checkpoint_format=MistralCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, + ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, + ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, tp2, stp2, stp2_ce4 + }, + compare_factor=2, + # modes not supported with reference models + skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), +) + +_update_and_add_testing_config( + "mistral_distill_logits", + "mistral_distill_activations", + updates={ + ("model", "base_model", "head", "distillation_loss_factor"): 0.001, + ("model", "base_model", "decoder", "block", "distillation_model"): "teacher", + ("model", "base_model", "decoder", "block", "activation_distillation_factor"): 0.1, + ("reference_models"): { + "teacher": { + "model": {"base_model": base_model}, + }, + }, + }, + # Megatron doesn't support sliding windows. + megatron_args=None, + checkpoint_format=MistralCheckpointFormat, + # TODO: Add back generate as `normal` when stable. + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, + ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, + ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, df4, df4_sf, tp2, stp2, + }, + compare_factor=8, + # modes not supported with reference models + skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2", "stp2_ce4"), +) + _update_and_add_testing_config( # Tests mixture of experts, mixtral converter. "llama", diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index 5a24e593..5c07324c 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -64,6 +64,40 @@ def run_test_script_base_path(model_testing_config, result_path, request): return result_path / "models" / model_testing_config.name +def _propagate_config_args_to_reference_models(config_args: list[str]) -> list[str]: + """ + Propagate certain model config args to reference models. + + Some config args that affect model behavior need to be applied to both + the main model and reference models to ensure compatibility. + """ + propagated_args = [] + # Patterns that should be propagated to reference models + # Only model-level configs should be propagated, not batch-level configs + # (batch is shared at the trainer level, not per-model) + propagate_patterns = [ + ("model", "base_model", "sequence_first"), + ("model", "base_model", "embeddings", "vocab_parallel"), + ] + + for arg in config_args: + if "=" not in arg: + continue + key, value = arg.split("=", 1) + key_tuple = tuple(key.split(".")) + + # Check if this arg should be propagated + for pattern in propagate_patterns: + if key_tuple == pattern: + # Add the reference model version of this arg + # For each reference model (we check if they exist in the config) + ref_key = f"reference_models.teacher.{key}" + propagated_args.append(f"{ref_key}={value}") + break + + return propagated_args + + def do_run_test_script_for_all_models( distributed_testing_config: DistributedTestingConfig, model_testing_config: ModelTestingConfig, @@ -72,12 +106,19 @@ def do_run_test_script_for_all_models( ): Assert.leq(distributed_testing_config.num_gpus, DistributedConfig.default_world_size) model_testing_config.get_dataset() + + # Propagate certain config args to reference models if they exist + propagated_args = [] + if "reference_models" in str(model_testing_config.config_dict): + propagated_args = _propagate_config_args_to_reference_models(distributed_testing_config.config_args) + args = [ "fast-llm", runnable_type, model_testing_config.model_type, *model_testing_config.config_args, *distributed_testing_config.config_args, + *propagated_args, f"model.distributed.world_size={distributed_testing_config.num_gpus}", f"model.distributed.local_world_size={distributed_testing_config.num_gpus}", f"run.experiment_dir={base_path/distributed_testing_config.name}",