diff --git a/docs/model-quirks.md b/docs/model-quirks.md index 7824e8bf78..dc3d025d22 100644 --- a/docs/model-quirks.md +++ b/docs/model-quirks.md @@ -31,8 +31,10 @@ NeMo-RL uses the vLLM V1 runtime for both synchronous and asynchronous inference ### Context Parallel with FSDP2 -NeMo-RL implemented this feature based on torch CP [implementation](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/experimental/_attention.py). And we inherit its limitations. + - NeMo-RL implemented this feature based on torch CP [implementation](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/experimental/_attention.py). And we inherit its limitations. Whether model level support CP only depends on arguments passed to `torch.nn.functional.scaled_dot_product_attention`. Current NeMo-RL passed all ones attention mask to `model.forward`. For Gemma-3, it won't ignore attention mask as result `attn_bias` is not None which is not supported by torch CP. Please see [assertion](https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/distributed/tensor/experimental/_attention.py#L262) . + - Context parallel can't be used together with sequence packing. Sequence packing requires `attn_implementation="flash_attention_2"`, this conflict with context parallel requires SDPA impl. Refer to [here](https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/modeling_utils.py#L2317) for more details. + ## vLLM Async Rollout Timeout diff --git a/nemo_rl/distributed/batched_data_dict.py b/nemo_rl/distributed/batched_data_dict.py index 52d478295f..969738d203 100644 --- a/nemo_rl/distributed/batched_data_dict.py +++ b/nemo_rl/distributed/batched_data_dict.py @@ -645,6 +645,10 @@ def slice(self, start: int, end: int) -> "SlicedDataDict": """ sliced_batch = SlicedDataDict() for k in self.data: + if isinstance(self.data[k], torch.Tensor): + assert end <= self.data[k].shape[0], ( + f"end: {end} is greater than the shape of the tensor: {self.data[k].shape[0]} for key: {k}" + ) sliced_batch[k] = self.data[k][start:end] return sliced_batch diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 98cb312ebd..606fd8464b 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -121,8 +121,8 @@ def backward( def dtensor_from_parallel_logits_to_logprobs( - vocab_parallel_logits: DTensor, - target: DTensor, + vocab_parallel_logits: torch.Tensor, + target: DTensor | torch.Tensor, vocab_start_index: int, vocab_end_index: int, tp_group: torch.distributed.ProcessGroup, @@ -132,7 +132,7 @@ def dtensor_from_parallel_logits_to_logprobs( """Get log probabilities from TP+CP sharded vocab logits. Args: - vocab_parallel_logits (DTensor): Logits distributed across tensor parallel workers, + vocab_parallel_logits (orch.Tensor): Logits distributed across tensor parallel workers, with shape [batch_size, seq_len, vocab_size/tp_size]. target (DTensor): Target token indices with shape [batch_size, seq_len]. NOTE: Must be the unmodified targets as this function will shift them internally. @@ -149,7 +149,7 @@ def dtensor_from_parallel_logits_to_logprobs( """ cp_size = 1 - if "cp" in target.device_mesh.mesh_dim_names: + if isinstance(target, DTensor) and "cp" in target.device_mesh.mesh_dim_names: cp_dim_index = target.device_mesh.mesh_dim_names.index("cp") cp_size = target.device_mesh.shape[cp_dim_index] diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index f82797929f..fef3d9ba2a 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -621,8 +621,10 @@ def get_logprobs_from_vocab_parallel_logits( Args: vocab_parallel_logits (DTensor): Logits distributed across tensor parallel workers, with shape [batch_size, seq_len, vocab_size/tp_size]. - input_ids (torch.Tensor): Input token IDs for which to compute log probabilities, + input_ids (torch.Tensor | DTensor): Input token IDs for which to compute log probabilities, with shape [batch_size, seq_len]. + seq_index (Optional[torch.Tensor]): Sequence index for the input IDs, + with shape [sequence_length]. Returns: torch.Tensor: Log probabilities for the given input IDs. diff --git a/nemo_rl/models/huggingface/common.py b/nemo_rl/models/huggingface/common.py index cdfab8ef04..c057f6d89a 100644 --- a/nemo_rl/models/huggingface/common.py +++ b/nemo_rl/models/huggingface/common.py @@ -71,7 +71,10 @@ def is_gemma_model(model_name: str) -> bool: def group_and_cat_tensors( - tensors: list[torch.Tensor], group_sizes: list[int], padding_value: int = 0 + tensors: list[torch.Tensor], + group_sizes: list[int], + padding_value: int = 0, + min_seq_len: int = 0, ) -> torch.Tensor: """Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor. @@ -82,6 +85,7 @@ def group_and_cat_tensors( tensors: List of 1D tensors of varying lengths. group_sizes: List of integers. Each integer specifies how many tensors to group. padding_value: Integer used to pad shorter sequences. + min_seq_len: Minimum sequence length. Returns: A 2D tensor where each row is a padded concatenation of the grouped tensors. @@ -108,6 +112,7 @@ def group_and_cat_tensors( # Compute the maximum length for padding max_len = max(t.size(0) for t in grouped) + max_len = max(max_len, min_seq_len) # Pad each tensor to max_len padded = torch.stack( @@ -126,6 +131,7 @@ def pack_sequences( packed_sequence_size: list[int], padding_value: int = 0, return_attention_mask: bool = True, + min_seq_len: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Packs sequences into rows where each row concatenates multiple sequences. @@ -138,6 +144,7 @@ def pack_sequences( packed_sequence_size (List[int]): How many sequences to pack per row padding_value (int): Pad value for input_ids return_attention_mask (bool): Whether to return per-row causal attention mask + min_seq_len (int): Minimum sequence length. Returns: Tuple: @@ -197,10 +204,10 @@ def pack_sequences( # Group and pad input_ids_packed = group_and_cat_tensors( - flat_input_ids, packed_sequence_size, padding_value + flat_input_ids, packed_sequence_size, padding_value, min_seq_len=min_seq_len ) position_ids_packed = group_and_cat_tensors( - position_ids, packed_sequence_size, padding_value=0 + position_ids, packed_sequence_size, padding_value=0, min_seq_len=min_seq_len ) # Compute max length diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index b668fc41e8..671a68cd38 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -14,11 +14,11 @@ import contextlib import gc +import itertools import os from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext from typing import Any, Generator, Iterable, Optional, Set, Union, cast -import itertools import ray import torch @@ -233,6 +233,10 @@ def __init__( tp_size = self.cfg["dtensor_cfg"]["tensor_parallel_size"] cp_size = self.cfg["dtensor_cfg"]["context_parallel_size"] + if cp_size > 1 and self.enable_seq_packing: + raise ValueError( + "Context parallel is not supported for sequence packing. Refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details." + ) dp_size = world_size // tp_size // cp_size assert world_size == dp_size * tp_size * cp_size, ( f"World size({world_size}) must equal to dp_size({dp_size}) * tp_size({tp_size}) * cp_size({cp_size}) to use DTensor" @@ -543,20 +547,30 @@ def train( mb_iterator = ( batch.make_microbatch_iterator_for_packable_sequences() ) - iterator_len, max_seqlen = batch.get_microbatch_iterator_for_packable_sequences_len() - max_batch_ct = torch.tensor([iterator_len], device='cuda') - torch.distributed.all_reduce(max_batch_ct, op=torch.distributed.ReduceOp.MAX) + iterator_len, max_seqlen = ( + batch.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) # Sequence packing can end up with unevenly distributed batch counts across DP ranks. # We add dummy batches to the end of the iterator to make the batch counts equal. dummy_batch_ct = int(max_batch_ct.item() - iterator_len) - dummy_iterator = batch.make_microbatch_iterator_for_packable_sequences() - dummy_iterator = itertools.islice(itertools.cycle(dummy_iterator), dummy_batch_ct) + dummy_iterator = ( + batch.make_microbatch_iterator_for_packable_sequences() + ) + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) else: mb_iterator = batch.make_microbatch_iterator(mbs) iterator_len = batch.size // mbs - for mb_idx, mb in enumerate(itertools.chain(mb_iterator, dummy_iterator)): + for mb_idx, mb in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): with torch.autocast(device_type="cuda", dtype=self.dtype): if self.enable_seq_packing: input_ids = mb.get("input_ids").cuda() @@ -568,6 +582,9 @@ def train( ], # flash attention 2 expects flattened input padding_value=self.tokenizer.eos_token_id, return_attention_mask=False, + min_seq_len=self.cfg["sequence_packing"][ + "train_mb_tokens" + ], # TODO: this is a WAR for sequence packing, we should fix this. Without this, backward will fail when TP is enabled. ) seq_len = input_ids.shape[1] attention_mask = None @@ -819,21 +836,29 @@ def get_logprobs( iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() elif self.enable_seq_packing: mb_iterator = data.make_microbatch_iterator_for_packable_sequences() - iterator_len, max_seqlen = data.get_microbatch_iterator_for_packable_sequences_len() - max_batch_ct = torch.tensor([iterator_len], device='cuda') - torch.distributed.all_reduce(max_batch_ct, op=torch.distributed.ReduceOp.MAX) + iterator_len, max_seqlen = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) # Sequence packing can end up with unevenly distributed batch counts across DP ranks. # We add dummy batches to the end of the iterator to make the batch counts equal. dummy_batch_ct = int(max_batch_ct.item() - iterator_len) dummy_iterator = data.make_microbatch_iterator_for_packable_sequences() - dummy_iterator = itertools.islice(itertools.cycle(dummy_iterator), dummy_batch_ct) + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) else: mb_iterator = data.make_microbatch_iterator(logprob_batch_size) iterator_len = data.size // logprob_batch_size step = 0 - for batch_idx, lp_batch in enumerate(itertools.chain(mb_iterator, dummy_iterator)): + for batch_idx, lp_batch in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): step += 1 input_ids = lp_batch.get("input_ids").cuda() input_lengths = lp_batch.get("input_lengths") @@ -843,7 +868,9 @@ def get_logprobs( input_ids, position_ids, _ = pack_sequences( input_ids=input_ids, input_lengths=input_lengths, - packed_sequence_size=[batch_size], # flash attention 2 expects flattened input + packed_sequence_size=[ + batch_size + ], # flash attention 2 expects flattened input padding_value=self.tokenizer.eos_token_id, return_attention_mask=False, ) @@ -863,9 +890,9 @@ def get_logprobs( # explicitly create position ids for the input, otherwise the sharding # for DTensor will be incorrect - position_ids = torch.arange(seq_len, device=input_ids.device).repeat( - batch_size, 1 - ) + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) flash_attn_kwargs = {} with torch.autocast(device_type="cuda", dtype=self.dtype): @@ -877,35 +904,99 @@ def get_logprobs( (batch_size, seq_len), dtype=torch.long, device=input_ids.device ) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask_input_all_ones, - position_ids=position_ids, - use_cache=False, - flash_attn_kwargs=flash_attn_kwargs, - ) - - if isinstance(outputs.logits, DTensor): - token_logprobs = get_logprobs_from_vocab_parallel_logits( - outputs.logits.to(torch.float32), input_ids + context_parallel_ctx = None + if self.cp_size > 1: + seq_index = torch.arange(seq_len, device=input_ids.device).repeat( + 1, 1 ) - else: - # Extract logprobs for each token in the sequence by gathering the logprob - # corresponding to the next token at each position - # Input shapes: - # log_probs: [batch_size, sequence_length, vocab_size] - logits for each position - # token_ids: [batch_size, sequence_length] - actual tokens - # Output shape: [batch_size, sequence_length] - logprob of each token given previous - # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length - - log_probs = torch.nn.functional.log_softmax( - outputs.logits.to(torch.float32), dim=-1 + cp_buffers = [input_ids, position_ids, seq_index] + + # Create context parallel context + context_parallel_ctx = self.create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), ) - next_tokens = input_ids[:, 1:] - log_probs = log_probs[:, :-1] - token_logprobs = log_probs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) + + with DTensorPolicyWorker.train_context(context_parallel_ctx): + with torch.autocast(device_type="cuda", dtype=self.dtype): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask_input_all_ones, + position_ids=position_ids, + use_cache=False, + flash_attn_kwargs=flash_attn_kwargs, + ) + + logits = outputs.logits + + if self.cp_size > 1: + seq_index_tensor = ( + DTensor.from_local( + seq_index, + device_mesh=self.cp_mesh, + placements=[Shard(1)], + ) + .full_tensor() + .squeeze(0) + ) + + input_ids_dtensor = DTensor.from_local( + input_ids, + device_mesh=self.cp_mesh, + placements=[Shard(sequence_dim)], + ) + + if isinstance(logits, DTensor): + # Must be tp sharded + assert ( + logits.device_mesh.ndim == 1 + and logits.device_mesh.mesh_dim_names[0] == "tp" + ), "logits must be tp sharded" + + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim + logits = DTensor.from_local( + logits.to_local(), + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + else: + logits = DTensor.from_local( + logits, + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + + token_logprobs = get_logprobs_from_vocab_parallel_logits( + logits.to(torch.float32), + input_ids_dtensor, + seq_index_tensor, + ) + + assert token_logprobs.shape[1] == seq_len - 1 + else: + if isinstance(logits, DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + logits.to(torch.float32), input_ids + ) + else: + # Extract logprobs for each token in the sequence by gathering the logprob + # corresponding to the next token at each position + # Input shapes: + # log_probs: [batch_size, sequence_length, vocab_size] - logits for each position + # token_ids: [batch_size, sequence_length] - actual tokens + # Output shape: [batch_size, sequence_length] - logprob of each token given previous + # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length + + log_probs = torch.nn.functional.log_softmax( + outputs.logits.to(torch.float32), dim=-1 + ) + next_tokens = input_ids[:, 1:] + log_probs = log_probs[:, :-1] + token_logprobs = log_probs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) token_logprobs = torch.cat( [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 @@ -920,13 +1011,19 @@ def get_logprobs( token_logprobs = token_logprobs * attention_mask else: # For packed sequences, unpack logprobs - unpacked_logprobs = torch.zeros((batch_size, seq_dim_size), dtype=token_logprobs.dtype, device=token_logprobs.device) + unpacked_logprobs = torch.zeros( + (batch_size, seq_dim_size), + dtype=token_logprobs.dtype, + device=token_logprobs.device, + ) cu_seqlens = flash_attn_kwargs.cu_seqlens_q for i in range(batch_size): start = cu_seqlens[i].item() + 1 - end = cu_seqlens[i+1].item() + end = cu_seqlens[i + 1].item() seq_len_actual = input_lengths[i].item() - unpacked_logprobs[i, 1: seq_len_actual] = token_logprobs[0, start:end] + unpacked_logprobs[i, 1:seq_len_actual] = token_logprobs[ + 0, start:end + ] token_logprobs = unpacked_logprobs all_log_probs.append(token_logprobs) diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 9996b213ff..dbbf5ddc1e 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -193,7 +193,6 @@ def get_logprobs( The logprob of input token i is specified at position i in the output logprobs tensor. """ dp_size = self.sharding_annotations.get_axis_size("data_parallel") - cp_size = self.sharding_annotations.get_axis_size("context_parallel") sharded_data: list[SlicedDataDict] unsorted_data_indices: list[int] @@ -202,7 +201,7 @@ def get_logprobs( "dynamic_batching" ]["logprob_mb_tokens"] sharded_data, unsorted_data_indices = data.shard_by_batch_size( # type: ignore - cp_size * dp_size, + dp_size, batch_size=None, dynamic_batching_args=self.dynamic_batching_args, ) @@ -218,24 +217,13 @@ def get_logprobs( ) else: sharded_data = data.shard_by_batch_size( # type: ignore - cp_size * dp_size, + dp_size, batch_size=None, ) - # sharded_data_2d = [] - # shard_idx = 0 - # # Convert to 2d dim array - # for _ in range(dp_size): - # cp_data = [] - # for _ in range(cp_size): - # cp_data.append(sharded_data[shard_idx]) - # shard_idx += 1 - # sharded_data_2d.append(cp_data) - sharded_data_2d = sharded_data - futures = self.worker_group.run_all_workers_sharded_data( "get_logprobs", - data=sharded_data_2d, + data=sharded_data, in_sharded_axes=["data_parallel"], replicate_on_axes=[ "context_parallel", @@ -269,7 +257,6 @@ def get_reference_policy_logprobs( Returns: Identical to get_logprobs. """ dp_size = self.sharding_annotations.get_axis_size("data_parallel") - cp_size = self.sharding_annotations.get_axis_size("context_parallel") sharded_data: list[SlicedDataDict] unsorted_data_indices: list[int] if self.use_dynamic_batches: @@ -277,7 +264,7 @@ def get_reference_policy_logprobs( "dynamic_batching" ]["logprob_mb_tokens"] sharded_data, unsorted_data_indices = data.shard_by_batch_size( # type: ignore - cp_size * dp_size, + dp_size, batch_size=None, dynamic_batching_args=self.dynamic_batching_args, ) @@ -292,24 +279,13 @@ def get_reference_policy_logprobs( ) else: sharded_data = data.shard_by_batch_size( # type: ignore - cp_size * dp_size, + dp_size, batch_size=None, ) - sharded_data_2d = sharded_data - # sharded_data_2d = [] - # shard_idx = 0 - # # Convert to 2d dim array - # for _ in range(dp_size): - # cp_data = [] - # for _ in range(cp_size): - # cp_data.append(sharded_data[shard_idx]) - # shard_idx += 1 - # sharded_data_2d.append(cp_data) - futures = self.worker_group.run_all_workers_sharded_data( "get_reference_policy_logprobs", - data=sharded_data_2d, + data=sharded_data, in_sharded_axes=["data_parallel"], replicate_on_axes=[ "context_parallel",