diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 7a43e69f6c..49f6bde57f 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -297,8 +297,12 @@ def loss_func(logits, data): loss_fn_outputs.append( { - "logprobs": action_log_probs[i, :valid_len].detach().cpu().tolist(), - "elementwise_loss": elementwise_loss[i, :valid_len].detach().cpu().tolist(), + "logprobs": ( + action_log_probs[i, -valid_len:].detach().cpu().tolist() if valid_len > 0 else [] + ), + "elementwise_loss": ( + elementwise_loss[i, -valid_len:].detach().cpu().tolist() if valid_len > 0 else [] + ), } ) @@ -351,7 +355,7 @@ def loss_func(logits, data): for i, valid_len in enumerate(valid_lens): loss_fn_outputs.append( { - "logprobs": detached_log_probs[i, :valid_len].tolist(), + "logprobs": detached_log_probs[i, -valid_len:].tolist() if valid_len > 0 else [], } ) diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index fc0c99d490..40efbf02e5 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -856,8 +856,10 @@ def _forward_backward_micro( loss_fn_outputs.append( { - "logprobs": action_log_probs[i, :valid_len].detach().cpu().tolist(), - "elementwise_loss": elementwise_loss[i, :valid_len].detach().cpu().tolist(), + "logprobs": action_log_probs[i, -valid_len:].detach().cpu().tolist() if valid_len > 0 else [], + "elementwise_loss": ( + elementwise_loss[i, -valid_len:].detach().cpu().tolist() if valid_len > 0 else [] + ), } ) @@ -913,7 +915,7 @@ def _forward_backward_micro( for i, valid_len in enumerate(valid_lens): loss_fn_outputs.append( { - "logprobs": detached_log_probs[i, :valid_len].tolist(), + "logprobs": detached_log_probs[i, -valid_len:].tolist() if valid_len > 0 else [], } )