diff --git a/docs/model-quirks.md b/docs/model-quirks.md index ec37048469..31827af86d 100644 --- a/docs/model-quirks.md +++ b/docs/model-quirks.md @@ -33,6 +33,11 @@ NeMo-RL uses the vLLM V1 runtime for both synchronous and asynchronous inference - NeMo-RL implemented this feature based on torch CP [implementation](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/experimental/_attention.py). And we inherit its limitations. Whether model level support CP only depends on arguments passed to `torch.nn.functional.scaled_dot_product_attention`. Current NeMo-RL passed all ones attention mask to `model.forward`. For Gemma-3, it won't ignore attention mask as result `attn_bias` is not None which is not supported by torch CP. Please see [assertion](https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/distributed/tensor/experimental/_attention.py#L262) . + - Context parallel can't be used together with sequence packing. Sequence packing requires `attn_implementation="flash_attention_2"`, this conflict with context parallel requires SDPA impl. Refer to [here](https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/modeling_utils.py#L2317) for more details. + + +- It's a known issue that context parallel can't be used together with sequence parallel. +Refer to [here](https://github.com/NVIDIA-NeMo/RL/issues/659) for more details. - It's a known issue that context parallel can't be used together with sequence parallel. Refer to [here](https://github.com/NVIDIA-NeMo/RL/issues/659) for more details. diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 8a5f5dcf00..bcbffb0761 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -57,6 +57,9 @@ policy: dynamic_batching: enabled: false + sequence_packing: + enabled: false + # makes the training sequence length divisible by the tensor parallel size # this is useful for sequence parallel training make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} diff --git a/examples/configs/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/grpo-deepscaler-1.5b-8K.yaml index 9efc308a0a..08d021f582 100644 --- a/examples/configs/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-8K.yaml @@ -55,6 +55,9 @@ policy: dynamic_batching: enabled: False + sequence_packing: + enabled: False + # makes the training sequence length divisible by the tensor parallel size # this is useful for sequence parallel training make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} diff --git a/examples/configs/grpo_deepscaler-1.5b-24K.yaml b/examples/configs/grpo_deepscaler-1.5b-24K.yaml index f2552eea7e..dc9db4ceab 100644 --- a/examples/configs/grpo_deepscaler-1.5b-24K.yaml +++ b/examples/configs/grpo_deepscaler-1.5b-24K.yaml @@ -21,6 +21,9 @@ policy: dynamic_batching: enabled: False + sequence_packing: + enabled: False + optimizer: name: "torch.optim.AdamW" kwargs: diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index c003eea2e1..a388f7b2cc 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -51,6 +51,9 @@ policy: tensor_parallel_size: 1 context_parallel_size: 1 custom_parallel_plan: null + + megatron_cfg: + enabled: false # dynamic_batching improves performance by ensuring logprob and training microbatches # have a sufficent number of tokens to maximize GPU utilization. Specifically, variable length @@ -58,9 +61,16 @@ policy: # amount of tokens is approximately close to 'train_mb_tokens' and 'logprob_mb_tokens' for the # training and logprob stages respectively. dynamic_batching: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" sequence_length_round: 64 # makes the training sequence length divisible by the tensor parallel size diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index 79a579e278..d58eb47aae 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -49,14 +49,19 @@ policy: # responses are sorted by sequence length and bucketed into microbatches with a total # amount of tokens is approximately close to 'train_mb_tokens' and 'logprob_mb_tokens' for the # training and logprob stages respectively. + # + # We disable it for Megatron as it is incompatible with Pipeline parallelism. Instead, we use sequence packing dynamic_batching: enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 sequence_packing: - enabled: False # coming soon + enabled: True train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} - algorithm: "modified_ffd" + algorithm: "modified_first_fit_decreasing" sequence_length_round: 64 max_grad_norm: 1.0 diff --git a/examples/configs/grpo_math_8B_megatron.yaml b/examples/configs/grpo_math_8B_megatron.yaml index ef0e932b0c..004bc738b0 100644 --- a/examples/configs/grpo_math_8B_megatron.yaml +++ b/examples/configs/grpo_math_8B_megatron.yaml @@ -72,4 +72,4 @@ policy: cluster: gpus_per_node: 8 - num_nodes: 1 \ No newline at end of file + num_nodes: 1 diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml index 8655dede0a..b060004882 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml @@ -43,7 +43,10 @@ policy: custom_parallel_plan: null dynamic_batching: - enabled: False + enabled: false + + sequence_packing: + enabled: false make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} max_grad_norm: 1.0 diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml index e8e8f472c0..c34771595b 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml @@ -43,7 +43,10 @@ policy: custom_parallel_plan: null dynamic_batching: - enabled: False + enabled: false + + sequence_packing: + enabled: false make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} max_grad_norm: 1.0 diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml index 3dbb98006e..abc42f30eb 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml @@ -37,7 +37,10 @@ policy: enabled: false dynamic_batching: - enabled: False + enabled: false + + sequence_packing: + enabled: false make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size} max_grad_norm: 1.0 diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml index 082520095e..a571f32582 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml @@ -37,7 +37,10 @@ policy: enabled: false dynamic_batching: - enabled: False + enabled: false + + sequence_packing: + enabled: false make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size} max_grad_norm: 1.0 diff --git a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml index afe19bf4ea..832d989b59 100644 --- a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml @@ -44,7 +44,10 @@ policy: custom_parallel_plan: null dynamic_batching: - enabled: False + enabled: false + + sequence_packing: + enabled: false make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} max_grad_norm: 1.0 diff --git a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml index a2c61ebce9..b503afad4b 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml @@ -49,6 +49,8 @@ policy: train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} sequence_length_round: 64 + sequence_packing: + enabled: false make_sequence_length_divisible_by: 1 max_grad_norm: 1 optimizer: diff --git a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml index 0fe72a150d..ea3188b9ae 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml @@ -50,6 +50,8 @@ policy: train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} sequence_length_round: 64 + sequence_packing: + enabled: false make_sequence_length_divisible_by: 8 max_grad_norm: 1 optimizer: diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml index 2ad3228001..d29b88c4e0 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml @@ -49,6 +49,8 @@ policy: train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} sequence_length_round: 64 + sequence_packing: + enabled: false make_sequence_length_divisible_by: 1 max_grad_norm: 1 optimizer: diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index 3caf0ccdbd..355cd3a5d3 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -49,6 +49,8 @@ policy: train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} sequence_length_round: 64 + sequence_packing: + enabled: false make_sequence_length_divisible_by: 1 max_grad_norm: 1 optimizer: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml index ae6426e305..0ce93de5ae 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml @@ -49,6 +49,8 @@ policy: train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} sequence_length_round: 64 + sequence_packing: + enabled: false make_sequence_length_divisible_by: 8 max_grad_norm: 1 optimizer: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml index e4449ae147..45788b3172 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -49,6 +49,8 @@ policy: train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} sequence_length_round: 64 + sequence_packing: + enabled: false make_sequence_length_divisible_by: 8 max_grad_norm: 1 optimizer: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml index 585a8f5d88..ae0add9bd2 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml @@ -49,6 +49,8 @@ policy: train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} sequence_length_round: 64 + sequence_packing: + enabled: false make_sequence_length_divisible_by: 4 max_grad_norm: 1 optimizer: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml index 78bfeee82d..cce3f5b327 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -49,6 +49,8 @@ policy: train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} sequence_length_round: 64 + sequence_packing: + enabled: false make_sequence_length_divisible_by: 1 max_grad_norm: 1 optimizer: diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml index a35a13533e..50aa3b96c6 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml @@ -34,7 +34,9 @@ policy: context_parallel_size: 1 custom_parallel_plan: null dynamic_batching: - enabled: False + enabled: false + sequence_packing: + enabled: false make_sequence_length_divisible_by: 1 max_grad_norm: 1 optimizer: diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml index 608edace8d..7a774c3654 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml @@ -34,7 +34,9 @@ policy: context_parallel_size: 1 custom_parallel_plan: null dynamic_batching: - enabled: False + enabled: false + sequence_packing: + enabled: false make_sequence_length_divisible_by: 2 max_grad_norm: 1 optimizer: diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml index 4fdfb5d37b..14c2f9692e 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml @@ -28,7 +28,9 @@ policy: dtensor_cfg: enabled: false dynamic_batching: - enabled: False + enabled: false + sequence_packing: + enabled: false make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size} max_grad_norm: 1 optimizer: null diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml index 8dac5cb980..617ce45096 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml @@ -34,7 +34,9 @@ policy: context_parallel_size: 1 custom_parallel_plan: null dynamic_batching: - enabled: False + enabled: false + sequence_packing: + enabled: false make_sequence_length_divisible_by: 1 max_grad_norm: 1 optimizer: diff --git a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml index bf38f37eb7..6761e2f015 100644 --- a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml +++ b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml @@ -34,7 +34,9 @@ policy: context_parallel_size: 1 custom_parallel_plan: null dynamic_batching: - enabled: False + enabled: false + sequence_packing: + enabled: false make_sequence_length_divisible_by: 8 max_grad_norm: 1 optimizer: diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index b14c6304dd..8a8b7d7129 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -44,6 +44,12 @@ policy: dynamic_batching: enabled: false + sequence_packing: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + # makes the training sequence length divisible by the tensor parallel size # this is useful for sequence parallel training make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} @@ -121,7 +127,7 @@ policy: average_in_collective: true data_parallel_sharding_strategy: "optim_grads_params" - + data: max_input_seq_length: ${policy.max_total_sequence_length} dataset_name: "squad" diff --git a/examples/configs/sft_openmathinstruct2.yaml b/examples/configs/sft_openmathinstruct2.yaml index de9fab880a..aa128e5a99 100644 --- a/examples/configs/sft_openmathinstruct2.yaml +++ b/examples/configs/sft_openmathinstruct2.yaml @@ -40,6 +40,9 @@ policy: dynamic_batching: enabled: false + sequence_packing: + enabled: false + # makes the training sequence length divisible by the tensor parallel size # this is useful for sequence parallel training make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} diff --git a/examples/run_sft.py b/examples/run_sft.py index ce5b258b0c..df0d7ce3f7 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -31,6 +31,8 @@ from nemo_rl.utils.config import load_config, parse_hydra_overrides from nemo_rl.utils.logger import get_next_experiment_dir +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + def parse_args(): """Parse command line arguments.""" diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 1bf472d830..923e836554 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -114,6 +114,7 @@ def __call__( global_valid_toks: torch.Tensor, vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict]: """Clipped Policy Gradient RL loss function.""" token_mask = data["token_mask"][:, 1:] @@ -149,7 +150,10 @@ def __call__( vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], tp_group=vocab_parallel_group, inference_only=False, + cp_group=context_parallel_group, ) + # slice off to the correct length to remove potential CP padding + curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1] elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): curr_logprobs = get_logprobs_from_vocab_parallel_logits( next_token_logits, data["input_ids"], seq_index=seq_index @@ -312,6 +316,7 @@ def __call__( global_valid_toks: Tensor, vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, dpo_loss: bool = False, dpo_average_log_probs: bool = False, ) -> tuple[torch.Tensor, dict[str, Any]]: @@ -335,7 +340,10 @@ def __call__( vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], tp_group=vocab_parallel_group, inference_only=False, + cp_group=context_parallel_group, ) + # slice off to the correct length to remove potential CP padding + token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): token_logprobs = get_logprobs_from_vocab_parallel_logits( next_token_logits, data["input_ids"] @@ -466,6 +474,7 @@ def _preference_loss( global_valid_seqs: Tensor, vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: ## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor token_mask = data["token_mask"][:, 1:] @@ -483,7 +492,10 @@ def _preference_loss( vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], tp_group=vocab_parallel_group, inference_only=False, + cp_group=context_parallel_group, ) + # slice off to the correct length to remove potential CP padding + token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): token_logprobs = get_logprobs_from_vocab_parallel_logits( next_token_logits, data["input_ids"] @@ -548,6 +560,7 @@ def __call__( global_valid_toks: Tensor | None, vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: sft_loss_chosen = torch.tensor(0.0) if self.sft_loss_weight > 0: @@ -561,6 +574,7 @@ def __call__( global_valid_toks=global_valid_toks, ## unused because sft loss returned is at the sample level vocab_parallel_rank=vocab_parallel_rank, vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, dpo_loss=True, dpo_average_log_probs=self.sft_average_log_probs, ) @@ -582,6 +596,7 @@ def __call__( global_valid_seqs, vocab_parallel_rank=vocab_parallel_rank, vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, ) dpo_loss = ( @@ -601,3 +616,83 @@ def __call__( "rewards_rejected_mean": rewards_rejected_mean.item(), "num_valid_samples": num_valid_samples.item(), } + + +class SequencePackingLossWrapper: + def __init__( + self, + loss_fn: LossFunction, + cu_seqlens_q: Tensor, + cu_seqlens_q_padded: Optional[Tensor] = None, + ): + self.loss_fn = loss_fn + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_q_padded = cu_seqlens_q_padded + + def __call__( + self, + next_token_logits: Tensor, + data: BatchedDataDict[Any], + global_valid_seqs: Tensor | None, + global_valid_toks: Tensor | None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + ) -> tuple[Tensor, dict[str, Any]]: + """Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.""" + unpadded_cu_seqlens = self.cu_seqlens_q + unpadded_seq_lengths = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] + if self.cu_seqlens_q_padded is not None: + padded_cu_seqlens = self.cu_seqlens_q_padded + padded_seq_lengths = ( + self.cu_seqlens_q_padded[1:] - self.cu_seqlens_q_padded[:-1] + ) + else: + padded_cu_seqlens = unpadded_cu_seqlens + padded_seq_lengths = unpadded_seq_lengths + seq_starts = padded_cu_seqlens[:-1] + seq_ends = padded_cu_seqlens[1:] + + loss_accum = 0 + metrics_accum = {} + for seq_idx in range(len(seq_starts)): + seq_start = seq_starts[seq_idx].item() + seq_end = seq_ends[seq_idx].item() + + # get sequence and unpad all 'data' tensors. The data dict is a BatchedDataDict of unpacked tensors + seq_data = data.slice(seq_idx, seq_idx + 1) + unpadded_seq_data = {} + for k, v in seq_data.items(): + if isinstance(v, torch.Tensor) and v.ndim > 1 and v.shape[1] > 1: + unpadded_seq_data[k] = v[:, : unpadded_seq_lengths[seq_idx]] + else: + unpadded_seq_data[k] = v + + # get next_token_logits + cp_size = ( + 1 + if context_parallel_group is None + else torch.distributed.get_world_size(context_parallel_group) + ) + logit_slice_idxs = slice( + seq_start // cp_size, + (seq_start + padded_seq_lengths[seq_idx]) // cp_size, + ) + next_token_logits_slice = next_token_logits[:, logit_slice_idxs, :] + + loss, metrics = self.loss_fn( + next_token_logits_slice, + unpadded_seq_data, + global_valid_seqs, + global_valid_toks, + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + ) + loss_accum += loss + for k, v in metrics.items(): + if k not in metrics_accum: + metrics_accum[k] = 0 + metrics_accum[k] += v + + return loss_accum, metrics_accum diff --git a/nemo_rl/data/packing/__init__.py b/nemo_rl/data/packing/__init__.py new file mode 100644 index 0000000000..a955f681cc --- /dev/null +++ b/nemo_rl/data/packing/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_rl.data.packing.algorithms import ( + ConcatenativePacker, + FirstFitDecreasingPacker, + FirstFitShufflePacker, + ModifiedFirstFitDecreasingPacker, + PackingAlgorithm, + SequencePacker, + get_packer, +) +from nemo_rl.data.packing.metrics import PackingMetrics + +__all__ = [ + "PackingAlgorithm", + "SequencePacker", + "ConcatenativePacker", + "FirstFitDecreasingPacker", + "FirstFitShufflePacker", + "ModifiedFirstFitDecreasingPacker", + "get_packer", + "PackingMetrics", +] diff --git a/nemo_rl/data/packing/algorithms.py b/nemo_rl/data/packing/algorithms.py new file mode 100644 index 0000000000..71e643f2b7 --- /dev/null +++ b/nemo_rl/data/packing/algorithms.py @@ -0,0 +1,571 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sequence packing algorithms for efficient batching of variable-length sequences.""" + +import enum +import math +import random +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple, Type, Union + + +class PackingAlgorithm(enum.Enum): + """Enum for supported sequence packing algorithms.""" + + CONCATENATIVE = "concatenative" + FIRST_FIT_DECREASING = "first_fit_decreasing" + FIRST_FIT_SHUFFLE = "first_fit_shuffle" + MODIFIED_FIRST_FIT_DECREASING = "modified_first_fit_decreasing" + + +class SequencePacker(ABC): + """Abstract base class for sequence packing algorithms. + + Sequence packing is the process of efficiently arranging sequences of different + lengths into fixed-capacity bins (batches) to maximize computational efficiency. + """ + + def __init__(self, bin_capacity: int, collect_metrics: bool = False): + """Initialize the sequence packer. + + Args: + bin_capacity: The maximum capacity of each bin. + collect_metrics: Whether to collect metrics across multiple packing operations. + """ + self.bin_capacity = bin_capacity + self.collect_metrics = collect_metrics + self.metrics = None + + if collect_metrics: + from nemo_rl.data.packing.metrics import PackingMetrics + + self.metrics = PackingMetrics() + + @abstractmethod + def _pack_implementation(self, sequence_lengths: List[int]) -> List[List[int]]: + """Implementation of the packing algorithm. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of bins, where each bin is a list of indices into the original + sequence_lengths list. + """ + pass + + def pack(self, sequence_lengths: List[int]) -> List[List[int]]: + """Pack sequences into bins and update metrics if enabled. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of bins, where each bin is a list of indices into the original + sequence_lengths list. + """ + # Call the implementation + bins = self._pack_implementation(sequence_lengths) + + # Update metrics if collection is enabled + if self.collect_metrics and self.metrics: + self.metrics.update(sequence_lengths, bins, self.bin_capacity) + + return bins + + def reset_metrics(self) -> None: + """Reset collected metrics.""" + if self.metrics: + self.metrics.reset() + + def compute_metrics( + self, sequence_lengths: List[int], bins: List[List[int]] + ) -> Dict[str, float]: + """Calculate metrics for a packing solution without updating the metrics tracker. + + Args: + sequence_lengths: List of sequence lengths + bins: List of bins, where each bin is a list of indices + + Returns: + Dictionary of packing metrics + """ + if self.metrics: + return self.metrics.calculate_stats_only( + sequence_lengths, bins, self.bin_capacity + ) + else: + # Create a temporary metrics object if not collecting + from nemo_rl.data.packing.metrics import PackingMetrics + + temp_metrics = PackingMetrics() + return temp_metrics.calculate_stats_only( + sequence_lengths, bins, self.bin_capacity + ) + + def get_aggregated_metrics(self) -> Dict[str, float]: + """Get aggregated metrics across all packing operations. + + Returns: + Dictionary of aggregated metrics, or empty dict if not collecting + """ + if self.metrics: + return self.metrics.get_aggregated_stats() + else: + return {} + + def print_metrics(self) -> None: + """Print the current metrics in a formatted way.""" + if not self.metrics: + print( + "Metrics collection is not enabled. Initialize with collect_metrics=True." + ) + return + + self.metrics.print_aggregated_stats() + + def _validate_sequence_lengths(self, sequence_lengths: List[int]) -> None: + """Validate that all sequence lengths are within bin capacity. + + Args: + sequence_lengths: A list of sequence lengths to validate. + + Raises: + ValueError: If any sequence length exceeds bin capacity. + """ + for length in sequence_lengths: + if length > self.bin_capacity: + raise ValueError( + f"Sequence length {length} exceeds bin capacity {self.bin_capacity}" + ) + + def _create_indexed_lengths( + self, sequence_lengths: List[int], reverse: bool = False + ) -> List[Tuple[int, int]]: + """Create a list of (length, index) pairs from sequence lengths. + + Args: + sequence_lengths: A list of sequence lengths. + reverse: Whether to sort in descending order (True) or ascending order (False). + + Returns: + A list of (length, index) pairs, optionally sorted. + """ + indexed_lengths = [(length, i) for i, length in enumerate(sequence_lengths)] + if reverse: + indexed_lengths.sort(reverse=True) # Sort in descending order + return indexed_lengths + + def _estimate_bins_needed(self, sequence_lengths: List[int]) -> int: + """Estimate the number of bins needed based on total length. + + Args: + sequence_lengths: A list of sequence lengths. + + Returns: + Estimated number of bins needed. + """ + total_length = sum(sequence_lengths) + return max(1, math.ceil(total_length / self.bin_capacity)) + + +class ConcatenativePacker(SequencePacker): + """Concatenative packing algorithm. + + This algorithm simply concatenates sequences in order until reaching the bin capacity, + then starts a new bin. It doesn't try to optimize the packing in any way. + + Time complexity: O(n) where n is the number of sequences. + + Example: + ```python + >>> examples = { + ... "sequence_lengths": [4, 1, 3, 2, 1, 3, 4, 5] + ... } + >>> # If packed with seq_length=5: + ... {"bins": [ [0, 1], [2, 3], [4, 5], [6], [7] ]} + >>> # If packed with seq_length=8: + ... {"bins": [ [0, 1, 2], [3, 4, 5], [6], [7] ]} + """ + + # Global class variable to limit the number of sequences packed in a unit + # -1 disables this limit + max_sequences_per_bin = 4 # Useful for debugging and testing + + def _pack_implementation(self, sequence_lengths: List[int]) -> List[List[int]]: + """Pack sequences using the Concatenative algorithm. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of bins, where each bin is a list of indices into the original + sequence_lengths list. + """ + # Validate sequence lengths + self._validate_sequence_lengths(sequence_lengths) + + bins = [] # List of bins, each bin is a list of sequence indices + current_bin = [] # Current bin being filled + current_length = 0 # Current length of sequences in the bin + + for i, length in enumerate(sequence_lengths): + # Check if adding this sequence would exceed bin capacity or sequence limit + exceeds_capacity = current_length + length > self.bin_capacity + exceeds_sequence_limit = ( + self.max_sequences_per_bin != -1 + and len(current_bin) >= self.max_sequences_per_bin + ) + + # If adding this sequence would exceed constraints, start a new bin + if exceeds_capacity or exceeds_sequence_limit: + if current_bin: # Only add the bin if it's not empty + bins.append(current_bin) + current_bin = [i] + current_length = length + else: + # Add the sequence to the current bin + current_bin.append(i) + current_length += length + + # Add the last bin if it's not empty + if current_bin: + bins.append(current_bin) + + return bins + + +class FirstFitPacker(SequencePacker): + """Base class for First-Fit algorithms. + + First-Fit algorithms place each sequence into the first bin where it fits. + If no bin can fit the sequence, a new bin is created. + + This is an abstract base class that provides the common implementation for + First-Fit variants. Subclasses must implement the _prepare_sequences method + to determine the order in which sequences are processed. + """ + + def _prepare_sequences(self, sequence_lengths: List[int]) -> List[Tuple[int, int]]: + """Prepare sequences for packing. + + This method determines the order in which sequences are processed. + Subclasses must override this method. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of (length, index) pairs. + """ + raise NotImplementedError("Subclasses must implement _prepare_sequences") + + def _pack_implementation(self, sequence_lengths: List[int]) -> List[List[int]]: + """Pack sequences using the First-Fit algorithm. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of bins, where each bin is a list of indices into the original + sequence_lengths list. + """ + # Prepare sequences for packing (order determined by subclass) + indexed_lengths = self._prepare_sequences(sequence_lengths) + + bins = [] # List of bins, each bin is a list of sequence indices + bin_remaining = [] # Remaining capacity for each bin + + for length, idx in indexed_lengths: + # If the sequence is larger than the bin capacity, it cannot be packed + if length > self.bin_capacity: + raise ValueError( + f"Sequence length {length} exceeds bin capacity {self.bin_capacity}" + ) + + # Try to find a bin where the sequence fits + bin_found = False + for i, remaining in enumerate(bin_remaining): + if remaining >= length: + # Add the sequence to this bin + bins[i].append(idx) + bin_remaining[i] -= length + bin_found = True + break + + # If no suitable bin was found, create a new one + if not bin_found: + bins.append([idx]) + bin_remaining.append(self.bin_capacity - length) + + return bins + + +class FirstFitDecreasingPacker(FirstFitPacker): + """First-Fit Decreasing (FFD) algorithm for sequence packing. + + This algorithm sorts sequences by length in descending order and then + places each sequence into the first bin where it fits. + + Time complexity: O(n log n) for sorting + O(n * m) for packing, + where n is the number of sequences and m is the number of bins. + """ + + def _prepare_sequences(self, sequence_lengths: List[int]) -> List[Tuple[int, int]]: + """Prepare sequences for packing by sorting them in descending order. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of (length, index) pairs sorted by length in descending order. + """ + # Create a list of (length, index) pairs + indexed_lengths = [(length, i) for i, length in enumerate(sequence_lengths)] + + # Sort by length in descending order + indexed_lengths.sort(reverse=True) + + return indexed_lengths + + +class FirstFitShufflePacker(FirstFitPacker): + """First-Fit Shuffle algorithm for sequence packing. + + This algorithm randomly shuffles the sequences and then places each + sequence into the first bin where it fits. + + Time complexity: O(n * m) for packing, where n is the number of sequences + and m is the number of bins. + """ + + def _prepare_sequences(self, sequence_lengths: List[int]) -> List[Tuple[int, int]]: + """Prepare sequences for packing by randomly shuffling them. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of (length, index) pairs in random order. + """ + # Create a list of (length, index) pairs + indexed_lengths = [(length, i) for i, length in enumerate(sequence_lengths)] + + # Shuffle the sequences + random.shuffle(indexed_lengths) + + return indexed_lengths + + +class ModifiedFirstFitDecreasingPacker(SequencePacker): + """Modified First-Fit Decreasing (MFFD) algorithm for sequence packing. + + This algorithm implements the Johnson & Garey (1985) Modified First-Fit-Decreasing + heuristic. It classifies items into four categories (large, medium, small, tiny) + and uses a sophisticated 5-phase packing strategy to achieve better bin utilization + than standard First-Fit Decreasing. + + The algorithm phases: + 1. Classify items by size relative to bin capacity + 2. Create one bin per large item + 3. Add medium items to large bins (forward pass) + 4. Add pairs of small items to bins with medium items (backward pass) + 5. Greedily fit remaining items + 6. Apply FFD to any leftovers + + Time complexity: O(n log n) for sorting + O(n * m) for packing, + where n is the number of sequences and m is the number of bins. + """ + + def _classify_items( + self, items: List[Tuple[int, int]] + ) -> Tuple[ + List[Tuple[int, int]], + List[Tuple[int, int]], + List[Tuple[int, int]], + List[Tuple[int, int]], + ]: + """Split items into large / medium / small / tiny classes. + + Follows the classification used by Johnson & Garey: + large : (C/2, C] + medium : (C/3, C/2] + small : (C/6, C/3] + tiny : (0 , C/6] + + Args: + items: List of (index, size) tuples + + Returns: + Tuple of four lists (large, medium, small, tiny) without additional sorting. + """ + large, medium, small, tiny = [], [], [], [] + for idx, size in items: + if size > self.bin_capacity / 2: + large.append((idx, size)) + elif size > self.bin_capacity / 3: + medium.append((idx, size)) + elif size > self.bin_capacity / 6: + small.append((idx, size)) + else: + tiny.append((idx, size)) + return large, medium, small, tiny + + def _pack_implementation(self, sequence_lengths: List[int]) -> List[List[int]]: + """Pack sequences using the Modified First-Fit Decreasing algorithm. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of bins, where each bin is a list of indices into the original + sequence_lengths list. + """ + # Validate inputs + if self.bin_capacity <= 0: + raise ValueError("bin_capacity must be positive") + if any(l <= 0 for l in sequence_lengths): + raise ValueError("sequence lengths must be positive") + + # Validate sequence lengths don't exceed capacity + self._validate_sequence_lengths(sequence_lengths) + + items: List[Tuple[int, int]] = [(i, l) for i, l in enumerate(sequence_lengths)] + + # Phase-0: classify + large, medium, small, tiny = self._classify_items(items) + + # Sort according to the rules of MFFD + large.sort(key=lambda x: x[1], reverse=True) # descending size + medium.sort(key=lambda x: x[1], reverse=True) + small.sort(key=lambda x: x[1]) # ascending size + tiny.sort(key=lambda x: x[1]) + + # Phase-1: start one bin per large item + bins: List[List[Tuple[int, int]]] = [[item] for item in large] + + # Phase-2: try to add one medium item to each large bin (forward pass) + for b in bins: + remaining = self.bin_capacity - sum(size for _, size in b) + for i, (idx, size) in enumerate(medium): + if size <= remaining: + b.append(medium.pop(i)) + break + + # Phase-3: backward pass – fill with two small items where possible + for b in reversed(bins): + has_medium = any( + self.bin_capacity / 3 < size <= self.bin_capacity / 2 for _, size in b + ) + if has_medium or len(small) < 2: + continue + remaining = self.bin_capacity - sum(size for _, size in b) + if small[0][1] + small[1][1] > remaining: + continue + first_small = small.pop(0) + # pick the *largest* small that fits with first_small (so iterate from end) + second_idx = None + for j in range(len(small) - 1, -1, -1): + if small[j][1] <= remaining - first_small[1]: + second_idx = j + break + if second_idx is not None: + second_small = small.pop(second_idx) + b.extend([first_small, second_small]) + + # Phase-4: forward greedy fit of remaining items + remaining_items = sorted( + medium + small + tiny, key=lambda x: x[1], reverse=True + ) + for b in bins: + while remaining_items: + rem = self.bin_capacity - sum(size for _, size in b) + # if even the smallest remaining doesn't fit we break + if rem < remaining_items[-1][1]: + break + + # pick the first (largest) that fits + chosen_idx = None + for i, (_, size) in enumerate(remaining_items): + if size <= rem: + chosen_idx = i + break + if chosen_idx is None: + break + b.append(remaining_items.pop(chosen_idx)) + + # Phase-5: FFD on leftovers + leftovers = remaining_items # renamed for clarity + ffd_bins: List[List[Tuple[int, int]]] = [] + for idx, size in sorted(leftovers, key=lambda x: x[1], reverse=True): + placed = False + for bin_ffd in ffd_bins: + if size <= self.bin_capacity - sum(s for _, s in bin_ffd): + bin_ffd.append((idx, size)) + placed = True + break + if not placed: + ffd_bins.append([(idx, size)]) + bins.extend(ffd_bins) + + # Convert to list of index lists (discard sizes) + return [[idx for idx, _ in b] for b in bins] + + +def get_packer( + algorithm: Union[PackingAlgorithm, str], + bin_capacity: int, + collect_metrics: bool = False, +) -> SequencePacker: + """Factory function to get a sequence packer based on the algorithm. + + Args: + algorithm: The packing algorithm to use. Can be either a PackingAlgorithm enum value + or a string (case-insensitive) matching one of the enum names. + bin_capacity: The maximum capacity of each bin. + collect_metrics: Whether to collect metrics across multiple packing operations. + + Returns: + A SequencePacker instance for the specified algorithm. + + Raises: + ValueError: If the algorithm is not recognized. + """ + packers: Dict[PackingAlgorithm, Type[SequencePacker]] = { + PackingAlgorithm.CONCATENATIVE: ConcatenativePacker, + PackingAlgorithm.FIRST_FIT_DECREASING: FirstFitDecreasingPacker, + PackingAlgorithm.FIRST_FIT_SHUFFLE: FirstFitShufflePacker, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING: ModifiedFirstFitDecreasingPacker, + } + + # Convert string to enum if needed + if isinstance(algorithm, str): + try: + algorithm = PackingAlgorithm[algorithm.upper()] + except KeyError: + available_algorithms = ", ".join([alg.name for alg in PackingAlgorithm]) + raise ValueError( + f"Unknown packing algorithm: {algorithm}. " + f"Available algorithms: {available_algorithms}" + ) + + if algorithm not in packers: + available_algorithms = ", ".join([alg.name for alg in PackingAlgorithm]) + raise ValueError( + f"Unknown packing algorithm: {algorithm}. " + f"Available algorithms: {available_algorithms}" + ) + + return packers[algorithm](bin_capacity, collect_metrics=collect_metrics) diff --git a/nemo_rl/data/packing/metrics.py b/nemo_rl/data/packing/metrics.py new file mode 100644 index 0000000000..f4c8da0aae --- /dev/null +++ b/nemo_rl/data/packing/metrics.py @@ -0,0 +1,249 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Metrics for evaluating sequence packing algorithms.""" + +import math +import statistics +from typing import Dict, List, Optional + + +class PackingMetrics: + """Class for tracking and computing metrics for sequence packing algorithms. + + This class provides methods to calculate various metrics that evaluate the + efficiency and effectiveness of sequence packing algorithms, such as bin + utilization, waste, and imbalance. + """ + + def __init__(self): + """Initialize the metrics tracker.""" + self.reset() + + def reset(self) -> None: + """Reset all metrics.""" + # Counters for aggregated metrics + self.total_sequences = 0 + self.total_bins = 0 + self.total_sequence_length = 0 + self.total_bin_capacity = 0 + self.total_waste = 0 + self.bin_utilizations = [] + self.bin_counts = [] + self.packing_times = [] + + # Tracking best and worst cases + self.min_utilization = 1.0 + self.max_utilization = 0.0 + self.min_waste_ratio = 1.0 + self.max_waste_ratio = 0.0 + + def update( + self, + sequence_lengths: List[int], + bins: List[List[int]], + bin_capacity: int, + packing_time: Optional[float] = None, + ) -> Dict[str, float]: + """Update metrics with a new packing solution. + + Args: + sequence_lengths: List of sequence lengths + bins: List of bins, where each bin is a list of indices + bin_capacity: Maximum capacity of each bin + packing_time: Optional time taken to compute the packing solution + + Returns: + Dictionary of metrics for this packing solution + """ + # Calculate metrics for this solution + stats = self.calculate_stats_only(sequence_lengths, bins, bin_capacity) + + # Update counters + self.total_sequences += len(sequence_lengths) + self.total_bins += len(bins) + self.total_sequence_length += sum(sequence_lengths) + self.total_bin_capacity += len(bins) * bin_capacity + self.total_waste += stats["total_waste"] + self.bin_utilizations.append(stats["average_utilization"]) + self.bin_counts.append(len(bins)) + + if packing_time is not None: + self.packing_times.append(packing_time) + + # Update min/max values + self.min_utilization = min(self.min_utilization, stats["average_utilization"]) + self.max_utilization = max(self.max_utilization, stats["average_utilization"]) + self.min_waste_ratio = min(self.min_waste_ratio, stats["waste_ratio"]) + self.max_waste_ratio = max(self.max_waste_ratio, stats["waste_ratio"]) + + return stats + + def calculate_stats_only( + self, sequence_lengths: List[int], bins: List[List[int]], bin_capacity: int + ) -> Dict[str, float]: + """Calculate metrics for a packing solution without updating the tracker. + + Args: + sequence_lengths: List of sequence lengths + bins: List of bins, where each bin is a list of indices + bin_capacity: Maximum capacity of each bin + + Returns: + Dictionary of metrics for this packing solution + """ + if not bins: + return { + "num_sequences": 0, + "num_bins": 0, + "total_sequence_length": 0, + "total_bin_capacity": 0, + "total_waste": 0, + "average_utilization": 0.0, + "waste_ratio": 0.0, + "bin_balance": 0.0, + "theoretical_min_bins": 0, + "bin_efficiency": 0.0, + } + + # Calculate bin loads + bin_loads = [ + sum(sequence_lengths[idx] for idx in bin_indices) for bin_indices in bins + ] + + # Calculate basic metrics + num_sequences = len(sequence_lengths) + num_bins = len(bins) + total_sequence_length = sum(sequence_lengths) + total_bin_capacity = num_bins * bin_capacity + total_waste = total_bin_capacity - total_sequence_length + + # Calculate utilization metrics + bin_utilizations = [load / bin_capacity for load in bin_loads] + average_utilization = total_sequence_length / total_bin_capacity + waste_ratio = total_waste / total_bin_capacity + + # Calculate bin balance metrics (standard deviation of utilization) + if num_bins > 1: + bin_balance = 1.0 - statistics.stdev(bin_utilizations) / average_utilization + else: + bin_balance = 1.0 + + # Calculate theoretical minimum number of bins + theoretical_min_bins = math.ceil(total_sequence_length / bin_capacity) + + # Calculate bin efficiency (ratio of theoretical min bins to actual bins) + bin_efficiency = theoretical_min_bins / num_bins if num_bins > 0 else 0.0 + + return { + "num_sequences": num_sequences, + "num_bins": num_bins, + "total_sequence_length": total_sequence_length, + "total_bin_capacity": total_bin_capacity, + "total_waste": total_waste, + "average_utilization": average_utilization, + "waste_ratio": waste_ratio, + "bin_balance": bin_balance, + "theoretical_min_bins": theoretical_min_bins, + "bin_efficiency": bin_efficiency, + } + + def get_aggregated_stats(self) -> Dict[str, float]: + """Get aggregated metrics across all packing operations. + + Returns: + Dictionary of aggregated metrics + """ + if not self.bin_utilizations: + return {} + + # Calculate aggregated metrics + avg_utilization = ( + self.total_sequence_length / self.total_bin_capacity + if self.total_bin_capacity > 0 + else 0.0 + ) + avg_waste_ratio = ( + self.total_waste / self.total_bin_capacity + if self.total_bin_capacity > 0 + else 0.0 + ) + avg_bin_count = ( + sum(self.bin_counts) / len(self.bin_counts) if self.bin_counts else 0.0 + ) + + # Calculate theoretical minimum number of bins + theoretical_min_bins = ( + math.ceil( + self.total_sequence_length / (self.total_bin_capacity / self.total_bins) + ) + if self.total_bins > 0 + else 0 + ) + + # Calculate bin efficiency (ratio of theoretical min bins to actual bins) + bin_efficiency = ( + theoretical_min_bins / self.total_bins if self.total_bins > 0 else 0.0 + ) + + # Calculate average packing time if available + avg_packing_time = ( + sum(self.packing_times) / len(self.packing_times) + if self.packing_times + else None + ) + + stats = { + "total_sequences": self.total_sequences, + "total_bins": self.total_bins, + "average_utilization": avg_utilization, + "min_utilization": self.min_utilization, + "max_utilization": self.max_utilization, + "average_waste_ratio": avg_waste_ratio, + "min_waste_ratio": self.min_waste_ratio, + "max_waste_ratio": self.max_waste_ratio, + "average_bin_count": avg_bin_count, + "bin_efficiency": bin_efficiency, + } + + if avg_packing_time is not None: + stats["average_packing_time"] = avg_packing_time + + return stats + + def print_aggregated_stats(self) -> None: + """Print the aggregated metrics in a formatted way.""" + stats = self.get_aggregated_stats() + + if not stats: + print("No metrics collected yet.") + return + + print("\n=== Packing Metrics Summary ===") + print(f"Total sequences packed: {stats['total_sequences']}") + print(f"Total bins used: {stats['total_bins']}") + print( + f"Average bin utilization: {stats['average_utilization']:.4f} (min: {stats['min_utilization']:.4f}, max: {stats['max_utilization']:.4f})" + ) + print( + f"Average waste ratio: {stats['average_waste_ratio']:.4f} (min: {stats['min_waste_ratio']:.4f}, max: {stats['max_waste_ratio']:.4f})" + ) + print( + f"Bin efficiency (theoretical min bins / actual bins): {stats['bin_efficiency']:.4f}" + ) + + if "average_packing_time" in stats: + print(f"Average packing time: {stats['average_packing_time']:.6f} seconds") + + print("===============================\n") diff --git a/nemo_rl/distributed/batched_data_dict.py b/nemo_rl/distributed/batched_data_dict.py index dc30d68364..969738d203 100644 --- a/nemo_rl/distributed/batched_data_dict.py +++ b/nemo_rl/distributed/batched_data_dict.py @@ -28,6 +28,7 @@ import torch from typing_extensions import Self +from nemo_rl.data.packing import get_packer from nemo_rl.distributed.collectives import ( gather_jagged_object_lists, rebalance_nd_tensor, @@ -36,6 +37,21 @@ DictT = TypeVar("DictT", bound=Mapping[str, Any]) +class SequencePackingArgs(TypedDict): + """Configuration settings for sequence packing. + + Pass this to 'shard_by_batch_size()' to preprocess batches for sequence packing. + """ + + max_tokens_per_microbatch: int + input_key: str + input_lengths_key: str + algorithm: str + sequence_length_pad_multiple: ( + int # pad each sequence to a multiple of this value (for CP/TP alignment) + ) + + class DynamicBatchingArgs(TypedDict): """Configuration settings for dynamic batching. @@ -58,6 +74,7 @@ def __init__(self, *args, **kwargs): self.micro_batch_indices = None self.micro_batch_lengths = None + self.elem_counts_per_gb = None @classmethod def from_batches( @@ -204,6 +221,7 @@ def shard_by_batch_size( batch_size: Optional[int] = None, allow_uneven_shards: bool = False, dynamic_batching_args: Optional[DynamicBatchingArgs] = None, + sequence_packing_args: Optional[SequencePackingArgs] = None, ) -> list["SlicedDataDict"] | tuple[list["SlicedDataDict"], list[int]]: """Shards a batch by first dividing it into chunks of size batch_size, then further dividing each chunk into shards equal parts. Finally aggregates the sub-shards by their position. @@ -219,7 +237,7 @@ def shard_by_batch_size( allow_uneven_shards (bool): Whether to allow shards to be unevenly sized. If True, the last shard may be smaller than the others. dynamic_batching_args (dict): If passed, preprocess batch for dynamic batching. This - dict requires two keys: + dict requires four keys: 1. max_tokens_per_microbatch (int): the maximum number of tokens in a microbatch 2. sequence_length_round (int): round each all @@ -229,6 +247,21 @@ def shard_by_batch_size( 4. input_lengths_key (str): the key in the batch which holds the sequence length per value. The sequence dim index is assumed to be 1. + Cannot be passed with sequence_packing_args. + + sequence_packing_args (dict): If passed, preprocess batch for sequence packing. This + dict requires five keys: + 1. max_tokens_per_microbatch (int): the maximum + number of tokens in a microbatch + 2. input_key (str): the key in the batch + which holds input ids. + 3. input_lengths_key (str): the key in the batch + which holds the sequence length per value. + The sequence dim index is assumed to be 1. + 4. algorithm (str): the algorithm to use for sequence packing. + 5. sequence_length_pad_multiple (int): the multiple to pad each sequence to. + With CP enabled, this should be set to a multiple of 2*CP and SP. + Cannot be passed with dynamic_batching_args. Returns: list[BatchedDataDict]: A list of BatchedDataDicts, length equal to shards. @@ -268,6 +301,9 @@ def shard_by_batch_size( assert batch_size is None, ( "batch_size must be None if allow_uneven_shards is True" ) + assert dynamic_batching_args is None or sequence_packing_args is None, ( + "dynamic_batching_args and sequence_packing_args cannot be passed together" + ) # Get the total batch size batch_sizes = set() @@ -336,6 +372,112 @@ def shard_by_batch_size( else: sorted_v = [v[i] for i in batch_sorted_indices] data[k] = sorted_v + + elif sequence_packing_args is not None: + bin_packer = get_packer( + algorithm=sequence_packing_args["algorithm"], + bin_capacity=sequence_packing_args["max_tokens_per_microbatch"], + collect_metrics=False, # TODO(ahmadki): make configurable + ) + + input_lengths_key = sequence_packing_args["input_lengths_key"] + input_lens = self.data[input_lengths_key] + if not isinstance(input_lens, torch.Tensor): + input_lens = torch.tensor(input_lens) + + pad_multiple = sequence_packing_args["sequence_length_pad_multiple"] + + def _get_padded_seqlen(seqlen: int) -> int: + return (seqlen + pad_multiple - 1) // pad_multiple * pad_multiple + + # Store bin assignments for each chunk to reuse later + all_chunk_bin_assignments = [] + + # Process each chunk separately to respect chunk boundaries + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * batch_size + chunk_end = (chunk_idx + 1) * batch_size + + # Get sequence lengths for this chunk + chunk_seqlens = input_lens[chunk_start:chunk_end] + chunk_padded_seqlens_list = [ + _get_padded_seqlen(seq_len.item()) for seq_len in chunk_seqlens + ] + + # Pack sequences in this chunk into bins + chunk_bin_assignments = bin_packer.pack( + sequence_lengths=chunk_padded_seqlens_list, + ) + all_chunk_bin_assignments.append(chunk_bin_assignments) + + # create shards with the packed bins + sharded_data: list[list[dict]] = [[] for _ in range(shards)] + sharded_micro_indices: list = [[] for _ in range(shards)] + sharded_micro_lengths: list = [[] for _ in range(shards)] + sharded_elem_counts_per_gb: list = [[] for _ in range(shards)] + global_indices_per_shard: list[list[int]] = [[] for _ in range(shards)] + for chunk_idx in range(num_chunks): + chunk_sharded_micro_indices: list[list[list[int]]] = [ + [] for _ in range(shards) + ] + chunk_sharded_micro_lengths: list[list[int]] = [ + [] for _ in range(shards) + ] + + num_bins = len(all_chunk_bin_assignments[chunk_idx]) + chunk_start = chunk_idx * batch_size + for bin_idx in range(num_bins): + shard_idx = bin_idx % shards + bin_indices = all_chunk_bin_assignments[chunk_idx][bin_idx] + global_bin_indices = [i + chunk_start for i in bin_indices] + sharded_data[shard_idx].append( + self.select_indices(global_bin_indices) + ) + global_indices_per_shard[shard_idx].extend(global_bin_indices) + bin_seqlen = sum( + [ + _get_padded_seqlen(input_lens[i].item()) + for i in global_bin_indices + ] + ) + + if chunk_sharded_micro_indices[shard_idx] == []: + chunk_sharded_micro_indices[shard_idx].append( + [0, len(bin_indices)] + ) + else: + prev_bin_end = chunk_sharded_micro_indices[shard_idx][-1][1] + chunk_sharded_micro_indices[shard_idx].append( + [prev_bin_end, prev_bin_end + len(bin_indices)] + ) + chunk_sharded_micro_lengths[shard_idx].append(bin_seqlen) + + for shard_idx in range(shards): + sharded_micro_indices[shard_idx].append( + chunk_sharded_micro_indices[shard_idx] + ) + sharded_micro_lengths[shard_idx].append( + chunk_sharded_micro_lengths[shard_idx] + ) + sharded_elem_counts_per_gb[shard_idx].append( + chunk_sharded_micro_indices[shard_idx][-1][1] + ) + + # flatten global_indices_per_shard + batch_sorted_indices = [] + for shard_idx in range(shards): + batch_sorted_indices.extend(global_indices_per_shard[shard_idx]) + + aggregated_shards = [] + for shard_idx in range(shards): + shard = SlicedDataDict.from_batches(sharded_data[shard_idx]) + shard.micro_batch_indices = sharded_micro_indices[shard_idx] + shard.micro_batch_lengths = sharded_micro_lengths[shard_idx] + shard.elem_counts_per_gb = sharded_elem_counts_per_gb[shard_idx] + aggregated_shards.append(shard) + + return aggregated_shards, batch_sorted_indices + else: data = self.data @@ -457,7 +599,7 @@ def shard_by_batch_size( return aggregated_shards - def get_batch(self, batch_idx, batch_size) -> "SlicedDataDict": + def get_batch(self, batch_idx, batch_size=None) -> "SlicedDataDict": """Slices a subbatch from the batch. Args: @@ -467,6 +609,21 @@ def get_batch(self, batch_idx, batch_size) -> "SlicedDataDict": Returns: BatchedDataDict: A new BatchedDataDict containing the sliced data """ + if self.elem_counts_per_gb is not None: + assert self.micro_batch_indices is not None, ( + "micro_batch_indices must be provided if sequence_packing is True" + ) + elem_count = self.elem_counts_per_gb[batch_idx] + cum_elem_count = [0] + for i in range(len(self.elem_counts_per_gb)): + cum_elem_count.append(cum_elem_count[i] + self.elem_counts_per_gb[i]) + + batch = self.slice(cum_elem_count[batch_idx], cum_elem_count[batch_idx + 1]) + batch.micro_batch_indices = [self.micro_batch_indices[batch_idx]] + batch.micro_batch_lengths = [self.micro_batch_lengths[batch_idx]] # type: ignore # This exists if idxs do + batch.elem_counts_per_gb = [elem_count] + return batch + start = batch_size * batch_idx end = batch_size * (batch_idx + 1) batch = self.slice(start, end) @@ -488,6 +645,10 @@ def slice(self, start: int, end: int) -> "SlicedDataDict": """ sliced_batch = SlicedDataDict() for k in self.data: + if isinstance(self.data[k], torch.Tensor): + assert end <= self.data[k].shape[0], ( + f"end: {end} is greater than the shape of the tensor: {self.data[k].shape[0]} for key: {k}" + ) sliced_batch[k] = self.data[k][start:end] return sliced_batch @@ -520,7 +681,7 @@ def make_microbatch_iterator_with_dynamic_shapes( self, sequence_dim: int = 1, ) -> Iterator["SlicedDataDict"]: - """Makes an interator that yields microbatchs of dynamic batch and sequence sizes. + """Makes an iterator that yields microbatchs of dynamic batch and sequence sizes. Args: sequence_dim: the index of the sequence dim for all tensors in the data dict @@ -542,9 +703,29 @@ def make_microbatch_iterator_with_dynamic_shapes( yield mb def get_microbatch_iterator_dynamic_shapes_len(self) -> int: - """Get the length of the microbatch iterator with dynamic shapes.""" + """Get the length of the microbatch iterator for dynamic shapes.""" return len(self.micro_batch_indices[0]) + def make_microbatch_iterator_for_packable_sequences( + self, + ) -> Iterator["SlicedDataDict"]: + """Make an iterator over the batch that yields microbatches that can be packed into a given max_tokens_per_microbatch.""" + assert ( + self.micro_batch_indices is not None + and len(self.micro_batch_indices) == 1 + and self.micro_batch_lengths is not None + ) + + for seqlen, (start_idx, end_idx) in zip( + self.micro_batch_lengths[0], self.micro_batch_indices[0] + ): + mb = self.slice(start_idx, end_idx) + yield mb + + def get_microbatch_iterator_for_packable_sequences_len(self) -> tuple[int, int]: + """Get the length of the microbatch iterator for sequence packing and the max packed seqlen.""" + return len(self.micro_batch_indices[0]), max(self.micro_batch_lengths[0]) + def make_microbatch_iterator( self, microbatch_size: int ) -> Iterator["SlicedDataDict"]: diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 31ac71cc23..606fd8464b 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -120,21 +120,21 @@ def backward( return grad_input, None, None, None, None, None, None -def from_parallel_logits_to_logprobs( +def dtensor_from_parallel_logits_to_logprobs( vocab_parallel_logits: torch.Tensor, - target: torch.Tensor | DTensor, + target: DTensor | torch.Tensor, vocab_start_index: int, vocab_end_index: int, tp_group: torch.distributed.ProcessGroup, inference_only: bool = False, seq_index: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Get log probabilities from TP sharded vocab logits. + """Get log probabilities from TP+CP sharded vocab logits. Args: - vocab_parallel_logits (torch.Tensor): Logits tensor with shape [batch_size, seq_len, vocab_size//TP] - where TP is the tensor parallel size. - target (torch.Tensor): Target token indices with shape [batch_size, seq_len]. + vocab_parallel_logits (orch.Tensor): Logits distributed across tensor parallel workers, + with shape [batch_size, seq_len, vocab_size/tp_size]. + target (DTensor): Target token indices with shape [batch_size, seq_len]. NOTE: Must be the unmodified targets as this function will shift them internally. vocab_start_index (int): Starting vocabulary index for this worker's partition. vocab_end_index (int): Ending vocabulary index for this worker's partition. @@ -146,8 +146,6 @@ def from_parallel_logits_to_logprobs( Returns: torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. The sequence dimension is reduced by 1 due to the target shifting. - - Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354 """ cp_size = 1 @@ -188,3 +186,286 @@ def from_parallel_logits_to_logprobs( assert probs.shape == target_shape return probs[:, :-1] + + +def from_parallel_logits_to_logprobs( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + cp_group: Optional[torch.distributed.ProcessGroup] = None, +) -> torch.Tensor: + """Get log probabilities from TP+CP sharded vocab logits. + + Args: + vocab_parallel_logits (torch.Tensor): Logits tensor with shape [batch_size, seq_len // CP, vocab_size // TP] + where TP is the tensor parallel size. + target (torch.Tensor): Target token indices with shape [batch_size, seq_len]. + NOTE: Must be the unmodified targets as this function will shift them internally. + vocab_start_index (int): Starting vocabulary index for this worker's partition. + vocab_end_index (int): Ending vocabulary index for this worker's partition. + tp_group (torch.distributed.ProcessGroup): Process group for distributed communication. + inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. + cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None. + + Returns: + torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. + The sequence dimension is reduced by 1 due to the target shifting. + + Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354 + """ + target = target.roll(shifts=-1, dims=-1) + cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) + pad_len = 0 + # if cp_size > 1: + # Pad the targets to local size * cp_size + pad_len = vocab_parallel_logits.shape[1] * cp_size - target.shape[1] + if pad_len > 0: + target = torch.nn.functional.pad(target, (0, pad_len), value=0) + + # Shard the targets by context parallelism + cp_rank = torch.distributed.get_rank(cp_group) + target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1) + + probs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + tp_group, + inference_only, + ).contiguous() + + if cp_size > 1: + # we need to gather the logits by context parallelism + probs = allgather_cp_sharded_tensor( + probs, cp_group, seq_dim=1 + ) # , unpadded_seqlen=target.shape[1]) + + if pad_len > 0: + probs = probs[:, :-pad_len] + + return probs[:, :-1] + + +def from_parallel_logits_to_logprobs_packed_sequences( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + unpacked_seqlen: int, + vocab_start_index: int, + vocab_end_index: int, + group: torch.distributed.ProcessGroup, + inference_only: bool = False, + cp_group: Optional[torch.distributed.ProcessGroup] = None, +) -> torch.Tensor: + """Get log probabilities from TP sharded vocab logits for packed sequences. + + Args: + vocab_parallel_logits (torch.Tensor): Packed logits tensor with shape [1, T // CP, vocab_size//TP] + where T is the total number of tokens across all packed sequences. + target (torch.Tensor): Packed target token indices with shape [1, T]. + NOTE: Must be the unmodified targets as this function will shift them internally. + cu_seqlens (torch.Tensor): Cumulative sequence lengths tensor with shape [batch_size + 1]. + cu_seqlens[i] indicates the start position of sequence i in the packed format. + unpacked_seqlen (int): The length of the unpacked sequence tensor. + vocab_start_index (int): Starting vocabulary index for this worker's partition. + vocab_end_index (int): Ending vocabulary index for this worker's partition. + group (torch.distributed.ProcessGroup): Process group for distributed communication. + inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. + cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None. + + Returns: + torch.Tensor: Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1]. + The total length is reduced by batch_size due to target shifting (one token per sequence). + """ + # Remove batch dimension to work with [T, vocab_size] and [T] + vocab_parallel_logits = vocab_parallel_logits.squeeze(0) + target = target.squeeze(0) + + batch_size = cu_seqlens_padded.shape[0] - 1 + cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) + cp_rank = 0 if cp_group is None else torch.distributed.get_rank(cp_group) + + # Roll each sequence individually + rolled_targets = torch.zeros( + target.shape[0] // cp_size, dtype=target.dtype, device=target.device + ) + for i in range(batch_size): + start_idx = cu_seqlens_padded[i].item() + end_idx = cu_seqlens_padded[i + 1].item() + + # Get the sequence targets and roll by -1 + seq_targets = target[start_idx:end_idx] + rolled_seq_targets = seq_targets.roll(shifts=-1, dims=0) + rolled_targets[start_idx // cp_size : end_idx // cp_size] = ( + _get_tokens_on_this_cp_rank(rolled_seq_targets, cp_rank, cp_size, seq_dim=0) + ) + + # Add batch dimension back for DistributedLogprob + rolled_targets = rolled_targets.unsqueeze(0) + vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0) + + # Apply distributed log probability computation + probs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + vocab_start_index, + vocab_end_index, + group, + inference_only, + ).contiguous() + + # Remove batch dimension for filtering + probs = probs.squeeze(0) + + # Ensure probs is 1D after squeezing + if probs.dim() != 1: + raise ValueError( + f"Expected probs to be 1D after squeezing, but got shape {probs.shape}. " + f"Original shape before squeeze: {probs.unsqueeze(0).shape}" + ) + + if cp_size > 1: + # per-sequence cp_allgather + final_probs = torch.zeros(probs.shape[0] * cp_size, device=probs.device) + for i in range(batch_size): + start_idx = cu_seqlens_padded[i].item() + end_idx = cu_seqlens_padded[i + 1].item() + final_probs[start_idx:end_idx] = allgather_cp_sharded_tensor( + probs[start_idx // cp_size : end_idx // cp_size], cp_group, seq_dim=0 + ) + probs = final_probs + + out_logprobs = torch.zeros( + (batch_size, unpacked_seqlen - 1), dtype=probs.dtype, device=probs.device + ) + # Filter out the last token of each sequence + for i in range(batch_size): + start_idx = cu_seqlens_padded[i].item() + end_idx = cu_seqlens_padded[i + 1].item() + + # Exclude the last position (which has the rolled target from position 0) + if end_idx - start_idx > 0: + seq_probs = probs[start_idx : end_idx - 1] + # Ensure seq_probs is 1D + if seq_probs.dim() > 1: + seq_probs = seq_probs.squeeze() + + # Ensure we don't exceed the unpacked sequence length + seq_len = min(seq_probs.shape[0], unpacked_seqlen - 1) + if seq_len > 0: + out_logprobs[i, :seq_len] = seq_probs[:seq_len] + + return out_logprobs + + +def _get_tokens_on_this_cp_rank( + input_ids: torch.Tensor, + cp_rank: int, + cp_size: int, + seq_dim: int = 1, +) -> torch.Tensor: + """Get tokens on this context parallelism rank. + + Assumes that input_ids are already padded to a multiple of cp_size * 2 or cp_size == 1. + + Args: + input_ids: Input token IDs [seq_length, ] + cp_rank: Context parallelism rank + cp_size: Context parallelism size + + Returns: + Tokens on this context parallelism rank [1, seq_length // cp_size] + """ + if cp_size == 1: + return input_ids + + # load balance for causal attention + shard_size = input_ids.shape[seq_dim] // (cp_size * 2) + shard_inds = (cp_rank, (cp_size * 2) - cp_rank - 1) + + # Create slices for each dimension + slices = [slice(None)] * input_ids.dim() + ids_chunks = [] + + for ind in shard_inds: + slices[seq_dim] = slice(ind * shard_size, (ind + 1) * shard_size) + ids_chunks.append(input_ids[slices]) + + ids = torch.cat(ids_chunks, dim=seq_dim) + return ids + + +def allgather_cp_sharded_tensor( + tensor, cp_group, seq_dim=1 +): # , unpadded_seqlen=None): + return AllGatherCPTensor.apply(tensor, cp_group, seq_dim) # , unpadded_seqlen) + + +class AllGatherCPTensor(torch.autograd.Function): + def forward( + ctx, tensor, cp_group: torch.distributed.ProcessGroup, seq_dim=1 + ): # , unpadded_seqlen: Optional[int] = None): + cp_size = torch.distributed.get_world_size(cp_group) + cp_rank_chunks = [] + for _ in range(cp_size): + cp_rank_chunks.append(torch.empty_like(tensor)) + + torch.distributed.all_gather( + tensor_list=cp_rank_chunks, tensor=tensor, group=cp_group + ) + + # undo the CP load balancing chunking + tensor_chunks = [] + for logit_chunk in cp_rank_chunks: + tensor_chunks.extend(torch.chunk(logit_chunk, chunks=2, dim=seq_dim)) + + chunk_indices = [] + for cp_rank in range(cp_size): + chunk_indices.append(cp_rank) + chunk_indices.append(2 * cp_size - cp_rank - 1) + + chunks_and_indices = list(zip(tensor_chunks, chunk_indices)) + chunks_and_indices = sorted(chunks_and_indices, key=lambda tup: tup[1]) + ret_tensor = [chunk for chunk, _ in chunks_and_indices] + ret_tensor = torch.cat(ret_tensor, dim=seq_dim) + + ctx.seq_dim = seq_dim + ctx.cp_group = cp_group + # ctx.unpadded_seqlen = unpadded_seqlen + + return ret_tensor + + def backward(ctx, grad_output): + cp_size = torch.distributed.get_world_size(ctx.cp_group) + cp_rank = torch.distributed.get_rank(ctx.cp_group) + torch.distributed.all_reduce(grad_output, group=ctx.cp_group) + + # chunk the seqdim in 2*cp chunks, and select with a CP load balanced indexing + seq_dim = ctx.seq_dim + # if ctx.unpadded_seqlen is not None: + # # Zero out grad_output along the seq_dim after unpadded_seqlen + # slicer = [slice(None)] * grad_output.dim() + # slicer[seq_dim] = slice(ctx.unpadded_seqlen, None) + # grad_output[tuple(slicer)] = 0 + + grad_output = grad_output.view( + *grad_output.shape[0:seq_dim], + 2 * cp_size, + grad_output.shape[seq_dim] // (2 * cp_size), + *grad_output.shape[(seq_dim + 1) :], + ) + + index = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True + ).cuda(non_blocking=True) + + grad_input = grad_output.index_select(seq_dim, index) + grad_input = grad_input.view( + *grad_input.shape[0:seq_dim], -1, *grad_input.shape[(seq_dim + 2) :] + ) + + return grad_input, None, None # , None diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index f668834f19..cc069be65d 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -41,7 +41,7 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM -from nemo_rl.distributed.model_utils import from_parallel_logits_to_logprobs +from nemo_rl.distributed.model_utils import dtensor_from_parallel_logits_to_logprobs from nemo_rl.models.policy.utils import import_class_from_path @@ -621,8 +621,10 @@ def get_logprobs_from_vocab_parallel_logits( Args: vocab_parallel_logits (DTensor): Logits distributed across tensor parallel workers, with shape [batch_size, seq_len, vocab_size/tp_size]. - input_ids (torch.Tensor): Input token IDs for which to compute log probabilities, + input_ids (torch.Tensor | DTensor): Input token IDs for which to compute log probabilities, with shape [batch_size, seq_len]. + seq_index (Optional[torch.Tensor]): Sequence index for the input IDs, + with shape [sequence_length]. Returns: torch.Tensor: Log probabilities for the given input IDs. @@ -641,7 +643,7 @@ def get_logprobs_from_vocab_parallel_logits( vocab_interval_per_rank = vocab_parallel_logits.shape[-1] // tp_size - return from_parallel_logits_to_logprobs( + return dtensor_from_parallel_logits_to_logprobs( vocab_parallel_logits.to_local(), input_ids, vocab_interval_per_rank * tp_rank, diff --git a/nemo_rl/models/huggingface/common.py b/nemo_rl/models/huggingface/common.py index df913f95b4..c057f6d89a 100644 --- a/nemo_rl/models/huggingface/common.py +++ b/nemo_rl/models/huggingface/common.py @@ -12,10 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from enum import Enum, auto +from typing import Optional, Tuple, TypeVar +import torch from transformers import AutoConfig +Tensor = TypeVar("Tensor", bound=torch.Tensor) + + +@dataclass +class FlashAttentionKwargs: + """Dataclass to hold FlashAttention v2 kwargs.""" + + cu_seqlens_q: Tensor + cu_seqlens_k: Tensor + max_seqlen_q: int + max_seqlen_k: int + class ModelFlag(Enum): """Enum that defines special flags for model-specific behaviors. @@ -53,3 +68,241 @@ def is_gemma_model(model_name: str) -> bool: "gemma3", "gemma3_text", ] + + +def group_and_cat_tensors( + tensors: list[torch.Tensor], + group_sizes: list[int], + padding_value: int = 0, + min_seq_len: int = 0, +) -> torch.Tensor: + """Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor. + + Each group of 1D tensors is concatenated into a single 1D tensor, and all resulting + group tensors are padded to the same length and stacked into a 2D tensor. + + Args: + tensors: List of 1D tensors of varying lengths. + group_sizes: List of integers. Each integer specifies how many tensors to group. + padding_value: Integer used to pad shorter sequences. + min_seq_len: Minimum sequence length. + + Returns: + A 2D tensor where each row is a padded concatenation of the grouped tensors. + + Example: + >>> tensors = [ + ... torch.tensor([1, 2]), + ... torch.tensor([3]), + ... torch.tensor([4, 5, 6]), + ... torch.tensor([7]) + ... ] + >>> group_sizes = [2, 2] + >>> group_and_cat_tensors(tensors, group_sizes, padding_value=-1) + tensor([[ 1, 2, 3, -1, -1], + [ 4, 5, 6, 7, -1]]) + """ + grouped = [] + index = 0 + for size in group_sizes: + group = tensors[index : index + size] + concat = torch.cat(group, dim=0) + grouped.append(concat) + index += size + + # Compute the maximum length for padding + max_len = max(t.size(0) for t in grouped) + max_len = max(max_len, min_seq_len) + + # Pad each tensor to max_len + padded = torch.stack( + [ + torch.nn.functional.pad(t, (0, max_len - t.size(0)), value=padding_value) + for t in grouped + ] + ) + + return padded + + +def pack_sequences( + input_ids: torch.Tensor, + input_lengths: torch.Tensor, + packed_sequence_size: list[int], + padding_value: int = 0, + return_attention_mask: bool = True, + min_seq_len: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Packs sequences into rows where each row concatenates multiple sequences. + + Useful for sequence packing in transformer models (e.g. for SFT training). Returns: + packed input_ids, packed position_ids, and optional attention_mask. + + Args: + input_ids (torch.Tensor): Tensor of shape [num_sequences, max_seq_len] + input_lengths (torch.Tensor): Tensor of shape [num_sequences], containing true lengths + packed_sequence_size (List[int]): How many sequences to pack per row + padding_value (int): Pad value for input_ids + return_attention_mask (bool): Whether to return per-row causal attention mask + min_seq_len (int): Minimum sequence length. + + Returns: + Tuple: + input_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] + position_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] + attention_mask (Optional[torch.Tensor]): [batch_size, max_len, max_len] if requested + + Example: + >>> input_ids = torch.tensor([ + ... [1, 2, 0, 0], # len 2 + ... [3, 4, 5, 0], # len 3 + ... [6, 0, 0, 0], # len 1 + ... [7, 8, 9, 9], # len 4 + ... [8, 7, 0, 0], # len 2 + ... [6, 0, 0, 0], # len 1 + ... [5, 4, 3, 0], # len 3 + ... ]) + >>> input_lengths = torch.tensor([2, 3, 1, 4, 2, 1, 3]) + >>> packed_sequence_size = [3, 4] + >>> input_ids_packed, position_ids_packed, attention_mask = pack_sequences( + ... input_ids, input_lengths, packed_sequence_size, padding_value=-1, return_attention_mask=True + ... ) + >>> input_ids_packed + tensor([ + [ 1, 2, 3, 4, 5, 6, -1, -1, -1, -1], + [ 7, 8, 9, 9, 8, 7, 6, 5, 4, 3] + ]) + >>> position_ids_packed + tensor([ + [0, 1, 0, 1, 2, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 0, 1, 0, 0, 1, 2] + ]) + >>> attention_mask[0] + tensor([ + [ True, True, False, False, False, False, False, False, False, False], + [False, False, True, True, True, False, False, False, False, False], + [False, False, False, False, False, True, False, False, False, False], + [False, False, False, False, False, False, False, False, False, False], + ]) + >>> attention_mask[1] + tensor([ + [ True, True, True, True, False, False, False, False, False, False], + [False, False, False, False, True, True, True, False, False, False], + [False, False, False, False, False, False, True, True, True, True], + [False, False, False, False, False, False, False, True, True, True], + ]) + """ + flat_input_ids = [] + position_ids = [] + flat_lengths = input_lengths.tolist() + + for i, seq_len in enumerate(flat_lengths): + flat_input_ids.append(input_ids[i, :seq_len]) + position_ids.append( + torch.arange(seq_len, dtype=torch.long, device=input_ids.device) + ) + + # Group and pad + input_ids_packed = group_and_cat_tensors( + flat_input_ids, packed_sequence_size, padding_value, min_seq_len=min_seq_len + ) + position_ids_packed = group_and_cat_tensors( + position_ids, packed_sequence_size, padding_value=0, min_seq_len=min_seq_len + ) + + # Compute max length + batch_size, max_seq_len = input_ids_packed.shape + + attention_mask = None + if return_attention_mask: + attention_mask = torch.zeros( + (batch_size, max_seq_len, max_seq_len), + dtype=torch.bool, + device=input_ids.device, + ) + index = 0 + for i, group_size in enumerate(packed_sequence_size): + group_lengths = flat_lengths[index : index + group_size] + total_len = sum(group_lengths) + attention_mask[i, :total_len, :total_len] = torch.tril( + torch.ones( + (total_len, total_len), dtype=torch.bool, device=input_ids.device + ) + ) + index += group_size + + return input_ids_packed, position_ids_packed, attention_mask + + +# TODO(ahmadki): the function doesn't actually handle returning 2D tensors because none of the backends support this. +# but we should support this anyways +def unpack_tensor(tensor, input_lengths): + """Unpacks a packed tensor into individual sequences padded to the same length. + + Args: + tensor (torch.Tensor): Packed tensor of shape [batch_size, packed_seq_len]. + packed_lengths (List[int]): Original sequence lengths in the order they were packed. + + Returns: + torch.Tensor: [num_sequences, max_seq_len], each row is one unpacked and padded sequence. + + Example: + >>> packed_tensor = torch.tensor([ + ... [1, 2, 3, 4, 5, 6, -1, -1], + ... [7, 8, 9, 9, 8, 7, 6, -1] + ... ]) + >>> packed_lengths = [2, 3, 1, 4, 2] + >>> unpack_tensor(packed_tensor, packed_lengths) + tensor([ + [1, 2, 0, 0], + [3, 4, 5, 0], + [6, 0, 0, 0], + [7, 8, 9, 9], + [8, 7, 0, 0], + ]) + """ + packed_seqlen = tensor.shape[1] + splitsizes = input_lengths.tolist() + splitsizes.append(packed_seqlen - sum(splitsizes)) + tensor_split = torch.split(tensor, tuple(splitsizes), dim=1) + + max_len = max(input_lengths.tolist()) # max sequence length in the batch + + tensor_stacked = [] + for t in tensor_split[0:-1]: + padding_needed = max_len - t.shape[1] + tensor_stacked.append( + torch.nn.functional.pad( + t, (0, 0, 0, padding_needed), mode="constant", value=0.0 + ) + ) + return torch.cat(tensor_stacked, dim=0) + + +def get_flash_attention_kwargs(input_lengths: torch.Tensor) -> FlashAttentionKwargs: + """Returns kwargs required for FlashAttention v2 forward functions. + + Args: + input_lengths (torch.Tensor): [batch_size] containing lengths of each sequence + + Returns: + Dict[str, torch.Tensor | int]: + { + "cu_seqlens_q": Tensor[int32], + "cu_seqlens_k": Tensor[int32], + "max_seqlen_q": int, + "max_seqlen_k": int + } + """ + input_lengths_int32 = input_lengths.to(torch.int32) + cu_seqlens = torch.nn.functional.pad( + input_lengths_int32.cumsum(dim=0), (1, 0) + ) # prepend 0 + max_len = input_lengths.max().item() + + return FlashAttentionKwargs( + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens.clone(), # same for self-attention + max_seqlen_q=max_len, + max_seqlen_k=max_len, + ) diff --git a/nemo_rl/models/megatron/common.py b/nemo_rl/models/megatron/common.py index 5c6431b15e..bc0d499f08 100644 --- a/nemo_rl/models/megatron/common.py +++ b/nemo_rl/models/megatron/common.py @@ -12,20 +12,240 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Iterator +from typing import Any, Iterator, Optional import torch import torch.distributed as dist from megatron.core.models.gpt import GPTModel +from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.parallel_state import ( + get_context_parallel_group, + get_context_parallel_rank, + get_context_parallel_world_size, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, ) from megatron.training.utils import get_ltor_masks_and_position_ids from nemo.tron.state import GlobalState -from nemo_rl.algorithms.loss_functions import LossFunction +from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank + + +def _pack_sequences_for_megatron( + input_ids: torch.Tensor, + seq_lengths: torch.Tensor, + pad_individual_seqs_to_multiple_of: int = 1, + pad_packed_seq_to: Optional[int] = None, + cp_rank: int = 0, + cp_size: int = 1, +) -> tuple[torch.Tensor, PackedSeqParams, torch.Tensor, Optional[torch.Tensor]]: + """Pack sequences for Megatron model processing with optional context parallelism. + + Args: + input_ids: Input token IDs [batch_size, seq_length] + seq_lengths: Actual sequence lengths for each sample [batch_size] + pad_individual_seqs_to_multiple_of: Pad individual sequences to a multiple of this value + pad_packed_seq_to: Pad packed sequences to this value (before CP) + cp_size: Context parallelism size + + Returns: + Tuple of: + - packed_input_ids: Packed input tensor [1, T] + - input_ids_cp_sharded: Sharded input tensor [cp_size, T // cp_size] + - packed_seq_params: PackedSeqParams object + - cu_seqlens: Cumulative sequence lengths + - cu_seqlens_padded: Padded cumulative sequence lengths + """ + batch_size = input_ids.shape[0] + + # Build cumulative sequence lengths (cu_seqlens) and extract valid tokens + cu_seqlens = [0] + cu_seqlens_padded = ( + [0] + if pad_individual_seqs_to_multiple_of > 1 or pad_packed_seq_to is not None + else None + ) + valid_tokens = [] + + pad_factor = pad_individual_seqs_to_multiple_of + + for b in range(batch_size): + seq_len = ( + seq_lengths[b].item() if torch.is_tensor(seq_lengths[b]) else seq_lengths[b] + ) + + # Extract valid tokens for this sequence + valid_tokens.append(input_ids[b, :seq_len]) + + # Update cumulative sequence lengths + cu_seqlens.append(cu_seqlens[-1] + seq_len) + + # For context parallelism, track padded sequence lengths + if pad_factor > 1 or pad_packed_seq_to is not None: + # Pad sequence length to multiple of (cp_size * 2) + padded_seq_len = ((seq_len + pad_factor - 1) // pad_factor) * pad_factor + cu_seqlens_padded.append(cu_seqlens_padded[-1] + padded_seq_len) + + # Convert to tensors + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=input_ids.device) + if pad_factor > 1 or pad_packed_seq_to is not None: + cu_seqlens_padded = torch.tensor( + cu_seqlens_padded, dtype=torch.int32, device=input_ids.device + ) + if pad_packed_seq_to is not None: + cu_seqlens_padded[-1] = pad_packed_seq_to + + # Calculate max sequence length (padded if using CP) + if pad_factor > 1 or (pad_packed_seq_to is not None): + seq_lens_padded = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + max_seqlen = seq_lens_padded.max().item() + else: + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + + # Concatenate all valid tokens + # If using individual padding, we need to pad individual sequences + # CP will always need padding (of at least cp_size * 2) + running_seq_len = 0 + if pad_factor > 1: + all_input_ids = [] + padded_tokens = [] + for b in range(batch_size): + seq_len = ( + seq_lengths[b].item() + if torch.is_tensor(seq_lengths[b]) + else seq_lengths[b] + ) + # if last element, pad to the max sequence length + if b == batch_size - 1 and pad_packed_seq_to is not None: + padded_seq_len = pad_packed_seq_to - running_seq_len + running_seq_len += padded_seq_len + else: + padded_seq_len = ((seq_len + pad_factor - 1) // pad_factor) * pad_factor + + running_seq_len += padded_seq_len + + # Pad this sequence to the required length + seq_tokens = input_ids[b, :seq_len] + if padded_seq_len > seq_len: + # Pad with zeros (or use a padding token if available) + seq_tokens = torch.nn.functional.pad( + seq_tokens, (0, padded_seq_len - seq_len), value=0 + ) + all_input_ids.append(seq_tokens) + + if cp_size > 1: + seq_tokens = _get_tokens_on_this_cp_rank( + seq_tokens, cp_rank, cp_size, seq_dim=0 + ) + + padded_tokens.append(seq_tokens) + + # Concatenate all padded tokens + # For 'thd' format, the shape should be [1, T] where T is total tokens + packed_input_ids = torch.cat(padded_tokens, dim=0).unsqueeze(0) + all_input_ids = torch.cat(all_input_ids, dim=0).unsqueeze(0) + else: + # No individual padding, just concatenate valid tokens + # For 'thd' format, the shape should be [1, T] where T is total tokens + packed_input_ids = torch.cat(valid_tokens, dim=0).unsqueeze(0) + all_input_ids = packed_input_ids + if pad_packed_seq_to is not None: + pad_len = pad_packed_seq_to - packed_input_ids.shape[1] + if pad_len > 0: + packed_input_ids = torch.nn.functional.pad( + packed_input_ids, (0, pad_len), value=0 + ) + all_input_ids = torch.nn.functional.pad( + all_input_ids, (0, pad_len), value=0 + ) + + if cu_seqlens_padded is None: + cu_seqlens_padded = cu_seqlens.clone() + + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens_padded, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=int(max_seqlen), + max_seqlen_kv=int(max_seqlen), + qkv_format="thd", + ) + + return ( + all_input_ids.contiguous(), + packed_input_ids.contiguous(), + packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) + + +def _unpack_sequences_from_megatron( + output_tensor: torch.Tensor, + seq_lengths: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqlens_padded: Optional[torch.Tensor], + original_batch_size: int, + original_seq_length: int, +) -> torch.Tensor: + """Unpack sequences from Megatron output format. + + Args: + output_tensor: Packed output tensor [1, T, vocab_size] + seq_lengths: Actual sequence lengths for each sample + cu_seqlens: Cumulative sequence lengths + cu_seqlens_padded: Padded cumulative sequence lengths (if CP was used) + original_batch_size: Original batch size + original_seq_length: Original maximum sequence length + + Returns: + Unpacked output tensor [batch_size, seq_length, vocab_size] + """ + # Remove the batch dimension to get [T, vocab_size] + output_tensor = output_tensor.squeeze(0) + + # Create a padded output tensor with original shape + vocab_size = output_tensor.shape[-1] + unpacked_output = torch.zeros( + (original_batch_size, original_seq_length, vocab_size), + dtype=output_tensor.dtype, + device=output_tensor.device, + ) + + # Get context parallel size to determine which cu_seqlens to use + cp_size = get_context_parallel_world_size() + + # Fill in the unpacked output tensor with valid tokens + for b in range(original_batch_size): + # Get actual sequence length for this sample + seq_len = ( + seq_lengths[b].item() if torch.is_tensor(seq_lengths[b]) else seq_lengths[b] + ) + + if cp_size > 1 and cu_seqlens_padded is not None: + # When using CP, we need to account for padding + # Calculate the padded sequence boundaries + pad_factor = cp_size * 2 + padded_seq_len = ((seq_len + pad_factor - 1) // pad_factor) * pad_factor + start_idx = cu_seqlens_padded[b].item() + + # Only copy the valid tokens (not the padding) + unpacked_output[b, :seq_len] = output_tensor[ + start_idx : start_idx + seq_len + ] + else: + # No CP, use regular cu_seqlens + start_idx = cu_seqlens[b].item() + end_idx = cu_seqlens[b + 1].item() + + # Copy the valid tokens to the unpacked tensor + unpacked_output[b, :seq_len] = output_tensor[start_idx:end_idx] + + return unpacked_output def forward_step_arbitrary_loss( @@ -35,6 +255,11 @@ def forward_step_arbitrary_loss( data_iterator: Iterator[BatchedDataDict[Any]], model: GPTModel, loss_fn: LossFunction, + pack_sequences: bool = False, + seq_length_key: Optional[str] = None, + pad_individual_seqs_to_multiple_of: int = 1, + pad_full_seq_to: Optional[int] = None, + cp_normalize: bool = True, ): """Forward training step with support for packed sequences and context parallelism. @@ -45,30 +270,111 @@ def forward_step_arbitrary_loss( data_iterator: Input data iterator model (GPTModel): The GPT Model loss_fn (LossFunction): Loss function to apply + pack_sequences (bool): Whether to pack sequences for efficiency + seq_length_key (Optional[str]): Key in data_dict containing actual sequence lengths + cp_normalize (bool): Whether to normalize the loss by the cp_size + + Notes on packed sequences with context parallelism (CP): + - When CP > 1, each sequence is padded to a multiple of (cp_size * 2) + - The factor of 2 ensures load balancing for causal attention + - cu_seqlens tracks actual sequence boundaries + - cu_seqlens_padded tracks padded sequence boundaries for CP + - Requires TransformerEngine >= 1.10 for CP support """ straggler_timer = state.straggler_timer with straggler_timer(bdata=True): data_dict = next(data_iterator).to("cuda") input_ids = data_dict["input_ids"] - attention_mask, _, position_ids = get_ltor_masks_and_position_ids( - input_ids, 0, False, False, False - ) + attention_mask = None + position_ids = None + packed_seq_params = None + + original_batch_size = input_ids.shape[0] + original_seq_length = input_ids.shape[1] + seq_lengths = None # Will be set if using packed sequences + cu_seqlens = None + cu_seqlens_padded = None + + if pack_sequences: + # For packed sequences with padded input, we need sequence lengths + assert seq_length_key is not None, ( + "seq_length_key must be provided for packed sequences" + ) + assert seq_length_key in data_dict, ( + f"{seq_length_key} not found in data_dict" + ) + + # Get sequence lengths and context parallel size + seq_lengths = data_dict[seq_length_key] + + # Pack sequences + ( + input_ids, + input_ids_cp_sharded, + packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) = _pack_sequences_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of, + pad_full_seq_to, + cp_rank=get_context_parallel_rank(), + cp_size=get_context_parallel_world_size(), + ) + + # For packed sequences, position_ids and attention_mask are typically None + # The PackedSeqParams handles all necessary sequence information + position_ids = None + attention_mask = None + else: + input_ids_cp_sharded = input_ids + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + input_ids, 0, False, False, False + ) with straggler_timer: - output_tensor = model(input_ids, position_ids, attention_mask) + output_tensor = model( + input_ids_cp_sharded, + position_ids, + attention_mask, + packed_seq_params=packed_seq_params, + ) + + # Unpack the output tensor if we did packed sequences + if pack_sequences and packed_seq_params is not None: + # remove padding + loss_fn = SequencePackingLossWrapper( + loss_fn=loss_fn, + cu_seqlens_q=packed_seq_params.cu_seqlens_q, + cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded, + ) loss_data = data_dict - return output_tensor, partial( + loss_fn_wrapped = partial( loss_fn, data=loss_data, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, vocab_parallel_rank=get_tensor_model_parallel_rank(), vocab_parallel_group=get_tensor_model_parallel_group(), + context_parallel_group=get_context_parallel_group(), ) + if cp_normalize: + cp_size = get_context_parallel_world_size() + orig_loss_fn_wrapped = loss_fn_wrapped + + def _div_by_cp_size(*args, **kwargs): + loss, metrics = orig_loss_fn_wrapped(*args, **kwargs) + return loss / cp_size, metrics + + loss_fn_wrapped = _div_by_cp_size + + return output_tensor, loss_fn_wrapped + def broadcast_tensor( tensor: torch.Tensor | None, src_rank: int, group: dist.ProcessGroup diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index f501c978c5..61dfe9b51c 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -14,6 +14,7 @@ import contextlib import gc +import itertools import os from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext @@ -40,6 +41,7 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM from nemo_rl.algorithms.interfaces import LossFunction, LossType +from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.dtensor.parallelize import ( _parallelize_model, @@ -48,7 +50,11 @@ get_logprobs_from_vocab_parallel_logits, to_local_if_dtensor, ) -from nemo_rl.models.huggingface.common import ModelFlag +from nemo_rl.models.huggingface.common import ( + ModelFlag, + get_flash_attention_kwargs, + pack_sequences, +) from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ( LogprobOutputSpec, @@ -170,6 +176,14 @@ def __init__( else: raise ValueError(f"Unknown precision: {self.cfg['precision']}") + print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") + self.enable_seq_packing = self.cfg["sequence_packing"]["enabled"] + if self.enable_seq_packing: + print( + f"[Rank {self.rank}] Sequence packing is enabled for model {model_name}" + ) + print(f"[Rank {self.rank}] Using FlashAttention2 for sequence packing") + model_config = AutoConfig.from_pretrained( model_name, # Always load the model in float32 to keep master weights in float32. @@ -179,6 +193,9 @@ def __init__( **sliding_window_overwrite( model_name ), # due to https://github.com/huggingface/transformers/issues/38002 + attn_implementation="flash_attention_2" + if self.enable_seq_packing + else None, ) full_state_dict = None @@ -216,6 +233,10 @@ def __init__( tp_size = self.cfg["dtensor_cfg"]["tensor_parallel_size"] cp_size = self.cfg["dtensor_cfg"]["context_parallel_size"] + if cp_size > 1 and self.enable_seq_packing: + raise ValueError( + "Context parallel is not supported for sequence packing. Refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details." + ) dp_size = world_size // tp_size // cp_size sequence_parallel_enabled = self.cfg["dtensor_cfg"]["sequence_parallel"] assert world_size == dp_size * tp_size * cp_size, ( @@ -463,8 +484,13 @@ def train( if mbs is None: mbs = self.cfg["train_micro_batch_size"] local_gbs = gbs // self.dp_size - dataset_size = data["input_ids"].shape[0] - num_global_batches = dataset_size // local_gbs + total_dataset_size = torch.tensor(data.size, device="cuda") + torch.distributed.all_reduce( + total_dataset_size, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_mesh.get_group(), + ) + num_global_batches = int(total_dataset_size.item()) // gbs # dim 1 is always assumed to be the sequence dim, sanity check this here sequence_dim = 1 @@ -489,10 +515,8 @@ def train( losses = [] all_mb_metrics = [] - for gb_idx, gb_start in enumerate(range(0, dataset_size, local_gbs)): - global_batch: BatchedDataDict[Any] = data.slice( - gb_start, gb_start + local_gbs - ) + for gb_idx in range(num_global_batches): + global_batch = data.get_batch(batch_idx=gb_idx, batch_size=local_gbs) assert "sample_mask" in global_batch, ( "sample_mask must be present in the data!" @@ -528,32 +552,72 @@ def train( # Calculate number of microbatches to process # make_microbatch_iterator assumes that the batch size is a multiple of the microbatch size # so its safe to not check for the case where the last data slice is smaller than mbs + dummy_iterator = iter([]) if self.cfg["dynamic_batching"]["enabled"]: mb_iterator = batch.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = batch.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = ( + batch.make_microbatch_iterator_for_packable_sequences() + ) + iterator_len, max_seqlen = ( + batch.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) + + # Sequence packing can end up with unevenly distributed batch counts across DP ranks. + # We add dummy batches to the end of the iterator to make the batch counts equal. + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = ( + batch.make_microbatch_iterator_for_packable_sequences() + ) + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) else: mb_iterator = batch.make_microbatch_iterator(mbs) + iterator_len = batch.size // mbs - for mb in mb_iterator: - input_ids = mb.get("input_ids").cuda() - input_lengths = mb.get("input_lengths") - batch_size, seq_len = input_ids.shape + for mb_idx, mb in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + with torch.autocast(device_type="cuda", dtype=self.dtype): + if self.enable_seq_packing: + input_ids = mb.get("input_ids").cuda() + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=mb["input_lengths"], + packed_sequence_size=[ + len(mb["input_lengths"]) + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + min_seq_len=self.cfg["sequence_packing"][ + "train_mb_tokens" + ], # TODO: this is a WAR for sequence packing, we should fix this. Without this, backward will fail when TP is enabled. + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=mb["input_lengths"], + ) - attention_mask = torch.zeros( - (batch_size, seq_len), dtype=torch.long, device=input_ids.device - ) - for i, length in enumerate(input_lengths): - # For right-padded sequence, set 1s at the beginning of the sequence - attention_mask[i, :length] = 1 + else: + input_ids = mb.get("input_ids").cuda() + batch_size, seq_len = input_ids.shape - with torch.autocast(device_type="cuda", dtype=self.dtype): - attention_mask_input_all_ones = torch.ones( - (batch_size, seq_len), - dtype=torch.long, - device=input_ids.device, - ) - position_ids = torch.arange( - seq_len, device=input_ids.device - ).repeat(batch_size, 1) + attention_mask = torch.ones( + (batch_size, seq_len), + dtype=torch.long, + device=input_ids.device, + ) + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + flash_attn_kwargs = {} context_parallel_ctx = None if self.cp_size > 1: @@ -578,9 +642,10 @@ def train( with torch.autocast(device_type="cuda", dtype=self.dtype): outputs = self.model( input_ids=input_ids, - attention_mask=attention_mask_input_all_ones, + attention_mask=attention_mask, position_ids=position_ids, use_cache=False, + flash_attn_kwargs=flash_attn_kwargs, ) # Get logprobs @@ -648,18 +713,34 @@ def train( placements=[Shard(sequence_dim), Shard(-1)], ) - loss, loss_metrics = loss_fn( - logits, mb, global_valid_seqs, global_valid_toks + if self.enable_seq_packing: + loss_fn_ = SequencePackingLossWrapper( + loss_fn=loss_fn, + cu_seqlens_q=flash_attn_kwargs.cu_seqlens_q, + cu_seqlens_q_padded=flash_attn_kwargs.cu_seqlens_q, + ) + else: + loss_fn_ = loss_fn + + loss, loss_metrics = loss_fn_( + logits, + mb, + global_valid_seqs, + global_valid_toks, ) - ## scale by the number of global batches so we get the correct - ## value when summing metrics across all microbatches - for k in loss_metrics.keys(): - loss_metrics[k] /= num_global_batches - num_valid_samples = loss_metrics["num_valid_samples"] - loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] - loss_metrics["global_valid_seqs"] = global_valid_seqs.item() - loss_metrics["global_valid_toks"] = global_valid_toks.item() + # skip the update for dummy batches + if mb_idx < iterator_len: + ## scale by the number of global batches so we get the correct + ## value when summing metrics across all microbatches + for k in loss_metrics.keys(): + loss_metrics[k] /= num_global_batches + num_valid_samples = loss_metrics["num_valid_samples"] + loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] + loss_metrics["global_valid_seqs"] = global_valid_seqs.item() + loss_metrics["global_valid_toks"] = global_valid_toks.item() + else: + loss *= 0 # Backward pass if not eval_mode: @@ -762,29 +843,70 @@ def get_logprobs( with unshard_fsdp2_model(self.model), torch.no_grad(): data.to("cuda") + dummy_iterator = iter([]) if self.cfg["dynamic_batching"]["enabled"]: mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = data.make_microbatch_iterator_for_packable_sequences() + iterator_len, max_seqlen = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) + + # Sequence packing can end up with unevenly distributed batch counts across DP ranks. + # We add dummy batches to the end of the iterator to make the batch counts equal. + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = data.make_microbatch_iterator_for_packable_sequences() + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) else: mb_iterator = data.make_microbatch_iterator(logprob_batch_size) + iterator_len = data.size // logprob_batch_size - for lp_batch in mb_iterator: + step = 0 + for batch_idx, lp_batch in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + step += 1 input_ids = lp_batch.get("input_ids").cuda() input_lengths = lp_batch.get("input_lengths") batch_size, seq_len = input_ids.shape - # Create attention mask for right-padded data - attention_mask = torch.zeros( - (batch_size, seq_len), dtype=torch.long, device=input_ids.device - ) - for i, length in enumerate(input_lengths): - # For right-padded sequence, set 1s at the beginning of the sequence - attention_mask[i, :length] = 1 - - # explicitly create position ids for the input, otherwise the sharding - # for DTensor will be incorrect - position_ids = torch.arange(seq_len, device=input_ids.device).repeat( - batch_size, 1 - ) + if self.enable_seq_packing: + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=input_lengths, + packed_sequence_size=[ + batch_size + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=input_lengths, + ) + else: + # Create attention mask for right-padded data + attention_mask = torch.zeros( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device + ) + for i, length in enumerate(input_lengths): + # For right-padded sequence, set 1s at the beginning of the sequence + attention_mask[i, :length] = 1 + + # explicitly create position ids for the input, otherwise the sharding + # for DTensor will be incorrect + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + flash_attn_kwargs = {} with torch.autocast(device_type="cuda", dtype=self.dtype): # DTensor requires the casual attention kernel to hit, @@ -795,41 +917,128 @@ def get_logprobs( (batch_size, seq_len), dtype=torch.long, device=input_ids.device ) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask_input_all_ones, - position_ids=position_ids, - use_cache=False, + context_parallel_ctx = None + if self.cp_size > 1: + seq_index = torch.arange(seq_len, device=input_ids.device).repeat( + 1, 1 ) - - if isinstance(outputs.logits, DTensor): - token_logprobs = get_logprobs_from_vocab_parallel_logits( - outputs.logits.to(torch.float32), input_ids + cp_buffers = [input_ids, position_ids, seq_index] + + # Create context parallel context + context_parallel_ctx = self.create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), ) - else: - # Extract logprobs for each token in the sequence by gathering the logprob - # corresponding to the next token at each position - # Input shapes: - # log_probs: [batch_size, sequence_length, vocab_size] - logits for each position - # token_ids: [batch_size, sequence_length] - actual tokens - # Output shape: [batch_size, sequence_length] - logprob of each token given previous - # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length - - log_probs = torch.nn.functional.log_softmax( - outputs.logits.to(torch.float32), dim=-1 - ) - next_tokens = input_ids[:, 1:] - log_probs = log_probs[:, :-1] - token_logprobs = log_probs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) + + with DTensorPolicyWorker.train_context(context_parallel_ctx): + with torch.autocast(device_type="cuda", dtype=self.dtype): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask_input_all_ones, + position_ids=position_ids, + use_cache=False, + flash_attn_kwargs=flash_attn_kwargs, + ) + + logits = outputs.logits + + if self.cp_size > 1: + seq_index_tensor = ( + DTensor.from_local( + seq_index, + device_mesh=self.cp_mesh, + placements=[Shard(1)], + ) + .full_tensor() + .squeeze(0) + ) + + input_ids_dtensor = DTensor.from_local( + input_ids, + device_mesh=self.cp_mesh, + placements=[Shard(sequence_dim)], + ) + + if isinstance(logits, DTensor): + # Must be tp sharded + assert ( + logits.device_mesh.ndim == 1 + and logits.device_mesh.mesh_dim_names[0] == "tp" + ), "logits must be tp sharded" + + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim + logits = DTensor.from_local( + logits.to_local(), + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + else: + logits = DTensor.from_local( + logits, + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + + token_logprobs = get_logprobs_from_vocab_parallel_logits( + logits.to(torch.float32), + input_ids_dtensor, + seq_index_tensor, + ) + + assert token_logprobs.shape[1] == seq_len - 1 + else: + if isinstance(logits, DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + logits.to(torch.float32), input_ids + ) + else: + # Extract logprobs for each token in the sequence by gathering the logprob + # corresponding to the next token at each position + # Input shapes: + # log_probs: [batch_size, sequence_length, vocab_size] - logits for each position + # token_ids: [batch_size, sequence_length] - actual tokens + # Output shape: [batch_size, sequence_length] - logprob of each token given previous + # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length + + log_probs = torch.nn.functional.log_softmax( + outputs.logits.to(torch.float32), dim=-1 + ) + next_tokens = input_ids[:, 1:] + log_probs = log_probs[:, :-1] + token_logprobs = log_probs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) token_logprobs = torch.cat( [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 ) - # Apply mask to zero out padding tokens logprobs - token_logprobs = token_logprobs * attention_mask + # skip keeping the logprobs for the dummy batches + if batch_idx >= iterator_len: + continue + + if not self.enable_seq_packing: + # Apply mask to zero out padding tokens logprobs + token_logprobs = token_logprobs * attention_mask + else: + # For packed sequences, unpack logprobs + unpacked_logprobs = torch.zeros( + (batch_size, seq_dim_size), + dtype=token_logprobs.dtype, + device=token_logprobs.device, + ) + cu_seqlens = flash_attn_kwargs.cu_seqlens_q + for i in range(batch_size): + start = cu_seqlens[i].item() + 1 + end = cu_seqlens[i + 1].item() + seq_len_actual = input_lengths[i].item() + unpacked_logprobs[i, 1:seq_len_actual] = token_logprobs[ + 0, start:end + ] + token_logprobs = unpacked_logprobs + all_log_probs.append(token_logprobs) # Concatenate all batches diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index c77f2460e7..dbbf5ddc1e 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -24,6 +24,7 @@ from nemo_rl.distributed.batched_data_dict import ( BatchedDataDict, DynamicBatchingArgs, + SequencePackingArgs, SlicedDataDict, ) from nemo_rl.distributed.named_sharding import NamedSharding @@ -140,9 +141,35 @@ def __init__( ], "max_tokens_per_microbatch": 0, # Override this in each different call (presumably different sizes) } + assert not config["sequence_packing"]["enabled"], ( + "Dynamic Batching is exclusive of Sequence Packing. Please disable Sequence Packing to use Dynamic Batching" + ) else: self.use_dynamic_batches = False + if config["sequence_packing"]["enabled"]: + assert ( + config["megatron_cfg"]["enabled"] or config["dtensor_cfg"]["enabled"] + ), "Sequence packing requires Megatron or DTensor policies." + self.use_sequence_packing = True + self.sequence_packing_args: SequencePackingArgs = { + "train_mb_tokens": config["sequence_packing"]["train_mb_tokens"], + "logprob_mb_tokens": config["sequence_packing"].get( + "logprob_mb_tokens", None + ), + "algorithm": config["sequence_packing"]["algorithm"], + "input_key": "input_ids", + "input_lengths_key": "input_lengths", + "sequence_length_pad_multiple": (cp_size * 2 * tp_size) + if cp_size > 1 + else tp_size, + } + assert not config["dynamic_batching"]["enabled"], ( + "Sequence Packing is exclusive of Dynamic Batching. Please disable Dynamic Batching" + ) + else: + self.use_sequence_packing = False + self.cfg = config def init_collective( @@ -166,7 +193,6 @@ def get_logprobs( The logprob of input token i is specified at position i in the output logprobs tensor. """ dp_size = self.sharding_annotations.get_axis_size("data_parallel") - cp_size = self.sharding_annotations.get_axis_size("context_parallel") sharded_data: list[SlicedDataDict] unsorted_data_indices: list[int] @@ -175,32 +201,40 @@ def get_logprobs( "dynamic_batching" ]["logprob_mb_tokens"] sharded_data, unsorted_data_indices = data.shard_by_batch_size( # type: ignore - cp_size * dp_size, + dp_size, batch_size=None, dynamic_batching_args=self.dynamic_batching_args, ) + elif self.use_sequence_packing: + self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ + "sequence_packing" + ]["logprob_mb_tokens"] + # we just shard into DP shards here as Sequence packing allows for CP. + sharded_data, unsorted_data_indices = data.shard_by_batch_size( + dp_size, + batch_size=None, + sequence_packing_args=self.sequence_packing_args, + ) else: sharded_data = data.shard_by_batch_size( # type: ignore - cp_size * dp_size, + dp_size, batch_size=None, ) - sharded_data_2d = [] - shard_idx = 0 - # Convert to 2d dim array - for _ in range(dp_size): - cp_data = [] - for _ in range(cp_size): - cp_data.append(sharded_data[shard_idx]) - shard_idx += 1 - sharded_data_2d.append(cp_data) - futures = self.worker_group.run_all_workers_sharded_data( "get_logprobs", - data=sharded_data_2d, - in_sharded_axes=["data_parallel", "context_parallel"], - replicate_on_axes=["tensor_parallel", "pipeline_parallel"], - output_is_replicated=["tensor_parallel", "pipeline_parallel"], + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], ) logprobs: BatchedDataDict[LogprobOutputSpec] = BatchedDataDict.from_batches( self.worker_group.get_all_worker_results(futures) @@ -208,7 +242,7 @@ def get_logprobs( # dynamic batching sorts the inputs by sequence length to improve load balancing, # so change it back here - if self.use_dynamic_batches: + if self.use_dynamic_batches or self.use_sequence_packing: logprobs.reorder_data(unsorted_data_indices) return logprobs @@ -223,7 +257,6 @@ def get_reference_policy_logprobs( Returns: Identical to get_logprobs. """ dp_size = self.sharding_annotations.get_axis_size("data_parallel") - cp_size = self.sharding_annotations.get_axis_size("context_parallel") sharded_data: list[SlicedDataDict] unsorted_data_indices: list[int] if self.use_dynamic_batches: @@ -231,32 +264,39 @@ def get_reference_policy_logprobs( "dynamic_batching" ]["logprob_mb_tokens"] sharded_data, unsorted_data_indices = data.shard_by_batch_size( # type: ignore - cp_size * dp_size, + dp_size, batch_size=None, dynamic_batching_args=self.dynamic_batching_args, ) + elif self.use_sequence_packing: + self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ + "sequence_packing" + ]["logprob_mb_tokens"] + sharded_data, unsorted_data_indices = data.shard_by_batch_size( + dp_size, + batch_size=None, + sequence_packing_args=self.sequence_packing_args, + ) else: sharded_data = data.shard_by_batch_size( # type: ignore - cp_size * dp_size, + dp_size, batch_size=None, ) - sharded_data_2d = [] - shard_idx = 0 - # Convert to 2d dim array - for _ in range(dp_size): - cp_data = [] - for _ in range(cp_size): - cp_data.append(sharded_data[shard_idx]) - shard_idx += 1 - sharded_data_2d.append(cp_data) - futures = self.worker_group.run_all_workers_sharded_data( "get_reference_policy_logprobs", - data=sharded_data_2d, - in_sharded_axes=["data_parallel", "context_parallel"], - replicate_on_axes=["tensor_parallel", "pipeline_parallel"], - output_is_replicated=["tensor_parallel", "pipeline_parallel"], + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], common_kwargs={"micro_batch_size": micro_batch_size}, ) logprobs: BatchedDataDict[ReferenceLogprobOutputSpec] = ( @@ -267,7 +307,7 @@ def get_reference_policy_logprobs( # dynamic batching sorts the inputs by sequence length to improve load balancing, # so change it back here - if self.use_dynamic_batches: + if self.use_dynamic_batches or self.use_sequence_packing: logprobs.reorder_data(unsorted_data_indices) return logprobs @@ -294,6 +334,15 @@ def train( batch_size=batch_size, dynamic_batching_args=self.dynamic_batching_args, ) + elif self.use_sequence_packing: + self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ + "sequence_packing" + ]["train_mb_tokens"] + sharded_data, _ = data.shard_by_batch_size( + dp_size, + batch_size=batch_size, + sequence_packing_args=self.sequence_packing_args, + ) else: sharded_data = data.shard_by_batch_size( dp_size, diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index b0d544aabb..867f27ea1d 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -39,6 +39,8 @@ from megatron.core.models.gpt import GPTModel from megatron.core.optimizer import ChainedOptimizer from megatron.core.parallel_state import ( + get_context_parallel_group, + get_context_parallel_rank, get_pipeline_model_parallel_group, get_pipeline_model_parallel_last_rank, get_pipeline_model_parallel_rank, @@ -86,7 +88,10 @@ from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.model_utils import from_parallel_logits_to_logprobs +from nemo_rl.distributed.model_utils import ( + from_parallel_logits_to_logprobs, + from_parallel_logits_to_logprobs_packed_sequences, +) from nemo_rl.distributed.named_sharding import NamedSharding from nemo_rl.models.generation.interfaces import ( GenerationDatumSpec, @@ -94,6 +99,7 @@ verify_right_padding, ) from nemo_rl.models.megatron.common import ( + _pack_sequences_for_megatron, broadcast_tensor, forward_step_arbitrary_loss, ) @@ -452,12 +458,11 @@ def __init__( model_cfg.sequence_parallel = self.cfg["megatron_cfg"]["sequence_parallel"] model_cfg.context_parallel_size = self.cfg["megatron_cfg"][ "context_parallel_size" - ] # not supported right now - assert model_cfg.context_parallel_size == 1, ( - "Context parallel is not supported right now" - ) - - ## moe-related + ] + if model_cfg.context_parallel_size > 1: + assert self.cfg["sequence_packing"]["enabled"], ( + "Sequence Packing must be enabled to use Context Parallelism with MCore" + ) model_cfg.expert_tensor_parallel_size = self.cfg["megatron_cfg"][ "expert_tensor_parallel_size" ] @@ -568,7 +573,6 @@ def __init__( ), ) self.megatron_cfg.validate() - print(f"cfg: {self.megatron_cfg}") ( self.mcore_state, self.model, @@ -799,11 +803,32 @@ def train( ) batch = data.get_batch(batch_idx=gb_idx, batch_size=local_gbs) + pack_seqs = False + seqlen_key = None + pad_factor = 1 + pad_full_seq_to = None if self.cfg["dynamic_batching"]["enabled"]: data_iterator = batch.make_microbatch_iterator_with_dynamic_shapes() data_iterator_len = ( batch.get_microbatch_iterator_dynamic_shapes_len() ) + elif self.cfg["sequence_packing"]["enabled"]: + data_iterator = ( + batch.make_microbatch_iterator_for_packable_sequences() + ) + data_iterator_len, seq_dim_size = ( + batch.get_microbatch_iterator_for_packable_sequences_len() + ) + mbs = 1 + pack_seqs = True + seqlen_key = "input_lengths" + tp_size = self.cfg["megatron_cfg"]["tensor_model_parallel_size"] + cp_size = self.cfg["megatron_cfg"]["context_parallel_size"] + pad_factor = cp_size * 2 * tp_size if cp_size > 1 else tp_size + if self.cfg["megatron_cfg"]["pipeline_model_parallel_size"] > 1: + _, pad_full_seq_to = ( + batch.get_microbatch_iterator_for_packable_sequences_len() + ) else: data_iterator = batch.make_microbatch_iterator(mbs) data_iterator_len = local_gbs // mbs @@ -822,6 +847,10 @@ def train( self.mcore_state, global_valid_seqs, global_valid_toks, + pack_sequences=pack_seqs, + seq_length_key=seqlen_key, + pad_individual_seqs_to_multiple_of=pad_factor, + pad_full_seq_to=pad_full_seq_to, ), data_iterator=data_iterator, model=self.model, @@ -929,7 +958,6 @@ def train( } return metrics - # Temporary fix, 'data' is a kwarg due to some sort of ray bug def get_logprobs( self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None ) -> BatchedDataDict[LogprobOutputSpec]: @@ -969,33 +997,88 @@ def get_logprobs( pp_rank = get_pipeline_model_parallel_rank() pp_grp = get_pipeline_model_parallel_group() pp_size = get_pipeline_model_parallel_world_size() + cp_size = self.cfg["megatron_cfg"]["context_parallel_size"] + # if pp_size > 1, we need to pad the full sequence to the max sequence length to maintain a static PP buffer + if ( + self.cfg["sequence_packing"]["enabled"] + and self.cfg["megatron_cfg"]["pipeline_model_parallel_size"] > 1 + ): + _, pad_full_seq_to = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + pp_seq_dim_size = pad_full_seq_to + else: + pad_full_seq_to = None def forward_step_fn( data_iterator: Iterator[BatchedDataDict[Any]], model: GPTModel ): + nonlocal pad_full_seq_to data_dict = next(data_iterator).to("cuda") - input_ids = data_dict["input_ids"] - attention_mask, _, position_ids = get_ltor_masks_and_position_ids( - input_ids, 0, False, False, False - ) + if self.cfg["sequence_packing"]["enabled"]: + original_seq_length = data_dict["input_ids"].shape[1] + tp_size = self.cfg["megatron_cfg"]["tensor_model_parallel_size"] + pp_size = self.cfg["megatron_cfg"]["pipeline_model_parallel_size"] + cp_size = self.cfg["megatron_cfg"]["context_parallel_size"] + cp_rank = get_context_parallel_rank() + pad_factor = cp_size * 2 * tp_size if cp_size > 1 else tp_size + ( + input_ids, + input_ids_cp_sharded, + packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) = _pack_sequences_for_megatron( + data_dict["input_ids"].clone(), + data_dict["input_lengths"], + pad_individual_seqs_to_multiple_of=pad_factor, + pad_packed_seq_to=pad_full_seq_to, + cp_rank=cp_rank, + cp_size=cp_size, + ) + attention_mask, position_ids = None, None + unpacked_input_ids = data_dict["input_ids"] + else: + input_ids = data_dict["input_ids"] + input_ids_cp_sharded = input_ids + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + input_ids, 0, False, False, False + ) + packed_seq_params = None + unpacked_input_ids = input_ids output_tensor = model( - input_ids, + input_ids_cp_sharded, position_ids, attention_mask, + packed_seq_params=packed_seq_params, ) def collection_fn(output_tensor): + stc = time.time() tp_grp = get_tensor_model_parallel_group() tp_rank = get_tensor_model_parallel_rank() - token_logprobs = from_parallel_logits_to_logprobs( - output_tensor.to(torch.float32), - target=input_ids, - vocab_start_index=tp_rank * output_tensor.shape[-1], - vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], - tp_group=tp_grp, - inference_only=True, - ) + if self.cfg["sequence_packing"]["enabled"]: + token_logprobs = from_parallel_logits_to_logprobs_packed_sequences( + output_tensor, + target=input_ids, + cu_seqlens_padded=cu_seqlens_padded, + unpacked_seqlen=original_seq_length, + vocab_start_index=tp_rank * output_tensor.shape[-1], + vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], + group=tp_grp, + inference_only=True, + cp_group=get_context_parallel_group(), + ) + else: + token_logprobs = from_parallel_logits_to_logprobs( + output_tensor.to(torch.float32), + target=unpacked_input_ids, + vocab_start_index=tp_rank * output_tensor.shape[-1], + vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], + tp_group=tp_grp, + inference_only=True, + ) # Prepend 0 logprob for first token to maintain same sequence length as input token_logprobs = torch.cat( @@ -1010,10 +1093,17 @@ def collection_fn(output_tensor): if self.cfg["dynamic_batching"]["enabled"]: mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() data_iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + micro_batch_size = logprob_batch_size + elif self.cfg["sequence_packing"]["enabled"]: + mb_iterator = data.make_microbatch_iterator_for_packable_sequences() + data_iterator_len, _ = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + micro_batch_size = 1 else: mb_iterator = data.make_microbatch_iterator(logprob_batch_size) data_iterator_len = max(1, data.size // logprob_batch_size) - micro_batch_size = logprob_batch_size + micro_batch_size = logprob_batch_size forward_backward_func = get_forward_backward_func() list_of_logprobs = forward_backward_func( diff --git a/pyproject.toml b/pyproject.toml index dd03c5939a..cddda79abe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,6 @@ requires = ["setuptools>=42", "wheel"] build-backend = "setuptools.build_meta" - [tool.setuptools] packages = ["nemo_rl"] diff --git a/tests/functional/test_converter_roundtrip.py b/tests/functional/test_converter_roundtrip.py index 9679fcc724..c0c4b2fdd8 100644 --- a/tests/functional/test_converter_roundtrip.py +++ b/tests/functional/test_converter_roundtrip.py @@ -81,6 +81,7 @@ def create_test_config() -> Dict[str, Any]: "custom_parallel_plan": None, }, "dynamic_batching": {"enabled": False}, + "sequence_packing": {"enabled": False}, "make_sequence_length_divisible_by": 1, "max_grad_norm": 1.0, "optimizer": { diff --git a/tests/unit/algorithms/__init__.py b/tests/unit/algorithms/__init__.py index e69de29bb2..341a77c5bc 100644 --- a/tests/unit/algorithms/__init__.py +++ b/tests/unit/algorithms/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index c81cf8a686..b387d1e2f0 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -77,17 +77,7 @@ def create_mock_batch( @pytest.fixture(scope="module") -def ray_init(): - """Initialize Ray for testing.""" - if not ray.is_initialized(): - ray.init(ignore_reinit_error=True) - yield - if ray.is_initialized(): - ray.shutdown() - - -@pytest.fixture(scope="module") -def mock_env(ray_init): +def mock_env(): """Create a mock environment for single task tests.""" env = MockEnvironment.remote(rewards=[1.0, 2.0]) yield env @@ -95,7 +85,7 @@ def mock_env(ray_init): @pytest.fixture(scope="module") -def mock_envs(ray_init): +def mock_envs(): """Create mock environments for multiple task tests.""" math_env = MockEnvironment.remote(rewards=[1.0, 2.0]) code_env = MockEnvironment.remote(rewards=[3.0, 4.0]) diff --git a/tests/unit/algorithms/test_sequence_packing_gradients.py b/tests/unit/algorithms/test_sequence_packing_gradients.py new file mode 100644 index 0000000000..33d858fbe4 --- /dev/null +++ b/tests/unit/algorithms/test_sequence_packing_gradients.py @@ -0,0 +1,449 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test script to debug high gradients with sequence packing + context parallelism.""" + +import os + +import pytest +import ray +import torch + +from nemo_rl.algorithms.loss_functions import ( + ClippedPGLossFn, + SequencePackingLossWrapper, +) +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.named_sharding import NamedSharding +from nemo_rl.distributed.ray_actor_environment_registry import ( + ACTOR_ENVIRONMENT_REGISTRY, + PY_EXECUTABLES, +) +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup + + +@ray.remote(num_gpus=1) +class SequencePackingGradientTestActor: + def __init__(self, cp_size): + self.cp_size = cp_size + self.env_vars = dict(os.environ) + + def test_sequence_packing_gradients(self): + from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank + from nemo_rl.models.megatron.common import ( + _pack_sequences_for_megatron, + forward_step_arbitrary_loss, + ) + + # Initialize process group + torch.distributed.init_process_group(backend="nccl") + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # Create CP group - all ranks participate in CP + cp_group = torch.distributed.new_group(ranks=list(range(world_size))) + + # Patch get_context_parallel_group to always return cp_group + # (Assume it's imported from nemo_rl.models.megatron.common) + import megatron.core.parallel_state as parallel_state + + parallel_state._CONTEXT_PARALLEL_GROUP = cp_group + parallel_state._TENSOR_MODEL_PARALLEL_GROUP = torch.distributed.new_group( + ranks=[rank] + ) + + # Test parameters + batch_size = 4 + max_seq_len = 512 + vocab_size = 1000 + cp_size = self.cp_size + + # Ensure sequence length is compatible with CP load balancing + if max_seq_len % (2 * cp_size) != 0: + max_seq_len = (max_seq_len // (2 * cp_size) + 1) * (2 * cp_size) + + # Create test data with varying sequence lengths + torch.manual_seed(42) # For reproducibility + seq_lengths = torch.tensor( + [ + max_seq_len // 4, + max_seq_len * 1 // 4, + max_seq_len // 4, + max_seq_len * 3 // 4, + ] + ) + + # Create input data + input_ids = torch.zeros( + batch_size, max_seq_len, dtype=torch.long, device="cuda" + ) + token_mask = torch.zeros( + batch_size, max_seq_len, dtype=torch.float, device="cuda" + ) + + # Fill with random tokens up to seq_length + for i in range(batch_size): + length = seq_lengths[i] + input_ids[i, :length] = torch.randint( + 0, vocab_size, (length,), device="cuda" + ) + token_mask[i, :length] = 1.0 + + # Create other required tensors + sample_mask = torch.ones(batch_size, dtype=torch.float, device="cuda") + advantages = torch.randn(batch_size, max_seq_len, device="cuda") + prev_logprobs = torch.randn(batch_size, max_seq_len, device="cuda") + generation_logprobs = torch.randn(batch_size, max_seq_len, device="cuda") + reference_policy_logprobs = generation_logprobs.clone() + + original_data = { + "input_ids": input_ids, + "input_lengths": seq_lengths, + "token_mask": token_mask, + "sample_mask": sample_mask, + "advantages": advantages, + "prev_logprobs": prev_logprobs, + "generation_logprobs": generation_logprobs, + "reference_policy_logprobs": reference_policy_logprobs, + } + + # ===== TEST 1: Baseline (no sequence packing) ===== + print(f"Rank {rank}: Testing baseline (no sequence packing)") + + baseline_logits = torch.randn( + batch_size, max_seq_len, vocab_size, requires_grad=True, device="cuda" + ) + + loss_config = { + "reference_policy_kl_penalty": 0.1, + "ratio_clip_min": 0.2, + "ratio_clip_max": 0.2, + "ratio_clip_c": 3.0, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, + "token_level_loss": True, + } + + base_loss_fn = ClippedPGLossFn(loss_config) + data_dict = BatchedDataDict(original_data) + + global_valid_toks = torch.tensor( + sum(seq_lengths).item(), dtype=torch.float, device="cuda" + ) + global_valid_seqs = torch.tensor(batch_size, dtype=torch.float, device="cuda") + + # Forward pass + baseline_loss, baseline_metrics = base_loss_fn( + baseline_logits, + data_dict, + global_valid_seqs, + global_valid_toks, + ) + + # Backward pass + baseline_loss.backward() + + # Check baseline gradients + baseline_grad_norm = torch.norm(baseline_logits.grad).item() + baseline_grad_max = torch.max(torch.abs(baseline_logits.grad)).item() + baseline_grad_mean = torch.mean(torch.abs(baseline_logits.grad)).item() + baseline_grad_store = baseline_logits.grad.clone() + baseline_logits.grad.zero_() + + print( + f"Rank {rank}: Baseline gradient stats - norm: {baseline_grad_norm:.4f}, max: {baseline_grad_max:.4f}, mean: {baseline_grad_mean:.4f}" + ) + + # ===== TEST 2: Sequence packing with context parallelism ===== + print(f"Rank {rank}: Testing with sequence packing + CP") + + # Pack sequences + pad_to_multiple = cp_size * 2 # Common requirement for CP + ( + packed_input_ids, + packed_input_ids_cp, + packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) = _pack_sequences_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of=pad_to_multiple, + pad_packed_seq_to=max_seq_len * batch_size if cp_size > 1 else None, + cp_rank=rank, + cp_size=cp_size, + ) + + # For CP, logits are sharded across context parallel ranks + def make_packed_logits(logits): + packed_logits = torch.zeros( + 1, packed_input_ids_cp.shape[1], vocab_size, device="cuda" + ) + run_seq = 0 + for i, seq_len in enumerate(seq_lengths): + padded_seqlen = cu_seqlens_padded[i + 1] - cu_seqlens_padded[i] + if padded_seqlen > baseline_logits.shape[1]: + # pad the logits with zeros + tmp_logits = torch.zeros( + 1, padded_seqlen, vocab_size, device="cuda" + ) + tmp_logits[:, :seq_len] = baseline_logits[i : i + 1, :seq_len] + else: + tmp_logits = baseline_logits[i : i + 1, :padded_seqlen] + packed_logits[ + :, run_seq // cp_size : (run_seq + padded_seqlen) // cp_size, : + ] = _get_tokens_on_this_cp_rank(tmp_logits, rank, cp_size) + run_seq += padded_seqlen + return packed_logits + + packed_logits = make_packed_logits(baseline_logits) + + # Create sequence packing wrapper + wrapper = SequencePackingLossWrapper( + loss_fn=base_loss_fn, + cu_seqlens_q=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + ) + + # Create data dict for packed sequences + packed_data_dict = BatchedDataDict(original_data) + + tp_group = torch.distributed.new_group(ranks=[rank]) + + # Forward pass + packed_loss, packed_metrics = wrapper( + packed_logits, + packed_data_dict, + global_valid_seqs, + global_valid_toks, + vocab_parallel_rank=0, + vocab_parallel_group=tp_group, + context_parallel_group=cp_group, + ) + + # Backward pass + packed_loss /= cp_size + packed_loss.backward() + + # Check packed gradients + packed_grad = baseline_logits.grad.clone() + # all-reduce across cp ranks + torch.distributed.all_reduce(packed_grad, op=torch.distributed.ReduceOp.SUM) + + packed_grad_norm = torch.norm(packed_grad).item() + packed_grad_max = torch.max(torch.abs(packed_grad)).item() + packed_grad_mean = torch.mean(torch.abs(packed_grad)).item() + # print(f"max grad on dims {torch.max(torch.abs(packed_grad), dim=0)}, {torch.max(torch.abs(packed_grad), dim=1)}, {torch.max(torch.abs(packed_grad), dim=2)}") + + print( + f"Rank {rank}: Packed gradient stats - norm: {packed_grad_norm:.4f}, max: {packed_grad_max:.4f}, mean: {packed_grad_mean:.4f}" + ) + + # ===== ANALYSIS ===== + gradient_ratio_norm = ( + packed_grad_norm / baseline_grad_norm + if baseline_grad_norm > 0 + else float("inf") + ) + gradient_ratio_max = ( + packed_grad_max / baseline_grad_max + if baseline_grad_max > 0 + else float("inf") + ) + gradient_ratio_mean = ( + packed_grad_mean / baseline_grad_mean + if baseline_grad_mean > 0 + else float("inf") + ) + + print( + f"Rank {rank}: Gradient ratios - norm: {gradient_ratio_norm:.4f}, max: {gradient_ratio_max:.4f}, mean: {gradient_ratio_mean:.4f}" + ) + + print( + f"differences by token: {torch.sum(torch.abs(packed_grad - baseline_grad_store), dim=-1)}" + ) + + torch.testing.assert_close( + packed_grad, baseline_grad_store, atol=1e-5, rtol=1e-5 + ) + + # test 3: with forward_step_arbitrary_loss + # reset grad + baseline_logits.grad.zero_() + packed_logits = make_packed_logits(baseline_logits) + + # mock model forward + class MockModel: + def __init__(self): + self.logits = packed_logits + + def __call__(self, *args, **kwargs): + return self.logits + + def forward( + self, input_ids, position_ids, attention_mask, packed_seq_params=None + ): + return self.logits + + class MockMcoreState: + def __init__(self): + # context that does nothing, but supports both with straggler_timer and with straggler_timer(bdata=True) + from contextlib import nullcontext + + class DummyStragglerTimer: + def __call__(self, *args, **kwargs): + return nullcontext() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + self.straggler_timer = DummyStragglerTimer() + + output_tensor, wrapped_loss_fn = forward_step_arbitrary_loss( + MockMcoreState(), + global_valid_seqs, + global_valid_toks, + data_iterator=iter([packed_data_dict]), + model=MockModel(), + loss_fn=base_loss_fn, + pack_sequences=True, + seq_length_key="input_lengths", + pad_individual_seqs_to_multiple_of=pad_to_multiple, + pad_full_seq_to=max_seq_len * batch_size if cp_size > 1 else None, + cp_normalize=True, + ) + loss, metrics = wrapped_loss_fn(output_tensor) + + loss.backward() + + # Check packed gradients + packed_grad = baseline_logits.grad.clone() + # all-reduce across cp ranks + torch.distributed.all_reduce(packed_grad, op=torch.distributed.ReduceOp.SUM) + + packed_grad_norm = torch.norm(packed_grad).item() + packed_grad_max = torch.max(torch.abs(packed_grad)).item() + packed_grad_mean = torch.mean(torch.abs(packed_grad)).item() + print( + f"Rank {rank}: Packed gradient stats - norm: {packed_grad_norm:.4f}, max: {packed_grad_max:.4f}, mean: {packed_grad_mean:.4f}" + ) + + gradient_ratio_norm = ( + packed_grad_norm / baseline_grad_norm + if baseline_grad_norm > 0 + else float("inf") + ) + gradient_ratio_max = ( + packed_grad_max / baseline_grad_max + if baseline_grad_max > 0 + else float("inf") + ) + + print( + f"Rank {rank}: Gradient ratios - norm: {gradient_ratio_norm:.4f}, max: {gradient_ratio_max:.4f}" + ) + print( + f"differences by token: {torch.sum(torch.abs(packed_grad - baseline_grad_store), dim=-1)}" + ) + + +SEQUENCE_PACKING_GRADIENT_TEST_ACTOR_FQN = ( + f"{SequencePackingGradientTestActor.__module__}.SequencePackingGradientTestActor" +) + + +@pytest.fixture +def register_sequence_packing_gradient_test_actor(): + """Register the SequencePackingGradientTestActor for use in tests.""" + original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get( + SEQUENCE_PACKING_GRADIENT_TEST_ACTOR_FQN + ) + ACTOR_ENVIRONMENT_REGISTRY[SEQUENCE_PACKING_GRADIENT_TEST_ACTOR_FQN] = ( + PY_EXECUTABLES.MCORE + ) + + yield SEQUENCE_PACKING_GRADIENT_TEST_ACTOR_FQN + + # Clean up registry + if SEQUENCE_PACKING_GRADIENT_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY: + if original_registry_value is None: + del ACTOR_ENVIRONMENT_REGISTRY[SEQUENCE_PACKING_GRADIENT_TEST_ACTOR_FQN] + else: + ACTOR_ENVIRONMENT_REGISTRY[SEQUENCE_PACKING_GRADIENT_TEST_ACTOR_FQN] = ( + original_registry_value + ) + + +@pytest.fixture(scope="function") +def cluster_fixture(request): + """Create and teardown a virtual cluster for CP tests.""" + cp_size = request.node.callspec.params["cp_size"] + + # Skip if not enough GPUs + if not torch.cuda.is_available() or torch.cuda.device_count() < cp_size: + pytest.skip( + f"Not enough GPUs available. Need {cp_size}, got {torch.cuda.device_count()}" + ) + + # Mysteriously, Ray is not initialized in this test, so we need to initialize it here. + if not ray.is_initialized(): + print("Ray not initialized, initializing now...") + from nemo_rl.distributed.virtual_cluster import init_ray + + init_ray() + print("Ray initialized successfully") + else: + print("Ray is already initialized") + + cluster_name = f"test-sequence-packing-cp{cp_size}" + print(f"Creating virtual cluster '{cluster_name}' for {cp_size} GPUs...") + + cluster = RayVirtualCluster( + name=cluster_name, bundle_ct_per_node_list=[cp_size], use_gpus=True + ) + yield cluster + print(f"Shutting down cluster '{cluster_name}'...") + cluster.shutdown() + + +@pytest.mark.parametrize("cp_size", [1, 2]) +def test_sequence_packing_gradients_with_cp( + cluster_fixture, register_sequence_packing_gradient_test_actor, cp_size +): + """Test sequence packing gradients with context parallelism.""" + cluster = cluster_fixture + actor_fqn = register_sequence_packing_gradient_test_actor + + # For CP, all ranks are in a single group + sharding = NamedSharding(layout=list(range(cp_size)), names=["cp"]) + builder = RayWorkerBuilder(actor_fqn, cp_size) + + worker_group = RayWorkerGroup( + cluster=cluster, + remote_worker_builder=builder, + workers_per_node=None, + sharding_annotations=sharding, + ) + + # Run the test on all workers + futures = worker_group.run_all_workers_single_data( + "test_sequence_packing_gradients" + ) + _ = ray.get(futures) + worker_group.shutdown(force=True) diff --git a/tests/unit/data/packing/__init__.py b/tests/unit/data/packing/__init__.py new file mode 100644 index 0000000000..913e5a1c57 --- /dev/null +++ b/tests/unit/data/packing/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sequence packing algorithms.""" diff --git a/tests/unit/data/packing/test_algorithms.py b/tests/unit/data/packing/test_algorithms.py new file mode 100644 index 0000000000..a47951969e --- /dev/null +++ b/tests/unit/data/packing/test_algorithms.py @@ -0,0 +1,326 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sequence packing algorithms.""" + +import random +from typing import Dict, List + +import pytest + +from nemo_rl.data.packing.algorithms import ( + PackingAlgorithm, + SequencePacker, + get_packer, +) + + +def validate_solution( + sequence_lengths: List[int], bins: List[List[int]], bin_capacity: int +) -> bool: + """Validate that a packing solution is valid. + + Args: + sequence_lengths: The original list of sequence lengths. + bins: The packing solution, where each bin is a list of indices into sequence_lengths. + bin_capacity: The maximum capacity of each bin. + + Returns: + True if the packing is valid, False otherwise. + """ + # Check that all sequences are packed + all_indices = set() + for bin_indices in bins: + all_indices.update(bin_indices) + + if len(all_indices) != len(sequence_lengths): + return False + + # Check that each bin doesn't exceed capacity + for bin_indices in bins: + bin_load = sum(sequence_lengths[idx] for idx in bin_indices) + if bin_load > bin_capacity: + return False + + return True + + +class TestSequencePacker: + """Test suite for sequence packing algorithms.""" + + @pytest.fixture + def bin_capacity(self) -> int: + """Fixture for bin capacity.""" + return 100 + + @pytest.fixture + def small_sequence_lengths(self) -> List[int]: + """Fixture for a small list of sequence lengths.""" + return [10, 20, 30, 40, 50, 60, 70, 80, 90] + + @pytest.fixture + def medium_sequence_lengths(self) -> List[int]: + """Fixture for a medium-sized list of sequence lengths.""" + return [25, 35, 45, 55, 65, 75, 85, 95, 15, 25, 35, 45, 55, 65, 75, 85, 95] + + @pytest.fixture + def large_sequence_lengths(self) -> List[int]: + """Fixture for a large list of sequence lengths.""" + # Set a seed for reproducibility + random.seed(42) + return [random.randint(10, 90) for _ in range(100)] + + @pytest.fixture + def edge_cases(self) -> Dict[str, List[int]]: + """Fixture for edge cases.""" + return { + "empty": [], + "single_item": [50], + "all_same_size": [30, 30, 30, 30, 30], + "max_size": [100, 100, 100], + "mixed_sizes": [10, 50, 100, 20, 80, 30, 70, 40, 60, 90], + } + + # TODO(ahmadki): use the function to specify all test algorithms ins tead of lists below + @pytest.fixture + def algorithms(self) -> List[PackingAlgorithm]: + """Fixture for packing algorithms.""" + return [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ] + + def test_get_packer(self, bin_capacity: int, algorithms: List[PackingAlgorithm]): + """Test the get_packer factory function.""" + # Test that each algorithm name returns the correct packer + for algorithm in algorithms: + packer = get_packer(algorithm, bin_capacity) + assert isinstance(packer, SequencePacker) + + # Test with an invalid algorithm value + with pytest.raises(ValueError): + # Create a non-existent enum value by using an arbitrary object + invalid_algorithm = object() + get_packer(invalid_algorithm, bin_capacity) # type: ignore + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + def test_small_sequences( + self, + bin_capacity: int, + small_sequence_lengths: List[int], + algorithm: PackingAlgorithm, + ): + """Test packing small sequences with all algorithms.""" + packer = get_packer(algorithm, bin_capacity) + bins = packer.pack(small_sequence_lengths) + + # Validate the packing + assert validate_solution(small_sequence_lengths, bins, bin_capacity) + + # Print the number of bins used (for information) + print(f"{algorithm.name} used {len(bins)} bins for small sequences") + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + def test_medium_sequences( + self, + bin_capacity: int, + medium_sequence_lengths: List[int], + algorithm: PackingAlgorithm, + ): + """Test packing medium-sized sequences with all algorithms.""" + packer = get_packer(algorithm, bin_capacity) + bins = packer.pack(medium_sequence_lengths) + + # Validate the packing + assert validate_solution(medium_sequence_lengths, bins, bin_capacity) + + # Print the number of bins used (for information) + print(f"{algorithm.name} used {len(bins)} bins for medium sequences") + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + def test_large_sequences( + self, + bin_capacity: int, + large_sequence_lengths: List[int], + algorithm: PackingAlgorithm, + ): + """Test packing large sequences with all algorithms.""" + packer = get_packer(algorithm, bin_capacity) + bins = packer.pack(large_sequence_lengths) + + # Validate the packing + assert validate_solution(large_sequence_lengths, bins, bin_capacity) + + # Print the number of bins used (for information) + print(f"{algorithm.name} used {len(bins)} bins for large sequences") + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + # TODO(ahmadki): use the function to specify all test algorithms instead of lists below + @pytest.mark.parametrize( + "case_name, sequence_lengths", + [ + ("single_item", [50]), + ("all_same_size", [30, 30, 30, 30, 30]), + ("max_size", [100, 100, 100]), + ("mixed_sizes", [10, 50, 100, 20, 80, 30, 70, 40, 60, 90]), + ], + ) + def test_edge_cases( + self, + bin_capacity: int, + algorithm: PackingAlgorithm, + case_name: str, + sequence_lengths: List[int], + ): + """Test edge cases with all algorithms.""" + packer = get_packer(algorithm, bin_capacity) + bins = packer.pack(sequence_lengths) + + # Validate the packing + assert validate_solution(sequence_lengths, bins, bin_capacity) + + # For single item, check that only one bin is created + if case_name == "single_item": + assert len(bins) == 1 + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + def test_empty_list(self, bin_capacity: int, algorithm: PackingAlgorithm): + """Test empty list with algorithms that can handle it.""" + packer = get_packer(algorithm, bin_capacity) + bins = packer.pack([]) + + # For empty list, check that no bins are created + assert len(bins) == 0 + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + def test_error_cases(self, bin_capacity: int, algorithm: PackingAlgorithm): + """Test error cases with all algorithms.""" + # Test with a sequence length that exceeds bin capacity + sequence_lengths = [50, 150, 70] # 150 > bin_capacity (100) + + packer = get_packer(algorithm, bin_capacity) + with pytest.raises(ValueError): + packer.pack(sequence_lengths) + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + def test_deterministic( + self, + bin_capacity: int, + medium_sequence_lengths: List[int], + algorithm: PackingAlgorithm, + ): + """Test that deterministic algorithms produce the same result on multiple runs.""" + packer = get_packer(algorithm, bin_capacity) + + # Run the algorithm twice and check that the results are the same + bins1 = packer.pack(medium_sequence_lengths) + bins2 = packer.pack(medium_sequence_lengths) + + # Convert to a format that can be compared (sort each bin and then sort the bins) + sorted_bins1 = sorted([sorted(bin_indices) for bin_indices in bins1]) + sorted_bins2 = sorted([sorted(bin_indices) for bin_indices in bins2]) + + assert sorted_bins1 == sorted_bins2 + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.FIRST_FIT_SHUFFLE, + ], + ) + def test_randomized( + self, + bin_capacity: int, + medium_sequence_lengths: List[int], + algorithm: PackingAlgorithm, + ): + """Test that randomized algorithms can produce different results on multiple runs.""" + # Note: This test might occasionally fail due to randomness + + # Set different seeds to ensure different random behavior + random.seed(42) + packer1 = get_packer(algorithm, bin_capacity) + bins1 = packer1.pack(medium_sequence_lengths) + + random.seed(43) + packer2 = get_packer(algorithm, bin_capacity) + bins2 = packer2.pack(medium_sequence_lengths) + + # Convert to a format that can be compared + sorted_bins1 = sorted([sorted(bin_indices) for bin_indices in bins1]) + sorted_bins2 = sorted([sorted(bin_indices) for bin_indices in bins2]) + + # Check if the results are different + # This is a weak test, as randomness might still produce the same result + if sorted_bins1 == sorted_bins2: + print( + f"Warning: {algorithm.name} produced the same result with different seeds" + ) diff --git a/tests/unit/distributed/test_batched_data_dict.py b/tests/unit/distributed/test_batched_data_dict.py index 6b6c95c092..eaebf2dd8a 100644 --- a/tests/unit/distributed/test_batched_data_dict.py +++ b/tests/unit/distributed/test_batched_data_dict.py @@ -14,7 +14,11 @@ import pytest import torch -from nemo_rl.distributed.batched_data_dict import BatchedDataDict, DynamicBatchingArgs +from nemo_rl.distributed.batched_data_dict import ( + BatchedDataDict, + DynamicBatchingArgs, + SequencePackingArgs, +) def test_shard_by_batch_size_basic(): @@ -236,3 +240,367 @@ def test_shard_by_batch_size_dynamic(): batch_size, seqlen = mb["data"].shape assert seqlen % 4 == 0 assert seqlen <= 32 + + +def test_sequence_packing_basic(): + """Test basic functionality of sequence packing with modified FFD algorithm.""" + # Create sample data with varying sequence lengths + batch_size = 8 + max_seq_length = 512 + + # Generate random sequence lengths between 50 and 400 + torch.manual_seed(42) + sequence_lengths = torch.randint(50, 400, (batch_size,)) + + # Create input tensors with padding + input_ids = [] + for seq_len in sequence_lengths: + # Create a sequence with actual tokens up to seq_len, then padding + seq = torch.cat( + [ + torch.randint(1, 1000, (seq_len,)), # Actual tokens + torch.zeros(max_seq_length - seq_len, dtype=torch.long), # Padding + ] + ) + input_ids.append(seq) + + input_ids = torch.stack(input_ids) + + # Create batch data dict + batch_data = BatchedDataDict( + { + "input_ids": input_ids, + "sequence_lengths": sequence_lengths, + "problem_ids": torch.arange(batch_size), + } + ) + + # Configure sequence packing + sequence_packing_args = SequencePackingArgs( + max_tokens_per_microbatch=1024, + input_key="input_ids", + input_lengths_key="sequence_lengths", + algorithm="modified_first_fit_decreasing", + sequence_length_pad_multiple=1, + ) + + # Shard the batch with sequence packing + shards = 2 + sharded_batches, sorted_indices = batch_data.shard_by_batch_size( + shards=shards, sequence_packing_args=sequence_packing_args + ) + + # Verify output structure + assert len(sharded_batches) == shards + assert len(sorted_indices) == batch_size + + # Verify each shard has microbatch indices and lengths + for shard in sharded_batches: + assert hasattr(shard, "micro_batch_indices") + assert hasattr(shard, "micro_batch_lengths") + assert len(shard.micro_batch_indices) > 0 + assert len(shard.micro_batch_lengths) > 0 + + problem_ids_seen = set() + + # Verify microbatch structure + for chunk_indices, chunk_lengths in zip( + shard.micro_batch_indices, shard.micro_batch_lengths + ): + assert len(chunk_indices) == len(chunk_lengths) + + # Verify each microbatch respects the token limit + for (start_idx, end_idx), packed_len in zip(chunk_indices, chunk_lengths): + assert packed_len <= sequence_packing_args["max_tokens_per_microbatch"] + + for s in sharded_batches: + for mb in s.make_microbatch_iterator_for_packable_sequences(): + mb_len = mb["sequence_lengths"].sum().item() + assert mb_len <= sequence_packing_args["max_tokens_per_microbatch"] + for i in range(mb["input_ids"].shape[0]): + problem_id = mb["problem_ids"][i].item() + assert problem_id not in problem_ids_seen, ( + f"Problem ID {problem_id} seen twice" + ) + problem_ids_seen.add(problem_id) + assert len(problem_ids_seen) == batch_size + + +def test_sequence_packing_uniform_lengths(): + """Test sequence packing when all sequences have the same length.""" + batch_size = 12 + seq_length = 256 + + batch_data = BatchedDataDict( + { + "input_ids": torch.ones(batch_size, seq_length, dtype=torch.long), + "sequence_lengths": torch.full((batch_size,), seq_length), + "problem_ids": torch.arange(batch_size), + } + ) + + sequence_packing_args = SequencePackingArgs( + max_tokens_per_microbatch=1024, + input_key="input_ids", + input_lengths_key="sequence_lengths", + algorithm="modified_first_fit_decreasing", + sequence_length_pad_multiple=1, + ) + + sharded_batches, sorted_indices = batch_data.shard_by_batch_size( + shards=2, sequence_packing_args=sequence_packing_args + ) + + # With uniform lengths, sequences should be efficiently packed + assert len(sharded_batches) == 2 + len_0 = len( + list(sharded_batches[0].make_microbatch_iterator_for_packable_sequences()) + ) + len_1 = len( + list(sharded_batches[1].make_microbatch_iterator_for_packable_sequences()) + ) + assert len_0 + len_1 == 3 + assert min(len_0, len_1) == 1 + + # Each microbatch should pack as many sequences as possible + for shard in sharded_batches: + for chunk_indices, chunk_lengths in zip( + shard.micro_batch_indices, shard.micro_batch_lengths + ): + for (start_idx, end_idx), packed_len in zip(chunk_indices, chunk_lengths): + # With 256 tokens per sequence and 1024 max, should pack 4 sequences + assert packed_len <= 1024 + num_seqs = end_idx - start_idx + assert num_seqs <= 4 # Can fit at most 4 sequences of length 256 + + problem_ids_seen = set() + for s in sharded_batches: + for mb in s.make_microbatch_iterator_for_packable_sequences(): + mb_len = mb["sequence_lengths"].sum().item() + assert mb_len <= sequence_packing_args["max_tokens_per_microbatch"] + for i in range(mb["input_ids"].shape[0]): + problem_id = mb["problem_ids"][i].item() + assert problem_id not in problem_ids_seen, ( + f"Problem ID {problem_id} seen twice" + ) + problem_ids_seen.add(problem_id) + assert len(problem_ids_seen) == batch_size + + +def test_sequence_packing_long_sequences(): + """Test sequence packing with very long sequences that require individual microbatches.""" + batch_size = 4 + + batch_data = BatchedDataDict( + { + "input_ids": torch.ones(batch_size, 2048, dtype=torch.long), + "sequence_lengths": torch.tensor([900, 850, 1000, 950]), + "problem_ids": torch.arange(batch_size), + } + ) + + sequence_packing_args = SequencePackingArgs( + max_tokens_per_microbatch=1024, + input_key="input_ids", + input_lengths_key="sequence_lengths", + algorithm="modified_first_fit_decreasing", + sequence_length_pad_multiple=1, + ) + + sharded_batches, sorted_indices = batch_data.shard_by_batch_size( + shards=2, sequence_packing_args=sequence_packing_args + ) + + # Each sequence should be in its own microbatch due to length + for shard in sharded_batches: + for chunk_indices, chunk_lengths in zip( + shard.micro_batch_indices, shard.micro_batch_lengths + ): + for (start_idx, end_idx), max_len in zip(chunk_indices, chunk_lengths): + num_seqs = end_idx - start_idx + # Each long sequence should be alone in its microbatch + assert num_seqs == 1 + + problem_ids_seen = set() + for s in sharded_batches: + for mb in s.make_microbatch_iterator_for_packable_sequences(): + mb_len = mb["sequence_lengths"].sum().item() + assert mb_len <= sequence_packing_args["max_tokens_per_microbatch"] + for i in range(mb["input_ids"].shape[0]): + problem_id = mb["problem_ids"][i].item() + assert problem_id not in problem_ids_seen, ( + f"Problem ID {problem_id} seen twice" + ) + problem_ids_seen.add(problem_id) + assert len(problem_ids_seen) == batch_size + + +def test_sequence_packing_with_dynamic_batching_conflict(): + """Test that sequence packing and dynamic batching cannot be used together.""" + batch_data = BatchedDataDict( + { + "input_ids": torch.ones(4, 100, dtype=torch.long), + "sequence_lengths": torch.tensor([50, 60, 70, 80]), + } + ) + + sequence_packing_args = SequencePackingArgs( + max_tokens_per_microbatch=1024, + input_key="input_ids", + input_lengths_key="sequence_lengths", + algorithm="modified_first_fit_decreasing", + ) + + dynamic_batching_args: DynamicBatchingArgs = { + "input_key": "input_ids", + "input_lengths_key": "sequence_lengths", + "sequence_length_round": 4, + "max_tokens_per_microbatch": 1024, + } + + with pytest.raises( + AssertionError, + match="dynamic_batching_args and sequence_packing_args cannot be passed together", + ): + batch_data.shard_by_batch_size( + shards=2, + sequence_packing_args=sequence_packing_args, + dynamic_batching_args=dynamic_batching_args, + ) + + +@pytest.mark.parametrize("pad_to_multiple_of", [1, 32, 64, 256]) +def test_sequence_packing_microbatch_boundaries(pad_to_multiple_of): + """Test that microbatch boundaries are correctly maintained across chunks with random sequences.""" + # Create a large batch with random sequence lengths to test boundary handling + torch.manual_seed(123) # For reproducible tests + batch_size = 1024 + num_global_batches = 4 + max_seq_length = 1024 + max_tokens_per_microbatch = 1200 + + def _get_padded_seqlen(seqlen: int) -> int: + return (seqlen + (pad_to_multiple_of - 1)) // pad_to_multiple_of + + # Generate random sequence lengths with good variety + sequence_lengths = torch.randint(50, 800, (batch_size,)) + + # Create input tensors with padding + input_ids = [] + for i, seq_len in enumerate(sequence_lengths): + # Create a sequence with actual tokens up to seq_len, then padding + seq = torch.cat( + [ + torch.randint(1, 1000, (seq_len,)), # Actual tokens + torch.zeros(max_seq_length - seq_len, dtype=torch.long), # Padding + ] + ) + input_ids.append(seq) + + input_ids = torch.stack(input_ids) + + batch_data = BatchedDataDict( + { + "input_ids": input_ids, + "sequence_lengths": sequence_lengths, + "problem_ids": torch.arange(batch_size), + } + ) + + sequence_packing_args = SequencePackingArgs( + max_tokens_per_microbatch=max_tokens_per_microbatch, + input_key="input_ids", + input_lengths_key="sequence_lengths", + algorithm="modified_first_fit_decreasing", + sequence_length_pad_multiple=pad_to_multiple_of, + ) + + # Test with multiple shards and explicit batch_size to create chunks + shards = 4 + chunk_batch_size = batch_size // num_global_batches + sharded_batches, sorted_indices = batch_data.shard_by_batch_size( + shards=shards, + batch_size=chunk_batch_size, + sequence_packing_args=sequence_packing_args, + ) + + # Verify output structure + assert len(sharded_batches) == shards + assert len(sorted_indices) == batch_size + + # Track all problem IDs to ensure completeness and no duplicates + problem_ids_seen = set() + + for gb_idx in range(num_global_batches): + mb_count_for_gb = 0 + min_mb_count = 100000000 # arbitrary large number + max_mb_count = 0 + legal_problem_ids = set( + range(gb_idx * chunk_batch_size, (gb_idx + 1) * chunk_batch_size) + ) + for shard_idx in range(shards): + shard_batch = sharded_batches[shard_idx].get_batch(gb_idx) + mb_count = 0 + for mb in shard_batch.make_microbatch_iterator_for_packable_sequences(): + mb_count += 1 + for i in range(mb["input_ids"].shape[0]): + problem_id = mb["problem_ids"][i].item() + assert problem_id in legal_problem_ids, ( + f"Problem ID {problem_id} not in legal problem IDs" + ) + assert problem_id not in problem_ids_seen, ( + f"Problem ID {problem_id} seen twice" + ) + problem_ids_seen.add(problem_id) + assert ( + _get_padded_seqlen(mb["sequence_lengths"]).sum().item() + <= max_tokens_per_microbatch + ), ( + f"Sequence length {_get_padded_seqlen(mb['sequence_lengths']).sum().item()} is greater than max tokens per microbatch {max_tokens_per_microbatch}" + ) + + min_mb_count = min(min_mb_count, mb_count) + max_mb_count = max(max_mb_count, mb_count) + mb_count_for_gb += mb_count + assert max_mb_count - min_mb_count <= 1 + + num_actual_tokens = sum( + sequence_lengths[ + gb_idx * chunk_batch_size : (gb_idx + 1) * chunk_batch_size + ] + ) + packing_efficiency = num_actual_tokens / ( + mb_count_for_gb * max_tokens_per_microbatch + ) + + pack_efficiency_standards = { + 1: (0.97, 1.0), + 32: (0.92, 0.97), + 64: (0.85, 0.92), + 256: (0.60, 0.80), + } + assert packing_efficiency >= pack_efficiency_standards[pad_to_multiple_of][0], ( + f"We expect packing efficiency to be above {pack_efficiency_standards[pad_to_multiple_of][0]} for these nice random inputs with padding to multiples of {pad_to_multiple_of}. Got {packing_efficiency}" + ) + assert packing_efficiency <= pack_efficiency_standards[pad_to_multiple_of][1], ( + f"We expect packing efficiency to be below {pack_efficiency_standards[pad_to_multiple_of][1]} for these nice random inputs with padding to multiples of {pad_to_multiple_of}. Got {packing_efficiency}" + ) + + assert len(problem_ids_seen) == batch_size + + # Finally, test that we can reorder everything back to how it was before + reconstructed = BatchedDataDict.from_batches(sharded_batches) + # check that it's different from the original + assert not torch.all(reconstructed["problem_ids"] == batch_data["problem_ids"]) + assert not torch.all(reconstructed["input_ids"] == batch_data["input_ids"]) + assert not torch.all( + reconstructed["sequence_lengths"] == batch_data["sequence_lengths"] + ) + + reconstructed.reorder_data(sorted_indices) + # check that it's the same as the original + assert torch.all(reconstructed["problem_ids"] == batch_data["problem_ids"]) + assert torch.all(reconstructed["input_ids"] == batch_data["input_ids"]) + assert torch.all( + reconstructed["sequence_lengths"] == batch_data["sequence_lengths"] + ) diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py new file mode 100644 index 0000000000..cee92c49b0 --- /dev/null +++ b/tests/unit/distributed/test_model_utils.py @@ -0,0 +1,424 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import ray +import torch + +from nemo_rl.distributed.model_utils import ( + _get_tokens_on_this_cp_rank, + allgather_cp_sharded_tensor, + from_parallel_logits_to_logprobs, + from_parallel_logits_to_logprobs_packed_sequences, +) +from nemo_rl.distributed.named_sharding import NamedSharding +from nemo_rl.distributed.ray_actor_environment_registry import ( + ACTOR_ENVIRONMENT_REGISTRY, + PY_EXECUTABLES, +) +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup + + +@ray.remote(num_gpus=1) +class ModelUtilsTestActor: + def __init__(self, tp_size, cp_size, sharding): + self.tp_size = tp_size + self.cp_size = cp_size + self.sharding = sharding + self.env_vars = dict(os.environ) + + def test_packed_sequences_equivalence(self): + """Test that packed and unpacked functions return the same results.""" + # Initialize worker groups + torch.distributed.init_process_group(backend="nccl") + + tp_rank = int(os.environ["RANK"]) % self.tp_size + cp_rank = int(os.environ["RANK"]) // self.tp_size + tp_ranks = self.sharding.get_ranks(tp=tp_rank) + if type(tp_ranks) != int: + tp_ranks = tp_ranks.layout.tolist() + else: + tp_ranks = [tp_ranks] + cp_ranks = self.sharding.get_ranks(cp=cp_rank) + if type(cp_ranks) != int: + cp_ranks = cp_ranks.layout.tolist() + else: + cp_ranks = [cp_ranks] + + tp_group = torch.distributed.new_group(ranks=cp_ranks) + cp_group = torch.distributed.new_group(ranks=tp_ranks) # this is correct + + # Test parameters + batch_size = 4 + seq_len = 32 + vocab_size = 1024 + + if self.cp_size > 1 and seq_len % (2 * self.cp_size) != 0: + seq_len = (seq_len // (2 * self.cp_size) + 1) * (2 * self.cp_size) + + vocab_part_size = vocab_size // self.tp_size + vocab_start_index = tp_rank * vocab_part_size + vocab_end_index = (tp_rank + 1) * vocab_part_size + + unpacked_seq_len = seq_len + + # Create random data + torch.manual_seed(42) # For reproducibility + unpacked_logits = torch.randn( + batch_size, unpacked_seq_len, vocab_part_size, device="cuda" + ) + unpacked_target_ids = ( + torch.arange(batch_size * seq_len).reshape(batch_size, seq_len).to("cuda") + ) + + # 1. Get expected logprobs from non-packed function + expected_logprobs = from_parallel_logits_to_logprobs( + unpacked_logits, + unpacked_target_ids, + vocab_start_index, + vocab_end_index, + tp_group, + cp_group=None, + ) + + # 1.5 get with_cp logprobs + with_cp_logprobs = from_parallel_logits_to_logprobs( + _get_tokens_on_this_cp_rank( + unpacked_logits, cp_rank, self.cp_size, seq_dim=1 + ), + unpacked_target_ids, + vocab_start_index, + vocab_end_index, + tp_group, + cp_group=cp_group, + ) + + torch.testing.assert_close( + with_cp_logprobs, expected_logprobs, rtol=1e-5, atol=1e-5 + ) + + # 2. Prepare inputs for packed function + # For simplicity, all sequences have the same length + seq_lengths = torch.full((batch_size,), seq_len, dtype=torch.int32) + cu_seqlens = torch.nn.functional.pad( + torch.cumsum(seq_lengths, dim=0, dtype=torch.int32), (1, 0) + ).to("cuda") + + # Pack the logits and target_ids + packed_logits = _get_tokens_on_this_cp_rank( + unpacked_logits, cp_rank, self.cp_size, seq_dim=1 + ).reshape(1, -1, vocab_part_size) + packed_target_ids = unpacked_target_ids.reshape(1, -1) + + # 3. Get actual logprobs from packed function + actual_logprobs = from_parallel_logits_to_logprobs_packed_sequences( + packed_logits, + packed_target_ids, + cu_seqlens, + seq_len, # unpacked_seqlen + vocab_start_index, + vocab_end_index, + tp_group, + cp_group=cp_group, + ) + + # 4. Compare results + torch.testing.assert_close( + actual_logprobs, expected_logprobs, rtol=1e-5, atol=1e-5 + ) + return {"success": True, "error": None} + + +MODEL_UTILS_TEST_ACTOR_FQN = f"{ModelUtilsTestActor.__module__}.ModelUtilsTestActor" + + +@pytest.fixture +def register_model_utils_test_actor(): + """Register the ModelUtilsTestActor for use in tests.""" + original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get(MODEL_UTILS_TEST_ACTOR_FQN) + ACTOR_ENVIRONMENT_REGISTRY[MODEL_UTILS_TEST_ACTOR_FQN] = PY_EXECUTABLES.SYSTEM + + yield MODEL_UTILS_TEST_ACTOR_FQN + + # Clean up registry + if MODEL_UTILS_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY: + if original_registry_value is None: + del ACTOR_ENVIRONMENT_REGISTRY[MODEL_UTILS_TEST_ACTOR_FQN] + else: + ACTOR_ENVIRONMENT_REGISTRY[MODEL_UTILS_TEST_ACTOR_FQN] = ( + original_registry_value + ) + + +@pytest.fixture +def virtual_cluster_2_gpus(): + """Create a virtual cluster with 2 GPU bundles.""" + cluster = RayVirtualCluster(bundle_ct_per_node_list=[2], use_gpus=True) + yield cluster + cluster.shutdown() + + +@pytest.fixture +def virtual_cluster_4_gpus(): + """Create a virtual cluster with 4 GPU bundles.""" + cluster = RayVirtualCluster(bundle_ct_per_node_list=[4], use_gpus=True) + yield cluster + cluster.shutdown() + + +import numpy as np + + +@pytest.mark.parametrize( + "tp_cp_config", + [ + (2, 1), # TP=2, CP=1 + (1, 2), # TP=1, CP=2 + ], +) +def test_from_parallel_logits_to_logprobs_packed_sequences( + register_model_utils_test_actor, tp_cp_config +): + """Test packed sequences function against unpacked version.""" + tp_size, cp_size = tp_cp_config + world_size = tp_size * cp_size + + # Skip if not enough GPUs + if not torch.cuda.is_available() or torch.cuda.device_count() < world_size: + pytest.skip( + f"Not enough GPUs available. Need {world_size}, got {torch.cuda.device_count()}" + ) + + # Create appropriate virtual cluster + cluster = RayVirtualCluster(bundle_ct_per_node_list=[2], use_gpus=True) + + try: + actor_fqn = register_model_utils_test_actor + + sharding = NamedSharding( + layout=np.arange(world_size).reshape(tp_size, cp_size), names=["tp", "cp"] + ) + builder = RayWorkerBuilder(actor_fqn, tp_size, cp_size, sharding) + + worker_group = RayWorkerGroup( + cluster=cluster, + remote_worker_builder=builder, + workers_per_node=None, + sharding_annotations=sharding, + ) + + # Run the test on all workers + futures = worker_group.run_all_workers_single_data( + "test_packed_sequences_equivalence" + ) + results = ray.get(futures) + + # Check that all workers succeeded + for i, result in enumerate(results): + assert result["success"], f"Worker {i} failed: {result['error']}" + + worker_group.shutdown(force=True) + + finally: + cluster.shutdown() + + +@ray.remote(num_gpus=1) +class AllGatherCPTestActor: + def __init__(self, cp_size): + self.cp_size = cp_size + self.env_vars = dict(os.environ) + + def test_allgather_cp_tensor(self): + """Test that allgather_cp_sharded_tensor correctly reconstructs tensors.""" + # Initialize process group + torch.distributed.init_process_group(backend="nccl") + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # Create CP group - all ranks participate in CP + cp_group = torch.distributed.new_group(ranks=list(range(world_size))) + + # Test parameters + batch_size = 2 + original_seq_len = 8 + hidden_size = 16 + + # Ensure sequence length is compatible with CP load balancing + if original_seq_len % (2 * self.cp_size) != 0: + original_seq_len = (original_seq_len // (2 * self.cp_size) + 1) * ( + 2 * self.cp_size + ) + + # Create original tensor (same on all ranks for testing) + torch.manual_seed(42) # Same seed for reproducibility + original_tensor = ( + torch.arange( + batch_size * original_seq_len * hidden_size, dtype=torch.float32 + ) + .reshape(batch_size, original_seq_len, hidden_size) + .to("cuda") + ) + original_tensor.requires_grad = True + + # Shard the tensor using CP logic + sharded_tensor = _get_tokens_on_this_cp_rank( + original_tensor, rank, self.cp_size, seq_dim=1 + ) + + # Test 1: Gather sharded tensor and verify it matches original + gathered_tensor = allgather_cp_sharded_tensor( + sharded_tensor, cp_group, seq_dim=1 + ) + + # Verify shapes match + if gathered_tensor.shape != original_tensor.shape: + return { + "success": False, + "error": f"Shape mismatch: expected {original_tensor.shape}, got {gathered_tensor.shape}", + } + + # Verify content matches (should be identical) + torch.testing.assert_close( + gathered_tensor, original_tensor, rtol=1e-5, atol=1e-5 + ) + + # test backward + def loss_fn(x): + return torch.sum(x**2) + + loss = loss_fn(gathered_tensor) + loss.backward() + grad = original_tensor.grad / self.cp_size + grad_sharded = _get_tokens_on_this_cp_rank(grad, rank, self.cp_size, seq_dim=1) + + torch.testing.assert_close( + grad_sharded, + _get_tokens_on_this_cp_rank( + 2 * original_tensor, rank, self.cp_size, seq_dim=1 + ), + rtol=1e-5, + atol=1e-5, + ) + torch.testing.assert_close( + _get_tokens_on_this_cp_rank( + grad, (rank + 1) % self.cp_size, self.cp_size, seq_dim=1 + ), + torch.zeros_like(sharded_tensor), + rtol=1e-5, + atol=1e-5, + ) + + # Test 2: Test with different sequence dimension (seq_dim=0) + # Create a tensor with sequence dimension at dim=0 + original_tensor_dim0 = torch.randn( + original_seq_len, batch_size, hidden_size, device="cuda" + ) + + sharded_tensor_dim0 = _get_tokens_on_this_cp_rank( + original_tensor_dim0, rank, self.cp_size, seq_dim=0 + ) + + gathered_tensor_dim0 = allgather_cp_sharded_tensor( + sharded_tensor_dim0, cp_group, seq_dim=0 + ) + + # Verify shapes and content match + if gathered_tensor_dim0.shape != original_tensor_dim0.shape: + return { + "success": False, + "error": f"Shape mismatch for seq_dim=0: expected {original_tensor_dim0.shape}, got {gathered_tensor_dim0.shape}", + } + + torch.testing.assert_close( + gathered_tensor_dim0, original_tensor_dim0, rtol=1e-5, atol=1e-5 + ) + + # Test 3: Test with different tensor shapes + # Test with 2D tensor + original_2d = torch.randn(original_seq_len, hidden_size, device="cuda") + sharded_2d = _get_tokens_on_this_cp_rank( + original_2d, rank, self.cp_size, seq_dim=0 + ) + gathered_2d = allgather_cp_sharded_tensor(sharded_2d, cp_group, seq_dim=0) + + torch.testing.assert_close(gathered_2d, original_2d, rtol=1e-5, atol=1e-5) + + return {"success": True, "error": None} + + +ALLGATHER_CP_TEST_ACTOR_FQN = f"{AllGatherCPTestActor.__module__}.AllGatherCPTestActor" + + +@pytest.fixture +def register_allgather_cp_test_actor(): + """Register the AllGatherCPTestActor for use in tests.""" + original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get( + ALLGATHER_CP_TEST_ACTOR_FQN + ) + ACTOR_ENVIRONMENT_REGISTRY[ALLGATHER_CP_TEST_ACTOR_FQN] = PY_EXECUTABLES.SYSTEM + + yield ALLGATHER_CP_TEST_ACTOR_FQN + + # Clean up registry + if ALLGATHER_CP_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY: + if original_registry_value is None: + del ACTOR_ENVIRONMENT_REGISTRY[ALLGATHER_CP_TEST_ACTOR_FQN] + else: + ACTOR_ENVIRONMENT_REGISTRY[ALLGATHER_CP_TEST_ACTOR_FQN] = ( + original_registry_value + ) + + +@pytest.mark.parametrize("cp_size", [2]) +def test_allgather_cp_sharded_tensor(register_allgather_cp_test_actor, cp_size): + """Test allgather_cp_sharded_tensor function.""" + # Skip if not enough GPUs + if not torch.cuda.is_available() or torch.cuda.device_count() < cp_size: + pytest.skip( + f"Not enough GPUs available. Need {cp_size}, got {torch.cuda.device_count()}" + ) + + # Create virtual cluster + cluster = RayVirtualCluster(bundle_ct_per_node_list=[cp_size], use_gpus=True) + + try: + actor_fqn = register_allgather_cp_test_actor + + # For CP, all ranks are in a single group + sharding = NamedSharding(layout=list(range(cp_size)), names=["cp"]) + builder = RayWorkerBuilder(actor_fqn, cp_size) + + worker_group = RayWorkerGroup( + cluster=cluster, + remote_worker_builder=builder, + workers_per_node=None, + sharding_annotations=sharding, + ) + + # Run the test on all workers + futures = worker_group.run_all_workers_single_data("test_allgather_cp_tensor") + results = ray.get(futures) + + # Check that all workers succeeded + for i, result in enumerate(results): + assert result["success"], f"Worker {i} failed: {result['error']}" + + worker_group.shutdown(force=True) + + finally: + cluster.shutdown() diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 8d6cff05f8..8a38e5c61e 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -107,6 +107,9 @@ "logprob_mb_tokens": 40, "sequence_length_round": 4, }, + "sequence_packing": { + "enabled": False, + }, "max_grad_norm": 1.0, "make_sequence_length_divisible_by": 1, "generation": deepcopy(basic_vllm_test_config), @@ -139,6 +142,9 @@ def get_basic_megatron_test_config( "dynamic_batching": { "enabled": False, # Start with simple batching }, + "sequence_packing": { + "enabled": False, + }, "megatron_cfg": { "enabled": True, "empty_unused_memory_level": 0, diff --git a/tests/unit/models/megatron/__init__.py b/tests/unit/models/megatron/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/tests/unit/models/megatron/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/models/megatron/test_common.py b/tests/unit/models/megatron/test_common.py new file mode 100644 index 0000000000..cc1214566a --- /dev/null +++ b/tests/unit/models/megatron/test_common.py @@ -0,0 +1,707 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import ray +import torch + +from nemo_rl.distributed.named_sharding import NamedSharding +from nemo_rl.distributed.ray_actor_environment_registry import ( + ACTOR_ENVIRONMENT_REGISTRY, + PY_EXECUTABLES, +) +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup + + +@ray.remote(num_gpus=1) +class PackSequencesTestActor: + def __init__(self, cp_size): + self.cp_size = cp_size + self.env_vars = dict(os.environ) + + def run_all_pack_sequences_tests(self): + """Run all sequence packing tests in a single call to avoid expensive reinitializations.""" + from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank + from nemo_rl.models.megatron.common import _pack_sequences_for_megatron + + # Initialize process group if CP > 1 + if self.cp_size > 1: + torch.distributed.init_process_group(backend="nccl") + rank = int(os.environ["RANK"]) + else: + rank = 0 + + results = {} + + # Test 1: Basic packing functionality + results["basic"] = self._test_basic_packing(_pack_sequences_for_megatron) + if not results["basic"]["success"]: + return results["basic"] + + # Test 2: Variable sequence lengths + results["variable_lengths"] = self._test_variable_lengths( + _pack_sequences_for_megatron + ) + if not results["variable_lengths"]["success"]: + return results["variable_lengths"] + + # Test 3: Content preservation and consistency + results["consistency"] = self._test_consistency(_pack_sequences_for_megatron) + if not results["consistency"]["success"]: + return results["consistency"] + + # Test 4: Edge cases + results["edge_cases"] = self._test_edge_cases(_pack_sequences_for_megatron) + if not results["edge_cases"]["success"]: + return results["edge_cases"] + + # Test 5: Context parallelism (only if CP > 1) + if self.cp_size > 1: + results["context_parallel"] = self._test_context_parallel( + _pack_sequences_for_megatron, _get_tokens_on_this_cp_rank, rank + ) + if not results["context_parallel"]["success"]: + return results["context_parallel"] + else: + results["context_parallel"] = { + "success": True, + "error": None, + "skipped": "CP=1", + } + + return {"success": True, "error": None, "detailed_results": results} + + def _test_basic_packing(self, _pack_sequences_for_megatron): + """Test basic sequence packing without context parallelism.""" + try: + # Test parameters + batch_size = 3 + max_seq_len = 10 + vocab_size = 100 + + # Create test data with variable sequence lengths + input_ids = torch.randint( + 0, vocab_size, (batch_size, max_seq_len), device="cuda" + ) + seq_lengths = torch.tensor([8, 5, 7], device="cuda") + + # Test 1: Basic packing without CP + packed_input_ids, _, packed_seq_params, cu_seqlens, cu_seqlens_padded = ( + _pack_sequences_for_megatron( + input_ids, seq_lengths, cp_rank=0, cp_size=1 + ) + ) + + # Verify shapes + expected_total_tokens = seq_lengths.sum().item() + if packed_input_ids.shape != (1, expected_total_tokens): + return { + "success": False, + "error": f"Basic packing shape mismatch: expected (1, {expected_total_tokens}), got {packed_input_ids.shape}", + } + + # Verify cu_seqlens + expected_cu_seqlens = torch.tensor( + [0, 8, 13, 20], device="cuda", dtype=torch.int32 + ) + if not torch.equal(cu_seqlens, expected_cu_seqlens): + return { + "success": False, + "error": f"cu_seqlens mismatch: expected {expected_cu_seqlens}, got {cu_seqlens}", + } + + # Verify PackedSeqParams + if packed_seq_params.qkv_format != "thd": + return { + "success": False, + "error": f"Wrong qkv_format: expected 'thd', got {packed_seq_params.qkv_format}", + } + + if packed_seq_params.max_seqlen_q != 8: + return { + "success": False, + "error": f"Wrong max_seqlen_q: expected 8, got {packed_seq_params.max_seqlen_q}", + } + + # Test 2: Packing with individual sequence padding + ( + packed_input_ids_pad, + _, + packed_seq_params_pad, + cu_seqlens_pad, + cu_seqlens_padded_pad, + ) = _pack_sequences_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of=4, + cp_rank=0, + cp_size=1, + ) + + # With padding to multiple of 4: [8, 5, 7] -> [8, 8, 8] = 24 tokens + expected_total_tokens_pad = 24 + if packed_input_ids_pad.shape != (1, expected_total_tokens_pad): + return { + "success": False, + "error": f"Padded packing shape mismatch: expected (1, {expected_total_tokens_pad}), got {packed_input_ids_pad.shape}", + } + + # Verify padded cu_seqlens + expected_cu_seqlens_padded = torch.tensor( + [0, 8, 16, 24], device="cuda", dtype=torch.int32 + ) + if not torch.equal(cu_seqlens_padded_pad, expected_cu_seqlens_padded): + return { + "success": False, + "error": f"Padded cu_seqlens mismatch: expected {expected_cu_seqlens_padded}, got {cu_seqlens_padded_pad}", + } + + return {"success": True, "error": None} + + except Exception as e: + return {"success": False, "error": f"Basic packing test failed: {str(e)}"} + + def _test_variable_lengths(self, _pack_sequences_for_megatron): + """Test sequence packing with variable sequence lengths.""" + try: + # Test parameters + batch_size = 4 + max_seq_len = 12 + vocab_size = 50 + + # Create test data with highly variable sequence lengths + input_ids = torch.randint( + 0, vocab_size, (batch_size, max_seq_len), device="cuda" + ) + seq_lengths = torch.tensor([12, 3, 8, 1], device="cuda") + + # Test 1: Variable lengths without padding + packed_input_ids, _, packed_seq_params, cu_seqlens, cu_seqlens_padded = ( + _pack_sequences_for_megatron( + input_ids, seq_lengths, cp_rank=0, cp_size=1 + ) + ) + + # Verify total tokens + expected_total_tokens = seq_lengths.sum().item() # 12 + 3 + 8 + 1 = 24 + if packed_input_ids.shape != (1, expected_total_tokens): + return { + "success": False, + "error": f"Variable lengths shape mismatch: expected (1, {expected_total_tokens}), got {packed_input_ids.shape}", + } + + # Verify cu_seqlens + expected_cu_seqlens = torch.tensor( + [0, 12, 15, 23, 24], device="cuda", dtype=torch.int32 + ) + if not torch.equal(cu_seqlens, expected_cu_seqlens): + return { + "success": False, + "error": f"Variable lengths cu_seqlens mismatch: expected {expected_cu_seqlens}, got {cu_seqlens}", + } + + # Test 2: Variable lengths with padding + ( + packed_input_ids_pad, + _, + packed_seq_params_pad, + cu_seqlens_pad, + cu_seqlens_padded_pad, + ) = _pack_sequences_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of=4, + cp_rank=0, + cp_size=1, + ) + + # With padding to multiple of 4: [12, 3, 8, 1] -> [12, 4, 8, 4] = 28 tokens + expected_total_tokens_pad = 28 + if packed_input_ids_pad.shape != (1, expected_total_tokens_pad): + return { + "success": False, + "error": f"Variable lengths padded shape mismatch: expected (1, {expected_total_tokens_pad}), got {packed_input_ids_pad.shape}", + } + + # Verify padded cu_seqlens + expected_cu_seqlens_padded = torch.tensor( + [0, 12, 16, 24, 28], device="cuda", dtype=torch.int32 + ) + if not torch.equal(cu_seqlens_padded_pad, expected_cu_seqlens_padded): + return { + "success": False, + "error": f"Variable lengths padded cu_seqlens mismatch: expected {expected_cu_seqlens_padded}, got {cu_seqlens_padded_pad}", + } + + # Verify max_seqlen + if packed_seq_params.max_seqlen_q != 12: + return { + "success": False, + "error": f"Variable lengths wrong max_seqlen_q: expected 12, got {packed_seq_params.max_seqlen_q}", + } + + if packed_seq_params_pad.max_seqlen_q != 12: + return { + "success": False, + "error": f"Variable lengths padded wrong max_seqlen_q: expected 12, got {packed_seq_params_pad.max_seqlen_q}", + } + + return {"success": True, "error": None} + + except Exception as e: + return { + "success": False, + "error": f"Variable lengths test failed: {str(e)}", + } + + def _test_consistency(self, _pack_sequences_for_megatron): + """Test that packing produces consistent results and that content is preserved.""" + try: + # Test parameters + batch_size = 2 + seq_len = 8 + vocab_size = 20 + + # Create deterministic test data + torch.manual_seed(123) + input_ids = torch.randint( + 0, vocab_size, (batch_size, seq_len), device="cuda" + ) + seq_lengths = torch.tensor([6, 4], device="cuda") + + # Test consistency between multiple calls + ( + packed_input_ids_1, + _, + packed_seq_params_1, + cu_seqlens_1, + cu_seqlens_padded_1, + ) = _pack_sequences_for_megatron( + input_ids, seq_lengths, cp_rank=0, cp_size=1 + ) + + ( + packed_input_ids_2, + _, + packed_seq_params_2, + cu_seqlens_2, + cu_seqlens_padded_2, + ) = _pack_sequences_for_megatron( + input_ids, seq_lengths, cp_rank=0, cp_size=1 + ) + + # Verify consistency + if not torch.equal(packed_input_ids_1, packed_input_ids_2): + return { + "success": False, + "error": "Inconsistent packed_input_ids between calls", + } + + if not torch.equal(cu_seqlens_1, cu_seqlens_2): + return { + "success": False, + "error": "Inconsistent cu_seqlens between calls", + } + + # Verify content preservation + # Extract the first sequence (length 6) and compare with original + first_seq_packed = packed_input_ids_1[0, :6] + first_seq_original = input_ids[0, :6] + + if not torch.equal(first_seq_packed, first_seq_original): + return { + "success": False, + "error": "Content not preserved in first sequence", + } + + # Extract the second sequence (length 4) and compare with original + second_seq_packed = packed_input_ids_1[0, 6:10] + second_seq_original = input_ids[1, :4] + + if not torch.equal(second_seq_packed, second_seq_original): + return { + "success": False, + "error": "Content not preserved in second sequence", + } + + return {"success": True, "error": None} + + except Exception as e: + return {"success": False, "error": f"Consistency test failed: {str(e)}"} + + def _test_edge_cases(self, _pack_sequences_for_megatron): + """Test edge cases and error conditions.""" + try: + # Test 1: Single sequence + batch_size = 1 + seq_len = 10 + vocab_size = 50 + + input_ids = torch.randint( + 0, vocab_size, (batch_size, seq_len), device="cuda" + ) + seq_lengths = torch.tensor([seq_len], device="cuda") + + packed_input_ids, _, packed_seq_params, cu_seqlens, cu_seqlens_padded = ( + _pack_sequences_for_megatron( + input_ids, seq_lengths, cp_rank=0, cp_size=1 + ) + ) + + # Verify single sequence packing + if packed_input_ids.shape != (1, seq_len): + return { + "success": False, + "error": f"Single sequence shape mismatch: expected (1, {seq_len}), got {packed_input_ids.shape}", + } + + expected_cu_seqlens = torch.tensor( + [0, seq_len], device="cuda", dtype=torch.int32 + ) + if not torch.equal(cu_seqlens, expected_cu_seqlens): + return { + "success": False, + "error": f"Single sequence cu_seqlens mismatch: expected {expected_cu_seqlens}, got {cu_seqlens}", + } + + # Test 2: Empty sequences (length 0) + batch_size = 3 + max_seq_len = 5 + input_ids = torch.randint( + 0, vocab_size, (batch_size, max_seq_len), device="cuda" + ) + seq_lengths = torch.tensor([3, 0, 2], device="cuda") + + packed_input_ids, _, packed_seq_params, cu_seqlens, cu_seqlens_padded = ( + _pack_sequences_for_megatron( + input_ids, seq_lengths, cp_rank=0, cp_size=1 + ) + ) + + # Should handle empty sequences gracefully + expected_total_tokens = 5 # 3 + 0 + 2 + if packed_input_ids.shape != (1, expected_total_tokens): + return { + "success": False, + "error": f"Empty sequence shape mismatch: expected (1, {expected_total_tokens}), got {packed_input_ids.shape}", + } + + expected_cu_seqlens = torch.tensor( + [0, 3, 3, 5], device="cuda", dtype=torch.int32 + ) + if not torch.equal(cu_seqlens, expected_cu_seqlens): + return { + "success": False, + "error": f"Empty sequence cu_seqlens mismatch: expected {expected_cu_seqlens}, got {cu_seqlens}", + } + + # Test 3: Large padding values + batch_size = 2 + seq_len = 4 + input_ids = torch.randint( + 0, vocab_size, (batch_size, seq_len), device="cuda" + ) + seq_lengths = torch.tensor([3, 2], device="cuda") + + packed_input_ids, _, packed_seq_params, cu_seqlens, cu_seqlens_padded = ( + _pack_sequences_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of=8, + cp_rank=0, + cp_size=1, + ) + ) + + # With padding to multiple of 8: [3, 2] -> [8, 8] = 16 tokens + expected_total_tokens = 16 + if packed_input_ids.shape != (1, expected_total_tokens): + return { + "success": False, + "error": f"Large padding shape mismatch: expected (1, {expected_total_tokens}), got {packed_input_ids.shape}", + } + + return {"success": True, "error": None} + + except Exception as e: + return {"success": False, "error": f"Edge cases test failed: {str(e)}"} + + def _test_context_parallel( + self, _pack_sequences_for_megatron, _get_tokens_on_this_cp_rank, rank + ): + """Test sequence packing with context parallelism.""" + # Test parameters + batch_size = 2 + seq_len = 16 # Ensure divisible by cp_size * 2 + vocab_size = 100 + + # Ensure sequence length is compatible with CP + if seq_len % (2 * self.cp_size) != 0: + seq_len = (seq_len // (2 * self.cp_size) + 1) * (2 * self.cp_size) + + # Create test data + torch.manual_seed(42) # For reproducibility + input_ids = torch.arange(seq_len * batch_size, device="cuda").reshape( + batch_size, seq_len + ) + seq_lengths = torch.tensor([seq_len, seq_len], device="cuda") + + # Test 1: CP packing with individual sequence padding + ( + packed_input_ids, + packed_input_ids_cp_sharded, + packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) = _pack_sequences_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of=self.cp_size * 2, + cp_rank=rank, + cp_size=self.cp_size, + ) + + # Verify the packed tensor shape + expected_tokens_per_rank = seq_len // self.cp_size + expected_total_tokens = batch_size * expected_tokens_per_rank + if packed_input_ids_cp_sharded.shape != (1, expected_total_tokens): + return { + "success": False, + "error": f"CP packing shape mismatch: expected (1, {expected_total_tokens}), got {packed_input_ids_cp_sharded.shape}", + } + + # Verify cu_seqlens for original sequences + expected_cu_seqlens = torch.tensor( + [0, seq_len, seq_len * 2], device="cuda", dtype=torch.int32 + ) + if not torch.equal(cu_seqlens, expected_cu_seqlens): + return { + "success": False, + "error": f"CP cu_seqlens mismatch: expected {expected_cu_seqlens}, got {cu_seqlens}", + } + + # Verify PackedSeqParams + if packed_seq_params.qkv_format != "thd": + return { + "success": False, + "error": f"CP wrong qkv_format: expected 'thd', got {packed_seq_params.qkv_format}", + } + + # Test 2: CP packing with full sequence padding + pad_full_seq_to = (batch_size * seq_len) + 8 # Add some padding + ( + packed_input_ids_full, + packed_input_ids_cp_sharded, + packed_seq_params_full, + cu_seqlens_full, + cu_seqlens_padded_full, + ) = _pack_sequences_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of=self.cp_size * 2, + pad_packed_seq_to=pad_full_seq_to, + cp_rank=rank, + cp_size=self.cp_size, + ) + + # Verify the packed tensor shape with full padding + expected_tokens_per_rank_full = pad_full_seq_to // self.cp_size + if packed_input_ids_cp_sharded.shape != (1, expected_tokens_per_rank_full): + return { + "success": False, + "error": f"CP full padding shape mismatch: expected (1, {expected_tokens_per_rank_full}), got {packed_input_ids_cp_sharded.shape}", + } + + # Verify cu_seqlens_padded for full padding + expected_cu_seqlens_padded_full = torch.tensor( + [0, seq_len, pad_full_seq_to], device="cuda", dtype=torch.int32 + ) + if not torch.equal(cu_seqlens_padded_full, expected_cu_seqlens_padded_full): + return { + "success": False, + "error": f"CP full padding cu_seqlens_padded mismatch: expected {expected_cu_seqlens_padded_full}, got {cu_seqlens_padded_full}", + } + + correct_ids_0 = torch.tensor( + [0, 1, 2, 3, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 0, 0, 0, 0, 0, 0], + device="cuda", + ) + correct_ids_1 = torch.tensor( + [4, 5, 6, 7, 8, 9, 10, 11, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 0], + device="cuda", + ) + + if ( + rank == 0 + and torch.sum(torch.abs(packed_input_ids_cp_sharded - correct_ids_0)).item() + != 0 + ): + return { + "success": False, + "error": f"CP full padding ids mismatch: expected {correct_ids_0}, got {packed_input_ids_cp_sharded[0, :20]}", + } + if ( + rank == 1 + and torch.sum(torch.abs(packed_input_ids_cp_sharded - correct_ids_1)).item() + != 0 + ): + return { + "success": False, + "error": f"CP full padding ids mismatch: expected {correct_ids_1}, got {packed_input_ids_cp_sharded[0, 20:]}", + } + + return {"success": True, "error": None} + + +PACK_SEQUENCES_TEST_ACTOR_FQN = ( + f"{PackSequencesTestActor.__module__}.PackSequencesTestActor" +) + + +@pytest.fixture +def register_pack_sequences_test_actor(): + """Register the PackSequencesTestActor for use in tests.""" + original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get( + PACK_SEQUENCES_TEST_ACTOR_FQN + ) + ACTOR_ENVIRONMENT_REGISTRY[PACK_SEQUENCES_TEST_ACTOR_FQN] = PY_EXECUTABLES.MCORE + + yield PACK_SEQUENCES_TEST_ACTOR_FQN + + # Clean up registry + if PACK_SEQUENCES_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY: + if original_registry_value is None: + del ACTOR_ENVIRONMENT_REGISTRY[PACK_SEQUENCES_TEST_ACTOR_FQN] + else: + ACTOR_ENVIRONMENT_REGISTRY[PACK_SEQUENCES_TEST_ACTOR_FQN] = ( + original_registry_value + ) + + +@pytest.fixture +def pack_sequences_setup(request): + """Setup and teardown for pack sequences tests - creates a virtual cluster and reusable actor.""" + # Get parameters from request + if hasattr(request, "param") and request.param is not None: + cp_size = request.param + else: + cp_size = 1 + + cluster = None + worker_group = None + + try: + # Skip if not enough GPUs + if not torch.cuda.is_available() or torch.cuda.device_count() < cp_size: + pytest.skip( + f"Not enough GPUs available. Need {cp_size}, got {torch.cuda.device_count()}" + ) + + cluster_name = f"test-pack-sequences-cp{cp_size}" + print(f"Creating virtual cluster '{cluster_name}' for {cp_size} GPUs...") + + cluster = RayVirtualCluster( + name=cluster_name, + bundle_ct_per_node_list=[cp_size], + use_gpus=True, + max_colocated_worker_groups=1, + ) + + actor_fqn = PACK_SEQUENCES_TEST_ACTOR_FQN + + # Register the actor + original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get(actor_fqn) + ACTOR_ENVIRONMENT_REGISTRY[actor_fqn] = PY_EXECUTABLES.MCORE + + try: + # For CP tests + sharding = NamedSharding(layout=list(range(cp_size)), names=["cp"]) + builder = RayWorkerBuilder(actor_fqn, cp_size) + + worker_group = RayWorkerGroup( + cluster=cluster, + remote_worker_builder=builder, + workers_per_node=None, + sharding_annotations=sharding, + ) + + yield worker_group + + finally: + # Clean up registry + if actor_fqn in ACTOR_ENVIRONMENT_REGISTRY: + if original_registry_value is None: + del ACTOR_ENVIRONMENT_REGISTRY[actor_fqn] + else: + ACTOR_ENVIRONMENT_REGISTRY[actor_fqn] = original_registry_value + + finally: + print("Cleaning up pack sequences test resources...") + if worker_group: + worker_group.shutdown(force=True) + if cluster: + cluster.shutdown() + + +@pytest.mark.parametrize("pack_sequences_setup", [1], indirect=True, ids=["cp1"]) +def test_pack_sequences_comprehensive(pack_sequences_setup): + """Comprehensive test of pack sequences functionality without context parallelism.""" + worker_group = pack_sequences_setup + + # Run all tests in a single call to the actor + futures = worker_group.run_all_workers_single_data("run_all_pack_sequences_tests") + results = ray.get(futures) + + # Check that all workers succeeded + for i, result in enumerate(results): + assert result["success"], f"Worker {i} failed: {result['error']}" + + # Print detailed results for debugging + if "detailed_results" in result: + detailed = result["detailed_results"] + print(f"Worker {i} detailed results:") + for test_name, test_result in detailed.items(): + status = "PASSED" if test_result["success"] else "FAILED" + print(f" {test_name}: {status}") + if not test_result["success"]: + print(f" Error: {test_result['error']}") + + +@pytest.mark.parametrize("pack_sequences_setup", [2], indirect=True, ids=["cp2"]) +def test_pack_sequences_with_context_parallel(pack_sequences_setup): + """Test pack sequences functionality with context parallelism.""" + worker_group = pack_sequences_setup + + # Run all tests including CP tests + futures = worker_group.run_all_workers_single_data("run_all_pack_sequences_tests") + results = ray.get(futures) + + # Check that all workers succeeded + for i, result in enumerate(results): + assert result["success"], f"Worker {i} failed: {result['error']}" + + # Print detailed results for debugging + if "detailed_results" in result: + detailed = result["detailed_results"] + print(f"Worker {i} detailed results:") + for test_name, test_result in detailed.items(): + if "skipped" in test_result: + print(f" {test_name}: SKIPPED ({test_result['skipped']})") + else: + status = "PASSED" if test_result["success"] else "FAILED" + print(f" {test_name}: {status}") + if not test_result["success"]: + print(f" Error: {test_result['error']}") diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index fcd0977117..c176082698 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -84,6 +84,9 @@ def create_test_config( "logprob_mb_tokens": 128, "sequence_length_round": 4, }, + "sequence_packing": { + "enabled": False, + }, "optimizer": { "name": "torch.optim.AdamW", "kwargs": { diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index ea1c70f9b3..a399bca0d5 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -74,6 +74,9 @@ def create_megatron_test_config( "dynamic_batching": { "enabled": False, # Start with simple batching }, + "sequence_packing": { + "enabled": False, # Start with simple batching + }, "megatron_cfg": { "enabled": True, "empty_unused_memory_level": 0, @@ -1318,3 +1321,413 @@ def test_megatron_sft_training(): finally: policy.shutdown() cluster.shutdown() + + +@pytest.mark.timeout(300) +def test_megatron_context_parallel_logprob_agreement(): + """Test that CP and non-CP models produce identical logprobs with sequence packing enabled.""" + num_gpus = 2 + batch_size = 4 + seq_len = 64 + vocab_size = 32000 + + # Create test data with varying sequence lengths to test sequence packing + torch.manual_seed(42) # Fixed seed for reproducibility + input_ids = torch.arange(seq_len * batch_size, device="cuda").reshape( + batch_size, seq_len + ) + # Create varied sequence lengths for more realistic sequence packing test + input_lengths = torch.tensor([31, 21, 29, 56], dtype=torch.int32) + attention_mask = torch.zeros(batch_size, seq_len) + for i, length in enumerate(input_lengths): + attention_mask[i, :length] = 1 + + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + } + ) + + # Test 1: Non-CP model (context_parallel_size=1) with sequence packing + print( + "=== Testing Non-CP model (context_parallel_size=1) with sequence packing ===" + ) + cluster_no_cp = RayVirtualCluster( + name="test-no-cp-packing", + bundle_ct_per_node_list=[num_gpus], + use_gpus=True, + num_gpus_per_node=num_gpus, + max_colocated_worker_groups=1, + ) + + config_no_cp = create_megatron_test_config(tp=1, pp=1, precision="bfloat16") + # Ensure context parallel is disabled + config_no_cp["megatron_cfg"]["context_parallel_size"] = 1 + + # Enable sequence packing + config_no_cp["sequence_packing"] = { + "enabled": True, + "train_mb_tokens": seq_len, + "logprob_mb_tokens": seq_len, + "algorithm": "modified_first_fit_decreasing", + } + + tokenizer = get_tokenizer(config_no_cp["tokenizer"]) + config_no_cp["generation"] = configure_generation_config( + config_no_cp["generation"], tokenizer + ) + + policy_no_cp = Policy( + cluster=cluster_no_cp, + config=config_no_cp, + tokenizer=tokenizer, + init_reference_model=False, + ) + + # Get logprobs from non-CP model with sequence packing + policy_no_cp.prepare_for_lp_inference() + logprobs_no_cp = policy_no_cp.get_logprobs(data)["logprobs"] + logprobs_no_cp = logprobs_no_cp * attention_mask + print(f"Non-CP logprobs shape: {logprobs_no_cp.shape}") + print(f"Non-CP logprobs sample: {logprobs_no_cp[0, :5]}") + + # Cleanup non-CP resources + policy_no_cp.shutdown() + + config_no_cp_no_packing = config_no_cp.copy() + config_no_cp_no_packing["sequence_packing"] = { + "enabled": False, + } + policy_no_cp_no_packing = Policy( + cluster=cluster_no_cp, + config=config_no_cp_no_packing, + tokenizer=tokenizer, + init_reference_model=False, + ) + # Get logprobs from non-CP model with sequence packing + policy_no_cp_no_packing.prepare_for_lp_inference() + logprobs_no_cp_no_packing = policy_no_cp_no_packing.get_logprobs(data)["logprobs"] + logprobs_no_cp_no_packing = logprobs_no_cp_no_packing * attention_mask + print(f"Non-CP logprobs no packing shape: {logprobs_no_cp_no_packing.shape}") + print(f"Non-CP logprobs no packing sample: {logprobs_no_cp_no_packing[0, :5]}") + + cluster_no_cp.shutdown() + + # Verify logprobs match between CP and non-CP models with sequence packing + print("=== Comparing logprobs ===") + + # Check shapes match + print(f"diff packing {logprobs_no_cp - logprobs_no_cp_no_packing}") + assert logprobs_no_cp.shape == logprobs_no_cp_no_packing.shape, ( + f"Logprob shapes should match: {logprobs_no_cp.shape} vs {logprobs_no_cp_no_packing.shape}" + ) + ( + torch.testing.assert_close( + logprobs_no_cp, logprobs_no_cp_no_packing, rtol=1e-3, atol=1e-3 + ), + ( + "Logprobs should match between non-CP and non-CP models with sequence packing" + ), + ) + + # Test 2: CP model (context_parallel_size=2) with sequence packing + print("=== Testing CP model (context_parallel_size=2) with sequence packing ===") + cluster_cp = RayVirtualCluster( + name="test-cp-packing", + bundle_ct_per_node_list=[num_gpus], + use_gpus=True, + num_gpus_per_node=num_gpus, + max_colocated_worker_groups=1, + ) + + config_cp = create_megatron_test_config(tp=1, pp=1, precision="bfloat16") + # Enable context parallel + config_cp["megatron_cfg"]["context_parallel_size"] = 2 + + # Enable sequence packing + config_cp["sequence_packing"] = { + "enabled": True, + "train_mb_tokens": seq_len, + "logprob_mb_tokens": seq_len, + "algorithm": "modified_first_fit_decreasing", + } + + config_cp["generation"] = configure_generation_config( + config_cp["generation"], tokenizer + ) + + policy_cp = Policy( + cluster=cluster_cp, + config=config_cp, + tokenizer=tokenizer, + init_reference_model=False, + ) + + # Get logprobs from CP model with sequence packing + policy_cp.prepare_for_lp_inference() + logprobs_cp = policy_cp.get_logprobs(data)["logprobs"] + print(f"CP logprobs shape: {logprobs_cp.shape}") + print(f"CP logprobs sample: {logprobs_cp[0, :5]}") + + # Cleanup CP resources + policy_cp.shutdown() + cluster_cp.shutdown() + + # Verify logprobs match between CP and non-CP models with sequence packing + print("=== Comparing logprobs ===") + + # Check shapes match + assert logprobs_no_cp.shape == logprobs_cp.shape, ( + f"Logprob shapes should match: {logprobs_no_cp.shape} vs {logprobs_cp.shape}" + ) + + # Check that neither contains NaN or Inf + assert not torch.isnan(logprobs_no_cp).any(), ( + "Non-CP logprobs should not contain NaN" + ) + assert not torch.isinf(logprobs_no_cp).any(), ( + "Non-CP logprobs should not contain Inf" + ) + assert not torch.isnan(logprobs_cp).any(), "CP logprobs should not contain NaN" + assert not torch.isinf(logprobs_cp).any(), "CP logprobs should not contain Inf" + + # Check that first token logprobs are zero (by convention) + assert torch.all(logprobs_no_cp[:, 0] == 0), ( + "First token logprobs should be zero (non-CP)" + ) + assert torch.all(logprobs_cp[:, 0] == 0), "First token logprobs should be zero (CP)" + + # Compare logprobs with tight tolerance + logprobs_cp = logprobs_cp * attention_mask + print(f"diff {logprobs_no_cp_no_packing - logprobs_cp}") + max_diff = torch.max(torch.abs(logprobs_no_cp_no_packing - logprobs_cp)).item() + mean_diff = torch.mean(torch.abs(logprobs_no_cp_no_packing - logprobs_cp)).item() + print(f"Max difference: {max_diff}") + print(f"Mean difference: {mean_diff}") + + # Assert logprobs are identical (or very close due to floating point) + torch.testing.assert_close( + logprobs_no_cp_no_packing, + logprobs_cp, + rtol=1e-3, + atol=1e-2, + msg="CP and non-CP models should produce identical logprobs with sequence packing", + ) + + print( + "✓ SUCCESS: CP and non-CP models produce identical logprobs with sequence packing" + ) + + +@pytest.mark.timeout(300) +def test_megatron_context_parallel_training_agreement(): + """Test that CP and non-CP models produce consistent training results with ClippedPG loss and sequence packing.""" + num_gpus = 2 + batch_size = 2 + seq_len = 64 + vocab_size = 32000 + + # Create test data with varying sequence lengths to test sequence packing + torch.manual_seed(42) # Fixed seed for reproducibility + input_ids = torch.arange(seq_len * batch_size, device="cuda").reshape( + batch_size, seq_len + ) + + # Create varied sequence lengths for more realistic sequence packing test + input_lengths = torch.tensor([33, 48], dtype=torch.int32) + attention_mask = torch.zeros(batch_size, seq_len) + for i, length in enumerate(input_lengths): + attention_mask[i, :length] = 1 + + # Create additional data required for ClippedPG loss + token_mask = torch.zeros(batch_size, seq_len) + sample_mask = torch.ones(batch_size) + advantages = torch.randn(batch_size, seq_len) + prev_logprobs = torch.randn(batch_size, seq_len) + generation_logprobs = prev_logprobs.clone() + reference_policy_logprobs = prev_logprobs.clone() + labels = torch.randint(0, vocab_size, (batch_size, seq_len)) + + for i in range(batch_size): + token_mask[i, : input_lengths[i]] = 1 + + base_data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + "token_mask": token_mask, + "sample_mask": sample_mask, + "advantages": advantages, + "prev_logprobs": prev_logprobs, + "generation_logprobs": generation_logprobs, + "reference_policy_logprobs": reference_policy_logprobs, + "labels": labels, + } + ) + + # Test 1: Non-CP model (context_parallel_size=1) with sequence packing + print( + "=== Testing Non-CP model (context_parallel_size=1) with sequence packing ===" + ) + cluster_no_cp = RayVirtualCluster( + name="test-no-cp-training", + bundle_ct_per_node_list=[1], + use_gpus=True, + num_gpus_per_node=1, + max_colocated_worker_groups=1, + ) + + config_no_cp = create_megatron_test_config(tp=1, pp=1, precision="bfloat16") + # Ensure context parallel is disabled + config_no_cp["megatron_cfg"]["context_parallel_size"] = 1 + config_no_cp["train_global_batch_size"] = 2 + + # Enable sequence packing + config_no_cp["sequence_packing"] = { + "enabled": True, + "train_mb_tokens": seq_len, + "logprob_mb_tokens": seq_len, + "algorithm": "modified_first_fit_decreasing", + } + + tokenizer = get_tokenizer(config_no_cp["tokenizer"]) + config_no_cp["generation"] = configure_generation_config( + config_no_cp["generation"], tokenizer + ) + + policy_no_cp = Policy( + cluster=cluster_no_cp, + config=config_no_cp, + tokenizer=tokenizer, + init_reference_model=False, + ) + + # Create ClippedPG loss function + loss_fn = ClippedPGLossFn( + { + "ratio_clip_min": 0.2, + "ratio_clip_max": 0.2, + "ratio_clip_c": None, + "reference_policy_kl_penalty": 0.1, + "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, + "token_level_loss": True, + } + ) + + # Train non-CP model + policy_no_cp.prepare_for_training() + no_cp_results = policy_no_cp.train(base_data, loss_fn) + no_cp_loss = no_cp_results["loss"] + no_cp_metrics = no_cp_results["all_mb_metrics"] + + print(f"Non-CP training loss: {no_cp_loss}") + print(f"Non-CP metrics: {no_cp_metrics}") + + # Cleanup non-CP resources + policy_no_cp.shutdown() + cluster_no_cp.shutdown() + + # Test 2: CP model (context_parallel_size=2) with sequence packing + print("=== Testing CP model (context_parallel_size=2) with sequence packing ===") + cluster_cp = RayVirtualCluster( + name="test-cp-training", + bundle_ct_per_node_list=[num_gpus], + use_gpus=True, + num_gpus_per_node=num_gpus, + max_colocated_worker_groups=1, + ) + + config_cp = create_megatron_test_config(tp=1, pp=1, precision="bfloat16") + # Enable context parallel + config_cp["megatron_cfg"]["context_parallel_size"] = 2 + config_cp["train_global_batch_size"] = 2 + + # Enable sequence packing + config_cp["sequence_packing"] = { + "enabled": True, + "train_mb_tokens": seq_len, + "logprob_mb_tokens": seq_len, + "algorithm": "modified_first_fit_decreasing", + } + + config_cp["generation"] = configure_generation_config( + config_cp["generation"], tokenizer + ) + + policy_cp = Policy( + cluster=cluster_cp, + config=config_cp, + tokenizer=tokenizer, + init_reference_model=False, + ) + + # Train CP model + policy_cp.prepare_for_training() + cp_results = policy_cp.train(base_data, loss_fn) + cp_loss = cp_results["loss"] + cp_metrics = cp_results["all_mb_metrics"] + + print(f"CP training loss: {cp_loss}") + print(f"CP metrics: {cp_metrics}") + + # Cleanup CP resources + policy_cp.shutdown() + cluster_cp.shutdown() + + # Compare training results + print("=== Comparing training results ===") + + # Check that neither contains NaN or Inf + assert not torch.isnan(no_cp_loss).any(), "Non-CP loss should not contain NaN" + assert not torch.isinf(no_cp_loss).any(), "Non-CP loss should not contain Inf" + assert not torch.isnan(cp_loss).any(), "CP loss should not contain NaN" + assert not torch.isinf(cp_loss).any(), "CP loss should not contain Inf" + + # Check shapes match + assert no_cp_loss.shape == cp_loss.shape, ( + f"Loss shapes should match: {no_cp_loss.shape} vs {cp_loss.shape}" + ) + + # Compare loss values with tolerance + loss_diff = torch.abs(no_cp_loss - cp_loss) + max_loss_diff = torch.max(loss_diff).item() + mean_loss_diff = torch.mean(loss_diff).item() + + print(f"Loss difference - Max: {max_loss_diff:.6f}, Mean: {mean_loss_diff:.6f}") + + # Check key metrics are similar + key_metrics = ["probs_ratio", "grad_norm", "kl_penalty", "approx_entropy"] + for metric in key_metrics: + if metric in no_cp_metrics and metric in cp_metrics: + no_cp_val = no_cp_metrics[metric] + cp_val = cp_metrics[metric] + if metric == "grad_norm": + diff = abs(sum(no_cp_val) - sum(cp_val) * 2) + else: + diff = abs(sum(no_cp_val) - sum(cp_val)) + print( + f"Metric {metric}: Non-CP={sum(no_cp_val):.6f}, CP={sum(cp_val):.6f}, Diff={diff:.6f}" + ) + + # Allow some tolerance for floating point differences + assert diff < 0.01 * sum(no_cp_val) or diff < 1e-4, ( + f"Metric {metric} differs too much: {diff:.6f}" + ) + + # Assert losses are very close (accounting for minor floating point differences) + torch.testing.assert_close( + no_cp_loss, + cp_loss, + rtol=1e-2, + atol=1e-2, + msg="CP and non-CP models should produce very similar training losses with sequence packing", + ) + + print( + "✓ SUCCESS: CP and non-CP models produce consistent training results with ClippedPG loss and sequence packing" + ) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 31f4b16321..11515ec661 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -30,6 +30,7 @@ def __call__( global_valid_toks: torch.Tensor | None, vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: # Just return mean of logprobs as the loss for testing loss = next_token_logits.mean() @@ -53,6 +54,7 @@ def __call__( global_valid_toks: torch.Tensor | None, vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: # logits shape: [batch_size, seq_len, vocab_size] # Get the next token logits for each position diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index feca16365d..eb7c7a19f0 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -65,6 +65,9 @@ "dynamic_batching": { "enabled": False, }, + "sequence_packing": { + "enabled": False, + }, "max_grad_norm": 1.0, "generation": { "backend": "vllm",