diff --git a/examples/nsys_profile_rank.sh b/examples/nsys_profile_rank.sh new file mode 100755 index 00000000000..4f356b03e9e --- /dev/null +++ b/examples/nsys_profile_rank.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +set -e + +# nsys profile -t cuda,nvtx,osrt -s none --cpuctxsw none --python-sampling true --python-sampling-frequency 1000 $@ || true + +# nsys profile -t cuda,nvtx,osrt --force-overwrite true \ +# --capture-range=cudaProfilerApi --capture-range-end=stop --gpu-metrics-device=0 \ +# --python-sampling-frequency 1000 --python-sampling true \ +# $@ || true + +nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi -o "$NSYS_DIR/datetime_${DATETIME}_gpt_sft_hetero_cp_iter2_4_flash_global_8192_rank${OMPI_COMM_WORLD_RANK}" $@ || true + +# PROFILE_RANKS=(0 1 2 3 4 5 6 7 8) + +# if [[ " ${PROFILE_RANKS[*]} " =~ " $OMPI_COMM_WORLD_RANK " ]]; then +# nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi -o "datetime_${DATETIME}_gpt_sft_hetero_cp_iter2_4_flash_global_8192_rank${OMPI_COMM_WORLD_RANK}" $@ || true +# else +# $@ || true +# fi diff --git a/examples/run_hybrid_cp.sh b/examples/run_hybrid_cp.sh new file mode 100755 index 00000000000..6d530a265b9 --- /dev/null +++ b/examples/run_hybrid_cp.sh @@ -0,0 +1,387 @@ +#!/bin/bash + +# set -euo pipefail +export DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +PYTHONPATH= + +export NCCL_IB_SL=1 +export TOKENIZERS_PARALLELISM="false" + +#export NVTE_DEBUG=1 +#export NVTE_DEBUG_LEVEL=2 + +USER="wuguohao" + +# Auto-detect batch or interactive mode. +which srun +BATCH=$((1-$?)) + +DEBUG=0 +USE_TILING=1 +USE_CP=0 +USE_TE_CE=1 +USE_FLASH_ATTN=0 +USE_FSDP=0 +PROFILE=0 +PROFILE_MEMORY=0 +# PROFILE_RANKS=[0,1,2,3,4,5,6,7,8] +TRAIN_ITERS=10 +USE_MOCK_DATA=1 +MASTER_PORT=6103 +TP=1 +PP=8 +PP_l= +MIN_CP=1 +MAX_CP=8 +NUM_LAYERS=8 + +MBZ=1 +BZ=2048 +HIDDEN_SIZE=5120 +FFN_HIDDEN_SIZE=13824 +HEAD_DIM=128 +NUM_HEAD=$((HIDDEN_SIZE / HEAD_DIM)) +SEQ_LEN=131072 #131072 #81920 #65536 # 32768 #16384 +MIN_SEQ_LEN=256 +MAX_SEQLEN_PER_DP_CP_RANK=65536 +NW=16 +AD=0.0 +HD=0.0 +LI=1 +EXTRA_ARGS="" +NONDETERMINISTIC_ATTN=1 +# NUM_GPU=8 + +# Remember to update model and job name if running in batch mode!! +# if [[ $BATCH -eq 0 ]]; then +# DATETIME=`date +'%y-%m-%d-%H-%M-%S'` +# MODEL_NAME="interactive_hybrid_cp" +# WORKSPACE="/home/tailaim//work_data/megatron-lm/logs" +# SOURCE="/home/tailaim/work_data/megatron-lm" +# TOKENIZER="/home/tailaim/work_data/megatron-moe-scripts/Nemotron-H-4B-Instruct" +# else +# MODEL_NAME="interactive_hybrid_cp" +# WORKSPACE="/lustre/fsw/portfolios/coreai/users/tailaim/work_data/megatron-lm/logs" +# SOURCE="/lustre/fsw/portfolios/coreai/users/tailaim/work_data/megatron-lm" +# TOKENIZER="/lustre/fsw/portfolios/llmservice/users/kezhik/images/Nemotron-H-4B-Instruct" +# fi + +HOSTFILE=${HOSTFILE:-} +if [ -f /etc/mpi/hostfile ]; then + if [ ! -f /etc/mpi/hostfile_seq -a -z "$HOSTFILE" ]; then + echo "Please use kai_launch to generate /etc/mpi/hostfile_seq" + exit 1 + fi + HOSTFILE=${HOSTFILE:-/etc/mpi/hostfile_seq} +fi + +if [ -n "$HOSTFILE" ]; then + # 多机任务 + if [ -z "${MY_NODE_IP:-}" ]; then echo "Variable MY_NODE_IP does not exist."; exit 1; fi + if ! ifconfig | grep " $MY_NODE_IP " >/dev/null; then echo "MY_NODE_IP \"$MY_NODE_IP\" is not contained in \`ifconfig\`."; exit 1; fi + MASTER_ADDR=$MY_NODE_IP + if [ ! -f "$HOSTFILE" ]; then echo "Hostfile \"$HOSTFILE\" does not exist."; exit 1; fi + NP=${NP:-$(cat "$HOSTFILE" | grep -v '^#' | grep -oP 'slots=\K\d+' | awk '{sum += $1} END {print sum}')} +else + # 单机任务 + MASTER_ADDR=127.0.0.1 + NP=${NP:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} +fi +# NP=16 + +function check_str() { + if [ ! -v "$1" ]; then echo "Variable $1 is not set."; exit 1; fi + if [[ -z "${!1}" ]]; then echo "Variable $1 is not a string."; exit 1; fi +} + +# PLM_RSH_ARGS +if [ -v TARGET_IP_PORT_FILE ]; then check_str TARGET_IP_PORT_FILE; PLM_RSH_ARGS="-F $TARGET_IP_PORT_FILE"; +elif [ ! -v TARGET_IP_PORT_FILE -a -n "$HOSTFILE" ]; then PORT=$(cat /etc/ssh/ssh_config | grep 'Port' | cut -d'"' -f2); check_integer PORT; PLM_RSH_ARGS="-p $PORT"; +else PLM_RSH_ARGS=; +fi + + +MODEL_NAME="interactive_hybrid_cp" +TOKENIZER=None + +WORKSPACE="../logs" +mkdir -p $WORKSPACE +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}/$DATETIME" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" +DATACACHE_DIR="${OUTPUT}/data_cache" +export NSYS_DIR="${OUTPUT}/nsys" +PROFILE_MEMORY_PATH="${OUTPUT}/mem_profile" + +mkdir -p $FINETUNE_DIR +mkdir -p $LOGS_DIR +mkdir -p $TENSORBOARD_DIR +mkdir -p $DATACACHE_DIR +mkdir -p $NSYS_DIR +mkdir -p $COST_DATA_FILE +mkdir -p $PROFILE_MEMORY_PATH + +export HF_DATASETS_CACHE="${OUTPUT}/hf_datasets_cache" + +DATA_TRAIN="/home/tailaim/data/thd_formatted_100k.jsonl" + +CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )" +MEGATRON_PATH=$( dirname ${CURRENT_DIR}) + +# if [[ $DEBUG -eq 1 ]]; then +# MBZ=1 +# BZ=256 +# NW=4 +# AD=0.0 +# HD=0.0 +# LI=1 + +# # EXTRA_ARGS="--deterministic-mode --use-cpu-initialization" + +# NONDETERMINISTIC_ATTN=1 + +# NUM_GPU=8 +# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# #export NCCL_ALGO=Tree +# #export CUBLAS_WORKSPACE_CONFIG=:4096:8 +# else +# fi + + +if [[ $USE_TE_CE -eq 1 ]]; then + EXTRA_ARGS+=" --cross-entropy-loss-fusion --cross-entropy-fusion-impl te" +fi + +if [[ $PROFILE -eq 1 ]]; then + EXTRA_ARGS+=" --profile --profile-step-start 2 --profile-step-end 6 " +fi + +echo $USE_MOCK_DATA +if [[ $USE_MOCK_DATA -eq 1 ]]; then + # EXTRA_ARGS+=" --mock-data --sft-mock-dataset-config-json '{\"mode\":\"file\",\"path\":\"path/to/file\"}'" + if [[ $BATCH -eq 0 ]]; then + # EXTRA_ARGS+=" --mock-data --sft-mock-dataset-config-json {\"mode\":\"distribution\",\"type\":\"lognormal\",\"min_seq_len\":256,\"max_seq_len\":$SEQ_LEN,\"mean_seq_len\":16384,\"lognormal_sigma\":1.1} " + # EXTRA_ARGS+=" --mock-data --sft-mock-dataset-config-json {\"mode\":\"distribution\",\"type\":\"linear\",\"min_seq_len\":1024,\"max_seq_len\":32768} " + EXTRA_ARGS+=" --mock-data --sft-mock-dataset-config-json {\"mode\":\"file\",\"path\":\"/m2v_model/wuguohao03/dataset/github/github_subset_2.csv\",\"min_seq_len\":$MIN_SEQ_LEN,\"max_seq_len\":$SEQ_LEN} " + # EXTRA_ARGS+=" --mock-data --sft-mock-dataset-config-json {\"mode\":\"file\",\"path\":\"/m2v_model/wuguohao03/dataset/commoncrawl/commoncrawl_subset_2.csv\",\"min_seq_len\":$MIN_SEQ_LEN,\"max_seq_len\":$SEQ_LEN} " + # EXTRA_ARGS+=" --mock-data --sft-mock-dataset-config-json {\"mode\":\"file\",\"path\":\"/m2v_model/wuguohao03/dataset/wikipedia/wikipedia_subset_2.csv\",\"min_seq_len\":$MIN_SEQ_LEN,\"max_seq_len\":$SEQ_LEN} " + # EXTRA_ARGS+=" --mock-data --sft-mock-dataset-config-json {\"mode\":\"indexed_file\",\"path\":\"${DATA_TRAIN}\",\"type\":\"lognormal\",\"min_seq_len\":1024,\"max_seq_len\":32768,\"mean_seq_len\":8192,\"lognormal_sigma\":1.1} " + else + # EXTRA_ARGS+=" --mock-data --sft-mock-dataset-config-json '{\"mode\":\"distribution\",\"type\":\"lognormal\",\"min_seq_len\":256,\"max_seq_len\":$SEQ_LEN,\"mean_seq_len\":16384,\"lognormal_sigma\":1.1}' " + # EXTRA_ARGS+=" --mock-data --sft-mock-dataset-config-json {\"mode\":\"distribution\",\"type\":\"linear\",\"min_seq_len\":1024,\"max_seq_len\":32768} " + EXTRA_ARGS+=" --mock-data --sft-mock-dataset-config-json {\"mode\":\"file\",\"path\":\"/m2v_model/wuguohao03/dataset/github/github_subset_2.csv\",\"min_seq_len\":$MIN_SEQ_LEN,\"max_seq_len\":$SEQ_LEN} " + # EXTRA_ARGS+=" --mock-data --sft-mock-dataset-config-json {\"mode\":\"file\",\"path\":\"/m2v_model/wuguohao03/dataset/commoncrawl/commoncrawl_subset_2.csv\",\"min_seq_len\":$MIN_SEQ_LEN,\"max_seq_len\":$SEQ_LEN} " + # EXTRA_ARGS+=" --mock-data --sft-mock-dataset-config-json {\"mode\":\"file\",\"path\":\"/m2v_model/wuguohao03/dataset/wikipedia/wikipedia_subset_2.csv\",\"min_seq_len\":$MIN_SEQ_LEN,\"max_seq_len\":$SEQ_LEN} " + # EXTRA_ARGS+=" --mock-data --sft-mock-dataset-config-json {\"mode\":\"indexed_file\",\"path\":\"${DATA_TRAIN}\",\"type\":\"lognormal\",\"min_seq_len\":1024,\"max_seq_len\":32768,\"mean_seq_len\":8192,\"lognormal_sigma\":1.1} " + fi +else + EXTRA_ARGS+=" --data-path ${DATA_TRAIN} --tokenizer-model ${TOKENIZER} " +fi + +if [[ $USE_FSDP -eq 1 ]]; then + # --ckpt-format fsdp_dtensor + EXTRA_ARGS+="--ckpt-format fsdp_dtensor --use-megatron-fsdp --data-parallel-sharding-strategy optim_grads_params --no-gradient-accumulation-fusion --use-distributed-optimizer " + unset CUDA_DEVICE_MAX_CONNECTIONS +else + export CUDA_DEVICE_MAX_CONNECTIONS=1 +fi + + + # --profile-ranks $PROFILE_RANKS \ + + # --use-gpu-timer \ + # --gpu-timer-interval 1 \ + # + # --hybrid-context-parallel-scheduler only_packing_no_scheduling \ + # --recompute-activations \ + # --disable-gloo-process-groups \ + # --add-qkv-bias \ + # --disable-gloo-process-groups \ + # --hybrid-context-parallel \ + # --async-hybrid-context-parallel-scheduler \ + # --hybrid-context-parallel-scheduler "only_packing_no_scheduling" \ + +OPTIONS=" \ + `if [ $PROFILE_MEMORY == 1 ]; then echo --profile-memory; fi` \ + `if [ $PROFILE_MEMORY == 1 ]; then echo --profile-memory-path $PROFILE_MEMORY_PATH; fi` \ + --log-throughput \ + --log-energy \ + --no-check-for-nan-in-loss-and-grad \ + --recompute-granularity full \ + --recompute-method uniform \ + --recompute-num-layers 1 \ + --timing-log-level 1 \ + --timing-log-option minmax \ + --sft-sequence-packing \ + --hybrid-context-parallel \ + --min-hybrid-context-parallel-size $MIN_CP \ + --max-hybrid-context-parallel-size $MAX_CP \ + --hybrid-context-parallel \ + --hybrid-context-parallel-scheduler only_packing_no_scheduling \ + --async-hybrid-context-parallel-scheduler \ + --max-seqlen-per-dp-cp-rank $MAX_SEQLEN_PER_DP_CP_RANK \ + --sft \ + --vocab-size $SEQ_LEN \ + --tokenizer-type NullTokenizer \ + --legacy-tokenizer \ + --use-distributed-optimizer \ + --disable-bias-linear \ + --sft-tokenizer-prompt-format nemotron-h-aligned \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + ${PP_l:+--num-layers-per-virtual-pipeline-stage $PP_l} \ + --rerun-mode disabled \ + --num-layers $NUM_LAYERS \ + --hidden-size $HIDDEN_SIZE \ + --ffn-hidden-size $FFN_HIDDEN_SIZE \ + --num-attention-heads $NUM_HEAD \ + --num-workers ${NW} \ + --exit-duration-in-mins 230 \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings ${SEQ_LEN} \ + --train-iters $TRAIN_ITERS \ + --lr-warmup-samples 0 \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --lr 2e-5 \ + --min-lr 0.0 \ + --lr-decay-style cosine \ + --log-interval ${LI} \ + --eval-iters 0 \ + --eval-interval 999999 \ + --save-interval 1000 \ + --data-cache-path ${DATACACHE_DIR} \ + --use-mcore-models \ + --no-create-attention-mask-in-dataloader \ + --no-mmap-bin-files \ + --split 100,0,0 \ + --clip-grad 1.0 \ + --weight-decay 0.05 \ + --adam-beta1 0.9 \ + --adam-beta2 0.999 \ + --init-method-std 0.014 \ + --bf16 \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + ${EXTRA_ARGS} \ + --distributed-timeout-minutes 60 \ + --calculate-per-token-loss \ + --attention-backend flash \ + --use-dist-ckpt \ +" + +# PROFILE_WRAPPER +if [ $PROFILE == 1 ]; then PROFILE_WRAPPER="$SCRIPT_DIR/nsys_profile_rank.sh"; +else PROFILE_WRAPPER=; fi + +# Interactive or batch mode +# if [[ $BATCH -eq 0 ]]; then +# DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` +# # if [[ $PROFILE -eq 1 ]]; then +# # nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi -o gpt_sft_hetero_cp_iter7_8_flash_global_64 torchrun --nproc_per_node ${NUM_GPU} pretrain_gpt.py ${OPTIONS} +# # else +# # torchrun --nproc_per_node ${NUM_GPU} /home/tailaim/work_data/megatron-lm/pretrain_gpt.py ${OPTIONS} +# # fi +# echo "MASTER_ADDR = ${MASTER_ADDR}, NP = ${NP}, NODE_RANK = ${NODE_RANK}, NUM_GPU = ${NUM_GPU} " +# $PROFILE_WRAPPER torchrun --master_addr ${MASTER_ADDR} --master_port=12345 --nnodes ${NP} --node_rank ${NODE_RANK} --nproc_per_node ${NUM_GPU} /m2v_model/wuguohao03/nv_teamwork/Megatron-LM/pretrain_gpt.py ${OPTIONS} | tee ${LOGS_DIR}/$DATETIME.log +# else +# if [[ $PROFILE -eq 1 ]]; then +# run_cmd="cd ${SOURCE}; nsys profile -w true -t cublas,cuda,nvtx,osrt -s cpu -c cudaProfilerApi --capture-range-end stop -o without_hetero_cp_global_%q{SLURM_PROCID} python -u pretrain_gpt.py ${OPTIONS}" +# else +# run_cmd="cd ${SOURCE}; python -u pretrain_gpt.py ${OPTIONS}" +# fi + +# DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` +# echo "run_cmd: ${run_cmd}" +# srun -l --verbose \ +# --container-image /lustre/fsw/portfolios/coreai/users/tailaim/work_data/megatron-moe-scripts/mcore-moe-pytorch25.06.sqsh \ +# --container-mounts "/lustre" \ +# --no-container-mount-home \ +# --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ +# sh -c "${run_cmd}" + +# set +x +# fi + +exec &> >(tee "${LOGS_DIR}/$DATETIME.log") +# echo "HOSTFILE = ${HOSTFILE} MASTER_ADDR = ${MASTER_ADDR}, NP = ${NP}, NUM_GPU = ${NUM_GPU} " + +cat $HOSTFILE + +set -x + +# mpirun --hostfile hostfile -np 24 cat $HOSTFILE + + # -x NVTE_DEBUG=1 \ + # -x NVTE_DEBUG_LEVEL=2 \ + # -x NCCL_ALGO=^NVLS,NVLSTree \ + # -x CUDA_DEVICE_MAX_CONNECTIONS=1 \ + # -x PYTHONPATH="$/m2v_model/wuguohao03/nv_teamwork/Megatron-LM":"/m2v_model/wangchenyu05/hot_switch/TransformerEngine":$PYTHONPATH \ + + +mpirun --allow-run-as-root --noprefix \ + ${HOSTFILE:+--hostfile "$HOSTFILE"} \ + --np $NP \ + --bind-to none --map-by slot \ + --mca plm_rsh_args "$PLM_RSH_ARGS" \ + --mca btl self,tcp \ + --mca pml ob1 \ + -mca plm_rsh_num_concurrent 600 \ + -mca routed_radix 600 \ + -mca btl_tcp_if_include bond0,eth01 \ + -mca oob_tcp_if_include bond0,eth01 \ + -mca btl_openib_allow_ib false \ + -mca opal_set_max_sys_limits 1 \ + -x HOROVOD_MPI_THREADS_DISABLE=1 \ + -x MPI_THREAD_SINGLE=1 \ + -x NCCL_IB_DISABLE=0 \ + -x NCCL_IB_GID_INDEX=3 \ + -x NCCL_IB_HCA=mlx5 \ + -x NCCL_IB_QPS_PER_CONNECTION=16 \ + -x NCCL_IB_TIMEOUT=20 \ + -x NCCL_ALGO=^NVLS,NVLSTree \ + -x NCCL_PROTO=^LL128 \ + -x KML_ID \ + -x TASK_ID \ + -x DATETIME \ + -x CREATOR \ + -x TASK_RECORD_URL \ + -x HOSTNAME \ + -x TRAIN_MODE=True \ + -x NSYS_DIR=$NSYS_DIR \ + -x PATH \ + -x PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True\ + ${LD_LIBRARY_PATH:+-x LD_LIBRARY_PATH} \ + -x PYTHONPATH:$MEGATRON_PATH:$PYTHONPATH \ + -x CUDA_DEVICE_MAX_CONNECTIONS \ + -x NCCL_IB_SL \ + -x TOKENIZERS_PARALLELISM \ + -x PYTORCH_CUDA_ALLOC_CONF \ + -x NCCL_DEBUG=WARN \ + -x http_proxy=http://oversea-squid2.ko.txyun:11080 \ + -x https_proxy=http://oversea-squid2.ko.txyun:11080 \ + -x no_proxy=localhost,127.0.0.1,localaddress,localdomain.com,internal,corp.kuaishou.com,test.gifshow.com,staging.kuaishou.com \ + $PROFILE_WRAPPER \ + with_nccl_local_env \ + python -u $MEGATRON_PATH/pretrain_gpt.py \ + ${OPTIONS} \ + --distributed-backend nccl \ + --master-addr ${MASTER_ADDR}:${MASTER_PORT} + + +exit 1 diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 0f016473b6a..be165ecd315 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -5,7 +5,6 @@ import torch from megatron.core import parallel_state -from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler from megatron.core.process_groups_config import ProcessGroupCollection @@ -15,7 +14,7 @@ class HybridCPDataLoaderWrapper: For every __next__ call, 1. Each DP rank pulls a batch of packed samples. 2. Extracts the sequence lengths of each sub-sample and all-gathers across the DP group. - 3. Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler. + 3. Schedules the sub-samples to the DPxCP ranks using the BalancedHybridCPscheduler. 4. Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all. 5. Returns the assigned sub-samples to this rank. @@ -42,7 +41,8 @@ def __init__( self.dp_cp_group is not None and self.dp_group is not None and self.tp_group is not None ), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel" - self.cp_balancing_scheduler = BalancedCPScheduler( + from megatron.core.pipeline_parallel.data_schedule import BalancedHybridCPscheduler + self.cp_balancing_scheduler = BalancedHybridCPscheduler( max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank, dp_cp_group=self.dp_cp_group ) diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index f50a6a77f57..794370a7adb 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -67,6 +67,19 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): data parallel size * context parallel size * sequence parallel size * 2. """ + hybrid_context_parallel_scheduler: str = 'balanced' + """Scheduler for hybrid context parallel. + balanced: balanced scheduler for hybrid context parallel. + only_packing_no_scheduling: scheduling is already handled by the data sampler, + this scheduler only performs packing. + """ + + sft_mock_dataset_config_json: Optional[str] = None + """This config provides the necessary information for the mock dataset.""" + + sft_sequence_packing: bool = False + """Option to enable sequence packing for SFT training.""" + def __post_init__(self) -> None: """Do asserts and set fields post init""" super().__post_init__() diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index acb93ef7853..b6994139e9b 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -966,6 +966,13 @@ def __init__( else: extra_kwargs["cp_comm_type"] = cp_comm_type + # we need to create a single stream for cp=1 and enable hybrid cp case + if ( + self.config.hybrid_context_parallel + and getattr(TEDotProductAttention, "cp_stream") is None + ): + TEDotProductAttention.cp_stream = torch.cuda.Stream() + if self.config.deterministic_mode: if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0: raise RuntimeError( @@ -1071,7 +1078,8 @@ def forward( elif packed_seq_params.local_cp_size is not None: assert ( packed_seq_params.local_cp_size == 1 - ), "local_cp_size must be == 1 if provided without cp_group" + ), f"local_cp_size must be == 1 if provided without cp_group, " + f"but got {packed_seq_params.local_cp_size}." super().set_context_parallel_group(None, None, None, self.cp_comm_type) self.kept_packed_seq_params.discard("cp_group") self.kept_packed_seq_params.discard("local_cp_size") diff --git a/megatron/core/gpu_timers.py b/megatron/core/gpu_timers.py new file mode 100644 index 00000000000..2a174b17b6f --- /dev/null +++ b/megatron/core/gpu_timers.py @@ -0,0 +1,87 @@ +import torch +from typing import Optional +from contextlib import contextmanager +from collections import defaultdict + +class GPUTimer: + + def __init__(self, use_gpu_timer): + self._starts = defaultdict(list) # name -> [Event, ...] + self._ends = defaultdict(list) # name -> [Event, ...] + self._times = defaultdict(list) # name -> [elapsed_ms, ...] + self.inactive = not use_gpu_timer + + def activate(self): + self.inactive = False + + def inactivate(self): + self.inactive = True + + def start(self, name: str = "default"): + if self.inactive: + return + evt = torch.cuda.Event(enable_timing=True) + evt.record() + self._starts[name].append(evt) + + def stop(self, name: str = "default"): + if self.inactive: + return + if not self._starts[name]: + raise ValueError(f"No start event recorded for '{name}'") + end_evt = torch.cuda.Event(enable_timing=True) + end_evt.record() + self._ends[name].append(end_evt) + + def compute(self, name: str = None): + if self.inactive: + return + keys = [name] if name is not None else list(self._starts.keys()) + for key in keys: + starts = self._starts[key] + ends = self._ends[key] + n = min(len(starts), len(ends)) + for i in range(n): + if i < len(self._times[key]): + continue + s_evt = starts[i] + e_evt = ends[i] + + e_evt.synchronize() + elapsed_ms = s_evt.elapsed_time(e_evt) + self._times[key].append(elapsed_ms) + + def elapsed(self, name: str = "default"): + if self.inactive: + return + if name not in self._times or not self._times[name]: + raise ValueError(f"No computed timings for '{name}'") + return list(self._times[name]) + + def reset(self, name: str = None): + if self.inactive: + return + if name is None: + self._starts.clear() + self._ends.clear() + self._times.clear() + else: + self._starts.pop(name, None) + self._ends.pop(name, None) + self._times.pop(name, None) + + def summary(self): + if self.inactive: + return + result = {key: list(vals) for key, vals in self._times.items()} + for name, times in result.items(): + print(f"DP rank {torch.distributed.get_rank()} {name}: {times} ms") + return result + + @contextmanager + def time(self, name: str = "default", auto_compute: bool = True): + self.start(name) + yield + self.stop(name) + if auto_compute: + self.compute(name) diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 129135c4cc0..bd144fab3e5 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -62,7 +62,7 @@ class ModelParallelConfig: can handle without overflowing the memory. Typically, a good starting point is to set this to maximum sequence length / context parallel size. This is used to calculate the number and length of sub-samples assigned to - each rank when using hybrid_context_parallel. + each rank when using sft_sequence_packing. """ hybrid_context_parallel: bool = False @@ -70,6 +70,28 @@ class ModelParallelConfig: If true, enables hybrid context parallel. This is used to balance the workload of each CP rank when we use packed samples with variable sequence lengths. Please set max_seqlen_per_dp_cp_rank when using hybrid_context_parallel. + When enabling hybrid_context_parallel, sft_sequence_packing must be true. + """ + + hybrid_context_parallel_scheduler: str = 'balanced' + """ + Scheduler for hybrid context parallel. + balanced: balanced scheduler for hybrid context parallel which provided by MCore. + only_packing_no_scheduling: scheduling is already handled by the data sampler, + this scheduler only performs packing. + """ + + sft_sequence_packing: bool = False + """ + If true, enables sft sequence packing. + """ + + balanced_sequence_packing: bool = False + """ + If true, enables balanced sequence packing. + This is used to pack samples with variable sequence lengths into a single sample + such that each packed sample has similar total sequence lengths. + This is useful to improve the efficiency of sequence packing. """ expert_model_parallel_size: int = 1 diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index a1230568cbd..b72628d515b 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -651,7 +651,6 @@ def _postprocess( else: logits = None - # Restore sequence parallel execution to the output layer if necessary. if sequence_parallel_override: assert ( in_inference_mode diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index fd0d0d9b9d9..b69efe9afc8 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -422,7 +422,7 @@ def create_hybrid_dp_cp_groups(rank, ranks, pg_options): hybrid_dp_cp_groups = {} # Generate group for every power of 2 up to the number of CP ranks # We limit the allowed group sizes in order to avoid excessive overhead. - group_sizes = [2**i for i in range(int(log2(len(ranks))))][1:] + group_sizes = [2**i for i in range(int(log2(len(ranks))))] for group_size in group_sizes: for i in range(0, len(ranks), group_size): group = create_group( @@ -559,6 +559,8 @@ def initialize_model_parallel( high_priority_stream_groups: Optional[List[str]] = None, sharp_enabled_group: Optional[str] = None, hybrid_context_parallel: bool = False, + min_hybrid_context_parallel_size: int = 1, + max_hybrid_context_parallel_size: int = 1, ) -> None: """Initialize model data parallel groups. @@ -970,6 +972,25 @@ def initialize_model_parallel( if rank in ranks: _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = hierarchical_groups + if hybrid_context_parallel: + # PyTorch is performing lazy initialization of the communicator group. + # Therefore, we need to perform a nccl call to ensure that the communicator group is created. + upper_bound = int(log2(data_parallel_size)) + if max_hybrid_context_parallel_size != -1: + upper_bound = min(int(log2(max_hybrid_context_parallel_size)), upper_bound) + group_sizes = [ + 2**i + for i in range( + int(log2(min_hybrid_context_parallel_size)), upper_bound + ) + ] + if group_sizes[-1] * 2 == data_parallel_size: + group_sizes.append(data_parallel_size) + for group_size in group_sizes: + group = get_hybrid_data_context_parallel_groups(group_size=group_size) + torch.distributed.barrier(group=group, device_ids=[torch.cuda.current_device()]) + torch.cuda.synchronize() + # Build the model-parallel groups. global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GLOBAL_RANKS @@ -1444,6 +1465,10 @@ def get_hybrid_data_context_parallel_groups(check_initialized=True, group_size=N if check_initialized: assert _DATA_PARALLEL_GROUP_WITH_CP is not None return _DATA_PARALLEL_GROUP_WITH_CP + elif group_size == 1: + if check_initialized: + assert _CONTEXT_PARALLEL_GROUP is not None + return _CONTEXT_PARALLEL_GROUP if check_initialized: assert _HYBRID_DP_CP_GROUPS is not None return _HYBRID_DP_CP_GROUPS[group_size] diff --git a/megatron/core/pipeline_parallel/data_schedule.py b/megatron/core/pipeline_parallel/data_schedule.py new file mode 100644 index 00000000000..4a2bb3b5c03 --- /dev/null +++ b/megatron/core/pipeline_parallel/data_schedule.py @@ -0,0 +1,2403 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import enum +import sys +import copy +import nvtx +from collections import deque +from functools import lru_cache +import math +from math import ceil, log2 +from typing import Callable, Dict, List, Optional, Tuple, Type, Union +import nvtx +import time + +import numpy as np +import torch +import torch.multiprocessing as mp + +from megatron.core import parallel_state +from megatron.core.datasets.megatron_dataset import MegatronDataset + +# from megatron.core.pipeline_parallel.utils import ( +# is_pp_first_stage, +# is_pp_last_stage, +# is_vp_first_stage, +# is_vp_last_stage, +# ) +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.rerun_state_machine import RerunDataIterator + +# time simulator +from megatron.pipeline_simulator.simulator.schedules import SplitFuseSchedule, InterleavedSchedule +from megatron.pipeline_simulator.simulator.solver import test_with_schedule + +class PackingScheduler(enum.Enum): + """Enum for supported sequence packing algorithms.""" + + HYBRID_CP = "hybrid_cp" + HYBRID_CP_WITH_PP = "hybrid_cp_with_pp" + NAIVE_SEQUENCE_PACKING = "naive_sequence_packing" + # schedule in data_samplers, only need to pack, no need to schedule + ONLY_PACKING_NO_SCHEDULING = "only_packing_no_scheduling" + + +def wrap_dataloader( + data_iterator, + config, + scheduler_type: Union[PackingScheduler, str], + pg_collection: Optional[ProcessGroupCollection] = None, +): + """ + A wrapper function that wraps around an existing data_iterator + and return the num_micro_batches for sequence packing. + + Args: + data_iterator: The original data_iterator to wrap around + config: The config object containing the max_seqlen_per_dp_cp_rank + dp_cp_group: Data parallel context parallel group. + """ + if torch.distributed.get_rank() == 0: print(f"{scheduler_type=}") + + scheduler_map = { + "hybrid_cp": BalancedHybridCPscheduler, + "hybrid_cp_with_pp": PipelineAwareBalancedHybridCPscheduler, + "naive": NaiveSequencePackingScheduler, + "only_packing_no_scheduling": OnlyPackingNoSchedulingScheduler, + } + + scheduler_map: Dict[PackingScheduler, Type[BaseScheduler]] = { + PackingScheduler.HYBRID_CP_WITH_PP: PipelineAwareBalancedHybridCPscheduler, + PackingScheduler.HYBRID_CP: BalancedHybridCPscheduler, + PackingScheduler.NAIVE_SEQUENCE_PACKING: NaiveSequencePackingScheduler, + PackingScheduler.ONLY_PACKING_NO_SCHEDULING: OnlyPackingNoSchedulingScheduler, + } + + def _get_global_seqlens(subsample_seqlens: torch.Tensor, dp_group) -> List[int]: + """ + Gathers the sequence lengths of all subsamples from all DP ranks. + Each DP rank loads the same number of microbatches but each microbatch + may have a different number of subsamples. + + We find the number of subsamples each rank holds and then gather the + sequence lengths of all subsamples from all ranks. + """ + # Collect the number of subsamples from all ranks + local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32) + dp_subsample_count = [torch.zeros_like(local_len) for _ in range(dp_group.size())] + torch.distributed.all_gather(dp_subsample_count, local_len, group=dp_group) + + # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length + dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) + max_sub_samples = int(dp_subsample_counts.max().item()) + + if subsample_seqlens.shape[0] < max_sub_samples: + subsample_seqlens_padded = torch.cat( + [ + subsample_seqlens, + torch.zeros(max_sub_samples - subsample_seqlens.shape[0], dtype=torch.int32, device=subsample_seqlens.device), + ], + dim=0, + ) + else: + subsample_seqlens_padded = subsample_seqlens + + # Gather the subsample_seqlens from all ranks + seqlens_gathered = [ + torch.empty_like(subsample_seqlens_padded) for _ in range(dp_group.size()) + ] + torch.distributed.all_gather(seqlens_gathered, subsample_seqlens_padded, group=dp_group) + + # Trim each seqlens_gathered to the length of the correct sample + for dp_rank, seqlen in enumerate(seqlens_gathered): + seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]] + + seqlens_gathered = torch.cat(seqlens_gathered, dim=0) + seqlens_gathered = seqlens_gathered.cpu().tolist() + + # Calculate the offsets to assign unique global ID to each subsample. + csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32) + offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum[:-1]], dim=0) + + nvtx.pop_range() + return seqlens_gathered, offsets + + def _get_global_id_seqlens(num_local_subsamples, offsets, seqlens_gathered, dp_group): + """ + Calculates the global ID for each subsample. + + We assign a unique global ID to each subsample. + + Returns: + global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. + global_ids_this_rank: list of global IDs locally present on this rank. + """ + nvtx.push_range("_get_global_id_seqlens") + dp_rank = dp_group.rank() + global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32) + # Create a list of (global_id, seqlen) tuples for scheduling + global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] + # Get the global IDs locally present on this rank + global_ids_this_rank = global_ids[ + offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples + ] + + nvtx.pop_range() + return global_id_seqlens, global_ids_this_rank + + def _gid_to_src_rank(gid: int, offsets: List[int], dp_group, tp_group, dp_cp_group) -> int: + dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) + # Since the torch.distributed.get_process_group_ranks + # provides the global rank, we need to consider TP + hdp_rank = ( + torch.distributed.get_process_group_ranks(dp_group)[dp_src_rank] // tp_group.size() + ) % dp_cp_group.size() + return hdp_rank + + def cast_inputs_device(inputs, device, skip_device=[]): + if isinstance(inputs, (list, tuple)): + return inputs.__class__(cast_inputs_device(v, device, skip_device) for v in inputs) + elif isinstance(inputs, dict): + return {k: v if k in skip_device else cast_inputs_device(v, device, skip_device=skip_device) for k, v in inputs.items()} + elif isinstance(inputs, torch.Tensor): + if not inputs.is_cuda: + inputs = inputs.to(device=device, non_blocking=True) # here input is expected to be pinned + + return inputs + + def _reroute_samples_to_hdp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_hdp_gpus, + ): + """ + Reroutes the sub-samples to the correct rank after scheduling. + + For each key in the batch dict, we perform an all-to-all communication + to transfer the data to the correct ranks. + Since all CP ranks within a DP group have the same data, we only need + to transfer data between matching CP ranks. + """ + nvtx.push_range("_reroute_samples_to_hdp_ranks") + gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} + hdp_rank = dp_cp_group.rank() + dp_ranks = torch.distributed.get_process_group_ranks(dp_group) + # Here we actually want to get the DP group's rank within the HDP group, + # we need to consider TP + # tp-cp-ep-dp-pp + dp_ranks = [(r // tp_group.size()) % dp_cp_group.size() for r in dp_ranks] + + data_keys = batch[0].keys() + + # Create the send plan + combined_sample_id_groups: List[List[int]] = [[] for _ in range(total_hdp_gpus)] + + for d in range(total_hdp_gpus): + for sample_id_group in sample_id_groups: + combined_sample_id_groups[d].extend(sample_id_group[d]) + + for dest_rank in range(total_hdp_gpus): + combined_sample_id_groups[dest_rank].sort() + + # Filter out samples that are not present on this rank + send_ids_sorted = [ + gid + for d in dp_ranks + for gid in combined_sample_id_groups[d] + if gid in global_ids_this_rank + ] + # send_counts = [len(combined_sample_id_groups[d]) for d in range(total_hdp_gpus)] + + send_num_split = [0] * total_hdp_gpus + send_lens_split = [0] * total_hdp_gpus + for dest_rank in range(total_hdp_gpus): + if dest_rank in dp_ranks: + send_seq_lens = [ + global_id_seqlens[gid][1] + for gid in combined_sample_id_groups[dest_rank] + if gid in global_ids_this_rank + ] + send_num_split[dest_rank] = len(send_seq_lens) + send_lens_split[dest_rank] = sum(send_seq_lens) + else: + # We only need to share local data with DP ranks that have different data. + send_lens_split[dest_rank] = 0 + + # Create the recv plan + recv_sample_id_groups = [[] for _ in range(total_hdp_gpus)] + for gid in combined_sample_id_groups[hdp_rank]: + src_rank = _gid_to_src_rank(gid, offsets, dp_group, tp_group, dp_cp_group) + recv_sample_id_groups[src_rank].append(gid) + + recv_lens_split = [0] * total_hdp_gpus + for src_rank in range(total_hdp_gpus): + recv_lens_split[src_rank] = sum( + [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] + ) + + recv_ids_sorted = [gid for d in range(total_hdp_gpus) for gid in recv_sample_id_groups[d]] + recv_counts = [len(recv_sample_id_groups[d]) for d in range(total_hdp_gpus)] + + recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] + + def _pack_sample_by_key(key: str) -> torch.Tensor: + flattened_tensors = [] + for gid in send_ids_sorted: + t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) + # flattened_tensors.append(t) + flattened_tensors.append(t.reshape(-1)) + return ( + torch.cat(flattened_tensors, dim=0) + if flattened_tensors + else torch.empty(0, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) + ) + + def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): + cursor = 0 + for i, gid in enumerate(recv_ids_sorted): + sample_len = 1 if key in ["original_seq_len"] else global_id_seqlens[gid][1] + recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] + cursor += sample_len + + for key in data_keys: + output_split_sizes, input_split_sizes = ( + (recv_counts, send_num_split) + if key in ["original_seq_len"] + else (recv_lens_split, send_lens_split) + ) + send_tensor = _pack_sample_by_key(key) + recv_tensor_size = sum(output_split_sizes) + recv_tensor = torch.empty( + recv_tensor_size, device=torch.cuda.current_device(), dtype=send_tensor.dtype + ) + torch.distributed.all_to_all_single( + output=recv_tensor, + input=send_tensor, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=dp_cp_group, + ) + _unpack_sample_by_key(key, recv_tensor) + + recv_sample_with_id = { + recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted) + } + nvtx.pop_range() + return recv_sample_with_id + + def _unpack_batch(batch): + """ + Unpacks the packed samples into a list of sub-samples. + Since each sub-sample may be routed to different DPxCP ranks, + we unpack the sample here to avoid unnecessarily transferring + the entire packed sample. + """ + batch_unpacked = [] + for sample in batch: + sample_dict = {} + for key in sample.keys(): + if key in ["cu_seqlens", "batch_idx", "max_seqlen"]: + continue + sample_dict[key] = sample[key] + batch_unpacked.append(sample_dict) + return batch_unpacked + + def _broadcast_to_tp_group(item): + if item is not None: + torch.distributed.broadcast( + item, + parallel_state.get_tensor_model_parallel_src_rank(), + group=parallel_state.get_tensor_model_parallel_group(), + ) + + def _broadcast_to_pp_group(item): + if item is not None: + torch.distributed.broadcast( + item, + parallel_state.get_pipeline_model_parallel_first_rank(), + group=parallel_state.get_pipeline_model_parallel_group(), + ) + + def _pack_sequences( + samples: List[MegatronDataset], partner_cp_size: Optional[int] = None + ) -> Dict[str, torch.Tensor]: + # TODO(tailaim): do we need attention_mask for sequence packing? + nvtx.push_range("_pack_sequences") + def _pack_tensors(tensors): + if len(tensors) == 1: + return tensors[0].reshape(-1) + return torch.cat([t.reshape(-1) for t in tensors], dim=0) + + nvtx.push_range("_pack_tensors") + tokens = _pack_tensors([sample["tokens"] for sample in samples]) + labels = _pack_tensors([sample["labels"] for sample in samples]) + loss_mask = _pack_tensors([sample["loss_mask"] for sample in samples]) + position_ids = _pack_tensors([sample["position_ids"] for sample in samples]) + nvtx.pop_range() + + new_sample = {} + new_sample["tokens"] = tokens + new_sample["labels"] = labels + new_sample["loss_mask"] = loss_mask + new_sample["position_ids"] = position_ids + if partner_cp_size is not None: + nvtx.push_range("create local_cp_size") + new_sample["local_cp_size"] = torch.tensor( + partner_cp_size, dtype=torch.int32 + ) + nvtx.pop_range() + + # create cu_seqlens_padded + lengths_padding = torch.tensor([s["tokens"].numel() for s in samples]).reshape(-1) + # create max_seqlen + max_seqlen = lengths_padding.max().int() + new_sample["max_seqlen"] = max_seqlen + lengths_padding = lengths_padding.pin_memory().to(dev, non_blocking=True) + cu_seqlens_padded = torch.cat([torch.zeros(1, dtype=torch.int32, device=lengths_padding.device), torch.cumsum(lengths_padding, dim=0, dtype=torch.int32).reshape(-1)], dim=0) + new_sample["cu_seqlens_padded"] = cu_seqlens_padded + + + # create cu_seqlens without padding + lengths = torch.stack([s["original_seq_len"] for s in samples], dim=0).reshape(-1) + cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32, device=lengths.device), torch.cumsum(lengths, dim=0, dtype=torch.int32).reshape(-1)], dim=0) + new_sample["cu_seqlens"] = cu_seqlens + nvtx.pop_range() + return new_sample + + # Convert string to enum if needed + if isinstance(scheduler_type, str): + try: + scheduler_type = PackingScheduler[scheduler_type.upper()] + except KeyError: + available_scheduler = ", ".join([scheduler.name for scheduler in PackingScheduler]) + raise ValueError( + f"Unknown packing scheduler: {scheduler_type}. " + f"Available schedulers: {available_scheduler}" + ) + + if scheduler_type not in scheduler_map: + available_scheduler = ", ".join([scheduler.name for scheduler in PackingScheduler]) + raise ValueError( + f"Unknown scheduler: {scheduler}. " f"Available schedulers: {available_scheduler}" + ) + + scheduler = scheduler_map[scheduler_type](config) + if pg_collection is None: + dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + dp_group_gloo = parallel_state.get_data_parallel_group_gloo() + dp_group = parallel_state.get_data_parallel_group() + tp_group = parallel_state.get_tensor_model_parallel_group() + pp_group = parallel_state.get_pipeline_model_parallel_group() + else: + dp_cp_group = pg_collection.dp_cp + dp_group = pg_collection.dp + tp_group = pg_collection.tp + pp_group = pg_collection.pp + assert ( + dp_cp_group is not None and dp_group is not None and tp_group is not None + ), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel" + + total_hdp_gpus = dp_cp_group.size() + dev = torch.cuda.current_device() + + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + if pp_group.rank() == pp_group.size() - 1: + assert len(data_iterator) == config.virtual_pipeline_model_parallel_size + data_iterator = data_iterator[-1] + else: + data_iterator = data_iterator[0] + + def cast_inputs_device(inputs, device, skip_device={}): + if isinstance(inputs, (list, tuple)): + return inputs.__class__(cast_inputs_device(v, device, skip_device) for v in inputs) + elif isinstance(inputs, dict): + return {k: v if k in skip_device else cast_inputs_device(v, device, skip_device=skip_device) for k, v in inputs.items()} + elif isinstance(inputs, torch.Tensor): + if not inputs.is_cuda: + inputs = inputs.to(device=device, non_blocking=True) # here input is expected to be pinned + + return inputs + + if data_iterator is not None: + # indicates TP rank 0, with PP stage 0 or -1. + local_cp_size = None + if scheduler_type is PackingScheduler.ONLY_PACKING_NO_SCHEDULING: + # ONLY_PACKING_NO_SCHEDULING scheduler does not schedule the data, + # just packing sequences + + # batch is a list of samples: List[MegatronDataset] + batch = next(data_iterator) + batch = cast_inputs_device(batch, dev) + # print(f"{batch=}") + num_micro_batches = batch[0]["num_micro_batches_left"] + 1 + + batch_all = [batch] + [next(data_iterator) for _ in range(num_micro_batches - 1)] + batch_all = cast_inputs_device(batch_all, dev) + + # calculate this two values for tflops calculation + seqlens_gathered = [ + sample["tokens"].numel() for samples in batch_all for sample in samples + ] + num_total_tokens = 0 + sequence_square_sum = 0 + + # pack sequences in the same group and create a new data iterator + new_samples = [] + for samples in batch_all: + partner_cp_size = samples[0]["local_cp_size"] + new_sample = _pack_sequences(samples, partner_cp_size) + new_samples.append(new_sample) + for sample in samples: + num_total_tokens += sample["tokens"].numel() / partner_cp_size + sequence_square_sum += sample["tokens"].numel() ** 2 / partner_cp_size + + elif ( + scheduler_type is PackingScheduler.HYBRID_CP + or scheduler_type is PackingScheduler.HYBRID_CP_WITH_PP + or scheduler_type is PackingScheduler.NAIVE_SEQUENCE_PACKING + ): + batch = next(data_iterator) + batch = cast_inputs_device(batch, dev) + subsample_seqlens = [] + for sample in batch: + subsample_seqlens.extend([sample["tokens"].numel()]) + subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32) + subsample_seqlens = subsample_seqlens[subsample_seqlens != 0] + + nvtx.push_range("_get_global_seqlens") + seqlens_gathered, offsets = _get_global_seqlens(subsample_seqlens, dp_group_gloo) + nvtx.pop_range() + + nvtx.push_range("_get_global_id_seqlens") + global_id_seqlens, global_ids_this_rank = _get_global_id_seqlens( + subsample_seqlens.shape[0], offsets, seqlens_gathered, dp_group_gloo + ) + nvtx.pop_range() + + nvtx.push_range("scheduler.get_groups_and_subsamples") + groups, sample_id_groups = scheduler.get_groups_and_subsamples( + global_id_seqlens, config + ) + nvtx.pop_range() + + set_gbs = set() + for group in sample_id_groups: + for sub in group: + set_gbs.update(sub) + assert len(set_gbs) == len( + global_id_seqlens + ), f"set_gbs length: {len(set_gbs)} \ + != global_ids_this_rank length: {len(global_id_seqlens)}" + + nvtx.push_range("_unpack_batch") + batch = _unpack_batch(batch) + nvtx.pop_range() + + nvtx.push_range("_reroute_samples_to_hdp_ranks") + samples_this_rank_with_id = _reroute_samples_to_hdp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_hdp_gpus, + ) + nvtx.pop_range() + batch, sample_id_groups = samples_this_rank_with_id, sample_id_groups + + hdp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + num_micro_batches = len(sample_id_groups) + # calculate this two values for tflops calculation + num_total_tokens_this_GB = np.int64(sum(seqlens_gathered)) + sequence_square_sum_this_GB = np.int64(sum(seqlen**2 for seqlen in seqlens_gathered)) + + new_samples = [] + cp_sizes = [] + for i in range(num_micro_batches): + # pack sequences in the same group and create a new data iterator + sample_ids_this_group = sample_id_groups[i][hdp_rank] + samples = [batch[sub_sample_id] for sub_sample_id in sample_ids_this_group] + partner_cp_size = ( + len( + [ + True + for sample_ids in sample_id_groups[i] + if sample_ids_this_group[0] in sample_ids + ] + ) + if config.hybrid_context_parallel + else None + ) + cp_sizes.append(partner_cp_size) + new_sample = _pack_sequences(samples, partner_cp_size) + new_samples.append(new_sample) + + + if scheduler_type is PackingScheduler.ONLY_PACKING_NO_SCHEDULING: + # allreduce to get the total number of microbatches + mfu_info_to_broadcast_this_hdp_group = torch.tensor( + [num_total_tokens, sequence_square_sum], dtype=torch.int64, pin_memory=True + ).to(dev, non_blocking=True) + torch.distributed.all_reduce(mfu_info_to_broadcast_this_hdp_group, group=dp_cp_group) + num_total_tokens_this_GB = mfu_info_to_broadcast_this_hdp_group[0].item() + sequence_square_sum_this_GB = mfu_info_to_broadcast_this_hdp_group[1].item() + + # # broadcast num_micro_batches, num_total_tokens_this_GB, sequence_square_sum_this_GB, + # # and packed_seq_params to tp group + # if pp_group.size() > 2 and tp_group.rank() == 0: + # if pp_group.rank() == 0: + # tensor_list = [ + # torch.tensor( + # [num_micro_batches, num_total_tokens_this_GB, sequence_square_sum_this_GB], + # dtype=torch.int64, + # ).cuda() + # ] + # for sample in new_samples: + # tensor_list.append(sample["max_seqlen"].unsqueeze(0)) + # for sample in new_samples: + # tensor_list.append( + # sample["local_cp_size"].unsqueeze(0) + # if scheduler_type is PackingScheduler.HYBRID_CP + # else torch.tensor([-1], dtype=torch.int32).cuda() + # ) + # for sample in new_samples: + # tensor_list.append(sample["cu_seqlens"]) + # tensor_list.append(sample["cu_seqlens_padded"]) + # info_to_broadcast_this_pp_group = torch.cat(tensor_list, dim=0).to( + # device=dev, dtype=torch.int64 + # ) + # info_length_tensor = torch.tensor( + # info_to_broadcast_this_pp_group.shape[0], dtype=torch.int32 + # ).cuda() + # _broadcast_to_pp_group(info_length_tensor) + # _broadcast_to_pp_group(info_to_broadcast_this_pp_group) + # else: + # info_length_tensor = torch.tensor(0, dtype=torch.int32).cuda() + # _broadcast_to_pp_group(info_length_tensor) + # info_to_broadcast_this_pp_group = torch.empty( + # info_length_tensor.item(), dtype=torch.int64 + # ).cuda() + # _broadcast_to_pp_group(info_to_broadcast_this_pp_group) + # if pp_group.rank() != pp_group.size() - 1: + # info_numpy = info_to_broadcast_this_pp_group.cpu().numpy() + # num_micro_batches = info_numpy[0] + # num_total_tokens_this_GB = info_numpy[1] + # sequence_square_sum_this_GB = info_numpy[2] + # max_seqlens = info_numpy[3 : 3 + num_micro_batches] + # local_cp_sizes = info_numpy[3 + num_micro_batches : 3 + 2 * num_micro_batches] + # cu_seqlens_list = [] + # cu_seqlens_padded_list = [] + # indices = np.where(info_numpy == 0)[0] + # for i in range(num_micro_batches): + # cu_seqlens_list.append(info_numpy[indices[i * 2] : indices[i * 2 + 1]]) + # if i == num_micro_batches - 1: + # cu_seqlens_padded_list.append(info_numpy[indices[i * 2 + 1] :]) + # else: + # cu_seqlens_padded_list.append( + # info_numpy[indices[i * 2 + 1] : indices[i * 2 + 2]] + # ) + + # new_samples = [] + # for i in range(num_micro_batches): + # new_sample = {} + # new_sample["max_seqlen"] = torch.tensor( + # max_seqlens[i], dtype=torch.int32 + # ).cuda() + # if local_cp_sizes[i] != -1: + # new_sample["local_cp_size"] = torch.tensor( + # local_cp_sizes[i], dtype=torch.int32 + # ).cuda() + # new_sample["cu_seqlens"] = torch.tensor( + # cu_seqlens_list[i], dtype=torch.int32 + # ).cuda() + # new_sample["cu_seqlens_padded"] = torch.tensor( + # cu_seqlens_padded_list[i], dtype=torch.int32 + # ).cuda() + # new_samples.append(new_sample) + + if data_iterator is not None: + def pin_tensor(sample): + for k, v in sample.items(): + if isinstance(v, torch.Tensor) and v.device == torch.device("cpu"): + sample[k] = v.pin_memory() + return sample + new_samples = type(new_samples)(map(pin_tensor, new_samples)) + + if tp_group.size() > 1: + if tp_group.rank() == 0: + info_to_broadcast_this_tpgroup = torch.tensor( + [num_micro_batches, num_total_tokens_this_GB, sequence_square_sum_this_GB], + dtype=torch.int64, + pin_memory=True + ).to(dev, non_blocking=True) + _broadcast_to_tp_group(info_to_broadcast_this_tpgroup) + else: + info_to_broadcast_this_tpgroup = torch.zeros(3, dtype=torch.int64, device=dev) + _broadcast_to_tp_group(info_to_broadcast_this_tpgroup) + info_numpy = info_to_broadcast_this_tpgroup.cpu().numpy() + (num_micro_batches, num_total_tokens_this_GB, sequence_square_sum_this_GB) = info_numpy[ + :3 + ] + + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + vpp_size = config.virtual_pipeline_model_parallel_size + if tp_group.rank() == 0: + if pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1: + new_samples_for_other_ppstage = [] + for sample in new_samples: + new_sample_for_other_ppstage = {} + new_sample_for_other_ppstage["max_seqlen"] = sample["max_seqlen"] + new_sample_for_other_ppstage["cu_seqlens"] = sample["cu_seqlens"] + new_sample_for_other_ppstage["cu_seqlens_padded"] = sample["cu_seqlens_padded"] + if config.hybrid_context_parallel: + new_sample_for_other_ppstage["local_cp_size"] = sample["local_cp_size"] + new_samples_for_other_ppstage.append(new_sample_for_other_ppstage) + if pp_group.rank() == 0: + new_data_iterator = [RerunDataIterator(iter(new_samples))] + [ + RerunDataIterator(iter(new_samples_for_other_ppstage)) + for _ in range(vpp_size - 1) + ] + else: + new_data_iterator = [ + RerunDataIterator(iter(new_samples_for_other_ppstage)) + for _ in range(vpp_size - 1) + ] + [RerunDataIterator(iter(new_samples))] + else: + new_data_iterator = [RerunDataIterator(iter(new_samples)) for _ in range(vpp_size)] + else: + new_data_iterator = [None for _ in range(vpp_size)] + else: + new_data_iterator = RerunDataIterator(iter(new_samples)) if tp_group.rank() == 0 else None + + return ( + new_data_iterator, + num_micro_batches, + num_total_tokens_this_GB, + sequence_square_sum_this_GB, + ) + + +class BaseScheduler: + """ + Base class for sequence packing schedulers. + """ + + def __init__(self, config): + pass + + +class NaiveSequencePackingScheduler(BaseScheduler): + """ + This scheduler simply packs sequences in their original order + until reaching the max sequence length. + It does not reorder sequences nor perform any load balancing. + """ + + def __init__(self, config): + super().__init__(config) + self.dp_size = int(parallel_state.get_data_parallel_world_size()) + self.cp_size = int(parallel_state.get_context_parallel_world_size()) + self.max_seq_len_all_ranks = config.max_seqlen_per_dp_cp_rank * self.cp_size + + def get_groups_and_subsamples(self, sample_id_seqlens, config): + """ + This scheduler simply packs sequences in their original order + until reaching the max sequence length. + It does not reorder sequences nor perform any load balancing. + """ + groups = [] + sample_id_groups = [] + packed_id_groups = [] + sum_seqlen = 0 + single_microbatch = [] + + for i in range(len(sample_id_seqlens)): + if sum_seqlen + sample_id_seqlens[i][1] <= self.max_seq_len_all_ranks: + # if flag and sum_seqlen + sample_id_seqlens[i][1] <= self.max_seq_len_all_ranks: + # flag = False + single_microbatch.append(i) + sum_seqlen += sample_id_seqlens[i][1] + else: + packed_id_groups.append(single_microbatch) + single_microbatch = [i] + sum_seqlen = sample_id_seqlens[i][1] + if len(single_microbatch) > 0: + packed_id_groups.append(single_microbatch) + + gbs_sum = 0 + for i in packed_id_groups: + gbs_sum += len(i) + assert gbs_sum == len( + sample_id_seqlens + ), f"gbs_sum: {gbs_sum} != sample_id_seqlens length: {len(sample_id_seqlens)}" + + groups.append(single_microbatch) + packed_id_groups.append(single_microbatch) + + # we want the number of packed sequences to be multiple of dp_size + # so we move few samples from previous microbatch + # to the end of the microbatches if needed + num_packed_sequence = len(packed_id_groups) + if num_packed_sequence % self.dp_size != 0: + # print(f"{num_packed_sequence=}, {self.dp_size=}, {len(sample_id_seqlens)=}") + remainder = num_packed_sequence % self.dp_size + num_to_move = self.dp_size - remainder + i = num_packed_sequence - 1 + while num_to_move > 0: + assert i > 0, "Not enough samples to move" + if len(packed_id_groups[i]) > 1: + seq_id = packed_id_groups[i].pop() + packed_id_groups.append([seq_id]) + num_to_move -= 1 + else: + i -= 1 + + num_micro_batches = int(len(packed_id_groups) / self.dp_size) + for i in range(num_micro_batches): + sample_id_groups.append([]) + for j in range(self.cp_size * self.dp_size): + seq_id = int(i * self.dp_size + j / self.cp_size) + sample_id_groups[i].append(packed_id_groups[seq_id]) + return groups, sample_id_groups + + +class BalancedHybridCPscheduler(BaseScheduler): + """ + This class provides the functionality to form groups of sub-samples + such that all DPxCP ranks have a roughly balanced workload in the group. + """ + + def __init__(self, config): + super().__init__(config) + self.max_seq_len_per_rank = config.max_seqlen_per_dp_cp_rank + self.num_subsamples = 0 + self.num_subsamples_processed = 0 + self.free_resources = [] + self.total_hdp_gpus = parallel_state.get_data_parallel_world_size( + with_context_parallel=True + ) + + @lru_cache(maxsize=128) + def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): + """ + seq_length: sequence length of a sub-sample + cp_size: total number of CP ranks working on this sub-sample + + Note: + This function is used to estimate the relative workload intensity + of a sub-sample. This is not meant to be an accurate flops calculator. + + Returns: workload of a sub-sample + """ + if cp_size is None: + cp_size = self.gpus_needed(seq_length) + return (seq_length * seq_length) / cp_size + + @lru_cache(maxsize=128) + def gpus_needed(self, seq_len: int) -> int: + """ + Calculates the number of GPUs needed for a given sequence length + and max sequence length per CP rank. + This is used to determine the CP size of a sub-sample. + + The number is rounded up to the next power of 2 to match the available + hybrid context parallel process group sizes. + """ + return max(1, 2 ** ceil(log2((seq_len / self.max_seq_len_per_rank)))) + + def make_buckets_equal( + self, + sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples + compute_estimator: Callable[[int], float], + ) -> List[deque]: + """ + Makes as many buckets as unique CP sizes needed. + This keeps sample IDs tethered to their sequence lengths throughout the bucketing process. + """ + # Extract just the sequence lengths for determining k + seqlens = [seq_len for _, seq_len in sample_seqlens] + + # Determine k based on unique GPU categories needed + k = len({self.gpus_needed(L) for L in seqlens}) + + # Create a work target for each bucket + # This is the total work divided by the number of buckets + work = [] + for _, s in sample_seqlens: + cp_size = self.gpus_needed(s) + work.append(compute_estimator(s, cp_size)) + total_work = sum(work) + target = total_work / k + buckets, cur, cur_work = [], [], 0.0 + remaining_work = total_work + remaining_k = k + + for i, (sample_id, seq_len) in enumerate(sample_seqlens): + work = compute_estimator(seq_len) + projected = cur_work + work + + # Check if we should close this bucket + if cur and ( + projected > target * 1.1 # Too much work + or len(sample_seqlens) - i <= remaining_k - len(buckets) + ): # Need to save sequences for remaining buckets + buckets.append(deque(cur)) + cur, cur_work = [], 0.0 + remaining_work -= sum(compute_estimator(seq_len) for _, seq_len in cur) + remaining_k -= 1 + + cur.append((sample_id, seq_len)) + cur_work += work + + if cur: + buckets.append(deque(cur)) + + return buckets + + def next_hdp_group( + self, + sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples + compute_estimator: Callable[[int], float], + total_gpus: int, + delta: float = 0.05, # balance slack (e.g. 5 %) + strategy: str = "dp", # "dp" or "pp" + eps_bucket: float = 0.10, # ε target for bucket balance + ) -> Tuple[List[List[int]], List[Tuple[int, int]], List[float], List[List[int]]]: + """ + Given a list of (sample_id, sequence_length) tuples, this function aims to assign + sequences in a group such that all GPUs in the DPxCP group have a roughly balanced + workload. Once each group is roughly balanced, we exit and return the + group and the leftover sequences. + + The function performs the following passes in order to form a balanced microbatch: + 1. We create buckets of sequences that are roughly balanced. + We try to create as many buckets as possible CP sizes. + 2. Given a bucket has sequences available, we assign the sample + a. To a new set of GPUs if there are enough free GPUs. + b. To an existing set of GPUs with the lowest load. + 3. We check if the group is balanced whenever we need to move onto a new CP size + in the same set of GPUs. + 4. We trim the group if removing the last added sequence helps improve balance. + 5. If we run out of sequences to assign and there are empty GPUs, + we redistribute work to empty GPUs by recursively increasing the CP size of a + sample until no empty GPUs are left. + + #TODO: Add clarification on when we check for balance. What does prev_needed do? + + Returns (micro_batches, leftover_sample_seqlens, exec_times, sample_ids_per_gpu). + """ + if not sample_seqlens: + return ( + [[] for _ in range(total_gpus)], + [], + [0.0 for _ in range(total_gpus)], + [[] for _ in range(total_gpus)], + ) + + # Get buckets of sequences with balanced work + buckets = self.make_buckets_equal(sample_seqlens, compute_estimator) + + # Initialize tracking structures + micro_batches = [[] for _ in range(total_gpus)] + exec_times = [0.0 for _ in range(total_gpus)] + sample_ids_per_gpu = [[] for _ in range(total_gpus)] + # gid : seq_len + packing_sequence_len = {} + + gpu_group_id = [None] * total_gpus + group_members = {} + group_size = {} + next_gid = 0 + + pp_cursor = 0 + prev_needed = None + check_balance = False + + while buckets: + # ---- Step 1 – pick the next sequence we COULD place ------------------ + sample_seq_tuple = bucket_idx = None + needed = None + + scan_order = ( + range(len(buckets)) + if strategy == "dp" + else [(pp_cursor + i) % len(buckets) for i in range(len(buckets))] + ) + + for idx in scan_order: + if not buckets[idx]: + continue + cand_tuple = buckets[idx][0] # This is now (sample_id, seq_len) + cand_seq_len = cand_tuple[1] + needed = self.gpus_needed(cand_seq_len) + + # (a) Do we have an *existing* group of size `needed`? + candidate_gids = [gid for gid, sz in group_size.items() if sz == needed] + + # (b) Or enough completely free GPUs to start a new group? + free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] + if candidate_gids or len(free_ranks) >= needed: + sample_seq_tuple, bucket_idx = cand_tuple, idx + break + + # No place to put any remaining sequence – finish this micro‑batch + if sample_seq_tuple is None: + break + + # TODO[pmannan]: PP not yet supported. Add PP scheduling. + if strategy == "pp": + pp_cursor = (bucket_idx + 1) % len(buckets) + + sample_id, seq_len = sample_seq_tuple + needed = self.gpus_needed(seq_len) + if prev_needed is None: + prev_needed = needed + + # (a) Existing groups of exactly this size + candidate_gids = [ + gid + for gid, sz in group_size.items() + if sz == needed + and packing_sequence_len[gid] + seq_len / needed <= self.max_seq_len_per_rank + ] + if candidate_gids: + best_gid, best_load = min( + ( + (gid, max(exec_times[r] for r in group_members[gid])) + for gid in candidate_gids + ), + key=lambda t: t[1], + ) + else: + best_gid, best_load = None, float("inf") + + # (b) Hypothetical **new** group from completely free GPUs + free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] + if len(free_ranks) >= needed: + free_sorted = sorted(free_ranks, key=lambda r: exec_times[r]) + new_members = free_sorted[:needed] + new_load = exec_times[new_members[-1]] + + if new_load < best_load: + best_gid = None + chosen_members = new_members + else: + chosen_members = group_members[best_gid] + else: + if best_gid is None: + break + chosen_members = group_members[best_gid] + + # ---- Step 2 – if we decided to create a fresh group ---------------- + if best_gid is None: + best_gid = next_gid + next_gid += 1 + group_members[best_gid] = chosen_members + group_size[best_gid] = needed + for r in chosen_members: + gpu_group_id[r] = best_gid + + # ---- Step 3 – assign the sequence to every member of that group ------ + per_gpu_cost = compute_estimator(seq_len) + + packing_sequence_len[best_gid] = ( + packing_sequence_len.get(best_gid, 0) + seq_len / needed + ) + for r in chosen_members: + micro_batches[r].append(seq_len) + exec_times[r] += per_gpu_cost + sample_ids_per_gpu[r].append(sample_id) + + # Remove the sequence definitively from its bucket + buckets[bucket_idx].popleft() + + # ---- Step 4 – tidy, balance‑check, maybe early‑exit ------------------ + while buckets and not buckets[0]: + buckets.pop(0) + pp_cursor %= max(1, len(buckets)) + + # TODO: Removing this helps reduce the number of groups when we have + # lots of samples with same CP size. + # But because we don't exit as soon as we get balanced, + # even if there is one group available that can take the next sample, + # we will keep adding samples to the same group. + # trim_overload() does not help because it only checks if removing the + # last added sample helps. + # We cannot check after adding every sample because there will always be imbalance + # if we don't wait for future scheduling. + + # IMPORTANT: So we need a solution here + if needed < prev_needed: + # When we get into a lower CP size in the same group, + # we can start checking for balance. There is still a gotcha here. + # Let's say we have a group of 3 GPU 0-2, then we move onto group of 2. + # We keep assigning group of 2 as we do in descending order but GPU 7/15 + # never sees a microbatch assigned to it + # until we run out of samples with CP2. + # This means we are never balanced as min(exec_times) will always be 0. + # We need a smart way of identifying that we have run out of big samples + # and if we are having to assign work to a GPU already working, + # is it because there are empty GPUs? + # Would assigning work to empty GPUs first by moving onto next CP bucket help? + # But we need to remember to come back to this CP size bucket and then + # check for balance. Maybe the scheduling algorithm should look at empty + # GPUs and find work rather than going sequence by sequence. + check_balance = True + + if ( + check_balance + and buckets + and max(exec_times) - min(exec_times) <= delta * max(exec_times) + ): + break + + # Gather leftovers (flatten remaining buckets, preserve order) + leftovers = [] + for b in buckets: + for sample_seq_tuple in b: + leftovers.append(sample_seq_tuple) + + # --------------------------------------------------------------------------- + def trim_overload(): + """ + Iteratively pop the most-recent sequence from the *most-loaded group* + whenever doing so reduces the global slack. + """ + while True: + cur_max = max(exec_times) + cur_min = min(exec_times) + cur_slack = cur_max - cur_min + if cur_slack <= delta * cur_max: + # Slack is already within limit. + break + if cur_min == 0: + # There are empty GPUs that will be + # handled in the next step. + break + + max_r = exec_times.index(cur_max) + gid = gpu_group_id[max_r] + members = group_members[gid] + + if not micro_batches[max_r] or len(micro_batches[max_r]) <= 1: + break + + seq = micro_batches[max_r][-1] + need = group_size[gid] + per_gpu_cost = compute_estimator(seq) + + proj_times = exec_times[:] + for r in members: + proj_times[r] -= per_gpu_cost + + proj_slack = max(proj_times) - min(proj_times) + + # Check if trimming the workload helps imbalance + if proj_slack < cur_slack: + sample_id_to_remove = sample_ids_per_gpu[max_r][-1] + for r in members: + micro_batches[r].pop() + exec_times[r] -= per_gpu_cost + sample_ids_per_gpu[r].pop() + leftovers.append((sample_id_to_remove, seq)) + else: + break + + # TODO(tailaim): uncomment this to support different ranks have different num_microbatches + # trim_overload() + + # Track samples in this group before redistribution to empty GPUs + total_work_before = sum(len(mb) for mb in micro_batches) + + # Check for empty GPUs and redistribute work + def fill_empty_gpus( + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size + ): + """ + Recursively check for empty GPUs and redistribute work by increasing + the number of GPUs sharing samples. This ensures all GPUs have work. + GPUs must be allocated consecutively so we may need to push existing + work to other ranks in order to expand samples. + """ + # Find empty GPUs + empty_gpus = [i for i in range(total_gpus) if not micro_batches[i]] + if not empty_gpus: + return ( + micro_batches, + exec_times, + sample_ids_per_gpu, + group_members, + group_size, + ) # No empty GPUs, we're done + + # Find the smallest group size that exists + existing_group_sizes = set(group_size.values()) + assert ( + existing_group_sizes + ), "There should be at least one group existing, cannot reditribute, " + "try to increase 'max-seqlen-per-dp-cp-rank'." + + min_group_size = min(existing_group_sizes) + # We have Hybrid DPxCP groups for every power of 2 of GPUs or the entire DPxCP group. + next_power = min(min_group_size * 2, total_gpus) + + # Find the first group of min_group_size that can be expanded + expandable_gid = None + expandable_members = None + expandable_new_gpus = None + + for gid, size in group_size.items(): + if size == min_group_size: + members = group_members[gid] + needed_count = next_power - min_group_size + group_start_gpu = members[0] + group_end_gpu = members[-1] + empty_gpu = [idx for idx, work in enumerate(micro_batches) if not work][0] + assert not all( + work for work in micro_batches[empty_gpu : empty_gpu + needed_count] + ), f"Empty GPUs were detected but not enough to expand." + work_to_push = micro_batches[ + group_end_gpu + 1 : empty_gpu + ] # This is work of all other subsequent sub-samples + exec_times_to_push = exec_times[group_end_gpu + 1 : empty_gpu] + sample_ids_to_push = sample_ids_per_gpu[group_end_gpu + 1 : empty_gpu] + + new_micro_batches = [[]] * len(micro_batches) + new_exec_times = [0.0] * len(exec_times) + new_sample_ids_per_gpu = [[]] * len(sample_ids_per_gpu) + + # No change in work until the group selected for expansion + for i in range(group_start_gpu): + new_micro_batches[i] = micro_batches[i] + new_exec_times[i] = exec_times[i] + new_sample_ids_per_gpu[i] = sample_ids_per_gpu[i] + + # The work is distributed across the expanded group + for i in range(group_start_gpu, group_end_gpu + needed_count + 1): + new_micro_batches[i] = micro_batches[group_end_gpu] + new_exec_times[i] = self.get_total_workload( + micro_batches[group_end_gpu][0], next_power + ) + new_sample_ids_per_gpu[i] = sample_ids_per_gpu[group_end_gpu] + + # Any assigned work on expanded GPUs is pushed + for i, work in enumerate(work_to_push): + new_micro_batches[group_end_gpu + needed_count + 1 + i] = work + new_exec_times[group_end_gpu + needed_count + 1 + i] = exec_times_to_push[i] + new_sample_ids_per_gpu[group_end_gpu + needed_count + 1 + i] = ( + sample_ids_to_push[i] + ) + + group_size[gid] = next_power + group_members[gid] = list(range(members[0], members[-1] + needed_count + 1)) + for pushed_gid in group_size.keys(): + if pushed_gid > gid: + group_members[pushed_gid] = [ + x + needed_count for x in group_members[pushed_gid] + ] + + return ( + new_micro_batches, + new_exec_times, + new_sample_ids_per_gpu, + group_members, + group_size, + ) + + empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) + while empty_gpus: + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size = ( + fill_empty_gpus( + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size + ) + ) + empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) + + # Assert that no sample has been completely removed + total_work_after = sum(len(mb) for mb in micro_batches) + assert ( + total_work_after >= total_work_before + ), f"Samples were removed: {total_work_before} -> {total_work_after}" + + return micro_batches, leftovers, exec_times, sample_ids_per_gpu + + def get_groups_and_subsamples(self, sample_id_seqlens, config): + """ + This function recursively forms groups of sub-samples such that all DPxCP ranks + have a roughly balanced workload in the group. + """ + groups = [] + sample_id_groups = [] + # We assign a sample_id to each sub-sample in order to track assignment to each GPU. + sample_id_seqlens = sorted(sample_id_seqlens, key=lambda x: x[1], reverse=True) + while sample_id_seqlens: + mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group( + sample_id_seqlens, self.get_total_workload, self.total_hdp_gpus + ) + groups.append(mb) + if len(sample_ids) < self.total_hdp_gpus: + sample_ids.extend([] * (self.total_hdp_gpus - len(sample_ids))) + sample_id_groups.append(sample_ids) + + # if torch.distributed.get_rank() == 0: + # breakpoint() + # torch.distributed.barrier() + return groups, sample_id_groups + + +def compute_pp_bubble_ratio(PP, m, VPP=1): + return (PP - 1) / (m * VPP + PP - 1) + + +def greedy_assign_bucket_to_dp(curr_m, indices_buckets, normal_indexes, except_buckets, except_bucket_num_per_sample, + except_bucket_m_per_sample, except_bucket_dp_per_sample, buckets_for_current_m, + dp_size_for_current_m, used_flops, used_fwd_flops, used_bwd_flops, bucket_num_per_dp_curr_m, + all_flops, all_lengths, combination=None, config=None): + """ + 使用贪心算法将桶分配给数据并行(DP)组 + + 参数: + curr_m: 当前处理的m值(微批次数量) + indices_buckets: 所有桶的索引信息 + except_buckets: 特殊处理的序列桶 + except_bucket_num_per_sample: 每个特殊序列分配的桶数 + except_bucket_m_per_sample: 每个特殊序列分配的m值 + buckets_for_current_m: 当前m值对应的桶列表 + dp_size_for_current_m: 当前m值的DP组大小 + used_flops: 已使用的总FLOPs + used_fwd_flops: 已使用的前向FLOPs + used_bwd_flops: 已使用的后向FLOPs + bucket_num_per_dp_curr_m: 每个DP rank的桶数量限制 + + 返回: + 包含以下内容的元组: + - 每个DP rank的总FLOPs列表 + - 每个DP rank的前向FLOPs列表 + - 每个DP rank的后向FLOPs列表 + - 分配给每个DP rank的桶列表 + - 是否遇到空桶的标志 + """ + + # args = get_args() + + # 初始化每个DP rank的统计列表 + fwd_flops_for_dp_per_m = [[] for _ in range(dp_size_for_current_m)] # 前向FLOPs + bwd_flops_for_dp_per_m = [[] for _ in range(dp_size_for_current_m)] # 后向FLOPs + seq_len_for_dp_per_m = [[] for _ in range(dp_size_for_current_m)] # 每个 microbatch 的 seqlen + buckets_for_dp = [[] for _ in range(dp_size_for_current_m)] # 分配的桶 + sample_ids_for_dp = [[] for _ in range(dp_size_for_current_m)] # 分配的 sample_id + sample_lengths_for_dp = [[] for _ in range(dp_size_for_current_m)] # 分配的 sample_length + + # 初始化每个DP rank的FLOPs总和和已用桶数 + fwd_flops_sum_per_dp_this_m = [0.0] * dp_size_for_current_m + bucket_used_num_per_dp_this_m = [0] * dp_size_for_current_m + prefix_sum_per_dp_this_m = 0 # 用于跟踪特殊序列桶的分配位置 + # 第一步:分配特殊序列(Seq1F1B)的桶 + # 遍历每个样本,判断其是否被分配到当前 m + num_split_for_dp = [0] * dp_size_for_current_m + # print_rank0(f"assign except bucket") + for idx in range(len(except_bucket_m_per_sample)): + # 只处理当前m值的特殊序列 + if (except_bucket_m_per_sample[idx] != curr_m): + continue + + # 计算当前序列在except_buckets中的位置范围 + st = prefix_sum_per_dp_this_m + ed = prefix_sum_per_dp_this_m + except_bucket_num_per_sample[idx] + + # 将序列片段分配到选定的DP rank + for k in range(st, ed): + # 记录FLOPs信息 + bucket_tmp = except_buckets[curr_m][k] + + bucket_tmp.fwd_flops = (bucket_tmp.fwd_flops, {}, bucket_tmp.cp_size, bucket_tmp.dp_index) + bucket_tmp.bwd_flops = (bucket_tmp.bwd_flops, {}, bucket_tmp.cp_size, bucket_tmp.dp_index) + + fwd_flops_for_dp_per_m[bucket_tmp.dp_index].append(bucket_tmp.fwd_flops) + bwd_flops_for_dp_per_m[bucket_tmp.dp_index].append(bucket_tmp.bwd_flops) + + # construct 2d array + # correction for memory simulator + seq_len_for_dp_per_m[bucket_tmp.dp_index].append([bucket_tmp.seq_len_sum // bucket_tmp.cp_size // config.min_hybrid_context_parallel_size * config.context_parallel_size]) + + # 更新DP rank的负载统计 + fwd_flops_sum_per_dp_this_m[bucket_tmp.dp_index] += bucket_tmp.fwd_flops[0] + + buckets_for_dp[bucket_tmp.dp_index].append(bucket_tmp) + sample_ids_for_dp[bucket_tmp.dp_index].append(bucket_tmp.samples) + + # 更新分配位置和桶使用计数 + num_split_for_dp[bucket_tmp.dp_index] += 1 + bucket_used_num_per_dp_this_m[bucket_tmp.dp_index] += 1 + prefix_sum_per_dp_this_m += except_bucket_num_per_sample[idx] + # for ttt in bucket_used_num_per_dp_this_m: + # print_rank0(ttt, end='\t') + # print_rank0("") + + # 第二步:分配普通桶 + # print_rank0(f"assign normal bucket") + empty_bucket_flag = False + for j in range(len(buckets_for_current_m)): + # 寻找最适合的DP rank(负载最小且桶未满) + min_flops = sys.float_info.max + min_flops_dp_rank = -1 + for dp_rank in range(len(fwd_flops_sum_per_dp_this_m)): + if (min_flops > fwd_flops_sum_per_dp_this_m[dp_rank]) and \ + (bucket_used_num_per_dp_this_m[dp_rank] < bucket_num_per_dp_curr_m): + min_flops = fwd_flops_sum_per_dp_this_m[dp_rank] + min_flops_dp_rank = dp_rank + + assert min_flops_dp_rank != -1 # 确保找到合适的DP rank + + # 获取当前桶ID并检查是否为空 + bucket_id = buckets_for_current_m[j][1] + if not indices_buckets[bucket_id] or len(indices_buckets[bucket_id].samples) == 0: + # for idx, except_b in enumerate(except_buckets[curr_m]): + # print_rank0(f"{idx=}, {except_b=}") + # print_rank0(except_b) + # for idx, normal_b in enumerate(indices_buckets): + # print_rank0(f"{idx=}, {normal_b=}") + # print_rank0(normal_b) + # import pdb; pdb.set_trace() + empty_bucket_flag = True + + # for test only + indices_buckets[bucket_id].samples_fwd_flops = [all_flops[1][indice] for indice in indices_buckets[bucket_id].samples] + + # tflops to time + scale = 0.5 + length_sum = 0 + length_square_sum = 0 + # attn_fwd_tflops_sum = 0 + # gemm_fwd_tflops_sum = 0 + lengths = [] + # NOTE shenglong + hidden_size = config.hidden_size + # hidden_size = 4096 + + bucket_tmp = indices_buckets[bucket_id] + for sample_id in bucket_tmp.samples: + length = all_lengths[sample_id] + # attn_fwd_tflops = attention_tflops(length, hidden_size, scale) + # gemm_fwd_tflops = linear_tflops(length, config.hidden_size) + + length_sum += length + lengths.append(length) + length_square_sum += (length ** 2) + # attn_fwd_tflops_sum += attn_fwd_tflops + # gemm_fwd_tflops_sum += gemm_fwd_tflops + + # fwd_time, bwd_time, fwd_time_dict = flops_to_times(length_sum, length_square_sum, attn_fwd_tflops_sum) + fwd_time, bwd_time = bucket_tmp.fwd_flops, bucket_tmp.bwd_flops # TODO(wuguohao) + fwd_time_dict = {} + + split_num = 1 + split_idx = 0 + bwd_time_dict = {} # TODO + bucket_tmp.fwd_flops = (bucket_tmp.fwd_flops, fwd_time_dict, split_num, split_idx) + bucket_tmp.bwd_flops = (bucket_tmp.bwd_flops, bwd_time_dict, split_num, split_idx) + # bucket_tmp.fwd_flops = (bucket_tmp.fwd_flops, {"attn_fwd_time":attn_fwd_tflops_sum, "mlp_fc1_fwd_time":gemm_fwd_tflops_sum}) + + # print_rank0(f"{lengths=}") + assert length_sum == bucket_tmp.seq_len_sum, f"{length_sum=}, {bucket_tmp.seq_len_sum=}, {bucket_id=}" + + # !将带 offset 的 data index 替换为真实的 data index + # indices_buckets[bucket_id].samples = [normal_indexes[indice] for indice in indices_buckets[bucket_id].samples] + + # 将桶分配给选定的DP rank + # fwd_flops_for_dp_per_m[min_flops_dp_rank].append(used_fwd_flops[bucket_id]) + # bwd_flops_for_dp_per_m[min_flops_dp_rank].append(used_bwd_flops[bucket_id]) + fwd_flops_for_dp_per_m[min_flops_dp_rank].append(bucket_tmp.fwd_flops) + bwd_flops_for_dp_per_m[min_flops_dp_rank].append(bucket_tmp.bwd_flops) + # correction for memory simulator + seq_len_for_dp_per_m[min_flops_dp_rank].append([bucket_tmp.seq_len_sum // config.min_hybrid_context_parallel_size * config.context_parallel_size]) + buckets_for_dp[min_flops_dp_rank].append(bucket_tmp) + sample_ids_for_dp[min_flops_dp_rank].append(bucket_tmp.samples) + + # 更新DP rank的负载统计 + fwd_flops_sum_per_dp_this_m[min_flops_dp_rank] += (bucket_tmp.fwd_flops[0]) + bucket_used_num_per_dp_this_m[min_flops_dp_rank] += 1 + # for ttt in bucket_used_num_per_dp_this_m: + # print_rank0(ttt, end='\t') + # print_rank0("") + + # print_rank0(f"aft asign normal bucket, {dp_size_for_current_m=}, {bucket_used_num_per_dp_this_m=}") + + for dp_rank in range(len(buckets_for_dp)): + # print_rank0(f"rank {torch.distributed.get_rank()} bucket num for dp{len(buckets_for_dp[dp_rank])}") + # num_fused = sum(1 for b in buckets_for_dp[dp_rank] if not isinstance(b, SplitBucket)) + for bucket_i, bucket in enumerate(buckets_for_dp[dp_rank]): + bucket.num_split_bucket_this_dp = num_split_for_dp[dp_rank] + # if isinstance(bucket, SplitBucket): + # # print_rank0(f"{dp_rank=}, {bucket_i=}, {bucket.fwd_flops=}") + # else: + # # print_rank0(f"{dp_rank=}, {bucket_i=}, {bucket.samples_fwd_flops=} {len(bucket.samples)=}") + assert len(buckets_for_dp) == len(sample_ids_for_dp), f"{len(sample_ids_for_dp)=}, {len(buckets_for_dp)=}" + return fwd_flops_for_dp_per_m, bwd_flops_for_dp_per_m, buckets_for_dp, sample_ids_for_dp, seq_len_for_dp_per_m, empty_bucket_flag + + + +def fwd_flops_update_rule(bucket, index, all_density, all_lengths, all_flops): + if bucket.seq_len_sum + all_lengths[index] > bucket.target_length: # add memory limit. + return None + + new_fwd_flops = bucket.fwd_flops + all_flops[1][index] + return new_fwd_flops * (new_fwd_flops / bucket.target_flops) + + +def length_update_rule(bucket, index, all_density, all_lengths, all_flops): + return ((bucket.seq_len_sum + all_lengths[index]) - bucket.target_length)** 2 / bucket.target_length**2 + + +class UpdateRule(enum.Enum): + DENSITY = 1 + FW_FLOPS = 2 + LENGTH = 3 + + +update_rule_mapping = { + UpdateRule.FW_FLOPS: fwd_flops_update_rule, + UpdateRule.LENGTH: length_update_rule +} + + +def fwd_flops_to_bwd_flops(pre_attn_fwd_time, attn_fwd_time, post_attn_fwd_time, mlp_fwd_time): + attn_bwd_time = 2.77 * attn_fwd_time + pre_attn_bwd_time = 2.7 * pre_attn_fwd_time + post_attn_bwd_time = 2.7 * post_attn_fwd_time + mlp_bwd_time = 2.7 * mlp_fwd_time + + return pre_attn_bwd_time, attn_bwd_time, post_attn_bwd_time, mlp_bwd_time + + +def attention_tflops(s, h, scale): + # NOTE: only consider forward tflops + s2 = s**2 + tflops = 2 * 2 * s2 * h / 1e12 * scale + return tflops + + +def linear_tflops(s, h1, h2): + # NOTE: only consider forward tflops + tflops = 2 * s * h1 * h2 / 1e12 + return tflops + + +def TFLOPs(s1, config): + """ + Only calculate one block TFLOPs here. + """ + scale = 0.5 + + ####### forward tflops ######## + gemm_fwd_tflops = linear_tflops(s1, config.hidden_size, config.hidden_size) + attn_fwd_tflops = attention_tflops(s1, config.hidden_size, scale) + + pre_attn_fwd_tflops = 3 * gemm_fwd_tflops + if config.num_query_groups is not None: + pre_attn_fwd_tflops = gemm_fwd_tflops + 2 * gemm_fwd_tflops / config.num_query_groups + post_attn_fwd_tflops = gemm_fwd_tflops + mlp_fc1_h = config.ffn_hidden_size * 2 if config.gated_linear_unit else config.ffn_hidden_size + mlp_fc2_h = config.ffn_hidden_size + mlp_fc1_fwd_tflops = linear_tflops(s1, config.hidden_size, mlp_fc1_h) + mlp_fc2_fwd_tflops = linear_tflops(s1, mlp_fc2_h, config.hidden_size) + + fwd_tflops = pre_attn_fwd_tflops + attn_fwd_tflops + post_attn_fwd_tflops + mlp_fc1_fwd_tflops + mlp_fc2_fwd_tflops + + ####### backward tflops ######## + pre_attn_bwd_tflops, attn_bwd_tflops, post_attn_bwd_tflops, mlp_bwd_tflops = \ + fwd_flops_to_bwd_flops(pre_attn_fwd_tflops, attn_fwd_tflops, post_attn_fwd_tflops, mlp_fc1_fwd_tflops + mlp_fc2_fwd_tflops) + + bwd_tflops = pre_attn_bwd_tflops + attn_bwd_tflops + post_attn_bwd_tflops + mlp_bwd_tflops + + ####### recompute tflops ######## + if config.recompute_granularity == "full": + bwd_tflops += fwd_tflops + else: + # TODO: add other recompute method here. + pass + + tot_tflops = fwd_tflops + bwd_tflops + return tot_tflops, fwd_tflops, bwd_tflops + + +def compute_ratios(combination, PP): + # PP = mpu.get_pipeline_model_parallel_world_size() + VPP = 1 + ratios = [] + for num_m in range(1, len(combination)+1): + ratio = num_m * PP * (PP * VPP + PP - 1) / (num_m * PP * VPP + PP - 1) + ratios.append(ratio) + return ratios + + +class Bucket: + def __init__(self, target_flops, target_density, target_length, bucket_id, cp_size, samples, fwd_flops=0, bwd_flops=0, seq_len_sum=0, dp_index=-1): + self.bucket_id = bucket_id + self.samples = samples + self.cp_size = cp_size + self.target_flops = target_flops + self.target_density = target_density + self.target_length = target_length + self.current_density = 0 + self.fwd_flops = fwd_flops + self.bwd_flops = bwd_flops + self.seq_len_sum = seq_len_sum + self.dp_index = dp_index + self.type = "Bucket" + + def __str__(self): + return ( + f"Bucket {self.bucket_id}:\n" + f" Target Flop: {self.target_flops}\n" + f" Target Density: {self.target_density}\n" + f" Target Length: {self.target_length}\n" + f" Current Density: {self.current_density}\n" + f" Forward Flops: {self.fwd_flops}\n" + f" Backward Flops: {self.bwd_flops}\n" + f" Sequence Length Sum: {self.seq_len_sum}\n" + f" Samples: {self.samples}\n" + ) + + +def create_buckets(num_buckets, avg_fwd_flops_with_m, max_seq_len_for_fuse): + total_bucket_num = 0 + all_buckets = [] + + for i in range(len(num_buckets)): + for j in range(num_buckets[i]): + target_density = avg_fwd_flops_with_m[i] / max_seq_len_for_fuse + all_buckets.append( + Bucket( + target_flops = avg_fwd_flops_with_m[i], + target_density = target_density, + target_length = max_seq_len_for_fuse, + bucket_id = total_bucket_num, + cp_size = 1, + samples = [], + ) + ) + total_bucket_num += 1 + + return total_bucket_num, all_buckets + + +def assign_samples_to_buckets( + sorted_indices, + buckets, + all_density, + all_lengths, + all_flops, + update_rule=None, + remaining_sample_indices=None, + print_score=False, +): + + preassigned_samples = [] + if update_rule is UpdateRule.DENSITY: + raise Exception() + print_rank0("using density update rule") + pre_assign_sample_to_empty_bucket(sorted_indices, buckets, all_density, all_lengths, all_flops, remaining_sample_indices, preassigned_samples) + else: + assert len(preassigned_samples) == 0 + update_rule = update_rule_mapping[update_rule] + + for index in sorted_indices: + if index in preassigned_samples: + print(f"{index=}, {preassigned_samples=}") + continue + + min_score = float('inf') + target_bucket = None + + score = None + for bucket in buckets: + score = update_rule(bucket, index, all_density, all_lengths, all_flops) + if score is not None and score < min_score: + min_score = score + target_bucket = bucket + + if target_bucket is not None: + target_bucket.fwd_flops += all_flops[1][index] + target_bucket.bwd_flops += (2 * all_flops[1][index]) # TODO(wuguohao): use more precisely bwd_flops + target_bucket.seq_len_sum += all_lengths[index] + target_bucket.samples.append(index) + remaining_sample_indices.remove(index) + # if torch.distributed.get_rank() == 0: print(f"pop {index=}, {len(remaining_sample_indices)=}, length_rule={update_rule == length_update_rule}, flops_rule={update_rule == fwd_flops_update_rule} {remaining_sample_indices=}") + else: + if update_rule == length_update_rule: + if torch.distributed.get_rank() == 0: print(f"skip {index=}, {score=}, {min_score=} {len(buckets)=}, {len(remaining_sample_indices)=}, length_rule={update_rule == length_update_rule}, flops_rule={update_rule == fwd_flops_update_rule} ") + + return remaining_sample_indices + + +def nearest_pow2(n: int) -> int: + """ + 将正整数 n 四舍五入到最接近的 2 的幂。 + n < 1 时返回 1。 + """ + if n < 1: + return 1 + # lower = 2^(⌊log2 n⌋) + lower = 1 << (n.bit_length() - 1) + # upper = 2^(⌈log2 n⌉) + upper = 1 << n.bit_length() + # 距离较小者 + return lower if (n - lower) < (upper - n) else upper + + +def simulate_memory(chunks_list, config): + from megatron.pipeline_simulator.hotsim.model import Model + from megatron.pipeline_simulator.hotsim.memory_model import MemoryModel + from megatron.pipeline_simulator.hotsim.training_config import TrainingConfig + from megatron.pipeline_simulator.hotsim.schedule import build_splitfuse_schedule + model = Model( + name="Llama", + vocab_size=config.vocab_size, + hidden_size=config.hidden_size, + intermediate_size=config.ffn_hidden_size, + num_hidden_layers=config.num_layers, + num_attention_heads=config.num_attention_heads, + ) + ckpt_type = "no" + if config.recompute_granularity == "full": + ckpt_type = "full" + # if config.kaimm_recompute_mlp_activation_func and config.kaimm_recompute_norm: + # if config.kaimm_recompute_mlp_fc1: + # ckpt_type = "partial+fc1" + # else: + # ckpt_type = "partial" + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + num_gpus = torch.distributed.get_world_size() + else: + num_gpus = parallel_state.get_tensor_model_parallel_world_size() \ + * parallel_state.get_pipeline_model_parallel_world_size() \ + * parallel_state.get_data_parallel_world_size() + train_config = TrainingConfig( + model=model, + num_gpus=num_gpus, + microbatch_size=1, + tensor_parallel_size=parallel_state.get_tensor_model_parallel_world_size(), + context_parallel_size=parallel_state.get_context_parallel_world_size(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + pipeline_parallel_size=parallel_state.get_pipeline_model_parallel_world_size(), + expert_parallel_size=parallel_state.get_expert_model_parallel_world_size(), + num_model_chunks=1, + ckpt=ckpt_type, + offload_ratio=0, + # offload_ratio=config.kaimm_offload_activation_ratio, + ) + + actions_by_rank = build_splitfuse_schedule( + config.pipeline_model_parallel_size, chunks_list + ) + + memory_model = MemoryModel(train_config) + memory_model.setup(chunks_list, actions_by_rank) + memory_model.run() + return max(memory_model.peak_memory_histogram) + + +def simulate_time(fwd_costs, bwd_costs, PP, VPP): + # PP = mpu.get_pipeline_model_parallel_world_size() + schedule = SplitFuseSchedule(PP, fwd_costs, bwd_costs) + # num_VPP = 8 + # schedule = InterleavedSchedule(PP, num_VPP, fwd_costs, bwd_costs) + return test_with_schedule(schedule) + + +def fill_bucket_with_samples( + curr_except_index, + sorted_indices, + target_flops, + all_flops, + all_lengths, + max_seq_len, + remaining_sample_indices=None, + total_num=0, + consumed_num_buckets=0, + assign_all_sample_to_except_bucket_flag=False, +): + # assume sorted_indices is sorted by fwd flops in reversed order. + selected_indices = [] + selected_fwd_flops = [] + selected_bwd_flops = [] + selected_lengths = [] + remained_flops = target_flops + # print_rank0(f"###{max_num_samples_to_fill=}") + length_sum = all_lengths[curr_except_index] + for index in sorted_indices: + # if max_num_samples_to_fill == 0: break + sample_fwd_flops = all_flops[1][index] + # sample_bwd_flops = all_flops[2][index] + sample_bwd_flops = all_flops[2][index] # TODO(wuguohao): more precisely bwd_flops + extra_limit = total_num < len(remaining_sample_indices) and length_sum < (max_seq_len * consumed_num_buckets) + extra_limit = (assign_all_sample_to_except_bucket_flag) or extra_limit # skip extra_limit if `assign_all_sample_to_except_bucket_flag` is True + + exceed_ratio = 1.05 + # if assign_all_sample_to_except_bucket_flag: + # exceed_ratio = 1.5 + if sample_fwd_flops < remained_flops * exceed_ratio and extra_limit: # TODO: consume num buckets * max seq len + # if torch.distributed.get_rank() == 0: print(f"{target_flops=}, {index=}, {sample_fwd_flops=}, {sample_bwd_flops=}") + remained_flops -= sample_fwd_flops + selected_indices.append(index) + selected_fwd_flops.append(sample_fwd_flops) + selected_bwd_flops.append(sample_bwd_flops) + selected_lengths.append(all_lengths[index]) + remaining_sample_indices.remove(index) + length_sum += all_lengths[index] + # max_num_samples_to_fill -= 1 + + selected_flops = [selected_fwd_flops, selected_bwd_flops] + + return selected_indices, selected_flops, selected_lengths, remained_flops, remaining_sample_indices + + +class PipelineAwareBalancedHybridCPscheduler(BaseScheduler): + + def __init__(self, config): + super().__init__(config) + self.max_seq_len_per_rank = config.max_seqlen_per_dp_cp_rank + self.num_subsamples = 0 + self.num_subsamples_processed = 0 + self.free_resources = [] + self.total_hdp_gpus = parallel_state.get_data_parallel_world_size( + with_context_parallel=True + ) + + @lru_cache(maxsize=128) + def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): + """ + seq_length: sequence length of a sub-sample + cp_size: total number of CP ranks working on this sub-sample + + Note: + This function is used to estimate the relative workload intensity + of a sub-sample. This is not meant to be an accurate flops calculator. + + Returns: workload of a sub-sample + """ + if cp_size is None: + cp_size = self.gpus_needed(seq_length) + return (seq_length * seq_length) / cp_size + + def get_groups_and_subsamples(self, sample_id_seqlens, config, return_cp_sizes=False): + """ + This function recursively forms groups of sub-samples such that all DPxCP ranks + have a roughly balanced workload in the group. + """ + groups = [] + sample_id_groups = [] + cp_sizes = [] + # We assign a sample_id to each sub-sample in order to track assignment to each GPU. + sample_id_seqlens = sorted(sample_id_seqlens, key=lambda x: x[1], reverse=True) + # while sample_id_seqlens: + # mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group( + # sample_id_seqlens, self.get_total_workload, self.total_hdp_gpus, config=config + # ) + # groups.append(mb) + # if len(sample_ids) < self.total_hdp_gpus: + # sample_ids.extend([] * (self.total_hdp_gpus - len(sample_ids))) + # sample_id_groups.append(sample_ids) + + _, _, best_indices_buckets, best_sample_ids, best_dp_combination, _ = self.next_hdp_group( + sample_id_seqlens, self.get_total_workload, self.total_hdp_gpus, config=config + ) + + # print(best_indices_buckets[-1][0][0]) + # breakpoint() + + mi = -1 + for i in range(len(best_indices_buckets)): + if len(best_indices_buckets[i]) > 0: + mi = i + break + assert mi != -1 + best_sample_ids = best_sample_ids[mi] + best_indices_buckets = best_indices_buckets[mi] + + # print(f"{len(best_indices_buckets)=}, {len(best_sample_ids)=}") + assert len(best_indices_buckets) == len(best_sample_ids) + # print(f"{best_sample_ids=}, {len(best_indices_buckets)=}, {len(best_indices_buckets[0])=}, {len(best_indices_buckets[1])=}") + # breakpoint() + + def transpose_2d_list(matrix): + return [list(row) for row in zip(*matrix)] + + local_sample_id_groups = transpose_2d_list(best_sample_ids) + local_best_indices_buckets = transpose_2d_list(best_indices_buckets) + # groups = + min_hybrid_context_parallel_size = config.min_hybrid_context_parallel_size + for microbatch_idx in range(len(local_sample_id_groups)): + sample_id_groups.append([]) + groups.append([]) + cp_sizes.append([]) + dpxcp = len(local_sample_id_groups[microbatch_idx]) * min_hybrid_context_parallel_size + for dp_rank in range(dpxcp): + # for min_hybrid_context_parallel_rank in range(min_hybrid_context_parallel_size): + sample_id_groups[microbatch_idx].append([]) + groups[microbatch_idx].append([]) + cp_sizes[microbatch_idx].append([]) + origin_dp_rank = dp_rank // min_hybrid_context_parallel_size + # if torch.distributed.get_rank() == 0: print(f"{microbatch_idx=}, {dp_rank=}, {origin_dp_rank=}, {local_sample_id_groups[microbatch_idx][origin_dp_rank]=}") + for local_sample_idx in local_sample_id_groups[microbatch_idx][origin_dp_rank]: + sample_id_groups[microbatch_idx][dp_rank].append(sample_id_seqlens[local_sample_idx][0]) + groups[microbatch_idx][dp_rank].append(sample_id_seqlens[local_sample_idx][1]) + final_cp_size = local_best_indices_buckets[microbatch_idx][origin_dp_rank].cp_size * min_hybrid_context_parallel_size + cp_sizes[microbatch_idx][dp_rank].append(final_cp_size) + + # if torch.distributed.get_rank() == 0: print(f"{sample_id_groups=}") + # if torch.distributed.get_rank() == 0: print(f"{cp_sizes=}") + def flatten(lst): + result = [] + for item in lst: + if isinstance(item, list): + result.extend(flatten(item)) + else: + result.append(item) + return result + + # 示例 + # nested_list = [1, [2, 3], [4, [5, 6]], 7] + # print(flatten(nested_list)) # [1, 2, 3, 4, 5, 6, 7] + + # breakpoint() + + if return_cp_sizes: + return groups, sample_id_groups, cp_sizes + + return groups, sample_id_groups + def split_sample( + self, + num_buckets: List[int], + avg_fwd_flops_with_m: List[float], + all_lengths, + all_flops, + except_indexes, + normal_indexes, + combination, + DP, PP, UP, TP, + max_split_size, + max_seq_len, + config, + ): + num_layers = config.num_layers # 模型层数 + hidden_size = config.hidden_size # 隐藏层大小 + num_heads = config.num_attention_heads # 注意力头数 + assert hidden_size % num_heads == 0, "hidden_size should be divisible by num_heads" + head_dim = hidden_size // num_heads # 每个注意力头的维度 + ffn_size = config.ffn_hidden_size # FFN层隐藏大小 + + # 初始化特殊序列的桶分配结构 + except_buckets = [[] for _ in range(len(num_buckets))] # 每个m值对应的特殊序列桶 + except_bucket_num = 0 # 特殊序列桶计数器 + except_bucket_m_per_sample = [] # 记录每个样本分配到的m值 + except_bucket_dp_per_sample = [] # 记录每个样本分配到的dp值 + except_bucket_num_per_sample = [] # 记录每个样本分割的桶数量 + + # 计算每个 m 下单个 dp 的桶数(相同 m 的不同 dp 的桶数相等) + bucket_num_per_dp_per_m = [] + # import pdb;pdb.set_trace() + for i in range(len(num_buckets)): + if combination[i] > 0: + assert num_buckets[i] % combination[i] == 0, f"{i=}, {num_buckets[i]=}, {combination[i]=}" + bucket_num_per_dp_per_m.append(num_buckets[i] // combination[i]) + else: + bucket_num_per_dp_per_m.append(0) + # print_rank0(f"{bucket_num_per_dp_per_m=}") + + # 维护不同 dp 当前剩余桶数,使用该桶数去做大 UP + remain_buckets_num_per_dp_per_m = [] + for i in range(len(num_buckets)): + if combination[i] > 0: + assert num_buckets[i] % combination[i] == 0, f"{i=}, {num_buckets[i]=}, {combination[i]=}" + # import pdb;pdb.set_trace() + remain_buckets_num_per_dp_per_m.append([num_buckets[i] // combination[i]] * combination[i]) + else: + remain_buckets_num_per_dp_per_m.append([]) + + # 遍历所有需要独占一路 DP 的序列 + single_sample_indexes = [] # 去掉需要独占一个 DP 的样本后的 except_indexes + combination_used = [0] * len(combination) + + # 重新计算 桶的容积 + sum_fwd_flops = sum([all_flops[1][idx] for idx in except_indexes if idx not in single_sample_indexes]) + \ + sum([all_flops[1][idx] for idx in normal_indexes]) + + ratios = compute_ratios(combination, PP=PP) + avg_fwd_flops_with_m_new = [] + total_num = sum([(combination[j]-combination_used[j]) * ratios[j] for j in range(len(combination))]) # TODO: total num need to - exceed_buckets num + mean_fwd_flops_with_m = sum_fwd_flops / total_num + for i in range(1, len(combination)+1): + avg_fwd_flops_with_m_new.append(mean_fwd_flops_with_m * ratios[i - 1] / i / PP) + + avg_fwd_flops_with_m = avg_fwd_flops_with_m_new + + non_zero_combination = [(combination[idx]-combination_used[idx]) != 0 for idx in range(len(combination))] + first_non_zero_m = 1 + non_zero_combination.index(True) + threshold = 2 * sum_fwd_flops / (first_non_zero_m * (DP-len(single_sample_indexes)) * PP) + + consumed_num_buckets_backup = {} + consumed_num_buckets_raw_backup = {} + for idx, index in enumerate(except_indexes): + find_bucket = False + for i in range(len(num_buckets)): + # 只考虑有剩余桶的m值 + if combination[i] > 0: + # 计算当前序列需要的桶数量(向上取整) + consumed_num_buckets_raw = math.ceil(all_flops[1][index] / avg_fwd_flops_with_m[i]) + remain_num_split_sample = len(except_indexes) - 1 - idx + consumed_num_buckets = min(nearest_pow2(consumed_num_buckets_raw), max_split_size, DP//config.min_hybrid_context_parallel_size, num_buckets[i]-remain_num_split_sample) + consumed_num_buckets_raw_backup[index] = consumed_num_buckets_raw + consumed_num_buckets_backup[index] = consumed_num_buckets + # 更新剩余桶数量 + num_buckets[i] -= consumed_num_buckets + find_bucket = True + break + + + assign_all_sample_to_except_bucket_flag = False + assert sum(num_buckets) >= 0 + if sum(num_buckets) == 0: + assign_all_sample_to_except_bucket_flag = True + # if torch.distributed.get_rank() == 0: print(f"{num_buckets=}\n{consumed_num_buckets_raw_backup.keys()=}\n{consumed_num_buckets_raw_backup.values()=}\n{consumed_num_buckets_backup.keys()=}\n{consumed_num_buckets_backup.values()=}\n{except_indexes=}") + + for index in except_indexes: + find_bucket = False + for i in range(len(num_buckets)): + # 只考虑有剩余桶的m值 + if combination[i] > 0: + # 计算当前序列需要的桶数量(向上取整) + # consumed_num_buckets_raw = math.ceil(all_flops[1][index] / avg_fwd_flops_with_m[i]) + # consumed_num_buckets = min(min(nearest_pow2(consumed_num_buckets_raw), max_split_size), DP) + consumed_num_buckets = consumed_num_buckets_backup[index] + # print(f"{index=}, {i=}, {consumed_num_buckets_raw=}, {consumed_num_buckets=}") + remained_flops = consumed_num_buckets * avg_fwd_flops_with_m[i] - all_flops[1][index] + + # choose the CP interval + max_value = -1 + max_left = max_right = -1 + max_indexes = [-1] * consumed_num_buckets + for j in range(combination[i]): + left = (j // consumed_num_buckets) * consumed_num_buckets + right = (j // consumed_num_buckets + 1) * consumed_num_buckets + min_value_this_interval = 10000000 + # for dp size not divisible by consumed_num_buckets, continue to skip this search space + if right > len(remain_buckets_num_per_dp_per_m[i]): + continue + + for k in range(left, right): #left close right close + min_value_this_interval = min(min_value_this_interval, remain_buckets_num_per_dp_per_m[i][k]) + if max_value < min_value_this_interval: + max_value = min_value_this_interval + max_left = left + max_right = right + max_indexes = list(range(left, right)) + + normal_indexes_copy = copy.deepcopy(normal_indexes) + selected_indices, selected_flops, selected_lengths, remained_flops, remaining_sample_indices = fill_bucket_with_samples(index, normal_indexes, remained_flops, all_flops, all_lengths, max_seq_len, normal_indexes_copy, total_num, consumed_num_buckets, assign_all_sample_to_except_bucket_flag) + # print(f"\n{len(selected_indices)=}, {selected_indices=}\n{len(remaining_sample_indices)=}, {remaining_sample_indices=}\n{sum(selected_lengths)=}, {sum(selected_flops[0])=}, {remained_flops=}, {all_lengths[index]=}, {max_seq_len*consumed_num_buckets=}") + normal_indexes = remaining_sample_indices + for j in range(consumed_num_buckets): + remain_buckets_num_per_dp_per_m[i][max_indexes[j]] -= 1 + + assert len(max_indexes) == consumed_num_buckets, f"{len(max_indexes)=}, {consumed_num_buckets=}" + # 将分割后的序列片段分配到各个桶中 + for j in range(consumed_num_buckets): + bucket_fwd_flops = all_flops[1][index] + sum(selected_flops[0]) + bucket_bwd_flops = (3 * all_flops[1][index]) + sum(selected_flops[1]) # TODO(wuguohao): more precisely bwd_flops + bucket_length = all_lengths[index] + sum(selected_lengths) + bucket_tmp = [index] + selected_indices + #shenglong target_flops=1 to handle except use all buckets + except_buckets[i].append( + Bucket( + bucket_id=except_bucket_num, + samples=bucket_tmp, + cp_size=consumed_num_buckets, + fwd_flops=bucket_fwd_flops/consumed_num_buckets, + bwd_flops=bucket_bwd_flops/consumed_num_buckets, + seq_len_sum=bucket_length, + target_flops=1, target_density=0, target_length=0, + dp_index=max_indexes[j], + ) + ) + except_bucket_num += 1 # 递增桶计数器 + + # 更新剩余桶数量 + # num_buckets[i] -= consumed_num_buckets + # 记录分配信息 + except_bucket_num_per_sample.append(consumed_num_buckets) + except_bucket_m_per_sample.append(i) + except_bucket_dp_per_sample.append(max_indexes) + + find_bucket = True + break # 成功分配到桶中,跳出循环 + + if not find_bucket: + raise NotImplementedError("not found a bucket for the sample") + + assert len(except_bucket_m_per_sample) == len(except_bucket_num_per_sample), f"{len(except_bucket_m_per_sample)=}, {len(except_bucket_num_per_sample)=}" + return except_buckets, num_buckets, except_bucket_num_per_sample, except_bucket_m_per_sample, except_bucket_dp_per_sample, except_indexes, normal_indexes, avg_fwd_flops_with_m + + def next_hdp_group( + self, + sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples + compute_estimator: Callable[[int], float], + total_gpus: int, + delta: float = 0.05, # balance slack (e.g. 5 %) + strategy: str = "dp", # "dp" or "pp" + eps_bucket: float = 0.10, # ε target for bucket balance + config = None, + ) -> Tuple[List[List[int]], List[Tuple[int, int]], List[float], List[List[int]]]: + + DP = parallel_state.get_data_parallel_world_size() + PP = parallel_state.get_pipeline_model_parallel_world_size() + UP = parallel_state.get_context_parallel_world_size() + TP = parallel_state.get_tensor_model_parallel_world_size() + + VPP = 1 + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + VPP = parallel_state.get_virtual_pipeline_model_parallel_world_size() + + # if torch.distributed.get_rank() == 0: + # breakpoint() + # torch.distributed.barrier() + + max_split_size = config.max_hybrid_context_parallel_size // config.min_hybrid_context_parallel_size + max_seq_len = config.max_seqlen_per_dp_cp_rank + + all_lengths = [sample_seqlens[i][1] for i in range(len(sample_seqlens))] + + all_flops = [] + all_tot_flops = [] + all_fwd_flops = [] + all_bwd_flops = [] + for idx in range(len(all_lengths)): + length = all_lengths[idx] + flops = TFLOPs(length, config) + all_tot_flops.append(flops[0]) + all_fwd_flops.append(flops[1]) + all_bwd_flops.append(flops[2]) + all_flops.append(all_tot_flops) + all_flops.append(all_fwd_flops) + all_flops.append(all_bwd_flops) + + all_density = [all_flops[1][i] / all_lengths[i] for i in range(len(all_lengths))] + best_max_seq_per_m = 0 + sum_fwd_flops = sum(all_flops[1]) + + assert len(all_lengths) == len(all_flops[1]) + + def dynamic_loops_product(limits): + min_max_flops_sum_per_iter = sys.float_info.max / 10.0 + best_indices_buckets = [] + best_sample_ids = [] + best_dp_combination = [] + + assert DP % config.min_hybrid_context_parallel_size == 0 + limit_item = DP // config.min_hybrid_context_parallel_size + + for idx, limit in enumerate(limits): + combination = [0] * len(limits) + combination[idx] = limit + if sum(combination) != limit_item: + continue + + num_buckets = [PP * i * combination[i - 1] for i in range(1, len(combination)+1)] + num_buckets_sum = sum(num_buckets) + + if num_buckets_sum > len(all_lengths): + print(f"continue due to num_buckets_sum {num_buckets_sum=} > len(all_lengths) {len(all_lengths)=}") + continue + + ratios = compute_ratios(combination, PP=PP) + avg_fwd_flops_with_m = [] + total_num = sum(combination[j] * ratios[j] for j in range(len(combination))) + mean_fwd_flops_with_m = sum_fwd_flops / total_num + + for i in range(1, len(combination)+1): + avg_fwd_flops_with_m.append(mean_fwd_flops_with_m * ratios[i - 1] / i / PP) + + import time + st_time = time.time() + indices_buckets, sample_ids, max_flops_sum_per_iter, max_seq_per_m, used_flops = \ + self.solver(all_lengths, all_flops, all_density, num_buckets, avg_fwd_flops_with_m, combination, DP, PP, UP, TP, VPP, max_seq_len, max_split_size, config) + ed_time = time.time() + if torch.distributed.get_rank() == 0: print(f"solver cost time :{ed_time-st_time}s") + + if max_flops_sum_per_iter < min_max_flops_sum_per_iter: + min_max_flops_sum_per_iter = max_flops_sum_per_iter + best_indices_buckets = indices_buckets + best_sample_ids = sample_ids + best_dp_combination = combination + # if torch.distributed.get_rank() == 0: + # print(f"{best_dp_combination=}\n{best_indices_buckets=}\n{best_sample_ids=}") + + return min_max_flops_sum_per_iter, best_indices_buckets, best_sample_ids, best_dp_combination + + search_space = config.search_space + assert DP % config.min_hybrid_context_parallel_size == 0 + limit_item = DP // config.min_hybrid_context_parallel_size + # limits = [limit_item] * search_space + if isinstance(search_space, int): + limits = [limit_item] * search_space + elif isinstance(search_space, list): + limits = [0] * max(search_space) + for idx in search_space: + limits[idx] = limit_item + else: + raise Exception(f"`search_space` should be int or list, but {type(search_space)} found.") + + min_max_flops_sum_per_iter, best_indices_buckets, best_sample_ids, best_dp_combination = dynamic_loops_product(limits) + if torch.distributed.get_rank() == 0: print(f"{best_dp_combination=}") + # assert all DP have the same num_microbatch + sum_best_dp_combination = sum(best_dp_combination) + best_m = -1 + for idx, num_dp in enumerate(best_dp_combination): + if num_dp == sum_best_dp_combination: + best_m = idx + break + assert best_m != -1, f"{best_dp_combination=}" + + if not best_dp_combination: + raise Exception() + + best_var_m = 0 + return min_max_flops_sum_per_iter, best_max_seq_per_m, best_indices_buckets, best_sample_ids, best_dp_combination, best_var_m + + def solver( + self, + all_lengths: List[int], + all_flops: List[List[float]], + all_density: List[float], + num_buckets: List[int], + avg_fwd_flops_with_m: List[float], + combination, + DP, PP, UP, TP, VPP, + max_seq_len, + max_split_size, + config, + ): + except_indexes = [] + normal_indexes = [] + + non_zero_combination = [combination[idx] != 0 for idx in range(len(combination))] + first_non_zero_m = 1 + non_zero_combination.index(True) + + sum_fwd_flops = sum(all_flops[1]) + threshold = 1.3 * sum_fwd_flops / (first_non_zero_m * DP * PP) + for idx in range(len(all_flops[1])): + if all_flops[1][idx] > threshold: + except_indexes.append(idx) + else: + normal_indexes.append(idx) + # if torch.distributed.get_rank() == 0: + # print(f"\n{except_indexes=}") + # except_flops = [] + # for idx in except_indexes: + # except_flops.append(all_flops[1][idx]) + # print(f"{except_indexes=}\n{except_flops=}") + except_indexes = sorted(except_indexes, key=lambda x: all_flops[1][x], reverse=True) + normal_indexes = sorted(normal_indexes, key=lambda x: all_flops[1][x], reverse=True) + + except_buckets, num_buckets, except_bucket_num_per_sample, except_bucket_m_per_sample, except_bucket_dp_per_sample, except_indexes, normal_indexes, avg_fwd_flops_with_m = \ + self.split_sample(num_buckets, avg_fwd_flops_with_m, all_lengths, all_flops, except_indexes, normal_indexes, combination, DP, PP, UP, TP, max_split_size, max_seq_len, config) + + sum_remained_flops = sum([all_flops[1][index] for index in normal_indexes]) + + # for the case that except indexes take all buckets + if sum(num_buckets) != 0: + max_seq_len_for_fuse = sum([all_lengths[idx] for idx in normal_indexes]) / sum(num_buckets) + else: + max_seq_len_for_fuse = 0 + + + if max_seq_len_for_fuse == 0: + assert len(normal_indexes) == 0 + total_bucket_num, all_buckets = create_buckets(num_buckets, avg_fwd_flops_with_m, max_seq_len_for_fuse) + sorted_indices_fwdflops = sorted(normal_indexes, key=lambda x: all_flops[1][x], reverse=True) + sorted_all_buckets_fwd_flops = sorted(all_buckets, key=lambda bucket: bucket.fwd_flops) + all_sample_index_copy = copy.deepcopy(sorted_indices_fwdflops) + + all_sample_index_copy_bef_flops = copy.deepcopy(all_sample_index_copy) + all_sample_index_copy = assign_samples_to_buckets(sorted_indices_fwdflops, + sorted_all_buckets_fwd_flops, + all_density, + all_lengths, + all_flops, + update_rule=UpdateRule.FW_FLOPS, + remaining_sample_indices=all_sample_index_copy) + + # If there are some leftover of the samples + # (e.g. if put the sample in any of the bucket will cause the bucket exceed the memory limit), + # we will use the length update rule to assign those samples to the bucket. + # The all_sample_index_copy should contain only a few samples. Sorting might be unnecessary. + sorted_indices_length = sorted(all_sample_index_copy, key=lambda x: all_lengths[x], reverse=True) + sorted_all_buckets_length = sorted(all_buckets, key=lambda bucket: bucket.seq_len_sum) + + if len(all_sample_index_copy) > 0: + all_sample_index_copy_bef_len = copy.deepcopy(all_sample_index_copy) + all_sample_index_copy = assign_samples_to_buckets(sorted_indices_length, sorted_all_buckets_length, all_density, all_lengths, all_flops, update_rule=UpdateRule.LENGTH, remaining_sample_indices=all_sample_index_copy, print_score=True) + + assert len(all_sample_index_copy) == 0, f"sample {all_sample_index_copy} is not assigned to any bucket." + + indices_buckets = [[] for _ in range(total_bucket_num)] + used_flops = [0.0] * total_bucket_num + used_fwd_flops = [0.0] * total_bucket_num + used_bwd_flops = [0.0] * total_bucket_num + max_seq_per_m = 0 + seq_per_m = [] + + for bucket in sorted_all_buckets_fwd_flops: + bucket_id = bucket.bucket_id + indices_buckets[bucket_id] = bucket + used_flops[bucket_id] = bucket.fwd_flops + bucket.bwd_flops + used_fwd_flops[bucket_id] = bucket.fwd_flops + used_bwd_flops[bucket_id] = bucket.bwd_flops + max_seq_per_m = max(bucket.seq_len_sum, max_seq_per_m) + seq_per_m.append(bucket.seq_len_sum) + + indices_buckets_2d = [[] for _ in range(len(num_buckets))] + sample_ids_2d = [[] for _ in range(len(num_buckets))] + new_cnt = 0 + max_sum_per_iter = 0.0 + rets = [0.0] * DP + thread_cnt = 0 + + max_iter_sum_among_dp_list = [] + for i in range(len(num_buckets)): + if len(except_buckets[i]) + num_buckets[i] == 0: + assert combination[i] == 0, f"{combination=}, {num_buckets=}, {len(except_buckets[i])=}" + continue + + total_buckets_for_current_m = num_buckets[i] + len(except_buckets[i]) + num_m = i + 1 + bucket_num_per_dp_curr_m = num_m * PP + assert total_buckets_for_current_m % bucket_num_per_dp_curr_m == 0, f"{total_buckets_for_current_m=}, {bucket_num_per_dp_curr_m=}" + dp_size_for_current_m = total_buckets_for_current_m // bucket_num_per_dp_curr_m + + buckets_for_current_m = [] + for j in range(num_buckets[i]): + buckets_for_current_m.append([used_flops[new_cnt], new_cnt, used_fwd_flops[new_cnt]]) + new_cnt += 1 + + buckets_for_current_m.sort(key=lambda x: x[2]) + + fwd_flops_for_dp_per_m, bwd_flops_for_dp_per_m, buckets_for_dp, sample_ids_for_dp, seq_len_for_dp_per_m, empty_bucket_flag = greedy_assign_bucket_to_dp(i, indices_buckets, normal_indexes, except_buckets, except_bucket_num_per_sample, except_bucket_m_per_sample, except_bucket_dp_per_sample, buckets_for_current_m, dp_size_for_current_m, used_flops, used_fwd_flops, used_bwd_flops, bucket_num_per_dp_curr_m, all_flops, all_lengths, combination, config) + + for j in range(len(buckets_for_dp)): + indices_buckets_2d[i].append(buckets_for_dp[j]) + sample_ids_2d[i].append(sample_ids_for_dp[j]) + + assert len(indices_buckets_2d) == len(sample_ids_2d), f"{len(indices_buckets_2d)=}, {len(sample_ids_2d)=}" + + bubble_time_list = [] + if empty_bucket_flag: + print(f"error, found empty bucket, skip") + max_sum_per_iter = sys.float_info.max / 10.0 + else: + for m in range(len(fwd_flops_for_dp_per_m)): + total_bucket_num_for_current_dp = len(fwd_flops_for_dp_per_m[m]) + forward_cost = [fwd_flops_for_dp_per_m[m][k][0] for k in range(len(fwd_flops_for_dp_per_m[m]))] + backward_cost = [bwd_flops_for_dp_per_m[m][k][0] for k in range(len(fwd_flops_for_dp_per_m[m]))] + seq_len_for_dp = seq_len_for_dp_per_m[m] + communication_cost = [0.0] * len(fwd_flops_for_dp_per_m[m]) + + forward_cost_cmp = [] + backward_cost_cmp = [] + assert len(fwd_flops_for_dp_per_m[m]) == len(bwd_flops_for_dp_per_m[m]) + for k in range(len(fwd_flops_for_dp_per_m[m])): + split_num = fwd_flops_for_dp_per_m[m][k][2] + split_idx = fwd_flops_for_dp_per_m[m][k][3] + fwd_cost = fwd_flops_for_dp_per_m[m][k][0] + bwd_cost = bwd_flops_for_dp_per_m[m][k][0] + + forward_cost_cmp.append([fwd_cost]) + backward_cost_cmp.append([bwd_cost]) + + max_iter_sum_among_dp = simulate_time(forward_cost_cmp, backward_cost_cmp, PP, VPP) + + max_iter_sum_among_dp_list.append(max_iter_sum_among_dp) + max_sum_per_iter = max(max_sum_per_iter, max_iter_sum_among_dp) + + if config.run_memory_simulator: + peak_memory = simulate_memory(seq_len_for_dp, config) + + forward_cost_cmp = torch.tensor(forward_cost_cmp).flatten().tolist() + backward_cost_cmp = torch.tensor(backward_cost_cmp).flatten().tolist() + + fwd_cost_total = sum(forward_cost_cmp) + bwd_cost_total = sum(backward_cost_cmp) + + fwd_bwd_cost_total = fwd_cost_total + bwd_cost_total + num_microbatch = (i+1) * PP + pp_bubble_ratio = compute_pp_bubble_ratio(PP, num_microbatch, VPP) + + pp_bubble_time = fwd_bwd_cost_total / (1 - pp_bubble_ratio) - fwd_bwd_cost_total + bubble_idle_time = max_iter_sum_among_dp - fwd_bwd_cost_total + imbalanced_bubble_time = bubble_idle_time - pp_bubble_time + + bubble_over_iter_time = bubble_idle_time / max_iter_sum_among_dp + bubble_over_compute_time = bubble_idle_time / fwd_bwd_cost_total + + pp_bubble_over_iter_time = pp_bubble_time / max_iter_sum_among_dp + pp_bubble_over_compute_time = pp_bubble_time / fwd_bwd_cost_total + + imbalanced_bubble_over_iter_time = imbalanced_bubble_time / max_iter_sum_among_dp + imbalanced_bubble_over_compute_time = imbalanced_bubble_time / fwd_bwd_cost_total + + bubble_time_list.append({ + "pp_bubble_ratio": pp_bubble_ratio, + "bubble_over_compute_time":bubble_over_compute_time, + "pp_bubble_over_compute_time":pp_bubble_over_compute_time, + "imbalanced_bubble_over_compute_time":imbalanced_bubble_over_compute_time, + }) + + if config.run_memory_simulator and peak_memory >= 70 * 1024**3: + max_sum_per_iter = sys.float_info.max / 10.0 # skip this m + print(f"rank={torch.distributed.get_rank()}, Peak memory usage: {peak_memory / 1024**3:.2f} GiB, {combination=}") + + if torch.distributed.get_rank() == 0: + print(f"{combination=}") + for k in range(len(bubble_time_list)): + for key in bubble_time_list[k].keys(): + bubble_time_list[k][key] = round(bubble_time_list[k][key], 3) + print(f"{k=}, {bubble_time_list[k]}") + + max_max_iter_sum = max(max_iter_sum_among_dp_list) + min_max_iter_sum = min(max_iter_sum_among_dp_list) + sum_max_iter_sum = sum(max_iter_sum_among_dp_list) + len_max_iter_sum = len(max_iter_sum_among_dp_list) + mean_max_iter_sum = sum_max_iter_sum/len_max_iter_sum + + # print(f"{sample_ids_2d=}") + + return indices_buckets_2d, sample_ids_2d, max_sum_per_iter, max_seq_per_m, used_flops + +class OnlyPackingNoSchedulingScheduler(BaseScheduler): + """ + This scheduler only packs sequences in their original order + and does not perform any load balancing. + """ + + def __init__(self, config): + super().__init__(config) + self.dp_size = int(parallel_state.get_data_parallel_world_size()) + self.cp_size = int(parallel_state.get_context_parallel_world_size()) + self.max_seq_len_all_ranks = config.max_seqlen_per_dp_cp_rank * self.cp_size diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py deleted file mode 100644 index 27b5fc87945..00000000000 --- a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py +++ /dev/null @@ -1,660 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -from collections import deque -from functools import lru_cache -from math import ceil, log2 -from typing import Callable, List, Optional, Tuple - -import torch - -from megatron.core import parallel_state -from megatron.core.rerun_state_machine import RerunDataIterator - - -class BalancedCPScheduler: - """ - This class provides the functionality to form groups of sub-samples - such that all DPxCP ranks have a roughly balanced workload in the group. - """ - - def __init__(self, max_seq_len_per_rank: int, dp_cp_group: torch.distributed.ProcessGroup): - self.max_seq_len_per_rank = max_seq_len_per_rank - self.num_subsamples = 0 - self.num_subsamples_processed = 0 - self.free_resources = [] - self.total_hdp_gpus = dp_cp_group.size() - - @lru_cache(maxsize=128) - def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): - """ - seq_length: sequence length of a sub-sample - cp_size: total number of CP ranks working on this sub-sample - - Note: - This function is used to estimate the relative workload intensity - of a sub-sample. This is not meant to be an accurate flops calculator. - - Returns: workload of a sub-sample - """ - if cp_size is None: - cp_size = self.gpus_needed(seq_length) - return (seq_length * seq_length) / cp_size - - @lru_cache(maxsize=128) - def gpus_needed(self, seq_len: int) -> int: - """ - Calculates the number of GPUs needed for a given sequence length - and max sequence length per CP rank. - This is used to determine the CP size of a sub-sample. - - The number is rounded up to the next power of 2 to match the available - hybrid context parallel process group sizes. - """ - return max(1, 2 ** ceil(log2((seq_len / self.max_seq_len_per_rank)))) - - def make_buckets_equal( - self, - sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples - compute_estimator: Callable[[int], float], - ) -> List[deque]: - """ - Makes as many buckets as unique CP sizes needed. - This keeps sample IDs tethered to their sequence lengths throughout the bucketing process. - """ - # Extract just the sequence lengths for determining k - seqlens = [seq_len for _, seq_len in sample_seqlens] - - # Determine k based on unique GPU categories needed - k = len({self.gpus_needed(L) for L in seqlens}) - - # Create a work target for each bucket - # This is the total work divided by the number of buckets - work = [] - for _, s in sample_seqlens: - cp_size = self.gpus_needed(s) - work.append(compute_estimator(s, cp_size)) - total_work = sum(work) - target = total_work / k - buckets, cur, cur_work = [], [], 0.0 - remaining_work = total_work - remaining_k = k - - for i, (sample_id, seq_len) in enumerate(sample_seqlens): - work = compute_estimator(seq_len) - projected = cur_work + work - - # Check if we should close this bucket - if cur and ( - projected > target * 1.1 # Too much work - or len(sample_seqlens) - i <= remaining_k - len(buckets) - ): # Need to save sequences for remaining buckets - buckets.append(deque(cur)) - cur, cur_work = [], 0.0 - remaining_work -= sum(compute_estimator(seq_len) for _, seq_len in cur) - remaining_k -= 1 - - cur.append((sample_id, seq_len)) - cur_work += work - - if cur: - buckets.append(deque(cur)) - - return buckets - - def next_hdp_group( - self, - sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples - compute_estimator: Callable[[int], float], - total_gpus: int, - delta: float = 0.05, # balance slack (e.g. 5 %) - strategy: str = "dp", # "dp" or "pp" - eps_bucket: float = 0.10, # ε target for bucket balance - ) -> Tuple[List[List[int]], List[Tuple[int, int]], List[float], List[List[int]]]: - """ - Given a list of (sample_id, sequence_length) tuples, this function aims to assign - sequences in a group such that all GPUs in the DPxCP group have a roughly balanced - workload. Once each group is roughly balanced, we exit and return the - group and the leftover sequences. - - The function performs the following passes in order to form a balanced microbatch: - 1. We create buckets of sequences that are roughly balanced. - We try to create as many buckets as possible CP sizes. - 2. Given a bucket has sequences available, we assign the sample - a. To a new set of GPUs if there are enough free GPUs. - b. To an existing set of GPUs with the lowest load. - 3. We check if the group is balanced whenever we need to move onto a new CP size - in the same set of GPUs. - 4. We trim the group if removing the last added sequence helps improve balance. - 5. If we run out of sequences to assign and there are empty GPUs, - we redistribute work to empty GPUs by recursively increasing the CP size of a - sample until no empty GPUs are left. - - Returns (micro_batches, leftover_sample_seqlens, exec_times, sample_ids_per_gpu). - """ - if not sample_seqlens: - return ( - [[] for _ in range(total_gpus)], - [], - [0.0 for _ in range(total_gpus)], - [[] for _ in range(total_gpus)], - ) - - # Get buckets of sequences with balanced work - buckets = self.make_buckets_equal(sample_seqlens, compute_estimator) - - # Initialize tracking structures - micro_batches = [[] for _ in range(total_gpus)] - exec_times = [0.0 for _ in range(total_gpus)] - sample_ids_per_gpu = [[] for _ in range(total_gpus)] - - gpu_group_id = [None] * total_gpus - group_members = {} - group_size = {} - next_gid = 0 - - pp_cursor = 0 - prev_needed = None - check_balance = False - - while buckets: - # ---- Step 1 – pick the next sequence we COULD place ------------------ - sample_seq_tuple = bucket_idx = None - needed = None - - scan_order = ( - range(len(buckets)) - if strategy == "dp" - else [(pp_cursor + i) % len(buckets) for i in range(len(buckets))] - ) - - for idx in scan_order: - if not buckets[idx]: - continue - cand_tuple = buckets[idx][0] # This is now (sample_id, seq_len) - cand_seq_len = cand_tuple[1] - needed = self.gpus_needed(cand_seq_len) - - # (a) Do we have an *existing* group of size `needed`? - candidate_gids = [gid for gid, sz in group_size.items() if sz == needed] - - # (b) Or enough completely free GPUs to start a new group? - free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] - if candidate_gids or len(free_ranks) >= needed: - sample_seq_tuple, bucket_idx = cand_tuple, idx - break - - # No place to put any remaining sequence – finish this micro‑batch - if sample_seq_tuple is None: - break - - # TODO[pmannan]: PP not yet supported. Add PP scheduling. - if strategy == "pp": - pp_cursor = (bucket_idx + 1) % len(buckets) - - sample_id, seq_len = sample_seq_tuple - needed = self.gpus_needed(seq_len) - if prev_needed is None: - prev_needed = needed - - # (a) Existing groups of exactly this size - candidate_gids = [gid for gid, sz in group_size.items() if sz == needed] - if candidate_gids: - best_gid, best_load = min( - ( - (gid, max(exec_times[r] for r in group_members[gid])) - for gid in candidate_gids - ), - key=lambda t: t[1], - ) - else: - best_gid, best_load = None, float("inf") - - # (b) Hypothetical **new** group from completely free GPUs - free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] - if len(free_ranks) >= needed: - free_sorted = sorted(free_ranks, key=lambda r: exec_times[r]) - new_members = free_sorted[:needed] - new_load = exec_times[new_members[-1]] - - if new_load < best_load: - best_gid = None - chosen_members = new_members - else: - chosen_members = group_members[best_gid] - else: - chosen_members = group_members[best_gid] - - # ---- Step 2 – if we decided to create a fresh group ---------------- - if best_gid is None: - best_gid = next_gid - next_gid += 1 - group_members[best_gid] = chosen_members - group_size[best_gid] = needed - for r in chosen_members: - gpu_group_id[r] = best_gid - - # ---- Step 3 – assign the sequence to every member of that group ------ - per_gpu_cost = compute_estimator(seq_len) - - for r in chosen_members: - micro_batches[r].append(seq_len) - exec_times[r] += per_gpu_cost - sample_ids_per_gpu[r].append(sample_id) - - # Remove the sequence definitively from its bucket - buckets[bucket_idx].popleft() - - # ---- Step 4 – tidy, balance‑check, maybe early‑exit ------------------ - while buckets and not buckets[0]: - buckets.pop(0) - pp_cursor %= max(1, len(buckets)) - - # TODO: Removing this helps reduce the number of groups when we have - # lots of samples with same CP size. - # But because we don't exit as soon as we get balanced, - # even if there is one group available that can take the next sample, - # we will keep adding samples to the same group. - # trim_overload() does not help because it only checks if removing the - # last added sample helps. - # We cannot check after adding every sample because there will always be imbalance - # if we don't wait for future scheduling. - - # IMPORTANT: So we need a solution here - if needed < prev_needed: - # When we get into a lower CP size in the same group, - # we can start checking for balance. There is still a gotcha here. - # Let's say we have a group of 3 GPU 0-2, then we move onto group of 2. - # We keep assigning group of 2 as we do in descending order but GPU 7/15 - # never sees a microbatch assigned to it - # until we run out of samples with CP2. - # This means we are never balanced as min(exec_times) will always be 0. - # We need a smart way of identifying that we have run out of big samples - # and if we are having to assign work to a GPU already working, - # is it because there are empty GPUs? - # Would assigning work to empty GPUs first by moving onto next CP bucket help? - # But we need to remember to come back to this CP size bucket and then - # check for balance. Maybe the scheduling algorithm should look at empty - # GPUs and find work rather than going sequence by sequence. - check_balance = True - - if ( - check_balance - and buckets - and max(exec_times) - min(exec_times) <= delta * max(exec_times) - ): - break - - # Gather leftovers (flatten remaining buckets, preserve order) - leftovers = [] - for b in buckets: - for sample_seq_tuple in b: - leftovers.append(sample_seq_tuple) - - # --------------------------------------------------------------------------- - def trim_overload(): - """ - Iteratively pop the most‑recent sequence from the *most‑loaded group* - whenever doing so reduces the global slack. - """ - while True: - cur_max = max(exec_times) - cur_min = min(exec_times) - cur_slack = cur_max - cur_min - if cur_slack <= delta * cur_max: - # Slack is already within limit. - break - if cur_min == 0: - # There are empty GPUs that will be - # handled in the next step. - break - - max_r = exec_times.index(cur_max) - gid = gpu_group_id[max_r] - members = group_members[gid] - - if not micro_batches[max_r] or len(micro_batches[max_r]) <= 1: - break - - seq = micro_batches[max_r][-1] - need = group_size[gid] - per_gpu_cost = compute_estimator(seq) - - proj_times = exec_times[:] - for r in members: - proj_times[r] -= per_gpu_cost - - proj_slack = max(proj_times) - min(proj_times) - - # Check if trimming the workload helps imbalance - if proj_slack < cur_slack: - sample_id_to_remove = sample_ids_per_gpu[max_r][-1] - for r in members: - micro_batches[r].pop() - exec_times[r] -= per_gpu_cost - sample_ids_per_gpu[r].pop() - leftovers.append((sample_id_to_remove, seq)) - else: - break - - trim_overload() - - # Track samples in this group before redistribution to empty GPUs - total_work_before = sum(len(mb) for mb in micro_batches) - - # Check for empty GPUs and redistribute work - def fill_empty_gpus( - micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size - ): - """ - Recursively check for empty GPUs and redistribute work by increasing - the number of GPUs sharing samples. This ensures all GPUs have work. - GPUs must be allocated consecutively so we may need to push existing - work to other ranks in order to expand samples. - """ - # Find empty GPUs - empty_gpus = [i for i in range(total_gpus) if not micro_batches[i]] - if not empty_gpus: - return ( - micro_batches, - exec_times, - sample_ids_per_gpu, - group_members, - group_size, - ) # No empty GPUs, we're done - - # Find the smallest group size that exists - existing_group_sizes = set(group_size.values()) - assert ( - existing_group_sizes - ), "There should be at least one group existing, cannot reditribute, " - "try to increase 'max-seqlen-per-cp-rank'." - - min_group_size = min(existing_group_sizes) - # We have Hybrid DPxCP groups for every power of 2 of GPUs or the entire DPxCP group. - next_power = min(min_group_size * 2, total_gpus) - - # Find the first group of min_group_size that can be expanded - expandable_gid = None - expandable_members = None - expandable_new_gpus = None - - for gid, size in group_size.items(): - if size == min_group_size: - members = group_members[gid] - needed_count = next_power - min_group_size - group_start_gpu = members[0] - group_end_gpu = members[-1] - empty_gpu = [idx for idx, work in enumerate(micro_batches) if not work][0] - assert not all( - work for work in micro_batches[empty_gpu : empty_gpu + needed_count] - ), f"Empty GPUs were detected but not enough to expand." - work_to_push = micro_batches[ - group_end_gpu + 1 : empty_gpu - ] # This is work of all other subsequent sub-samples - exec_times_to_push = exec_times[group_end_gpu + 1 : empty_gpu] - sample_ids_to_push = sample_ids_per_gpu[group_end_gpu + 1 : empty_gpu] - - new_micro_batches = [[]] * len(micro_batches) - new_exec_times = [0.0] * len(exec_times) - new_sample_ids_per_gpu = [[]] * len(sample_ids_per_gpu) - - # No change in work until the group selected for expansion - for i in range(group_start_gpu): - new_micro_batches[i] = micro_batches[i] - new_exec_times[i] = exec_times[i] - new_sample_ids_per_gpu[i] = sample_ids_per_gpu[i] - - # The work is distributed across the expanded group - for i in range(group_start_gpu, group_end_gpu + needed_count + 1): - new_micro_batches[i] = micro_batches[group_end_gpu] - new_exec_times[i] = self.get_total_workload( - micro_batches[group_end_gpu][0], next_power - ) - new_sample_ids_per_gpu[i] = sample_ids_per_gpu[group_end_gpu] - - # Any assigned work on expanded GPUs is pushed - for i, work in enumerate(work_to_push): - new_micro_batches[group_end_gpu + needed_count + 1 + i] = work - new_exec_times[group_end_gpu + needed_count + 1 + i] = exec_times_to_push[i] - new_sample_ids_per_gpu[group_end_gpu + needed_count + 1 + i] = ( - sample_ids_to_push[i] - ) - - group_size[gid] = next_power - group_members[gid] = list(range(members[0], members[-1] + needed_count + 1)) - for pushed_gid in group_size.keys(): - if pushed_gid > gid: - group_members[pushed_gid] = [ - x + needed_count for x in group_members[pushed_gid] - ] - - return ( - new_micro_batches, - new_exec_times, - new_sample_ids_per_gpu, - group_members, - group_size, - ) - - empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) - while empty_gpus: - micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size = ( - fill_empty_gpus( - micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size - ) - ) - empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) - - # Assert that no sample has been completely removed - total_work_after = sum(len(mb) for mb in micro_batches) - assert ( - total_work_after >= total_work_before - ), f"Samples were removed: {total_work_before} -> {total_work_after}" - - return micro_batches, leftovers, exec_times, sample_ids_per_gpu - - def get_groups_and_subsamples(self, sample_id_seqlens, config): - """ - This function recursively forms groups of sub-samples such that all DPxCP ranks - have a roughly balanced workload in the group. - """ - groups = [] - sample_id_groups = [] - # We assign a sample_id to each sub-sample in order to track assignment to each GPU. - sample_id_seqlens = sorted(sample_id_seqlens, key=lambda x: x[1], reverse=True) - while sample_id_seqlens: - mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group( - sample_id_seqlens, self.get_total_workload, self.total_hdp_gpus - ) - groups.append(mb) - if len(sample_ids) < self.total_hdp_gpus: - sample_ids.extend([] * (self.total_hdp_gpus - len(sample_ids))) - sample_id_groups.append(sample_ids) - - return groups, sample_id_groups - - -def hybrid_context_parallel_forward_backward( - forward_step_func, - data_iterator, - model, - num_microbatches, - input_tensor, - output_tensor_grad, - forward_data_store, - config, - collect_non_loss_data, - first_val_step, - forward_only, - no_sync_func, - total_num_tokens, - check_first_val_step, - model_type, -): - """ - Scheduler for Hybrid Context Parallel. - - This function performs the packed sample scheduling and determines - 1. The number of microbatches to schedule for each CP rank - 2. The number of groups each CP rank should execute - 3. The number of sub-samples per group each CP rank should execute - - A group is defined by a set of samples that can run across the CP domain without any barrier. - There are many reasons why we may not be able to run endless samples within a single group. - For example, if we have 8 GPUs, - if GPU 0-5 are assigned a long sample that requires CP6, - GPU 6-7 are assigned a short sample that requires CP2, - The next sample which requires CP4 can be assigned GPU 4-7. - But GPU 6-7 will finish first and get deadlocked if GPU 4-5 are not participating in the group. - """ - from .schedules import backward_step, forward_step - - def _broadcast(item): - if item is not None: - torch.distributed.broadcast( - item, - parallel_state.get_tensor_model_parallel_src_rank(), - group=parallel_state.get_tensor_model_parallel_group(), - ) - - def _broadcast_num_samples_this_group(num_samples_this_group): - dev = torch.cuda.current_device() - torch.distributed.barrier() - - n = 0 if num_samples_this_group is None else int(num_samples_this_group.numel()) - n = torch.tensor([n], dtype=torch.int64, device=dev) - - _broadcast(n) - n = int(n.item()) - - assert n > 0, "there should be at least 1 sub samples in the group" - num_samples_this_group_broadcast = ( - torch.empty(n, dtype=torch.int32, device=dev) - if num_samples_this_group is None - else num_samples_this_group - ) - _broadcast(num_samples_this_group_broadcast) - return num_samples_this_group_broadcast - - def _get_new_data_iterator(sample_id_in_group, group_id): - if is_first_tp_rank: - sub_sample_id = sample_ids_this_group[sample_id_in_group] - sample = batch[sub_sample_id] - partner_cp_size = len( - [True for sample_ids in sample_id_groups[group_id] if sub_sample_id in sample_ids] - ) - sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - new_data_iterator = RerunDataIterator(iter([sample])) - return new_data_iterator - else: - return None - - # We get data once per global batch and schedule the sub-samples. - # TODO(pmannan): Should we wrap the data_iterator here instead of the training.py file? - hdp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) - is_first_tp_rank = parallel_state.get_tensor_model_parallel_rank() == 0 - - if is_first_tp_rank: - data = next(data_iterator) - sample_id_groups = data[1] - batch = data[0] - else: - data, sample_id_groups, batch = None, None, None - - num_samples_this_group = None - if is_first_tp_rank: - num_samples_this_group = torch.tensor( - [len(group[hdp_rank]) for group in sample_id_groups], dtype=torch.int32, device='cuda' - ) - - num_samples_this_group = _broadcast_num_samples_this_group(num_samples_this_group) - num_samples_this_group = num_samples_this_group.cpu().numpy() - num_total_groups = num_samples_this_group.shape[0] - - current_microbatch = 0 - - # Upto last group, we don't need any sync. - with no_sync_func(): - for j in range(num_total_groups - 1): - sample_ids_this_group = sample_id_groups[j][hdp_rank] if is_first_tp_rank else None - for i in range(num_samples_this_group[j]): - # Call forward step for each sub-sample - new_data_iterator = _get_new_data_iterator(i, j) - # TODO: Find the usage of current_microbatch and is_first_microbatch and - # how that may affect my usage. - output_tensor, num_tokens = forward_step( - forward_step_func, - new_data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - is_first_microbatch=check_first_val_step( - first_val_step, forward_only, current_microbatch == 0 - ), - current_microbatch=current_microbatch, - ) - current_microbatch += 1 - total_num_tokens += num_tokens.item() - if not forward_only: - backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config - ) - - # Create a barrier at end of each group. - # This barrier ensures that all ranks are prepared to change assigned CP group sizes and - # no rank is starting a sub-sample ahead of it's partner ranks. - torch.distributed.barrier( - parallel_state.get_data_parallel_group(with_context_parallel=True) - ) - - # For the last group, we need to run the last sub-sample out of the context handler. - with no_sync_func(): - sample_ids_this_group = sample_id_groups[-1][hdp_rank] if is_first_tp_rank else None - for i in range(num_samples_this_group[-1] - 1): - new_data_iterator = _get_new_data_iterator(i, -1) - # Call forward step for each sub-sample - output_tensor, num_tokens = forward_step( - forward_step_func, - new_data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - is_first_microbatch=check_first_val_step( - first_val_step, forward_only, current_microbatch == 0 - ), - current_microbatch=current_microbatch, - ) - current_microbatch += 1 - total_num_tokens += num_tokens.item() - if not forward_only: - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - - # The last sub-sample of the last group of the last microbatch is - # run out of the context handler. - new_data_iterator = _get_new_data_iterator(-1, -1) - # Call forward step for each sub-sample - output_tensor, num_tokens = forward_step( - forward_step_func, - new_data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - is_first_microbatch=check_first_val_step( - first_val_step, forward_only, current_microbatch == 0 - ), - current_microbatch=current_microbatch, - ) - total_num_tokens += num_tokens.item() - if not forward_only: - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - - return forward_data_store, total_num_tokens diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index a8fdf2324f2..29d38cedd0d 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -4,11 +4,13 @@ from functools import partial from typing import Callable, Iterator, List, Optional, Union +import nvtx import torch from torch.autograd.variable import Variable from megatron.core import parallel_state from megatron.core.enums import ModelType +from megatron.core.pipeline_parallel.data_schedule import PackingScheduler, wrap_dataloader from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_reset, ) @@ -36,7 +38,6 @@ combined_1f1b_schedule_for_interleaved_pipelining, combined_1f1b_schedule_for_no_pipelining, ) -from .hybrid_cp_schedule import hybrid_context_parallel_forward_backward # Types Shape = Union[List[int], torch.Size] @@ -394,6 +395,10 @@ def forward_step( if config.timers is not None: config.timers('forward-compute', log_level=2).start() + + from megatron.training.global_vars import get_gpu_timers + gpu_timer = get_gpu_timers() + gpu_timer.start(name="forward-compute") if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'): model.set_is_first_microbatch() @@ -431,6 +436,7 @@ def forward_step( cp_group_size, is_last_stage, ) + gpu_timer.stop(name="forward-compute") if unwrap_output_tensor: return output_tensor, num_tokens @@ -453,6 +459,10 @@ def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, c if config.timers is not None: config.timers('backward-compute', log_level=2).start() + from megatron.training.global_vars import get_gpu_timers + gpu_timer = get_gpu_timers() + gpu_timer.start(name="backward-compute") + # Retain the grad on the input_tensor. unwrap_input_tensor_grad = False if not isinstance(input_tensor, list): @@ -498,6 +508,8 @@ def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, c if config.timers is not None: config.timers('backward-compute').stop() + gpu_timer.stop(name="backward-compute") + return input_tensor_grad @@ -509,6 +521,80 @@ def check_first_val_step(first_val_step, forward_only, cond): return cond +def wrap_iterator_helper( + config, + data_iterator: Union[Iterator, List[Iterator]], + num_microbatches: int, + pg_collection: Optional[ProcessGroupCollection] = None, +): + """Warp data iterator for sequence packing if needed.""" + if config.sft_sequence_packing: + num_total_tokens_this_GB, sequence_square_sum_this_GB = None, None + if config.hybrid_context_parallel: + if config.hybrid_context_parallel_scheduler == 'balanced': + ( + data_iterator, + num_microbatches, + num_total_tokens_this_GB, + sequence_square_sum_this_GB, + ) = wrap_dataloader( + data_iterator, config, PackingScheduler.HYBRID_CP, pg_collection=None + ) + elif config.hybrid_context_parallel_scheduler == 'balanced_with_pp': + ( + data_iterator, + num_microbatches, + num_total_tokens_this_GB, + sequence_square_sum_this_GB, + ) = wrap_dataloader( + data_iterator, config, PackingScheduler.HYBRID_CP_WITH_PP, pg_collection=None + ) + elif config.hybrid_context_parallel_scheduler == 'only_packing_no_scheduling': + ( + data_iterator, + num_microbatches, + num_total_tokens_this_GB, + sequence_square_sum_this_GB, + ) = wrap_dataloader( + data_iterator, + config, + PackingScheduler.ONLY_PACKING_NO_SCHEDULING, + pg_collection=None, + ) + else: + raise ValueError( + f"Invalid hybrid context parallel scheduler: \ + {config.hybrid_context_parallel_scheduler}" + ) + else: + if config.balanced_sequence_packing: + # enable balanced sequence packing scheduler, will be implemented later + pass + else: + # naive sequence packing scheduler + ( + data_iterator, + num_microbatches, + num_total_tokens_this_GB, + sequence_square_sum_this_GB, + ) = wrap_dataloader( + data_iterator, + config, + PackingScheduler.NAIVE_SEQUENCE_PACKING, + pg_collection=None, + ) + # if torch.distributed.get_rank() == 12: + # print(f"{data_iterator=}, {num_microbatches=}, {num_total_tokens_this_GB=}, {sequence_square_sum_this_GB=}") + return ( + data_iterator, + num_microbatches, + num_total_tokens_this_GB, + sequence_square_sum_this_GB, + ) + else: + return data_iterator, num_microbatches, None, None + + def forward_backward_no_pipelining( *, forward_step_func, @@ -591,6 +677,10 @@ def forward_backward_no_pipelining( input_tensor, output_tensor_grad = None, None total_num_tokens = torch.zeros([], dtype=torch.int, device="cuda") + data_iterator, num_microbatches, num_total_tokens_this_GB, sequence_square_sum_this_GB = ( + wrap_iterator_helper(config, data_iterator, num_microbatches, pg_collection) + ) + if config.overlap_moe_expert_parallel_comm and not forward_only: forward_data_store, total_num_tokens = combined_1f1b_schedule_for_no_pipelining( forward_step_func, @@ -608,24 +698,6 @@ def forward_backward_no_pipelining( total_num_tokens, partial(check_first_val_step, first_val_step, forward_only), ) - elif config.hybrid_context_parallel: - forward_data_store, total_num_tokens = hybrid_context_parallel_forward_backward( - forward_step_func, - data_iterator, - model, - num_microbatches, - input_tensor, - output_tensor_grad, - forward_data_store, - config, - collect_non_loss_data, - first_val_step, - forward_only, - no_sync_func, - total_num_tokens, - check_first_val_step, - model_type, - ) else: with no_sync_func(): for i in range(num_microbatches - 1): @@ -689,6 +761,9 @@ def forward_backward_no_pipelining( ): create_cudagraphs() + if config.sft_sequence_packing: + forward_data_store.append([num_total_tokens_this_GB, sequence_square_sum_this_GB]) + return forward_data_store @@ -941,6 +1016,10 @@ def forward_backward_pipelining_with_interleaving( if config.overlap_p2p_comm and config.batch_p2p_comm: raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") + data_iterator, num_microbatches, num_total_tokens_this_GB, sequence_square_sum_this_GB = ( + wrap_iterator_helper(config, data_iterator, num_microbatches, pg_collection) + ) + # Needed only when gradients are finalized in M-Core if config.finalize_model_grads_func is not None and not forward_only: # vp is ignored for clear_embedding_activation_buffer @@ -1025,14 +1104,17 @@ def enable_grad_sync(): # If the final micro-batch group has fewer micro-batches than pipeline-parallel size, # the pipeline will have dependency bubbles. final_microbatch_group_size = num_microbatches % config.microbatch_group_size_per_vp_stage - if 0 < final_microbatch_group_size < pipeline_parallel_size: - msg = 'The remainder of M (the total micro-batches) divided by N (number of ' - msg += 'contiguous micro-batches in a virtual pipeline stage) should be 0, ' - msg += 'or larger than or equal to the pipeline-parallel size, but it is ' - msg += f'{final_microbatch_group_size}. ' - msg += 'Otherwise, it introduces dependency bubbles in the pipeline ' - msg += 'and reduces throughput.' - raise RuntimeError(msg) + if not config.sft_sequence_packing: + # sft sequence packing allows num_microbatches to change dynamically, + # we don't need to check this + if 0 < final_microbatch_group_size < pipeline_parallel_size: + msg = 'The remainder of M (the total micro-batches) divided by N (number of ' + msg += 'contiguous micro-batches in a virtual pipeline stage) should be 0, ' + msg += 'or larger than or equal to the pipeline-parallel size, but it is ' + msg += f'{final_microbatch_group_size}. ' + msg += 'Otherwise, it introduces dependency bubbles in the pipeline ' + msg += 'and reduces throughput.' + raise RuntimeError(msg) model_type = get_model_type(model[0]) @@ -1957,6 +2039,9 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): create_cudagraphs() nvtx_range_pop(suffix="misc") + if config.sft_sequence_packing: + forward_data_store.append([num_total_tokens_this_GB, sequence_square_sum_this_GB]) + return forward_data_store @@ -2073,6 +2158,49 @@ def forward_backward_pipelining_without_interleaving( "Invalid combination of p2p_communicator, pg_collection " "provide none or provide all the process groups" ) + data_iterator, num_microbatches, num_total_tokens_this_GB, sequence_square_sum_this_GB = ( + wrap_iterator_helper(config, data_iterator, num_microbatches, pg_collection) + ) + if torch.distributed.get_rank() == 0: + print(f"rank={torch.distributed.get_rank()}, {num_microbatches=}") + + if is_pp_first_stage(p2p_communicator.pp_group) or is_pp_last_stage(p2p_communicator.pp_group): + nvtx.push_range("send info_tensor among pp ranks") + if config.sft_sequence_packing: + info_tensor = torch.tensor( + [num_microbatches, num_total_tokens_this_GB, sequence_square_sum_this_GB], + dtype=torch.float, pin_memory=True + ).to("cuda", non_blocking=True) + if not is_pp_last_stage(p2p_communicator.pp_group): + next_rank = torch.distributed.get_global_rank( + p2p_communicator.pp_group, p2p_communicator.pp_group.rank() + 1 + ) + torch.distributed.send(info_tensor, dst=next_rank) + + # TODO(tailaim): last pp rank does not need to receive num_microbatches + if config.sft_sequence_packing and not (is_pp_first_stage(p2p_communicator.pp_group)): + nvtx.push_range("recv info_tensor among pp ranks") + info_tensor = torch.empty(3, dtype=torch.float, device="cuda") + prev_rank = torch.distributed.get_global_rank( + p2p_communicator.pp_group, p2p_communicator.pp_group.rank() - 1 + ) + torch.distributed.recv(info_tensor, src=prev_rank) + + if not is_pp_last_stage(p2p_communicator.pp_group): + next_rank = torch.distributed.get_global_rank( + p2p_communicator.pp_group, p2p_communicator.pp_group.rank() + 1 + ) + torch.distributed.send(info_tensor, dst=next_rank) + + info_tensor = info_tensor.cpu() + num_microbatches = int(info_tensor[0].item()) + num_total_tokens_this_GB = int(info_tensor[1].item()) + sequence_square_sum_this_GB = info_tensor[2].item() + nvtx.pop_range() + + # data_iterator, num_microbatches, num_total_tokens_this_GB, sequence_square_sum_this_GB = ( + # wrap_iterator_helper(config, data_iterator, num_microbatches, pg_collection) + # ) # Needed only when gradients are finalized in M-Core if config.finalize_model_grads_func is not None and not forward_only: @@ -2080,8 +2208,10 @@ def forward_backward_pipelining_without_interleaving( config, model, is_pp_last_stage(p2p_communicator.pp_group) ) + nvtx.push_range("config.barrier_with_L1_time") if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + nvtx.pop_range() if not forward_only and config.fine_grained_activation_offloading: fine_grained_offloading_reset() @@ -2343,4 +2473,7 @@ def enable_grad_sync(): ): create_cudagraphs() + if config.sft_sequence_packing: + forward_data_store.append([num_total_tokens_this_GB, sequence_square_sum_this_GB]) + return forward_data_store diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index d3ec11aaf5c..3af344e88e7 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -287,6 +287,7 @@ def forward(self, input_): masked_input[input_mask] = 0 else: masked_input = input_ + masked_input = masked_input % self.num_embeddings # Get the embeddings. if self.deterministic_mode: output_parallel = self.weight[masked_input] diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 3c1c05f8c86..96e3619bbb4 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -701,6 +701,10 @@ def forward( (Tuple[Tensor, Tensor]) Attention output and bias. """ + # here we need to set the right cp group for hybrid-cp + if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: + self.pg_collection.cp = packed_seq_params.cp_group + # Check if we need to skip RoPE # no_rope is 0-indexed array and self.layer_number is 1-indexed no_rope = ( @@ -876,6 +880,8 @@ def forward( else: cu_seqlens_q = cu_seqlens_kv = None + # print(f"rank={torch.distributed.get_rank()}, {cu_seqlens_q=}") + if split_qkv: if q_pos_emb is not None: # TODO VIJAY: simplify diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index af8ae572adb..5ac14dd69e0 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -664,12 +664,50 @@ def token_dispatch(self, permutated_local_input_tokens, permuted_probs): self.tokens_per_expert = self._maybe_dtoh_and_synchronize( "before_ep_alltoall", self.tokens_per_expert ) - global_input_tokens = all_to_all( - self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits - ) - global_probs = all_to_all( - self.ep_group, permuted_probs, self.output_splits, self.input_splits - ) + # TODO(tailaim): remove this after testing + # debugmtl + # global_input_tokens = all_to_all( + # self.ep_group, permutated_local_input_tokens, + # self.output_splits, self.input_splits + # ) + # global_probs = all_to_all( + # self.ep_group, permuted_probs, self.output_splits, + # self.input_splits + # ) + try: + global_input_tokens = all_to_all( + self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits + ) + global_probs = all_to_all( + self.ep_group, permuted_probs, self.output_splits, self.input_splits + ) + except RuntimeError as e: + # 获取 EP group 内的 rank(防止 group 还没初始化时报错) + try: + rank = torch.distributed.get_rank(self.ep_group) + except Exception: + rank = -1 + + print(f"[MoE all_to_all error] rank={rank}, err={e}") + print( + f"[MoE all_to_all debug] " + f"tokens_shape={getattr(permutated_local_input_tokens, 'shape', None)}, " + f"probs_shape={getattr(permuted_probs, 'shape', None)}" + ) + print( + f"[MoE all_to_all debug] " + f"input_splits={self.input_splits}, sum={sum(self.input_splits) if self.input_splits is not None else None}, " + f"output_splits={self.output_splits}, sum={sum(self.output_splits) if self.output_splits is not None else None}" + ) + print( + f"[MoE all_to_all debug] " + f"tokens_per_expert={self.tokens_per_expert}, " + f"sum={self.tokens_per_expert.sum() if hasattr(self.tokens_per_expert, 'sum') else None}" + ) + torch.set_printoptions(profile="full") + print(f"hidden_states shape: {self.hidden_shape}") + print(f"routing_map: {self.routing_map}") + raise return global_input_tokens, global_probs diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 023db1fe75a..467f2c49af0 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -711,49 +711,56 @@ def forward( ) else: for l_no, layer in enumerate(self.layers): - # Get appropriate inner quantization context - if use_inner_quantization_context: - if self.config.fp8: - inner_quantization_context = get_fp8_context( - self.config, layer.layer_number - 1 - ) - elif self.config.fp4: - inner_quantization_context = get_fp4_context( - self.config, layer.layer_number - 1 - ) + # debugmtl + try: + # Get appropriate inner quantization context + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context( + self.config, layer.layer_number - 1 + ) + elif self.config.fp4: + inner_quantization_context = get_fp4_context( + self.config, layer.layer_number - 1 + ) + else: + inner_quantization_context = nullcontext() else: inner_quantization_context = nullcontext() - else: - inner_quantization_context = nullcontext() - if self.config.fine_grained_activation_offloading: - fine_grained_offloading_set_last_layer( - l_no == self.num_layers_per_pipeline_rank - 1 - ) - - with self.offload_context, inner_quantization_context: - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - rotary_pos_cos_sin=rotary_pos_cos_sin, - attention_bias=attention_bias, - inference_context=inference_context, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - ) + if self.config.fine_grained_activation_offloading: + fine_grained_offloading_set_last_layer( + l_no == self.num_layers_per_pipeline_rank - 1 + ) - if ( - torch.is_grad_enabled() - and self.config.cpu_offloading - and self.group_prefetch_offload_commit_async is not None - ): - hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + with self.offload_context, inner_quantization_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + except Exception as e: + # print( + # f"rank:{torch.distributed.get_rank()}, error: {e}, \ + # error layer number: {layer.layer_number}" + # ) + raise e # Final layer norm. if self.final_layernorm is not None: hidden_states = self.final_layernorm(hidden_states) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 31dd5a98a58..225d6679824 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -43,6 +43,15 @@ class TransformerConfig(ModelParallelConfig): #################### # model architecture #################### + min_hybrid_context_parallel_size: int = 1 + + max_hybrid_context_parallel_size: int = 1 + + run_memory_simulator = False + + search_space = 6 + + vocab_size: int = 0 num_layers: int = 0 """Number of transformer layers in a transformer block.""" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 3ea40577009..89bd6b28677 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -457,7 +457,12 @@ def forward(self, *args, **kwargs): # runners in the cuda graph manager kwargs.pop("dynamic_inference_decode_only", None) hidden_states, context = self._forward_attention(*args, **kwargs) + # debugmtl sync here hang + # torch.cuda.synchronize() output = self._forward_mlp(hidden_states, kwargs.get("inference_context", None)) + # debugmtl barrier here works, sync here works + # torch.distributed.barrier() + # torch.cuda.synchronize() return output, context def _forward_attention( @@ -677,6 +682,10 @@ def _forward_mlp(self, hidden_states, inference_context=None): else: mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + # debugmtl barrier here works, sync here hang + # torch.distributed.barrier() + # torch.cuda.synchronize() + if self.recompute_pre_mlp_layernorm: # discard the output of the pre-mlp layernorm and register the recompute # as a gradient hook of mlp_output_with_bias[0] @@ -710,6 +719,10 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual): hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)( mlp_output_with_bias, residual, self.hidden_dropout ) + + # debugmtl barrier here works + # torch.distributed.barrier() + nvtx_range_pop(suffix="mlp_bda") if self.offload_mlp_norm: (hidden_states,) = fine_grained_offloading_group_commit( @@ -726,6 +739,9 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual): inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) + # debugmtl barrier here works! + # torch.distributed.barrier() + return output def sharded_state_dict( diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 3a153468ae6..00c9f5c5fb2 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -42,6 +42,7 @@ from megatron.core import parallel_state from megatron.core.dist_checkpointing.mapping import ShardedTensor +from megatron.core.packed_seq_params import PackedSeqParams try: from packaging.version import Version as PkgVersion @@ -57,6 +58,12 @@ except ImportError: HAVE_NVTX = False +# Register the TE CUDA kernels +import transformer_engine # pylint: disable=unused-import + +# Alias the PyTorch wrapper so we can call tex.* APIs +import transformer_engine_torch as tex + logger = logging.getLogger(__name__) try: @@ -1956,7 +1963,7 @@ def is_submodule(module, parent_module, strict=True): def get_batch_on_this_cp_rank( - batch: Dict[str, Any], cp_group: Optional[torch.distributed.ProcessGroup] = None + batch: Dict[str, Any], cp_size: Optional[int] = None, cp_rank: Optional[int] = None ): """Slice batch input along sequence dimension into multiple chunks, which are parallelized across GPUs in a context parallel group. @@ -1974,14 +1981,15 @@ def get_batch_on_this_cp_rank( # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0 # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so # that we can get balanced workload among GPUs in a context parallel group. - # Determine CP topology either from provided group or from current context parallel state - if cp_group is not None: - cp_size = get_pg_size(cp_group) - cp_rank = get_pg_rank(cp_group) - else: + if cp_size is not None or cp_rank is not None: + assert ( + cp_size is not None and cp_rank is not None + ), "Both cp_size and cp_rank must be provided for batch slicing" + + if cp_size is None: cp_size = parallel_state.get_context_parallel_world_size() + if cp_rank is None: cp_rank = parallel_state.get_context_parallel_rank() - if cp_size > 1: for key, val in batch.items(): if val is not None: @@ -2006,97 +2014,74 @@ def get_thd_batch_on_this_cp_rank( batch: Dict[str, Any], cu_seqlens: torch.Tensor, cu_seqlens_padded: torch.Tensor, - max_seqlen: torch.Tensor, + max_seqlen: Optional[int] = None, + cp_size: Optional[int] = None, + cp_rank: Optional[int] = None, + local_cp_size: Optional[int] = None, cp_group: Optional[torch.distributed.ProcessGroup] = None, + only_packed_seq_params: bool = False, + vp_stage: Optional[int] = None, ): """Slice each sub-sample in a packed sample batch input along sequence dimension into multiple chunks, which are parallelized across GPUs in a context parallel group. """ - packed_seq_params = PackedSeqParams( - qkv_format="thd", - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - cu_seqlens_q_padded=cu_seqlens_padded, - cu_seqlens_kv_padded=cu_seqlens_padded, - max_seqlen_q=int(max_seqlen[0].item()), - max_seqlen_kv=int(max_seqlen[0].item()), - ) - - if cp_group is not None: - cp_size = get_pg_size(cp_group) - cp_rank = get_pg_rank(cp_group) + if local_cp_size: + # enable hybrid context parallel + cp_size = local_cp_size + if cp_group is None: + cp_group = parallel_state.get_hybrid_data_context_parallel_groups(group_size=cp_size) + cp_rank = torch.distributed.get_rank(group=cp_group) + assert cp_group.size() == cp_size + else: + assert cp_group.size() == local_cp_size else: cp_size = parallel_state.get_context_parallel_world_size() cp_rank = parallel_state.get_context_parallel_rank() - if cp_size > 1: # slice batch along sequence dimension for context parallelism - assert tex is not None and is_te_min_version("1.10.0"), ( - "Please update Transformer Engine to >= 1.10 to use " - "Context Parallel with THD format data" - ) - index = tex.thd_get_partitioned_indices( - cu_seqlens_padded, batch['tokens'].size(1), cp_size, cp_rank - ) - for key, data in batch.items(): - if key in {'attention_mask', 'cu_seqlens', 'cu_seqlens_padded', 'max_seqlen'}: - continue - batch[key] = data.index_select(1, index) - - return batch, packed_seq_params - - -################################ -### hybrid context parallel ### -################################ - + cp_group = None -def get_batch_on_this_hybrid_cp_rank( - batch: Dict[str, Any], - local_cp_size: int, - cp_group: Optional[torch.distributed.ProcessGroup] = None, -): - """Slice batch input along sequence dimension into multiple chunks, - which are parallelized across GPUs in a context parallel group. - """ - assert local_cp_size is not None - if cp_group is None: - # Get the local cp group required for as defined by the HybridCPDataLoaderWrapper - if local_cp_size > 1: - cp_group = parallel_state.get_hybrid_data_context_parallel_groups( - group_size=local_cp_size - ) - else: - # If cp group is provided, it must match the local cp size - # as defined by the HybridCPDataLoaderWrapper - assert cp_group.size() == local_cp_size - - # Convert [seqlen] to [1, seqlen] similar to default collate_fn - # as hybrid_context_parallel dataloader wrapper does not go through default collate_fn - for key, data in batch.items(): - if key in ['attention_mask']: - continue - batch[key] = torch.stack([data], 0) - sample_length = batch['tokens'].shape[1] - # TODO(pmannan): Take care of padding tokens here if not divisible by cp_size*2 - # Create packed_seq_params for SBHD format with cp group information. packed_seq_params = PackedSeqParams( - qkv_format="sbhd", - cu_seqlens_q=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - cu_seqlens_kv=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - cu_seqlens_q_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - cu_seqlens_kv_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - max_seqlen_q=sample_length, - max_seqlen_kv=sample_length, + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens_padded, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, local_cp_size=local_cp_size, cp_group=cp_group, ) - - if cp_group is not None and cp_group.size() > 1: - # When using hybrid_context_parallel, each sub-sample of a packed sample is - # required to be divisible by CP*DP*2 or CP*DP*TP*2 (if using sequence parallel) - batch = get_batch_on_this_cp_rank(batch, cp_group) - - return batch, packed_seq_params + if not only_packed_seq_params: + batch_keys = [] + if parallel_state.is_pipeline_first_stage(vp_stage=vp_stage): + batch_keys += ['tokens', 'position_ids'] + if parallel_state.is_pipeline_last_stage(vp_stage=vp_stage): + batch_keys += ['labels', 'loss_mask'] + + for key in ["tokens", "position_ids", "labels", "loss_mask"]: + if key in batch: + if batch[key] is not None: + batch[key] = batch[key].unsqueeze(0) + + if cp_size > 1: # slice batch along sequence dimension for context parallelism + assert tex is not None and is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 to use " + "Context Parallel with THD format data" + ) + # print(f"tokens shape before cp slice: {batch['tokens'].shape}") + size = ( + batch['tokens'].size(1) if batch['tokens'] is not None else batch['labels'].size(1) + ) + index = tex.thd_get_partitioned_indices(cu_seqlens_padded, size, cp_size, cp_rank) + for key, data in batch.items(): + if key in {'attention_mask'}: + continue + if data is not None: + batch[key] = data.index_select(1, index) + + return batch, packed_seq_params + else: + return batch, packed_seq_params ###################### diff --git a/megatron/legacy/data/data_samplers.py b/megatron/legacy/data/data_samplers.py index 79bdc7b193f..dd5f587935e 100644 --- a/megatron/legacy/data/data_samplers.py +++ b/megatron/legacy/data/data_samplers.py @@ -2,10 +2,15 @@ """Dataloaders.""" +import os +import threading +import ctypes +import sys import random import torch import numpy as np +import torch.multiprocessing as mp from torch.utils.data import Dataset from megatron.training import get_args from megatron.core import mpu @@ -34,14 +39,25 @@ def build_pretraining_data_loader(dataset, consumed_samples): data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size()) elif args.dataloader_type == 'single': - if args.hybrid_context_parallel: - batch_sampler = HybridCPMegatronPretrainingSampler( - total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=args.micro_batch_size, - global_batch_size=args.global_batch_size, - data_parallel_rank=mpu.get_data_parallel_rank(), - data_parallel_size=mpu.get_data_parallel_world_size()) + if args.sft_sequence_packing: + if args.async_hybrid_context_parallel_scheduler: + assert args.hybrid_context_parallel_scheduler == "only_packing_no_scheduling" + batch_sampler = MegatronSFTPrefetchDPBalancedSampler( + dataset=dataset, + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + global_batch_size=args.global_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size()) + else: + batch_sampler = MegatronSFTSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + global_batch_size=args.global_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size()) else: # Megatron sampler batch_sampler = MegatronPretrainingSampler( @@ -68,7 +84,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): args.dataloader_type)) # Torch dataloader. - if args.hybrid_context_parallel: + if args.sft_sequence_packing: extra_kwargs = {"collate_fn": lambda x: x,} else: extra_kwargs = {} @@ -128,7 +144,7 @@ def __iter__(self): start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] -class HybridCPMegatronPretrainingSampler(MegatronPretrainingSampler): +class MegatronSFTSampler(MegatronPretrainingSampler): """ Data sampler for hybrid context parallel (Hybrid CP) format. This data sampler pulls in the entire global batch at once across all data parallel ranks. @@ -162,6 +178,9 @@ def __iter__(self): for i in range(self.num_micro_batches): global_batch_idx.extend(batch[start_idx[i]:end_idx[i]]) yield global_batch_idx + # if torch.distributed.get_rank() == 0: + # print(f"rank={torch.distributed.get_rank()}, {batch=}") + # yield batch batch = [] # Check the last partial batch and see drop_last is set @@ -172,6 +191,136 @@ def __iter__(self): global_batch_idx.extend(batch[start_idx[i]:end_idx[i]]) yield global_batch_idx + +class MegatronSFTPrefetchDPBalancedSampler(MegatronPretrainingSampler): + """ + Data sampler for hybrid context parallel (Hybrid CP) format. + This data sampler pulls in the entire global batch at once across all data parallel ranks. + This helps provide the Hybrid CP Dataloader Wrapper to schedule and load balance sub-samples + of the entire global batch. + """ + + def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size, global_batch_size, + data_parallel_rank, data_parallel_size, drop_last=True): + super().__init__(total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last) + self.dataset = dataset + self.global_batch_size = global_batch_size + self.data_parallel_size = data_parallel_size + self.num_micro_batches = self.global_batch_size // self.micro_batch_times_data_parallel_size + + from megatron.training.yaml_arguments import core_transformer_config_from_yaml + from megatron.training.arguments import core_transformer_config_from_args + args = get_args() + if args.yaml_cfg is not None: + config = core_transformer_config_from_yaml(args, "language_model") + else: + config = core_transformer_config_from_args(args) + + self.config = config + from megatron.core.pipeline_parallel.data_schedule import PipelineAwareBalancedHybridCPscheduler + self.data_scheduler = PipelineAwareBalancedHybridCPscheduler(self.config) + + ctx = mp.get_context('fork') + self._queue1 = ctx.Queue() + self._queue2 = ctx.Queue() + self._prefetch_process = ctx.Process(target=self.prefetch_batch, + args=(self._queue1, self._queue2), + name=f'prefetch_batch', daemon=False) + self._prefetch_process.start() + + def __len__(self): + return self.total_samples + + # def get_start_end_idx_global_batch(self): + # start_idx = [self.data_parallel_rank * self.micro_batch_size + i * self.micro_batch_size * self.data_parallel_size for i in range(self.num_micro_batches)] + # end_idx = [start_idx[i] + self.micro_batch_size for i in range(self.num_micro_batches)] + # return start_idx, end_idx + + def get_shape(self, idx): + data = self.dataset[idx] + shape = data["tokens"].shape + return shape + + def get_numel(self, idx): + data = self.dataset[idx] + numel = data["tokens"].numel() + return [idx, numel] + + def prepare_info(self, batch, batch_numel): + pass + + def prefetch_batch(self, queue1, queue2): + torch.multiprocessing._set_thread_name("pt_prefetch_batch") + torch.set_num_threads(1) + # global_store = DistKVStore(world_size=torch.distributed.get_world_size(), rank=torch.distributed.get_rank(), group_name=global_group_name) + # within_node_store = DistKVStore(world_size=8, rank=torch.distributed.get_rank(), group_name=within_node_group_name) + # assert torch.distributed.get_world_size() % 8 == 0, f"world_size should be divisible by 8" # 单机8卡 + # if torch.distributed.get_rank() % 8 == 0: # 每个节点的0号rank + # cross_node_store = DistKVStore(world_size=torch.distributed.get_world_size() // 8, rank=torch.distributed.get_rank(), group_name=cross_node_group_name) + # else: + # cross_node_store = None + + while True: + full_batch = queue1.get() + # print(f"GET queue1, {full_batch=}") + if full_batch == None: + return + batch_data = self.prepare_batch(full_batch) + queue2.put(batch_data) + # print(f"PUT queue2, {batch_data=}") + + def prepare_batch(self, batch): + batch_numel = [self.get_numel(idx) for idx in batch] + # TODO: use distributed `get_numel` to reduce io pressure. + + groups, sample_id_groups, cp_sizes = self.data_scheduler.get_groups_and_subsamples(batch_numel, self.config, return_cp_sizes=True) + return groups, sample_id_groups, cp_sizes + + def __iter__(self): + # batch = [] + # Last batch will be dropped if drop_last is not set False + batch = list(range(self.consumed_samples, min(self.consumed_samples + self.global_batch_size, self.total_samples))) + # print(f"PUT queue1, {batch=}") + self._queue1.put(batch) + while self.consumed_samples < self.total_samples: + # for idx in range(self.consumed_samples, self.total_samples): + # batch.append(idx) + # if len(batch) == self.global_batch_size: + # groups, sample_id_groups, cp_sizes = self.prepare_batch(batch) + batch_data = self._queue2.get(timeout=3000) + groups, sample_id_groups, cp_sizes = batch_data + # print(f"GET queue2, {groups=}") + + consumed_samples_before = self.consumed_samples + next_full_batch = list(range(consumed_samples_before + self.global_batch_size, min(consumed_samples_before + 2*self.global_batch_size, self.total_samples))) + # print(f"PUT queue1, {next_full_batch=}") + self._queue1.put(next_full_batch) + + # global_batch_idx = [] + for microbatch_idx in range(len(sample_id_groups)): + microbatch = sample_id_groups[microbatch_idx][self.data_parallel_rank] + microbatch_cp_sizes = cp_sizes[microbatch_idx][self.data_parallel_rank] + num_microbatch_left = [len(sample_id_groups)-microbatch_idx-1] * len(microbatch) + # print(f"{groups=}\n{sample_id_groups=}") + yield list(zip(microbatch, num_microbatch_left, microbatch_cp_sizes)) + # global_batch_idx.extend(microbatch) + + # yield global_batch_idx + # batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + groups, sample_id_groups, cp_sizes = self.prepare_batch(batch) + global_batch_idx = [] + for microbatch_idx in range(len(sample_id_groups)): + microbatch = sample_id_groups[microbatch_idx][self.data_parallel_rank] + microbatch_cp_sizes = cp_sizes[microbatch_idx][self.data_parallel_rank] + num_microbatch_left = [len(sample_id_groups)-microbatch_idx-1] * len(microbatch) + assert len(microbatch) == len(microbatch_cp_sizes) + global_batch_idx.extend(list(zip(microbatch, num_microbatch_left, microbatch_cp_sizes))) + yield global_batch_idx + + class RandomSeedDataset(Dataset): def __init__(self, dataset): diff --git a/megatron/legacy/model/transformer.py b/megatron/legacy/model/transformer.py index 2a662a55b16..412d5b094c4 100644 --- a/megatron/legacy/model/transformer.py +++ b/megatron/legacy/model/transformer.py @@ -836,6 +836,7 @@ def forward(self, hidden_states, attention_mask, # Output. [sq, b, h] # ================= + output, bias = self.dense(context_layer) return output, bias diff --git a/megatron/pipeline_simulator/.gitignore b/megatron/pipeline_simulator/.gitignore new file mode 100644 index 00000000000..0a197900e25 --- /dev/null +++ b/megatron/pipeline_simulator/.gitignore @@ -0,0 +1,174 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc diff --git a/megatron/pipeline_simulator/.python-version b/megatron/pipeline_simulator/.python-version new file mode 100644 index 00000000000..e4fba218358 --- /dev/null +++ b/megatron/pipeline_simulator/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/megatron/pipeline_simulator/README.md b/megatron/pipeline_simulator/README.md new file mode 100644 index 00000000000..2d156733961 --- /dev/null +++ b/megatron/pipeline_simulator/README.md @@ -0,0 +1,3 @@ +# Pipeline Simulator V2 + +TODO(Wei Zhang): Add a description of the project here. diff --git a/megatron/pipeline_simulator/hotsim/memory_model.py b/megatron/pipeline_simulator/hotsim/memory_model.py new file mode 100644 index 00000000000..cedbb78a3c0 --- /dev/null +++ b/megatron/pipeline_simulator/hotsim/memory_model.py @@ -0,0 +1,522 @@ +from .model import Model +from .schedule import ( + Action, + Chunk, + Op, + build_splitfuse_schedule, + build_hybrid_schedule, + build_1f1b_schedule, + plot_combined_visualization, +) +from .training_config import TrainingConfig + + +class MemoryModel: + """ + MemoryModel simulates memory usage during pipeline parallel training of large language models. + + This class tracks memory allocation and deallocation for model parameters, optimizer states, + activations, KV caches, and gradients during the forward and backward passes. + + The model supports various parallelism strategies (tensor, pipeline, expert, context, data) + and checkpointing configurations to optimize memory usage. + + Methods: + setup(chunks_list): Prepares the memory model with chunk information and builds the execution schedule + run(rank): Simulates execution on the specified rank and calculates memory usage + """ + + def __init__(self, config: TrainingConfig): + self.model = config.model + self.config = config + + if self.config.microbatch_size != 1: + raise ValueError("This memory model only supports a microbatch size of 1.") + + self._init_memory() + + def _init_memory(self) -> None: + # Model parameters (FP16) and gradients (FP32): 2 + 4 = 6 bytes per parameter + self.parameter_size = ( + 6 + * self.model.params_per_layer + * self.config.num_layers_per_stage + / self.config.tensor_parallel_size + ) + + # Optimizer states: weights (4) + momentum (4) + variance (4) = 12 bytes per parameter + self.optimizer_size = ( + 12 + * self.model.params_per_layer + * self.config.num_layers_per_stage + / self.config.tensor_parallel_size + / self.config.context_parallel_size + / self.config.data_parallel_size + ) + + # Model states and optimizer states for the embedding layer + self.embedding_size = ( + 6 + * self.model.vocab_size + * self.model.hidden_size + / self.config.tensor_parallel_size + ) + ( + 12 + * self.model.vocab_size + * self.model.hidden_size + / self.config.tensor_parallel_size + / self.config.context_parallel_size + / self.config.data_parallel_size + ) + + # TODO(Wei Zhang): Implement output size calculation + # Output logits memory (FP32): 4 bytes per element + # self.output_size = ( + # 4 + # * self.config.microbatch_size + # * self.config.seq_length + # * self.model.vocab_size + # / self.config.tensor_parallel_size + # / self.config.context_parallel_size + # ) + self.output_size = 0 + + def _init_cache(self) -> None: + num_layers = self.model.num_hidden_layers + # key: stage_id, value: Dict[(batch_id, chunk_id), size] + self.activations: list[dict[tuple[int, int], float]] = [ + dict() for _ in range(num_layers) + ] + self.kv_caches: list[dict[tuple[int, int], float]] = [ + dict() for _ in range(num_layers) + ] + self.kv_gradients: list[dict[tuple[int, int], float]] = [ + dict() for _ in range(num_layers) + ] + self.offload_caches: list[dict[tuple[int, int], float]] = [ + dict() for _ in range(num_layers) + ] + # There will be two buffers for offload and reload + self.offload_buffer_size = 0.0 + # Memory changes for every forward/backward pass + self.memory_histogram: list[float] = [] + self.peak_memory_histogram: list[float] = [] + CC = 90 + self.matmul_buffer = 2 * {80: 8320 * 1024, 90: 32 * 1024 * 1024}[CC] + + def _get_activation_size(self, chunk: Chunk, recompute=False) -> float: + # Activation memory (FP16): 2 bytes per activation + activation_size = ( + 2 + * self.model.acts_per_layer( + chunk.batch_size, chunk.length, self.config.ckpt, recompute + ) + / self.config.tensor_parallel_size + / self.config.context_parallel_size + ) + return activation_size + + def _get_kv_cache_size(self, chunk: Chunk) -> float: + num_chunks = self.num_chunks[chunk.batch_id] + if num_chunks == 1: + # No kv cache if not sliced + return 0 + kv_cache_size = ( + 2 + * self.model.kv_acts_per_layer(chunk.batch_size, chunk.length) + / self.config.tensor_parallel_size + / self.config.context_parallel_size + ) + return kv_cache_size + + def _get_kv_gradient_size(self, chunk: Chunk) -> float: + batch_id = chunk.batch_id + chunk_id = chunk.chunk_id + num_chunks = self.num_chunks[batch_id] + if num_chunks == 1 or chunk_id == 0: + # No kv gradient if not sliced or first chunk + return 0 + + # dKdV will be used for all previous chunks + length = sum(self.chunks_list[batch_id][:chunk_id]) + + kv_gradient_size = ( + 2 + * self.model.kv_acts_per_layer(chunk.batch_size, length) + / self.config.tensor_parallel_size + / self.config.context_parallel_size + ) + return kv_gradient_size + + def _get_offload_cache_size(self, activation_size: float) -> float: + offload_cache_size = activation_size * self.config.offload_ratio + return offload_cache_size + + def _forward_layer(self, action: Action, layer_id: int) -> None: + chunk = action.chunk + stage_id, batch_id, chunk_id = action.stage_id, chunk.batch_id, chunk.chunk_id + + key = (batch_id, chunk_id) + + # Forward generates activations and kv caches + kv_cache_size = self._get_kv_cache_size(action.chunk) + activation_size = self._get_activation_size(action.chunk) + # We have duplicate KV in activations unless full checkpointing is used + if self.config.ckpt != "full": + activation_size -= kv_cache_size + + # If the new activation size is larger than the offload buffer size, + # we need to expand the offload buffer + offload_cache_size = self._get_offload_cache_size(activation_size) + offload_buffer_expansion = 0.0 + expanded_buffer_size = 2 * 3 * offload_cache_size * self.config.offload_ratio + if self.offload_buffer_size < expanded_buffer_size: + offload_buffer_expansion = expanded_buffer_size - self.offload_buffer_size + self.offload_buffer_size = expanded_buffer_size + + tp_all_gather_buffer = ( + self.config.microbatch_size + * action.chunk.length + // self.config.context_parallel_size + * self.model.hidden_size + * 2 + if self.config.tensor_parallel_size >= 2 + else 0 + ) + # The peak memory occurs just before we start offloading + peak_memory = ( + self.memory_histogram[-1] + + activation_size + + kv_cache_size + + offload_buffer_expansion + + tp_all_gather_buffer + ) + self.peak_memory_histogram.append(peak_memory) + + # Some activations are offloaded + remaining_activation_size = activation_size - offload_cache_size + + self.activations[layer_id][key] = remaining_activation_size + self.kv_caches[layer_id][key] = kv_cache_size + self.offload_caches[layer_id][key] = offload_cache_size + + # Add up what's left in memory + memory = ( + self.memory_histogram[-1] + + remaining_activation_size + + kv_cache_size + + offload_buffer_expansion + ) + self.memory_histogram.append(memory) + + def _forward(self, action: Action) -> None: + for i in range(self.config.num_layers_per_stage): + before = self.memory_histogram[-1] + self._forward_layer( + action, action.stage_id * self.config.num_layers_per_stage + i + ) + after = self.memory_histogram[-1] + + def _backward_layer(self, action: Action, layer_id: int) -> None: + chunk = action.chunk + stage_id, batch_id, chunk_id = action.stage_id, chunk.batch_id, chunk.chunk_id + + key = (batch_id, chunk_id) + + # Reload activations from offload cache + offload_cache_size = self.offload_caches[layer_id].pop(key) + activation_size = self.activations[layer_id].pop(key) + kv_cache_size = self.kv_caches[layer_id].pop(key) + kv_gradient_size = self._get_kv_gradient_size(chunk) + if chunk_id > 0: + self.kv_gradients[layer_id][key] = kv_gradient_size + + recomputed_activation_size = self._get_activation_size(chunk, recompute=True) + # With full checkpointing, KV are regenerated during recompute + if self.config.ckpt == "full": + recomputed_activation_size -= kv_cache_size + + tp_all_gather_buffer = ( + self.config.microbatch_size + * action.chunk.length + // self.config.context_parallel_size + * self.model.hidden_size + * 2 + if self.config.tensor_parallel_size >= 2 + else 0 + ) + # The peak memory occurs just before we discard activations, kv cache and kv gradients + peak_memory = ( + self.memory_histogram[-1] + + recomputed_activation_size + + offload_cache_size + + kv_gradient_size + + tp_all_gather_buffer + + (1 + 2) * 2 * chunk.batch_size * chunk.length * self.model.intermediate_size # fc2 bwd & swiglu_back + ) + self.peak_memory_histogram.append(peak_memory) + + prev_kv_gradient_size = 0.0 + if chunk_id < self.num_chunks[batch_id] - 1: + prev_key = (batch_id, chunk_id + 1) + prev_kv_gradient_size = self.kv_gradients[layer_id].pop(prev_key) + + # Release activations, kv cache and previous kv gradients + memory = ( + self.memory_histogram[-1] + - activation_size + - kv_cache_size + + kv_gradient_size + - prev_kv_gradient_size + ) + self.memory_histogram.append(memory) + + def _backward(self, action: Action) -> None: + for i in reversed(range(self.config.num_layers_per_stage)): + self._backward_layer( + action, action.stage_id * self.config.num_layers_per_stage + i + ) + + def _validate_cache(self) -> None: + storage_list = [ + self.activations, + self.kv_caches, + self.kv_gradients, + self.offload_caches, + ] + for storage in storage_list: + for cache in storage: + if len(cache) != 0: + raise ValueError("Memory leakage detected.") + + def _simulate_execution(self, rank: int = 0) -> None: + self._init_cache() + base_memory = self.parameter_size + self.optimizer_size + self.matmul_buffer + if rank == 0: + base_memory += self.embedding_size + if rank == self.config.pipeline_parallel_size - 1: + base_memory += self.output_size + self.embedding_size + self.memory_histogram = [base_memory] + self.peak_memory_histogram = [base_memory] + + actions = self.actions_by_rank[rank] + for i, action in enumerate(actions): + if action is None: + # Pipeline bubble + memory = self.memory_histogram[-1] + self.peak_memory_histogram.append(memory) + self.memory_histogram.append(memory) + elif action.op == Op.FORWARD: + self._forward(action) + elif action.op == Op.BACKWARD: + self._backward(action) + else: + raise ValueError(f"Unknown operation: {action.op}") + + self._validate_cache() + + def setup( + self, chunks_list: list[list[int]], actions_by_rank: list[list[Action]] + ) -> None: + """ + Setup the memory model with chunk information and build the execution schedule. + Parameters + ---------- + chunks_list : list[list[int]] + List of micro-batch slice counts for each batch to process. + Each number must be divisible by the number of ranks (p). + kfkb : bool, optional + Flag to indicate if KFKB scheduling is enabled. Default is False. + """ + self.chunks_list = chunks_list + self.num_microbatches = len(chunks_list) + self.num_chunks = [len(chunks) for chunks in chunks_list] + self.actions_by_rank = actions_by_rank + + def run(self, rank: int = 0): + """ + Simulate execution on the specified rank and calculate memory usage. + Parameters + ---------- + rank : int + The rank to simulate execution for. Default is 0. + """ + if rank < 0 or rank >= self.config.pipeline_parallel_size: + raise ValueError( + f"Rank {rank} is out of range. Must be between 0 and {self.config.pipeline_parallel_size - 1}." + ) + if self.actions_by_rank is None: + raise ValueError("The schedule has not been built yet. Call setup() first.") + + self._simulate_execution(rank) + + +def test_splitfuse_schedule(): + model = Model( + name="Llama 7B", + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + ) + config = TrainingConfig( + model=model, + num_gpus=8, + microbatch_size=1, + tensor_parallel_size=1, + context_parallel_size=1, + data_parallel_size=2, + pipeline_parallel_size=4, + expert_parallel_size=1, + num_model_chunks=1, + ckpt="partial", + offload_ratio=0.0, + ) + print(model) + # Example usage + p = config.pipeline_parallel_size # Number of ranks + chunks_list = [ + [4545, 4432], + [4545, 4542], + [4564], + [4566], + [4594], + [4624], + [4638], + [4616], + [4629], + [4645], + [4644], + [4647], + [4644], + [4673], + [4667], + [4672], + [4671], + [4817], + [4948], + [4910], + [4971], + [4959], + ] + actions_by_rank = build_splitfuse_schedule(p, chunks_list=chunks_list) + + memory_model = MemoryModel(config) + memory_model.setup(chunks_list, actions_by_rank) + memory_model.run(rank=1) + print(f"Initial memory usage: {memory_model.memory_histogram[0] / 1024**2:.2f} MiB") + print( + f"Peak memory usage: {max(memory_model.peak_memory_histogram) / 1024**2:.2f} MiB" + ) + print( + f"Delta memory usage: {(max(memory_model.peak_memory_histogram) - memory_model.memory_histogram[0]) / 1024**2:.2f} MiB" + ) + print(f"offload buffer size: {memory_model.offload_buffer_size / 1024**2:.2f} MiB") + plot_combined_visualization( + memory_model.actions_by_rank, + memory_model.memory_histogram, + memory_model.peak_memory_histogram, + ) + + +def test_hybrid_schedule(): + model = Model( + name="Llama 13B", + vocab_size=128000, + hidden_size=5120, + intermediate_size=13824, + num_hidden_layers=40, + num_attention_heads=40, + ) + config = TrainingConfig( + model=model, + num_gpus=32, + microbatch_size=1, + tensor_parallel_size=8, + context_parallel_size=1, + data_parallel_size=1, + pipeline_parallel_size=4, + expert_parallel_size=1, + num_model_chunks=1, + ckpt="no", + offload_ratio=0, + ) + # Example usage + p = config.pipeline_parallel_size # Number of ranks + k = 2 + num_chunks = [2] * 8 # Number of slices for each batch + seq_length = 32 * 1024 # Sequence length + chunks_list = [] + for num in num_chunks: + chunks_list.append([seq_length // num] * num) + + actions_by_rank = build_hybrid_schedule( + p, k, fwd_switch=(3, 0), bwd_switch=(3, 1), chunks_list=chunks_list + ) + + memory_model = MemoryModel(config) + memory_model.setup(chunks_list, actions_by_rank) + memory_model.run() + print( + f"Peak memory usage: {max(memory_model.peak_memory_histogram) / 1024**3:.2f} GiB" + ) + plot_combined_visualization( + memory_model.actions_by_rank, + memory_model.memory_histogram, + memory_model.peak_memory_histogram, + ) + + +def test_1f1b_schedule(): + model = Model( + name="Llama 7B", + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + ) + config = TrainingConfig( + model=model, + num_gpus=16, + microbatch_size=1, + tensor_parallel_size=2, + context_parallel_size=1, + data_parallel_size=2, + pipeline_parallel_size=4, + expert_parallel_size=1, + num_model_chunks=1, + ckpt="no", + offload_ratio=0.0, + ) + # Example usage + p = config.pipeline_parallel_size # Number of ranks + chunks = [16 * 1024] * 8 + chunks_list = [[16 * 1024]] * 8 + actions_by_rank = build_1f1b_schedule(p, chunks=chunks) + + memory_model = MemoryModel(config) + memory_model.setup(chunks_list, actions_by_rank) + for i in range(p): + memory_model.run(rank=i) + print(f"Rank {i} memory usage:") + print( + f"Initial memory usage: {memory_model.memory_histogram[0] / 1024**2:.2f} MiB" + ) + print( + f"Peak memory usage: {max(memory_model.peak_memory_histogram) / 1024**2:.2f} MiB" + ) + print( + f"Delta memory usage: {(max(memory_model.peak_memory_histogram) - memory_model.memory_histogram[0]) / 1024**2:.2f} MiB" + ) + print( + f"offload buffer size: {memory_model.offload_buffer_size / 1024**2:.2f} MiB" + ) + print("=" * 80) + + +if __name__ == "__main__": + test_splitfuse_schedule() + test_hybrid_schedule() + test_1f1b_schedule() diff --git a/megatron/pipeline_simulator/hotsim/model.py b/megatron/pipeline_simulator/hotsim/model.py new file mode 100644 index 00000000000..cf26bc9b885 --- /dev/null +++ b/megatron/pipeline_simulator/hotsim/model.py @@ -0,0 +1,304 @@ +from dataclasses import dataclass, field + + +@dataclass +class Model: + """Configuration class for a LLaMA-like model architecture. + + Defines parameters for a transformer model including dimensions, layers, and attention configuration. + + Attributes: + name: Model name (auto-generated based on parameters if not provided) + vocab_size: Size of token vocabulary + hidden_size: Dimension of hidden representations + intermediate_size: Dimension of MLP layers + num_hidden_layers: Number of transformer layers + num_attention_heads: Number of attention heads + num_key_value_heads: Number of key/value heads for grouped query attention + num_experts: Number of expert layers in mixture of experts + num_active_experts: Number of experts activated per token + moe_layer_interval: Interval between MoE layers + head_dim: Dimension of each attention head (calculated automatically) + """ + + name: str = field(default=None, repr=True) + vocab_size: int = 128000 + hidden_size: int = 4096 + intermediate_size: int = 11008 + num_hidden_layers: int = 32 + num_attention_heads: int = field(default=None, repr=True) + num_key_value_heads: int = field(default=None, repr=True) + num_experts: int = 1 + num_active_experts: int = 1 + moe_layer_interval: int = 1 + head_dim: int = field(init=False, repr=False) + + def __post_init__(self): + # Set default attention heads if not provided + if self.num_attention_heads is None: + self.num_attention_heads = self.hidden_size // 128 + + # Calculate head dimension + self.head_dim = self.hidden_size // self.num_attention_heads + + # Set default KV heads if not provided + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + detailed_params = self.calc_detailed_parameters() + self.total_params = detailed_params["total_params"] + self.active_params = detailed_params["active_params"] + self.params_per_dense_layer = detailed_params["params_per_dense_layer"] + self.params_per_sparse_layer = detailed_params["params_per_sparse_layer"] + self.active_params_per_sparse_layer = detailed_params[ + "active_params_per_sparse_layer" + ] + self.expert_params_per_sparse_layer = detailed_params[ + "expert_params_per_sparse_layer" + ] + self.non_expert_params_per_sparse_layer = detailed_params[ + "non_expert_params_per_sparse_layer" + ] + + is_moe = self.num_experts > 1 + self.params_per_layer = ( + self.params_per_sparse_layer if is_moe else self.params_per_dense_layer + ) + + # Auto-generate name if not provided + if self.name is None: + if is_moe: + self.name = ( + f"Mixtral-{self.num_experts}x{self.active_params / 1e9:.0f}B" + ) + else: + gqa_suffix = ( + "-GQA" + if self.num_key_value_heads < self.num_attention_heads + else "" + ) + self.name = f"LLaMA-{self.total_params / 1e9:.0f}B{gqa_suffix}" + + def kv_acts_per_layer(self, batch_size: int, seq_len: int) -> float: + """Calculates activations for key-value cache per layer. + + Args: + batch_size: Batch size of input + seq_length: Sequence length of input + + Returns: + float: Number of key-value activations + """ + kv_ratio = self.num_key_value_heads / self.num_attention_heads + return batch_size * seq_len * self.hidden_size * (2 * kv_ratio) + + def acts_per_layer( + self, batch_size: int, seq_len: int, ckpt: str = "no", recompute: bool = False + ) -> float: + """Calculates activations needed per transformer layer. + + Args: + batch_size: Number of sequences in batch + seq_length: Length of each sequence + ckpt: Activation checkpointing strategy ("no", "partial", "partial+fc1", or "full") + recompute: Whether to return recomputed activations instead of stored ones + + Returns: + float: Number of activation elements needed + """ + # Model dimensions + h = self.hidden_size + H = self.intermediate_size + b = batch_size + s = seq_len + kv_ratio = self.num_key_value_heads / self.num_attention_heads + + # If full checkpointing is used, we only need to store layer input + if ckpt == "full" and not recompute: + return b * s * h + + # Calculate individual activation components + attn_input = b * s * h # Input to attention block + attn_norm = b * s * h # Layer norm output (not counted separately) + attn = b * s * h * (2 + 2 * kv_ratio) # QKV projections and output + ffn_input = b * s * h # Input to feed-forward network + ffn_norm = b * s * h # Layer norm output (not counted separately) + fc1 = 2 * b * s * H * self.num_active_experts # First FFN projection for SwiGLU + swiglu = b * s * H * self.num_active_experts # SwiGLU activations + + # Sum all activations + total_acts = attn_input + attn + ffn_input + fc1 + swiglu + + # Return appropriate activations based on checkpointing strategy + if ckpt == "no": + return total_acts if not recompute else 0 + elif ckpt == "partial": + return total_acts - swiglu if not recompute else swiglu + elif ckpt == "partial+fc1": + return total_acts - fc1 - swiglu if not recompute else fc1 + swiglu + elif ckpt == "full": + return attn_input if not recompute else total_acts - attn_input + + # Handle invalid checkpointing strategy + raise ValueError(f"Unknown checkpointing strategy: {ckpt}") + + def tflops(self, batch_size: int, seq_len: int) -> float: + """Calculates total TeraFLOPs for a complete forward and backward pass. + + Args: + batch_size: Batch size of input + seq_length: Sequence length of input + + Returns: + float: Total TFLOPs required + """ + detailed_flops = self.calc_detailed_flops(seq_len) + return 3 * batch_size * detailed_flops["flops_per_forward"] / 1e12 + + def calc_detailed_parameters(self) -> dict: + """Calculate detailed model parameters breakdown. + + Provides a comprehensive breakdown of parameter counts for both standard + transformer models and mixture-of-experts architectures. + + Returns: + dict: Parameter statistics including total, active, and per-layer counts + """ + # Basic model configuration + is_moe = self.num_experts > 1 + h = self.hidden_size + H = self.intermediate_size + V = self.vocab_size + kv_ratio = self.num_key_value_heads / self.num_attention_heads + + # Calculate layer distribution + num_sparse_layers = ( + self.num_hidden_layers // self.moe_layer_interval if is_moe else 0 + ) + num_dense_layers = self.num_hidden_layers - num_sparse_layers + + # Calculate attention parameters (same for both dense and sparse layers) + params_attention = ( + h + h * h * 2 + h * (h * kv_ratio) * 2 + ) # Q/K/V bias + Q,O,K,V projections + + # Calculate dense layer parameters + params_dense_layer = ( + params_attention + h + 3 * h * H + ) # Attention + layer norm + MLP + + # Calculate MoE layer parameters (total and active) + params_sparse_layer = 0 + active_params_sparse_layer = 0 + expert_params_sparse_layer = 0 + non_expert_params_sparse_layer = 0 + + if is_moe: + params_sparse_layer = ( + params_attention + + h + + h * self.num_experts + + 3 * h * H * self.num_experts + ) # Attention + norm + router + MLP + active_params_sparse_layer = ( + params_attention + + h + + h * self.num_experts + + 3 * h * H * self.num_active_experts + ) + expert_params_sparse_layer = 3 * h * H * self.num_experts + non_expert_params_sparse_layer = ( + params_sparse_layer - expert_params_sparse_layer + ) + + # Calculate total parameters + total_params = ( + num_dense_layers * params_dense_layer + + num_sparse_layers * params_sparse_layer + + h # Final layer norm + + 2 * h * V # Embedding and output layers + ) + + # Calculate active parameters during forward pass + active_params = ( + num_dense_layers * params_dense_layer + + num_sparse_layers * active_params_sparse_layer + + h # Final layer norm + + 2 * h * V # Embedding and output layers + ) + + return { + "total_params": total_params, + "active_params": active_params, + "params_per_dense_layer": params_dense_layer, + "params_per_sparse_layer": params_sparse_layer, + "active_params_per_sparse_layer": active_params_sparse_layer, + "expert_params_per_sparse_layer": expert_params_sparse_layer, + "non_expert_params_per_sparse_layer": non_expert_params_sparse_layer, + } + + def calc_detailed_flops(self, seq_len: int) -> dict: + """Calculate detailed FLOPs breakdown for model computation. + + Provides FLOP counts for model components during the forward pass. + + Args: + seq_len: Sequence length for FLOP calculations + + Returns: + dict: FLOP counts for dense layers, sparse layers, and complete forward pass + """ + # Basic model configuration + is_moe = self.num_experts > 1 + h = self.hidden_size + H = self.intermediate_size + V = self.vocab_size + s = seq_len + kv_ratio = self.num_key_value_heads / self.num_attention_heads + + # Calculate layer distribution + num_sparse_layers = ( + self.num_hidden_layers // self.moe_layer_interval if is_moe else 0 + ) + num_dense_layers = self.num_hidden_layers - num_sparse_layers + + # Calculate FLOPs for attention mechanism + flops_attention = ( + 4 * s * h * h # Q, K, V projections + + 4 * s * h * (h * kv_ratio) # Attention computations + + 2 * s * s * h # Attention matrix multiplications + ) + + # Calculate FLOPs for dense layer + flops_dense_layer = flops_attention + 6 * s * h * H # MLP operations + + # Calculate FLOPs for MoE layer + flops_sparse_layer = 0 + if is_moe: + flops_sparse_layer = ( + flops_attention + + 2 * s * h * self.num_experts # Router + + 6 * s * h * H * self.num_active_experts # MLP operations + ) + + # Calculate total forward pass FLOPs + flops_forward = ( + num_dense_layers * flops_dense_layer + + num_sparse_layers * flops_sparse_layer + + 2 * s * h * V # Embedding and final projection + ) + + return { + "flops_per_dense_layer": flops_dense_layer, + "flops_per_sparse_layer": flops_sparse_layer, + "flops_per_forward": flops_forward, + } + + def __repr__(self) -> str: + """String representation of the model configuration.""" + return ( + f"({self.name}, V={self.vocab_size}, h={self.hidden_size}, " + f"H={self.intermediate_size}, L={self.num_hidden_layers}, " + f"a={self.num_attention_heads}, g={self.num_key_value_heads}, " + f"d={self.head_dim}, e={self.num_experts}, topk={self.num_active_experts})" + ) diff --git a/megatron/pipeline_simulator/hotsim/schedule.py b/megatron/pipeline_simulator/hotsim/schedule.py new file mode 100644 index 00000000000..04cdf6c01a8 --- /dev/null +++ b/megatron/pipeline_simulator/hotsim/schedule.py @@ -0,0 +1,517 @@ +import pprint +from dataclasses import dataclass +from enum import Enum + +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np + +pp = pprint.PrettyPrinter() + + +class Op(Enum): + FORWARD = "F" + BACKWARD = "B" + + +@dataclass(frozen=True) +class Chunk: + batch_id: int + chunk_id: int + length: int + # Assume that we always use a batch size of 1 for simplicity + batch_size: int = 1 + + +@dataclass(frozen=True) +class Action: + stage_id: int + op: Op + chunk: Chunk + + +def plot_timeline(actions_by_rank: list[list[Action]]) -> None: + fig = plt.figure(figsize=(16, 4)) + ax = fig.add_axes([0, 0, 1, 1]) + height = len(actions_by_rank) + width = max([len(actions) for actions in actions_by_rank]) + blues = ["#BFDBFE", "#60A5FA", "#2563EB", "#1E40AF"] + greens = ["#BBF7D0", "#34D399", "#16A34A", "#166534"] + for rank, actions in enumerate(actions_by_rank): + for step, action in enumerate(actions): + if action is None: + continue + if action.op == Op.FORWARD: + color = blues[action.chunk.batch_id % len(blues)] + else: + color = greens[action.chunk.batch_id % len(greens)] + ax.add_patch( + patches.Rectangle( + (step / width, rank / height), + 1 / width, + 1 / height, + facecolor=color, + edgecolor="black", + ) + ) + if action is not None: + chunk = action.chunk + ax.text( + (step + 0.5) / width, + (rank + 0.5) / height, + f"{chunk.chunk_id}\n{chunk.batch_id}", + fontsize=8, + verticalalignment="center", + horizontalalignment="center", + ) + ax.invert_yaxis() + plt.show() + + +def plot_memory_histogram(histogram: list[int], peak_histogram: list[int]) -> None: + histogram = np.array(histogram) / 1024**3 # Convert to GiB + peak_histogram = np.array(peak_histogram) / 1024**3 # Convert to GiB + + fig, ax = plt.subplots(figsize=(10, 6)) + ax.plot(histogram, label="Memory Usage", color="blue") + ax.plot(peak_histogram, label="Peak Memory Usage", color="red") + ax.set_xlabel("Time Steps") + ax.set_ylabel("Memory Size (GiB)") + ax.set_title("Memory Usage Over Time") + ax.legend() + plt.show() + + +def plot_combined_visualization( + actions_by_rank: list[list[Action]], histogram: list[int], peak_histogram: list[int] +) -> None: + """ + Create a combined visualization with timeline and memory histogram stacked vertically. + + Parameters: + ----------- + actions_by_rank : list[list[Action]] + Actions organized by ranks for timeline visualization + histogram : list[int] + Memory usage over time + peak_histogram : list[int] + Peak memory usage over time + """ + fig, (ax1, ax2) = plt.subplots( + 2, 1, figsize=(16, 10), gridspec_kw={"height_ratios": [1, 1]} + ) + + # Plot timeline on the top subplot + height = len(actions_by_rank) + width = max([len(actions) for actions in actions_by_rank]) + blues = ["#BFDBFE", "#60A5FA", "#2563EB", "#1E40AF"] + greens = ["#BBF7D0", "#34D399", "#16A34A", "#166534"] + + for rank, actions in enumerate(actions_by_rank): + for step, action in enumerate(actions): + if action is None: + continue + if action.op == Op.FORWARD: + color = blues[action.chunk.batch_id % len(blues)] + else: + color = greens[action.chunk.batch_id % len(greens)] + ax1.add_patch( + patches.Rectangle( + (step, rank), + 1, + 1, + facecolor=color, + edgecolor="black", + ) + ) + if action is not None: + chunk = action.chunk + ax1.text( + step + 0.5, + rank + 0.5, + f"{chunk.chunk_id}\n{chunk.batch_id}", + fontsize=8, + verticalalignment="center", + horizontalalignment="center", + ) + + ax1.set_xlim(0, width) + ax1.set_ylim(height, 0) # Invert y-axis + ax1.set_title("Pipeline Execution Timeline") + ax1.set_xlabel("Time Steps") + ax1.set_ylabel("Ranks") + + # Plot memory histogram on the bottom subplot + histogram = np.array(histogram) / 1024**3 # Convert to GiB + peak_histogram = np.array(peak_histogram) / 1024**3 # Convert to GiB + ax2.plot(histogram, label="Memory Usage", color="blue") + ax2.plot(peak_histogram, label="Peak Memory Usage", color="red") + ax2.set_xlabel("Time Steps") + ax2.set_ylabel("Memory Size (GiB)") + ax2.set_title("Memory Usage Over Time") + ax2.legend() + ax2.set_xlim(0, len(histogram)) + ax2.grid(visible=True, which="both", linestyle="--", linewidth=0.5) + + plt.tight_layout() + plt.show() + + +def build_slimpipe_schedule(p, v, chunks_list: list[list[int]]) -> list[list[Action]]: + """ + Simulate pipeline parallelism with different configurations. + + This function simulates the behavior of pipeline parallelism in distributed training, + analyzing how data flows through different pipeline stages across multiple ranks. + + Parameters + ---------- + p : int + Number of ranks (processors/GPUs) available for parallel processing. + v : int + Number of pipeline stages in the model. + num_chunks : list[int] + List of micro-batch slice counts for each batch to process. + Each number must be divisible by the number of ranks (p). + + Returns + ------- + None + """ + num_chunks = [len(chunks) for chunks in chunks_list] + + if any(n == 0 for n in num_chunks): + raise ValueError("Number of slices must be greater than 0.") + if any(n % p != 0 for n in num_chunks): + raise ValueError("Number of slices must be divisible by the number of ranks.") + if any(num_chunks[i] > num_chunks[i - 1] for i in range(1, len(num_chunks))): + raise ValueError("Number of slices must be non-increasing.") + + m = len(num_chunks) + forwards = [] + for batch_id in range(m): + n = num_chunks[batch_id] + for start_chunk_id in range(0, n, p): + for stage_id in range(v): + for chunk_id in range(start_chunk_id, start_chunk_id + p): + chunk = Chunk( + batch_id=batch_id, + chunk_id=chunk_id, + length=chunks_list[batch_id][chunk_id], + ) + computation = Action( + stage_id=stage_id, + op=Op.FORWARD, + chunk=chunk, + ) + forwards.append(computation) + + backwards = [] + for batch_id in range(m): + n = num_chunks[batch_id] + for start_chunk_id in range(n - 1, -1, -p): + for stage_id in range(v - 1, -1, -1): + for chunk_id in range(start_chunk_id, start_chunk_id - p, -1): + chunk = Chunk( + batch_id=batch_id, + chunk_id=chunk_id, + length=chunks_list[batch_id][chunk_id], + ) + computation = Action( + stage_id=stage_id, + op=Op.BACKWARD, + chunk=chunk, + ) + backwards.append(computation) + + actions = [] + warmup = num_chunks[0] * v + fwd, bwd = 1 - p, 0 + # While there are still backward slices to process on the first rank + while bwd < len(forwards) + p - 1: + if fwd < warmup: + op = Op.FORWARD + elif fwd == len(forwards): + op = Op.BACKWARD + elif fwd - bwd == warmup: + op = Op.BACKWARD + else: + op = Op.FORWARD + + ops = [] + if op == Op.FORWARD: + for rank in range(p): + fwd_idx = fwd + rank + ops.append( + forwards[fwd_idx] + if fwd_idx >= 0 and fwd_idx < len(forwards) + else None + ) + fwd += 1 + else: + for rank in range(p): + bwd_idx = bwd - rank + ops.append( + backwards[bwd_idx] + if bwd_idx >= 0 and bwd_idx < len(backwards) + else None + ) + bwd += 1 + actions.append(ops[::-1]) + + actions_by_rank = [list(row) for row in zip(*actions)] + return actions_by_rank + + +def build_splitfuse_schedule(p, chunks_list: list[list[int]]) -> list[list[Action]]: + forwards = [] + for batch_id, chunks in enumerate(chunks_list): + for chunk_id, size in enumerate(chunks): + chunk = Chunk( + batch_id=batch_id, + chunk_id=chunk_id, + length=size, + ) + action = Action( + stage_id=0, + op=Op.FORWARD, + chunk=chunk, + ) + forwards.append(action) + + backwards = [] + for batch_id, chunks in enumerate(chunks_list): + for chunk_id, size in reversed(list(enumerate(chunks))): + chunk = Chunk( + batch_id=batch_id, + chunk_id=chunk_id, + length=size, + ) + action = Action( + stage_id=0, + op=Op.BACKWARD, + chunk=chunk, + ) + backwards.append(action) + + actions = [] + warmup = len(chunks_list[0]) + fwd, bwd = 1 - p, 0 + # While there are still backward slices to process on the first rank + while bwd < len(forwards) + p - 1: + if fwd < warmup: + op = Op.FORWARD + elif fwd == len(forwards): + op = Op.BACKWARD + elif fwd - bwd == warmup: + op = Op.BACKWARD + else: + op = Op.FORWARD + + ops = [] + if op == Op.FORWARD: + for rank in range(p): + fwd_idx = fwd + rank + if fwd_idx >= 0 and fwd_idx < len(forwards): + action = forwards[fwd_idx] + ops.append( + Action(action.stage_id + p - 1 - rank, action.op, action.chunk) + ) + else: + ops.append(None) + fwd += 1 + else: + for rank in range(p): + bwd_idx = bwd - rank + if bwd_idx >= 0 and bwd_idx < len(backwards): + action = backwards[bwd_idx] + ops.append( + Action(action.stage_id + p - 1 - rank, action.op, action.chunk) + ) + else: + ops.append(None) + bwd += 1 + actions.append(ops[::-1]) + + return [list(row) for row in zip(*actions)] + + +def build_hybrid_schedule( + p: int, + k: int, + fwd_switch: tuple[int, int], + bwd_switch: tuple[int, int], + chunks_list: list[list[int]], +) -> list[list[Action]]: + forwards = [] + for microbatch_id, chunks in enumerate(chunks_list): + for chunk_id, size in enumerate(chunks): + chunk = Chunk( + batch_id=microbatch_id, + chunk_id=chunk_id, + length=size, + ) + action = Action( + stage_id=0, + op=Op.FORWARD, + chunk=chunk, + ) + forwards.append(action) + + backwards = [] + for microbatch_id, chunks in enumerate(chunks_list): + for chunk_id, size in reversed(list(enumerate(chunks))): + chunk = Chunk( + batch_id=microbatch_id, + chunk_id=chunk_id, + length=size, + ) + action = Action( + stage_id=0, + op=Op.BACKWARD, + chunk=chunk, + ) + backwards.append(action) + + num_chunks = [len(chunks) for chunks in chunks_list] + actions_by_rank: list[list[Action]] = [] + + for rank in range(p): + fwd_switched, bwd_switched = False, False + actions = [] + slimpipe_warmup = num_chunks[0] + 2 * (p - 1 - rank) + kfkb_warmup = k * (p - rank) + fwd, bwd = 0, 0 + + # Warmup phase + counter = 0 + while counter < ( + kfkb_warmup if fwd_switched else slimpipe_warmup + ) and fwd < len(forwards): + chunk = forwards[fwd].chunk + data_id = (chunk.batch_id, chunk.chunk_id) + if not fwd_switched and data_id == fwd_switch: + fwd_switched = True + continue + cnt = k if fwd_switched else 1 + for _ in range(cnt): + if fwd >= len(forwards): + break + action = forwards[fwd] + actions.append(Action(rank, action.op, action.chunk)) + fwd += 1 + counter += 1 + + # Steady state phase + while fwd < len(forwards): + # Backward + if not bwd_switched: + chunk = backwards[bwd].chunk + data_id = (chunk.batch_id, chunk.chunk_id) + if data_id == bwd_switch: + bwd_switched = True + cnt = k if bwd_switched else 1 + for _ in range(cnt): + if bwd >= len(backwards): + break + action = backwards[bwd] + actions.append(Action(rank, action.op, action.chunk)) + bwd += 1 + + # Forward + if not fwd_switched: + chunk = forwards[fwd].chunk + data_id = (chunk.batch_id, chunk.chunk_id) + if data_id == fwd_switch: + fwd_switched = True + cnt = k if fwd_switched else 1 + for _ in range(cnt): + if fwd >= len(forwards): + break + action = forwards[fwd] + actions.append(Action(rank, action.op, action.chunk)) + fwd += 1 + + # Cooldown phase + while bwd < len(backwards): + action = backwards[bwd] + actions.append(Action(rank, action.op, action.chunk)) + bwd += 1 + + actions_by_rank.append(actions) + + return actions_by_rank + + +def build_1f1b_schedule(p, chunks: list[int]) -> list[list[Action]]: + forwards = [] + for batch_id, size in enumerate(chunks): + chunk = Chunk( + batch_id=batch_id, + chunk_id=0, + length=size, + ) + action = Action( + stage_id=0, + op=Op.FORWARD, + chunk=chunk, + ) + forwards.append(action) + + backwards = [] + for batch_id, size in enumerate(chunks): + chunk = Chunk( + batch_id=batch_id, + chunk_id=0, + length=size, + ) + action = Action( + stage_id=0, + op=Op.BACKWARD, + chunk=chunk, + ) + backwards.append(action) + + actions_by_rank: list[list[Action]] = [] + + for rank in range(p): + actions: list[Action] = [] + warmup = min(len(chunks), p - rank) + fwd, bwd = 0, 0 + + while fwd < warmup: + action = forwards[fwd] + actions.append(Action(rank, action.op, action.chunk)) + fwd += 1 + + while fwd < len(forwards): + action = backwards[bwd] + actions.append(Action(rank, action.op, action.chunk)) + bwd += 1 + + action = forwards[fwd] + actions.append(Action(rank, action.op, action.chunk)) + fwd += 1 + + while bwd < len(backwards): + action = backwards[bwd] + actions.append(Action(rank, action.op, action.chunk)) + bwd += 1 + + actions_by_rank.append(actions) + + return actions_by_rank + + +if __name__ == "__main__": + # Example usage + p = 4 # Number of ranks + num_chunks = [8, 6, 4, 1] # Number of slices for each batch + + chunks_list = [] + for num in num_chunks: + chunks_list.append([1] * num) + + actions_by_rank = build_splitfuse_schedule(p, chunks_list) + pp.pprint(actions_by_rank) + plot_timeline(actions_by_rank) diff --git a/megatron/pipeline_simulator/hotsim/training_config.py b/megatron/pipeline_simulator/hotsim/training_config.py new file mode 100644 index 00000000000..067c6f48c1a --- /dev/null +++ b/megatron/pipeline_simulator/hotsim/training_config.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass + +from .model import Model + + +@dataclass +class TrainingConfig: + """Configuration for a distributed training experiment. + + This class calculates memory requirements and efficiency metrics for training + large language models across multiple GPUs using various parallelism techniques. + """ + + # Input configuration + model: Model # Model architecture details + num_gpus: int # Total GPUs available + microbatch_size: int # Batch size per microbatch + + # Parallelism dimensions + tensor_parallel_size: int # Number of tensor parallel groups + context_parallel_size: int # Number of context parallel groups + data_parallel_size: int # Number of data parallel groups + pipeline_parallel_size: int # Number of pipeline stages + expert_parallel_size: int # Number of expert parallel groups + + # Model execution strategy + num_model_chunks: int | None = None # Number of model chunks per pipeline stage + num_layers_per_virtual_stage: int | None = ( + None # Transformer layers per virtual pipeline stage + ) + ckpt: str = "no" # Whether to use activation checkpointing + offload_ratio: float = 0.0 # Ratio of activations to offload + + def __post_init__(self): + # Validate parallelism configuration + if ( + self.tensor_parallel_size + * self.context_parallel_size + * self.pipeline_parallel_size + * self.data_parallel_size + != self.num_gpus + ): + raise ValueError( + "Product of parallel dimensions must equal the number of GPUs" + ) + + # Validate model chunking configuration + if self.num_model_chunks is None and self.num_layers_per_virtual_stage is None: + raise ValueError( + "Either num_model_chunks or num_layers_per_virtual_stage must be specified" + ) + + # Compute derived parameters for model chunking + if self.num_layers_per_virtual_stage: + self.num_model_chunks = ( + self.model.num_hidden_layers + // self.pipeline_parallel_size + // self.num_layers_per_virtual_stage + ) + else: + self.num_layers_per_virtual_stage = ( + self.model.num_hidden_layers + // self.pipeline_parallel_size + // self.num_model_chunks + ) + + # Calculate number of transformer layers per pipeline stage + self.num_layers_per_stage = ( + self.model.num_hidden_layers // self.pipeline_parallel_size + ) + + def __repr__(self): + """Generate a human-readable representation of the experiment configuration.""" + return ( + f"#GPUs={self.num_gpus},\n" + f"Model={self.model},\n" + f"B={self.global_batch_size}, b={self.microbatch_size},\n" + f"t={self.tensor_parallel_size}, c={self.context_parallel_size}, " + f"d={self.data_parallel_size}, p={self.pipeline_parallel_size}\n" + f"v={self.num_model_chunks}, l={self.num_layers_per_virtual_stage}\n" + f"ckpt={self.ckpt}, offload={self.offload_ratio:.2%}" + ) diff --git a/megatron/pipeline_simulator/pyproject.toml b/megatron/pipeline_simulator/pyproject.toml new file mode 100644 index 00000000000..8a337ac984a --- /dev/null +++ b/megatron/pipeline_simulator/pyproject.toml @@ -0,0 +1,22 @@ +[project] +name = "pipeline-simulator" +version = "0.2.0" +description = "A pipeline parallelism simulator" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "matplotlib>=3.10.1", + "networkx>=3.4.2", + "numpy>=2.2.4", + "scipy>=1.15.2", +] + +[tool.ruff] +line-length = 88 + +[dependency-groups] +dev = [ + "mypy>=1.15.0", + "ruff>=0.11.6", + "types-networkx>=3.4.2.20250319", +] diff --git a/megatron/pipeline_simulator/simulator/__init__.py b/megatron/pipeline_simulator/simulator/__init__.py new file mode 100644 index 00000000000..dff642b17c8 --- /dev/null +++ b/megatron/pipeline_simulator/simulator/__init__.py @@ -0,0 +1,14 @@ +""" +Pipeline Simulator package. + +This package provides tools for simulating and analyzing pipelined processor execution. +""" + +# Import main components to make them available when importing the package +# For example: +# from .simulator import Simulator +# from .pipeline import Pipeline +# from .instruction import Instruction + +__version__ = "0.2.0" +__author__ = "Wei Zhang" diff --git a/megatron/pipeline_simulator/simulator/ir.py b/megatron/pipeline_simulator/simulator/ir.py new file mode 100644 index 00000000000..4dffa4c29cf --- /dev/null +++ b/megatron/pipeline_simulator/simulator/ir.py @@ -0,0 +1,145 @@ +""" +Intermediate representations for pipeline simulation. + +This module defines data structures and enumerations for representing computations and actions in a +pipeline parallel neural network training system. Several data structures (e.g. ActionType) are +copied from Pytorch, please check the license and copyright information in the original repository. + +Classes: + ActionType: Enumeration of different computation/communication types in the pipeline. + Action: Represents a specific action to be performed on a chunk. +""" + +import re +from enum import Enum, auto +from typing import NamedTuple, Optional + + +class ActionType(Enum): + """Types of actions that can be performed in the pipeline.""" + + FORWARD = auto() + BACKWARD_INPUT = auto() + BACKWARD_WEIGHT = auto() + SEND_F = auto() + RECV_F = auto() + SEND_B = auto() + RECV_B = auto() + FULL_BACKWARD = auto() + + def __str__(self) -> str: + m = { + FORWARD: "F", + BACKWARD_INPUT: "I", + BACKWARD_WEIGHT: "W", + SEND_F: "SEND_F", + RECV_F: "RECV_F", + SEND_B: "SEND_B", + RECV_B: "RECV_B", + FULL_BACKWARD: "B", + } + return m.get(self, f"Unknown({self.value})") + + @classmethod + def from_str(cls, action: str) -> "ActionType": + """Convert string representation to ActionType.""" + m = { + "F": FORWARD, + "I": BACKWARD_INPUT, + "W": BACKWARD_WEIGHT, + "SEND_F": SEND_F, + "RECV_F": RECV_F, + "SEND_B": SEND_B, + "RECV_B": RECV_B, + "B": FULL_BACKWARD, + } + if action in m: + return m[action] + raise ValueError(f"Invalid action type: {action}") + + +# Global constants for convenience +FORWARD = ActionType.FORWARD +BACKWARD_INPUT = ActionType.BACKWARD_INPUT +BACKWARD_WEIGHT = ActionType.BACKWARD_WEIGHT +SEND_F = ActionType.SEND_F +RECV_F = ActionType.RECV_F +SEND_B = ActionType.SEND_B +RECV_B = ActionType.RECV_B +FULL_BACKWARD = ActionType.FULL_BACKWARD + +# Convenience shorthand for compute actions only +F = FORWARD +I = BACKWARD_INPUT +W = BACKWARD_WEIGHT +B = FULL_BACKWARD + +# Regular expression for parsing action strings +_ACTION_REGEX = re.compile(r"(.+)(F|I|B|W|SEND_F|RECV_F|SEND_B|RECV_B)(.+)") + + +class Action(NamedTuple): + """ + An action to be performed on a chunk. + + Attributes: + stage_id: The id of the stage in the pipeline. + action_type: The type of action to be performed. + data_id: The id of the data chunk. + """ + + stage_id: int + action_type: ActionType + data_id: str | int | tuple[int, ...] + + def __repr__(self) -> str: + return f"{self.stage_id}{self.action_type}{self.data_id}" + + @staticmethod + def from_str(action_string: str) -> Optional["Action"]: + """ + Parse a string representation of an Action. + + Args: + action_string: String formatted as [stage_id][action_type][chunk_id] + e.g. `2F0`, `3SEND_F1` + + Returns: + The parsed Action object, or None if the string is empty. + + Raises: + ValueError: If the action string format is invalid. + """ + action_string = action_string.strip() + if not action_string: + return None + + match = _ACTION_REGEX.match(action_string) + if not match: + raise ValueError( + f"Invalid action string: {action_string}, should be formatted as " + f"[stage_id][action_type][data_id] (e.g. 2F0, 5B(2, 3))." + ) + + stage_id, action_type, data_id = match.groups() + return Action( + stage_id=int(stage_id), + action_type=ActionType.from_str(action_type), + data_id=data_id, + ) + + +class Stats(NamedTuple): + """ + Timing information for an action. + + Attributes: + start_time: The start time of the action. + end_time: The end time of the action. + """ + + start_time: int | float + end_time: int | float + + def __repr__(self) -> str: + return f"({self.start_time}, {self.end_time})" diff --git a/megatron/pipeline_simulator/simulator/parser.py b/megatron/pipeline_simulator/simulator/parser.py new file mode 100644 index 00000000000..a0fe9a15410 --- /dev/null +++ b/megatron/pipeline_simulator/simulator/parser.py @@ -0,0 +1,107 @@ +from copy import deepcopy +from itertools import zip_longest +from typing import Mapping + +from .ir import Action +from .schedules import InterleavedSchedule, SlimPipeSchedule + + +class Parser: + @staticmethod + def print( + pipeline_order: Mapping[int, list[Action | None]], + error_step_number: int | None = None, + ) -> str: + """ + Formats the pipeline order in a timestep (row) x rank (column) grid of actions. + + Args: + pipeline_order: Dictionary mapping ranks to their action sequences + error_step_number: Optional step number to highlight with an error marker + + Returns: + Formatted string representation of the pipeline schedule + """ + # Create a deep copy to avoid mutating the original + pipeline_order = deepcopy(pipeline_order) + + # Replace None values with empty strings + for rank_actions in pipeline_order.values(): + for i, action in enumerate(rank_actions): + if action is None: + rank_actions[i] = "" # type: ignore + + # Calculate dimensions and labels + num_steps = max(len(actions) for actions in pipeline_order.values()) + num_ranks = len(pipeline_order) + + step_labels = [ + f"Step {i:0{len(str(num_steps - 1))}d}" for i in range(num_steps) + ] + rank_labels = [f"Rank {i}" for i in range(num_ranks)] + + # Get actions for each rank in sorted order + rank_actions = [ + pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order) + ] + + # Transpose to get actions by step instead of by rank + transposed_actions = list(zip_longest(*rank_actions, fillvalue="")) + + # Calculate column widths for alignment + max_lengths = [ + max( + len(str(item)) + for item in [rank_labels[i], *[row[i] for row in transposed_actions]] + ) + for i in range(num_ranks) + ] + + # Format the header row + label_width = len(step_labels[0]) + header_row = " " * (label_width + 2) + " ".join( + f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels) + ) + + # Format each row with proper alignment + formatted_rows = [] + for step_num, (label, actions) in enumerate( + zip(step_labels, transposed_actions) + ): + row = f"{label}: " + " ".join( + f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(actions) + ) + + # Add error indicator if needed + if error_step_number is not None and step_num == error_step_number: + row += " <-- ERROR HERE" + + formatted_rows.append(row) + + # Join all rows into the final table + return header_row + "\n" + "\n".join(formatted_rows) + "\n" + + +def test_interleaved(): + p = 4 + v = 2 + chunks_list = [1] * 8 + schedule = InterleavedSchedule(p, v, chunks_list) + order_deps = schedule.order_deps() + pipeline_order = {rank: actions for rank, actions in enumerate(order_deps)} + print(Parser.print(pipeline_order)) + + +def test_slimpipe(): + p = 4 + v = 2 + chunks_list = [[0, 1, 2, 3], [4, 5, 6, 7]] + schedule = SlimPipeSchedule(p, v, chunks_list) + order_deps = schedule.order_deps() + pipeline_order = {rank: actions for rank, actions in enumerate(order_deps)} + print(Parser.print(pipeline_order)) + + +if __name__ == "__main__": + test_interleaved() + test_slimpipe() diff --git a/megatron/pipeline_simulator/simulator/plotter.py b/megatron/pipeline_simulator/simulator/plotter.py new file mode 100644 index 00000000000..9eb51cf027c --- /dev/null +++ b/megatron/pipeline_simulator/simulator/plotter.py @@ -0,0 +1,87 @@ +from typing import Mapping + +import matplotlib.patches as patches +import matplotlib.pyplot as plt +from matplotlib import colormaps + +try: + from .ir import Action, B, F, Stats + from .schedules import AbstractSchedule +except: + from ir import Action, B, F, Stats + from schedules import AbstractSchedule + +class Plotter: + def __init__(self, schedule: AbstractSchedule) -> None: + self.schedule = schedule + + def draw_timeline(self, timeline: Mapping[Action, Stats]) -> None: + """Draw a timeline visualization of the pipeline execution.""" + fig, ax = plt.subplots(figsize=(10, 4), layout="constrained") + + pp = self.schedule.pipeline_parallelism() + stage_map = self.schedule.stage_map() + width = max(timeline[action].end_time for action in timeline) + + # Get distinct colormaps for forward and backward passes + forward_cmap = colormaps["Blues"] + backward_cmap = colormaps["Greens"] + + # Calculate number of colors needed based on virtual stages + num_stages = max(stage_id // pp for stage_id in stage_map.keys()) + 1 + + # Generate color lists with brightness range from 0.4 to 0.8 + blues = [ + forward_cmap(0.3 + 0.3 * i / max(1, num_stages - 1)) + for i in range(num_stages) + ] + greens = [ + backward_cmap(0.3 + 0.3 * i / max(1, num_stages - 1)) + for i in range(num_stages) + ] + + for action, stats in timeline.items(): + stage_id, action_type, data_id = action + start_time, end_time = stats + cost = end_time - start_time + + virtual_stage_id = stage_id // pp + if action_type == F: + color = blues[virtual_stage_id % len(blues)] + elif action_type == B: + color = greens[virtual_stage_id % len(greens)] + else: + raise ValueError(f"Invalid action: {action}") + + rank = stage_map[stage_id] + + # Add rectangle for action + ax.add_patch( + patches.Rectangle( + (start_time / width, rank / pp), + cost / width, + 1 / pp, + facecolor=color, + edgecolor="black", + ) + ) + + # Add label + if action_type in (F, B): + label = ( + "\n".join(map(str, data_id)) + if isinstance(data_id, tuple) + else str(data_id) + ) + ax.text( + (start_time + 0.5 * cost) / width, + (rank + 0.5) / pp, + label, + fontsize=8, + verticalalignment="center", + horizontalalignment="center", + ) + + ax.invert_yaxis() + plt.axis("off") + plt.savefig("/m2v_model/wuguohao03/nv_teamwork/Megatron-LM/megatron/pipeline_simulator/simulator/pipeline.png") diff --git a/megatron/pipeline_simulator/simulator/schedules.py b/megatron/pipeline_simulator/simulator/schedules.py new file mode 100644 index 00000000000..f11ed55cf05 --- /dev/null +++ b/megatron/pipeline_simulator/simulator/schedules.py @@ -0,0 +1,822 @@ +from abc import ABC, abstractmethod +from itertools import pairwise +from typing import Mapping + +try: + from .ir import Action, B, F +except: + from ir import Action, B, F + + +class AbstractSchedule(ABC): + @abstractmethod + def pipeline_parallelism(self) -> int: + """ + Returns the number of ranks in the schedule. + """ + pass + + @abstractmethod + def data_deps(self) -> list[list[Action]]: + """ + Returns the data dependencies for the schedule. + """ + pass + + @abstractmethod + def order_deps(self) -> list[list[Action]]: + """ + Returns the order dependencies for the schedule. + """ + pass + + @abstractmethod + def cost_map(self) -> Mapping[Action, int | float]: + """ + Returns the costs associated with each action in the schedule. + """ + pass + + @abstractmethod + def stage_map(self) -> Mapping[int, int]: + """ + Returns the mapping of stages to ranks. + """ + pass + + +class InterleavedSchedule(AbstractSchedule): + def __init__(self, p: int, v: int, microbatches: list[int] | list[float], backward_microbatches=None) -> None: + self.p = p + self.v = v + self.microbatches = microbatches + self.backward_microbatches = None + if backward_microbatches is not None: + self.backward_microbatches = backward_microbatches + + num_microbatches = len(microbatches) + if num_microbatches % self.p != 0 or num_microbatches // self.p <= 1: + raise ValueError( + "Number of microbatches must be divisible by the number of ranks and twice the number of ranks." + ) + + def pipeline_parallelism(self) -> int: + return self.p + + def data_deps(self) -> list[list[Action]]: + num_stages = self.p * self.v + deps = [] + + # Latter stages depend on earlier stages + for microbatch_id in range(len(self.microbatches)): + dep = [] + for stage_id in range(num_stages): + action = Action( + stage_id=stage_id, + action_type=F, + data_id=microbatch_id, + ) + dep.append(action) + for stage_id in reversed(range(num_stages)): + action = Action( + stage_id=stage_id, + action_type=B, + data_id=microbatch_id, + ) + dep.append(action) + deps.append(dep) + + return deps + + def order_deps(self) -> list[list[Action]]: + m, p, v = len(self.microbatches), self.p, self.v + + forwards = [] + for start_microbatch_id in range(0, m, p): + for virtual_stage_id in range(v): + for microbatch_id in range( + start_microbatch_id, start_microbatch_id + p + ): + action = Action( + stage_id=virtual_stage_id * p, + action_type=F, + data_id=microbatch_id, + ) + forwards.append(action) + + backwards = [] + for start_microbatch_id in range(0, m, p): + for virtual_stage_id in range(v - 1, -1, -1): + for microbatch_id in range( + start_microbatch_id, start_microbatch_id + p + ): + action = Action( + stage_id=virtual_stage_id * p, + action_type=B, + data_id=microbatch_id, + ) + backwards.append(action) + + actions_by_rank = [] + for rank in range(p): + actions = [] + warmup = v * p + p - 1 - 2 * rank + fwd, bwd = 0, 0 + while bwd < len(backwards): + if fwd < warmup: + op = F + elif fwd == len(forwards): + op = B + elif fwd - bwd == warmup: + op = B + else: + op = F + + action = forwards[fwd] if op == F else backwards[bwd] + if op == F: + fwd += 1 + else: + bwd += 1 + stage_id, action_type, data_id = action + actions.append(Action(stage_id + rank, action_type, data_id)) + actions_by_rank.append(actions) + + return actions_by_rank + + def cost_map(self) -> Mapping[Action, int | float]: + costs = {} + for microbatch_id, cost in enumerate(self.microbatches): + if self.backward_microbatches is not None: + bwd_cost = self.backward_microbatches[microbatch_id] + else: + bwd_cost = 2 * cost + for stage_id in range(self.p * self.v): + fwd = Action( + stage_id=stage_id, + action_type=F, + data_id=microbatch_id, + ) + bwd = Action( + stage_id=stage_id, + action_type=B, + data_id=microbatch_id, + ) + costs[fwd] = cost + costs[bwd] = bwd_cost + return costs + + def stage_map(self) -> Mapping[int, int]: + return {i: i % self.p for i in range(self.p * self.v)} + + +class SlimPipeSchedule(AbstractSchedule): + def __init__( + self, p: int, v: int, chunks_list: list[list[int]] | list[list[float]] + ) -> None: + self.p = p + self.v = v + self.chunks_list = chunks_list + + num_chunks = [len(chunks) for chunks in self.chunks_list] + if any(n == 0 for n in num_chunks): + raise ValueError("Number of slices must be greater than 0.") + if any(n % self.p != 0 for n in num_chunks): + raise ValueError( + "Number of slices must be divisible by the number of ranks." + ) + if any(prev < curr for prev, curr in pairwise(num_chunks)): + raise ValueError("Number of slices must be non-increasing.") + + def pipeline_parallelism(self) -> int: + return self.p + + def data_deps(self) -> list[list[Action]]: + num_stages = self.p * self.v + deps = [] + + # Latter stages depend on earlier stages + for microbatch_id, chunks in enumerate(self.chunks_list): + for chunk_id in range(len(chunks)): + dep = [] + for stage_id in range(num_stages): + action = Action( + stage_id=stage_id, + action_type=F, + data_id=(microbatch_id, chunk_id), + ) + dep.append(action) + for stage_id in reversed(range(num_stages)): + action = Action( + stage_id=stage_id, + action_type=B, + data_id=(microbatch_id, chunk_id), + ) + dep.append(action) + deps.append(dep) + + # Latter chunks depend on earlier chunks + for stage_id in range(num_stages): + for microbatch_id, chunks in enumerate(self.chunks_list): + dep = [] + for chunk_id in range(len(chunks)): + action = Action( + stage_id=stage_id, + action_type=F, + data_id=(microbatch_id, chunk_id), + ) + dep.append(action) + for chunk_id in reversed(range(len(chunks))): + action = Action( + stage_id=stage_id, + action_type=B, + data_id=(microbatch_id, chunk_id), + ) + dep.append(action) + deps.append(dep) + + return deps + + def order_deps(self) -> list[list[Action]]: + p, v = self.p, self.v + + forwards = [] + for microbatch_id, chunks in enumerate(self.chunks_list): + n = len(chunks) + for start_chunk_id in range(0, n, p): + for virtual_stage_id in range(v): + for chunk_id in range(start_chunk_id, start_chunk_id + p): + action = Action( + stage_id=virtual_stage_id * p, + action_type=F, + data_id=(microbatch_id, chunk_id), + ) + forwards.append(action) + + backwards = [] + for microbatch_id, chunks in enumerate(self.chunks_list): + n = len(chunks) + for start_chunk_id in range(n - 1, -1, -p): + for virtual_stage_id in range(v - 1, -1, -1): + for chunk_id in range(start_chunk_id, start_chunk_id - p, -1): + action = Action( + stage_id=virtual_stage_id * p, + action_type=B, + data_id=(microbatch_id, chunk_id), + ) + backwards.append(action) + + actions = [] + warmup = len(self.chunks_list[0]) * v + fwd, bwd = 1 - p, 0 + # While there are still backward slices to process on the first rank + while bwd < len(forwards) + p - 1: + if fwd < warmup: + op = F + elif fwd == len(forwards): + op = B + elif fwd - bwd == warmup: + op = B + else: + op = F + + ops = [] + if op == F: + for rank in range(p): + fwd_idx = fwd + rank + if fwd_idx >= 0 and fwd_idx < len(forwards): + stage_id, action_type, data_id = forwards[fwd_idx] + ops.append( + Action(stage_id + p - 1 - rank, action_type, data_id) + ) + else: + ops.append(None) + fwd += 1 + else: + for rank in range(p): + bwd_idx = bwd - rank + if bwd_idx >= 0 and bwd_idx < len(backwards): + stage_id, action_type, data_id = backwards[bwd_idx] + ops.append( + Action(stage_id + p - 1 - rank, action_type, data_id) + ) + else: + ops.append(None) + bwd += 1 + actions.append(ops[::-1]) + + return [list(row) for row in zip(*actions)] + + def cost_map(self) -> Mapping[Action, int | float]: + costs = {} + for microbatch_id, chunks in enumerate(self.chunks_list): + for chunk_id, cost in enumerate(chunks): + for stage_id in range(self.p * self.v): + fwd = Action( + stage_id=stage_id, + action_type=F, + data_id=(microbatch_id, chunk_id), + ) + bwd = Action( + stage_id=stage_id, + action_type=B, + data_id=(microbatch_id, chunk_id), + ) + costs[fwd] = cost + costs[bwd] = 2 * cost + return costs + + def stage_map(self) -> Mapping[int, int]: + return {i: i % self.p for i in range(self.p * self.v)} + + +class SplitFuseSchedule(AbstractSchedule): + def __init__( + self, + p: int, + fwd_costs: list[list[float]], + bwd_costs: list[list[float]], + ) -> None: + self.p = p + self.fwd_costs = fwd_costs + self.bwd_costs = bwd_costs + + if len(fwd_costs) != len(bwd_costs): + raise ValueError("Number of microbatches must be the same for fwd and bwd.") + for fwds, bwds in zip(fwd_costs, bwd_costs): + if len(fwds) != len(bwds): + raise ValueError( + "Number of slices must be the same for the same microbatch." + ) + + num_chunks = [len(chunks) for chunks in self.fwd_costs] + if any(n == 0 for n in num_chunks): + raise ValueError("Number of slices must be greater than 0.") + if any(prev < curr for prev, curr in pairwise(num_chunks)): + raise ValueError("Number of slices must be non-increasing.") + + def pipeline_parallelism(self) -> int: + return self.p + + def data_deps(self) -> list[list[Action]]: + num_stages = self.p + deps = [] + + # Latter stages depend on earlier stages + for microbatch_id, chunks in enumerate(self.fwd_costs): + for chunk_id in range(len(chunks)): + dep = [] + for stage_id in range(num_stages): + action = Action( + stage_id=stage_id, + action_type=F, + data_id=(microbatch_id, chunk_id), + ) + dep.append(action) + for stage_id in reversed(range(num_stages)): + action = Action( + stage_id=stage_id, + action_type=B, + data_id=(microbatch_id, chunk_id), + ) + dep.append(action) + deps.append(dep) + + # Latter chunks depend on earlier chunks + for stage_id in range(num_stages): + for microbatch_id, chunks in enumerate(self.fwd_costs): + dep = [] + for chunk_id in range(len(chunks)): + action = Action( + stage_id=stage_id, + action_type=F, + data_id=(microbatch_id, chunk_id), + ) + dep.append(action) + for chunk_id in reversed(range(len(chunks))): + action = Action( + stage_id=stage_id, + action_type=B, + data_id=(microbatch_id, chunk_id), + ) + dep.append(action) + deps.append(dep) + + return deps + + def order_deps(self) -> list[list[Action]]: + p = self.p + + forwards = [] + for microbatch_id, chunks in enumerate(self.fwd_costs): + for chunk_id, chunk in enumerate(chunks): + action = Action( + stage_id=0, + action_type=F, + data_id=(microbatch_id, chunk_id), + ) + forwards.append(action) + + backwards = [] + for microbatch_id, chunks in enumerate(self.fwd_costs): + for chunk_id, chunk in reversed(list(enumerate(chunks))): + action = Action( + stage_id=0, + action_type=B, + data_id=(microbatch_id, chunk_id), + ) + backwards.append(action) + + actions = [] + warmup = len(self.fwd_costs[0]) + fwd, bwd = 1 - p, 0 + # While there are still backward slices to process on the first rank + while bwd < len(forwards) + p - 1: + if fwd < warmup: + op = F + elif fwd == len(forwards): + op = B + elif fwd - bwd == warmup: + op = B + else: + op = F + + ops = [] + if op == F: + for rank in range(p): + fwd_idx = fwd + rank + if fwd_idx >= 0 and fwd_idx < len(forwards): + stage_id, action_type, data_id = forwards[fwd_idx] + ops.append( + Action(stage_id + p - 1 - rank, action_type, data_id) + ) + else: + ops.append(None) + fwd += 1 + else: + for rank in range(p): + bwd_idx = bwd - rank + if bwd_idx >= 0 and bwd_idx < len(backwards): + stage_id, action_type, data_id = backwards[bwd_idx] + ops.append( + Action(stage_id + p - 1 - rank, action_type, data_id) + ) + else: + ops.append(None) + bwd += 1 + actions.append(ops[::-1]) + + actions_by_rank = [list(row) for row in zip(*actions)] + + bwd_fwds = [] + for rank, actions in enumerate(reversed(actions_by_rank)): + i = 0 + while i < len(actions) and (not actions[i] or actions[i].action_type == F): + i += 1 + while i < len(actions) and (not actions[i] or actions[i].action_type == B): + bwd_fwds.append(actions[i]) + i += 1 + while i < len(actions) and (not actions[i] or actions[i].action_type == F): + bwd_fwds.append(actions[i]) + i += 1 + + actions_by_rank.append(bwd_fwds) + + return actions_by_rank + + def cost_map(self) -> Mapping[Action, int | float]: + costs = {} + for microbatch_id, (fwd_chunks, bwd_chunks) in enumerate( + zip(self.fwd_costs, self.bwd_costs) + ): + for chunk_id, (fwd_cost, bwd_cost) in enumerate( + zip(fwd_chunks, bwd_chunks) + ): + for stage_id in range(self.p): + fwd = Action( + stage_id=stage_id, + action_type=F, + data_id=(microbatch_id, chunk_id), + ) + bwd = Action( + stage_id=stage_id, + action_type=B, + data_id=(microbatch_id, chunk_id), + ) + costs[fwd] = fwd_cost + costs[bwd] = bwd_cost + return costs + + def stage_map(self) -> Mapping[int, int]: + return {i: i for i in range(self.p)} + + +class kFkBSchedule(AbstractSchedule): + def __init__(self, p: int, k: int, microbatches: list[int] | list[float], backward_microbatches=None) -> None: + self.p = p + self.k = k + self.microbatches = microbatches + self.backward_microbatches = None + if backward_microbatches is not None: + self.backward_microbatches = backward_microbatches + + def pipeline_parallelism(self) -> int: + return self.p + + def data_deps(self) -> list[list[Action]]: + num_stages = self.p + deps = [] + + # Latter stages depend on earlier stages + for microbatch_id in range(len(self.microbatches)): + dep = [] + for stage_id in range(num_stages): + action = Action( + stage_id=stage_id, + action_type=F, + data_id=microbatch_id, + ) + dep.append(action) + for stage_id in reversed(range(num_stages)): + action = Action( + stage_id=stage_id, + action_type=B, + data_id=microbatch_id, + ) + dep.append(action) + deps.append(dep) + + return deps + + def order_deps(self) -> list[list[Action]]: + p, k = self.p, self.k + + forwards = [] + for microbatch_id in range(len(self.microbatches)): + action = Action( + stage_id=0, + action_type=F, + data_id=microbatch_id, + ) + forwards.append(action) + + backwards = [] + for microbatch_id in range(len(self.microbatches)): + action = Action( + stage_id=0, + action_type=B, + data_id=microbatch_id, + ) + backwards.append(action) + + actions_by_rank: list[list[Action]] = [] + for rank in range(p): + actions = [] + warmup = min(k * (p - rank), len(forwards)) + fwd, bwd = 0, 0 + while bwd < len(backwards): + if fwd < warmup: + op = F + elif fwd == len(forwards): + op = B + elif fwd - bwd >= warmup: + op = B + else: + op = F + + for i in range(k): + action = forwards[fwd] if op == F else backwards[bwd] + stage_id, action_type, data_id = action + actions.append(Action(stage_id + rank, action_type, data_id)) + if op == F: + fwd += 1 + else: + bwd += 1 + if fwd >= len(forwards) or bwd >= len(backwards): + break + actions_by_rank.append(actions) + + return actions_by_rank + + def cost_map(self) -> Mapping[Action, int | float]: + costs = {} + + microbatch_id_list = list(range(len(self.microbatches))) + for idx, microbatch_id in enumerate(microbatch_id_list): + cost = self.microbatches[idx] + if self.backward_microbatches: + bwd_cost = self.backward_microbatches[idx] + else: + bwd_cost = 2 * cost + for stage_id in range(self.p): + fwd = Action( + stage_id=stage_id, + action_type=F, + data_id=microbatch_id, + ) + bwd = Action( + stage_id=stage_id, + action_type=B, + data_id=microbatch_id, + ) + costs[fwd] = cost + costs[bwd] = bwd_cost + + return costs + + def stage_map(self) -> Mapping[int, int]: + return {i: i for i in range(self.p)} + + +class HybridSlimSchedule(SplitFuseSchedule): + def __init__( + self, + p: int, + k: int, + fwd_switch: tuple[int, int], + bwd_switch: tuple[int, int], + fwd_costs: list[list[float]], + bwd_costs: list[list[float]], + ) -> None: + """ + Initializes the HybridSlimSchedule with the given parameters. + + Args: + p (int): Number of pipeline stages. + k (int): Number of k as in kFkB. + fwd_switch (tuple[int, int]): Starting index of kFkB forward. + bwd_switch (tuple[int, int]): Starting index of kFkB backward. + fwd_costs (list[list[float]]): Forward costs for each microbatch and chunk. + bwd_costs (list[list[float]]): Backward costs for each microbatch and chunk. + """ + super().__init__(p, fwd_costs, bwd_costs) + self.k = k + self.fwd_switch = fwd_switch + self.bwd_switch = bwd_switch + + num_chunks = [len(chunks) for chunks in self.fwd_costs] + if any(n == 0 for n in num_chunks): + raise ValueError("Number of slices must be greater than 0.") + if any(prev < curr for prev, curr in pairwise(num_chunks)): + raise ValueError("Number of slices must be non-increasing.") + + if self.k <= 1: + raise ValueError("k for kFkB schedule must be greater than 1.") + if self.k < num_chunks[0]: + raise ValueError( + "k for kFkB schedule must be greater than or equal to the maximal number of chunks." + ) + + microbatch_id, chunk_id = fwd_switch + if ( + microbatch_id < 0 + or microbatch_id >= len(fwd_costs) + or chunk_id < 0 + or chunk_id >= len(fwd_costs[microbatch_id]) + ): + raise ValueError("Invalid fwd_switch indices.") + + microbatch_id, chunk_id = bwd_switch + if ( + microbatch_id < 0 + or microbatch_id >= len(bwd_costs) + or chunk_id < 0 + or chunk_id >= len(bwd_costs[microbatch_id]) + ): + raise ValueError("Invalid bwd_switch indices.") + + def order_deps(self) -> list[list[Action]]: + p, k = self.p, self.k + + forwards = [] + for microbatch_id, chunks in enumerate(self.fwd_costs): + for chunk_id in range(len(chunks)): + action = Action( + stage_id=0, + action_type=F, + data_id=(microbatch_id, chunk_id), + ) + forwards.append(action) + + backwards = [] + for microbatch_id, chunks in enumerate(self.fwd_costs): + for chunk_id in reversed(range(len(chunks))): + action = Action( + stage_id=0, + action_type=B, + data_id=(microbatch_id, chunk_id), + ) + backwards.append(action) + + num_chunks = [len(chunks) for chunks in self.fwd_costs] + total_chunks = sum(num_chunks) + actions_by_rank: list[list[Action]] = [] + + for rank in range(p): + fwd_switched, bwd_switched = False, False + actions = [] + slimpipe_warmup = num_chunks[0] + 2 * (p - 1 - rank) + kfkb_warmup = k * (p - rank) + fwd, bwd = 0, 0 + + # Warmup phase + counter = 0 + while counter < ( + kfkb_warmup if fwd_switched else slimpipe_warmup + ) and fwd < len(forwards): + action = forwards[fwd] + stage_id, action_type, data_id = action + if not fwd_switched and data_id == self.fwd_switch: + fwd_switched = True + continue + cnt = k if fwd_switched else 1 + for _ in range(cnt): + if fwd >= len(forwards): + break + action = forwards[fwd] + stage_id, action_type, data_id = action + actions.append(Action(stage_id + rank, action_type, data_id)) + fwd += 1 + counter += 1 + + # Steady state phase + while fwd < len(forwards): + # Backward + if not bwd_switched: + stage_id, action_type, data_id = backwards[bwd] + if data_id == self.bwd_switch: + bwd_switched = True + cnt = k if bwd_switched else 1 + for _ in range(cnt): + if bwd >= len(backwards): + break + action = backwards[bwd] + stage_id, action_type, data_id = action + actions.append(Action(rank, action_type, data_id)) + bwd += 1 + + # Forward + if not fwd_switched: + stage_id, action_type, data_id = forwards[fwd] + if data_id == self.fwd_switch: + fwd_switched = True + cnt = k if fwd_switched else 1 + for _ in range(cnt): + if fwd >= len(forwards): + break + action = forwards[fwd] + stage_id, action_type, data_id = action + actions.append(Action(rank, action_type, data_id)) + fwd += 1 + + # Cooldown phase + while bwd < len(backwards): + action = backwards[bwd] + stage_id, action_type, data_id = action + actions.append(Action(rank, action_type, data_id)) + bwd += 1 + + actions_by_rank.append(actions) + + bwd_fwds = [] + for rank, actions in enumerate(reversed(actions_by_rank)): + i = 0 + while i < len(actions) and actions[i].action_type == F: + i += 1 + while i < len(actions) and actions[i].action_type == B: + bwd_fwds.append(actions[i]) + i += 1 + while i < len(actions) and actions[i].action_type == F: + bwd_fwds.append(actions[i]) + i += 1 + + actions_by_rank.append(bwd_fwds) + + return actions_by_rank + + +if __name__ == "__main__": + p = 4 + v = 2 + chunks_list = [[2, 3, 4, 5], [4, 5, 6, 7]] + schedule = SlimPipeSchedule(p, v, chunks_list) + order_deps = schedule.order_deps() + data_deps = schedule.data_deps() + + print("\nData Dependencies:") + for rank in range(p): + print(f"Rank {rank}:") + for action in data_deps[rank]: + print(action) + print("\nOrder Dependencies:") + for rank in range(p): + print(f"Rank {rank}:") + for action in order_deps[rank]: + print(action) + print("\nCosts:") + costs = schedule.cost_map() + for action, cost in costs.items(): + print(f"{action}: {cost}") + print("\nStage Mapping:") + stage_mapping = schedule.stage_map() + for stage_id, rank in stage_mapping.items(): + print(f"Stage {stage_id} -> Rank {rank}") diff --git a/megatron/pipeline_simulator/simulator/solver.py b/megatron/pipeline_simulator/simulator/solver.py new file mode 100644 index 00000000000..0bee1d8cc6c --- /dev/null +++ b/megatron/pipeline_simulator/simulator/solver.py @@ -0,0 +1,209 @@ +from typing import Iterable, Mapping + +import matplotlib.pyplot as plt +import networkx as nx + +try: + from .ir import Action, Stats + from .plotter import Plotter + from .schedules import ( + AbstractSchedule, + InterleavedSchedule, + SlimPipeSchedule, + SplitFuseSchedule, + kFkBSchedule, + HybridSlimSchedule, + ) +except: + from ir import Action, Stats + from plotter import Plotter + from schedules import ( + AbstractSchedule, + InterleavedSchedule, + SlimPipeSchedule, + SplitFuseSchedule, + kFkBSchedule, + HybridSlimSchedule, + ) + + +class Solver: + def __init__(self): + self.G = nx.DiGraph() + self.sorted_actions = [] + + def add_deps(self, actions: Iterable[Action]) -> None: + """ + Add dependencies on actions. The actions will be executed in the order provided. + """ + nx.add_path(self.G, actions) + + def add_costs(self, costs: Mapping[Action, float]) -> None: + """ + Add costs to the actions in the graph. + """ + nx.set_node_attributes(self.G, costs, "cost") + + def sort(self) -> list[Action]: + """ + Sort the actions in the graph based on their dependencies. + """ + self.sorted_actions = list(nx.topological_sort(self.G)) + return self.sorted_actions + + def get_sources(self) -> list[Action]: + """ + Get the source actions in the graph (actions with no prerequisites). + """ + return [node for node, in_degree in self.G.in_degree() if in_degree == 0] + + def solve(self) -> dict[Action, Stats]: + """ + Solves the pipeline scheduling problem and returns the makespan. + + Sets start times of source nodes to 0, then computes end times based on costs. + Updates successor nodes to respect precedence constraints. + + Returns: + The makespan of the schedule (maximum end time of any node). + """ + for node in self.get_sources(): + self.G.nodes[node]["start_time"] = 0 + + for u in nx.topological_sort(self.G): + start_time = self.G.nodes[u]["start_time"] + cost = self.G.nodes[u]["cost"] + end_time = start_time + cost + self.G.nodes[u]["end_time"] = end_time + + for v in self.G.successors(u): + v_node = self.G.nodes[v] + if "start_time" not in v_node or v_node["start_time"] < end_time: + v_node["start_time"] = end_time + + timeline = {} + for node in self.G.nodes: + start_time = self.G.nodes[node]["start_time"] + end_time = self.G.nodes[node]["end_time"] + timeline[node] = Stats(start_time, end_time) + + return timeline + + def show(self) -> None: + """Draw the dependency graph using networkx.""" + options = { + "font_size": 8, + "node_size": 16, + "linewidths": 1, + } + nx.draw_networkx( + self.G, with_labels=True, pos=nx.spring_layout(self.G), **options + ) + plt.axis("off") + plt.show() + + +def test_with_schedule(schedule: AbstractSchedule) -> None: + data_deps = schedule.data_deps() + order_deps = schedule.order_deps() + cost_map = schedule.cost_map() + + # Set up and run solver + solver = Solver() + for actions in order_deps: + # Filter out None values and add dependencies for remaining actions + solver.add_deps(filter(None, actions)) + + for actions in data_deps: + solver.add_deps(actions) + + solver.add_costs(cost_map) + timeline = solver.solve() + # solver.show() + + # plotter = Plotter(schedule) + # plotter.draw_timeline(timeline) + + # print(f"The makespan is {max(stats.end_time for stats in timeline.values())}") + return max(stats.end_time for stats in timeline.values()) + + +def test_interleaved() -> None: + """Test the solver with an interleaved schedule.""" + p = 4 # Number of pipeline stages + v = 2 # Number of microbatches + microbatches_1 = [2, 6, 8, 4, 3, 4, 4, 2] # Cost of each microbatch + microbatches_2 = [4, 8, 5, 4, 10, 3, 3, 3] # Cost of each microbatch + # microbatches = [(microbatches_1[i] + microbatches_2[i])/2 for i in range(len(microbatches_1))] + + microbatches_balanced = (sum(microbatches_1) + sum(microbatches_2)) // (len(microbatches_1) + len(microbatches_2)) + microbatches = [microbatches_balanced] * len(microbatches_1) + + # microbatches = microbatches_2 + + # microbatches = [3, 1, 2, 4, 8/2, 7/2, 4, 2] # Cost of each microbatch + # microbatches = microbatches.sort() + + # Build the schedule and dependencies + schedule = InterleavedSchedule(p, v, microbatches) + test_with_schedule(schedule) + + +def test_slimpipe(): + """Test the solver with a slimpipe schedule.""" + p = 4 # Number of pipeline stages + v = 2 # Number of microbatches + chunks_list = [[1] * 8, [1] * 4] # Cost of each chunk + + # Build the schedule and dependencies + schedule = SlimPipeSchedule(p, v, chunks_list) + test_with_schedule(schedule) + + +def test_splitfuse(): + """Test the solver with a slimpipe schedule.""" + p = 4 # Number of pipeline stages + fwd_costs = [[1] * 8, [2] * 7, [1.5] * 4, [1] * 3] # Cost of each chunk + bwd_costs = [[2] * 8, [2.5] * 7, [3] * 4, [1.8] * 3] # Cost of each chunk + + # Build the schedule and dependencies + schedule = SplitFuseSchedule(p, fwd_costs, bwd_costs) + test_with_schedule(schedule) + + +def test_kfkb(): + """Test the solver with a kFkB schedule.""" + p = 4 # Number of pipeline stages + k = 2 # Number for kFkB + fwd_costs = [1] * 12 # Cost of each chunk + + # Build the schedule and dependencies + schedule = kFkBSchedule(p, k, fwd_costs) + test_with_schedule(schedule) + + +def test_hybrid(): + """Test the solver with a hybrid slimpipe schedule.""" + p = 4 # Number of pipeline stages + k = 2 # Number for kFkB + fwd_costs = [[1, 1] for i in reversed(range(40, 56, 2))] + bwd_costs = [[2, 2] for i in reversed(range(40, 56, 2))] + + # Build the schedule and dependencies + schedule = HybridSlimSchedule( + p, + k, + fwd_switch=(3, 0), + bwd_switch=(3, 1), + fwd_costs=fwd_costs, + bwd_costs=bwd_costs, + ) + test_with_schedule(schedule) + + +if __name__ == "__main__": + test_interleaved() + # test_slimpipe() + # test_splitfuse() + # test_kfkb() + # test_hybrid() diff --git a/megatron/pipeline_simulator/uv.lock b/megatron/pipeline_simulator/uv.lock new file mode 100644 index 00000000000..b59ced20355 --- /dev/null +++ b/megatron/pipeline_simulator/uv.lock @@ -0,0 +1,444 @@ +version = 1 +revision = 1 +requires-python = ">=3.12" + +[[package]] +name = "contourpy" +version = "1.3.2" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/66/54/eb9bfc647b19f2009dd5c7f5ec51c4e6ca831725f1aea7a993034f483147/contourpy-1.3.2.tar.gz", hash = "sha256:b6945942715a034c671b7fc54f9588126b0b8bf23db2696e3ca8328f3ff0ab54" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/34/f7/44785876384eff370c251d58fd65f6ad7f39adce4a093c934d4a67a7c6b6/contourpy-1.3.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4caf2bcd2969402bf77edc4cb6034c7dd7c0803213b3523f111eb7460a51b8d2" }, + { url = "http://mirrors.aliyun.com/pypi/packages/93/3b/0004767622a9826ea3d95f0e9d98cd8729015768075d61f9fea8eeca42a8/contourpy-1.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:82199cb78276249796419fe36b7386bd8d2cc3f28b3bc19fe2454fe2e26c4c15" }, + { url = "http://mirrors.aliyun.com/pypi/packages/e7/bb/7bd49e1f4fa805772d9fd130e0d375554ebc771ed7172f48dfcd4ca61549/contourpy-1.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:106fab697af11456fcba3e352ad50effe493a90f893fca6c2ca5c033820cea92" }, + { url = "http://mirrors.aliyun.com/pypi/packages/fc/97/e1d5dbbfa170725ef78357a9a0edc996b09ae4af170927ba8ce977e60a5f/contourpy-1.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d14f12932a8d620e307f715857107b1d1845cc44fdb5da2bc8e850f5ceba9f87" }, + { url = "http://mirrors.aliyun.com/pypi/packages/6f/66/e69e6e904f5ecf6901be3dd16e7e54d41b6ec6ae3405a535286d4418ffb4/contourpy-1.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:532fd26e715560721bb0d5fc7610fce279b3699b018600ab999d1be895b09415" }, + { url = "http://mirrors.aliyun.com/pypi/packages/a8/32/b8a1c8965e4f72482ff2d1ac2cd670ce0b542f203c8e1d34e7c3e6925da7/contourpy-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26b383144cf2d2c29f01a1e8170f50dacf0eac02d64139dcd709a8ac4eb3cfe" }, + { url = "http://mirrors.aliyun.com/pypi/packages/30/c6/12a7e6811d08757c7162a541ca4c5c6a34c0f4e98ef2b338791093518e40/contourpy-1.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c49f73e61f1f774650a55d221803b101d966ca0c5a2d6d5e4320ec3997489441" }, + { url = "http://mirrors.aliyun.com/pypi/packages/2a/8a/bebe5a3f68b484d3a2b8ffaf84704b3e343ef1addea528132ef148e22b3b/contourpy-1.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3d80b2c0300583228ac98d0a927a1ba6a2ba6b8a742463c564f1d419ee5b211e" }, + { url = "http://mirrors.aliyun.com/pypi/packages/34/db/fcd325f19b5978fb509a7d55e06d99f5f856294c1991097534360b307cf1/contourpy-1.3.2-cp312-cp312-win32.whl", hash = "sha256:90df94c89a91b7362e1142cbee7568f86514412ab8a2c0d0fca72d7e91b62912" }, + { url = "http://mirrors.aliyun.com/pypi/packages/01/c8/fadd0b92ffa7b5eb5949bf340a63a4a496a6930a6c37a7ba0f12acb076d6/contourpy-1.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:8c942a01d9163e2e5cfb05cb66110121b8d07ad438a17f9e766317bcb62abf73" }, + { url = "http://mirrors.aliyun.com/pypi/packages/2e/61/5673f7e364b31e4e7ef6f61a4b5121c5f170f941895912f773d95270f3a2/contourpy-1.3.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:de39db2604ae755316cb5967728f4bea92685884b1e767b7c24e983ef5f771cb" }, + { url = "http://mirrors.aliyun.com/pypi/packages/ff/66/a40badddd1223822c95798c55292844b7e871e50f6bfd9f158cb25e0bd39/contourpy-1.3.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3f9e896f447c5c8618f1edb2bafa9a4030f22a575ec418ad70611450720b5b08" }, + { url = "http://mirrors.aliyun.com/pypi/packages/1e/c7/cf9fdee8200805c9bc3b148f49cb9482a4e3ea2719e772602a425c9b09f8/contourpy-1.3.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71e2bd4a1c4188f5c2b8d274da78faab884b59df20df63c34f74aa1813c4427c" }, + { url = "http://mirrors.aliyun.com/pypi/packages/dd/e7/ccb9bec80e1ba121efbffad7f38021021cda5be87532ec16fd96533bb2e0/contourpy-1.3.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de425af81b6cea33101ae95ece1f696af39446db9682a0b56daaa48cfc29f38f" }, + { url = "http://mirrors.aliyun.com/pypi/packages/dc/49/ca13bb2da90391fa4219fdb23b078d6065ada886658ac7818e5441448b78/contourpy-1.3.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:977e98a0e0480d3fe292246417239d2d45435904afd6d7332d8455981c408b85" }, + { url = "http://mirrors.aliyun.com/pypi/packages/c8/65/5245ce8c548a8422236c13ffcdcdada6a2a812c361e9e0c70548bb40b661/contourpy-1.3.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:434f0adf84911c924519d2b08fc10491dd282b20bdd3fa8f60fd816ea0b48841" }, + { url = "http://mirrors.aliyun.com/pypi/packages/72/30/669b8eb48e0a01c660ead3752a25b44fdb2e5ebc13a55782f639170772f9/contourpy-1.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c66c4906cdbc50e9cba65978823e6e00b45682eb09adbb78c9775b74eb222422" }, + { url = "http://mirrors.aliyun.com/pypi/packages/05/5a/b569f4250decee6e8d54498be7bdf29021a4c256e77fe8138c8319ef8eb3/contourpy-1.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8b7fc0cd78ba2f4695fd0a6ad81a19e7e3ab825c31b577f384aa9d7817dc3bef" }, + { url = "http://mirrors.aliyun.com/pypi/packages/19/ba/b227c3886d120e60e41b28740ac3617b2f2b971b9f601c835661194579f1/contourpy-1.3.2-cp313-cp313-win32.whl", hash = "sha256:15ce6ab60957ca74cff444fe66d9045c1fd3e92c8936894ebd1f3eef2fff075f" }, + { url = "http://mirrors.aliyun.com/pypi/packages/12/6e/2fed56cd47ca739b43e892707ae9a13790a486a3173be063681ca67d2262/contourpy-1.3.2-cp313-cp313-win_amd64.whl", hash = "sha256:e1578f7eafce927b168752ed7e22646dad6cd9bca673c60bff55889fa236ebf9" }, + { url = "http://mirrors.aliyun.com/pypi/packages/54/4c/e76fe2a03014a7c767d79ea35c86a747e9325537a8b7627e0e5b3ba266b4/contourpy-1.3.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0475b1f6604896bc7c53bb070e355e9321e1bc0d381735421a2d2068ec56531f" }, + { url = "http://mirrors.aliyun.com/pypi/packages/7b/e2/5aba47debd55d668e00baf9651b721e7733975dc9fc27264a62b0dd26eb8/contourpy-1.3.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:c85bb486e9be652314bb5b9e2e3b0d1b2e643d5eec4992c0fbe8ac71775da739" }, + { url = "http://mirrors.aliyun.com/pypi/packages/a1/37/cd45f1f051fe6230f751cc5cdd2728bb3a203f5619510ef11e732109593c/contourpy-1.3.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:745b57db7758f3ffc05a10254edd3182a2a83402a89c00957a8e8a22f5582823" }, + { url = "http://mirrors.aliyun.com/pypi/packages/8b/a2/36ea6140c306c9ff6dd38e3bcec80b3b018474ef4d17eb68ceecd26675f4/contourpy-1.3.2-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:970e9173dbd7eba9b4e01aab19215a48ee5dd3f43cef736eebde064a171f89a5" }, + { url = "http://mirrors.aliyun.com/pypi/packages/95/b7/2fc76bc539693180488f7b6cc518da7acbbb9e3b931fd9280504128bf956/contourpy-1.3.2-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c6c4639a9c22230276b7bffb6a850dfc8258a2521305e1faefe804d006b2e532" }, + { url = "http://mirrors.aliyun.com/pypi/packages/f4/10/76d4f778458b0aa83f96e59d65ece72a060bacb20cfbee46cf6cd5ceba41/contourpy-1.3.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc829960f34ba36aad4302e78eabf3ef16a3a100863f0d4eeddf30e8a485a03b" }, + { url = "http://mirrors.aliyun.com/pypi/packages/43/a3/10cf483ea683f9f8ab096c24bad3cce20e0d1dd9a4baa0e2093c1c962d9d/contourpy-1.3.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:d32530b534e986374fc19eaa77fcb87e8a99e5431499949b828312bdcd20ac52" }, + { url = "http://mirrors.aliyun.com/pypi/packages/78/73/69dd9a024444489e22d86108e7b913f3528f56cfc312b5c5727a44188471/contourpy-1.3.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e298e7e70cf4eb179cc1077be1c725b5fd131ebc81181bf0c03525c8abc297fd" }, + { url = "http://mirrors.aliyun.com/pypi/packages/0f/1b/96d586ccf1b1a9d2004dd519b25fbf104a11589abfd05484ff12199cca21/contourpy-1.3.2-cp313-cp313t-win32.whl", hash = "sha256:d0e589ae0d55204991450bb5c23f571c64fe43adaa53f93fc902a84c96f52fe1" }, + { url = "http://mirrors.aliyun.com/pypi/packages/b0/e6/6000d0094e8a5e32ad62591c8609e269febb6e4db83a1c75ff8868b42731/contourpy-1.3.2-cp313-cp313t-win_amd64.whl", hash = "sha256:78e9253c3de756b3f6a5174d024c4835acd59eb3f8e2ca13e775dbffe1558f69" }, +] + +[[package]] +name = "cycler" +version = "0.12.1" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/a9/95/a3dbbb5028f35eafb79008e7522a75244477d2838f38cbb722248dabc2a8/cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30" }, +] + +[[package]] +name = "fonttools" +version = "4.57.0" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/03/2d/a9a0b6e3a0cf6bd502e64fc16d894269011930cabfc89aee20d1635b1441/fonttools-4.57.0.tar.gz", hash = "sha256:727ece10e065be2f9dd239d15dd5d60a66e17eac11aea47d447f9f03fdbc42de" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/cb/98/d4bc42d43392982eecaaca117d79845734d675219680cd43070bb001bc1f/fonttools-4.57.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:889e45e976c74abc7256d3064aa7c1295aa283c6bb19810b9f8b604dfe5c7f31" }, + { url = "http://mirrors.aliyun.com/pypi/packages/1a/62/7168030eeca3742fecf45f31e63b5ef48969fa230a672216b805f1d61548/fonttools-4.57.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:0425c2e052a5f1516c94e5855dbda706ae5a768631e9fcc34e57d074d1b65b92" }, + { url = "http://mirrors.aliyun.com/pypi/packages/5d/82/121a26d9646f0986ddb35fbbaf58ef791c25b59ecb63ffea2aab0099044f/fonttools-4.57.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44c26a311be2ac130f40a96769264809d3b0cb297518669db437d1cc82974888" }, + { url = "http://mirrors.aliyun.com/pypi/packages/5b/26/e0f2fb662e022d565bbe280a3cfe6dafdaabf58889ff86fdef2d31ff1dde/fonttools-4.57.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84c41ba992df5b8d680b89fd84c6a1f2aca2b9f1ae8a67400c8930cd4ea115f6" }, + { url = "http://mirrors.aliyun.com/pypi/packages/9e/44/9075e323347b1891cdece4b3f10a3b84a8f4c42a7684077429d9ce842056/fonttools-4.57.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ea1e9e43ca56b0c12440a7c689b1350066595bebcaa83baad05b8b2675129d98" }, + { url = "http://mirrors.aliyun.com/pypi/packages/48/28/caa8df32743462fb966be6de6a79d7f30393859636d7732e82efa09fbbb4/fonttools-4.57.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:84fd56c78d431606332a0627c16e2a63d243d0d8b05521257d77c6529abe14d8" }, + { url = "http://mirrors.aliyun.com/pypi/packages/f6/46/95ab0f0d2e33c5b1a4fc1c0efe5e286ba9359602c0a9907adb1faca44175/fonttools-4.57.0-cp312-cp312-win32.whl", hash = "sha256:f4376819c1c778d59e0a31db5dc6ede854e9edf28bbfa5b756604727f7f800ac" }, + { url = "http://mirrors.aliyun.com/pypi/packages/06/5d/1be5424bb305880e1113631f49a55ea7c7da3a5fe02608ca7c16a03a21da/fonttools-4.57.0-cp312-cp312-win_amd64.whl", hash = "sha256:57e30241524879ea10cdf79c737037221f77cc126a8cdc8ff2c94d4a522504b9" }, + { url = "http://mirrors.aliyun.com/pypi/packages/e9/2f/11439f3af51e4bb75ac9598c29f8601aa501902dcedf034bdc41f47dd799/fonttools-4.57.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:408ce299696012d503b714778d89aa476f032414ae57e57b42e4b92363e0b8ef" }, + { url = "http://mirrors.aliyun.com/pypi/packages/25/52/677b55a4c0972dc3820c8dba20a29c358197a78229daa2ea219fdb19e5d5/fonttools-4.57.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:bbceffc80aa02d9e8b99f2a7491ed8c4a783b2fc4020119dc405ca14fb5c758c" }, + { url = "http://mirrors.aliyun.com/pypi/packages/64/79/184555f8fa77b827b9460a4acdbbc0b5952bb6915332b84c615c3a236826/fonttools-4.57.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f022601f3ee9e1f6658ed6d184ce27fa5216cee5b82d279e0f0bde5deebece72" }, + { url = "http://mirrors.aliyun.com/pypi/packages/f8/ad/c25116352f456c0d1287545a7aa24e98987b6d99c5b0456c4bd14321f20f/fonttools-4.57.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4dea5893b58d4637ffa925536462ba626f8a1b9ffbe2f5c272cdf2c6ebadb817" }, + { url = "http://mirrors.aliyun.com/pypi/packages/53/ae/398b2a833897297797a44f519c9af911c2136eb7aa27d3f1352c6d1129fa/fonttools-4.57.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:dff02c5c8423a657c550b48231d0a48d7e2b2e131088e55983cfe74ccc2c7cc9" }, + { url = "http://mirrors.aliyun.com/pypi/packages/b7/5d/7cb31c4bc9ffb9a2bbe8b08f8f53bad94aeb158efad75da645b40b62cb73/fonttools-4.57.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:767604f244dc17c68d3e2dbf98e038d11a18abc078f2d0f84b6c24571d9c0b13" }, + { url = "http://mirrors.aliyun.com/pypi/packages/4c/e4/6934513ec2c4d3d69ca1bc3bd34d5c69dafcbf68c15388dd3bb062daf345/fonttools-4.57.0-cp313-cp313-win32.whl", hash = "sha256:8e2e12d0d862f43d51e5afb8b9751c77e6bec7d2dc00aad80641364e9df5b199" }, + { url = "http://mirrors.aliyun.com/pypi/packages/c4/0d/2177b7fdd23d017bcfb702fd41e47d4573766b9114da2fddbac20dcc4957/fonttools-4.57.0-cp313-cp313-win_amd64.whl", hash = "sha256:f1d6bc9c23356908db712d282acb3eebd4ae5ec6d8b696aa40342b1d84f8e9e3" }, + { url = "http://mirrors.aliyun.com/pypi/packages/90/27/45f8957c3132917f91aaa56b700bcfc2396be1253f685bd5c68529b6f610/fonttools-4.57.0-py3-none-any.whl", hash = "sha256:3122c604a675513c68bd24c6a8f9091f1c2376d18e8f5fe5a101746c81b3e98f" }, +] + +[[package]] +name = "kiwisolver" +version = "1.4.8" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/82/59/7c91426a8ac292e1cdd53a63b6d9439abd573c875c3f92c146767dd33faf/kiwisolver-1.4.8.tar.gz", hash = "sha256:23d5f023bdc8c7e54eb65f03ca5d5bb25b601eac4d7f1a042888a1f45237987e" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/fc/aa/cea685c4ab647f349c3bc92d2daf7ae34c8e8cf405a6dcd3a497f58a2ac3/kiwisolver-1.4.8-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:d6af5e8815fd02997cb6ad9bbed0ee1e60014438ee1a5c2444c96f87b8843502" }, + { url = "http://mirrors.aliyun.com/pypi/packages/c5/0b/8db6d2e2452d60d5ebc4ce4b204feeb16176a851fd42462f66ade6808084/kiwisolver-1.4.8-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bade438f86e21d91e0cf5dd7c0ed00cda0f77c8c1616bd83f9fc157fa6760d31" }, + { url = "http://mirrors.aliyun.com/pypi/packages/60/26/d6a0db6785dd35d3ba5bf2b2df0aedc5af089962c6eb2cbf67a15b81369e/kiwisolver-1.4.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b83dc6769ddbc57613280118fb4ce3cd08899cc3369f7d0e0fab518a7cf37fdb" }, + { url = "http://mirrors.aliyun.com/pypi/packages/c9/ed/1d97f7e3561e09757a196231edccc1bcf59d55ddccefa2afc9c615abd8e0/kiwisolver-1.4.8-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:111793b232842991be367ed828076b03d96202c19221b5ebab421ce8bcad016f" }, + { url = "http://mirrors.aliyun.com/pypi/packages/29/61/39d30b99954e6b46f760e6289c12fede2ab96a254c443639052d1b573fbc/kiwisolver-1.4.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:257af1622860e51b1a9d0ce387bf5c2c4f36a90594cb9514f55b074bcc787cfc" }, + { url = "http://mirrors.aliyun.com/pypi/packages/0c/3e/804163b932f7603ef256e4a715e5843a9600802bb23a68b4e08c8c0ff61d/kiwisolver-1.4.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:69b5637c3f316cab1ec1c9a12b8c5f4750a4c4b71af9157645bf32830e39c03a" }, + { url = "http://mirrors.aliyun.com/pypi/packages/8a/9e/60eaa75169a154700be74f875a4d9961b11ba048bef315fbe89cb6999056/kiwisolver-1.4.8-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:782bb86f245ec18009890e7cb8d13a5ef54dcf2ebe18ed65f795e635a96a1c6a" }, + { url = "http://mirrors.aliyun.com/pypi/packages/bc/b3/9458adb9472e61a998c8c4d95cfdfec91c73c53a375b30b1428310f923e4/kiwisolver-1.4.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc978a80a0db3a66d25767b03688f1147a69e6237175c0f4ffffaaedf744055a" }, + { url = "http://mirrors.aliyun.com/pypi/packages/e4/7a/0a42d9571e35798de80aef4bb43a9b672aa7f8e58643d7bd1950398ffb0a/kiwisolver-1.4.8-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:36dbbfd34838500a31f52c9786990d00150860e46cd5041386f217101350f0d3" }, + { url = "http://mirrors.aliyun.com/pypi/packages/d9/07/1255dc8d80271400126ed8db35a1795b1a2c098ac3a72645075d06fe5c5d/kiwisolver-1.4.8-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:eaa973f1e05131de5ff3569bbba7f5fd07ea0595d3870ed4a526d486fe57fa1b" }, + { url = "http://mirrors.aliyun.com/pypi/packages/84/df/5a3b4cf13780ef6f6942df67b138b03b7e79e9f1f08f57c49957d5867f6e/kiwisolver-1.4.8-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:a66f60f8d0c87ab7f59b6fb80e642ebb29fec354a4dfad687ca4092ae69d04f4" }, + { url = "http://mirrors.aliyun.com/pypi/packages/8f/10/2348d068e8b0f635c8c86892788dac7a6b5c0cb12356620ab575775aad89/kiwisolver-1.4.8-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:858416b7fb777a53f0c59ca08190ce24e9abbd3cffa18886a5781b8e3e26f65d" }, + { url = "http://mirrors.aliyun.com/pypi/packages/32/d8/014b89fee5d4dce157d814303b0fce4d31385a2af4c41fed194b173b81ac/kiwisolver-1.4.8-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:085940635c62697391baafaaeabdf3dd7a6c3643577dde337f4d66eba021b2b8" }, + { url = "http://mirrors.aliyun.com/pypi/packages/bd/72/dfff0cc97f2a0776e1c9eb5bef1ddfd45f46246c6533b0191887a427bca5/kiwisolver-1.4.8-cp312-cp312-win_amd64.whl", hash = "sha256:01c3d31902c7db5fb6182832713d3b4122ad9317c2c5877d0539227d96bb2e50" }, + { url = "http://mirrors.aliyun.com/pypi/packages/dc/85/220d13d914485c0948a00f0b9eb419efaf6da81b7d72e88ce2391f7aed8d/kiwisolver-1.4.8-cp312-cp312-win_arm64.whl", hash = "sha256:a3c44cb68861de93f0c4a8175fbaa691f0aa22550c331fefef02b618a9dcb476" }, + { url = "http://mirrors.aliyun.com/pypi/packages/79/b3/e62464a652f4f8cd9006e13d07abad844a47df1e6537f73ddfbf1bc997ec/kiwisolver-1.4.8-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:1c8ceb754339793c24aee1c9fb2485b5b1f5bb1c2c214ff13368431e51fc9a09" }, + { url = "http://mirrors.aliyun.com/pypi/packages/8d/2d/f13d06998b546a2ad4f48607a146e045bbe48030774de29f90bdc573df15/kiwisolver-1.4.8-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:54a62808ac74b5e55a04a408cda6156f986cefbcf0ada13572696b507cc92fa1" }, + { url = "http://mirrors.aliyun.com/pypi/packages/59/e3/b8bd14b0a54998a9fd1e8da591c60998dc003618cb19a3f94cb233ec1511/kiwisolver-1.4.8-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:68269e60ee4929893aad82666821aaacbd455284124817af45c11e50a4b42e3c" }, + { url = "http://mirrors.aliyun.com/pypi/packages/f0/1c/6c86f6d85ffe4d0ce04228d976f00674f1df5dc893bf2dd4f1928748f187/kiwisolver-1.4.8-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:34d142fba9c464bc3bbfeff15c96eab0e7310343d6aefb62a79d51421fcc5f1b" }, + { url = "http://mirrors.aliyun.com/pypi/packages/4e/b9/1c6e9f6dcb103ac5cf87cb695845f5fa71379021500153566d8a8a9fc291/kiwisolver-1.4.8-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ddc373e0eef45b59197de815b1b28ef89ae3955e7722cc9710fb91cd77b7f47" }, + { url = "http://mirrors.aliyun.com/pypi/packages/ee/81/aca1eb176de671f8bda479b11acdc42c132b61a2ac861c883907dde6debb/kiwisolver-1.4.8-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:77e6f57a20b9bd4e1e2cedda4d0b986ebd0216236f0106e55c28aea3d3d69b16" }, + { url = "http://mirrors.aliyun.com/pypi/packages/49/f4/e081522473671c97b2687d380e9e4c26f748a86363ce5af48b4a28e48d06/kiwisolver-1.4.8-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08e77738ed7538f036cd1170cbed942ef749137b1311fa2bbe2a7fda2f6bf3cc" }, + { url = "http://mirrors.aliyun.com/pypi/packages/8f/e9/6a7d025d8da8c4931522922cd706105aa32b3291d1add8c5427cdcd66e63/kiwisolver-1.4.8-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a5ce1e481a74b44dd5e92ff03ea0cb371ae7a0268318e202be06c8f04f4f1246" }, + { url = "http://mirrors.aliyun.com/pypi/packages/82/13/13fa685ae167bee5d94b415991c4fc7bb0a1b6ebea6e753a87044b209678/kiwisolver-1.4.8-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:fc2ace710ba7c1dfd1a3b42530b62b9ceed115f19a1656adefce7b1782a37794" }, + { url = "http://mirrors.aliyun.com/pypi/packages/ef/92/bb7c9395489b99a6cb41d502d3686bac692586db2045adc19e45ee64ed23/kiwisolver-1.4.8-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:3452046c37c7692bd52b0e752b87954ef86ee2224e624ef7ce6cb21e8c41cc1b" }, + { url = "http://mirrors.aliyun.com/pypi/packages/ed/12/87f0e9271e2b63d35d0d8524954145837dd1a6c15b62a2d8c1ebe0f182b4/kiwisolver-1.4.8-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:7e9a60b50fe8b2ec6f448fe8d81b07e40141bfced7f896309df271a0b92f80f3" }, + { url = "http://mirrors.aliyun.com/pypi/packages/02/6e/c8af39288edbce8bf0fa35dee427b082758a4b71e9c91ef18fa667782138/kiwisolver-1.4.8-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:918139571133f366e8362fa4a297aeba86c7816b7ecf0bc79168080e2bd79957" }, + { url = "http://mirrors.aliyun.com/pypi/packages/13/78/df381bc7b26e535c91469f77f16adcd073beb3e2dd25042efd064af82323/kiwisolver-1.4.8-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e063ef9f89885a1d68dd8b2e18f5ead48653176d10a0e324e3b0030e3a69adeb" }, + { url = "http://mirrors.aliyun.com/pypi/packages/d0/dc/c1abe38c37c071d0fc71c9a474fd0b9ede05d42f5a458d584619cfd2371a/kiwisolver-1.4.8-cp313-cp313-win_amd64.whl", hash = "sha256:a17b7c4f5b2c51bb68ed379defd608a03954a1845dfed7cc0117f1cc8a9b7fd2" }, + { url = "http://mirrors.aliyun.com/pypi/packages/a0/b6/21529d595b126ac298fdd90b705d87d4c5693de60023e0efcb4f387ed99e/kiwisolver-1.4.8-cp313-cp313-win_arm64.whl", hash = "sha256:3cd3bc628b25f74aedc6d374d5babf0166a92ff1317f46267f12d2ed54bc1d30" }, + { url = "http://mirrors.aliyun.com/pypi/packages/34/bd/b89380b7298e3af9b39f49334e3e2a4af0e04819789f04b43d560516c0c8/kiwisolver-1.4.8-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:370fd2df41660ed4e26b8c9d6bbcad668fbe2560462cba151a721d49e5b6628c" }, + { url = "http://mirrors.aliyun.com/pypi/packages/83/41/5857dc72e5e4148eaac5aa76e0703e594e4465f8ab7ec0fc60e3a9bb8fea/kiwisolver-1.4.8-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:84a2f830d42707de1d191b9490ac186bf7997a9495d4e9072210a1296345f7dc" }, + { url = "http://mirrors.aliyun.com/pypi/packages/e1/d1/be059b8db56ac270489fb0b3297fd1e53d195ba76e9bbb30e5401fa6b759/kiwisolver-1.4.8-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:7a3ad337add5148cf51ce0b55642dc551c0b9d6248458a757f98796ca7348712" }, + { url = "http://mirrors.aliyun.com/pypi/packages/e1/83/4b73975f149819eb7dcf9299ed467eba068ecb16439a98990dcb12e63fdd/kiwisolver-1.4.8-cp313-cp313t-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7506488470f41169b86d8c9aeff587293f530a23a23a49d6bc64dab66bedc71e" }, + { url = "http://mirrors.aliyun.com/pypi/packages/c7/2c/30a5cdde5102958e602c07466bce058b9d7cb48734aa7a4327261ac8e002/kiwisolver-1.4.8-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f0121b07b356a22fb0414cec4666bbe36fd6d0d759db3d37228f496ed67c880" }, + { url = "http://mirrors.aliyun.com/pypi/packages/ff/9b/1e71db1c000385aa069704f5990574b8244cce854ecd83119c19e83c9586/kiwisolver-1.4.8-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d6d6bd87df62c27d4185de7c511c6248040afae67028a8a22012b010bc7ad062" }, + { url = "http://mirrors.aliyun.com/pypi/packages/85/92/c8fec52ddf06231b31cbb779af77e99b8253cd96bd135250b9498144c78b/kiwisolver-1.4.8-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:291331973c64bb9cce50bbe871fb2e675c4331dab4f31abe89f175ad7679a4d7" }, + { url = "http://mirrors.aliyun.com/pypi/packages/0b/51/9eb7e2cd07a15d8bdd976f6190c0164f92ce1904e5c0c79198c4972926b7/kiwisolver-1.4.8-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:893f5525bb92d3d735878ec00f781b2de998333659507d29ea4466208df37bed" }, + { url = "http://mirrors.aliyun.com/pypi/packages/0f/95/c5a00387a5405e68ba32cc64af65ce881a39b98d73cc394b24143bebc5b8/kiwisolver-1.4.8-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b47a465040146981dc9db8647981b8cb96366fbc8d452b031e4f8fdffec3f26d" }, + { url = "http://mirrors.aliyun.com/pypi/packages/44/83/eeb7af7d706b8347548313fa3a3a15931f404533cc54fe01f39e830dd231/kiwisolver-1.4.8-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:99cea8b9dd34ff80c521aef46a1dddb0dcc0283cf18bde6d756f1e6f31772165" }, + { url = "http://mirrors.aliyun.com/pypi/packages/05/f9/27e94c1b3eb29e6933b6986ffc5fa1177d2cd1f0c8efc5f02c91c9ac61de/kiwisolver-1.4.8-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:151dffc4865e5fe6dafce5480fab84f950d14566c480c08a53c663a0020504b6" }, + { url = "http://mirrors.aliyun.com/pypi/packages/d9/d4/3c9735faa36ac591a4afcc2980d2691000506050b7a7e80bcfe44048daa7/kiwisolver-1.4.8-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:577facaa411c10421314598b50413aa1ebcf5126f704f1e5d72d7e4e9f020d90" }, + { url = "http://mirrors.aliyun.com/pypi/packages/4c/fa/be89a49c640930180657482a74970cdcf6f7072c8d2471e1babe17a222dc/kiwisolver-1.4.8-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:be4816dc51c8a471749d664161b434912eee82f2ea66bd7628bd14583a833e85" }, +] + +[[package]] +name = "matplotlib" +version = "3.10.1" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "contourpy" }, + { name = "cycler" }, + { name = "fonttools" }, + { name = "kiwisolver" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pillow" }, + { name = "pyparsing" }, + { name = "python-dateutil" }, +] +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/2f/08/b89867ecea2e305f408fbb417139a8dd941ecf7b23a2e02157c36da546f0/matplotlib-3.10.1.tar.gz", hash = "sha256:e8d2d0e3881b129268585bf4765ad3ee73a4591d77b9a18c214ac7e3a79fb2ba" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/7c/1d/5e0dc3b59c034e43de16f94deb68f4ad8a96b3ea00f4b37c160b7474928e/matplotlib-3.10.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:66e907a06e68cb6cfd652c193311d61a12b54f56809cafbed9736ce5ad92f107" }, + { url = "http://mirrors.aliyun.com/pypi/packages/7a/81/dae7e14042e74da658c3336ab9799128e09a1ee03964f2d89630b5d12106/matplotlib-3.10.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b4bb156abb8fa5e5b2b460196f7db7264fc6d62678c03457979e7d5254b7be" }, + { url = "http://mirrors.aliyun.com/pypi/packages/21/c4/22516775dcde10fc9c9571d155f90710761b028fc44f660508106c363c97/matplotlib-3.10.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1985ad3d97f51307a2cbfc801a930f120def19ba22864182dacef55277102ba6" }, + { url = "http://mirrors.aliyun.com/pypi/packages/63/23/c0615001f67ce7c96b3051d856baedc0c818a2ed84570b9bf9bde200f85d/matplotlib-3.10.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c96f2c2f825d1257e437a1482c5a2cf4fee15db4261bd6fc0750f81ba2b4ba3d" }, + { url = "http://mirrors.aliyun.com/pypi/packages/ca/c0/a07939a82aed77770514348f4568177d7dadab9787ebc618a616fe3d665e/matplotlib-3.10.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:35e87384ee9e488d8dd5a2dd7baf471178d38b90618d8ea147aced4ab59c9bea" }, + { url = "http://mirrors.aliyun.com/pypi/packages/a6/b6/a9405484fb40746fdc6ae4502b16a9d6e53282ba5baaf9ebe2da579f68c4/matplotlib-3.10.1-cp312-cp312-win_amd64.whl", hash = "sha256:cfd414bce89cc78a7e1d25202e979b3f1af799e416010a20ab2b5ebb3a02425c" }, + { url = "http://mirrors.aliyun.com/pypi/packages/60/73/6770ff5e5523d00f3bc584acb6031e29ee5c8adc2336b16cd1d003675fe0/matplotlib-3.10.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c42eee41e1b60fd83ee3292ed83a97a5f2a8239b10c26715d8a6172226988d7b" }, + { url = "http://mirrors.aliyun.com/pypi/packages/08/97/b0ca5da0ed54a3f6599c3ab568bdda65269bc27c21a2c97868c1625e4554/matplotlib-3.10.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4f0647b17b667ae745c13721602b540f7aadb2a32c5b96e924cd4fea5dcb90f1" }, + { url = "http://mirrors.aliyun.com/pypi/packages/df/9a/1acbdc3b165d4ce2dcd2b1a6d4ffb46a7220ceee960c922c3d50d8514067/matplotlib-3.10.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa3854b5f9473564ef40a41bc922be978fab217776e9ae1545c9b3a5cf2092a3" }, + { url = "http://mirrors.aliyun.com/pypi/packages/51/d0/2bc4368abf766203e548dc7ab57cf7e9c621f1a3c72b516cc7715347b179/matplotlib-3.10.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e496c01441be4c7d5f96d4e40f7fca06e20dcb40e44c8daa2e740e1757ad9e6" }, + { url = "http://mirrors.aliyun.com/pypi/packages/ab/1b/8b350f8a1746c37ab69dda7d7528d1fc696efb06db6ade9727b7887be16d/matplotlib-3.10.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5d45d3f5245be5b469843450617dcad9af75ca50568acf59997bed9311131a0b" }, + { url = "http://mirrors.aliyun.com/pypi/packages/89/06/f570373d24d93503988ba8d04f213a372fa1ce48381c5eb15da985728498/matplotlib-3.10.1-cp313-cp313-win_amd64.whl", hash = "sha256:8e8e25b1209161d20dfe93037c8a7f7ca796ec9aa326e6e4588d8c4a5dd1e473" }, + { url = "http://mirrors.aliyun.com/pypi/packages/fc/e0/8c811a925b5a7ad75135f0e5af46408b78af88bbb02a1df775100ef9bfef/matplotlib-3.10.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:19b06241ad89c3ae9469e07d77efa87041eac65d78df4fcf9cac318028009b01" }, + { url = "http://mirrors.aliyun.com/pypi/packages/4a/34/319ec2139f68ba26da9d00fce2ff9f27679fb799a6c8e7358539801fd629/matplotlib-3.10.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:01e63101ebb3014e6e9f80d9cf9ee361a8599ddca2c3e166c563628b39305dbb" }, + { url = "http://mirrors.aliyun.com/pypi/packages/77/ea/9812124ab9a99df5b2eec1110e9b2edc0b8f77039abf4c56e0a376e84a29/matplotlib-3.10.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f06bad951eea6422ac4e8bdebcf3a70c59ea0a03338c5d2b109f57b64eb3972" }, + { url = "http://mirrors.aliyun.com/pypi/packages/c9/db/b05bf463689134789b06dea85828f8ebe506fa1e37593f723b65b86c9582/matplotlib-3.10.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3dfb036f34873b46978f55e240cff7a239f6c4409eac62d8145bad3fc6ba5a3" }, + { url = "http://mirrors.aliyun.com/pypi/packages/c2/04/41ccec4409f3023a7576df3b5c025f1a8c8b81fbfe922ecfd837ac36e081/matplotlib-3.10.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:dc6ab14a7ab3b4d813b88ba957fc05c79493a037f54e246162033591e770de6f" }, + { url = "http://mirrors.aliyun.com/pypi/packages/ac/c2/0d5aae823bdcc42cc99327ecdd4d28585e15ccd5218c453b7bcd827f3421/matplotlib-3.10.1-cp313-cp313t-win_amd64.whl", hash = "sha256:bc411ebd5889a78dabbc457b3fa153203e22248bfa6eedc6797be5df0164dbf9" }, +] + +[[package]] +name = "mypy" +version = "1.15.0" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "mypy-extensions" }, + { name = "typing-extensions" }, +] +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/ce/43/d5e49a86afa64bd3839ea0d5b9c7103487007d728e1293f52525d6d5486a/mypy-1.15.0.tar.gz", hash = "sha256:404534629d51d3efea5c800ee7c42b72a6554d6c400e6a79eafe15d11341fd43" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/98/3a/03c74331c5eb8bd025734e04c9840532226775c47a2c39b56a0c8d4f128d/mypy-1.15.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:aea39e0583d05124836ea645f412e88a5c7d0fd77a6d694b60d9b6b2d9f184fd" }, + { url = "http://mirrors.aliyun.com/pypi/packages/f0/1a/41759b18f2cfd568848a37c89030aeb03534411eef981df621d8fad08a1d/mypy-1.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2f2147ab812b75e5b5499b01ade1f4a81489a147c01585cda36019102538615f" }, + { url = "http://mirrors.aliyun.com/pypi/packages/12/7e/873481abf1ef112c582db832740f4c11b2bfa510e829d6da29b0ab8c3f9c/mypy-1.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce436f4c6d218a070048ed6a44c0bbb10cd2cc5e272b29e7845f6a2f57ee4464" }, + { url = "http://mirrors.aliyun.com/pypi/packages/b3/d0/92ae4cde706923a2d3f2d6c39629134063ff64b9dedca9c1388363da072d/mypy-1.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8023ff13985661b50a5928fc7a5ca15f3d1affb41e5f0a9952cb68ef090b31ee" }, + { url = "http://mirrors.aliyun.com/pypi/packages/46/8b/df49974b337cce35f828ba6fda228152d6db45fed4c86ba56ffe442434fd/mypy-1.15.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1124a18bc11a6a62887e3e137f37f53fbae476dc36c185d549d4f837a2a6a14e" }, + { url = "http://mirrors.aliyun.com/pypi/packages/13/50/da5203fcf6c53044a0b699939f31075c45ae8a4cadf538a9069b165c1050/mypy-1.15.0-cp312-cp312-win_amd64.whl", hash = "sha256:171a9ca9a40cd1843abeca0e405bc1940cd9b305eaeea2dda769ba096932bb22" }, + { url = "http://mirrors.aliyun.com/pypi/packages/6a/9b/fd2e05d6ffff24d912f150b87db9e364fa8282045c875654ce7e32fffa66/mypy-1.15.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:93faf3fdb04768d44bf28693293f3904bbb555d076b781ad2530214ee53e3445" }, + { url = "http://mirrors.aliyun.com/pypi/packages/74/37/b246d711c28a03ead1fd906bbc7106659aed7c089d55fe40dd58db812628/mypy-1.15.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:811aeccadfb730024c5d3e326b2fbe9249bb7413553f15499a4050f7c30e801d" }, + { url = "http://mirrors.aliyun.com/pypi/packages/a6/ac/395808a92e10cfdac8003c3de9a2ab6dc7cde6c0d2a4df3df1b815ffd067/mypy-1.15.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:98b7b9b9aedb65fe628c62a6dc57f6d5088ef2dfca37903a7d9ee374d03acca5" }, + { url = "http://mirrors.aliyun.com/pypi/packages/d2/8b/801aa06445d2de3895f59e476f38f3f8d610ef5d6908245f07d002676cbf/mypy-1.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c43a7682e24b4f576d93072216bf56eeff70d9140241f9edec0c104d0c515036" }, + { url = "http://mirrors.aliyun.com/pypi/packages/c7/67/5a4268782eb77344cc613a4cf23540928e41f018a9a1ec4c6882baf20ab8/mypy-1.15.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:baefc32840a9f00babd83251560e0ae1573e2f9d1b067719479bfb0e987c6357" }, + { url = "http://mirrors.aliyun.com/pypi/packages/83/3e/57bb447f7bbbfaabf1712d96f9df142624a386d98fb026a761532526057e/mypy-1.15.0-cp313-cp313-win_amd64.whl", hash = "sha256:b9378e2c00146c44793c98b8d5a61039a048e31f429fb0eb546d93f4b000bedf" }, + { url = "http://mirrors.aliyun.com/pypi/packages/09/4e/a7d65c7322c510de2c409ff3828b03354a7c43f5a8ed458a7a131b41c7b9/mypy-1.15.0-py3-none-any.whl", hash = "sha256:5469affef548bd1895d86d3bf10ce2b44e33d86923c29e4d675b3e323437ea3e" }, +] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d" }, +] + +[[package]] +name = "networkx" +version = "3.4.2" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f" }, +] + +[[package]] +name = "numpy" +version = "2.2.4" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/e1/78/31103410a57bc2c2b93a3597340a8119588571f6a4539067546cb9a0bfac/numpy-2.2.4.tar.gz", hash = "sha256:9ba03692a45d3eef66559efe1d1096c4b9b75c0986b5dff5530c378fb8331d4f" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/a2/30/182db21d4f2a95904cec1a6f779479ea1ac07c0647f064dea454ec650c42/numpy-2.2.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a7b9084668aa0f64e64bd00d27ba5146ef1c3a8835f3bd912e7a9e01326804c4" }, + { url = "http://mirrors.aliyun.com/pypi/packages/24/6d/9483566acfbda6c62c6bc74b6e981c777229d2af93c8eb2469b26ac1b7bc/numpy-2.2.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dbe512c511956b893d2dacd007d955a3f03d555ae05cfa3ff1c1ff6df8851854" }, + { url = "http://mirrors.aliyun.com/pypi/packages/27/f6/dba8a258acbf9d2bed2525cdcbb9493ef9bae5199d7a9cb92ee7e9b2aea6/numpy-2.2.4-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:bb649f8b207ab07caebba230d851b579a3c8711a851d29efe15008e31bb4de24" }, + { url = "http://mirrors.aliyun.com/pypi/packages/62/30/82116199d1c249446723c68f2c9da40d7f062551036f50b8c4caa42ae252/numpy-2.2.4-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:f34dc300df798742b3d06515aa2a0aee20941c13579d7a2f2e10af01ae4901ee" }, + { url = "http://mirrors.aliyun.com/pypi/packages/0e/b2/54122b3c6df5df3e87582b2e9430f1bdb63af4023c739ba300164c9ae503/numpy-2.2.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3f7ac96b16955634e223b579a3e5798df59007ca43e8d451a0e6a50f6bfdfba" }, + { url = "http://mirrors.aliyun.com/pypi/packages/02/e2/e2cbb8d634151aab9528ef7b8bab52ee4ab10e076509285602c2a3a686e0/numpy-2.2.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f92084defa704deadd4e0a5ab1dc52d8ac9e8a8ef617f3fbb853e79b0ea3592" }, + { url = "http://mirrors.aliyun.com/pypi/packages/8e/21/efd47800e4affc993e8be50c1b768de038363dd88865920439ef7b422c60/numpy-2.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7a4e84a6283b36632e2a5b56e121961f6542ab886bc9e12f8f9818b3c266bfbb" }, + { url = "http://mirrors.aliyun.com/pypi/packages/04/1e/f8bb88f6157045dd5d9b27ccf433d016981032690969aa5c19e332b138c0/numpy-2.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:11c43995255eb4127115956495f43e9343736edb7fcdb0d973defd9de14cd84f" }, + { url = "http://mirrors.aliyun.com/pypi/packages/2b/93/df59a5a3897c1f036ae8ff845e45f4081bb06943039ae28a3c1c7c780f22/numpy-2.2.4-cp312-cp312-win32.whl", hash = "sha256:65ef3468b53269eb5fdb3a5c09508c032b793da03251d5f8722b1194f1790c00" }, + { url = "http://mirrors.aliyun.com/pypi/packages/46/69/8c4f928741c2a8efa255fdc7e9097527c6dc4e4df147e3cadc5d9357ce85/numpy-2.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:2aad3c17ed2ff455b8eaafe06bcdae0062a1db77cb99f4b9cbb5f4ecb13c5146" }, + { url = "http://mirrors.aliyun.com/pypi/packages/2a/d0/bd5ad792e78017f5decfb2ecc947422a3669a34f775679a76317af671ffc/numpy-2.2.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1cf4e5c6a278d620dee9ddeb487dc6a860f9b199eadeecc567f777daace1e9e7" }, + { url = "http://mirrors.aliyun.com/pypi/packages/c3/bc/2b3545766337b95409868f8e62053135bdc7fa2ce630aba983a2aa60b559/numpy-2.2.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1974afec0b479e50438fc3648974268f972e2d908ddb6d7fb634598cdb8260a0" }, + { url = "http://mirrors.aliyun.com/pypi/packages/6a/70/67b24d68a56551d43a6ec9fe8c5f91b526d4c1a46a6387b956bf2d64744e/numpy-2.2.4-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:79bd5f0a02aa16808fcbc79a9a376a147cc1045f7dfe44c6e7d53fa8b8a79392" }, + { url = "http://mirrors.aliyun.com/pypi/packages/1c/8b/e2fc8a75fcb7be12d90b31477c9356c0cbb44abce7ffb36be39a0017afad/numpy-2.2.4-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:3387dd7232804b341165cedcb90694565a6015433ee076c6754775e85d86f1fc" }, + { url = "http://mirrors.aliyun.com/pypi/packages/13/73/41b7b27f169ecf368b52533edb72e56a133f9e86256e809e169362553b49/numpy-2.2.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f527d8fdb0286fd2fd97a2a96c6be17ba4232da346931d967a0630050dfd298" }, + { url = "http://mirrors.aliyun.com/pypi/packages/4b/04/e208ff3ae3ddfbafc05910f89546382f15a3f10186b1f56bd99f159689c2/numpy-2.2.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bce43e386c16898b91e162e5baaad90c4b06f9dcbe36282490032cec98dc8ae7" }, + { url = "http://mirrors.aliyun.com/pypi/packages/fe/bc/2218160574d862d5e55f803d88ddcad88beff94791f9c5f86d67bd8fbf1c/numpy-2.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:31504f970f563d99f71a3512d0c01a645b692b12a63630d6aafa0939e52361e6" }, + { url = "http://mirrors.aliyun.com/pypi/packages/a5/78/97c775bc4f05abc8a8426436b7cb1be806a02a2994b195945600855e3a25/numpy-2.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:81413336ef121a6ba746892fad881a83351ee3e1e4011f52e97fba79233611fd" }, + { url = "http://mirrors.aliyun.com/pypi/packages/b9/eb/38c06217a5f6de27dcb41524ca95a44e395e6a1decdc0c99fec0832ce6ae/numpy-2.2.4-cp313-cp313-win32.whl", hash = "sha256:f486038e44caa08dbd97275a9a35a283a8f1d2f0ee60ac260a1790e76660833c" }, + { url = "http://mirrors.aliyun.com/pypi/packages/52/17/d0dd10ab6d125c6d11ffb6dfa3423c3571befab8358d4f85cd4471964fcd/numpy-2.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:207a2b8441cc8b6a2a78c9ddc64d00d20c303d79fba08c577752f080c4007ee3" }, + { url = "http://mirrors.aliyun.com/pypi/packages/fa/e2/793288ede17a0fdc921172916efb40f3cbc2aa97e76c5c84aba6dc7e8747/numpy-2.2.4-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:8120575cb4882318c791f839a4fd66161a6fa46f3f0a5e613071aae35b5dd8f8" }, + { url = "http://mirrors.aliyun.com/pypi/packages/3a/75/bb4573f6c462afd1ea5cbedcc362fe3e9bdbcc57aefd37c681be1155fbaa/numpy-2.2.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a761ba0fa886a7bb33c6c8f6f20213735cb19642c580a931c625ee377ee8bd39" }, + { url = "http://mirrors.aliyun.com/pypi/packages/03/68/07b4cd01090ca46c7a336958b413cdbe75002286295f2addea767b7f16c9/numpy-2.2.4-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:ac0280f1ba4a4bfff363a99a6aceed4f8e123f8a9b234c89140f5e894e452ecd" }, + { url = "http://mirrors.aliyun.com/pypi/packages/a5/fd/d4a29478d622fedff5c4b4b4cedfc37a00691079623c0575978d2446db9e/numpy-2.2.4-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:879cf3a9a2b53a4672a168c21375166171bc3932b7e21f622201811c43cdd3b0" }, + { url = "http://mirrors.aliyun.com/pypi/packages/41/78/96dddb75bb9be730b87c72f30ffdd62611aba234e4e460576a068c98eff6/numpy-2.2.4-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f05d4198c1bacc9124018109c5fba2f3201dbe7ab6e92ff100494f236209c960" }, + { url = "http://mirrors.aliyun.com/pypi/packages/00/06/5306b8199bffac2a29d9119c11f457f6c7d41115a335b78d3f86fad4dbe8/numpy-2.2.4-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2f085ce2e813a50dfd0e01fbfc0c12bbe5d2063d99f8b29da30e544fb6483b8" }, + { url = "http://mirrors.aliyun.com/pypi/packages/fa/03/74c5b631ee1ded596945c12027649e6344614144369fd3ec1aaced782882/numpy-2.2.4-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:92bda934a791c01d6d9d8e038363c50918ef7c40601552a58ac84c9613a665bc" }, + { url = "http://mirrors.aliyun.com/pypi/packages/cb/dc/4fc7c0283abe0981e3b89f9b332a134e237dd476b0c018e1e21083310c31/numpy-2.2.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ee4d528022f4c5ff67332469e10efe06a267e32f4067dc76bb7e2cddf3cd25ff" }, + { url = "http://mirrors.aliyun.com/pypi/packages/e5/2b/878576190c5cfa29ed896b518cc516aecc7c98a919e20706c12480465f43/numpy-2.2.4-cp313-cp313t-win32.whl", hash = "sha256:05c076d531e9998e7e694c36e8b349969c56eadd2cdcd07242958489d79a7286" }, + { url = "http://mirrors.aliyun.com/pypi/packages/3e/05/eb7eec66b95cf697f08c754ef26c3549d03ebd682819f794cb039574a0a6/numpy-2.2.4-cp313-cp313t-win_amd64.whl", hash = "sha256:188dcbca89834cc2e14eb2f106c96d6d46f200fe0200310fc29089657379c58d" }, +] + +[[package]] +name = "packaging" +version = "24.2" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759" }, +] + +[[package]] +name = "pillow" +version = "11.2.1" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/af/cb/bb5c01fcd2a69335b86c22142b2bccfc3464087efb7fd382eee5ffc7fdf7/pillow-11.2.1.tar.gz", hash = "sha256:a64dd61998416367b7ef979b73d3a85853ba9bec4c2925f74e588879a58716b6" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/c7/40/052610b15a1b8961f52537cc8326ca6a881408bc2bdad0d852edeb6ed33b/pillow-11.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:78afba22027b4accef10dbd5eed84425930ba41b3ea0a86fa8d20baaf19d807f" }, + { url = "http://mirrors.aliyun.com/pypi/packages/e5/7e/b86dbd35a5f938632093dc40d1682874c33dcfe832558fc80ca56bfcb774/pillow-11.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:78092232a4ab376a35d68c4e6d5e00dfd73454bd12b230420025fbe178ee3b0b" }, + { url = "http://mirrors.aliyun.com/pypi/packages/a4/5c/467a161f9ed53e5eab51a42923c33051bf8d1a2af4626ac04f5166e58e0c/pillow-11.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25a5f306095c6780c52e6bbb6109624b95c5b18e40aab1c3041da3e9e0cd3e2d" }, + { url = "http://mirrors.aliyun.com/pypi/packages/62/73/972b7742e38ae0e2ac76ab137ca6005dcf877480da0d9d61d93b613065b4/pillow-11.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c7b29dbd4281923a2bfe562acb734cee96bbb129e96e6972d315ed9f232bef4" }, + { url = "http://mirrors.aliyun.com/pypi/packages/e4/3a/427e4cb0b9e177efbc1a84798ed20498c4f233abde003c06d2650a6d60cb/pillow-11.2.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3e645b020f3209a0181a418bffe7b4a93171eef6c4ef6cc20980b30bebf17b7d" }, + { url = "http://mirrors.aliyun.com/pypi/packages/fe/7c/d8b1330458e4d2f3f45d9508796d7caf0c0d3764c00c823d10f6f1a3b76d/pillow-11.2.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b2dbea1012ccb784a65349f57bbc93730b96e85b42e9bf7b01ef40443db720b4" }, + { url = "http://mirrors.aliyun.com/pypi/packages/b3/2f/65738384e0b1acf451de5a573d8153fe84103772d139e1e0bdf1596be2ea/pillow-11.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:da3104c57bbd72948d75f6a9389e6727d2ab6333c3617f0a89d72d4940aa0443" }, + { url = "http://mirrors.aliyun.com/pypi/packages/6a/c5/e795c9f2ddf3debb2dedd0df889f2fe4b053308bb59a3cc02a0cd144d641/pillow-11.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:598174aef4589af795f66f9caab87ba4ff860ce08cd5bb447c6fc553ffee603c" }, + { url = "http://mirrors.aliyun.com/pypi/packages/96/ae/ca0099a3995976a9fce2f423166f7bff9b12244afdc7520f6ed38911539a/pillow-11.2.1-cp312-cp312-win32.whl", hash = "sha256:1d535df14716e7f8776b9e7fee118576d65572b4aad3ed639be9e4fa88a1cad3" }, + { url = "http://mirrors.aliyun.com/pypi/packages/7c/18/24bff2ad716257fc03da964c5e8f05d9790a779a8895d6566e493ccf0189/pillow-11.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:14e33b28bf17c7a38eede290f77db7c664e4eb01f7869e37fa98a5aa95978941" }, + { url = "http://mirrors.aliyun.com/pypi/packages/da/bb/e8d656c9543276517ee40184aaa39dcb41e683bca121022f9323ae11b39d/pillow-11.2.1-cp312-cp312-win_arm64.whl", hash = "sha256:21e1470ac9e5739ff880c211fc3af01e3ae505859392bf65458c224d0bf283eb" }, + { url = "http://mirrors.aliyun.com/pypi/packages/36/9c/447528ee3776e7ab8897fe33697a7ff3f0475bb490c5ac1456a03dc57956/pillow-11.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fdec757fea0b793056419bca3e9932eb2b0ceec90ef4813ea4c1e072c389eb28" }, + { url = "http://mirrors.aliyun.com/pypi/packages/b5/09/29d5cd052f7566a63e5b506fac9c60526e9ecc553825551333e1e18a4858/pillow-11.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b0e130705d568e2f43a17bcbe74d90958e8a16263868a12c3e0d9c8162690830" }, + { url = "http://mirrors.aliyun.com/pypi/packages/71/5d/446ee132ad35e7600652133f9c2840b4799bbd8e4adba881284860da0a36/pillow-11.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bdb5e09068332578214cadd9c05e3d64d99e0e87591be22a324bdbc18925be0" }, + { url = "http://mirrors.aliyun.com/pypi/packages/69/5f/cbe509c0ddf91cc3a03bbacf40e5c2339c4912d16458fcb797bb47bcb269/pillow-11.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d189ba1bebfbc0c0e529159631ec72bb9e9bc041f01ec6d3233d6d82eb823bc1" }, + { url = "http://mirrors.aliyun.com/pypi/packages/f9/b3/dd4338d8fb8a5f312021f2977fb8198a1184893f9b00b02b75d565c33b51/pillow-11.2.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:191955c55d8a712fab8934a42bfefbf99dd0b5875078240943f913bb66d46d9f" }, + { url = "http://mirrors.aliyun.com/pypi/packages/13/eb/2552ecebc0b887f539111c2cd241f538b8ff5891b8903dfe672e997529be/pillow-11.2.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:ad275964d52e2243430472fc5d2c2334b4fc3ff9c16cb0a19254e25efa03a155" }, + { url = "http://mirrors.aliyun.com/pypi/packages/72/d1/924ce51bea494cb6e7959522d69d7b1c7e74f6821d84c63c3dc430cbbf3b/pillow-11.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:750f96efe0597382660d8b53e90dd1dd44568a8edb51cb7f9d5d918b80d4de14" }, + { url = "http://mirrors.aliyun.com/pypi/packages/43/ab/8f81312d255d713b99ca37479a4cb4b0f48195e530cdc1611990eb8fd04b/pillow-11.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fe15238d3798788d00716637b3d4e7bb6bde18b26e5d08335a96e88564a36b6b" }, + { url = "http://mirrors.aliyun.com/pypi/packages/94/86/8f2e9d2dc3d308dfd137a07fe1cc478df0a23d42a6c4093b087e738e4827/pillow-11.2.1-cp313-cp313-win32.whl", hash = "sha256:3fe735ced9a607fee4f481423a9c36701a39719252a9bb251679635f99d0f7d2" }, + { url = "http://mirrors.aliyun.com/pypi/packages/6d/ec/1179083b8d6067a613e4d595359b5fdea65d0a3b7ad623fee906e1b3c4d2/pillow-11.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:74ee3d7ecb3f3c05459ba95eed5efa28d6092d751ce9bf20e3e253a4e497e691" }, + { url = "http://mirrors.aliyun.com/pypi/packages/23/f1/2fc1e1e294de897df39fa8622d829b8828ddad938b0eaea256d65b84dd72/pillow-11.2.1-cp313-cp313-win_arm64.whl", hash = "sha256:5119225c622403afb4b44bad4c1ca6c1f98eed79db8d3bc6e4e160fc6339d66c" }, + { url = "http://mirrors.aliyun.com/pypi/packages/c4/3e/c328c48b3f0ead7bab765a84b4977acb29f101d10e4ef57a5e3400447c03/pillow-11.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:8ce2e8411c7aaef53e6bb29fe98f28cd4fbd9a1d9be2eeea434331aac0536b22" }, + { url = "http://mirrors.aliyun.com/pypi/packages/18/0e/1c68532d833fc8b9f404d3a642991441d9058eccd5606eab31617f29b6d4/pillow-11.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:9ee66787e095127116d91dea2143db65c7bb1e232f617aa5957c0d9d2a3f23a7" }, + { url = "http://mirrors.aliyun.com/pypi/packages/b7/cb/6faf3fb1e7705fd2db74e070f3bf6f88693601b0ed8e81049a8266de4754/pillow-11.2.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9622e3b6c1d8b551b6e6f21873bdcc55762b4b2126633014cea1803368a9aa16" }, + { url = "http://mirrors.aliyun.com/pypi/packages/07/94/8be03d50b70ca47fb434a358919d6a8d6580f282bbb7af7e4aa40103461d/pillow-11.2.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63b5dff3a68f371ea06025a1a6966c9a1e1ee452fc8020c2cd0ea41b83e9037b" }, + { url = "http://mirrors.aliyun.com/pypi/packages/fd/a4/bfe78777076dc405e3bd2080bc32da5ab3945b5a25dc5d8acaa9de64a162/pillow-11.2.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:31df6e2d3d8fc99f993fd253e97fae451a8db2e7207acf97859732273e108406" }, + { url = "http://mirrors.aliyun.com/pypi/packages/65/4d/eaf9068dc687c24979e977ce5677e253624bd8b616b286f543f0c1b91662/pillow-11.2.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:062b7a42d672c45a70fa1f8b43d1d38ff76b63421cbbe7f88146b39e8a558d91" }, + { url = "http://mirrors.aliyun.com/pypi/packages/1d/26/0fd443365d9c63bc79feb219f97d935cd4b93af28353cba78d8e77b61719/pillow-11.2.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4eb92eca2711ef8be42fd3f67533765d9fd043b8c80db204f16c8ea62ee1a751" }, + { url = "http://mirrors.aliyun.com/pypi/packages/49/65/dca4d2506be482c2c6641cacdba5c602bc76d8ceb618fd37de855653a419/pillow-11.2.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f91ebf30830a48c825590aede79376cb40f110b387c17ee9bd59932c961044f9" }, + { url = "http://mirrors.aliyun.com/pypi/packages/b3/92/1ca0c3f09233bd7decf8f7105a1c4e3162fb9142128c74adad0fb361b7eb/pillow-11.2.1-cp313-cp313t-win32.whl", hash = "sha256:e0b55f27f584ed623221cfe995c912c61606be8513bfa0e07d2c674b4516d9dd" }, + { url = "http://mirrors.aliyun.com/pypi/packages/a5/ac/77525347cb43b83ae905ffe257bbe2cc6fd23acb9796639a1f56aa59d191/pillow-11.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:36d6b82164c39ce5482f649b437382c0fb2395eabc1e2b1702a6deb8ad647d6e" }, + { url = "http://mirrors.aliyun.com/pypi/packages/67/32/32dc030cfa91ca0fc52baebbba2e009bb001122a1daa8b6a79ad830b38d3/pillow-11.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:225c832a13326e34f212d2072982bb1adb210e0cc0b153e688743018c94a2681" }, +] + +[[package]] +name = "pipeline-simulator" +version = "0.2.0" +source = { virtual = "." } +dependencies = [ + { name = "matplotlib" }, + { name = "networkx" }, + { name = "numpy" }, + { name = "scipy" }, +] + +[package.dev-dependencies] +dev = [ + { name = "mypy" }, + { name = "ruff" }, + { name = "types-networkx" }, +] + +[package.metadata] +requires-dist = [ + { name = "matplotlib", specifier = ">=3.10.1" }, + { name = "networkx", specifier = ">=3.4.2" }, + { name = "numpy", specifier = ">=2.2.4" }, + { name = "scipy", specifier = ">=1.15.2" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "mypy", specifier = ">=1.15.0" }, + { name = "ruff", specifier = ">=0.11.6" }, + { name = "types-networkx", specifier = ">=3.4.2.20250319" }, +] + +[[package]] +name = "pyparsing" +version = "3.2.3" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/bb/22/f1129e69d94ffff626bdb5c835506b3a5b4f3d070f17ea295e12c2c6f60f/pyparsing-3.2.3.tar.gz", hash = "sha256:b9c13f1ab8b3b542f72e28f634bad4de758ab3ce4546e4301970ad6fa77c38be" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf" }, +] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427" }, +] + +[[package]] +name = "ruff" +version = "0.11.6" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/d9/11/bcef6784c7e5d200b8a1f5c2ddf53e5da0efec37e6e5a44d163fb97e04ba/ruff-0.11.6.tar.gz", hash = "sha256:bec8bcc3ac228a45ccc811e45f7eb61b950dbf4cf31a67fa89352574b01c7d79" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/6e/1f/8848b625100ebcc8740c8bac5b5dd8ba97dd4ee210970e98832092c1635b/ruff-0.11.6-py3-none-linux_armv6l.whl", hash = "sha256:d84dcbe74cf9356d1bdb4a78cf74fd47c740bf7bdeb7529068f69b08272239a1" }, + { url = "http://mirrors.aliyun.com/pypi/packages/e0/47/c44036e70c6cc11e6ee24399c2a1e1f1e99be5152bd7dff0190e4b325b76/ruff-0.11.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:9bc583628e1096148011a5d51ff3c836f51899e61112e03e5f2b1573a9b726de" }, + { url = "http://mirrors.aliyun.com/pypi/packages/ed/5b/170444061650202d84d316e8f112de02d092bff71fafe060d3542f5bc5df/ruff-0.11.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f2959049faeb5ba5e3b378709e9d1bf0cab06528b306b9dd6ebd2a312127964a" }, + { url = "http://mirrors.aliyun.com/pypi/packages/ff/91/f02839fb3787c678e112c8865f2c3e87cfe1744dcc96ff9fc56cfb97dda2/ruff-0.11.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63c5d4e30d9d0de7fedbfb3e9e20d134b73a30c1e74b596f40f0629d5c28a193" }, + { url = "http://mirrors.aliyun.com/pypi/packages/9e/f3/c09933306096ff7a08abede3cc2534d6fcf5529ccd26504c16bf363989b5/ruff-0.11.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:26a4b9a4e1439f7d0a091c6763a100cef8fbdc10d68593df6f3cfa5abdd9246e" }, + { url = "http://mirrors.aliyun.com/pypi/packages/e0/0d/a87f8933fccbc0d8c653cfbf44bedda69c9582ba09210a309c066794e2ee/ruff-0.11.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b5edf270223dd622218256569636dc3e708c2cb989242262fe378609eccf1308" }, + { url = "http://mirrors.aliyun.com/pypi/packages/52/7d/8eac0bd083ea8a0b55b7e4628428203441ca68cd55e0b67c135a4bc6e309/ruff-0.11.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:f55844e818206a9dd31ff27f91385afb538067e2dc0beb05f82c293ab84f7d55" }, + { url = "http://mirrors.aliyun.com/pypi/packages/c2/dc/d0c17d875662d0c86fadcf4ca014ab2001f867621b793d5d7eef01b9dcce/ruff-0.11.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d8f782286c5ff562e4e00344f954b9320026d8e3fae2ba9e6948443fafd9ffc" }, + { url = "http://mirrors.aliyun.com/pypi/packages/f9/f3/81a1aea17f1065449a72509fc7ccc3659cf93148b136ff2a8291c4bc3ef1/ruff-0.11.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:01c63ba219514271cee955cd0adc26a4083df1956d57847978383b0e50ffd7d2" }, + { url = "http://mirrors.aliyun.com/pypi/packages/61/9f/a3e34de425a668284e7024ee6fd41f452f6fa9d817f1f3495b46e5e3a407/ruff-0.11.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15adac20ef2ca296dd3d8e2bedc6202ea6de81c091a74661c3666e5c4c223ff6" }, + { url = "http://mirrors.aliyun.com/pypi/packages/df/c5/4a57a86d12542c0f6e2744f262257b2aa5a3783098ec14e40f3e4b3a354a/ruff-0.11.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4dd6b09e98144ad7aec026f5588e493c65057d1b387dd937d7787baa531d9bc2" }, + { url = "http://mirrors.aliyun.com/pypi/packages/58/3f/a3b4346dff07ef5b862e2ba06d98fcbf71f66f04cf01d375e871382b5e4b/ruff-0.11.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:45b2e1d6c0eed89c248d024ea95074d0e09988d8e7b1dad8d3ab9a67017a5b03" }, + { url = "http://mirrors.aliyun.com/pypi/packages/93/cc/7ed02e0b86a649216b845b3ac66ed55d8aa86f5898c5f1691797f408fcb9/ruff-0.11.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:bd40de4115b2ec4850302f1a1d8067f42e70b4990b68838ccb9ccd9f110c5e8b" }, + { url = "http://mirrors.aliyun.com/pypi/packages/39/5e/5b09840fef0eff1a6fa1dea6296c07d09c17cb6fb94ed5593aa591b50460/ruff-0.11.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:77cda2dfbac1ab73aef5e514c4cbfc4ec1fbef4b84a44c736cc26f61b3814cd9" }, + { url = "http://mirrors.aliyun.com/pypi/packages/6f/4c/1cd5a84a412d3626335ae69f5f9de2bb554eea0faf46deb1f0cb48534042/ruff-0.11.6-py3-none-win32.whl", hash = "sha256:5151a871554be3036cd6e51d0ec6eef56334d74dfe1702de717a995ee3d5b287" }, + { url = "http://mirrors.aliyun.com/pypi/packages/42/46/8997872bc44d43df986491c18d4418f1caff03bc47b7f381261d62c23442/ruff-0.11.6-py3-none-win_amd64.whl", hash = "sha256:cce85721d09c51f3b782c331b0abd07e9d7d5f775840379c640606d3159cae0e" }, + { url = "http://mirrors.aliyun.com/pypi/packages/d7/6a/65fecd51a9ca19e1477c3879a7fda24f8904174d1275b419422ac00f6eee/ruff-0.11.6-py3-none-win_arm64.whl", hash = "sha256:3567ba0d07fb170b1b48d944715e3294b77f5b7679e8ba258199a250383ccb79" }, +] + +[[package]] +name = "scipy" +version = "1.15.2" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/b7/b9/31ba9cd990e626574baf93fbc1ac61cf9ed54faafd04c479117517661637/scipy-1.15.2.tar.gz", hash = "sha256:cd58a314d92838f7e6f755c8a2167ead4f27e1fd5c1251fd54289569ef3495ec" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/4b/5d/3c78815cbab499610f26b5bae6aed33e227225a9fa5290008a733a64f6fc/scipy-1.15.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c4697a10da8f8765bb7c83e24a470da5797e37041edfd77fd95ba3811a47c4fd" }, + { url = "http://mirrors.aliyun.com/pypi/packages/37/20/3d04eb066b471b6e171827548b9ddb3c21c6bbea72a4d84fc5989933910b/scipy-1.15.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:869269b767d5ee7ea6991ed7e22b3ca1f22de73ab9a49c44bad338b725603301" }, + { url = "http://mirrors.aliyun.com/pypi/packages/a4/98/e5c964526c929ef1f795d4c343b2ff98634ad2051bd2bbadfef9e772e413/scipy-1.15.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:bad78d580270a4d32470563ea86c6590b465cb98f83d760ff5b0990cb5518a93" }, + { url = "http://mirrors.aliyun.com/pypi/packages/1d/cd/1dc7371e29195ecbf5222f9afeedb210e0a75057d8afbd942aa6cf8c8eca/scipy-1.15.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:b09ae80010f52efddb15551025f9016c910296cf70adbf03ce2a8704f3a5ad20" }, + { url = "http://mirrors.aliyun.com/pypi/packages/f0/24/1a181a9e5050090e0b5138c5f496fee33293c342b788d02586bc410c6477/scipy-1.15.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a6fd6eac1ce74a9f77a7fc724080d507c5812d61e72bd5e4c489b042455865e" }, + { url = "http://mirrors.aliyun.com/pypi/packages/c0/53/eaada1a414c026673eb983f8b4a55fe5eb172725d33d62c1b21f63ff6ca4/scipy-1.15.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b871df1fe1a3ba85d90e22742b93584f8d2b8e6124f8372ab15c71b73e428b8" }, + { url = "http://mirrors.aliyun.com/pypi/packages/e9/06/0449b744892ed22b7e7b9a1994a866e64895363572677a316a9042af1fe5/scipy-1.15.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:03205d57a28e18dfd39f0377d5002725bf1f19a46f444108c29bdb246b6c8a11" }, + { url = "http://mirrors.aliyun.com/pypi/packages/6a/6f/a8ac3cfd9505ec695c1bc35edc034d13afbd2fc1882a7c6b473e280397bb/scipy-1.15.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:601881dfb761311045b03114c5fe718a12634e5608c3b403737ae463c9885d53" }, + { url = "http://mirrors.aliyun.com/pypi/packages/f5/6f/e6e5aff77ea2a48dd96808bb51d7450875af154ee7cbe72188afb0b37929/scipy-1.15.2-cp312-cp312-win_amd64.whl", hash = "sha256:e7c68b6a43259ba0aab737237876e5c2c549a031ddb7abc28c7b47f22e202ded" }, + { url = "http://mirrors.aliyun.com/pypi/packages/53/40/09319f6e0f276ea2754196185f95cd191cb852288440ce035d5c3a931ea2/scipy-1.15.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:01edfac9f0798ad6b46d9c4c9ca0e0ad23dbf0b1eb70e96adb9fa7f525eff0bf" }, + { url = "http://mirrors.aliyun.com/pypi/packages/fe/c3/2854f40ecd19585d65afaef601e5e1f8dbf6758b2f95b5ea93d38655a2c6/scipy-1.15.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:08b57a9336b8e79b305a143c3655cc5bdbe6d5ece3378578888d2afbb51c4e37" }, + { url = "http://mirrors.aliyun.com/pypi/packages/dd/b1/f9fe6e3c828cb5930b5fe74cb479de5f3d66d682fa8adb77249acaf545b8/scipy-1.15.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:54c462098484e7466362a9f1672d20888f724911a74c22ae35b61f9c5919183d" }, + { url = "http://mirrors.aliyun.com/pypi/packages/15/9d/a60db8c795700414c3f681908a2b911e031e024d93214f2d23c6dae174ab/scipy-1.15.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:cf72ff559a53a6a6d77bd8eefd12a17995ffa44ad86c77a5df96f533d4e6c6bb" }, + { url = "http://mirrors.aliyun.com/pypi/packages/37/3b/9bda92a85cd93f19f9ed90ade84aa1e51657e29988317fabdd44544f1dd4/scipy-1.15.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9de9d1416b3d9e7df9923ab23cd2fe714244af10b763975bea9e4f2e81cebd27" }, + { url = "http://mirrors.aliyun.com/pypi/packages/03/5a/fc34bf1aa14dc7c0e701691fa8685f3faec80e57d816615e3625f28feb43/scipy-1.15.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb530e4794fc8ea76a4a21ccb67dea33e5e0e60f07fc38a49e821e1eae3b71a0" }, + { url = "http://mirrors.aliyun.com/pypi/packages/4a/71/472eac45440cee134c8a180dbe4c01b3ec247e0338b7c759e6cd71f199a7/scipy-1.15.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5ea7ed46d437fc52350b028b1d44e002646e28f3e8ddc714011aaf87330f2f32" }, + { url = "http://mirrors.aliyun.com/pypi/packages/01/b3/21f890f4f42daf20e4d3aaa18182dddb9192771cd47445aaae2e318f6738/scipy-1.15.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:11e7ad32cf184b74380f43d3c0a706f49358b904fa7d5345f16ddf993609184d" }, + { url = "http://mirrors.aliyun.com/pypi/packages/0d/76/77cf2ac1f2a9cc00c073d49e1e16244e389dd88e2490c91d84e1e3e4d126/scipy-1.15.2-cp313-cp313-win_amd64.whl", hash = "sha256:a5080a79dfb9b78b768cebf3c9dcbc7b665c5875793569f48bf0e2b1d7f68f6f" }, + { url = "http://mirrors.aliyun.com/pypi/packages/4c/4b/a57f8ddcf48e129e6054fa9899a2a86d1fc6b07a0e15c7eebff7ca94533f/scipy-1.15.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:447ce30cee6a9d5d1379087c9e474628dab3db4a67484be1b7dc3196bfb2fac9" }, + { url = "http://mirrors.aliyun.com/pypi/packages/0c/43/c304d69a56c91ad5f188c0714f6a97b9c1fed93128c691148621274a3a68/scipy-1.15.2-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:c90ebe8aaa4397eaefa8455a8182b164a6cc1d59ad53f79943f266d99f68687f" }, + { url = "http://mirrors.aliyun.com/pypi/packages/44/1a/6c21b45d2548eb73be9b9bff421aaaa7e85e22c1f9b3bc44b23485dfce0a/scipy-1.15.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:def751dd08243934c884a3221156d63e15234a3155cf25978b0a668409d45eb6" }, + { url = "http://mirrors.aliyun.com/pypi/packages/74/4b/aefac4bba80ef815b64f55da06f62f92be5d03b467f2ce3668071799429a/scipy-1.15.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:302093e7dfb120e55515936cb55618ee0b895f8bcaf18ff81eca086c17bd80af" }, + { url = "http://mirrors.aliyun.com/pypi/packages/b1/53/1cbb148e6e8f1660aacd9f0a9dfa2b05e9ff1cb54b4386fe868477972ac2/scipy-1.15.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cd5b77413e1855351cdde594eca99c1f4a588c2d63711388b6a1f1c01f62274" }, + { url = "http://mirrors.aliyun.com/pypi/packages/2c/23/e0eb7f31a9c13cf2dca083828b97992dd22f8184c6ce4fec5deec0c81fcf/scipy-1.15.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d0194c37037707b2afa7a2f2a924cf7bac3dc292d51b6a925e5fcb89bc5c776" }, + { url = "http://mirrors.aliyun.com/pypi/packages/03/f3/e699e19cabe96bbac5189c04aaa970718f0105cff03d458dc5e2b6bd1e8c/scipy-1.15.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:bae43364d600fdc3ac327db99659dcb79e6e7ecd279a75fe1266669d9a652828" }, + { url = "http://mirrors.aliyun.com/pypi/packages/af/f5/ab3838e56fe5cc22383d6fcf2336e48c8fe33e944b9037fbf6cbdf5a11f8/scipy-1.15.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f031846580d9acccd0044efd1a90e6f4df3a6e12b4b6bd694a7bc03a89892b28" }, + { url = "http://mirrors.aliyun.com/pypi/packages/0a/c8/b3f566db71461cabd4b2d5b39bcc24a7e1c119535c8361f81426be39bb47/scipy-1.15.2-cp313-cp313t-win_amd64.whl", hash = "sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db" }, +] + +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274" }, +] + +[[package]] +name = "types-networkx" +version = "3.4.2.20250319" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/41/ba/ddc3510d544cd919d751876abda8ff9da356d5064e1f9a9e9b3b92863b0c/types_networkx-3.4.2.20250319.tar.gz", hash = "sha256:514b373c21f03fef94d71ca7182026ae73773fe3d4eaaf9139e9bdd78f5361c5" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/23/23/13414ead12ee55491c25e75029b5bfce41990bc01071a1e2e07f1c642178/types_networkx-3.4.2.20250319-py3-none-any.whl", hash = "sha256:95266a3b4ee04779638231a8ef4c312e745f3eeab9da242614d03ad0bfbb0cc9" }, +] + +[[package]] +name = "typing-extensions" +version = "4.13.2" +source = { registry = "http://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "http://mirrors.aliyun.com/pypi/packages/f6/37/23083fcd6e35492953e8d2aaaa68b860eb422b34627b13f2ce3eb6106061/typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef" } +wheels = [ + { url = "http://mirrors.aliyun.com/pypi/packages/8b/54/b1ae86c0973cc6f0210b53d508ca3641fb6d0c56823f288d108bc7ab3cc8/typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c" }, +] diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 682bd94bdf9..fd4ed44508d 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -118,6 +118,16 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): args.rank = int(os.getenv('RANK', '0')) args.world_size = int(os.getenv("WORLD_SIZE", '1')) + # launch from mpi + if int(os.getenv('OMPI_COMM_WORLD_SIZE', '0')) > 0: + args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + args.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + addr, port = args.master_addr.split(':') + os.environ['MASTER_ADDR'] = addr + os.environ['MASTER_PORT'] = port + delattr(args, 'master_addr') + # Args to disable MSC if not args.enable_msc: MultiStorageClientFeature.disable() @@ -802,7 +812,13 @@ def validate_args(args, defaults={}): # across batches/microbatches. Due to additional communication overhead # during pipeline parallelism, it should not be set if sequence length # is constant during training. - args.variable_seq_lengths = False + if args.sft_sequence_packing: + args.variable_seq_lengths = True + # TODO(tailaim): add support for other dispatcher types + print(f"Setting moe_token_dispatcher_type to alltoall for sft sequence packing with pipeline parallelism") + args.moe_token_dispatcher_type = "alltoall" + else: + args.variable_seq_lengths = False # Iteration-based training. if args.train_iters: @@ -956,11 +972,27 @@ def validate_args(args, defaults={}): assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled' if args.hybrid_context_parallel: - assert not args.pipeline_model_parallel_size > 1, 'Hybrid context parallelism not supported with pipeline parallelism' + assert not (args.pipeline_model_parallel_size > 1 and args.use_megatron_fsdp), \ + 'Hybrid context parallelism not supported with pipeline parallelism when using FSDP' assert not args.enable_cuda_graph, 'Hybrid context parallelism not supported with CUDA Graph' - assert not args.use_megatron_fsdp, 'Hybrid context parallelism not supported with Megatron FSDP' assert args.dataloader_type == 'single', 'Hybrid context parallelism only supported with single dataloader type' assert args.calculate_per_token_loss, 'Hybrid context parallelism must be used with --calculate-per-token-loss' + # assert args.context_parallel_size == 1, 'context parallel size must be 1 for hybrid context parallelism' + + if args.sft_sequence_packing: + # Validate that packed sequence buffer is large enough for single sequences + if args.hybrid_context_parallel: + # packed_buffer_size = hdp_size * max_seqlen_per_rank >= single_seq_max_len + hdp_size = args.world_size // (args.tensor_model_parallel_size * args.pipeline_model_parallel_size) + assert hdp_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ + f'Packed sequence buffer size ({hdp_size * args.max_seqlen_per_dp_cp_rank}) ' \ + f'must be >= single sequence max length ({args.seq_length})' + else: + # packed_buffer_size = cp_size * max_seqlen_per_rank >= single_seq_max_len + assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ + f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \ + f'must be >= single sequence max length ({args.seq_length})' + # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled @@ -1333,6 +1365,8 @@ def core_transformer_config_from_args(args, config_class=None): for f in dataclasses.fields(config_class): if hasattr(args, f.name): kw_args[f.name] = getattr(args, f.name) + kw_args['vocab_size'] = args.vocab_size + kw_args['min_hybrid_context_parallel_size'] = args.min_hybrid_context_parallel_size kw_args['persist_layer_norm'] = not args.no_persist_layer_norm kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p kw_args['layernorm_epsilon'] = args.norm_epsilon @@ -1946,6 +1980,10 @@ def _add_logging_args(parser): ' max: report the max timing across all ranks' ' minmax: report min and max timings across all ranks' ' all: report timings of all ranks.') + group.add_argument("--use-gpu-timer", action='store_true', default=False, + help='Enable GPU timer.') + group.add_argument('--gpu-timer-interval', type=int, default=100, + help='Number of iterations to run for time record for gpu timer.') group.add_argument('--tensorboard-log-interval', type=int, default=1, help='Report to tensorboard interval.') group.add_argument('--tensorboard-queue-size', type=int, default=1000, @@ -2211,10 +2249,14 @@ def _add_training_args(parser): '-o --force-overwrite true ' '--capture-range=cudaProfilerApi ' '--capture-range-end=stop`.') - group.add_argument('--profile-step-start', type=int, default=10, + group.add_argument('--profile-step-start', type=int, default=4, help='Global step to start profiling.') - group.add_argument('--profile-step-end', type=int, default=12, + group.add_argument('--profile-step-end', type=int, default=6, help='Global step to stop profiling.') + group.add_argument('--profile-memory', action='store_true', + default=False, help='Record memory info for analysis purpose. ') + group.add_argument('--profile-memory-path', type=str, default=None, + help='filepath to saveRecord memory info. ') group.add_argument('--iterations-to-skip', nargs='+', type=int, default=[], help='List of iterations to skip, empty by default.') group.add_argument('--result-rejected-tracker-filename', type=str, default=None, @@ -2226,7 +2268,7 @@ def _add_training_args(parser): help='Use the built-in pytorch profiler. ' 'Useful if you wish to view profiles in tensorboard.', dest='use_pytorch_profiler') - group.add_argument('--profile-ranks', nargs='+', type=int, default=[0], + group.add_argument('--profile-ranks', nargs='+', type=int, default=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 24, 28, 32, 64, 96, 128, 160, 192, 224], help='Global ranks to profile.') group.add_argument('--record-memory-history', action="store_true", default=False, help='Record memory history in last rank.') @@ -2720,6 +2762,8 @@ def _add_mixed_precision_args(parser): def _add_distributed_args(parser): group = parser.add_argument_group(title='distributed') + group.add_argument('--master-addr', type=str, default='127.0.0.1:8389', + help='master add.') group.add_argument('--tensor-model-parallel-size', type=int, default=1, help='Degree of tensor model parallelism.') group.add_argument('--pipeline-model-parallel-size', type=int, default=1, @@ -2881,13 +2925,35 @@ def _add_distributed_args(parser): '--hierarchical-context-parallel-sizes 2 4 indicates every two adjacent gpus ' 'forms the first level of cp groups and the cp ranks with the same odevity ' 'forms the second level of cp groups.') - group.add_argument('--max-seqlen-per-cp-rank', type=int, default=None, + group.add_argument('--max-seqlen-per-dp-cp-rank', type=int, default=None, help='Maximum sequence length per CP rank. This is used to calculate the ' 'number of sub-samples assigned to each CP rank when using heterogeneous context parallel.') group.add_argument('--hybrid-context-parallel', action='store_true', default=False, help='Enables hybrid context parallel. This is used to balance the workload ' 'of each CP rank when we use packed samples with variable sequence lengths. ' - 'Requires --max-seqlen-per-cp-rank to be set.') + 'Requires --max-seqlen-per-dp-cp-rank to be set.') + group.add_argument('--min-hybrid-context-parallel-size', type=int, default=1, + help='Minimum size of the hybrid context parallel groups.') + group.add_argument('--max-hybrid-context-parallel-size', type=int, default=-1, + help='Minimum size of the hybrid context parallel groups.') + group.add_argument('--hybrid-context-parallel-scheduler', type=str, default='balanced', + choices=['balanced', 'balanced_with_pp', 'only_packing_no_scheduling'], + help='Scheduler for hybrid context parallel. ' + 'balanced: balanced scheduler for hybrid context parallel. ' + 'balanced_with_pp: balanced scheduler for hybrid context parallel with pipeline parallel. ' + 'only_packing_no_scheduling: scheduling is already handled by the data sampler, ' + 'this scheduler only performs packing.') + group.add_argument('--async-hybrid-context-parallel-scheduler', action='store_true', + default=False, help='Use asynchronize context parallel scheduler to avoid scheduler execution and extra communication time. ') + group.add_argument('--run-memory-simulator', action='store_true', + default=False, help='run memory simulator for pp scheduler. ') + group.add_argument('--search-space', nargs='+', type=int, default=[1,2,3,4,5,6], + help='search space for `PipelineAwareBalancedHybridCPscheduler`, ' + 'if only one param, it means the range of the search space.' + 'For example, 4 means choose PP*1, PP*2, PP*3, PP*4 to search.' + 'if more than one param, it means the selected number of microbatch to search.' + 'For example, [1,2,3,4] means choose PP*1, PP*2, PP*3, PP*4 to search' + '[2,4] means only choose PP*2 and PP*4 to search.') group.add_argument('--nccl-communicator-config-path', type=str, default=None, help='Path to the yaml file with NCCL communicator ' 'configurations. The number of min/max thread groups and thread ' @@ -3579,4 +3645,8 @@ def _add_sft_args(parser): group.add_argument('--sft', action="store_true", help='Megatron SFT training') group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", help='SFT prompt format.') + group.add_argument('--sft-sequence-packing', action='store_true', + help='use sequence packing(thd format) for SFT training') + group.add_argument('--sft-mock-dataset-config-json', type=str, default=None, + help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution.') return parser diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index e4d8a6faf24..caf886216b7 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -1,12 +1,16 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -from typing import Any, Dict, Optional +import json +import math +from typing import Any, Dict, Optional, List import numpy as np +import pandas as pd import torch from megatron.core.datasets.gpt_dataset import GPTDatasetConfig from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset +from megatron.core.datasets.indexed_dataset import IndexedDataset from megatron.core.datasets.utils import Split IGNORE_INDEX = -100 @@ -56,6 +60,8 @@ def __init__( config: GPTDatasetConfig, ) -> None: super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) + # Pre-calculate padding divisor to avoid redundant computation in get_padding_size + self.padding_divisor = self._calculate_padding_divisor() @staticmethod def numel_low_level_dataset(low_level_dataset: LowLevelDataset) -> int: @@ -68,8 +74,38 @@ def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> LowL def __len__(self) -> int: return self.num_samples - def __getitem__(self, idx: int) -> Dict[str, Any]: + def _calculate_padding_divisor(self) -> int: + """ + Calculate the divisor used for sequence padding. + tp_pad = tp_size * 2 if tp_size > 1 else 1 + cp_pad = cp_size * 2 if cp_size > 1 else 1 + cp_pad = cp_pad * dp_size if hybrid_cp else cp_pad + divisor = cp_pad * tp_pad + """ + if self.config.hybrid_context_parallel: + # Hybrid CP: consider both CP and DP + cp_pad = self.config.data_parallel_size * self.config.context_parallel_size * 2 + else: + # Standard CP: only consider CP + cp_pad = self.config.context_parallel_size * 2 if self.config.context_parallel_size > 1 else 1 + tp_pad = self.config.sequence_parallel_size if self.config.sequence_parallel_size > 0 else 1 + divisor = cp_pad * tp_pad + # TODO(tailaim): do we need to pad for FP8 execution? + # divisor = ((divisor + 15) // 16) * 16 + return divisor + + def get_padding_size( + self, + seq_len: int, + ) -> int: + seq_len_padded = math.ceil(seq_len / self.padding_divisor) * self.padding_divisor + assert seq_len > seq_len_padded / 2 / self.config.context_parallel_size * (self.config.context_parallel_size - 1), \ + f"sequence length {seq_len} is too short, the divisor is {self.padding_divisor}, that means cp_rank \ + {self.config.context_parallel_size-1} will have no valid tokens" + return seq_len_padded + def __getitem__(self, idx: int) -> Dict[str, Any]: + sft_sequence_packing = self.config.sft_sequence_packing tokenizer = self.config.tokenizer max_seq_len = self.config.sequence_length @@ -84,11 +120,15 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: tokens = tokens[: max_seq_len - force_eod_length] target = target[: max_seq_len - force_eod_length] - # padding + # if use sequence packing, pad according to get_padding_size + # else pad to max_seq_len num_tokens = len(tokens) + force_eod_length - padding_len = max_seq_len - num_tokens + if sft_sequence_packing: + padding_len = self.get_padding_size(num_tokens) - num_tokens + else: + padding_len = max_seq_len - num_tokens assert padding_len >= 0 - filler = [tokenizer.eod] * force_eod_length + [tokenizer.pad] * (padding_len + 1) + filler = [1] * force_eod_length + [1] * (padding_len + 1) tokens = np.array(tokens.tolist() + filler, dtype=np.int64) target = np.array(target.tolist() + filler, dtype=np.int64) @@ -98,9 +138,10 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: tokens = tokens[:-1].contiguous() target = target[1:].contiguous() + seq_len = tokens.numel() loss_mask, position_ids, attention_mask = self._get_ltor_masks_and_position_ids( - max_seq_len, target, tokenizer.pad + seq_len, target, 1 ) if self.config.create_attention_mask: @@ -119,6 +160,10 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: 'position_ids': position_ids, } + if sft_sequence_packing: + # sequence packing need both original sequence length and padded length + ret['original_seq_len'] = torch.tensor(num_tokens, dtype=torch.int32) + return ret def _get_ltor_masks_and_position_ids(self, max_seq_len, target, pad_token): @@ -136,7 +181,7 @@ def _get_ltor_masks_and_position_ids(self, max_seq_len, target, pad_token): if self.config.create_attention_mask: attention_mask = torch.tril( - torch.ones((seq_length, seq_length), device=data.device) + torch.ones((max_seq_len, max_seq_len), device=target.device) ).unsqueeze(0) # Convert attention mask to binary: attention_mask = attention_mask < 0.5 @@ -144,3 +189,184 @@ def _get_ltor_masks_and_position_ids(self, max_seq_len, target, pad_token): attention_mask = None return loss_mask, position_ids, attention_mask + + +class MockSFTLowLevelDataset: + """The low-level mock dataset for SFT + + Args: + mock_config (dict): The config for mock dataset. + """ + + seed: int = 0 + """The hard-coded random seed to use to set the NumPy RNG""" + + size: int = 1000000 + """The hard-coded number of sequence to generate""" + + # This is to maintain consistency with the SFT dataset that uses real data. In the real dataset, an element in the low-level dataset often contains multiple sequences. So here, each element in the mock low-level dataset also contains num_sequence_per_sample sequences. This will be made more reasonable in the future. + + + def __init__(self, config: Dict) -> None: + np.random.seed(self.seed) + # either choose to load sequence lengths from external file, or generate random sequence lengths + + assert "mode" in config, f"mode must be set, either 'file' or 'distribution'" + + if config["mode"] == "indexed_file": + self.dataset = IndexedDataset(config["path"]) + self.size = len(self.dataset) + self.sequence_lengths = self.dataset.sequence_lengths + elif config["mode"] == "file": + min_seq_len = config["min_seq_len"] + max_seq_len = config["max_seq_len"] + self.sequence_lengths = np.array(pd.read_csv(config["path"])).flatten() + self.sequence_lengths = self.sequence_lengths[(min_seq_len <= self.sequence_lengths) & (self.sequence_lengths <= max_seq_len)] + self.size = len(self.sequence_lengths) + elif config["mode"] == "distribution": + min_seq_len = config["min_seq_len"] + max_seq_len = config["max_seq_len"] + if config["type"] == "lognormal": + mean_seq_len = config["mean_seq_len"] + lognormal_sigma = config["lognormal_sigma"] + self.sequence_lengths = self.generate_lognormal_samples(self.size, mean_seq_len,lognormal_sigma, min_seq_len, max_seq_len) + elif config["type"] == "linear": + self.sequence_lengths = self.generate_linear_samples(self.size, min_seq_len, max_seq_len) + print(f"{self.sequence_lengths=}") + else: + raise ValueError(f"Unsupported sequence length distribution type {config['type']}") + + def generate_linear_samples(self, size, min_seq_len, max_seq_len, step=256): + samples = np.arange(min_seq_len, max_seq_len, step) + print(f"{size=}, {min_seq_len=}, {max_seq_len=}, {samples=}") + return samples.astype(int) + + def generate_lognormal_samples(self, size, mean, sigma, min_seq_len, max_seq_len): + mu = np.log(mean) - sigma**2 / 2 + samples = np.random.lognormal(mu, sigma, size) + samples = np.clip(samples, min_seq_len, max_seq_len) + return samples.astype(int) + + def __len__(self) -> int: + return self.size + + def __getitem__(self, idx: int) -> List[np.ndarray]: + if hasattr(self, "dataset") and self.dataset is not None: + return self.dataset[idx] + length = self.sequence_lengths[idx % len(self.sequence_lengths)] + # the length of sample is 'length', but only length-1 elements are generated here, + # because an eod token will be appended at the end later in SFTDataset + sample = np.arange(2, length + 1 , dtype=np.int64) + return sample + + +class MockSFTDataset(SFTDataset): + """The mock dataset used during SFT""" + + def __init__( + self, + dataset: LowLevelDataset, + dataset_path: Optional[str], + indices: np.ndarray, + num_samples: Optional[int], + index_split: Split, + config: GPTDatasetConfig, + ) -> None: + super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) + + @staticmethod + def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> LowLevelDataset: + mock_config = json.loads(config.sft_mock_dataset_config_json) + return MockSFTLowLevelDataset(mock_config) + + def __len__(self) -> int: + return self.num_samples + + @property + def sequence_lengths(self) -> np.ndarray: + """Get the sequence lengths + + Returns: + numpy.ndarray: The sequence lengths + """ + return self.dataset.sequence_lengths + + def __getitem__(self, idx: int) -> Dict[str, Any]: + num_microbatch_left = -1 + cp_size = -1 + if isinstance(idx, tuple): + # print(f"{idx=}") + if len(idx) == 2: + idx, num_microbatch_left = idx + elif len(idx) == 3: + idx, num_microbatch_left, cp_size = idx + + sft_sequence_packing = self.config.sft_sequence_packing + tokenizer = self.config.tokenizer + max_seq_len = self.config.sequence_length + + tokens = self.dataset[int(self.indices[idx % len(self.indices)])] + target = np.array(tokens, dtype=np.int64) + + # force_eod_length = int(tokenizer.force_eod) + force_eod_length = 1 + + if len(tokens) > max_seq_len - force_eod_length: + # cut the right side + tokens = tokens[: max_seq_len - force_eod_length] + target = target[: max_seq_len - force_eod_length] + # tokens = tokens[(-max_seq_len + force_eod_length):] + # target = target[(-max_seq_len + force_eod_length):] + + # padding + num_tokens = len(tokens) + force_eod_length + if sft_sequence_packing: + padding_len = self.get_padding_size(num_tokens) - num_tokens + else: + padding_len = max_seq_len - num_tokens + assert padding_len >= 0 + filler = [1] * force_eod_length + [1] * (padding_len + 1) + + tokens = np.array(tokens.tolist() + filler, dtype=np.int64) + target = np.array(target.tolist() + filler, dtype=np.int64) + + tokens = torch.tensor(tokens) + target = torch.tensor(target) + + tokens = tokens[:-1].contiguous() + target = target[1:].contiguous() + seq_len = tokens.numel() + + loss_mask, position_ids, attention_mask = self._get_ltor_masks_and_position_ids( + seq_len, target, 1 + ) + + if self.config.create_attention_mask: + ret = { + 'tokens': tokens, + 'labels': target, + 'attention_mask': attention_mask, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + if num_microbatch_left != -1: + ret['num_micro_batches_left'] = num_microbatch_left + if cp_size != -1: + ret['local_cp_size'] = cp_size + else: + ret = { + 'tokens': tokens, + 'labels': target, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + if num_microbatch_left != -1: + ret['num_micro_batches_left'] = num_microbatch_left + if cp_size != -1: + ret['local_cp_size'] = cp_size + + if sft_sequence_packing: + # sequence packing need both original sequence length and padded length + ret['original_seq_len'] = torch.tensor(num_tokens, dtype=torch.int32) + + return ret diff --git a/megatron/training/global_vars.py b/megatron/training/global_vars.py index ec402263d29..76d5ddb5424 100644 --- a/megatron/training/global_vars.py +++ b/megatron/training/global_vars.py @@ -7,6 +7,7 @@ import torch from megatron.core import Timers +from megatron.core.gpu_timers import GPUTimer from megatron.core.config import set_experimental_flag from megatron.core.energy_monitor import EnergyMonitor from megatron.core.jit import disable_jit_fuser @@ -21,6 +22,7 @@ _GLOBAL_ONE_LOGGER = None _GLOBAL_ADLR_AUTORESUME = None _GLOBAL_TIMERS = None +_GLOBAL_GPU_TIMERS = None _GLOBAL_ENERGY_MONITOR = None _GLOBAL_SIGNAL_HANDLER = None @@ -64,6 +66,14 @@ def get_timers(): _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') return _GLOBAL_TIMERS + +def get_gpu_timers(): + """Return timers.""" + global _GLOBAL_GPU_TIMERS + _ensure_var_is_initialized(_GLOBAL_GPU_TIMERS, 'timers') + return _GLOBAL_GPU_TIMERS + + def get_energy_monitor(): """Return energy monitor.""" _ensure_var_is_initialized(_GLOBAL_ENERGY_MONITOR, 'energy monitor') @@ -141,6 +151,7 @@ def unset_global_variables(): _GLOBAL_ONE_LOGGER = None _GLOBAL_ADLR_AUTORESUME = None _GLOBAL_TIMERS = None + _GLOBAL_GPU_TIMERS = None _GLOBAL_ENERGY_MONITOR = None _GLOBAL_SIGNAL_HANDLER = None @@ -262,9 +273,10 @@ def _set_adlr_autoresume(args): def _set_timers(args): """Initialize timers.""" - global _GLOBAL_TIMERS + global _GLOBAL_TIMERS, _GLOBAL_GPU_TIMERS _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') _GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option) + _GLOBAL_GPU_TIMERS = GPUTimer(args.use_gpu_timer) def _set_energy_monitor(args): """Initialize energy monitor.""" @@ -303,6 +315,7 @@ def destroy_global_vars(): global _GLOBAL_TIMERS _GLOBAL_TIMERS = None + _GLOBAL_GPU_TIMERS = None global _GLOBAL_ENERGY_MONITOR _GLOBAL_ENERGY_MONITOR = None diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index fb9a3aa273b..b36615b8d2f 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -6,6 +6,7 @@ import random import time import warnings +import resource from datetime import timedelta import numpy as np @@ -87,6 +88,9 @@ def initialize_megatron( else: validate_args(args, args_defaults) + soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) + resource.setrlimit(resource.RLIMIT_NOFILE, (hard_limit, hard_limit)) + # set global args, build tokenizer, and set adlr-autoresume, # tensorboard-writer, and timers. set_global_variables(args) @@ -381,6 +385,8 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s create_gloo_process_groups=args.enable_gloo_process_groups, high_priority_stream_groups=args.high_priority_stream_groups, sharp_enabled_group=args.sharp_enabled_group, + min_hybrid_context_parallel_size=args.min_hybrid_context_parallel_size, + max_hybrid_context_parallel_size=args.max_hybrid_context_parallel_size, ) if args.rank == 0: print( diff --git a/megatron/training/tokenizer/sft_tokenizer.py b/megatron/training/tokenizer/sft_tokenizer.py index f525352e892..5801dec53f6 100644 --- a/megatron/training/tokenizer/sft_tokenizer.py +++ b/megatron/training/tokenizer/sft_tokenizer.py @@ -62,7 +62,9 @@ def __init__( raise NotImplementedError("unknown SFT prompt format", prompt_format) self._prompt_format = prompt_format - + if self._prompt_config.pad_token_id is None: + self._prompt_config.pad_token_id = self._tokenizer.eos_token_id - 1 + print(f"pad token id is not set, set to (eos_token_id - 1): {self._prompt_config.pad_token_id} for {prompt_format}") def tokenize_conversation( self, conversation: List[Dict], return_target: bool, add_generation_prompt: bool @@ -179,6 +181,11 @@ def bos(self): def eod(self): """End of sentence token ID.""" return self._tokenizer.eos_token_id + + @property + def eos(self): + """End of sentence token ID.""" + return self._tokenizer.eos_token_id @property def vocab(self): diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py index 13b7526ca07..fb74d32b49c 100644 --- a/megatron/training/tokenizer/tokenizer.py +++ b/megatron/training/tokenizer/tokenizer.py @@ -871,6 +871,14 @@ def eos(self): def additional_special_tokens_ids(self): return None + @property + def force_eod(self): + """To force an EOD at the end of every data sample in SFT.""" + return True + + @property + def pad(self): + return self._eod_id - 1 class _NullMultimodalTokenizer(MegatronLegacyTokenizer): def __init__(self, vocab_size, image_token=None, image_token_id=None): diff --git a/megatron/training/training.py b/megatron/training/training.py index c29c48d4c9f..3129f9704ff 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -134,6 +134,7 @@ get_args, get_signal_handler, get_timers, + get_gpu_timers, get_tensorboard_writer, get_wandb_writer, get_one_logger, @@ -164,7 +165,7 @@ def print_datetime(string): print_rank_0(f'[{string}] datetime: {time_str} ') -def num_floating_point_operations(args, batch_size): +def num_floating_point_operations(args, num_total_tokens_this_GB, sequence_square_sum_this_GB): def calculate_layer_counts(): """Calculate the number of attention, Mamba, and MLP layers.""" if args.hybrid_override_pattern: @@ -179,27 +180,25 @@ def calculate_layer_counts(): num_mamba_layers = args.num_layers - num_attn_layers - num_mlp_layers return num_attn_layers, num_mamba_layers, num_mlp_layers - def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False): + def mlp_layer_flops(num_total_tokens_this_GB, hidden_size, expansion=4.0, swiglu=False): """Calculate FLOPs for an MLP layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 - return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2 + return 4 * expansion * scale_factor * num_total_tokens_this_GB * hidden_size**2 def attn_layer_flops( - batch_size, seq_len, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None + num_total_tokens_this_GB, sequence_square_sum_this_GB, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None ): """Calculate FLOPs for an attention layer.""" p = (kv_channels * num_heads / hidden_size) if kv_channels else 1 g = gqa_groups if gqa else num_heads return ( 4 - * batch_size - * seq_len * hidden_size * p - * (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2)) + * (hidden_size * num_total_tokens_this_GB + (hidden_size * (g / num_heads)) * num_total_tokens_this_GB + (sequence_square_sum_this_GB / 2)) ) - def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, + def mamba_layer_flops(num_total_tokens_this_GB, hidden_size, state_dim=16, head_dim=64, num_groups=1, num_heads=128): """Calculate FLOPs for a Mamba layer.""" # Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels, @@ -212,16 +211,15 @@ def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, return ( ( 2 - * batch_size - * seq_len + * num_total_tokens_this_GB * hidden_size * (2 * d_in + 2 * num_groups * state_dim + nheads) ) # in_proj - + (7 * batch_size * seq_len * d_in * state_dim) # scan - + (2 * batch_size * seq_len * d_in * hidden_size) # out_proj + + (7 * num_total_tokens_this_GB * d_in * state_dim) # scan + + (2 * num_total_tokens_this_GB * d_in * hidden_size) # out_proj ) - def hybrid_flops(batch_size, seq_len, hidden_size, + def hybrid_flops(num_total_tokens_this_GB, sequence_square_sum_this_GB, hidden_size, num_attn_layers, num_mamba_layers, num_mlp_layers, mamba_state_dim=128, mamba_head_dim=64, mamba_num_groups=8, mamba_num_heads=128, @@ -231,20 +229,23 @@ def hybrid_flops(batch_size, seq_len, hidden_size, vocab_size=256000): """Calculate total FLOPs for the hybrid model.""" flops_fwd = ( - num_attn_layers * attn_layer_flops(batch_size, seq_len, hidden_size, + num_attn_layers * attn_layer_flops(num_total_tokens_this_GB, sequence_square_sum_this_GB, hidden_size, num_attn_heads, gqa, gqa_groups, kv_channels) + - num_mlp_layers * mlp_layer_flops(batch_size, seq_len, hidden_size, + num_mlp_layers * mlp_layer_flops(num_total_tokens_this_GB, hidden_size, mlp_expansion, swiglu) + - num_mamba_layers * mamba_layer_flops(batch_size, seq_len, hidden_size, + num_mamba_layers * mamba_layer_flops(num_total_tokens_this_GB, hidden_size, mamba_state_dim, mamba_head_dim, mamba_num_groups, mamba_num_heads) + - (2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation + (2 * num_total_tokens_this_GB * hidden_size * vocab_size) # logits computation ) return flops_fwd * 3 def transformer_flops(): """Calculate FLOPs for a standard Transformer model.""" # TODO(helenn/dnarayanan): Refactor this to reuse the helper methods. + # Attention projection size. + query_projection_size = args.kv_channels * args.num_attention_heads + query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size # Group Query Attention. if not args.group_query_attention: args.num_query_groups = args.num_attention_heads @@ -311,13 +312,18 @@ def transformer_flops(): assert not args.group_query_attention ''' Basic arithmetic - let B is batch size, s is seq_len, h is embedding dim, - for one self_attnetion block (prenorm is not included) - qkv projection: 6Bsh^2 - attn: 2Bs^2h - attn over value: 2Bs^2h - oproj: 2Bsh^2 - + + Let h be the embedding dim. + We use two statistics to unify BSHD and THD cases: + num_total_tokens_this_GB: total number of tokens in this global batch + sequence_square_sum_this_GB: sum of squared sequence lengths in this global batch + + For one self-attention block (prenorm not included): + qkv projection: 6 * num_total_tokens_this_GB * h^2 + attn: 2 * sequence_square_sum_this_GB * h + attn over value: 2 * sequence_square_sum_this_GB * h + oproj: 2 * num_total_tokens_this_GB * h^2 + references https://arxiv.org/abs/2305.10403 https://arxiv.org/abs/2205.05198 @@ -338,7 +344,7 @@ def transformer_flops(): standard_self_attn_term = ( 3 * 2 # fwd(1) + bwd(2) *FMA - * ( + * ( num_total_tokens_this_GB * ( ## q lora + rope + q norm q_term ## kv lora + rope + kv norm @@ -350,12 +356,12 @@ def transformer_flops(): ) + args.hidden_size * args.qk_pos_emb_head_dim ## o proj - + (args.num_attention_heads * args.v_head_dim) * args.hidden_size + + (args.num_attention_heads * args.v_head_dim) * args.hidden_size) ## core attn - + args.seq_length + + sequence_square_sum_this_GB * (args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)) - / 2 # causal mask (only half of the mask is non-zero) - + args.seq_length * args.num_attention_heads * args.v_head_dim / 2 + / 2 # causal mask (only half of the mask is non-zero) + + sequence_square_sum_this_GB * args.num_attention_heads * args.v_head_dim / 2 ) ) @@ -367,17 +373,17 @@ def transformer_flops(): standard_self_attn_term = ( 3 * 2 # fwd(1) + bwd(2) *FMA - * ( + * ( num_total_tokens_this_GB *( ## qkv proj args.hidden_size - * (query_projection_size + key_projection_size + value_projection_size) + * (query_projection_size + key_projection_size + value_projection_size)) ## core attention + query_projection_size - * args.seq_length + * sequence_square_sum_this_GB / 2 # causal mask (only half of the mask is non-zero) * 2 # QK^T and (QK^T)V ## out proj - + query_projection_size + + num_total_tokens_this_GB * query_projection_size * args.hidden_size ) ) @@ -450,8 +456,7 @@ def transformer_flops(): ) total_floating_point_operations = ( - batch_size - * args.seq_length + num_total_tokens_this_GB * ( # MLP expansion_factor @@ -468,8 +473,6 @@ def transformer_flops(): + (shared_expert_ffn_hidden_size * gated_linear_multiplier) * (num_moe_layers / num_layers) ) - # Self Attention - + self_attn_term # MTP norms and proj + 3 * 2 @@ -483,6 +486,10 @@ def transformer_flops(): # Logit. + 3 * 2 * args.hidden_size * args.padded_vocab_size * (mtp_num_layers + 1) ) + + + # Self Attention + self_attn_term + ) return total_floating_point_operations @@ -493,8 +500,8 @@ def transformer_flops(): # Compute hybrid model FLOPs. return hybrid_flops( - batch_size=batch_size, - seq_len=args.seq_length, + num_total_tokens_this_GB=num_total_tokens_this_GB, + sequence_square_sum_this_GB=sequence_square_sum_this_GB, hidden_size=args.hidden_size, num_attn_layers=num_attn_layers, num_mamba_layers=num_mamba_layers, @@ -1348,6 +1355,7 @@ def setup_model_and_optimizer( def dummy_train_step(data_iterator): + # TODO(tailaim): this need to be modified """Single dummy training step.""" num_microbatches = get_num_microbatches() rerun_state_machine = get_rerun_state_machine() @@ -1362,6 +1370,7 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch """Single training step.""" args = get_args() timers = get_timers() + gpu_timer = get_gpu_timers() rerun_state_machine = get_rerun_state_machine() while rerun_state_machine.should_run_forward_backward(data_iterator): @@ -1398,9 +1407,16 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch forward_only=False, adjust_tensor_shapes_fn=adjust_tensor_shapes_fn, ) + + if args.sft_sequence_packing: + num_total_tokens_this_GB, sequence_square_sum_this_GB = losses_reduced.pop() + else: + sequence_square_sum_this_GB = args.seq_length ** 2 * args.micro_batch_size * args.data_parallel_size * get_num_microbatches() + num_total_tokens_this_GB = args.seq_length * args.micro_batch_size * args.data_parallel_size * get_num_microbatches() + should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() if should_exit: - return {}, True, should_checkpoint, should_exit, exit_code, None, None, 0 + return {}, True, should_checkpoint, should_exit, exit_code, None, None, num_total_tokens_this_GB, sequence_square_sum_this_GB, 0 # Empty unused memory. if args.empty_unused_memory_level >= 1: @@ -1450,6 +1466,22 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch if args.empty_unused_memory_level >= 2: torch.cuda.empty_cache() + + if mpu.get_pipeline_model_parallel_world_size() > 1 and not gpu_timer.inactive: + gpu_timer.compute(name="forward-compute") + gpu_timer.compute(name="backward-compute") + # gpu_timer.compute(name="forward-backward") + # fwd_bwd_tot_time = gpu_timer.elapsed("forward-backward") + fwd_each_time = gpu_timer.elapsed("forward-compute") # list, len = each virtual pp stage + bwd_each_time = gpu_timer.elapsed("backward-compute") + # print(f"rank={torch.distributed.get_rank()}, pp_rank={mpu.get_pipeline_model_parallel_rank()}, dp_rank={mpu.get_data_parallel_rank()}, {fwd_each_time=}, {bwd_each_time=}") + # summary_data_parallel_imbalance(fwd_bwd_tot_time, fwd_each_time, bwd_each_time) + # summary_pipeline_parallel_imbalance(fwd_bwd_tot_time, fwd_each_time, bwd_each_time) + gpu_timer.reset() + gpu_timer.inactivate() + if args.use_gpu_timer and int(args.curr_iteration) % args.gpu_timer_interval == 0: + gpu_timer.activate() + if mpu.is_pipeline_last_stage(ignore_virtual=True): # Average loss across microbatches. loss_reduced = {} @@ -1480,8 +1512,10 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch grad_norm, num_zeros_in_grad, log_max_attention_logit, + num_total_tokens_this_GB, + sequence_square_sum_this_GB, ) - return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit + return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit, num_total_tokens_this_GB, sequence_square_sum_this_GB def training_log( @@ -1497,6 +1531,8 @@ def training_log( params_norm, num_zeros_in_grad, max_attention_logit, + num_total_tokens_this_GB, + sequence_square_sum_this_GB, pg_collection=None, ): """Log training information such as losses, timing, ....""" @@ -1713,9 +1749,10 @@ def training_log( dump(snapshot, f) elapsed_time = timers('interval-time').elapsed(barrier=True) + # elapsed_time = timers('forward-backward').elapsed(barrier=True) elapsed_time_per_iteration = elapsed_time / total_iterations - throughput = num_floating_point_operations(args, batch_size) / ( + throughput = num_floating_point_operations(args,num_total_tokens_this_GB, sequence_square_sum_this_GB) / ( elapsed_time_per_iteration * 10**12 * args.world_size ) @@ -2102,7 +2139,7 @@ def train( """Training function: run train_step desired number of times, run validation, checkpoint.""" args = get_args() timers = get_timers() - + if getattr(args, 'perform_rl_step', False): assert has_rl_utils, "RL cannot run without the megatron.rl package" @@ -2165,9 +2202,6 @@ def train( energy_monitor = get_energy_monitor() one_logger = get_one_logger() - if args.hybrid_context_parallel: - train_data_iterator = iter(HybridCPDataLoaderWrapper(train_data_iterator, config)) - if args.run_workload_inspector_server: try: from workload_inspector.utils.webserver import run_server @@ -2361,6 +2395,18 @@ def get_e2e_base_metrics(): nsys_nvtx_context = torch.autograd.profiler.emit_nvtx(record_shapes=True) nsys_nvtx_context.__enter__() + if args.profile_memory and torch.distributed.get_rank() in args.profile_ranks: + if iteration == args.profile_step_start: + print(f"start profile") + torch.cuda.memory._record_memory_history(max_entries=8000000) + if iteration == args.profile_step_end: + filepath = args.profile_memory_path + filename = f"memory_record_DP{mpu.get_data_parallel_rank()}_TP{mpu.get_tensor_model_parallel_rank()}_PP{mpu.get_pipeline_model_parallel_rank()}_rank{torch.distributed.get_rank()}.json" + filename = os.path.join(filepath, filename) + print(f"end profile, {filename=}") + torch.cuda.memory._dump_snapshot(f"{filename}") + torch.cuda.memory._record_memory_history(enabled=None) + ft_integration.on_checkpointing_start() maybe_finalize_async_save(blocking=False) ft_integration.on_checkpointing_end(is_async_finalization=True) @@ -2454,6 +2500,8 @@ def get_e2e_base_metrics(): grad_norm, num_zeros_in_grad, max_attention_logit, + num_total_tokens_this_GB, + sequence_square_sum_this_GB, ) = train_step( forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func ) @@ -2524,7 +2572,7 @@ def get_e2e_base_metrics(): else: assert num_skipped_samples_in_batch == 0 args.skipped_train_samples += num_skipped_samples_in_batch - num_floating_point_operations_in_batch = num_floating_point_operations(args, batch_size) + num_floating_point_operations_in_batch = num_floating_point_operations(args, num_total_tokens_this_GB, sequence_square_sum_this_GB) num_floating_point_operations_so_far += num_floating_point_operations_in_batch num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch @@ -2559,6 +2607,8 @@ def get_e2e_base_metrics(): params_norm, num_zeros_in_grad, max_attention_logit, + num_total_tokens_this_GB, + sequence_square_sum_this_GB, pg_collection=model_pg_collection, ) @@ -2731,6 +2781,8 @@ def evaluate( decoder_seq_length=args.decoder_seq_length, forward_only=True, ) + # need to drop first two elements which are total_num_tokens and total_sequence_square_sum + loss_dicts = loss_dicts[2:] ft_integration.on_eval_step_end() config.timers = get_timers() @@ -2769,6 +2821,8 @@ def evaluate( group=mpu.get_data_parallel_group(with_context_parallel=True) ) total_loss_dict[key] += val + + elif val[0].numel() == 1: val = torch.cat(val).sum() total_loss_dict[key][0] += val @@ -2993,19 +3047,21 @@ def build_train_valid_test_data_loaders(build_train_valid_test_datasets_provider if not args.skip_train: train_dataloader = build_pretraining_data_loader(train_ds, args.consumed_train_samples) - valid_dataloaders = [] - for valid_d in valid_ds: - if args.skip_train or args.full_validation: - valid_dataloaders.append(build_pretraining_data_loader(valid_d, 0)) - else: - if args.multiple_validation_sets: - # TODO(bnorick): for multiple validation sets without full validation, args.consumed_valid_samples is not - # correct and needs to be calculated/set per validation set - raise NotImplementedError("--multiple-validation-sets currently requires --full-validation") - valid_dataloaders.append(build_pretraining_data_loader(valid_d, args.consumed_valid_samples)) - if not args.multiple_validation_sets: - assert len(valid_dataloaders) == 1 - test_dataloader = build_pretraining_data_loader(test_ds, 0) + valid_dataloaders = None + test_dataloader = None + # valid_dataloaders = [] + # for valid_d in valid_ds: + # if args.skip_train or args.full_validation: + # valid_dataloaders.append(build_pretraining_data_loader(valid_d, 0)) + # else: + # if args.multiple_validation_sets: + # # TODO(bnorick): for multiple validation sets without full validation, args.consumed_valid_samples is not + # # correct and needs to be calculated/set per validation set + # raise NotImplementedError("--multiple-validation-sets currently requires --full-validation") + # valid_dataloaders.append(build_pretraining_data_loader(valid_d, args.consumed_valid_samples)) + # if not args.multiple_validation_sets: + # assert len(valid_dataloaders) == 1 + # test_dataloader = build_pretraining_data_loader(test_ds, 0) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and args.train_iters > 0 diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 4730a525271..457fe4ccd03 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -8,6 +8,7 @@ from contextlib import contextmanager from datetime import datetime from collections import defaultdict +from typing import Optional import torch @@ -43,6 +44,7 @@ unwrap_model, ) from megatron.legacy.model.module import param_is_not_shared +from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator def calc_params_l2_norm(model, force_create_fp32_copy=False): @@ -514,8 +516,11 @@ def get_blend_and_blend_per_split(args): return blend, blend_per_split - -def get_batch_on_this_tp_rank(data_iterator, mtp_on_this_rank: bool = False): +def get_batch_on_this_tp_rank( + data_iterator, + mtp_on_this_rank: bool = False, + vp_stage: Optional[int] = None, + ): args = get_args() @@ -526,42 +531,64 @@ def _broadcast(item): mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group(), ) - + if mpu.get_tensor_model_parallel_rank() == 0: assert data_iterator is not None data = next(data_iterator) + batch = { - 'tokens': data["tokens"].cuda(non_blocking=True), - 'labels': data["labels"].cuda(non_blocking=True), - 'loss_mask': data["loss_mask"].cuda(non_blocking=True), + 'tokens': ( + data["tokens"].cuda(non_blocking=True) + if "tokens" in data + else None + ), + 'labels': ( + data["labels"].cuda(non_blocking=True) + if "labels" in data + else None + ), + 'loss_mask': ( + data["loss_mask"].cuda(non_blocking=True) + if "loss_mask" in data + else None + ), 'attention_mask': ( - None - if "attention_mask" not in data - else data["attention_mask"].cuda(non_blocking=True) + data["attention_mask"].cuda(non_blocking=True) + if "attention_mask" in data + else None + ), + 'position_ids': ( + data["position_ids"].cuda(non_blocking=True) + if "position_ids" in data + else None ), - 'position_ids': data["position_ids"].cuda(non_blocking=True), 'cu_seqlens': ( - None - if "cu_seqlens" not in data - else data["cu_seqlens"].cuda(non_blocking=True) + data["cu_seqlens"].cuda(non_blocking=True) + if "cu_seqlens" in data + else None + ), + 'cu_seqlens_padded': ( + data["cu_seqlens_padded"].cuda(non_blocking=True) + if "cu_seqlens_padded" in data + else None ), 'max_seqlen': ( - None - if "max_seqlen" not in data - else data["max_seqlen"].cuda(non_blocking=True) + data["max_seqlen"].cuda(non_blocking=True) + if "max_seqlen" in data + else None ), 'local_cp_size': ( - None - if "local_cp_size" not in data - else data["local_cp_size"].cuda(non_blocking=True) + data["local_cp_size"].cuda(non_blocking=True) + if "local_cp_size" in data + else None ), } def _broadcast_cu_seqlens(cu_seqlens): dev = torch.cuda.current_device() n = 0 if cu_seqlens is None else int(cu_seqlens.numel()) - n_tensor = torch.tensor(n, dtype=torch.int64, device=dev) + n_tensor = torch.tensor(n, dtype=torch.int32, pin_memory=True).to(dev, non_blocking=True) _broadcast(n_tensor) if n == 0: @@ -569,12 +596,11 @@ def _broadcast_cu_seqlens(cu_seqlens): else: assert isinstance(cu_seqlens, torch.Tensor) assert cu_seqlens.dtype == torch.int32 - assert cu_seqlens.shape[0] == 1, "micro-batch-size must be 1 for packing" buf = cu_seqlens.to(device=dev, non_blocking=True).contiguous() _broadcast(buf) - if args.hybrid_context_parallel: - seq_len = torch.tensor(batch['tokens'].shape[0], dtype=torch.int32, device=torch.cuda.current_device()) + if args.sft_sequence_packing and is_first_or_last_pipeline_stage(vp_stage): + seq_len = torch.tensor(batch['labels'].shape[0], dtype=torch.int32, pin_memory=True).to(torch.cuda.current_device(), non_blocking=True) _broadcast(seq_len) if args.pipeline_model_parallel_size == 1 or mtp_on_this_rank: @@ -583,69 +609,88 @@ def _broadcast_cu_seqlens(cu_seqlens): _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) - _broadcast_cu_seqlens(batch['cu_seqlens']) _broadcast(batch['max_seqlen']) _broadcast(batch['local_cp_size']) + if args.sft_sequence_packing: + _broadcast_cu_seqlens(batch['cu_seqlens']) + _broadcast_cu_seqlens(batch['cu_seqlens_padded']) - elif mpu.is_pipeline_first_stage(): + elif mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage): _broadcast(batch['tokens']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) - _broadcast_cu_seqlens(batch['cu_seqlens']) _broadcast(batch['max_seqlen']) + _broadcast(batch['local_cp_size']) + if args.sft_sequence_packing: + _broadcast_cu_seqlens(batch['cu_seqlens']) + _broadcast_cu_seqlens(batch['cu_seqlens_padded']) - elif mpu.is_pipeline_last_stage(): + elif mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage): # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. # Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage. _broadcast(batch['labels']) _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) + _broadcast(batch['max_seqlen']) + _broadcast(batch['local_cp_size']) + if args.sft_sequence_packing: + _broadcast_cu_seqlens(batch['cu_seqlens']) + _broadcast_cu_seqlens(batch['cu_seqlens_padded']) + + elif (not is_first_or_last_pipeline_stage(vp_stage)) and args.sft_sequence_packing: + # Except for PP rank 0 and the last PP rank, broadcast + # cu_seqlens, cu_seqlens_padded and max_seqlen for the THD format. + _broadcast_cu_seqlens(batch['cu_seqlens']) + _broadcast_cu_seqlens(batch['cu_seqlens_padded']) + _broadcast(batch['max_seqlen']) + _broadcast(batch['local_cp_size']) else: - if args.hybrid_context_parallel: - seq_len = torch.tensor(0, dtype=torch.int32, device=torch.cuda.current_device()) - _broadcast(seq_len) - shape = (seq_len.item()) - else: - shape = (args.micro_batch_size, args.seq_length) - - tokens = torch.empty( - shape, - dtype=torch.int64, - device=torch.cuda.current_device(), - ) - labels = torch.empty( - shape, - dtype=torch.int64, - device=torch.cuda.current_device(), - ) - loss_mask = torch.empty( - shape, - dtype=torch.float32, - device=torch.cuda.current_device(), - ) - if args.create_attention_mask_in_dataloader: - shape_attention_mask = (args.micro_batch_size, 1, args.seq_length, args.seq_length) if not args.hybrid_context_parallel else (1, 1, shape[0], shape[0]) - attention_mask = torch.empty( - shape_attention_mask, - dtype=torch.bool, + if is_first_or_last_pipeline_stage(vp_stage): + if args.sft_sequence_packing: + seq_len = torch.zeros(1, dtype=torch.int32, device=torch.cuda.current_device()) + _broadcast(seq_len) + shape = (seq_len.item()) + else: + shape = (args.micro_batch_size, args.seq_length) + tokens = torch.empty( + shape, + dtype=torch.int64, + device=torch.cuda.current_device(), + ) + labels = torch.empty( + shape, + dtype=torch.int64, + device=torch.cuda.current_device(), + ) + loss_mask = torch.empty( + shape, + dtype=torch.float32, + device=torch.cuda.current_device(), + ) + if args.create_attention_mask_in_dataloader: + shape_attention_mask = (args.micro_batch_size, 1, args.seq_length, args.seq_length) if not args.hybrid_context_parallel else (1, 1, shape[0], shape[0]) + attention_mask = torch.empty( + shape_attention_mask, + dtype=torch.bool, + device=torch.cuda.current_device(), + ) + else: + attention_mask = None + position_ids = torch.empty( + shape, + dtype=torch.int64, device=torch.cuda.current_device(), ) - else: - attention_mask = None - position_ids = torch.empty( - shape, - dtype=torch.int64, - device=torch.cuda.current_device(), - ) cu_seqlens = None + cu_seqlens_padded = None max_seqlen = torch.empty( 1, dtype=torch.int32, device=torch.cuda.current_device(), - ) if args.hybrid_context_parallel else None + ) if args.sft_sequence_packing else None local_cp_size = torch.empty( 1, dtype=torch.int32, @@ -655,49 +700,81 @@ def _broadcast_cu_seqlens(cu_seqlens): def _broadcast_cu_seqlens(): dev = torch.cuda.current_device() - n = torch.empty((), dtype=torch.int64, device=dev) + n = torch.empty((), dtype=torch.int32, device=dev) _broadcast(n) n = int(n.item()) - if n == 0: cu_seqlens = torch.empty(0, dtype=torch.int32, device=dev) else: - cu_seqlens = torch.empty((args.micro_batch_size, n), dtype=torch.int32, device=dev) + cu_seqlens = torch.empty(n, dtype=torch.int32, device=dev) _broadcast(cu_seqlens) return cu_seqlens if n > 0 else None + cu_seqlens = None + cu_seqlens_padded = None + max_seqlen = torch.empty( + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) if args.sft_sequence_packing else None + local_cp_size = torch.empty( + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) if args.hybrid_context_parallel else None + if args.pipeline_model_parallel_size == 1 or mtp_on_this_rank: _broadcast(tokens) _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) _broadcast(position_ids) - cu_seqlens = _broadcast_cu_seqlens() _broadcast(max_seqlen) _broadcast(local_cp_size) + if args.sft_sequence_packing: + cu_seqlens = _broadcast_cu_seqlens() + cu_seqlens_padded = _broadcast_cu_seqlens() - elif mpu.is_pipeline_first_stage(): + elif mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage): labels = None loss_mask = None - _broadcast(tokens) _broadcast(attention_mask) _broadcast(position_ids) - cu_seqlens = _broadcast_cu_seqlens() _broadcast(max_seqlen) + _broadcast(local_cp_size) + if args.sft_sequence_packing: + cu_seqlens = _broadcast_cu_seqlens() + cu_seqlens_padded = _broadcast_cu_seqlens() - elif mpu.is_pipeline_last_stage(): + elif mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage): # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. # Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage. tokens = None position_ids = None cu_seqlens = None - max_seqlen = None + cu_seqlens_padded = None + # max_seqlen = None + _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) + _broadcast(max_seqlen) + _broadcast(local_cp_size) + if args.sft_sequence_packing: + cu_seqlens = _broadcast_cu_seqlens() + cu_seqlens_padded = _broadcast_cu_seqlens() + + elif (not is_first_or_last_pipeline_stage(vp_stage)) and args.sft_sequence_packing: + # Except for PP rank 0 and the last PP rank, broadcast + # cu_seqlens, cu_seqlens_padded and max_seqlen for the THD format. + tokens, labels, loss_mask, attention_mask, position_ids = None, None, None, None, None + cu_seqlens = _broadcast_cu_seqlens() + cu_seqlens_padded = _broadcast_cu_seqlens() + _broadcast(max_seqlen) + _broadcast(local_cp_size) batch = { 'tokens': tokens, @@ -706,10 +783,19 @@ def _broadcast_cu_seqlens(): 'attention_mask': attention_mask, 'position_ids': position_ids, 'cu_seqlens': cu_seqlens, + 'cu_seqlens_padded': cu_seqlens_padded, 'max_seqlen': max_seqlen, 'local_cp_size': local_cp_size, } + if args.sft_sequence_packing and not args.hybrid_context_parallel: + # using THD(sequence packing) but not using hybrid-cp, + # so we need to pop the local_cp_size + batch.pop('local_cp_size') + elif not args.sft_sequence_packing: + keys_to_keep = ['tokens', 'labels', 'loss_mask', 'attention_mask', 'position_ids'] + batch = {k: v for k, v in batch.items() if k in keys_to_keep} + return batch diff --git a/pretrain_gpt.py b/pretrain_gpt.py index e976f5aff79..16a6af09a0c 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -6,15 +6,21 @@ from typing import List, Optional, Tuple import torch +import os +import nvtx from gpt_builders import gpt_builder from megatron.core import parallel_state +from megatron.core.parallel_state import ( + get_context_parallel_rank, + get_context_parallel_world_size, +) from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset from megatron.core.enums import ModelType from megatron.core.models.gpt import GPTModel from megatron.core.rerun_state_machine import get_rerun_state_machine -from megatron.core.utils import get_attr_wrapped_model, get_thd_batch_on_this_cp_rank, get_batch_on_this_hybrid_cp_rank, StragglerDetector +from megatron.core.utils import get_attr_wrapped_model, get_thd_batch_on_this_cp_rank, StragglerDetector from megatron.core.tokenizers.text.utils.build_tokenizer import build_tokenizer from megatron.core.transformer.multi_token_prediction import mtp_on_this_rank, get_mtp_ranks from megatron.training.arguments import core_transformer_config_from_args @@ -26,6 +32,7 @@ get_blend_and_blend_per_split, is_first_or_last_pipeline_stage, ) +from megatron.training.datasets.sft_dataset import SFTDataset, MockSFTDataset from model_provider import model_provider try: @@ -36,6 +43,16 @@ except ImportError: has_nvidia_modelopt = False +try: + # Register the TE CUDA kernels + import transformer_engine # pylint: disable=unused-import + + # Alias the PyTorch wrapper so we can call tex.* APIs + import transformer_engine_torch as tex +except ImportError: + # TE isn’t installed or the torch wrapper is missing + tex = None + stimer = StragglerDetector() @@ -43,35 +60,46 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None): """Generate a batch.""" args = get_args() config = core_transformer_config_from_args(args) - # TODO: this is pretty hacky, find a better way - if not is_first_or_last_pipeline_stage(vp_stage) and ( - (not mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage))): - return None, None, None, None, None, None - - # get batches based on the TP rank you are on - batch = get_batch_on_this_tp_rank( + + if args.sft_sequence_packing: + + # get batches based on the TP rank you are on + nvtx.push_range("get_batch_on_this_tp_rank") + batch = get_batch_on_this_tp_rank( + data_iterator, + mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage), + vp_stage=vp_stage, + ) + nvtx.pop_range() + + cu_seqlens = batch.pop('cu_seqlens') + cu_seqlens_padded = batch.pop('cu_seqlens_padded') + max_seqlen = int(batch.pop('max_seqlen').item()) + # local_cp_size is None if we disable hybrid-cp + local_cp_size = int(batch.pop('local_cp_size').item()) if ('local_cp_size' in batch) else None + + if is_first_or_last_pipeline_stage(vp_stage): + batch, packed_seq_params = get_thd_batch_on_this_cp_rank(batch, cu_seqlens, + cu_seqlens_padded, max_seqlen, local_cp_size=local_cp_size, vp_stage=vp_stage) + return (*batch.values(), packed_seq_params) + + else: + _, packed_seq_params = get_thd_batch_on_this_cp_rank(batch, cu_seqlens, + cu_seqlens_padded, max_seqlen, local_cp_size=local_cp_size, only_packed_seq_params=True) + return None, None, None, None, None, packed_seq_params + else: + # TODO: this is pretty hacky, find a better way + if not is_first_or_last_pipeline_stage(vp_stage) and ( + (not mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage))): + return None, None, None, None, None, None + batch = get_batch_on_this_tp_rank( data_iterator, - mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage) + mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage), + vp_stage=vp_stage ) - - cu_seqlens = batch.pop('cu_seqlens', None) - cu_seqlens_padded = batch.pop('cu_seqlens_padded', None) - max_seqlen = batch.pop('max_seqlen', None) - local_cp_size = batch.pop('local_cp_size', None) - if local_cp_size is not None: - local_cp_size = int(local_cp_size.item()) - - if cu_seqlens is None and local_cp_size is None: - # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore packed_seq_params = None - elif local_cp_size is None: # Packed THD format - assert max_seqlen.dim() == 1 - batch, packed_seq_params = get_thd_batch_on_this_cp_rank(batch, cu_seqlens, cu_seqlens_padded, max_seqlen) - else: # Hybrid CP format - batch, packed_seq_params = get_batch_on_this_hybrid_cp_rank(batch, local_cp_size) - - return (*batch.values(), packed_seq_params) + return (*batch.values(), packed_seq_params) # define spiky loss as a loss that's 10x the max loss observed @@ -158,12 +186,13 @@ def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = Fa vp_stage = get_attr_wrapped_model(model, "vp_stage") tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch(data_iterator, vp_stage) timers('batch-generator').stop() - + # if parallel_state.get_pipeline_model_parallel_rank() == 0: + # print(f"{tokens.shape=}, dp rank:{parallel_state.get_data_parallel_rank()}") with stimer: if args.use_legacy_models: output_tensor = model(tokens, position_ids, attention_mask, labels=labels) else: - if return_schedule_plan: + if return_schedule_plan: assert args.overlap_moe_expert_parallel_comm, \ "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan" schedule_plan = model.build_schedule_plan( @@ -182,10 +211,11 @@ def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = Fa def is_dataset_built_on_rank(vp_stage=None): args = get_args() config = core_transformer_config_from_args(args) - return ( - is_first_or_last_pipeline_stage(vp_stage) - or mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage) - ) and parallel_state.get_tensor_model_parallel_rank() == 0 + # return ( + # is_first_or_last_pipeline_stage(vp_stage) + # or mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage) + # ) and parallel_state.get_tensor_model_parallel_rank() == 0 + return parallel_state.get_tensor_model_parallel_rank() == 0 def core_gpt_dataset_config_from_args(args): @@ -222,6 +252,9 @@ def core_gpt_dataset_config_from_args(args): data_parallel_size=args.data_parallel_size, sequence_parallel_size=args.tensor_model_parallel_size*args.sequence_parallel, hybrid_context_parallel=args.hybrid_context_parallel, + sft_mock_dataset_config_json=args.sft_mock_dataset_config_json, + sft_sequence_packing=args.sft_sequence_packing, + hybrid_context_parallel_scheduler=args.hybrid_context_parallel_scheduler, ) @@ -236,7 +269,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None config = core_gpt_dataset_config_from_args(args) if args.sft: - dataset_type = SFTDataset + if args.mock_data: + dataset_type = MockSFTDataset + else: + dataset_type = SFTDataset else: if args.mock_data: dataset_type = MockGPTDataset @@ -276,8 +312,9 @@ def get_embedding_ranks(pp_ranks: List[int]): train_valid_test_datasets_provider.is_distributed = True # Optionally enable inprocess restart on pretrain - pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) - + # pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) + store = None + # torch.cuda.memory._record_memory_history(max_entries=8000000) pretrain( train_valid_test_datasets_provider, partial(model_provider, gpt_builder),