From 99344cf2fbd5cb1574c4546a086891157cbf599a Mon Sep 17 00:00:00 2001 From: shuailong616 <452509829@qq.com> Date: Wed, 29 Oct 2025 16:19:36 +0800 Subject: [PATCH 1/8] Single-process simulator adapted to the new version of Flagscale code --- flagscale/train/hetero/parallel_context.py | 66 +++++++++++++++++---- flagscale/train/theoretical_memory_usage.py | 3 +- flagscale/train/train.py | 45 ++++++-------- flagscale/train/train_gpt.py | 2 + 4 files changed, 78 insertions(+), 38 deletions(-) diff --git a/flagscale/train/hetero/parallel_context.py b/flagscale/train/hetero/parallel_context.py index 8054b1c288..5d6905ce79 100644 --- a/flagscale/train/hetero/parallel_context.py +++ b/flagscale/train/hetero/parallel_context.py @@ -124,6 +124,7 @@ def __init__(self, args): self._rank_infos = {} self._physical_rank_to_logical_rank = {} self._logical_rank_to_physical_rank = {} + self._enable_simulator = args.enable_simulator self.build_rank_mapping() def build_rank_mapping(self): @@ -133,8 +134,16 @@ def build_rank_mapping(self): all_rank_infos = [None] * world_size cur_rank_info = {'rank': rank, 'device_type': self._hetero_current_device_type} - torch.distributed.all_gather_object( - all_rank_infos, cur_rank_info) + #torch.distributed.all_gather_object( + # all_rank_infos, cur_rank_info) + if self._enable_simulator: + for index, value in enumerate(all_rank_infos): + corresponding_rank_info = {'rank': index, 'device_type': self._hetero_current_device_type} + all_rank_infos[index] = corresponding_rank_info + else: + torch.distributed.all_gather_object( + all_rank_infos, cur_rank_info) + physical_ranks = [] for info in all_rank_infos: self._rank_infos[info['rank']] = info @@ -308,10 +317,17 @@ def build_process_group( ranks = self._rank_mapper.to_physical_ranks(logical_ranks) group = create_group(ranks, timeout=self._timeout, backend=self._distributed_backend, pg_options=pg_options, group_desc=group_name) if gloo: - if create_gloo_process_groups: - group_gloo = create_group(ranks, timeout=self._timeout, backend="gloo", group_desc=group_name+"_gloo") + if self._args.enable_simulator: + if create_gloo_process_groups: + group_gloo = create_group(ranks, timeout=self._timeout, backend=self._distributed_backend, group_desc=group_name+"_gloo") + else: + group_gloo = None + else: - group_gloo = None + if create_gloo_process_groups: + group_gloo = create_group(ranks, timeout=self._timeout, backend="gloo", group_desc=group_name+"_gloo") + else: + group_gloo = None self._all_group_ranks[group_name].append(ranks) if self._rank in ranks: self._group_ranks[group_name] = ranks @@ -638,9 +654,19 @@ def build_all_process_meshes(self): "rank": rank, "process_mesh_idx": self._current_process_mesh_index, } - torch.distributed.all_gather_object( - all_rank_to_process_mesh, cur_rank_to_process_mesh - ) + #torch.distributed.all_gather_object( + # all_rank_to_process_mesh, cur_rank_to_process_mesh + #) + + if self._args.enable_simulator: + for index, value in enumerate(all_rank_to_process_mesh): + corresponding_mesh_info = {'rank': index, 'process_mesh_idx': self._current_process_mesh_index} + all_rank_to_process_mesh[index] = corresponding_mesh_info + else: + torch.distributed.all_gather_object( + all_rank_to_process_mesh, cur_rank_to_process_mesh + ) + for item in all_rank_to_process_mesh: self._rank_to_process_mesh[item["rank"]] = self._process_meshes[ item["process_mesh_idx"] @@ -756,7 +782,12 @@ def _backtrack(mesh_index, prev_rank, path, token = "pp", is_expert=False): aggregated_ranks = [rank for ranks in path for rank in ranks] self._global_all_group_ranks[group_name].append(aggregated_ranks) # NOTE: "use_local_synchronization=True" works well in torhch <= 2.5, but it causes hang in torch >= 2.6 - group = create_group(aggregated_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc=group_name) + #group = create_group(aggregated_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc=group_name) + if self._args.enable_simulator: + group = create_group(aggregated_ranks, timeout=self._timeout, use_local_synchronization=False, backend=self._args.distributed_backend, group_desc=group_name) + else: + group = create_group(aggregated_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc=group_name) + if self._rank in aggregated_ranks: self._global_process_groups[group_name].append(group) self._global_group_ranks[group_name].append(aggregated_ranks) @@ -812,13 +843,24 @@ def _backtrack(mesh_index, prev_rank, path, token = "pp", is_expert=False): else: embedding_ranks = ranks position_embedding_ranks = ranks - group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd") + #group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd") + if self._args.enable_simulator: + group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=False, backend=self._args.distributed_backend, group_desc="embd") + else: + group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd") + + if self._rank in embedding_ranks and ("embd" not in self._global_group_ranks or embedding_ranks not in self._global_group_ranks["embd"]): self._global_process_groups["embd"].append(group) self._global_process_group_to_ranks[group] = embedding_ranks self._global_group_ranks["embd"].append(embedding_ranks) - group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd_pos") + #group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd_pos") + if self._args.enable_simulator: + group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=False, backend=self._args.distributed_backend, group_desc="embd_pos") + else: + group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd_pos") + if self._rank in position_embedding_ranks: self._global_process_groups["embd_pos"].append(group) self._global_process_group_to_ranks[group] = position_embedding_ranks @@ -1634,6 +1676,8 @@ def _build_ddp_config(args): kwargs[f.name] = getattr(args, f.name) kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32 kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad + if args.enable_simulator: + args.check_for_nan_in_loss_and_grad = False kwargs['bucket_size'] = args.ddp_bucket_size kwargs['average_in_collective'] = args.ddp_average_in_collective ddp_config = DistributedDataParallelConfig(**kwargs) diff --git a/flagscale/train/theoretical_memory_usage.py b/flagscale/train/theoretical_memory_usage.py index e5da1d04c4..d2863c987a 100644 --- a/flagscale/train/theoretical_memory_usage.py +++ b/flagscale/train/theoretical_memory_usage.py @@ -467,7 +467,8 @@ def compute_activation_memory(args, num_microbatches, verbose=False): 2 * args.seq_length * args.micro_batch_size * args.hidden_size ) # Attention: - if args.multi_latent_attention: + # if args.multi_latent_attention: + if getattr(args, "multi_latent_attention", False): # 1. Q, K, V matrix multiplies if args.q_lora_rank is None: QKV_activation_memory = ( diff --git a/flagscale/train/train.py b/flagscale/train/train.py index e02e2890e5..1e0571d6ca 100644 --- a/flagscale/train/train.py +++ b/flagscale/train/train.py @@ -781,11 +781,9 @@ def pretrain( inprocess_call_wrapper: an optional instance of inprocess.CallWrapper, it is automatically injected when in-process restart is in use """ - if inprocess_call_wrapper is not None: iteration = inprocess_call_wrapper.iteration store = torch.distributed.PrefixStore(str(iteration), store) - # Initalize and get arguments, timers, and Tensorboard writer. initialize_megatron( extra_args_provider=extra_args_provider, @@ -797,7 +795,6 @@ def pretrain( args = get_args() timers = get_timers() - if args.log_progress: append_to_progress_log("Starting job") @@ -809,7 +806,6 @@ def pretrain( # Set pytorch JIT layer fusion options and warmup JIT functions. set_jit_fusion_options() - # Adjust the startup time so it reflects the largest value. # This will be closer to what scheduler will see (outside of # image ... launches. @@ -832,10 +828,8 @@ def pretrain( ) print_datetime('after megatron is initialized') app_metrics['app_model_init_finish_time'] = one_logger_utils.get_timestamp_in_ms() - # Track E2E metrics on pretrain start one_logger_utils.on_pretrain_start() - # Context used for persisting some state between checkpoint saves. if args.non_persistent_ckpt_type == 'local': try: @@ -869,22 +863,18 @@ def pretrain( } else: checkpointing_context = {} - ########## FlagScale Begin ########## num_microbatches = get_num_microbatches() fs_report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True) ########## FlagScale End ########## - # Model, optimizer, and learning rate. timers('model-and-optimizer-setup', log_level=0).start(barrier=True) model, optimizer, opt_param_scheduler = setup_model_and_optimizer( model_provider, model_type, checkpointing_context=checkpointing_context ) - timers('model-and-optimizer-setup').stop() print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') config = get_model_config(model[0]) - # Data stuff. app_metrics['app_build_dataiters_start_time'] = one_logger_utils.get_timestamp_in_ms() timers('train/valid/test-data-iterators-setup', log_level=0).start(barrier=True) @@ -929,7 +919,6 @@ def pretrain( # Print setup timing. print_rank_0('done with setup ...') timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'], barrier=True) - one_logger = get_one_logger() one_logger and one_logger.log_metrics(app_metrics) @@ -937,10 +926,8 @@ def pretrain( if wandb_writer: # Add job name to the wandb config to make it easier to run more singleton dependency jobs. wandb_writer.config.update({'slurm_job_name': os.getenv("SLURM_JOB_NAME", "N/A")}) - if not args.skip_train: print_rank_0('training ...') - if args.dataloader_type == 'cyclic' and args.retro_project_dir: assert args.retro_cyclic_train_iters is not None args.train_iters = args.retro_cyclic_train_iters @@ -1055,7 +1042,6 @@ def pretrain( non_loss_data_func=non_loss_data_func ) ######### FlagScale End ########## - wandb_writer = get_wandb_writer() if wandb_writer: wandb_writer.finish() @@ -1071,7 +1057,6 @@ def pretrain( ft_integration.shutdown() one_logger_utils.finish() - def update_train_iters(args): # For iteration-based training, we don't need to do anything @@ -1414,14 +1399,14 @@ def setup_model_and_optimizer( no_wd_decay_cond, scale_lr_cond, lr_mult, - use_gloo_process_groups=args.enable_gloo_process_groups, + #use_gloo_process_groups=args.enable_gloo_process_groups, + #use_gloo_process_groups=False, # If the user is asking for a non-zero embedding init std, skip weight decay for embeddings # to avoid embeddings from shrinking to zero as recommended in https://arxiv.org/abs/2312.16903 - default_skip_embedding_weight_decay=args.embedding_init_method_std is not None, + #default_skip_embedding_weight_decay=args.embedding_init_method_std is not None, ) opt_param_scheduler = get_optimizer_param_scheduler(optimizer) one_logger and one_logger.log_metrics({"app_build_optimzer_finish_time": one_logger_utils.get_timestamp_in_ms()}) - if args.moe_use_upcycling: torch.distributed.barrier() assert not checkpoint_exists(args.save), ( @@ -1528,7 +1513,6 @@ def setup_model_and_optimizer( print_rank_0("> converted checkpoint: %s -> %s." % (load_ckpt_format, args.ckpt_format)) torch.distributed.barrier() exit() - return model, optimizer, opt_param_scheduler @@ -1547,14 +1531,12 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch """Single training step.""" args = get_args() timers = get_timers() - rerun_state_machine = get_rerun_state_machine() while rerun_state_machine.should_run_forward_backward(data_iterator): # Set grad to zero. for model_chunk in model: model_chunk.zero_grad_buffer() optimizer.zero_grad() - if has_nvidia_modelopt: # [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation( @@ -1572,6 +1554,9 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch optim_instance._copy_main_params_to_param_buffer() # Forward pass. + # =================== Forward + Backward timing =================== + torch.cuda.synchronize() + t_fwd_start = time.time() losses_reduced = forward_backward_func( forward_step_func=forward_step_func, data_iterator=data_iterator, @@ -1580,9 +1565,15 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch seq_length=args.seq_length, micro_batch_size=args.micro_batch_size, decoder_seq_length=args.decoder_seq_length, - forward_only=False, + forward_only=True, adjust_tensor_shapes_fn=adjust_tensor_shapes_fn, ) + torch.cuda.synchronize() + t_fwd_end = time.time() + fwd_time = t_fwd_end - t_fwd_start + bwd_time = fwd_time * 2.0 + print(f"[simulatior output] forward: {fwd_time:.2f}, backward: {bwd_time:.2f}", flush=True) + # ================================================================ should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() if should_exit: return {}, True, should_checkpoint, should_exit, exit_code, None, None @@ -1609,11 +1600,11 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) # Update parameters. - + # =================== Communication / Optimizer timing =================== timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time) update_successful, grad_norm, num_zeros_in_grad = optimizer.step() timers('optimizer').stop() - + #print(f"[simulatior output] forward: {fwd_time:.2f}, backward: {bwd_time:.2f}, communication: {comm_time:.2f}", flush=True) # when freezing sub-models we may have a mixture of successful and unsucessful ranks, # so we must gather across mp ranks update_successful = logical_and_across_model_parallel_group(update_successful) @@ -1640,6 +1631,7 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch if args.empty_unused_memory_level >= 2: torch.cuda.empty_cache() + if mpu.is_pipeline_last_stage(ignore_virtual=True): # Average loss across microbatches. loss_reduced = {} @@ -2640,7 +2632,7 @@ def get_e2e_base_metrics(): model, optimizer, iteration, ref_state_dict, ) train_data_iterator = buffered_rollouts - + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 ft_integration.on_training_step_start() ( loss_dict, @@ -3257,7 +3249,8 @@ def build_train_valid_test_data_loaders(build_train_valid_test_datasets_provider args.do_test = getattr(args, "do_test", False) or flags[2].item() if getattr(args, 'perform_rl_step', False): args.to_test = False - + if args.enable_simulator: + args.do_train = 1 return train_dataloader, valid_dataloaders, test_dataloader diff --git a/flagscale/train/train_gpt.py b/flagscale/train/train_gpt.py index 780d247866..32dcb8e45c 100644 --- a/flagscale/train/train_gpt.py +++ b/flagscale/train/train_gpt.py @@ -88,6 +88,8 @@ def loss_func( # Check individual rank losses are not NaN prior to DP all-reduce. rerun_state_machine = get_rerun_state_machine() + if args.enable_simulator: + args.check_for_nan_in_loss_and_grad = False if args.check_for_nan_in_loss_and_grad: rerun_state_machine.validate_result( result=loss, From 835581a1a61bc5bb7efca5813fac255164021dce Mon Sep 17 00:00:00 2001 From: shuailong616 <452509829@qq.com> Date: Wed, 29 Oct 2025 16:21:35 +0800 Subject: [PATCH 2/8] Single-process simulator adapted to the new version of Flagscale code --- .../runner/auto_tuner/simulator/README.md | 23 ++ .../simulator/analylize_pipeline_time.py | 225 ++++++++++ .../runner/auto_tuner/simulator/config_gen.py | 391 ++++++++++++++++++ .../custom_backend/include/dummy.hpp | 157 +++++++ .../simulator/custom_backend/setup.py | 25 ++ .../simulator/custom_backend/src/dummy.cpp | 285 +++++++++++++ 6 files changed, 1106 insertions(+) create mode 100644 flagscale/runner/auto_tuner/simulator/README.md create mode 100644 flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py create mode 100644 flagscale/runner/auto_tuner/simulator/config_gen.py create mode 100644 flagscale/runner/auto_tuner/simulator/custom_backend/include/dummy.hpp create mode 100644 flagscale/runner/auto_tuner/simulator/custom_backend/setup.py create mode 100644 flagscale/runner/auto_tuner/simulator/custom_backend/src/dummy.cpp diff --git a/flagscale/runner/auto_tuner/simulator/README.md b/flagscale/runner/auto_tuner/simulator/README.md new file mode 100644 index 0000000000..eee9c9513e --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/README.md @@ -0,0 +1,23 @@ +# Environment +Begin at the root path of `FlagScale` repository: +``` +cd flagscale/flagscale/runner/auto_tuner/simulator/custom_backend/ +python setup.py develop +``` + +# Setup +Set necessary parameters in `config_gen.py`. For example: +``` +device_type_list = ["A", "B"] +device_num_list = [4, 4] +global_batch_size = 32 +num_micro_batches = 8 +num_layers = 4 +``` +# Run a Task +Start the auto-tuning: +``` +export PYTHONPATH=/****/FlagScale:$PYTHONPATH +export PYTHONPATH=$PYTHONPATH:/***/FlagScale/third_party/Megatron-LM + +python flagscale/runner/auto_tuner/simulator/config_gen.py diff --git a/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py b/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py new file mode 100644 index 0000000000..ce61302b10 --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py @@ -0,0 +1,225 @@ +import os +import re +import subprocess +import time + +# from megatron.training import get_args + + +def kill_other_python_processes(): + current_pid = os.getpid() + clear_cmd = f"pkill -f python -o --signal TERM --ignore \"${current_pid}\"" + subprocess.run(clear_cmd, text=True, shell=True) + + +def compute_pipeline_parallelism_cost( + scheme: str = '1F1B', + # num_stages: int=1, + num_micro_batches: int = 1, + process_mesh: list = None, + pp_layers_split: list = None, + fwd_time_per_stage_chunk: list = None, + bwd_time_per_stage_chunk: list = None, + comm_time_between_stages: list = None, + # TODO: add fine-greaied recomputation +): + print(f"--- Compute Pipeline Cost ---") + + # process_mesh: [tp0,cp0,ep0,dp0,pp0,(tp1,cp1,...)] + # comm_time_between_stages[i] means the comm time between stage i-1 and stage i + num_pp_stages = sum(process_mesh[4::5]) + assert ( + len(pp_layers_split) == num_pp_stages + ), "\flength of list {num_layers_per_stage} should match {num_stages}" + assert ( + len(fwd_time_per_stage_chunk) == num_pp_stages + ), "\flength of list {fwd_time_per_stage_chunk} should match {num_stages}" + assert ( + len(bwd_time_per_stage_chunk) == num_pp_stages + ), "\flength of list {bwd_time_per_stage_chunk} should match {num_stages}" + assert ( + len(comm_time_between_stages) == num_pp_stages + ), "\flength of list {comm_time_between_stages} should match {num_stages}" + + pp_last_stage_time = num_micro_batches * ( + fwd_time_per_stage_chunk[num_pp_stages - 1] + bwd_time_per_stage_chunk[num_pp_stages - 1] + ) + if num_pp_stages == 1: + return num_micro_batches * ( + fwd_time_per_stage_chunk[num_pp_stages - 1] + + bwd_time_per_stage_chunk[num_pp_stages - 1] + ) + + pipeline_cost = 0 + # TODO: consider when comm time > comp time + # each stage onlt depends on its next stage + if scheme == '1F1B' or scheme == 'AFAB': + pipeline_cost = pp_last_stage_time + for stage_from_last in range(2, num_pp_stages): + pp_this_stage_overlapped_time = (num_micro_batches - 1) * ( + fwd_time_per_stage_chunk[num_pp_stages - 1] + + bwd_time_per_stage_chunk[num_pp_stages - 1] + ) + pp_this_stage_compute_time = ( + fwd_time_per_stage_chunk[num_pp_stages - stage_from_last] + + bwd_time_per_stage_chunk[num_pp_stages - stage_from_last] + ) + pp_last_stage_overall_time = ( + pipeline_cost + 2 * comm_time_between_stages[num_pp_stages - stage_from_last + 1] + ) + # not consider the situation that comm stucks the comp + # which means the comm time should no more than the comp time(fwd time) + pipeline_cost = pp_this_stage_compute_time + max( + pp_last_stage_overall_time, pp_this_stage_overlapped_time + ) + else: + raise (ValueError("Scheme must be '1F1B' or 'AFAB'.")) + + return pipeline_cost + + +import random + + +def simulator( + process_mesh: list = None, + stage: int = 0, + num_layers: int = None, + simulated_rank: int = None, + pp_layers_split: list = None, +): + + # os.environ["PYTHONPATH"] = "/share/project/heyongzhe/FlagScale/megatron:/share/project/heyongzhe/FlagScale" + os.environ["PYTHONPATH"] = ( + "/workspace/20251010/new/FlagScale:" + "/workspace/20251010/new/FlagScale/third_party/Megatron-LM" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + os.environ["RANK"] = str(simulated_rank) + os.environ["LOCAL_RANK"] = str(simulated_rank) + # os.environ["WORLD_SIZE"] = args.world_size + os.environ["WORLD_SIZE"] = "8" + # os.environ["WORLD_SIZE"] = "32" + rdav_endpoint = random.randint(0, 40000) + os.environ["RDZV_ENDPOINT"] = "localhost:" + str(rdav_endpoint) + # os.environ["RZDV_ENDPOINT"]="localhost:37832" + os.environ["RDZV_BACKEND"] = "c10d" + os.environ["MASTER_ADDR"] = "localhost" + + program_entry = " ./flagscale/train/train_aquila_sft.py " + simulation_arguments = " --enable-hetero --enable-simulator --distributed-backend dummy " + # fine_grained_recomputation_args = "--recompute-granularity-per-stage-micro-batch '[1, 1, 1]' --recompute-method-per-stage-micro-batch '[1, 1, 1]' --recompute-num-layers-per-stage-micro-batch '[1, 1, 1]'" + fine_grained_recomputation_args = "" + # print(stage) + + pp_layer_split_args = " --hetero-pipeline-layer-split " + for layers in pp_layers_split: + pp_layer_split_args = pp_layer_split_args + str(layers) + " " + + process_mesh_str = " --hetero-process-meshes " + for dim in process_mesh: + process_mesh_str = process_mesh_str + str(dim) + " " + + num_pp_stages = sum(process_mesh[4::5]) + pp_size_args = " --pipeline-model-parallel-size " + str(num_pp_stages) + " " + + # TODO: too ugly to show this command in the code, re-organize these parameters in another way later + train_command = ( + "python " + + program_entry + + "--tensor-model-parallel-size 1 --timing-log-level 2 --disable-bias-linear --use-flash-attn --sequence-parallel --use-distributed-optimizer --use-mcore-models --transformer-impl transformer_engine --hetero-device-types A800 BI150 --hetero-current-device-type A800 --recompute-granularity full --recompute-method uniform --recompute-num-layers 1 --bf16 --attention-softmax-in-fp32 --accumulate-allreduce-grads-in-fp32 --log-interval 1 --log-throughput --tensorboard-log-interval 1 --wandb-project aquila2 --wandb-exp-name test --tensorboard-dir /share/project/heyongzhe/FlagScale/outputs/tensorboard --wandb-save-dir /share/project/heyongzhe/FlagScale/outputs/wandb --num-layers 32 --hidden-size 4096 --num-attention-heads 32 --seq-length 2048 --max-position-embeddings 2048 --norm-epsilon 1e-05 --use-rotary-position-embeddings --no-position-embedding --swiglu --multiple-of 256 --normalization RMSNorm --untie-embeddings-and-output-weights --init-method-std 0.0165 --attention-dropout 0.0 --hidden-dropout 0.0 --weight-decay 0.1 --clip-grad 1.0 --train-samples 128 --global-batch-size 64 --micro-batch-size 1 --seed 42 --lr 0.0002 --weight-decay 0.01 --adam-beta1 0.9 --adam-beta2 0.95 --lr 0.00015 --min-lr 1.5e-05 --lr-warmup-samples 0 --lr-decay-style cosine --data-path /workspace/FlagScale/datapath/pile_wikipedia_demo --split 1 --tokenizer-type AquilaTokenizerFS --vocab-file ./examples/aquila/tokenizer/vocab.json --merge-file ./examples/aquila/tokenizer/merges.txt --special-tokens-file ./examples/aquila/tokenizer/special_tokens.txt --vocab-size 100008 " + + process_mesh_str + + simulation_arguments + + pp_layer_split_args + + fine_grained_recomputation_args + + pp_size_args + ) + + # enough sleeping time is needed to really kill the survival megatron process + # as least 5 sec before & after killing can not succeed every time + print("sleeping...") + # print(train_command) + # time.sleep(10) + kill_other_python_processes() + # time.sleep(10) + print("start...") + result = subprocess.run(train_command, capture_output=True, text=True, shell=True) + output = result.stdout.strip() + print(train_command) + print(output) + # example output: "[simulatior output] forward: 12.34, backward: 56.78, communication: 90.12" + match = re.search(r"forward:\s*([\d.]+),\s*backward:\s*([\d.]+)", output) + if match: + fwd_time = float(match.group(1)) + bwd_time = float(match.group(2)) + # comm_time = float(match.group(3)) + comm_time = 0.01 + print("forward:", fwd_time) + print("backward:", bwd_time) + print("communication:", comm_time) + else: + # print(fwd_time,bwd_time,comm_time) + fwd_time = 12.34 + bwd_time = 56.78 + comm_time = 90.12 + print("forward:", fwd_time) + print("backward:", bwd_time) + print("communication:", comm_time) + # raise(ValueError("Results not found. Example output: \"[simulatior output] forward: 12.34, backward: 56.78, communication: 90.12\"")) + return fwd_time, bwd_time, comm_time + + +# call simulator to obtain the execution of each stage +def simulate_pipeline_parallelism_per_stage_time( + process_mesh: list = None, + pp_layers_split: list = None, + fwd_time_per_stage_chunk: list = None, + bwd_time_per_stage_chunk: list = None, + comm_time_between_stages: list = None, +): + print(f"--- Simulation Begin ---") + print(f"Process Mesh: {process_mesh}") + print(f"PP Layer Split: {pp_layers_split}") + for stage, num_layers in enumerate(pp_layers_split): + # TODO: confirm simulated_rank for different stage + print(f"Stage: {stage}; Num Layers: {num_layers}") + simulated_rank = 0 + fwd_time, bwd_time, comm_time = simulator( + process_mesh, stage, num_layers, simulated_rank, pp_layers_split + ) + fwd_time_per_stage_chunk.append(fwd_time) + bwd_time_per_stage_chunk.append(bwd_time) + comm_time_between_stages.append(comm_time) + print(f"--- Simulation End ---") + + +def analyze_pp_time( + scheme: str = '1F1B', + num_micro_batches: int = 1, + process_mesh: list = None, + pp_layers_split: list = None, +): + fwd_time_per_stage_chunk = [] + bwd_time_per_stage_chunk = [] + comm_time_between_stages = [] + + simulate_pipeline_parallelism_per_stage_time( + process_mesh=process_mesh, + pp_layers_split=pp_layers_split, + fwd_time_per_stage_chunk=fwd_time_per_stage_chunk, + bwd_time_per_stage_chunk=bwd_time_per_stage_chunk, + comm_time_between_stages=comm_time_between_stages, + ) + + pipeline_cost = compute_pipeline_parallelism_cost( + scheme=scheme, + num_micro_batches=num_micro_batches, + process_mesh=process_mesh, + pp_layers_split=pp_layers_split, + fwd_time_per_stage_chunk=fwd_time_per_stage_chunk, + bwd_time_per_stage_chunk=bwd_time_per_stage_chunk, + comm_time_between_stages=comm_time_between_stages, + ) + + return pipeline_cost diff --git a/flagscale/runner/auto_tuner/simulator/config_gen.py b/flagscale/runner/auto_tuner/simulator/config_gen.py new file mode 100644 index 0000000000..0b97e0ed6e --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/config_gen.py @@ -0,0 +1,391 @@ +import ast +import json +import os + +from functools import reduce +from itertools import combinations, product + +import analylize_pipeline_time + +# from itertools import product +import flagscale.train.theoretical_memory_usage as mem_usg + +BYTES_OF_GB = 10**9 + +# device_type_list = ["A800", "A800", "BI150", "BI150"] +# device_num_list = [8, 8, 8, 8] +# memory_capacity_of_devices = [80, 80, 32, 32] # GB + +device_type_list = ["A800", "BI150"] +device_num_list = [4, 4] +memory_capacity_of_devices = [80, 32] # GB + +global_batch_size = 512 +num_micro_batches = 8 +num_layers = 32 + +num_gpus = sum(device_num_list) + + +class DevicesInfo: + def __init__(self, device_type_list: list, device_num_list: list): + assert len(device_type_list) == len( + device_num_list + ), "\flength of list {device_type_list} should match {device_num_list}" + self.device_type_list = device_type_list + self.device_num_list = device_num_list + self.device_types_count = len(device_type_list) + self.possible_parallelisms = [] + + +class HeteroConfig: + def __init__( + self, + mesh, + device_types, + pp_layer_split, + recompute_granularity=None, + recompute_method="uniform", + recompute_num_layers=1, + theory_peak_memory=0.0, + oom_error=False, + ): + self.mesh = mesh + self.device_types = device_types + self.pp_layer_split = pp_layer_split + # self.micro_batch_size = 1 + self.recompute_granularity = recompute_granularity + self.recompute_method = recompute_method + self.recompute_num_layers = recompute_num_layers + + self.simulated_time = 0.0 + self.theory_peak_memory = theory_peak_memory + self.oom_error = oom_error + + +def generate_hetero_meshes( + devices_info: DevicesInfo, + global_batch_size: int = None, + num_layers: int = None, + output_file: str = "results.json", +): + def enumerate_parallelism(device_num: int = None): + possible_parallelisms = [] + for tp in range(1, device_num + 1): + for dp in range(1, device_num // tp + 1): + if device_num % (dp * tp) == 0: + pp = device_num // (dp * tp) + # mesh: [tp, cp, ep, dp, pp] + possible_parallelisms.append([tp, 1, 1, dp, pp]) + return possible_parallelisms + + def is_legal_combination(comb: list): + pp = sum(comb[4::5]) + # check dp is legal + max_dp = global_batch_size // pp + for dp in comb[3::5]: + if max_dp % dp != 0: + return False + return True + + def is_extreme_strategy(comb: list): + for mesh_index in range(len(comb) // 5): + # num_devices_in_mesh = sum( + # comb[ + # mesh_index * 5 : mesh_index * 5 + 4 + # ] + # ) + num_devices_in_mesh = reduce( + lambda x, y: x * y, comb[mesh_index * 5 : mesh_index * 5 + 5] + ) + dp_size_in_mesh = comb[mesh_index * 5 + 3] + tp_size_in_mesh = comb[mesh_index * 5 + 0] + pp_size_in_mesh = comb[mesh_index * 5 + 4] + print( + mesh_index, + comb[mesh_index * 5 : mesh_index * 5 + 5], + num_devices_in_mesh, + dp_size_in_mesh, + tp_size_in_mesh, + pp_size_in_mesh, + ) + if ( + pp_size_in_mesh > num_devices_in_mesh // 2 + or tp_size_in_mesh > 8 + or dp_size_in_mesh > num_devices_in_mesh / 4 + ): + return True + else: + return False + + def combine_possible_parallelisms(possible_parallelisms, output_file): + '''Combine and filter results, writing them to a file to avoid OOM.''' + all_combinations = product(*possible_parallelisms) + with open(output_file, "w") as f: + for comb in all_combinations: + result = sum(comb, []) + if is_legal_combination(result): + if not is_extreme_strategy(result): + f.write(",".join(map(str, result)) + "\n") + + # Ensure output file does not exist initially + if os.path.exists(output_file): + os.remove(output_file) + + # Enumerate all possible meshes for each kind of device + for i in range(devices_info.device_types_count): + device_num = devices_info.device_num_list[i] + devices_info.possible_parallelisms.append(enumerate_parallelism(device_num)) + + # Combine possibilities and write results to file + combine_possible_parallelisms(devices_info.possible_parallelisms, output_file) + print(f"Results written to {output_file}") + + +def split_layers(num_layers, pp_stages): + results = [] + # print(pp_stages) + for split_points in combinations(range(1, num_layers), pp_stages - 1): + # print(split_points) + if len(split_points) == 0: + continue + splits = ( + [split_points[0]] + + [split_points[i] - split_points[i - 1] for i in range(1, len(split_points))] + + [num_layers - split_points[-1]] + ) + # to prune some extreme splits + if max(splits) / min(splits) > 2: + continue + # print(splits) + results.append(splits) + return results + + +class MeshArguments: + def __init__(self, mesh_config: HeteroConfig): + # [tp, cp, ep, dp, pp] + self.data_parallel_size = mesh_config.mesh[3] + # TODO: pp size not correct when computing memory, because former method divides the layers evenly + # no embed and dropout for stages except the 1st, and make the layers changable + + # if args.pipeline_model_parallel_size > 1: + # activation_memory = ( + # perlayer_activation + # * args.num_layers + # / args.pipeline_model_parallel_size + # * in_flight_microbatches + # + embedding_activation_memory + # + dropout_activation_memory + # ) + # else: + # activation_memory = ( + # perlayer_activation * args.num_layers + # + embedding_activation_memory + # + dropout_activation_memory + # + output_layer_and_loss_activation_memory + # ) + self.pipeline_model_parallel_size = sum(mesh_config.mesh[4::5]) + self.tensor_model_parallel_size = mesh_config.mesh[0] + self.virtual_pipeline_model_parallel_size = None + self.num_experts = 1 + self.context_parallel_size = 1 + self.swiglu = True + self.micro_batch_size = global_batch_size / num_micro_batches / self.data_parallel_size + self.num_layers = num_layers + self.num_attention_heads = 32 + self.group_query_attention = None # TODO + self.num_query_groups = 1 # TODO + self.moe_layer_freq = 2 + self.moe_router_topk = 1 + self.multi_latent_attention = False + self.seq_length = 2048 + self.padded_vocab_size = 4096 # TODO + self.hidden_size = 4096 + self.qk_layernorm = False + self.mtp_num_layers = None + self.expert_model_parallel_size = 1 + self.world_size = 8 + self.moe_shared_expert_intermediate_size = 16384 + self.moe_ffn_hidden_size = 4 * self.hidden_size + # self.ffn_hidden_size + self.multiple_of = 256 + hidden_dim = int(4 * self.hidden_size * 2 / 3) + self.ffn_hidden_size = self.multiple_of * ( + (hidden_dim + self.multiple_of - 1) // self.multiple_of + ) + # self.kv_channels + self.kv_channels = self.hidden_size // self.num_attention_heads + + self.recompute_granularity = mesh_config.recompute_granularity + self.recompute_method = mesh_config.recompute_method + self.recompute_num_layers = mesh_config.recompute_num_layers + self.expert_tensor_parallel_size = 1 + self.use_flash_attn = True + self.sequence_parallel = True + self.use_distributed_optimizer = True + self.untie_embeddings_and_output_weights = False # TODO + + self.enable_hetero = True + + +def report_oom_error( + memory_capacity_of_devices: list, meshes_config: list, peak_memory_usage_per_stage: list +): + stage_index = 0 + for mesh_index, num_stages_in_current_mesh in enumerate(meshes_config[4::5]): + for i in range(num_stages_in_current_mesh): + if ( + peak_memory_usage_per_stage[stage_index + i] + >= memory_capacity_of_devices[mesh_index] + ): + return True + stage_index = stage_index + num_stages_in_current_mesh + return False + + +def calculate_peak_memory_per_stage(mesh_config): + peak_memory_usage_per_stage = [] + model_parallel_training_args = MeshArguments(mesh_config) + stage_index = 0 + mesh_index = 0 + for pp_stage_num_per_mesh in mesh_config.mesh[4::5]: + model_parallel_training_args.data_parallel_size = mesh_config.mesh[3 + 5 * mesh_index] + model_parallel_training_args.tensor_model_parallel_size = mesh_config.mesh[ + 0 + 5 * mesh_index + ] + for stage in range(pp_stage_num_per_mesh): + model_parallel_training_args.num_layers = mesh_config.pp_layer_split[stage_index] + + peak_activation_memory_usage = mem_usg.compute_activation_memory( + args=model_parallel_training_args, num_microbatches=num_micro_batches + ) + peak_weight_optimizer_usage = mem_usg.compute_weight_and_optimizer_memory( + args=model_parallel_training_args + ) + peak_memory_usage = peak_activation_memory_usage + peak_weight_optimizer_usage + + peak_memory_usage_per_stage.append(peak_memory_usage / BYTES_OF_GB) + stage_index = stage_index + 1 + + mesh_index = mesh_index + 1 + + return peak_memory_usage_per_stage + + +def gen_hetero_configs( + device_type_list, + device_num_list, + global_batch_size, + num_layers, + # num_micro_batches, + # hetero_configs: list, + output_config_file: str = "hetero_configs.json", # 新增参数用于保存 hetero_config +): + devices_info = DevicesInfo(device_type_list=device_type_list, device_num_list=device_num_list) + + # 调用 generate_hetero_meshes,生成并写入结果文件 + generate_hetero_meshes( + devices_info=devices_info, + global_batch_size=global_batch_size, + num_layers=num_layers, + output_file="results.json", # 保存 hetero_meshes 的中间文件 + ) + + # 从 results.json 读取 hetero_meshes + hetero_meshes = [] + with open("results.json", "r") as f: + for line in f: + hetero_meshes.append(list(map(int, line.strip().split(",")))) + # print(hetero_meshes) + # assert False + # 遍历 hetero_meshes 并生成 hetero_config + with open(output_config_file, "w") as config_file: # 打开输出文件 + for mesh in hetero_meshes: + pp_stages = sum(mesh[4::5]) + # in order to prune the num of layers in each stage to even number + pp_layer_splits = split_layers(num_layers=num_layers // 2, pp_stages=pp_stages) + for split in pp_layer_splits: + split = [x * 2 for x in split] + hetero_config = HeteroConfig( + mesh=mesh, pp_layer_split=split, device_types=device_type_list + ) + # hetero_configs.append(hetero_config) + + # 保存 HeteroConfig 的每个成员变量到文件 + theory_peak_memory_per_stage = calculate_peak_memory_per_stage(hetero_config) + oom_error = report_oom_error( + memory_capacity_of_devices=memory_capacity_of_devices, + meshes_config=mesh, + peak_memory_usage_per_stage=theory_peak_memory_per_stage, + ) + # if oom_error: + # continue + config_data = { + "mesh": hetero_config.mesh, + "device_types": hetero_config.device_types, + "pp_layer_split": hetero_config.pp_layer_split, + "recompute_granularity": hetero_config.recompute_granularity, + "recompute_method": hetero_config.recompute_method, + "recompute_num_layers": hetero_config.recompute_num_layers, + "simulated_time": hetero_config.simulated_time, + "theory_peak_memory": theory_peak_memory_per_stage, + "oom_error": oom_error, + } + config_file.write(f"{config_data}\n") + + print(f"Hetero configurations saved to {output_config_file}") + + +import ast +import json + + +def read_configs_from_json(file_path: str): + configs_list = [] + with open(file_path, "r") as file: + for line in file: + # config_data = json.loads(line.strip()) + config_data = ast.literal_eval(line.strip()) + configs_list.append(config_data) + return configs_list + + +def get_min_simulated_time_config(hetero_configs): + if not hetero_configs: + return None + return min(hetero_configs, key=lambda x: x.get("simulated_time", float("inf"))) + + +# for test and usage +if __name__ == "__main__": + # hetero_configs = [] + + # generate all possible and legal mesh configs, each element of hetero_configs is a mesh list + gen_hetero_configs( + device_type_list=device_type_list, + device_num_list=device_num_list, + global_batch_size=global_batch_size, + num_layers=num_layers, + output_config_file="hetero_configs.json", + # num_micro_batches=num_micro_batches, + # hetero_configs=hetero_configs + ) + + # assert False + # simulation + file_path = "hetero_configs.json" + hetero_configs = read_configs_from_json(file_path) + for hetero_config in hetero_configs: + print(hetero_config) + pp_cost = hetero_config['simulated_time'] = analylize_pipeline_time.analyze_pp_time( + # pp_cost = hetero_config.simulated_time = analylize_pipeline_time.analyze_pp_time( + scheme="1F1B", + num_micro_batches=num_micro_batches, + process_mesh=hetero_config['mesh'], + pp_layers_split=hetero_config['pp_layer_split'], + ) + print(f"pipeline cost: {pp_cost}") + break + best_config = get_min_simulated_time_config(hetero_configs) + print(best_config) diff --git a/flagscale/runner/auto_tuner/simulator/custom_backend/include/dummy.hpp b/flagscale/runner/auto_tuner/simulator/custom_backend/include/dummy.hpp new file mode 100644 index 0000000000..a71eb8536a --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/custom_backend/include/dummy.hpp @@ -0,0 +1,157 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +#include + +using AnyType = std::variant; + + +namespace c10d { + +class ProcessGroup; // 假设的类 +class Store; // 假设的类 + +class BackendDummy : public Backend { + public: + + BackendDummy(int rank, int size); + + const std::string getBackendName() const override; + void startCoalescing() override; + c10::intrusive_ptr endCoalescing() override; + +c10::intrusive_ptr reduce_scatter_tensor_coalesced( + std::vector& outputTensors, + std::vector& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + +c10::intrusive_ptr allgather_into_tensor_coalesced( + std::vector& outputTensors/* outputs */, + std::vector& inputTensors/* inputs */, + const AllgatherOptions& /* opts */ = AllgatherOptions()) override; + +c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& outputTensors/* outputBuffer */, + at::Tensor& inputTensors/* inputBuffer */, + const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) override; + +c10::intrusive_ptr broadcast( + std::vector &data, + const BroadcastOptions &opts = BroadcastOptions()) override; + +c10::intrusive_ptr allreduce( + std::vector &tensors, + const AllreduceOptions &opts = AllreduceOptions()) override; + +c10::intrusive_ptr allreduce_coalesced( + std::vector &tensors, + const AllreduceCoalescedOptions &opts = + AllreduceCoalescedOptions()) override; + +c10::intrusive_ptr reduce( + std::vector &tensors, + const ReduceOptions &opts = ReduceOptions()) override; + +c10::intrusive_ptr all_gather_object( + std::vector &outputTensors, + AnyType &inputTensors, + const AllgatherOptions &opts = AllgatherOptions()); + +c10::intrusive_ptr allgather( + std::vector> &outputTensors, + std::vector &inputTensors, + const AllgatherOptions &opts = AllgatherOptions()) override; + +c10::intrusive_ptr _allgather_base( + at::Tensor &outputBuffer, + at::Tensor &inputBuffer, + const AllgatherOptions &opts = AllgatherOptions()) override; + +c10::intrusive_ptr barrier( + const BarrierOptions &opts = BarrierOptions()) override; + +c10::intrusive_ptr gather( + std::vector> &outputTensors, + std::vector &inputTensors, + const GatherOptions &opts = GatherOptions()) override; + +c10::intrusive_ptr scatter( + std::vector &outputTensors, + std::vector> &inputTensors, + const ScatterOptions &opts = ScatterOptions()) override; + +c10::intrusive_ptr reduce_scatter( + std::vector &outputTensors, + std::vector> &inputTensors, + const ReduceScatterOptions &opts = ReduceScatterOptions()) override; + +c10::intrusive_ptr alltoall_base( + at::Tensor &outputTensor, + at::Tensor &inputTensor, + std::vector &outputSplitSizes, + std::vector &inputSplitSizes, + const AllToAllOptions &opts = AllToAllOptions()) override; + +c10::intrusive_ptr alltoall( + std::vector &outputTensors, + std::vector &inputTensors, + const AllToAllOptions &opts = AllToAllOptions()) override; + +c10::intrusive_ptr send( + std::vector &tensors, + int dstRank, + int tag) override; + +c10::intrusive_ptr recv( + std::vector &tensors, + int srcRank, + int tag) override; + +c10::intrusive_ptr recvAnysource( + std::vector &tensors, + int tag) override; + +static c10::intrusive_ptr createBackendDummy( + const c10::intrusive_ptr<::c10d::Store> &store, + int rank, + int size, + const std::chrono::duration &timeout); + +static void BackendDummyConstructor() __attribute__((constructor)) +{ + py::object module = py::module::import("torch.distributed"); + py::object register_backend = + module.attr("Backend").attr("register_backend"); + register_backend("dummy", py::cpp_function(createBackendDummy)); + } +}; + +class WorkDummy : public Work { + friend class BackendDummy; +public: + WorkDummy( + OpType opType, + c10::intrusive_ptr future) // future of the output + : Work( + -1, // rank, only used by recvAnySource, irrelevant in this demo + opType), + future_(std::move(future)) {} + bool isCompleted() override; + bool isSuccess() const override; + bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override; + virtual c10::intrusive_ptr getFuture() override; + +private: + c10::intrusive_ptr future_; +}; + +} // namespace c10d diff --git a/flagscale/runner/auto_tuner/simulator/custom_backend/setup.py b/flagscale/runner/auto_tuner/simulator/custom_backend/setup.py new file mode 100644 index 0000000000..172d5ad0e6 --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/custom_backend/setup.py @@ -0,0 +1,25 @@ +import os + +import torch + +from setuptools import setup +from torch.utils import cpp_extension + +sources = ["src/dummy.cpp"] +include_dirs = [f"{os.path.dirname(os.path.abspath(__file__))}/include/"] + +if torch.cuda.is_available(): + module = cpp_extension.CUDAExtension( + name="dummy_collectives", sources=sources, include_dirs=include_dirs + ) +else: + module = cpp_extension.CppExtension( + name="dummy_collectives", sources=sources, include_dirs=include_dirs + ) + +setup( + name="Dummy-Collectives", + version="0.0.1", + ext_modules=[module], + cmdclass={'build_ext': cpp_extension.BuildExtension}, +) diff --git a/flagscale/runner/auto_tuner/simulator/custom_backend/src/dummy.cpp b/flagscale/runner/auto_tuner/simulator/custom_backend/src/dummy.cpp new file mode 100644 index 0000000000..231ef1b1e7 --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/custom_backend/src/dummy.cpp @@ -0,0 +1,285 @@ +#include "dummy.hpp" +#include +// #include +// #include +// #include +// #include + +namespace c10d { + + +bool WorkDummy::isCompleted() { + return true; +} + +bool WorkDummy::isSuccess() const { + return true; +} + +bool WorkDummy::wait(std::chrono::milliseconds /* unused */) { + return true; +} + +c10::intrusive_ptr WorkDummy::getFuture() { + return future_; +} + +// If necessary, pass store/rank/size to the ctor and exchange connection +// information here +BackendDummy::BackendDummy(int rank, int size) + : Backend(rank, size) {} + +const std::string BackendDummy::getBackendName() const{ + return "dummy"; +} + +void BackendDummy::startCoalescing(){ + return; + } + +c10::intrusive_ptr BackendDummy::endCoalescing(){ + at::Tensor outputTensors; + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::reduce_scatter_tensor_coalesced( + std::vector& outputTensors, + std::vector& inputTensors, + const ReduceScatterOptions&) { + // printf("dummy reduce_scatter_tensor_coalesced\n"); + for (auto& outputTensor : outputTensors) { + outputTensor.fill_(1); + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::allgather_into_tensor_coalesced( + std::vector& outputTensors/* outputs */, + std::vector& inputTensors/* inputs */, + const AllgatherOptions& ) { + // printf("dummy reduce_scatter_tensor_coalesced\n"); + for (auto& outputTensor : outputTensors) { + outputTensor.fill_(1); + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::_reduce_scatter_base( + at::Tensor& outputTensors/* outputBuffer */, + at::Tensor& inputTensors/* inputBuffer */, + const ReduceScatterOptions& ) { + // printf("dummy _reduce_scatter_base\n"); + // for (auto& outputTensor : outputTensors) { + // outputTensor.fill_(1); + // } + outputTensors.fill_(1); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +// This is a dummy allgather that sets all output tensors to zero +// Modify the implementation to conduct real communication asynchronously +c10::intrusive_ptr BackendDummy::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& /* unused */) { + // printf("dummy allgather\n"); + for (auto& outputTensorVec : outputTensors) { + for (auto& outputTensor : outputTensorVec) { + outputTensor.fill_(1); + } + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::all_gather_object( + std::vector& outputTensors, + AnyType& inputTensors, + const AllgatherOptions& /* unused */) { + // printf("dummy all_gather_object Begin\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::_allgather_base( + at::Tensor& /* unused */, + at::Tensor& /* unused */, + const AllgatherOptions& /* unused */) { + // printf("dummy _allgather_base\n"); + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +// This is a dummy allreduce that sets all output tensors to zero +// Modify the implementation to conduct real communication asynchronously +c10::intrusive_ptr BackendDummy::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + // printf("dummy allreduce\n"); + for (auto& tensor : tensors) { + tensor.zero_(); + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + future->markCompleted(c10::IValue(tensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::allreduce_coalesced( + std::vector& /* unused */, + const AllreduceCoalescedOptions& /* unused */) { + // printf("dummy allreduce_coalesced\n"); + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::alltoall( + std::vector& /* unused */, + std::vector& /* unused */, + const AllToAllOptions& /* unused */) { + // printf("dummy alltoall\n"); + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& /* unused */) { + // printf("dummy alltoall_base\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::barrier( + const BarrierOptions& /* unused */) { + // printf("dummy barrier\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::broadcast( + std::vector& tensors, + const BroadcastOptions& opts) { + // printf("dummy broadcast\n"); + for (auto& tensor : tensors) { + tensor.zero_(); + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + future->markCompleted(c10::IValue(tensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::gather( + std::vector>& /* unused */, + std::vector& /* unused */, + const GatherOptions& /* unused */) { + // printf("dummy gather\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::reduce( + std::vector& /* unused */, + const ReduceOptions& /* unused */) { + // printf("dummy reduce\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::reduce_scatter( + std::vector& /* unused */, + std::vector>& /* unused */, + const ReduceScatterOptions& /* unused */) { + // printf("dummy reduce_scatter\n"); + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::scatter( + std::vector& /* unused */, + std::vector>& /* unused */, + const ScatterOptions& /* unused */) { + // printf("dummy scatter\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::send( + std::vector& tensors, + int dstRank, + int tag) { + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::recv( + std::vector& tensors, + int srcRank, + int tag) { + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::recvAnysource( + std::vector& tensors, + int tag) { + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::createBackendDummy( + const c10::intrusive_ptr<::c10d::Store>& /* unused */, + int rank, + int size, + const std::chrono::duration& /* unused */) { + return c10::make_intrusive(rank, size); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("createBackendDummy", &BackendDummy::createBackendDummy); +} + +} // namespace c10d From 9b45cded92cea4252fcef2f677eeb70d606ebc35 Mon Sep 17 00:00:00 2001 From: shuailong616 <452509829@qq.com> Date: Wed, 29 Oct 2025 19:17:18 +0800 Subject: [PATCH 3/8] modified arguments.py --- flagscale/train/arguments.py | 72 ++++++++++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 16 deletions(-) diff --git a/flagscale/train/arguments.py b/flagscale/train/arguments.py index a74c17acbf..e085369cc9 100644 --- a/flagscale/train/arguments.py +++ b/flagscale/train/arguments.py @@ -13,6 +13,12 @@ warnings.warn( "flagcx is not installed, you can't use flagcx backend for communication.", ImportWarning ) +import datetime +import multiprocessing +import os +import threading + +import dummy_collectives from flagscale.train.hetero.parallel_context import RankMapper @@ -55,22 +61,56 @@ def _initialize_distributed(self): else: device_id = None - # Call the init process - init_process_group_kwargs = { - "backend": args.distributed_backend, - "world_size": args.world_size, - "rank": args.rank, - "timeout": timedelta(minutes=args.distributed_timeout_minutes), - } - if args.distributed_backend == "flagcx": - init_process_group_kwargs["backend"] = "cpu:gloo,cuda:flagcx" - # for communication based cpu - if args.enable_hetero and args.hetero_use_cpu_communication: - # if not all(device_type == args.hetero_device_types[0] for device_type in args.hetero_device_types): - # init_process_group_kwargs['backend'] = 'cpu:gloo' - # Force the group of backend gloo only support cpu - init_process_group_kwargs["backend"] = "cpu:gloo" - torch.distributed.init_process_group(**init_process_group_kwargs) + if args.enable_simulator: + # Define a function to initialize and run operations with a virtual rank + def run_virtual_rank(rank, world_size, timeout): + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "37832" + init_process_group_kwargs = { + 'backend': args.distributed_backend, + 'world_size': world_size, + 'rank': rank, + 'timeout': datetime.timedelta(minutes=timeout), + } + torch.distributed.init_process_group(**init_process_group_kwargs) + torch.distributed.barrier() + + # Call the init process with multithreads + args.distributed_timeout_minutes = 1 + threads = [] + # Start a thread for each virtual rank + for rank in range(1, 2): # 2 for skipping launching thousands of threads + # for rank in range(1, args.world_size): + thread = threading.Thread( + target=run_virtual_rank, + args=(rank, args.world_size, args.distributed_timeout_minutes), + ) + thread.start() + threads.append(thread) + rank = 0 + gpu_task = multiprocessing.Process( + target=run_virtual_rank, + args=(rank, args.world_size, args.distributed_timeout_minutes), + ) + gpu_task.start() + # Wait for all threads to complete + for thread in threads: + thread.join() + else: + # Call the init process + init_process_group_kwargs = { + "backend": args.distributed_backend, + "world_size": args.world_size, + "rank": args.rank, + "timeout": timedelta(minutes=args.distributed_timeout_minutes), + } + # for communication based cpu + if args.enable_hetero and args.hetero_use_cpu_communication: + # if not all(device_type == args.hetero_device_types[0] for device_type in args.hetero_device_types): + # init_process_group_kwargs['backend'] = 'gloo' + # Force the group of backend gloo only support cpu + init_process_group_kwargs['backend'] = 'cpu:gloo' + torch.distributed.init_process_group(**init_process_group_kwargs) def _build_rank_mapper(self): self._initialize_distributed() From e655abe5878d9d4f08cdf5d10bbd5296c534ae66 Mon Sep 17 00:00:00 2001 From: shuailong616 <452509829@qq.com> Date: Wed, 29 Oct 2025 19:33:56 +0800 Subject: [PATCH 4/8] modified config_gen.py --- flagscale/runner/auto_tuner/simulator/config_gen.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flagscale/runner/auto_tuner/simulator/config_gen.py b/flagscale/runner/auto_tuner/simulator/config_gen.py index 0b97e0ed6e..864788692e 100644 --- a/flagscale/runner/auto_tuner/simulator/config_gen.py +++ b/flagscale/runner/auto_tuner/simulator/config_gen.py @@ -386,6 +386,5 @@ def get_min_simulated_time_config(hetero_configs): pp_layers_split=hetero_config['pp_layer_split'], ) print(f"pipeline cost: {pp_cost}") - break best_config = get_min_simulated_time_config(hetero_configs) print(best_config) From 12eb33c8eb27aa7b140c6099cbed5914257feb09 Mon Sep 17 00:00:00 2001 From: shuailong616 <452509829@qq.com> Date: Tue, 18 Nov 2025 17:45:02 +0800 Subject: [PATCH 5/8] add vpp scheme & support simulate all stage --- .../simulator/analylize_pipeline_time.py | 185 +++++++++++++++--- .../runner/auto_tuner/simulator/config_gen.py | 83 +++++--- flagscale/train/arguments.py | 4 +- 3 files changed, 221 insertions(+), 51 deletions(-) diff --git a/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py b/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py index ce61302b10..c9c7400723 100644 --- a/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py +++ b/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py @@ -21,6 +21,7 @@ def compute_pipeline_parallelism_cost( fwd_time_per_stage_chunk: list = None, bwd_time_per_stage_chunk: list = None, comm_time_between_stages: list = None, + vpp_partition: list = None, # TODO: add fine-greaied recomputation ): print(f"--- Compute Pipeline Cost ---") @@ -28,9 +29,13 @@ def compute_pipeline_parallelism_cost( # process_mesh: [tp0,cp0,ep0,dp0,pp0,(tp1,cp1,...)] # comm_time_between_stages[i] means the comm time between stage i-1 and stage i num_pp_stages = sum(process_mesh[4::5]) + assert ( len(pp_layers_split) == num_pp_stages ), "\flength of list {num_layers_per_stage} should match {num_stages}" + if scheme == 'vpp': + num_pp_stages = sum(vpp_partition) + assert ( len(fwd_time_per_stage_chunk) == num_pp_stages ), "\flength of list {fwd_time_per_stage_chunk} should match {num_stages}" @@ -72,8 +77,41 @@ def compute_pipeline_parallelism_cost( pipeline_cost = pp_this_stage_compute_time + max( pp_last_stage_overall_time, pp_this_stage_overlapped_time ) - else: - raise (ValueError("Scheme must be '1F1B' or 'AFAB'.")) + # else: + # raise (ValueError("Scheme must be '1F1B' or 'AFAB'.")) + elif scheme == 'vpp': + num_vp_stages = len(fwd_time_per_stage_chunk) + num_pp_stages = len(comm_time_between_stages) # error + vstage_to_pp = [] + for i, count in enumerate(vpp_partition): + vstage_to_pp += [i] * count + + comm_per_vstage = [0.0] * num_vp_stages + for i in range(num_vp_stages - 1): + cur_pp, next_pp = vstage_to_pp[i], vstage_to_pp[i + 1] + if next_pp != cur_pp: + comm_per_vstage[i] = comm_time_between_stages[cur_pp + 1] + + vp_last_stage_time = num_micro_batches * ( + fwd_time_per_stage_chunk[-1] + bwd_time_per_stage_chunk[-1] + ) + pipeline_cost = vp_last_stage_time + for vp_from_last in range(2, num_vp_stages + 1): + this_vp_idx = num_vp_stages - vp_from_last + this_stage_fwd = fwd_time_per_stage_chunk[this_vp_idx] + this_stage_bwd = bwd_time_per_stage_chunk[this_vp_idx] + this_stage_compute_time = this_stage_fwd + this_stage_bwd + + pp_idx = this_vp_idx % num_pp_stages + comm_time = comm_time_between_stages[min(pp_idx + 1, num_pp_stages - 1)] + + this_vp_overlapped_time = (num_micro_batches - 1) * this_stage_compute_time + + last_vp_total_time = pipeline_cost + 2 * comm_time + + pipeline_cost = this_stage_compute_time + max( + this_vp_overlapped_time, last_vp_total_time + ) return pipeline_cost @@ -91,13 +129,13 @@ def simulator( # os.environ["PYTHONPATH"] = "/share/project/heyongzhe/FlagScale/megatron:/share/project/heyongzhe/FlagScale" os.environ["PYTHONPATH"] = ( - "/workspace/20251010/new/FlagScale:" - "/workspace/20251010/new/FlagScale/third_party/Megatron-LM" + "/workspace/single_process_simulator_nd/FlagScale:" + "/workspace/single_process_simulator_nd/FlagScale/third_party/Megatron-LM" ) - os.environ["CUDA_VISIBLE_DEVICES"] = "0" + os.environ["CUDA_VISIBLE_DEVICES"] = "6" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["RANK"] = str(simulated_rank) - os.environ["LOCAL_RANK"] = str(simulated_rank) + os.environ["LOCAL_RANK"] = "0" # os.environ["WORLD_SIZE"] = args.world_size os.environ["WORLD_SIZE"] = "8" # os.environ["WORLD_SIZE"] = "32" @@ -128,7 +166,7 @@ def simulator( train_command = ( "python " + program_entry - + "--tensor-model-parallel-size 1 --timing-log-level 2 --disable-bias-linear --use-flash-attn --sequence-parallel --use-distributed-optimizer --use-mcore-models --transformer-impl transformer_engine --hetero-device-types A800 BI150 --hetero-current-device-type A800 --recompute-granularity full --recompute-method uniform --recompute-num-layers 1 --bf16 --attention-softmax-in-fp32 --accumulate-allreduce-grads-in-fp32 --log-interval 1 --log-throughput --tensorboard-log-interval 1 --wandb-project aquila2 --wandb-exp-name test --tensorboard-dir /share/project/heyongzhe/FlagScale/outputs/tensorboard --wandb-save-dir /share/project/heyongzhe/FlagScale/outputs/wandb --num-layers 32 --hidden-size 4096 --num-attention-heads 32 --seq-length 2048 --max-position-embeddings 2048 --norm-epsilon 1e-05 --use-rotary-position-embeddings --no-position-embedding --swiglu --multiple-of 256 --normalization RMSNorm --untie-embeddings-and-output-weights --init-method-std 0.0165 --attention-dropout 0.0 --hidden-dropout 0.0 --weight-decay 0.1 --clip-grad 1.0 --train-samples 128 --global-batch-size 64 --micro-batch-size 1 --seed 42 --lr 0.0002 --weight-decay 0.01 --adam-beta1 0.9 --adam-beta2 0.95 --lr 0.00015 --min-lr 1.5e-05 --lr-warmup-samples 0 --lr-decay-style cosine --data-path /workspace/FlagScale/datapath/pile_wikipedia_demo --split 1 --tokenizer-type AquilaTokenizerFS --vocab-file ./examples/aquila/tokenizer/vocab.json --merge-file ./examples/aquila/tokenizer/merges.txt --special-tokens-file ./examples/aquila/tokenizer/special_tokens.txt --vocab-size 100008 " + + "--tensor-model-parallel-size 1 --timing-log-level 2 --disable-bias-linear --use-flash-attn --sequence-parallel --use-distributed-optimizer --use-mcore-models --transformer-impl transformer_engine --hetero-device-types A800 A800 --hetero-current-device-type A800 --bf16 --attention-softmax-in-fp32 --accumulate-allreduce-grads-in-fp32 --log-interval 1 --log-throughput --tensorboard-log-interval 1 --wandb-project aquila2 --wandb-exp-name test --tensorboard-dir /share/project/heyongzhe/FlagScale/outputs/tensorboard --wandb-save-dir /share/project/heyongzhe/FlagScale/outputs/wandb --num-layers 32 --hidden-size 4096 --num-attention-heads 32 --seq-length 2048 --max-position-embeddings 2048 --norm-epsilon 1e-05 --use-rotary-position-embeddings --no-position-embedding --swiglu --multiple-of 256 --normalization RMSNorm --untie-embeddings-and-output-weights --init-method-std 0.0165 --attention-dropout 0.0 --hidden-dropout 0.0 --weight-decay 0.1 --clip-grad 1.0 --train-samples 128 --global-batch-size 64 --micro-batch-size 1 --seed 42 --lr 0.0002 --weight-decay 0.01 --adam-beta1 0.9 --adam-beta2 0.95 --lr 0.00015 --min-lr 1.5e-05 --lr-warmup-samples 0 --lr-decay-style cosine --data-path /workspace/FlagScale/datapath/pile_wikipedia_demo --split 1 --tokenizer-type AquilaTokenizerFS --vocab-file ./examples/aquila/tokenizer/vocab.json --merge-file ./examples/aquila/tokenizer/merges.txt --special-tokens-file ./examples/aquila/tokenizer/special_tokens.txt --vocab-size 100008 " + process_mesh_str + simulation_arguments + pp_layer_split_args @@ -145,6 +183,7 @@ def simulator( # time.sleep(10) print("start...") result = subprocess.run(train_command, capture_output=True, text=True, shell=True) + print(result) output = result.stdout.strip() print(train_command) print(output) @@ -154,26 +193,88 @@ def simulator( fwd_time = float(match.group(1)) bwd_time = float(match.group(2)) # comm_time = float(match.group(3)) - comm_time = 0.01 + comm_time = estimate_comm_time_between_stages(1, 2048, 4096) print("forward:", fwd_time) print("backward:", bwd_time) print("communication:", comm_time) else: - # print(fwd_time,bwd_time,comm_time) - fwd_time = 12.34 - bwd_time = 56.78 - comm_time = 90.12 - print("forward:", fwd_time) - print("backward:", bwd_time) - print("communication:", comm_time) - # raise(ValueError("Results not found. Example output: \"[simulatior output] forward: 12.34, backward: 56.78, communication: 90.12\"")) + raise ( + ValueError( + "Results not found. Example output: \"[simulatior output] forward: 12.34, backward: 56.78, communication: 90.12\"" + ) + ) return fwd_time, bwd_time, comm_time +def compute_vpp_from_layers( + pp_layers_split, target_layers_per_vstage=2, device_speed=None, min_layers_per_virtual_stage=2 +): + """ + Args: + pp_layers_split: list[int] + target_layers_per_vstage: int + device_speed: list[float] + min_layers_per_virtual_stage: + Returns: + vpp_list: list[int], + """ + vpp_list = [] + max_speed = max(device_speed) if device_speed else 1.0 + + for i, num_layers in enumerate(pp_layers_split): + base_vpp = max(1, round(num_layers / target_layers_per_vstage)) + + if device_speed: + scale = device_speed[i] / max_speed + base_vpp = max(1, round(base_vpp * scale)) + + base_vpp = min(base_vpp, num_layers // min_layers_per_virtual_stage) + if base_vpp == 0: + base_vpp = 1 + + while num_layers % base_vpp != 0 and base_vpp > 1: + base_vpp -= 1 + + vpp_list.append(base_vpp) + + return vpp_list + + +def estimate_comm_time_between_stages( + batch_size: int, + seq_len: int, + hidden_size: int, + dtype_bytes: int = 2, # bf16 + bandwidth_GBps: float = 300.0, + latency_ms: float = 0.01, + tensor_model_parallel_size: int = 1, + virtual_pipeline_size: int = 1, + activation_fraction: float = 1.0, + use_allgather_for_activation: bool = False, +): + bytes_one_way = batch_size * seq_len * hidden_size * dtype_bytes * activation_fraction + if tensor_model_parallel_size > 1: + bytes_one_way /= tensor_model_parallel_size + + K = max(1, virtual_pipeline_size) + per_transfer_bytes = bytes_one_way / K + bw_Bps = bandwidth_GBps * 1e9 + one_way_time = per_transfer_bytes / bw_Bps + latency_ms / 1000.0 + comm_time = 2 * K * one_way_time # fwd+bwd + + if use_allgather_for_activation and tensor_model_parallel_size > 1: + extra_bytes = ( + (tensor_model_parallel_size - 1) / tensor_model_parallel_size + ) * bytes_one_way + comm_time += extra_bytes / bw_Bps + (latency_ms / 1000.0) + return comm_time + + # call simulator to obtain the execution of each stage def simulate_pipeline_parallelism_per_stage_time( process_mesh: list = None, pp_layers_split: list = None, + scheme: str = '1F1B', fwd_time_per_stage_chunk: list = None, bwd_time_per_stage_chunk: list = None, comm_time_between_stages: list = None, @@ -181,16 +282,45 @@ def simulate_pipeline_parallelism_per_stage_time( print(f"--- Simulation Begin ---") print(f"Process Mesh: {process_mesh}") print(f"PP Layer Split: {pp_layers_split}") - for stage, num_layers in enumerate(pp_layers_split): - # TODO: confirm simulated_rank for different stage - print(f"Stage: {stage}; Num Layers: {num_layers}") - simulated_rank = 0 - fwd_time, bwd_time, comm_time = simulator( - process_mesh, stage, num_layers, simulated_rank, pp_layers_split - ) - fwd_time_per_stage_chunk.append(fwd_time) - bwd_time_per_stage_chunk.append(bwd_time) - comm_time_between_stages.append(comm_time) + if scheme == '1F1B': + for stage, num_layers in enumerate(pp_layers_split): + # TODO: confirm simulated_rank for different stage + print(f"Stage: {stage}; Num Layers: {num_layers}") + simulated_rank = stage + try: + fwd_time, bwd_time, comm_time = simulator( + process_mesh, stage, num_layers, simulated_rank, pp_layers_split + ) + fwd_time_per_stage_chunk.append(fwd_time) + bwd_time_per_stage_chunk.append(bwd_time) + comm_time_between_stages.append(comm_time) + except Exception as e: + print(f"[Error] Simulator failed at stage {stage}, skip. Reason: {e}") + continue + + elif scheme == 'vpp': + vpp_list = compute_vpp_from_layers(pp_layers_split) + print(vpp_list) + for stage_idx, (num_layers, vpp) in enumerate(zip(pp_layers_split, vpp_list)): + layers_per_chunk = num_layers // vpp + for vstage_idx in range(vpp): + vstage_name = f"{stage_idx}-{vstage_idx}" + print(f" ->Stage {vstage_name} : ( {layers_per_chunk})") + try: + fwd_time, bwd_time, comm_time = simulator( + process_mesh=process_mesh, + stage=vstage_name, + num_layers=layers_per_chunk, + simulated_rank=stage_idx, + pp_layers_split=pp_layers_split, + ) + fwd_time_per_stage_chunk.append(fwd_time) + bwd_time_per_stage_chunk.append(bwd_time) + comm_time_between_stages.append(comm_time) + except Exception as e: + print(f"[Error] Simulator failed at V-stage {vstage_name}, skip. Reason: {e}") + continue + print(f"--- Simulation End ---") @@ -203,10 +333,12 @@ def analyze_pp_time( fwd_time_per_stage_chunk = [] bwd_time_per_stage_chunk = [] comm_time_between_stages = [] + vpp_partition = compute_vpp_from_layers(pp_layers_split) simulate_pipeline_parallelism_per_stage_time( process_mesh=process_mesh, pp_layers_split=pp_layers_split, + scheme=scheme, fwd_time_per_stage_chunk=fwd_time_per_stage_chunk, bwd_time_per_stage_chunk=bwd_time_per_stage_chunk, comm_time_between_stages=comm_time_between_stages, @@ -220,6 +352,7 @@ def analyze_pp_time( fwd_time_per_stage_chunk=fwd_time_per_stage_chunk, bwd_time_per_stage_chunk=bwd_time_per_stage_chunk, comm_time_between_stages=comm_time_between_stages, + vpp_partition=vpp_partition, ) return pipeline_cost diff --git a/flagscale/runner/auto_tuner/simulator/config_gen.py b/flagscale/runner/auto_tuner/simulator/config_gen.py index 864788692e..d16411c9d9 100644 --- a/flagscale/runner/auto_tuner/simulator/config_gen.py +++ b/flagscale/runner/auto_tuner/simulator/config_gen.py @@ -12,13 +12,10 @@ BYTES_OF_GB = 10**9 -# device_type_list = ["A800", "A800", "BI150", "BI150"] -# device_num_list = [8, 8, 8, 8] -# memory_capacity_of_devices = [80, 80, 32, 32] # GB -device_type_list = ["A800", "BI150"] +device_type_list = ["A800", "A800"] device_num_list = [4, 4] -memory_capacity_of_devices = [80, 32] # GB +memory_capacity_of_devices = [80, 80] # GB global_batch_size = 512 num_micro_batches = 8 @@ -86,15 +83,15 @@ def is_legal_combination(comb: list): for dp in comb[3::5]: if max_dp % dp != 0: return False + for i in range(len(comb) // 5): + tp, _, _, dp, pp = comb[i * 5 : i * 5 + 5] + device_num = devices_info.device_num_list[i] + if tp * dp * pp != device_num: + return False return True def is_extreme_strategy(comb: list): for mesh_index in range(len(comb) // 5): - # num_devices_in_mesh = sum( - # comb[ - # mesh_index * 5 : mesh_index * 5 + 4 - # ] - # ) num_devices_in_mesh = reduce( lambda x, y: x * y, comb[mesh_index * 5 : mesh_index * 5 + 5] ) @@ -142,8 +139,16 @@ def combine_possible_parallelisms(possible_parallelisms, output_file): print(f"Results written to {output_file}") -def split_layers(num_layers, pp_stages): +def extract_mesh_stage_structure(mesh): + stage_counts = [] + for i in range(0, len(mesh), 5): + stage_counts.append(mesh[i + 4]) + return stage_counts + + +def split_layers(num_layers, pp_stages, mesh): results = [] + mesh_stage_counts = extract_mesh_stage_structure(mesh) # print(pp_stages) for split_points in combinations(range(1, num_layers), pp_stages - 1): # print(split_points) @@ -155,10 +160,31 @@ def split_layers(num_layers, pp_stages): + [num_layers - split_points[-1]] ) # to prune some extreme splits - if max(splits) / min(splits) > 2: + stage_index = 0 + mesh_total_layers = [] + violate = False + + for m in mesh_stage_counts: + sub_splits = splits[stage_index : stage_index + m] + stage_index += m + + if not sub_splits: + continue + + if max(sub_splits) - min(sub_splits) > 4: + violate = True + break + + mesh_total_layers.append(sum(sub_splits)) + + if violate: + continue + + if max(mesh_total_layers) - min(mesh_total_layers) > 4: continue - # print(splits) + results.append(splits) + return results @@ -280,39 +306,35 @@ def gen_hetero_configs( num_layers, # num_micro_batches, # hetero_configs: list, - output_config_file: str = "hetero_configs.json", # 新增参数用于保存 hetero_config + output_config_file: str = "hetero_configs.json", ): devices_info = DevicesInfo(device_type_list=device_type_list, device_num_list=device_num_list) - # 调用 generate_hetero_meshes,生成并写入结果文件 generate_hetero_meshes( devices_info=devices_info, global_batch_size=global_batch_size, num_layers=num_layers, - output_file="results.json", # 保存 hetero_meshes 的中间文件 + output_file="results.json", ) - # 从 results.json 读取 hetero_meshes hetero_meshes = [] with open("results.json", "r") as f: for line in f: hetero_meshes.append(list(map(int, line.strip().split(",")))) # print(hetero_meshes) # assert False - # 遍历 hetero_meshes 并生成 hetero_config + seen = set() with open(output_config_file, "w") as config_file: # 打开输出文件 for mesh in hetero_meshes: pp_stages = sum(mesh[4::5]) # in order to prune the num of layers in each stage to even number - pp_layer_splits = split_layers(num_layers=num_layers // 2, pp_stages=pp_stages) + pp_layer_splits = split_layers(num_layers=num_layers, pp_stages=pp_stages, mesh=mesh) for split in pp_layer_splits: - split = [x * 2 for x in split] + split = [x for x in split] hetero_config = HeteroConfig( mesh=mesh, pp_layer_split=split, device_types=device_type_list ) # hetero_configs.append(hetero_config) - - # 保存 HeteroConfig 的每个成员变量到文件 theory_peak_memory_per_stage = calculate_peak_memory_per_stage(hetero_config) oom_error = report_oom_error( memory_capacity_of_devices=memory_capacity_of_devices, @@ -321,6 +343,12 @@ def gen_hetero_configs( ) # if oom_error: # continue + + key = (tuple(mesh), tuple(split), tuple(device_type_list)) + if key in seen: + continue # 跳过重复项 + seen.add(key) + config_data = { "mesh": hetero_config.mesh, "device_types": hetero_config.device_types, @@ -357,6 +385,11 @@ def get_min_simulated_time_config(hetero_configs): return min(hetero_configs, key=lambda x: x.get("simulated_time", float("inf"))) +def append_config_to_file(file_path: str, config: dict): + with open(file_path, "a") as f: + f.write(str(config) + "\n") + + # for test and usage if __name__ == "__main__": # hetero_configs = [] @@ -375,16 +408,20 @@ def get_min_simulated_time_config(hetero_configs): # assert False # simulation file_path = "hetero_configs.json" + result_path = "simulate_time.json" hetero_configs = read_configs_from_json(file_path) for hetero_config in hetero_configs: print(hetero_config) pp_cost = hetero_config['simulated_time'] = analylize_pipeline_time.analyze_pp_time( # pp_cost = hetero_config.simulated_time = analylize_pipeline_time.analyze_pp_time( - scheme="1F1B", + scheme="vpp", + # scheme="1F1B", num_micro_batches=num_micro_batches, process_mesh=hetero_config['mesh'], pp_layers_split=hetero_config['pp_layer_split'], ) print(f"pipeline cost: {pp_cost}") + append_config_to_file(result_path, hetero_config) + best_config = get_min_simulated_time_config(hetero_configs) print(best_config) diff --git a/flagscale/train/arguments.py b/flagscale/train/arguments.py index e085369cc9..b2c04f1780 100644 --- a/flagscale/train/arguments.py +++ b/flagscale/train/arguments.py @@ -79,8 +79,8 @@ def run_virtual_rank(rank, world_size, timeout): args.distributed_timeout_minutes = 1 threads = [] # Start a thread for each virtual rank - for rank in range(1, 2): # 2 for skipping launching thousands of threads - # for rank in range(1, args.world_size): + # for rank in range(1, 2): # 2 for skipping launching thousands of threads + for rank in range(1, args.world_size): thread = threading.Thread( target=run_virtual_rank, args=(rank, args.world_size, args.distributed_timeout_minutes), From a779bbbd039715e9677636116c30d0a369412b8a Mon Sep 17 00:00:00 2001 From: shuailong616 <452509829@qq.com> Date: Wed, 19 Nov 2025 10:59:08 +0800 Subject: [PATCH 6/8] adpate to single-process simulator, which still has problems --- .../megatron/core/optimizer/__init__.py.patch | 58 +++++++++++++------ 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/flagscale/backends/Megatron-LM/megatron/core/optimizer/__init__.py.patch b/flagscale/backends/Megatron-LM/megatron/core/optimizer/__init__.py.patch index c3199e793a..ece9ab2ef8 100644 --- a/flagscale/backends/Megatron-LM/megatron/core/optimizer/__init__.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/core/optimizer/__init__.py.patch @@ -1,5 +1,5 @@ diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py -index 1846907e9..7fd0554c4 100644 +index 1846907e9..70d0e72b4 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -55,6 +55,7 @@ def _get_param_groups( @@ -19,23 +19,41 @@ index 1846907e9..7fd0554c4 100644 if scale_lr_cond is not None: scale_lr = scale_lr_cond(name, param) -@@ -128,8 +131,14 @@ def _get_param_groups( +@@ -128,11 +131,32 @@ def _get_param_groups( param, 'is_embedding_or_output_parameter', False ): is_decoupled_lr = True -+ +- + key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr) + if key not in params_map: + params_map[key] = [] + params_map[key].append(param) ++ param_groups = [] ++ for (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr), params in params_map.items(): ++ assert len(params) > 0 ++ param_group = { ++ 'params': params, ++ 'wd_mult': wd_mult, ++ 'lr_mult': _lr_mult, ++ 'is_expert_parallel': is_expert_parallel, ++ 'is_decoupled_lr': is_decoupled_lr, ++ } ++ param_groups.append(param_group) ++ ''' + is_vision_model_param = False + if "vision_model" in name: + is_vision_model_param = True + else: + is_vision_model_param = False - -- key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr) ++ + key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr, is_vision_model_param) - if key not in params_map: - params_map[key] = [] - params_map[key].append(param) -@@ -147,7 +156,7 @@ def _get_param_groups( ++ if key not in params_map: ++ params_map[key] = [] ++ params_map[key].append(param) + + # Distributed checkpoint requires all ranks to have the same param groups, + # so we need to align the param groups across ranks, otherwise we may have +@@ -147,7 +171,7 @@ def _get_param_groups( param_groups = [] for key in params_key: @@ -44,7 +62,7 @@ index 1846907e9..7fd0554c4 100644 params = params_map[key] if key in params_map else [] param_group = { 'params': params, -@@ -155,6 +164,7 @@ def _get_param_groups( +@@ -155,11 +179,13 @@ def _get_param_groups( 'lr_mult': _lr_mult, 'is_expert_parallel': is_expert_parallel, 'is_decoupled_lr': is_decoupled_lr, @@ -52,7 +70,13 @@ index 1846907e9..7fd0554c4 100644 } # Ensure param_group has required keys for matching when loading optimizer state # See MegatronOptimizer._filter_and_reorder_param_groups. -@@ -167,6 +177,7 @@ def _get_param_groups( + assert set(param_group.keys()) - set(param_group_identifier_keys) == {'params'} + param_groups.append(param_group) ++ ''' + + param_groups = _update_min_and_max_lr_in_param_groups( + param_groups, +@@ -167,6 +193,7 @@ def _get_param_groups( min_lr=min_lr, decoupled_lr=decoupled_lr, decoupled_min_lr=decoupled_min_lr, @@ -60,7 +84,7 @@ index 1846907e9..7fd0554c4 100644 ) return param_groups -@@ -178,6 +189,7 @@ def _update_min_and_max_lr_in_param_groups( +@@ -178,6 +205,7 @@ def _update_min_and_max_lr_in_param_groups( min_lr: float, decoupled_lr: Optional[float], decoupled_min_lr: Optional[float], @@ -68,16 +92,16 @@ index 1846907e9..7fd0554c4 100644 ) -> List[Dict]: """ Updates `max_lr` and `min_lr` values in each parameter group, and returns new list. -@@ -206,7 +218,7 @@ def _update_min_and_max_lr_in_param_groups( +@@ -206,7 +234,7 @@ def _update_min_and_max_lr_in_param_groups( param_group['max_lr'] = decoupled_lr param_group['min_lr'] = decoupled_min_lr else: - param_group['max_lr'] = lr -+ param_group['max_lr'] = lr if not param_group['is_vision_model_param'] else lr * vision_ration # NOTE(lizhiyu): change the ration here ++ param_group['max_lr'] = lr #if not param_group['is_vision_model_param'] else lr * vision_ration # NOTE(lizhiyu): change the ration here param_group['min_lr'] = min_lr return param_groups -@@ -255,6 +267,7 @@ def _get_param_groups_and_buffers( +@@ -255,6 +283,7 @@ def _get_param_groups_and_buffers( decoupled_lr=config.decoupled_lr, decoupled_min_lr=config.decoupled_min_lr, default_skip_embedding_weight_decay=default_skip_embedding_weight_decay, @@ -85,7 +109,7 @@ index 1846907e9..7fd0554c4 100644 ) param_groups = list(filter(filter_fn, param_groups)) buffers = {} -@@ -511,6 +524,10 @@ def get_megatron_optimizer( +@@ -511,6 +540,10 @@ def get_megatron_optimizer( intra_dp_cp_group = process_groups['intra_dp_cp_group'] intra_expt_dp_group = process_groups['intra_expt_dp_group'] mp_group = process_groups['mp_group'] @@ -96,7 +120,7 @@ index 1846907e9..7fd0554c4 100644 expt_tp_pp_group = process_groups['expt_tp_pp_group'] intra_dp_cp_group_gloo = process_groups['intra_dp_cp_group_gloo'] intra_expt_dp_group_gloo = process_groups['intra_expt_dp_group_gloo'] -@@ -609,7 +626,11 @@ def get_megatron_optimizer( +@@ -609,7 +642,11 @@ def get_megatron_optimizer( default_skip_embedding_weight_decay=default_skip_embedding_weight_decay, ) if len(moe_param_groups) > 0: From c7178671a797807aa81d8cdedb562b29f03901ff Mon Sep 17 00:00:00 2001 From: shuailong616 <452509829@qq.com> Date: Thu, 20 Nov 2025 14:52:38 +0800 Subject: [PATCH 7/8] fix README.md --- .../runner/auto_tuner/simulator/README.md | 36 +++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/flagscale/runner/auto_tuner/simulator/README.md b/flagscale/runner/auto_tuner/simulator/README.md index eee9c9513e..4ac43d38ce 100644 --- a/flagscale/runner/auto_tuner/simulator/README.md +++ b/flagscale/runner/auto_tuner/simulator/README.md @@ -1,12 +1,13 @@ # Environment Begin at the root path of `FlagScale` repository: +1. Install backend ``` -cd flagscale/flagscale/runner/auto_tuner/simulator/custom_backend/ +cd flagscale/runner/auto_tuner/simulator/custom_backend/ python setup.py develop ``` # Setup -Set necessary parameters in `config_gen.py`. For example: +2. Set necessary parameters in `config_gen.py`. For example: ``` device_type_list = ["A", "B"] device_num_list = [4, 4] @@ -15,9 +16,30 @@ num_micro_batches = 8 num_layers = 4 ``` # Run a Task -Start the auto-tuning: -``` -export PYTHONPATH=/****/FlagScale:$PYTHONPATH -export PYTHONPATH=$PYTHONPATH:/***/FlagScale/third_party/Megatron-LM +3. Start the auto-tuning: + a. set PYTHONPATH + ``` + export PYTHONPATH=/***/FlagScale:$PYTHONPATH + export PYTHONPATH=$PYTHONPATH:/***/FlagScale/third_party/Megatron-LM + + vim /***/FlagScale/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py + os.environ["PYTHONPATH"] = ( + "/***/FlagScale:" + "/***/FlagScale/third_party/Megatron-LM" + ) + ``` + b. run + + vim flagscale/runner/auto_tuner/simulator/config_gen.py + + set scheme = vpp or scheme = 1F1B + + python flagscale/runner/auto_tuner/simulator/config_gen.py -python flagscale/runner/auto_tuner/simulator/config_gen.py + c. result + ``` + {'mesh': [2, 1, 1, 1, 2, 1, 1, 1, 1, 4], 'device_types': ['A800', 'A800'], 'pp_layer_split': [8, 8, 5, 5, 5, 1], 'recompute_granularity': None, 'recompute_method': 'uniform', 'recompute_num_layers': 1, 'simulated_time': 57.52105478485333, 'theory_peak_memory': [110.487650304, 118.80914944, 158.35625472, 158.35625472, 158.35625472, 42.519842816], 'oom_error': True} + {'mesh': [2, 1, 1, 1, 2, 1, 1, 1, 1, 4], 'device_types': ['A800', 'A800'], 'pp_layer_split': [8, 7, 5, 5, 5, 2], 'recompute_granularity': None, 'recompute_method': 'uniform', 'recompute_num_layers': 1, 'simulated_time': 61.20105478485332, 'theory_peak_memory': [110.487650304, 109.345202176, 158.35625472, 158.35625472, 158.35625472, 61.447737344], 'oom_error': True} + {'mesh': [2, 1, 1, 1, 2, 1, 1, 1, 1, 4], 'device_types': ['A800', 'A800'], 'pp_layer_split': [8, 8, 5, 5, 4, 2], 'recompute_granularity': None, 'recompute_method': 'uniform', 'recompute_num_layers': 1, 'simulated_time': 54.73105478485331, 'theory_peak_memory': [110.487650304, 118.80914944, 158.35625472, 158.35625472, 119.365943296, 61.447737344], 'oom_error': True} +... +``` From e4f0b28562346d0025fc3a77e9c73a449035cfe5 Mon Sep 17 00:00:00 2001 From: shuailong616 <452509829@qq.com> Date: Wed, 26 Nov 2025 19:01:08 +0800 Subject: [PATCH 8/8] fix some bug --- .../megatron/core/optimizer/__init__.py.patch | 156 ++++++++++++------ .../megatron/training/arguments.py.patch | 9 +- .../runner/auto_tuner/simulator/config_gen.py | 15 +- 3 files changed, 116 insertions(+), 64 deletions(-) diff --git a/flagscale/backends/Megatron-LM/megatron/core/optimizer/__init__.py.patch b/flagscale/backends/Megatron-LM/megatron/core/optimizer/__init__.py.patch index ece9ab2ef8..665124d88e 100644 --- a/flagscale/backends/Megatron-LM/megatron/core/optimizer/__init__.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/core/optimizer/__init__.py.patch @@ -1,8 +1,16 @@ diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py -index 1846907e9..70d0e72b4 100644 +index 1846907e9..8355608fe 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py -@@ -55,6 +55,7 @@ def _get_param_groups( +@@ -3,6 +3,7 @@ import logging + import warnings + from typing import Callable, Dict, List, Optional, Tuple + ++import os + import torch + from torch.optim import SGD as CPUSGD + from torch.optim import AdamW as CPUAdam +@@ -55,6 +56,7 @@ def _get_param_groups( decoupled_lr: Optional[float], decoupled_min_lr: Optional[float], default_skip_embedding_weight_decay: bool = False, @@ -10,7 +18,7 @@ index 1846907e9..70d0e72b4 100644 ) -> List[Dict]: """Create parameter groups for optimizer. -@@ -106,6 +107,8 @@ def _get_param_groups( +@@ -106,6 +108,8 @@ def _get_param_groups( or len(param.shape) == 1 or (default_skip_embedding_weight_decay and "embedding" in name) ) @@ -19,64 +27,99 @@ index 1846907e9..70d0e72b4 100644 if scale_lr_cond is not None: scale_lr = scale_lr_cond(name, param) -@@ -128,11 +131,32 @@ def _get_param_groups( - param, 'is_embedding_or_output_parameter', False +@@ -129,37 +133,64 @@ def _get_param_groups( ): is_decoupled_lr = True -- - key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr) - if key not in params_map: - params_map[key] = [] - params_map[key].append(param) -+ param_groups = [] -+ for (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr), params in params_map.items(): -+ assert len(params) > 0 -+ param_group = { -+ 'params': params, -+ 'wd_mult': wd_mult, -+ 'lr_mult': _lr_mult, -+ 'is_expert_parallel': is_expert_parallel, -+ 'is_decoupled_lr': is_decoupled_lr, -+ } -+ param_groups.append(param_group) -+ ''' -+ is_vision_model_param = False -+ if "vision_model" in name: -+ is_vision_model_param = True + +- key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr) +- if key not in params_map: +- params_map[key] = [] +- params_map[key].append(param) ++ if os.environ.get("ENABLE_SIMULATOR") == "1": ++ key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr) ++ if key not in params_map: ++ params_map[key] = [] ++ params_map[key].append(param) + else: + is_vision_model_param = False ++ if "vision_model" in name: ++ is_vision_model_param = True ++ else: ++ is_vision_model_param = False + -+ key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr, is_vision_model_param) -+ if key not in params_map: -+ params_map[key] = [] -+ params_map[key].append(param) ++ key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr, is_vision_model_param) ++ if key not in params_map: ++ params_map[key] = [] ++ params_map[key].append(param) ++ if os.environ.get("ENABLE_SIMULATOR") == "1": ++ param_groups = [] ++ for (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr), params in params_map.items(): ++ assert len(params) > 0 ++ param_group = { ++ 'params': params, ++ 'wd_mult': wd_mult, ++ 'lr_mult': _lr_mult, ++ 'is_expert_parallel': is_expert_parallel, ++ 'is_decoupled_lr': is_decoupled_lr, ++ } ++ param_groups.append(param_group) # Distributed checkpoint requires all ranks to have the same param groups, # so we need to align the param groups across ranks, otherwise we may have -@@ -147,7 +171,7 @@ def _get_param_groups( - - param_groups = [] - for key in params_key: + # runtime error when loading the checkpoint or numerical error when resuming training. +- params_key = list(params_map.keys()) +- gathered_params_key = [None for _ in range(torch.distributed.get_world_size())] +- torch.distributed.all_gather_object(gathered_params_key, params_key) +- for keys in gathered_params_key: +- for key in keys: +- if key not in params_key: +- params_key.append(key) +- +- param_groups = [] +- for key in params_key: - wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr = key -+ wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr, is_vision_model_param = key - params = params_map[key] if key in params_map else [] - param_group = { - 'params': params, -@@ -155,11 +179,13 @@ def _get_param_groups( - 'lr_mult': _lr_mult, - 'is_expert_parallel': is_expert_parallel, - 'is_decoupled_lr': is_decoupled_lr, -+ 'is_vision_model_param': is_vision_model_param, - } - # Ensure param_group has required keys for matching when loading optimizer state - # See MegatronOptimizer._filter_and_reorder_param_groups. - assert set(param_group.keys()) - set(param_group_identifier_keys) == {'params'} - param_groups.append(param_group) -+ ''' +- params = params_map[key] if key in params_map else [] +- param_group = { +- 'params': params, +- 'wd_mult': wd_mult, +- 'lr_mult': _lr_mult, +- 'is_expert_parallel': is_expert_parallel, +- 'is_decoupled_lr': is_decoupled_lr, +- } +- # Ensure param_group has required keys for matching when loading optimizer state +- # See MegatronOptimizer._filter_and_reorder_param_groups. +- assert set(param_group.keys()) - set(param_group_identifier_keys) == {'params'} +- param_groups.append(param_group) ++ else: ++ params_key = list(params_map.keys()) ++ gathered_params_key = [None for _ in range(torch.distributed.get_world_size())] ++ torch.distributed.all_gather_object(gathered_params_key, params_key) ++ for keys in gathered_params_key: ++ for key in keys: ++ if key not in params_key: ++ params_key.append(key) ++ ++ param_groups = [] ++ for key in params_key: ++ wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr, is_vision_model_param = key ++ params = params_map[key] if key in params_map else [] ++ param_group = { ++ 'params': params, ++ 'wd_mult': wd_mult, ++ 'lr_mult': _lr_mult, ++ 'is_expert_parallel': is_expert_parallel, ++ 'is_decoupled_lr': is_decoupled_lr, ++ 'is_vision_model_param': is_vision_model_param, ++ } ++ # Ensure param_group has required keys for matching when loading optimizer state ++ # See MegatronOptimizer._filter_and_reorder_param_groups. ++ assert set(param_group.keys()) - set(param_group_identifier_keys) == {'params'} ++ param_groups.append(param_group) ++ param_groups = _update_min_and_max_lr_in_param_groups( param_groups, -@@ -167,6 +193,7 @@ def _get_param_groups( +@@ -167,6 +198,7 @@ def _get_param_groups( min_lr=min_lr, decoupled_lr=decoupled_lr, decoupled_min_lr=decoupled_min_lr, @@ -84,7 +127,7 @@ index 1846907e9..70d0e72b4 100644 ) return param_groups -@@ -178,6 +205,7 @@ def _update_min_and_max_lr_in_param_groups( +@@ -178,6 +210,7 @@ def _update_min_and_max_lr_in_param_groups( min_lr: float, decoupled_lr: Optional[float], decoupled_min_lr: Optional[float], @@ -92,16 +135,19 @@ index 1846907e9..70d0e72b4 100644 ) -> List[Dict]: """ Updates `max_lr` and `min_lr` values in each parameter group, and returns new list. -@@ -206,7 +234,7 @@ def _update_min_and_max_lr_in_param_groups( +@@ -206,7 +239,10 @@ def _update_min_and_max_lr_in_param_groups( param_group['max_lr'] = decoupled_lr param_group['min_lr'] = decoupled_min_lr else: - param_group['max_lr'] = lr -+ param_group['max_lr'] = lr #if not param_group['is_vision_model_param'] else lr * vision_ration # NOTE(lizhiyu): change the ration here ++ if os.environ.get("ENABLE_SIMULATOR") == "1": ++ param_group['max_lr'] = lr ++ else: ++ param_group['max_lr'] = lr if not param_group['is_vision_model_param'] else lr * vision_ration # NOTE(lizhiyu): change the ration here param_group['min_lr'] = min_lr return param_groups -@@ -255,6 +283,7 @@ def _get_param_groups_and_buffers( +@@ -255,6 +291,7 @@ def _get_param_groups_and_buffers( decoupled_lr=config.decoupled_lr, decoupled_min_lr=config.decoupled_min_lr, default_skip_embedding_weight_decay=default_skip_embedding_weight_decay, @@ -109,7 +155,7 @@ index 1846907e9..70d0e72b4 100644 ) param_groups = list(filter(filter_fn, param_groups)) buffers = {} -@@ -511,6 +540,10 @@ def get_megatron_optimizer( +@@ -511,6 +548,10 @@ def get_megatron_optimizer( intra_dp_cp_group = process_groups['intra_dp_cp_group'] intra_expt_dp_group = process_groups['intra_expt_dp_group'] mp_group = process_groups['mp_group'] @@ -120,7 +166,7 @@ index 1846907e9..70d0e72b4 100644 expt_tp_pp_group = process_groups['expt_tp_pp_group'] intra_dp_cp_group_gloo = process_groups['intra_dp_cp_group_gloo'] intra_expt_dp_group_gloo = process_groups['intra_expt_dp_group_gloo'] -@@ -609,7 +642,11 @@ def get_megatron_optimizer( +@@ -609,7 +650,11 @@ def get_megatron_optimizer( default_skip_embedding_weight_decay=default_skip_embedding_weight_decay, ) if len(moe_param_groups) > 0: diff --git a/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch b/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch index 6e946d062a..f01686ca7e 100644 --- a/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch @@ -1,5 +1,5 @@ diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index 1120c7529..190fac52b 100644 +index 1120c7529..ebb1467c3 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -67,6 +67,7 @@ def add_megatron_arguments(parser: argparse.ArgumentParser): @@ -477,7 +477,7 @@ index 1120c7529..190fac52b 100644 dest='overlap_p2p_comm_warmup_flush') group.add_argument('--distributed-backend', default='nccl', - choices=['nccl', 'gloo'], -+ choices=['nccl', 'gloo', 'flagcx'], ++ choices=['nccl', 'gloo', 'flagcx', 'dummy'], help='Which backend to use for distributed training.') group.add_argument('--distributed-timeout-minutes', type=int, default=10, help='Timeout minutes for torch.distributed.') @@ -602,7 +602,7 @@ index 1120c7529..190fac52b 100644 return parser -@@ -3275,3 +3542,75 @@ def _add_sft_args(parser): +@@ -3275,3 +3542,78 @@ def _add_sft_args(parser): group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", help='SFT prompt format.') return parser @@ -637,6 +637,9 @@ index 1120c7529..190fac52b 100644 + group.add_argument('--auto-tune', action='store_true', + help='use auto tuner') + ++ group.add_argument('--enable-simulator', action='store_true', ++ help='Use single process to simulate the distributed training.') ++ + return parser + + diff --git a/flagscale/runner/auto_tuner/simulator/config_gen.py b/flagscale/runner/auto_tuner/simulator/config_gen.py index d16411c9d9..47b0cc37e3 100644 --- a/flagscale/runner/auto_tuner/simulator/config_gen.py +++ b/flagscale/runner/auto_tuner/simulator/config_gen.py @@ -214,7 +214,8 @@ def __init__(self, mesh_config: HeteroConfig): self.pipeline_model_parallel_size = sum(mesh_config.mesh[4::5]) self.tensor_model_parallel_size = mesh_config.mesh[0] self.virtual_pipeline_model_parallel_size = None - self.num_experts = 1 + self.num_experts = None + self.context_parallel_size = 1 self.swiglu = True self.micro_batch_size = global_batch_size / num_micro_batches / self.data_parallel_size @@ -222,8 +223,8 @@ def __init__(self, mesh_config: HeteroConfig): self.num_attention_heads = 32 self.group_query_attention = None # TODO self.num_query_groups = 1 # TODO - self.moe_layer_freq = 2 - self.moe_router_topk = 1 + # self.moe_layer_freq = 2 + # self.moe_router_topk = 1 self.multi_latent_attention = False self.seq_length = 2048 self.padded_vocab_size = 4096 # TODO @@ -232,9 +233,9 @@ def __init__(self, mesh_config: HeteroConfig): self.mtp_num_layers = None self.expert_model_parallel_size = 1 self.world_size = 8 - self.moe_shared_expert_intermediate_size = 16384 - self.moe_ffn_hidden_size = 4 * self.hidden_size - # self.ffn_hidden_size + self.moe_shared_expert_intermediate_size = None + self.moe_ffn_hidden_size = None + ## self.ffn_hidden_size self.multiple_of = 256 hidden_dim = int(4 * self.hidden_size * 2 / 3) self.ffn_hidden_size = self.multiple_of * ( @@ -246,7 +247,9 @@ def __init__(self, mesh_config: HeteroConfig): self.recompute_granularity = mesh_config.recompute_granularity self.recompute_method = mesh_config.recompute_method self.recompute_num_layers = mesh_config.recompute_num_layers + self.expert_tensor_parallel_size = 1 + self.use_flash_attn = True self.sequence_parallel = True self.use_distributed_optimizer = True