diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 785b6e0d2e..28126b526c 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -47,6 +47,10 @@ policy: weight_decay: 0.1 betas: [0.9, 0.98] eps: 1e-5 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False data: max_input_seq_length: ${policy.max_total_sequence_length} diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index d674e7deb0..ef5a698678 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -112,7 +112,7 @@ def __call__( next_token_logprobs = torch.nn.functional.log_softmax( next_token_logits, dim=-1 ) - next_tokens = data["input_ids"][:, 1:] # Skip first token + next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token curr_logprobs = next_token_logprobs.gather( dim=-1, index=next_tokens.unsqueeze(-1) ).squeeze(-1) @@ -168,14 +168,22 @@ def __call__( sample_mask = data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) - next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits + next_token_logits = next_token_logits.to(torch.float32) # Gather the logprobs for the actual next tokens - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) + if isinstance(next_token_logits, torch.distributed.tensor.DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, data["input_ids"] + ) + else: + next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits, dim=-1 + ) + logprobs = next_token_logprobs[:, :-1] # Remove last position's logits + token_logprobs = logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) # Only compute loss on generated tokens (not input tokens) # by applying the token_loss_mask (shifted by 1 since we're predicting next tokens) diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 45f4f08575..e6a6b3f418 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -237,6 +237,7 @@ def validate( val_metrics = {"val_loss": 0.0} + policy.prepare_for_training() for batch_idx, val_batch in enumerate(val_dataloader): ## add loss mask based on role to every message add_loss_mask_to_message_log( @@ -247,6 +248,9 @@ def validate( cat_and_padded, input_lengths = batched_message_log_to_flat_message( val_batch["message_log"], pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"][ + "make_sequence_length_divisible_by" + ], ) val_data: BatchedDataDict = BatchedDataDict( @@ -358,6 +362,9 @@ def sft_train( cat_and_padded, input_lengths = batched_message_log_to_flat_message( batch["message_log"], pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"][ + "make_sequence_length_divisible_by" + ], ) train_data: BatchedDataDict = BatchedDataDict( diff --git a/nemo_reinforcer/models/policy/dtensor_policy_worker.py b/nemo_reinforcer/models/policy/dtensor_policy_worker.py index c967a53c97..a7c7f717fb 100644 --- a/nemo_reinforcer/models/policy/dtensor_policy_worker.py +++ b/nemo_reinforcer/models/policy/dtensor_policy_worker.py @@ -321,6 +321,7 @@ def train( mb_losses.append(loss.item()) all_mb_metrics.append(loss_metrics) + grad_norm = None if not eval_mode: with torch.no_grad(): grad_norm = get_grad_norm( @@ -347,7 +348,7 @@ def train( with torch.no_grad(): local_loss = torch.tensor(losses, device="cuda") global_loss = torch.zeros_like(local_loss) - torch.distributed.all_reduce(local_loss) + torch.distributed.all_reduce(local_loss, group=self.dp_mesh.get_group()) global_loss = local_loss / self.dp_size # Aggregate metrics across all microbatches