Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/model-quirks.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ NeMo-RL uses the vLLM V1 runtime for both synchronous and asynchronous inference

### Context Parallel with FSDP2

NeMo-RL implemented this feature based on torch CP [implementation](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/experimental/_attention.py). And we inherit its limitations.
- NeMo-RL implemented this feature based on torch CP [implementation](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/experimental/_attention.py). And we inherit its limitations.
Whether model level support CP only depends on arguments passed to `torch.nn.functional.scaled_dot_product_attention`. Current NeMo-RL passed all ones attention mask to `model.forward`. For Gemma-3, it won't ignore attention mask as result `attn_bias` is not None which is not supported by torch CP. Please see [assertion](https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/distributed/tensor/experimental/_attention.py#L262) .
- Context parallel can't be used together with sequence packing. Sequence packing requires `attn_implementation="flash_attention_2"`, this conflict with context parallel requires SDPA impl. Refer to [here](https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/modeling_utils.py#L2317) for more details.


## vLLM Async Rollout Timeout

Expand Down
4 changes: 4 additions & 0 deletions nemo_rl/distributed/batched_data_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,10 @@ def slice(self, start: int, end: int) -> "SlicedDataDict":
"""
sliced_batch = SlicedDataDict()
for k in self.data:
if isinstance(self.data[k], torch.Tensor):
assert end <= self.data[k].shape[0], (
f"end: {end} is greater than the shape of the tensor: {self.data[k].shape[0]} for key: {k}"
)
sliced_batch[k] = self.data[k][start:end]
return sliced_batch

Expand Down
8 changes: 4 additions & 4 deletions nemo_rl/distributed/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def backward(


def dtensor_from_parallel_logits_to_logprobs(
vocab_parallel_logits: DTensor,
target: DTensor,
vocab_parallel_logits: torch.Tensor,
target: DTensor | torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
tp_group: torch.distributed.ProcessGroup,
Expand All @@ -132,7 +132,7 @@ def dtensor_from_parallel_logits_to_logprobs(
"""Get log probabilities from TP+CP sharded vocab logits.

Args:
vocab_parallel_logits (DTensor): Logits distributed across tensor parallel workers,
vocab_parallel_logits (orch.Tensor): Logits distributed across tensor parallel workers,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
vocab_parallel_logits (orch.Tensor): Logits distributed across tensor parallel workers,
vocab_parallel_logits (torch.Tensor): Logits distributed across tensor parallel workers,

with shape [batch_size, seq_len, vocab_size/tp_size].
target (DTensor): Target token indices with shape [batch_size, seq_len].
NOTE: Must be the unmodified targets as this function will shift them internally.
Expand All @@ -149,7 +149,7 @@ def dtensor_from_parallel_logits_to_logprobs(
"""
cp_size = 1

if "cp" in target.device_mesh.mesh_dim_names:
if isinstance(target, DTensor) and "cp" in target.device_mesh.mesh_dim_names:
cp_dim_index = target.device_mesh.mesh_dim_names.index("cp")
cp_size = target.device_mesh.shape[cp_dim_index]

Expand Down
4 changes: 3 additions & 1 deletion nemo_rl/models/dtensor/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,10 @@ def get_logprobs_from_vocab_parallel_logits(
Args:
vocab_parallel_logits (DTensor): Logits distributed across tensor parallel workers,
with shape [batch_size, seq_len, vocab_size/tp_size].
input_ids (torch.Tensor): Input token IDs for which to compute log probabilities,
input_ids (torch.Tensor | DTensor): Input token IDs for which to compute log probabilities,
with shape [batch_size, seq_len].
seq_index (Optional[torch.Tensor]): Sequence index for the input IDs,
with shape [sequence_length].

Returns:
torch.Tensor: Log probabilities for the given input IDs.
Expand Down
13 changes: 10 additions & 3 deletions nemo_rl/models/huggingface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def is_gemma_model(model_name: str) -> bool:


def group_and_cat_tensors(
tensors: list[torch.Tensor], group_sizes: list[int], padding_value: int = 0
tensors: list[torch.Tensor],
group_sizes: list[int],
padding_value: int = 0,
min_seq_len: int = 0,
) -> torch.Tensor:
"""Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor.

Expand All @@ -82,6 +85,7 @@ def group_and_cat_tensors(
tensors: List of 1D tensors of varying lengths.
group_sizes: List of integers. Each integer specifies how many tensors to group.
padding_value: Integer used to pad shorter sequences.
min_seq_len: Minimum sequence length.

Returns:
A 2D tensor where each row is a padded concatenation of the grouped tensors.
Expand All @@ -108,6 +112,7 @@ def group_and_cat_tensors(

# Compute the maximum length for padding
max_len = max(t.size(0) for t in grouped)
max_len = max(max_len, min_seq_len)

# Pad each tensor to max_len
padded = torch.stack(
Expand All @@ -126,6 +131,7 @@ def pack_sequences(
packed_sequence_size: list[int],
padding_value: int = 0,
return_attention_mask: bool = True,
min_seq_len: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Packs sequences into rows where each row concatenates multiple sequences.

Expand All @@ -138,6 +144,7 @@ def pack_sequences(
packed_sequence_size (List[int]): How many sequences to pack per row
padding_value (int): Pad value for input_ids
return_attention_mask (bool): Whether to return per-row causal attention mask
min_seq_len (int): Minimum sequence length.

Returns:
Tuple:
Expand Down Expand Up @@ -197,10 +204,10 @@ def pack_sequences(

# Group and pad
input_ids_packed = group_and_cat_tensors(
flat_input_ids, packed_sequence_size, padding_value
flat_input_ids, packed_sequence_size, padding_value, min_seq_len=min_seq_len
)
position_ids_packed = group_and_cat_tensors(
position_ids, packed_sequence_size, padding_value=0
position_ids, packed_sequence_size, padding_value=0, min_seq_len=min_seq_len
)

# Compute max length
Expand Down
Loading
Loading