diff --git a/llm/auto_parallel/deepseek-v2/run_pretrain_auto.py b/llm/auto_parallel/deepseek-v2/run_pretrain_auto.py deleted file mode 100644 index cd9051106bf3..000000000000 --- a/llm/auto_parallel/deepseek-v2/run_pretrain_auto.py +++ /dev/null @@ -1,725 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -deepseek-v2 auto parallel pretraining scripts. -""" -import os -import random -import sys -import types -from collections import OrderedDict -from dataclasses import dataclass, field -from typing import List, Optional - -import numpy as np -import paddle -import paddle.distributed as dist -from paddle.distributed import fleet - -from paddlenlp.ops import Topology -from paddlenlp.trainer import ( - AutoTrainingArguments, - PdArgumentParser, - get_last_checkpoint, -) -from paddlenlp.trainer.auto_trainer import AutoTrainer -from paddlenlp.trainer.trainer_utils import IntervalStrategy, _get_distributed_seeds -from paddlenlp.transformers import ( - AutoTokenizer, - CosineAnnealingWithWarmupDecay, - DeepseekV2Config, - DeepseekV2ForCausalLMAuto, - DeepseekV2PretrainingCriterion, - LinearAnnealingWithWarmupDecay, -) -from paddlenlp.utils.log import logger - -MODEL_CLASSES = { - "deepseekv2_auto": (DeepseekV2Config, DeepseekV2ForCausalLMAuto, DeepseekV2PretrainingCriterion), -} - - -from paddlenlp.data.causal_dataset import ( - build_train_valid_test_datasets, - check_data_split, - print_rank_0, -) -from paddlenlp.trainer.utils.doc import add_start_docstrings - - -@dataclass -@add_start_docstrings(AutoTrainingArguments.__doc__) -class PreTrainingArguments(AutoTrainingArguments): - min_learning_rate: float = field( - default=1e-5, - metadata={"help": "Minimum learning rate deacyed to."}, - ) - decay_steps: float = field( - default=None, - metadata={ - "help": "The steps use to control the learing rate. If the step > decay_steps, will use the min_learning_rate." - }, - ) - enable_linear_fused_grad_add: bool = field( - default=False, - metadata={ - "help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ." - }, - ) - job_schedule_profiler_start: int = field( - default=-1, - metadata={"help": "The step to start job_schedule_profiler."}, - ) - job_schedule_profiler_end: int = field( - default=-1, - metadata={"help": "The step to end job_schedule_profiler."}, - ) - pipeline_schedule_mode: str = field( - default="1F1B", metadata={"help": "The pipeline schedule mode, support FThenB, 1F1B, VPP and Eager-1F1B."} - ) - sr: Optional[int] = field(default=0, metadata={"help": "The count of chunks without recompute."}) - virtual_pipeline_seg_method: str = field( - default="DeepseekV2DecoderLayerAuto", - metadata={"help": "The seg method of spliting pp layer for virtual pipeline."}, - ) - # NOTE(gongenlei): new add autotuner_benchmark - autotuner_benchmark: bool = field( - default=False, - metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."}, - ) - - def __post_init__(self): - super().__post_init__() - assert self.enable_auto_parallel - - # NOTE(gongenlei): new add autotuner_benchmark - if self.autotuner_benchmark: - self.max_steps = 5 - self.do_train = True - self.do_export = False - self.do_predict = False - self.do_eval = False - self.overwrite_output_dir = True - self.load_best_model_at_end = False - self.report_to = [] - self.save_strategy = IntervalStrategy.NO - self.evaluation_strategy = IntervalStrategy.NO - - logger.info(self.strategy) - - -@dataclass -class DataArguments: - """ - Arguments pertaining to what data we are going to input our model for training and evaluating. - Using `PdArgumentParser` we can turn this class into argparse arguments to be able to - specify them on the command line. - """ - - input_dir: str = field( - default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} - ) - split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."}) - - max_seq_length: int = field( - default=1024, - metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - }, - ) - share_folder: bool = field( - default=False, - metadata={"help": "Use share folder for data dir and output dir on multi machine."}, - ) - - data_impl: str = field(default="mmap", metadata={"help": "The format of the preprocessed data."}) - skip_warmup: bool = field( - default=True, - metadata={"help": "Whether to skip the warmup process of mmap files."}, - ) - data_cache: str = field(default=None, metadata={"help": "The path of the cached dataset."}) - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to pre-train from. - """ - - model_type: Optional[str] = field( - default="deepseekv2", metadata={"help": "Only support for llama pre-training for now."} - ) - model_name_or_path: str = field( - default="deepseek-ai/DeepSeek-V2-Lite", - metadata={ - "help": "Path to pretrained model or model identifier from https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html" - }, - ) - tokenizer_name_or_path: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - vocab_size: Optional[int] = field( - default=None, - metadata={ - "help": ".Vocabulary size of the deeepseekv2 model. Defines the number of different tokens that can be represented by the `inputs_ids`" - }, - ) - hidden_size: Optional[int] = field(default=None, metadata={"help": "Dimension of the hidden representations."}) - intermediate_size: Optional[int] = field(default=None, metadata={"help": "Dimension of the MLP representations."}) - num_hidden_layers: Optional[int] = field( - default=None, metadata={"help": "Number of hidden layers in the Transformer encoder."} - ) - num_attention_heads: Optional[int] = field( - default=None, - metadata={"help": "Number of attention heads for each attention layer in the Transformer encoder."}, - ) - use_flash_attention: bool = field( - default=False, - metadata={"help": "use_flash_attention"}, - ) - use_fused_rms_norm: bool = field( - default=False, - metadata={"help": "deepseekv2, use_fused_rms_norm"}, - ) - fuse_attention_qkv: bool = field( - default=False, - metadata={"help": "whether to fuse attention qkv"}, - ) - fuse_attention_ffn: bool = field( - default=False, - metadata={"help": "whether to fuse first up and gate proj in mlp block"}, - ) - recompute_granularity: str = field( - default="full", - metadata={"help": "Choose among ['full', 'core_attn', 'full_attn']"}, - ) - virtual_pp_degree: int = field( - default=1, - metadata={"help": "virtual_pp_degree"}, - ) - continue_training: bool = field( - default=False, - metadata={ - "help": "Pre-training from existing paddlenlp model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models." - }, - ) - use_fused_rope: Optional[bool] = field( - default=False, - metadata={"help": "Enable rope fusion or not."}, - ) - no_recompute_layers: Optional[List[int]] = field( - default=None, - metadata={"help": "Specify the full transformer layers that should not be recomputed."}, - ) - pp_recompute_interval: int = field( - default=1, - metadata={ - "help": "The interval for the number of layers at which recomputation occurs. A value of 0 indicates no recomputation. Default is 0." - }, - ) - recompute_use_reentrant: bool = field( - default=False, - metadata={"help": "recompute_use_reentrant"}, - ) - - -def create_pretrained_dataset( - data_args, - training_args, - data_file, - tokenizer, - need_data=True, -): - - check_data_split(data_args.split, training_args.do_train, training_args.do_eval, training_args.do_predict) - - train_val_test_num_samples = [ - training_args.per_device_train_batch_size - * training_args.dataset_world_size - * training_args.max_steps - * training_args.gradient_accumulation_steps, - training_args.per_device_eval_batch_size - * training_args.dataset_world_size - * training_args.eval_iters - * (training_args.max_steps // training_args.eval_steps + 1), - training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters, - ] - - print_rank_0(" > datasets target sizes (minimum size):") - if training_args.do_train: - print_rank_0(" train: {}".format(train_val_test_num_samples[0])) - if training_args.do_eval: - print_rank_0(" validation: {}".format(train_val_test_num_samples[1])) - if training_args.do_predict: - print_rank_0(" test: {}".format(train_val_test_num_samples[2])) - - # Build the datasets. - train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets( - data_prefix=data_file, - data_impl=data_args.data_impl, - splits_string=data_args.split, - train_val_test_num_samples=train_val_test_num_samples, - seq_length=data_args.max_seq_length, - seed=training_args.seed, - skip_warmup=data_args.skip_warmup, - share_folder=data_args.share_folder, - data_cache_path=data_args.data_cache, - need_data=need_data, - ) - - def print_dataset(data, mode="train"): - logger.info(f"Sample data for {mode} mode.") - # input_ids, loss_mask, attention_mask, position_ids, labels = data - input_ids = data["text"] - - logger.info(tokenizer._decode(input_ids)) - - from paddlenlp.data import Stack - - def _collate_data(data, stack_fn=Stack()): - tokens_ = stack_fn([x["text"] for x in data]) - - labels = tokens_[:, 1:] - tokens = tokens_[:, :-1] - - return { - "input_ids": tokens, - "labels": labels, - } - - if need_data: - if training_args.do_train: - print_dataset(train_dataset[0], "train") - if training_args.do_eval: - print_dataset(valid_dataset[0], "valid") - if training_args.do_predict: - print_dataset(test_dataset[0], "test") - - return train_dataset, valid_dataset, test_dataset, _collate_data - - -def get_train_data_file(args): - if len(args.input_dir.split()) > 1: - # weight-1 data-prefix-1 weight-2 data-prefix-2 ... - return args.input_dir.split() - else: - files = [ - os.path.join(args.input_dir, f) - for f in os.listdir(args.input_dir) - if (os.path.isfile(os.path.join(args.input_dir, f)) and ("_idx.npz" in str(f) or ".idx" in str(f))) - ] - files = [x.replace("_idx.npz", "") for x in files] - files = [x.replace(".idx", "") for x in files] # add - - if len(files) > 1: - ret = [] - logger.info("You are using multi-dataset:") - for x in files: - ret.append(1.0) - ret.append(x) - logger.info(" > set weight of %s dataset to 1.0" % x) - return ret - - return files - - -class PretrainingTrainer(AutoTrainer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_pretraining = True - - def _wrap_for_dist_loader(self, train_dataloader): - dist_loader = super()._wrap_for_dist_loader(train_dataloader) - dist_loader._input_keys = ["input_ids", "labels"] - return dist_loader - - def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: - if self.train_dataset is None: - return None - - total_batch_size_per_acc_step = self.args.per_device_train_batch_size * self.args.dataset_world_size - total_batch_size = total_batch_size_per_acc_step - - # In llm/llama/run_pretrain.py, it uses paddlenlp.utils.batch_sampler.DistributedBatchSampler, - # which does no shuffle when shuffle is set True. - sampler = paddle.io.BatchSampler( - dataset=self.train_dataset, - shuffle=False, - batch_size=total_batch_size, - drop_last=self.args.dataloader_drop_last, - ) - sampler._acc_steps = self.args.gradient_accumulation_steps - return sampler - - -def print_config(args, key=""): - """ - print config values - """ - logger.info("=" * 60) - if args is None: - args = args - key = "Training" - import paddlenlp - - logger.info("{:^40}".format("{} Configuration Arguments".format(key))) - logger.info("{:30}: {}".format("paddle commit id", paddle.version.commit)) - logger.info("{:30}: {}".format("paddlenlp commit id", paddlenlp.version.commit)) - - for a in dir(args): - if a[:2] != "__": # don't print double underscore methods - v = getattr(args, a) - if not isinstance(v, types.MethodType): - logger.info("{:30}: {}".format(a, v)) - - logger.info("") - - -def init_seed(seed: int = 1234, args=None): - if args is None: - random.seed(seed) - np.random.seed(seed) - paddle.seed(seed) - else: - assert not args.use_hybrid_parallel and args.enable_auto_parallel - if dist.get_world_size() > 1: - if args.hybrid_parallel_topo_order is None or args.hybrid_parallel_topo_order == "pp_first": - order = ["pp", "dp", "sharding", "mp", "sep"] - elif args.hybrid_parallel_topo_order == "sharding_first": - order = ["dp", "sharding", "pp", "mp", "sep"] - topo = Topology( - dist.get_rank(), - dist.get_world_size(), - dp_degree=args.dataset_world_size, - pp_degree=args.pipeline_parallel_degree, - mp_degree=args.tensor_parallel_degree, - sharding_degree=1, # auto_parallel's sharding is not orthogonal with dp, mp and pp - order=order, - ) - - global_seed, local_seed, random_seed = _get_distributed_seeds(args.seed, topo) - - paddle.seed(local_seed) - random.seed(random_seed) - np.random.seed(random_seed) - - logger.info( - "The global seed is set to {}, local seed is set to {} and " - "random seed is set to {}.".format(global_seed, local_seed, random_seed) - ) - else: - random.seed(args.seed) - np.random.seed(args.seed) - paddle.seed(args.seed) - - -def get_mesh(pp_idx=0): - mesh = fleet.auto.get_mesh() - if "pp" in mesh.dim_names: - mesh = mesh.get_mesh_with_dim("pp")[pp_idx] - return mesh - - -def shard_fn(layer, mesh_idx, placements): - paran_name = layer.weight.name - layer.weight = dist.shard_tensor(layer.weight, get_mesh(mesh_idx), placements) - layer.weight.name = paran_name - - -def main(): - parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) - else: - model_args, data_args, training_args = parser.parse_args_into_dataclasses() - - if training_args.enable_linear_fused_grad_add: - from fused_layers import mock_layers - - mock_layers() - - if model_args.tokenizer_name_or_path is None: - model_args.tokenizer_name_or_path = model_args.model_name_or_path - - if data_args.data_cache is not None: - os.makedirs(data_args.data_cache, exist_ok=True) - - init_seed(args=training_args) - paddle.set_device(training_args.device) - if paddle.distributed.get_world_size() > 1: - paddle.distributed.init_parallel_env() - - training_args.eval_iters = 10 - training_args.test_iters = training_args.eval_iters * 10 - - # Log model and data config - training_args.print_config(model_args, "Model") - training_args.print_config(data_args, "Data") - - # Log on each process the small summary: - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " - + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" - ) - - # Detecting last checkpoint. - last_checkpoint = None - if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: - last_checkpoint = get_last_checkpoint(training_args.output_dir) - if last_checkpoint is not None and training_args.resume_from_checkpoint is None: - logger.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." - ) - - config_class, model_class, criterion_class = MODEL_CLASSES[model_args.model_type] - - tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) - - config = config_class.from_pretrained(model_args.model_name_or_path) - - config.seq_length = data_args.max_seq_length - # There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings - if not model_args.continue_training: - config.max_position_embeddings = max(config.max_position_embeddings, data_args.max_seq_length) - - if not model_args.continue_training: - config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128) - logger.info(f"Reset vocab size to {config.vocab_size} for batter amp peformance.") - - if model_args.no_recompute_layers is not None: - model_args.no_recompute_layers.sort() - - config.vocab_size = model_args.vocab_size if model_args.vocab_size is not None else config.vocab_size - config.hidden_size = model_args.hidden_size if model_args.hidden_size is not None else config.hidden_size - config.intermediate_size = ( - model_args.intermediate_size if model_args.intermediate_size is not None else config.intermediate_size - ) - config.num_hidden_layers = ( - model_args.num_hidden_layers if model_args.num_hidden_layers is not None else config.num_hidden_layers - ) - config.num_attention_heads = ( - model_args.num_attention_heads if model_args.num_attention_heads is not None else config.num_attention_heads - ) - - config.use_flash_attention = model_args.use_flash_attention - config.use_fused_rms_norm = model_args.use_fused_rms_norm - config.fuse_attention_qkv = model_args.fuse_attention_qkv - config.fuse_attention_ffn = model_args.fuse_attention_ffn - config.recompute_granularity = model_args.recompute_granularity - config.virtual_pp_degree = model_args.virtual_pp_degree - config.sequence_parallel = training_args.sequence_parallel - - config.fuse_sequence_parallel_allreduce = training_args.fuse_sequence_parallel_allreduce - - config.use_fused_rope = model_args.use_fused_rope - config.no_recompute_layers = model_args.no_recompute_layers - config.pp_recompute_interval = model_args.pp_recompute_interval - config.recompute_use_reentrant = model_args.recompute_use_reentrant - - config.use_recompute = training_args.recompute - config.tensor_parallel_degree = training_args.tensor_parallel_degree - config.tensor_parallel_rank = training_args.tensor_parallel_rank - config.sharding_parallel_degree = training_args.sharding_parallel_degree - - if training_args.strategy.pipeline.enable and config.virtual_pp_degree > 1: - pipeline = training_args.strategy.pipeline - pipeline.vpp_degree = config.virtual_pp_degree - pipeline.vpp_seg_method = training_args.virtual_pipeline_seg_method - - print("Final pre-training config:", config) - - # # Set the dtype for loading model - # dtype = "float32" - # if training_args.fp16_opt_level == "O2": - # if training_args.fp16: - # dtype = "float16" - # if training_args.bf16: - # dtype = "bfloat16" - - with paddle.LazyGuard(): - model = model_class.from_config(config, dtype="float32") - criterion = criterion_class(config) - - if training_args.recompute: - - def fn(layer): - if hasattr(layer, "enable_recompute") and (layer.enable_recompute is False or layer.enable_recompute == 0): - layer.enable_recompute = True - - model.apply(fn) - - # Create the learning_rate sheduler and optimizer - if training_args.decay_steps is None: - training_args.decay_steps = training_args.max_steps - - if training_args.warmup_steps > 0: - warmup_steps = training_args.warmup_steps - else: - warmup_steps = training_args.warmup_ratio * training_args.max_steps - - lr_scheduler = None - if training_args.lr_scheduler_type.value == "cosine": - lr_scheduler = CosineAnnealingWithWarmupDecay( - max_lr=training_args.learning_rate, - min_lr=training_args.min_learning_rate, - warmup_step=warmup_steps, - decay_step=training_args.decay_steps, - last_epoch=0, - ) - elif training_args.lr_scheduler_type.value == "linear": - lr_scheduler = LinearAnnealingWithWarmupDecay( - max_lr=training_args.learning_rate, - min_lr=training_args.min_learning_rate, - warmup_step=warmup_steps, - decay_step=training_args.decay_steps, - last_epoch=0, - ) - - data_file = get_train_data_file(data_args) - train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset( - data_args, - training_args, - data_file, - tokenizer, - need_data=training_args.should_load_dataset, - ) - trainer = PretrainingTrainer( - model=model, - criterion=criterion, - args=training_args, - data_collator=data_collator, - train_dataset=train_dataset if training_args.do_train else None, - eval_dataset=eval_dataset if training_args.do_eval else None, - optimizers=(None, lr_scheduler), - tokenizer=tokenizer, - ) - - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - - # Training - if training_args.do_train: - train_result = trainer.train(resume_from_checkpoint=checkpoint) - - # NOTE(gongenlei): new add - if not training_args.autotuner_benchmark: - metrics = train_result.metrics - if not int(os.getenv("test_ci_no_save_model", 0)): - trainer.save_model() - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - if training_args.do_predict: - test_ret = trainer.predict(test_dataset) - trainer.log_metrics("test", test_ret.metrics) - - # if training_args.should_load_dataset: - # effective_tokens_per_second = total_effective_tokens / train_result.metrics["train_runtime"] - # print(f"Effective Tokens per second: {effective_tokens_per_second:.2f}") - # print(f"ips: {effective_tokens_per_second:.2f} tokens/s") - - -def shard_model(model): - pp_stage = 0 - for name, layer in model.named_sublayers(include_self=False): - if hasattr(layer, "ipp"): - pp_stage = layer.ipp - # print(f"name {name},pp_stage {pp_stage}==>", type(layer)) - if "embed_tokens" in name: - # embedding only support column split now. it will update in the future - shard_fn(layer, 0, [dist.Replicate(), dist.Shard(1)]) - for n in [ - "self_attn.q_proj", - "self_attn.k_proj", - "self_attn.v_proj", - "self_attn.qkv_proj", - "gate_proj", - "up_proj", - "gate_up_fused_proj", - ]: - if n in name: - shard_fn(layer, pp_stage, [dist.Replicate(), dist.Shard(1)]) - break - for n in ["self_attn.o_proj", "down_proj"]: - if n in name: - shard_fn(layer, pp_stage, [dist.Replicate(), dist.Shard(0)]) - break - if "lm_head" in name: - shard_fn(layer, -1, [dist.Replicate(), dist.Shard(1)]) - - -def load_model(model): - model_state_dict = model.state_dict() - state_dict = paddle.load("hand/all.pdparams") - tmp = OrderedDict() - (tmp, state_dict) = (state_dict, tmp) - for (k, v) in tmp.items(): - k = map_structure_name(k) - state_dict[k] = v - model.set_state_dict(state_dict) - assert len(model_state_dict) == len(state_dict), f"{len(model_state_dict)} vs {len(state_dict)}" - """ - print("=======model_state_dict=======") - for (k,v) in model_state_dict.items(): - print(f"{k}=>{v.shape}") - """ - print("=======state_dict=======") - for (k, v) in state_dict.items(): - assert k in model_state_dict - print(f"{k}=>{v.shape}") - - -def print_grad(model): - model_state_dict = model.state_dict() - name_mapping = {v.name: k for (k, v) in model_state_dict.items()} - for p in model.parameters(): - assert p.name in name_mapping - if p.grad is not None: - print(f"{name_mapping[p.name]} {p.name}_grad shape: {p.grad.shape} md5sum: {p.grad._md5sum()}") - - -def print_param(model): - model_state_dict = model.state_dict() - name_mapping = {v.name: k for (k, v) in model_state_dict.items()} - for p in model.parameters(): - assert p.name in name_mapping - if p.grad is not None: - print(f"{name_mapping[p.name]} {p.name} shape: {p.shape} md5sum: {p._md5sum()}") - - -def map_structure_name(k): - fs = k.split(".") - idx = int(fs[1]) - if idx == 0: - return "deepseek_v2.embed_tokens.weight" - if idx == 28: - return "deepseek_v2.norm.weight" - if idx == 29: - return "lm_head.weight" - else: - return f"deepseek_v2.layers.{idx-1}." + ".".join(fs[2:]) - - -if __name__ == "__main__": - main() diff --git a/llm/auto_parallel/deepseek-v2/run_pretrain_auto.sh b/llm/auto_parallel/deepseek-v2/run_pretrain_auto.sh deleted file mode 100644 index 1ed3bb01edaf..000000000000 --- a/llm/auto_parallel/deepseek-v2/run_pretrain_auto.sh +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -#!/bin/bash -set -x -unset CUDA_VISIBLE_DEVICES - -task_name="deepseekv2" -rm -rf output/$task_name/ -rm -rf "output/$task_name""_log" -rm -rf /root/paddlejob/workspace/env_run/xuxinyi/PaddleNLP/llm/auto_parallel/deepseek-v2/log - -export SOT_LOG_LEVEL=4 -export PYTHONPATH=/root/paddlejob/workspace/env_run/xuxinyi/PaddleNLP:$PYTHONPATH -#ulimit -c unlimited -# export GLOG_v=3 - -# export FLAGS_call_stack_level=3 -# export FLAGS_use_cuda_managed_memory=true - -# export FLAGS_embedding_deterministic=1 -# export FLAGS_cudnn_deterministic=1 -# export NVIDIA_TF32_OVERRIDE=0 - -to_static=0 # 是否开启动转静训练 - -python -u -m paddle.distributed.launch \ - --gpus "0,1,2,3" \ - --log_dir "log" \ - run_pretrain_auto.py \ - --model_type "deepseekv2_auto" \ - --model_name_or_path "deepseek-ai/DeepSeek-V2-Lite" \ - --tokenizer_name_or_path "deepseek-ai/DeepSeek-V2-Lite" \ - --input_dir "./data" \ - --output_dir "output/$task_name" \ - --split 949,50,1 \ - --max_seq_length 2048 \ - --per_device_train_batch_size 1 \ - --per_device_eval_batch_size 2 \ - --gradient_accumulation_steps 2 \ - --use_flash_attention 0 \ - --use_fused_rms_norm 1 \ - --fp16 0 \ - --fp16_opt_level "O2" \ - --scale_loss 1024 \ - --pipeline_parallel_degree 1 \ - --tensor_parallel_degree 2 \ - --sharding_parallel_degree 1 \ - --learning_rate 0.0001 \ - --min_learning_rate 0.00001 \ - --max_steps 2 \ - --save_steps 5000000 \ - --weight_decay 0.01 \ - --warmup_ratio 0.01 \ - --logging_steps 1\ - --dataloader_num_workers 1 \ - --sharding "" \ - --eval_steps 1000000 \ - --disable_tqdm true \ - --continue_training 0\ - --recompute 0 \ - --do_train \ - --do_eval \ - --device "gpu" \ - --data_impl "mmap" \ - --enable_auto_parallel 1 \ - --max_grad_norm 1.0 \ - --num_hidden_layers 1 \ - --use_intermediate_api true \ - --to_static $to_static \ diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index ee58e1b638d9..c4d635fe33a7 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -675,6 +675,8 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs): is_bias=False, default_initializer=nn.initializer.Constant(1.0), ) + print("==== weight after init ====") + print(self.weight) if config.topk_method == "noaux_tc": self.e_score_correction_bias = paddle.create_parameter( @@ -691,6 +693,10 @@ def forward(self, hidden_states): _, h_dim = hidden_states.shape # compute gating score + print("==== weight ====") + print(self.weight) + print("==== hidden_states ====") + print(hidden_states) logits = F.linear(hidden_states, self.weight, None) with paddle.amp.auto_cast(False): diff --git a/paddlenlp/transformers/deepseek_v2/modeling_auto.py b/paddlenlp/transformers/deepseek_v2/modeling_auto.py index e21daddf2666..a6ca38e7af12 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_auto.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_auto.py @@ -49,17 +49,16 @@ from ..llama.modeling import get_use_casual_mask from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ..model_utils import PretrainedModel, register_base_model -from ..moe_layer import MoELayer +from ..moe_layer_auto import MoELayer +from ..moe_gate_auto import PretrainedMoEGate from .configuration import DeepseekV2Config from .modeling import ( - AddAuxiliaryLoss, DeepseekV2DynamicNTKScalingRotaryEmbedding, DeepseekV2LinearScalingRotaryEmbedding, DeepseekV2PretrainingCriterion, DeepseekV2RMSNorm, DeepseekV2RotaryEmbedding, DeepseekV2YarnRotaryEmbedding, - MoEGate, _expand_2d_mask, _make_causal_mask, apply_rotary_pos_emb, @@ -169,6 +168,68 @@ def scaled_dot_product_attention( return (attn_output, attn_weights) if output_attentions else attn_output +class MoEGate(PretrainedMoEGate): + def __init__(self, config, num_experts, expert_hidden_size, **kwargs): + super().__init__(config, num_experts, expert_hidden_size, **kwargs) + # [hidden_size, n_expert] + + self.scoring_func = config.scoring_func + self.topk_method = config.topk_method + + self.weight = paddle.create_parameter( + shape=[expert_hidden_size, num_experts], + dtype=paddle.get_default_dtype(), + is_bias=False, + default_initializer=nn.initializer.Constant(1.0), + ) + + if config.topk_method == "noaux_tc": + self.e_score_correction_bias = paddle.create_parameter( + shape=[num_experts], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(0.0), + ) + + def forward(self, hidden_states): + """ + Args: + hidden_states (_type_): [batch_size * seq_len, hidden_size] + """ + _, h_dim = hidden_states.shape + + # compute gating score + logits = F.linear(hidden_states, self.weight, None) + + with paddle.amp.auto_cast(False): + scores = self.gate_score_func(logits=logits) + scores = scores.cast(paddle.get_default_dtype()) + + capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.topkgating(scores) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + +class AddAuxiliaryLoss(paddle.autograd.PyLayer): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + assert paddle.numel(loss) == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = not loss.stop_gradient + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = paddle.ones(1, dtype=ctx.dtype) + return grad_output, grad_loss + + class DeepseekV2MLPAuto(nn.Layer): def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None, is_moe=False): super().__init__() @@ -972,18 +1033,19 @@ def _reorder_cache(past_key_values, beam_idx): def auto_dist_config(self, prefix=""): if prefix != "": assert prefix.endswith(".") - config = { - "mp_config": { - "parallelize_plan": { - f"{prefix}deepseek_v2.embed_tokens": dist.ColWiseParallel(gather_output=True), - f"{prefix}deepseek_v2.layers.*.self_attn.q_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.self_attn.kv_b_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.self_attn.o_proj": dist.RowWiseParallel(), - f"{prefix}deepseek_v2.layers.*.mlp.gate_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.mlp.up_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.mlp.down_proj": dist.RowWiseParallel(), - f"{prefix}lm_head.weight": dist.ColWiseParallel(), - } - }, - } + config = {} + # config = { + # "mp_config": { + # "parallelize_plan": { + # f"{prefix}deepseek_v2.embed_tokens": dist.ColWiseParallel(gather_output=True), + # f"{prefix}deepseek_v2.layers.*.self_attn.q_proj": dist.ColWiseParallel(), + # f"{prefix}deepseek_v2.layers.*.self_attn.kv_b_proj": dist.ColWiseParallel(), + # f"{prefix}deepseek_v2.layers.*.self_attn.o_proj": dist.RowWiseParallel(), + # f"{prefix}deepseek_v2.layers.*.mlp.gate_proj": dist.ColWiseParallel(), + # f"{prefix}deepseek_v2.layers.*.mlp.up_proj": dist.ColWiseParallel(), + # f"{prefix}deepseek_v2.layers.*.mlp.down_proj": dist.RowWiseParallel(), + # f"{prefix}lm_head.weight": dist.ColWiseParallel(), + # } + # }, + # } return config diff --git a/paddlenlp/transformers/moe_gate_auto.py b/paddlenlp/transformers/moe_gate_auto.py new file mode 100644 index 000000000000..b7cb622bbf40 --- /dev/null +++ b/paddlenlp/transformers/moe_gate_auto.py @@ -0,0 +1,549 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Tuple + +import traceback +import paddle +import paddle.distributed as dist +import paddle.nn as nn +import paddle.nn.functional as F + +from ..utils.log import logger + + +class MoEGateMixin: + def gate_score_func(self, logits: paddle.Tensor) -> paddle.Tensor: + # [..., hidden_dim] -> [..., num_experts] + with paddle.amp.auto_cast(False): + scoring_func = getattr(self, "scoring_func", None) + if scoring_func == "softmax": + scores = F.softmax(logits.cast("float32"), axis=-1) + elif scoring_func == "sigmoid": + scores = F.sigmoid(logits.cast("float32")) + elif scoring_func == "tanh": + scores = F.tanh(logits.cast("float32")) + elif scoring_func == "relu": + scores = F.relu(logits.cast("float32")) + elif scoring_func == "gelu": + scores = F.gelu(logits.cast("float32")) + elif scoring_func == "leaky_relu": + scores = F.leaky_relu(logits.cast("float32")) + else: + logger.warning_once( + f"insupportable scoring function for MoE gating: {scoring_func}, use softmax instead" + ) + scores = F.softmax(logits.cast("float32"), axis=-1) + return scores + + def gumbel_rsample(self, logits: paddle.Tensor) -> paddle.Tensor: + gumbel = paddle.distribution.gumbel.Gumbel(0, 1) + return gumbel.rsample(logits.shape) + + def uniform_sample(self, logits: paddle.Tensor) -> paddle.Tensor: + uniform = paddle.distribution.uniform.Uniform(0, 1) + return uniform.sample(logits.shape) + + @paddle.no_grad() + def _one_hot_to_float(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.get_default_dtype()) + + @paddle.no_grad() + def _one_hot_to_int64(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.int64) + + @paddle.no_grad() + def _capacity( + self, + gates: paddle.Tensor, + capacity_factor: float, + max_capacity: int, + min_capacity: int, + ) -> paddle.Tensor: + """Calculate the capacity for each expert based on the gates and capacity factor. + + Args: + gates (paddle.Tensor): A tensor of shape [num_tokens, num_experts] representing the probability distribution + over experts for each token. + capacity_factor (float): A scalar float value representing the capacity factor for each expert. + min_capacity (int): A scalar integer value representing the minimum capacity for each expert. + + Returns: + int: A tensor value representing the calculated capacity for each expert. + """ + assert gates.ndim == 2, f"gates should be 2D, but got {gates.ndim}, {gates.shape}" + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + print(f"==== num_tokens:{num_tokens}, num_experts:{num_experts} ====") + capacity = int((num_tokens // num_experts) * capacity_factor) + if capacity < min_capacity: + capacity = min_capacity + if capacity > max_capacity: + capacity = max_capacity + assert capacity > 0, f"requires capacity > 0, capacity_factor: {capacity_factor}, input_shape: {gates.shape}" + + return capacity + + def _cal_aux_loss(self, gates, mask): + """ + Calculate auxiliary loss + + Args: + gates (paddle.Tensor): Represents the output probability of each expert. The shape is [batch_size, num_experts] + mask (paddle.Tensor): Represents whether each sample belongs to a certain expert. The shape is [batch_size, num_experts] + + Returns: + paddle.Tensor: The value of auxiliary loss. + + """ + # TODO: @DrownFish19 update aux_loss for Qwen2MoE and DeepSeekV2&V3 + me = paddle.mean(gates, axis=0) + ce = paddle.mean(mask.cast("float32"), axis=0) + if self.global_aux_loss: + me_list, ce_list = [], [] + # dist.all_gather(me_list, me, group=self.group) + # dist.all_gather(ce_list, ce, group=self.group) + + me_list[self.rank] = me + ce_list[self.rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + aux_loss = paddle.sum(me * ce) * float(self.num_experts) + return aux_loss + + def _cal_z_loss(self, logits) -> paddle.Tensor: + """ + Calculate the z loss. + + Args: + logits (paddle.Tensor): Model output. The shape is [batch_size, num_experts]. + + Returns: + paddle.Tensor: The z loss value. + """ + l_zloss = logits.exp().sum(1).log().square().mean() + return l_zloss + + def _cal_orthogonal_loss(self) -> paddle.Tensor: + """Gate weight orthogonal loss. + + Returns: + Paddle.Tensor: orthogonal loss + """ + weight = F.normalize(self.weight, axis=0) + orthogonal_loss = paddle.mean(paddle.square(paddle.matmul(weight.T, weight) - paddle.eye(self.num_experts))) + return orthogonal_loss + + +class PretrainedMoEGate(nn.Layer, MoEGateMixin): + def __init__(self, config, num_experts, expert_hidden_size, **kwargs): + super(PretrainedMoEGate, self).__init__() + + self.config = config + + self.num_experts = num_experts + self.expert_hidden_size = expert_hidden_size + + # force keep in float32 when using amp + self._cast_to_low_precision = False + + self.capacity_factor = kwargs.pop("capacity_factor", 1.0) + self.eval_capacity_factor = kwargs.pop("eval_capacity_factor", 1.0) + self.min_capacity = kwargs.pop("min_capacity", 1.0) + self.max_capacity = kwargs.pop("max_capacity", pow(2, 32)) + + self.group = kwargs.pop("group", None) + self.global_aux_loss = kwargs.pop("global_aux_loss", False) + if self.global_aux_loss: + assert self.group is not None, "group is required when global_aux_loss is True" + self.rank = dist.get_rank(self.group) + + self.expert_drop = kwargs.pop("expert_drop", False) + self.noisy_gate_policy = kwargs.pop("noisy_gate_policy", None) + self.drop_tokens = kwargs.pop("drop_tokens", True) + self.use_rts = kwargs.pop("use_rts", True) + self.top2_2nd_expert_sampling = kwargs.pop("top2_2nd_expert_sampling", True) + + self.drop_policy = kwargs.pop("drop_policy", "probs") + # Qwen2MoE: greedy + # DeepSeekV2&V3: group_limited_greedy for training, and noaux_tc for inference + self.topk_method = kwargs.pop("topk_method", "greedy") + self.top_k = kwargs.pop("top_k", 2) + self.n_group = kwargs.pop("n_group", 1) # for group_limited_greedy + self.topk_group = kwargs.pop("topk_group", 1) # for group_limited_greedy + self.norm_topk_prob = kwargs.pop("norm_topk_prob", False) + self.routed_scaling_factor = kwargs.pop("routed_scaling_factor", 1.0) + + def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor: + """_summary_ + The priority is the cumulative sum of the expert indices. + + This method is used in hunyuan model + Args: + topk_idx (paddle.Tensor): [batch_size * seq_len, topk] + + Returns: + paddle.Tensor: cumsum locations + """ + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = paddle.transpose(topk_idx, [1, 0]) # [topk, B*S] + # Shape: [num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape([-1]) + + # Create mask out of indices. + # Shape: [tokens_per_group * num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_index, self.num_experts).cast(paddle.int32) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [tokens_per_group * num_selected_experts, num_experts]. + token_priority = paddle.cumsum(expert_mask, axis=0) * expert_mask - 1 + # Shape: [num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((self.top_k, -1, self.num_experts)) + # Shape: [tokens_per_group, num_selected_experts, num_experts]. + token_priority = paddle.transpose(token_priority, [1, 0, 2]) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [tokens_per_group, num_experts]. + token_priority = paddle.max(token_priority, axis=1) + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [tokens_per_group, num_experts, expert_capacity]. + valid_mask = paddle.logical_and(token_priority >= 0, token_priority < capacity) + token_priority = paddle.masked_fill(token_priority, ~valid_mask, 0) + dispatch_mask = F.one_hot(token_priority, capacity).cast(paddle.int32) + valid_mask = valid_mask.unsqueeze(-1).expand(valid_mask.shape + [capacity]) + dispatch_mask = paddle.masked_fill(dispatch_mask, ~valid_mask, 0) + + return dispatch_mask + + def _topk_greedy(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + """ + topk_weight, topk_idx = paddle.topk(scores, k=k, axis=-1, sorted=True) + return topk_weight, topk_idx + + def _topk_group_limited_greedy( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + group_scores = scores.reshape([0, n_group, -1]).max(axis=-1) # [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=False)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=False) + + return topk_weight, topk_idx + + def _topk_noaux_tc( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + assert self.e_score_correction_bias is not None, "e_score_correction_bias is None" + scores = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.unsqueeze(0) + group_scores = scores.reshape([bsz_seq_len, self.n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) # [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=False)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=False) + topk_weight = scores.gather(topk_idx, axis=1) if not self.training else topk_weight + + return topk_weight, topk_idx + + def top1gating( + self, + logits: paddle.Tensor, + used_token: paddle.Tensor = None, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements Top1Gating on logits.""" + if self.noisy_gate_policy == "RSample": + logits += self.gumbel_rsample(logits.shape) + + print("==== top1 ====") + traceback.print_stack() + gates = self.gate_score_func(logits=logits) + capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity) + + # Create a mask for 1st's expert per token + # noisy gating + # Only save the position of the maximum value + indices1_s = paddle.argmax(logits if self.noisy_gate_policy == "RSample" else gates, axis=1) + # Convert the position of the maximum value to a one-hot vector [s, e] + mask1 = self._one_hot_to_float(indices1_s, num_classes=self.num_experts) + + # mask only used tokens + if used_token is not None: + mask1 = paddle.einsum( + "s,se->se", used_token, mask1 + ) # Element-wise multiply used_token with mask1 to obtain a new mask1 + + # gating decisions + exp_counts = paddle.sum(mask1, axis=0) # Calculate the number of tokens for each expert + + # if we don't want to drop any tokens + if not self.drop_tokens: + new_capacity = paddle.max(exp_counts) # Calculate the number of tokens for each expert + # Communicate across expert processes to pick the maximum capacity. + if self.group is not None: + dist.all_reduce( + new_capacity, op=dist.ReduceOp.MAX, group=self.group + ) # Calculate the maximum value among expert processes + # Make sure the capacity value does not exceed the number of tokens. + capacity = int(min(new_capacity, paddle.tensor(mask1.size(0)))) + + l_aux = self._cal_aux_loss(gates, mask1) + l_zloss = self._cal_z_loss(logits) + + # Random Token Selection + if self.use_rts: + mask1_rand = mask1 * self.uniform_sample(mask1) + else: + mask1_rand = mask1 + + assert ( + logits.shape[0] >= self.min_capacity + ), "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." + + _, top_idx = paddle.topk(mask1_rand, k=capacity, axis=0) # Select top_capacity tokens + + new_mask1 = mask1 * paddle.zeros_like(mask1).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=0) + mask1 = new_mask1 + + # Compute locations in capacity buffer + locations1 = paddle.cumsum(mask1, axis=0) - 1 # Compute the position of each token in mask1 + + # Store the capacity location for each token + locations1_s = paddle.sum(locations1 * mask1, axis=1).cast(paddle.int64) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + gates = gates / gates * mask1_float + + locations1_sc = self._one_hot_to_float(locations1_s, capacity) + combine_weights = paddle.einsum("se,sc->sec", gates, locations1_sc) + dispatch_mask = combine_weights.cast(paddle.bool).detach() + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def top2gating( + self, + logits: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + # everything is in fp32 in this function + print("==== top2 ====") + traceback.print_stack() + gates = self.gate_score_func(logits=logits) + + # Create a mask for 1st's expert per token. + indices1_s = paddle.argmax(gates, axis=1) # [S, 1] + mask1 = self._one_hot_to_int64(indices1_s, self.num_experts) # [S, E] + + if self.top2_2nd_expert_sampling: + # Create a mask for 2nd's expert per token using Gumbel-max trick. + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + logits += self.gumbel_rsample(logits) + + # Replace top-expert with min value + logits_except1 = logits.masked_fill(mask1.cast(paddle.bool), float("-inf")) # [S, E] + indices2_s = paddle.argmax(logits_except1, axis=1) # [S, 1] + mask2 = self._one_hot_to_int64(indices2_s, self.num_experts) # [S, E] + + # Note: mask1 and mask2 can be combined to form a single mask. + # mask = paddle.concat([mask1, mask2], axis=0) + # locations = paddle.cumsum(mask, axis=0) - 1 + # locations1, locations2 = locations.split(2, axis=0) + # Compute locations in capacity buffer. + locations1 = paddle.cumsum(mask1, axis=0) - 1 # [S, E] + locations2 = paddle.cumsum(mask2, axis=0) - 1 # [S, E] + # Update 2nd's location by accounting for locations of 1st. + locations2 += paddle.sum(mask1, axis=0, keepdim=True) + + l_aux = self._cal_aux_loss(gates, mask1) + l_zloss = self._cal_z_loss(logits) + + # gating decisions + exp_counts = paddle.sum(mask1 + mask2, axis=0) + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity) + # Remove locations outside capacity from mask. + mask1 *= (locations1 < capacity).cast(paddle.int64) + mask2 *= (locations2 < capacity).cast(paddle.int64) + else: + # Do not drop tokens - set capacity according to current expert assignments + new_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(new_capacity) + + # Store the capacity location for each token. + locations1_s = paddle.sum(locations1 * mask1, axis=1) + locations2_s = paddle.sum(locations2 * mask2, axis=1) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + mask2_float = mask2.cast(paddle.float32) + gates1_s = paddle.einsum("se,se->s", gates, mask1_float) + gates2_s = paddle.einsum("se,se->s", gates, mask2_float) + denom_s = gates1_s + gates2_s + # Avoid divide-by-zero + denom_s = paddle.clip(denom_s, min=paddle.finfo(denom_s.dtype).eps) + gates1_s /= denom_s + gates2_s /= denom_s + + # Calculate combine_weights and dispatch_mask + gates1 = paddle.einsum("s,se->se", gates1_s, mask1_float) + gates2 = paddle.einsum("s,se->se", gates2_s, mask2_float) + locations1_sc = self._one_hot_to_float(locations1_s, capacity) + locations2_sc = self._one_hot_to_float(locations2_s, capacity) + combine1_sec = paddle.einsum("se,sc->sec", gates1, locations1_sc) + combine2_sec = paddle.einsum("se,sc->sec", gates2, locations2_sc) + combine_weights = combine1_sec + combine2_sec + dispatch_mask = combine_weights.cast(paddle.bool) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def topkgating( + self, + gates: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements TopKGating on logits.""" + print("==== topk ====") + traceback.print_stack() + l_zloss = self._cal_z_loss(gates) + + # get topk gates + if self.topk_method == "greedy": + top_gate, top_idx = self._topk_greedy(gates, k=self.top_k) + elif self.topk_method == "group_limited_greedy": + top_gate, top_idx = self._topk_group_limited_greedy( + gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + elif self.topk_method == "noaux_tc": + top_gate, top_idx = self._topk_noaux_tc( + gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + # norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20 + top_gate = top_gate / denominator + else: + top_gate = top_gate * self.routed_scaling_factor + + # get topk mask + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + l_aux = self._cal_aux_loss(gates, mask) + + exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) + + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity( + gates, + self.capacity_factor * self.top_k, + self.max_capacity, + self.min_capacity, + ) + + # update mask and locations by capacity + if self.drop_policy == "probs": + topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + capacity_probs, capacity_indices = paddle.topk(topk_masked_gates, k=capacity, axis=0, sorted=False) + token_priority = self._priority(capacity_indices, capacity) + + elif self.drop_policy == "position": + token_priority = self._priority(top_idx, capacity) + else: + raise ValueError(f"Invalid drop_policy: {self.drop_policy}") + else: + # Do not drop tokens - set capacity according to current expert assignments + local_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(local_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(local_capacity) + token_priority = self._priority(top_idx, capacity) + + # normalize gates + gates_masked = gates * mask + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s + + combine_weights = paddle.einsum("se,sec->sec", gates_masked, token_priority.cast(paddle.get_default_dtype())) + dispatch_mask = combine_weights.cast(paddle.bool) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss diff --git a/paddlenlp/transformers/moe_layer_auto.py b/paddlenlp/transformers/moe_layer_auto.py new file mode 100644 index 000000000000..771e38c92706 --- /dev/null +++ b/paddlenlp/transformers/moe_layer_auto.py @@ -0,0 +1,357 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, Tuple + +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed.communication import stream +from paddle.distributed.communication.group import Group + +from .moe_gate_auto import PretrainedMoEGate + + +def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): + """ + Rearranges the input tensor `x` based on gate results, truncates it according to the specified capacity, and performs padding. + + Args: + x (Tensor)[Seq, Dim]: The input tensor. + dispatch_mask (List[Tensor[Seq, 1], Tensor[Seq, 1]]): A list of dispatch masks. + scatter_index (Union[List[Tensor[Seq,], Tensor[Seq]], Tensor[Seq, 2]]): A list or tensor representing scatter indices. + num_experts (int): The number of experts. + capacity (int): The capacity size. + + Returns: + Tensor [Expert*Capacity, Dim]: The output tensor after dispatching. + """ + output = None + orig_dtype = x.dtype + if isinstance(scatter_index, paddle.Tensor): + scatter_index = scatter_index.unbind(1) + for i_scatter_index, i_dispatch_mask in zip(scatter_index, dispatch_mask): + init_output = paddle.zeros([num_experts * capacity, x.shape[-1]], dtype="float32") + updates = x * i_dispatch_mask.cast(x.dtype) + if output is None: + output = paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + else: + output = output + paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + if output.dtype != orig_dtype: + output = output.cast(orig_dtype) + return output + + +def combining(x, combine_weights, scatter_index): + """ + Performs combination and aggregation operations on the input matrix. + + Args: + x: Tensor[num_experts * capacity, dim] - The input matrix to be processed, where the last dimension represents the number of features. + combine_weights: Union[List[Tensor[seq, 1], Tensor[seq, 1]], Tensor[seq, 2, 1]] - A list or tensor containing combination weights for each feature. + scatter_index: Union[List[Tensor[seq], Tensor[seq]], Tensor[seq, 2]] - A tuple of indices indicating which elements are to be aggregated, where the first element is the row index and the second element is the column index. + + Returns: + Tensor: The output matrix after combination and aggregation, with a shape of [n, dim * num_features], where n is the number of samples in the input matrix. + """ + + dim = x.shape[-1] + if isinstance(scatter_index, (list, tuple)): + scatter_index = paddle.concat([i.unsqueeze([-1]) for i in scatter_index], -1) + scatter_index = scatter_index.reshape([-1]) + num_k = len(combine_weights) if isinstance(combine_weights, (list, tuple)) else combine_weights.shape[-1] + x = paddle.gather(x, scatter_index).reshape([-1, num_k, dim]) # [seq,2,dim] + if isinstance(combine_weights, (list, tuple)): + combine_weights = paddle.concat(combine_weights, -1).unsqueeze([1]) + return paddle.matmul(combine_weights, x).squeeze(1) # [seq,1,2] @ [seq,2,dim] -> [seq,1,dim] + + +class _AllToAll(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx: Any, + input: Tensor, + group: Group, + ) -> Tensor: # type: ignore + """ + All-to-all communication in the group. + + Args: + ctx (Any): Context object. + input (Tensor): Input tensor. + group (Group): The group object. + + Returns: + Tensor: Output tensor. + """ + + ctx.group = group + # return input + if dist.get_world_size(group) <= 1: + return input + output = paddle.empty_like(input) + stream.alltoall_single(output, input, None, None, group, True, True) + return output + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor]: + """ + Aggregates gradient information from all input tensors into a single tensor. + + Args: + ctx (Any): The context object used to store information that needs to be passed. + *grad_output (Tensor): A list of input tensors whose gradients are to be aggregated. + + Returns: + Tuple[Tensor]: A tuple containing a tensor that holds the gradients of all input tensors. + + """ + # return grad_output + return _AllToAll.apply(*grad_output, ctx.group) + + +class LocalPart(dist.LocalLayer): + def __init__(self, out_dist_attrs, config, gate: PretrainedMoEGate): + print("==== out_dist_attrs ====") + print(out_dist_attrs) + super().__init__(out_dist_attrs) + self.config = config + self.gate = gate + + def forward(self, hidden_state, gate_weight, used_token=None): + # Implement Algorithm 2 from GShard paper. + batch_size, seq_len, d_model = hidden_state.shape + + # Initial implementation -> Reshape into S tokens by dropping sequence dimension. + # Reshape into G groups so that each group can distribute tokens equally + # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 + reshaped_input = hidden_state.reshape([-1, d_model]) + print("==== reshaped_input ===") + print(reshaped_input) + + _, h_dim = reshaped_input.shape + + # compute gating score + logits = F.linear(reshaped_input, gate_weight, None) + + with paddle.amp.auto_cast(False): + scores = self.gate.gate_score_func(logits=logits) + scores = scores.cast(paddle.get_default_dtype()) + + # capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.topkgating(scores) + capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.gate.topkgating(scores) + print("==== combine_weights ====") + print(combine_weights) + print("==== dispatch_mask ====") + print(dispatch_mask) + + # self.l_aux : + # combine_weights : sec + # dispatch_mask : sec + # self.exp_counts : + dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input) + + return dispatched_input, combine_weights, l_aux, l_zloss + + +class LocalCombine(dist.LocalLayer): + def __init__(self, out_dist_attrs): + super().__init__(out_dist_attrs) + + def forward(self, combine_weights, expert_output, dtype="float32"): + combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(dtype), expert_output) + return combined_output + + +def get_mesh(pp_idx=0): + """ + 获得pp_idx的mesh + """ + mesh = dist.fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp", pp_idx) + return mesh + + +class MoELayer(nn.Layer): + def __init__( + self, + config, + moe_num_experts: int, + expert_class: nn.Layer, + expert_kwargs: dict, + gate: PretrainedMoEGate, + capacity: int = 1.0, + moe_group: str = "data", + all_to_all_dropout=0.0, + ): + super().__init__() + + self.config = config + + print(f"moe_num_experts:{moe_num_experts}") + self.moe_num_experts = moe_num_experts + self.capacity = capacity + self.expert_parallel_degree = 1 + + self.all_to_all_dropout = all_to_all_dropout + self.enable_recompute = False + + self.experts = nn.LayerList([]) + for i in range(self.moe_num_experts): + self.experts.append(expert_class(**expert_kwargs)) + + self.moe_num_experts_per_device = self._parse_moe_expert_parallel( + self.moe_num_experts, self.expert_parallel_degree + ) + self.moe_group = None + self.gate = gate + self.gate.group = self.moe_group + self.is_dummy_moe = True + self._post_init() + + mesh = get_mesh() + local_out_dist_attrs = [ + (mesh, [dist.Shard(1)]), # dispatched_input [e,c,h] + (mesh, [dist.Shard(0)]), # combine_weights [s,e,c] + (mesh, [dist.Partial()]), # l_aux, scalar + (mesh, [dist.Partial()]), # l_zloss, scalar + ] + self.local_computes = LocalPart(local_out_dist_attrs, config, gate) + + local_combine_dist_attrs = [ + (mesh, [dist.Shard(0)]) + ] + self.local_combine = LocalCombine(local_combine_dist_attrs) + + def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): + assert ( + moe_num_experts >= expert_parallel_degree + ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={expert_parallel_degree}" + assert ( + moe_num_experts % expert_parallel_degree == 0 + ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={expert_parallel_degree} == 0" + moe_num_experts_per_device = moe_num_experts // expert_parallel_degree + return moe_num_experts_per_device + + def _post_init(self): + for p in self.gate.parameters(): + p.is_gate = True + + for k in self.experts: + if k is not None: + for p in k.parameters(): + p.expert = not self.is_dummy_moe + p.no_sync = not self.is_dummy_moe + # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") + + def expert_forward(self, dispatched_input): + expert_outputs = [] + chunks = dispatched_input.unbind(1) + for chunk, expert in zip(chunks, self.experts): + chunk = chunk.contiguous() + expert_outputs += [expert(chunk)] + expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] + return expert_output + + def forward( + self, + hidden_state: paddle.Tensor, + used_token: paddle.Tensor = None, + ): + """_summary_ + + Args: + input (_type_): _description_ + used_token + + Returns: + _type_: _description_ + """ + # Implement Algorithm 2 from GShard paper. + batch_size, seq_len, d_model = hidden_state.shape + + # Initial implementation -> Reshape into S tokens by dropping sequence dimension. + # Reshape into G groups so that each group can distribute tokens equally + # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 + # reshaped_input = hidden_state.reshape([-1, d_model]) + # reshaped_input = dist.reshard(reshaped_input, reshaped_input.process_mesh, [dist.Replicate(), dist.Replicate()]) + # print("==== reshaped_input ====") + # print(reshaped_input) + + # capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.gate(reshaped_input) + # print("==== combine_weights ====") + # print(combine_weights) + + # # self.l_aux : + # # combine_weights : sec + # # dispatch_mask : sec + # # self.exp_counts : + # dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input) + # print("==== dispatched_input ====") + # print(dispatched_input) + + print("==== hidden_state ====") + print(hidden_state) + dispatched_input, combine_weights, l_aux, l_zloss = self.local_computes(hidden_state, self.gate.weight, used_token=used_token) + + # dispatched_input = dist.reshard(dispatched_input, get_mesh(), [dist.Shard(0)]) + # if self.expert_parallel_degree > 1: + # dispatched_input = _AllToAll.apply(dispatched_input, self.moe_group) + + # Re-shape after all-to-all: ecm -> gecm + dispatched_input = dispatched_input.reshape( + [self.expert_parallel_degree, self.moe_num_experts_per_device, -1, d_model] + ) + expert_output = self.expert_forward(dispatched_input) + # Re-shape before drop_tokens: gecm -> ecm + expert_output = expert_output.reshape( + [self.expert_parallel_degree * self.moe_num_experts_per_device, -1, d_model] + ) + print("==== expert_output ===") + print(expert_output) + + expert_output = dist.reshard(expert_output, get_mesh(), [dist.Shard(1)]) + print("==== expert_output after reshard ====") + print(expert_output) + # if self.expert_parallel_degree > 1: + # expert_output = _AllToAll.apply(expert_output, self.moe_group) + + print("==== combine_weights ====") + print(combine_weights) + # combine withe expert weights + # Einsum infermeta has not supported auto parallel dist tensor, + # so use local layer here. + # combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output) + combined_output = self.local_combine(combine_weights, expert_output, dtype=hidden_state[0].dtype) + + a = combined_output.reshape(hidden_state.shape) + print("==== a ====") + print(a) + + return a, l_aux, l_zloss