Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions nemo_rl/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __call__(
prev_logprobs = data["prev_logprobs"][:, 1:]
generation_logprobs = data["generation_logprobs"][:, 1:]
reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:]
seq_index = data.get("seq_index", None)

mask = token_mask * sample_mask.unsqueeze(-1)

Expand All @@ -146,12 +147,12 @@ def __call__(
data["input_ids"],
vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1],
vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1],
group=vocab_parallel_group,
tp_group=vocab_parallel_group,
inference_only=False,
)
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
curr_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"]
next_token_logits, data["input_ids"], seq_index=seq_index
)
else:
next_token_logits_wo_last = next_token_logits[
Expand Down Expand Up @@ -332,7 +333,7 @@ def __call__(
data["input_ids"],
vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1],
vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1],
group=vocab_parallel_group,
tp_group=vocab_parallel_group,
inference_only=False,
)
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
Expand Down Expand Up @@ -480,7 +481,7 @@ def _preference_loss(
data["input_ids"],
vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1],
vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1],
group=vocab_parallel_group,
tp_group=vocab_parallel_group,
inference_only=False,
)
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
Expand Down
45 changes: 39 additions & 6 deletions nemo_rl/distributed/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from typing import Any, Optional

import torch
from torch.distributed.tensor import DTensor, distribute_tensor


@torch.no_grad()
Expand Down Expand Up @@ -121,11 +122,12 @@ def backward(

def from_parallel_logits_to_logprobs(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
target: torch.Tensor | DTensor,
vocab_start_index: int,
vocab_end_index: int,
group: torch.distributed.ProcessGroup,
tp_group: torch.distributed.ProcessGroup,
inference_only: bool = False,
seq_index: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Get log probabilities from TP sharded vocab logits.

Expand All @@ -136,22 +138,53 @@ def from_parallel_logits_to_logprobs(
NOTE: Must be the unmodified targets as this function will shift them internally.
vocab_start_index (int): Starting vocabulary index for this worker's partition.
vocab_end_index (int): Ending vocabulary index for this worker's partition.
group (torch.distributed.ProcessGroup): Process group for distributed communication.
tp_group (torch.distributed.ProcessGroup): Process group for distributed communication.
inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False.
seq_index (Optional[torch.Tensor]): Sequence index tensor with shape [seq_len].
It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension.

Returns:
torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1].
The sequence dimension is reduced by 1 due to the target shifting.

Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354
"""
target = target.roll(shifts=-1, dims=-1)
cp_size = 1

if isinstance(target, DTensor) and "cp" in target.device_mesh.mesh_dim_names:
cp_dim_index = target.device_mesh.mesh_dim_names.index("cp")
cp_size = target.device_mesh.shape[cp_dim_index]

if cp_size > 1:
assert seq_index is not None, "seq_index must be provided for cp sharded logits"
target_shape = torch.Size(target.shape)
cp_mesh = target.device_mesh
cp_placements = target.placements
_, sorted_indices = torch.sort(seq_index)
# Recover the original order of the target
target = target.full_tensor()[:, sorted_indices]
target = target.roll(shifts=-1, dims=-1)[:, seq_index]

# Reshard
target = distribute_tensor(target, cp_mesh, cp_placements)
target = target.to_local()
else:
target = target.roll(shifts=-1, dims=-1)

probs: torch.Tensor = DistributedLogprob.apply( # type: ignore
vocab_parallel_logits,
target,
vocab_start_index,
vocab_end_index,
group,
tp_group,
inference_only,
).contiguous()

if cp_size > 1:
# probs is sharded on the sequence dimension.
# Get full sequence tensor, vocab dim has been reduced already.
probs_dtensor = DTensor.from_local(probs, cp_mesh, cp_placements)
probs = probs_dtensor.full_tensor()[:, sorted_indices]
assert probs.shape == target_shape

return probs[:, :-1]
22 changes: 17 additions & 5 deletions nemo_rl/models/dtensor/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,9 @@ def get_grad_norm(


def get_logprobs_from_vocab_parallel_logits(
vocab_parallel_logits: DTensor, input_ids: torch.Tensor
vocab_parallel_logits: DTensor,
input_ids: torch.Tensor | DTensor,
seq_index: Optional[torch.Tensor] = None,
):
"""Computes log probabilities from vocabulary-parallel logits.

Expand All @@ -632,16 +634,26 @@ def get_logprobs_from_vocab_parallel_logits(
Returns:
torch.Tensor: Log probabilities for the given input IDs.
"""
tp_mesh = vocab_parallel_logits.device_mesh
tp_rank: int = tp_mesh.get_local_rank()
device_mesh = vocab_parallel_logits.device_mesh
if seq_index is not None:
assert "cp" in device_mesh.mesh_dim_names, (
"seq_index must be provided for cp sharded logits"
)

tp_size = 1

tp_group = device_mesh.get_group("tp")
tp_rank = tp_group.rank()
tp_size = tp_group.size()

vocab_interval_per_rank = vocab_parallel_logits.shape[-1] // tp_mesh.size()
vocab_interval_per_rank = vocab_parallel_logits.shape[-1] // tp_size

return from_parallel_logits_to_logprobs(
vocab_parallel_logits.to_local(),
input_ids,
vocab_interval_per_rank * tp_rank,
(tp_rank + 1) * vocab_interval_per_rank,
tp_mesh.get_group(),
tp_group,
inference_only=not torch.is_grad_enabled(),
seq_index=seq_index,
)
31 changes: 21 additions & 10 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,8 @@ def train(
.full_tensor()
.squeeze(0)
)
_, sorted_indices = torch.sort(seq_index_dtensor)

mb["seq_index"] = seq_index_dtensor

for tensor_name in mb:
current_tensor = mb[tensor_name]
Expand All @@ -587,18 +588,28 @@ def train(
current_tensor,
device_mesh=self.cp_mesh,
placements=[Shard(sequence_dim)],
).full_tensor()[:, sorted_indices]
)
break

if isinstance(logits, DTensor):
logits = logits.full_tensor()

logits_dtensor = DTensor.from_local(
logits,
device_mesh=self.cp_mesh,
placements=[Shard(sequence_dim)],
)
logits = logits_dtensor.full_tensor()[:, sorted_indices]
# Must be tp sharded
assert (
logits.device_mesh.ndim == 1
and logits.device_mesh.mesh_dim_names[0] == "tp"
), "logits must be tp sharded"

# CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim
logits = DTensor.from_local(
logits.to_local(),
device_mesh=self.device_mesh[("cp", "tp")],
placements=[Shard(sequence_dim), Shard(-1)],
)
else:
logits = DTensor.from_local(
logits,
device_mesh=self.device_mesh[("cp", "tp")],
placements=[Shard(sequence_dim), Shard(-1)],
)

loss, loss_metrics = loss_fn(
logits, mb, global_valid_seqs, global_valid_toks
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ def collection_fn(output_tensor):
target=input_ids,
vocab_start_index=tp_rank * output_tensor.shape[-1],
vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1],
group=tp_grp,
tp_group=tp_grp,
inference_only=True,
)

Expand Down
Loading