diff --git a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-24K.yaml b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-24K.yaml index c48b54996d..ab4918df73 100644 --- a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-24K.yaml +++ b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-24K.yaml @@ -45,7 +45,3 @@ policy: gpu_memory_utilization: 0.8 enforce_eager: True max_model_len: ${policy.max_total_sequence_length} - -cluster: - gpus_per_node: 8 - num_nodes: 4 diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 190c366ebf..f27a8c0a97 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -720,6 +720,7 @@ def train( logits = self.model.lm_head(outputs.last_hidden_state) else: logits = outputs.logits + del outputs # Apply temperature scaling logits = self._apply_temperature_scaling(logits) @@ -786,6 +787,7 @@ def train( global_valid_seqs, global_valid_toks, ) + del logits # skip the update for dummy batches if mb_idx < iterator_len: @@ -1044,8 +1046,9 @@ def get_logprobs( placements=[Shard(sequence_dim), Shard(-1)], ) + logits = logits.to(torch.float32) token_logprobs = get_logprobs_from_vocab_parallel_logits( - logits.to(torch.float32), + logits, input_ids_dtensor, seq_index_tensor, ) @@ -1053,8 +1056,9 @@ def get_logprobs( assert token_logprobs.shape[1] == seq_len - 1 else: if isinstance(logits, DTensor): + logits = logits.to(torch.float32) token_logprobs = get_logprobs_from_vocab_parallel_logits( - logits.to(torch.float32), input_ids + logits, input_ids ) else: # Extract logprobs for each token in the sequence by gathering the logprob @@ -1064,16 +1068,16 @@ def get_logprobs( # token_ids: [batch_size, sequence_length] - actual tokens # Output shape: [batch_size, sequence_length] - logprob of each token given previous # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length - - log_probs = torch.nn.functional.log_softmax( - outputs.logits.to(torch.float32), dim=-1 - ) + logits = outputs.logits.to(torch.float32) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) next_tokens = input_ids[:, 1:] log_probs = log_probs[:, :-1] token_logprobs = log_probs.gather( dim=-1, index=next_tokens.unsqueeze(-1) ).squeeze(-1) + del outputs, logits + token_logprobs = torch.cat( [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 )