Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ policy:
weight_decay: 0.1
betas: [0.9, 0.98]
eps: 1e-5
# when using Dtensor, we need to set foreach
# and fused to False
foreach: False
fused: False

data:
max_input_seq_length: ${policy.max_total_sequence_length}
Expand Down
22 changes: 15 additions & 7 deletions nemo_reinforcer/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __call__(
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
)
next_tokens = data["input_ids"][:, 1:] # Skip first token
next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token
curr_logprobs = next_token_logprobs.gather(
dim=-1, index=next_tokens.unsqueeze(-1)
).squeeze(-1)
Expand Down Expand Up @@ -168,14 +168,22 @@ def __call__(
sample_mask = data["sample_mask"]
mask = token_mask * sample_mask.unsqueeze(-1)

next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token
next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
logprobs = next_token_logprobs[:, :-1] # Remove last position's logits
next_token_logits = next_token_logits.to(torch.float32)

# Gather the logprobs for the actual next tokens
token_logprobs = logprobs.gather(
dim=-1, index=next_tokens.unsqueeze(-1)
).squeeze(-1)
if isinstance(next_token_logits, torch.distributed.tensor.DTensor):
token_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"]
)
else:
next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
)
logprobs = next_token_logprobs[:, :-1] # Remove last position's logits
token_logprobs = logprobs.gather(
dim=-1, index=next_tokens.unsqueeze(-1)
).squeeze(-1)

# Only compute loss on generated tokens (not input tokens)
# by applying the token_loss_mask (shifted by 1 since we're predicting next tokens)
Expand Down
7 changes: 7 additions & 0 deletions nemo_reinforcer/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def validate(

val_metrics = {"val_loss": 0.0}

policy.prepare_for_training()
for batch_idx, val_batch in enumerate(val_dataloader):
## add loss mask based on role to every message
add_loss_mask_to_message_log(
Expand All @@ -247,6 +248,9 @@ def validate(
cat_and_padded, input_lengths = batched_message_log_to_flat_message(
val_batch["message_log"],
pad_value_dict={"token_ids": tokenizer.pad_token_id},
make_sequence_length_divisible_by=master_config["policy"][
"make_sequence_length_divisible_by"
],
)

val_data: BatchedDataDict = BatchedDataDict(
Expand Down Expand Up @@ -358,6 +362,9 @@ def sft_train(
cat_and_padded, input_lengths = batched_message_log_to_flat_message(
batch["message_log"],
pad_value_dict={"token_ids": tokenizer.pad_token_id},
make_sequence_length_divisible_by=master_config["policy"][
"make_sequence_length_divisible_by"
],
)

train_data: BatchedDataDict = BatchedDataDict(
Expand Down
3 changes: 2 additions & 1 deletion nemo_reinforcer/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def train(
mb_losses.append(loss.item())
all_mb_metrics.append(loss_metrics)

grad_norm = None
if not eval_mode:
with torch.no_grad():
grad_norm = get_grad_norm(
Expand All @@ -347,7 +348,7 @@ def train(
with torch.no_grad():
local_loss = torch.tensor(losses, device="cuda")
global_loss = torch.zeros_like(local_loss)
torch.distributed.all_reduce(local_loss)
torch.distributed.all_reduce(local_loss, group=self.dp_mesh.get_group())
global_loss = local_loss / self.dp_size

# Aggregate metrics across all microbatches
Expand Down