diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 1078da5fa3..1bf472d830 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -122,6 +122,7 @@ def __call__( prev_logprobs = data["prev_logprobs"][:, 1:] generation_logprobs = data["generation_logprobs"][:, 1:] reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:] + seq_index = data.get("seq_index", None) mask = token_mask * sample_mask.unsqueeze(-1) @@ -146,12 +147,12 @@ def __call__( data["input_ids"], vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - group=vocab_parallel_group, + tp_group=vocab_parallel_group, inference_only=False, ) elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): curr_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"] + next_token_logits, data["input_ids"], seq_index=seq_index ) else: next_token_logits_wo_last = next_token_logits[ @@ -332,7 +333,7 @@ def __call__( data["input_ids"], vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - group=vocab_parallel_group, + tp_group=vocab_parallel_group, inference_only=False, ) elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): @@ -480,7 +481,7 @@ def _preference_loss( data["input_ids"], vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - group=vocab_parallel_group, + tp_group=vocab_parallel_group, inference_only=False, ) elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index f1e2e6ac81..31ac71cc23 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import torch +from torch.distributed.tensor import DTensor, distribute_tensor @torch.no_grad() @@ -121,11 +122,12 @@ def backward( def from_parallel_logits_to_logprobs( vocab_parallel_logits: torch.Tensor, - target: torch.Tensor, + target: torch.Tensor | DTensor, vocab_start_index: int, vocab_end_index: int, - group: torch.distributed.ProcessGroup, + tp_group: torch.distributed.ProcessGroup, inference_only: bool = False, + seq_index: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Get log probabilities from TP sharded vocab logits. @@ -136,8 +138,10 @@ def from_parallel_logits_to_logprobs( NOTE: Must be the unmodified targets as this function will shift them internally. vocab_start_index (int): Starting vocabulary index for this worker's partition. vocab_end_index (int): Ending vocabulary index for this worker's partition. - group (torch.distributed.ProcessGroup): Process group for distributed communication. + tp_group (torch.distributed.ProcessGroup): Process group for distributed communication. inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. + seq_index (Optional[torch.Tensor]): Sequence index tensor with shape [seq_len]. + It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension. Returns: torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. @@ -145,13 +149,42 @@ def from_parallel_logits_to_logprobs( Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354 """ - target = target.roll(shifts=-1, dims=-1) + cp_size = 1 + + if isinstance(target, DTensor) and "cp" in target.device_mesh.mesh_dim_names: + cp_dim_index = target.device_mesh.mesh_dim_names.index("cp") + cp_size = target.device_mesh.shape[cp_dim_index] + + if cp_size > 1: + assert seq_index is not None, "seq_index must be provided for cp sharded logits" + target_shape = torch.Size(target.shape) + cp_mesh = target.device_mesh + cp_placements = target.placements + _, sorted_indices = torch.sort(seq_index) + # Recover the original order of the target + target = target.full_tensor()[:, sorted_indices] + target = target.roll(shifts=-1, dims=-1)[:, seq_index] + + # Reshard + target = distribute_tensor(target, cp_mesh, cp_placements) + target = target.to_local() + else: + target = target.roll(shifts=-1, dims=-1) + probs: torch.Tensor = DistributedLogprob.apply( # type: ignore vocab_parallel_logits, target, vocab_start_index, vocab_end_index, - group, + tp_group, inference_only, ).contiguous() + + if cp_size > 1: + # probs is sharded on the sequence dimension. + # Get full sequence tensor, vocab dim has been reduced already. + probs_dtensor = DTensor.from_local(probs, cp_mesh, cp_placements) + probs = probs_dtensor.full_tensor()[:, sorted_indices] + assert probs.shape == target_shape + return probs[:, :-1] diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index fb9c720c20..370624a163 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -616,7 +616,9 @@ def get_grad_norm( def get_logprobs_from_vocab_parallel_logits( - vocab_parallel_logits: DTensor, input_ids: torch.Tensor + vocab_parallel_logits: DTensor, + input_ids: torch.Tensor | DTensor, + seq_index: Optional[torch.Tensor] = None, ): """Computes log probabilities from vocabulary-parallel logits. @@ -632,16 +634,26 @@ def get_logprobs_from_vocab_parallel_logits( Returns: torch.Tensor: Log probabilities for the given input IDs. """ - tp_mesh = vocab_parallel_logits.device_mesh - tp_rank: int = tp_mesh.get_local_rank() + device_mesh = vocab_parallel_logits.device_mesh + if seq_index is not None: + assert "cp" in device_mesh.mesh_dim_names, ( + "seq_index must be provided for cp sharded logits" + ) + + tp_size = 1 + + tp_group = device_mesh.get_group("tp") + tp_rank = tp_group.rank() + tp_size = tp_group.size() - vocab_interval_per_rank = vocab_parallel_logits.shape[-1] // tp_mesh.size() + vocab_interval_per_rank = vocab_parallel_logits.shape[-1] // tp_size return from_parallel_logits_to_logprobs( vocab_parallel_logits.to_local(), input_ids, vocab_interval_per_rank * tp_rank, (tp_rank + 1) * vocab_interval_per_rank, - tp_mesh.get_group(), + tp_group, inference_only=not torch.is_grad_enabled(), + seq_index=seq_index, ) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 6872250d10..b94e95fd72 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -574,7 +574,8 @@ def train( .full_tensor() .squeeze(0) ) - _, sorted_indices = torch.sort(seq_index_dtensor) + + mb["seq_index"] = seq_index_dtensor for tensor_name in mb: current_tensor = mb[tensor_name] @@ -587,18 +588,28 @@ def train( current_tensor, device_mesh=self.cp_mesh, placements=[Shard(sequence_dim)], - ).full_tensor()[:, sorted_indices] + ) break if isinstance(logits, DTensor): - logits = logits.full_tensor() - - logits_dtensor = DTensor.from_local( - logits, - device_mesh=self.cp_mesh, - placements=[Shard(sequence_dim)], - ) - logits = logits_dtensor.full_tensor()[:, sorted_indices] + # Must be tp sharded + assert ( + logits.device_mesh.ndim == 1 + and logits.device_mesh.mesh_dim_names[0] == "tp" + ), "logits must be tp sharded" + + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim + logits = DTensor.from_local( + logits.to_local(), + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + else: + logits = DTensor.from_local( + logits, + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) loss, loss_metrics = loss_fn( logits, mb, global_valid_seqs, global_valid_toks diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 9113723af0..415da80a40 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -942,7 +942,7 @@ def collection_fn(output_tensor): target=input_ids, vocab_start_index=tp_rank * output_tensor.shape[-1], vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], - group=tp_grp, + tp_group=tp_grp, inference_only=True, )