Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/utils/debug/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2025 Individual Contributor: TomQunChaoA
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from verl.protocol import DataProto
from verl.utils.debug.metrics import calculate_debug_metrics


class TestMetrics(unittest.TestCase):
def test_calculate_debug_metrics(self):
data = DataProto.from_dict(
{
"rollout_log_probs": torch.tensor(
[
[-1.5085, -0.1200, -0.6650, -0.4823, -0.1426, -1.5557, -2.8532, -0.3919, -0.4294, -0.4700],
[-0.0585, -0.0573, -0.4681, -0.5187, -0.7451, -1.2737, -0.0682, -0.4284, -0.5754, -0.0611],
]
),
"old_log_probs": torch.tensor(
[
[-1.8636, -0.7863, -0.2136, -0.4376, -2.0257, -0.2579, -1.1547, -0.5203, -0.3802, -0.9872],
[-0.3507, -0.5426, -0.2725, -0.4637, -0.3577, -0.3733, -1.7560, -1.9542, -0.4229, -1.3098],
]
),
"loss_mask": torch.tensor([[1, 0, 0, 0, 1, 1, 0, 1, 1, 0], [1, 0, 1, 0, 1, 1, 1, 0, 1, 1]]),
"responses": torch.zeros((2, 10)),
}
)
metrics = calculate_debug_metrics(data)
print(metrics)
assert metrics["training/rollout_probs_diff_valid"] == 1


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions tests/workers/rollout/utils_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def get_rollout_config(
"use_inference_chat_template": False,
"tokenization_sanity_check_mode": "strict",
},
"calculate_log_probs": False,
"max_model_len": None,
**sampling_params,
}
Expand Down
24 changes: 3 additions & 21 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,27 +1194,9 @@ def fit(self):

if "rollout_log_probs" in batch.batch.keys():
# TODO: we may want to add diff of probs too.
rollout_old_log_probs = batch.batch["rollout_log_probs"]
actor_old_log_probs = batch.batch["old_log_probs"]
attention_mask = batch.batch["attention_mask"]
responses = batch.batch["responses"]
response_length = responses.size(1)
response_mask = attention_mask[:, -response_length:]

rollout_probs = torch.exp(rollout_old_log_probs)
actor_probs = torch.exp(actor_old_log_probs)
rollout_probs_diff = torch.abs(rollout_probs - actor_probs)
rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool())
rollout_probs_diff_max = torch.max(rollout_probs_diff)
rollout_probs_diff_mean = torch.mean(rollout_probs_diff)
rollout_probs_diff_std = torch.std(rollout_probs_diff)
metrics.update(
{
"training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(),
"training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(),
"training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(),
}
)
from verl.utils.debug.metrics import calculate_debug_metrics

metrics.update(calculate_debug_metrics(batch))

if self.use_reference_policy:
# compute reference log_prob
Expand Down
109 changes: 109 additions & 0 deletions verl/utils/debug/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2025 Individual Contributor: TomQunChaoA
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import torch

from verl.protocol import DataProto

logger = logging.getLogger(__file__)


def calculate_token_list_diff(tensor1: torch.Tensor, tensor2: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# verify inputs
if tensor1.numel() == 0 or tensor2.numel() == 0:
return torch.zeros(tensor1.shape[0], dtype=torch.long, device=tensor1.device)
if tensor1.shape != tensor2.shape or mask.shape != tensor1.shape or mask.shape != tensor2.shape:
print(
f"<WARN> dim of tensor1, tensor2, mask is not equal, {(tensor1.shape)=},{(tensor2.shape)=}, {(mask.shape)=}"
)
return torch.ones_like(tensor1)
# transfer to same device
if tensor2.device != tensor1.device:
tensor2 = tensor2.to(tensor1.device)
if mask.device != tensor1.device:
mask = mask.to(tensor1.device)

# calculate diff
diff_mask = tensor1 != tensor2

valid_diff_mask = diff_mask & (mask == 1)

diff_counts = valid_diff_mask.sum(dim=1)

return diff_counts


def pearson_correlation_coefficient(tensor1: torch.Tensor, tensor2: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# implemention of https://arxiv.org/pdf/2506.13585
if tensor1.shape != tensor2.shape or mask.shape != tensor1.shape or mask.shape != tensor2.shape:
return 0
mt1 = torch.masked_select(tensor1, mask)
mt2 = torch.masked_select(tensor2, mask)
result = torch.corrcoef(torch.stack([mt1, mt2], dim=0))
return result[0][1].detach().item()


def calculate_log_prob_diff(log_probs1: torch.Tensor, log_probs2: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
full_diff = torch.abs(log_probs1 - log_probs2)
return torch.masked_select(full_diff, mask)


def calculate_debug_metrics(data: DataProto) -> dict:
"""
calculate rollout vs actor logprobs diff, for debugging purpose

Args:
data: DataProto
the data batch to calculate
rollout_log_probs: log_probs record when rollout forward tokens
old_log_probs(actor log probs): log_probs record when actor forward tokens
loss_mask or attention_mask: to mask unrelated token
responses: the response tokens, for calculating size
Returns:
dict: metrics
"training/rollout_probs_diff_valid": 1->input is valid, 0->input is invalid
"training/rollout_probs_diff_max": max value of logprob diff of rollout vs. actor
"training/rollout_probs_diff_mean": mean value of logprob diff of rollout vs. actor
"training/rollout_probs_diff_std": std value of logprob diff of rollout vs. actor
"training/rollout_actor_probs_pearson_corr": logprob's pearson corrcoef of rollout vs. actor, reference to https://arxiv.org/pdf/2506.13585
"""

rollout_old_log_probs = data.batch["rollout_log_probs"]
actor_old_log_probs = data.batch["old_log_probs"]
if "response_mask" in data.batch:
logger.debug("response mask found, use it to mask log probs")
log_prob_mask = data.batch["response_mask"]
elif "attention_mask" in data.batch:
log_prob_mask = data.batch["attention_mask"]
else:
logger.warning(f"no mask info found, use all log probs, {(data.batch.keys())=}")
log_prob_mask = torch.ones_like(rollout_old_log_probs)
responses = data.batch["responses"]
response_length = responses.size(1)

response_mask = log_prob_mask[:, -response_length:]
# calculate pearson corrcoef
actor_probs = torch.exp(actor_old_log_probs)
rollout_probs = torch.exp(rollout_old_log_probs)
response_mask_bool = response_mask.bool()
pearson_corrcoef = pearson_correlation_coefficient(actor_probs, rollout_probs, response_mask_bool)
rollout_probs_diff = calculate_log_prob_diff(actor_probs, rollout_probs, response_mask_bool)
return {
"training/rollout_probs_diff_valid": 1,
"training/rollout_probs_diff_max": torch.max(rollout_probs_diff).detach().item(),
"training/rollout_probs_diff_mean": torch.mean(rollout_probs_diff).detach().item(),
"training/rollout_probs_diff_std": torch.std(rollout_probs_diff).detach().item(),
"training/rollout_actor_probs_pearson_corr": pearson_corrcoef,
}
2 changes: 2 additions & 0 deletions verl/workers/rollout/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class AsyncRolloutRequest(BaseModel):
max_response_len: int = 8192
max_model_len: int = 32768
metrics: dict[str, list[Any]] = {}
output_token_ids: torch.Tensor | None = None
rollout_log_probs: torch.Tensor | None = None

use_inference_chat_template: bool
tokenization_sanity_check_mode: TokenizationSanityCheckModeEnum
Expand Down
51 changes: 50 additions & 1 deletion verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,22 @@ def _pre_process_inputs(
return prompt_token_ids[non_pad_index:]


def _extract_logprob_from_output(output):
"""
extract log_prob from single sglang inference output
"""

def _map_each_response(resp):
input_token_logprobs = resp["meta_info"]["input_token_logprobs"]
log_probs, output_token_ids = zip(
*[(log_prob, token_ids) for log_prob, token_ids, _ in input_token_logprobs[1:]], strict=False
)
return torch.tensor(output_token_ids), torch.tensor(log_probs)

output_token_ids, log_probs = _map_each_response(output)
return output_token_ids, log_probs


# NOTE(linjunrong): adhoc
def _post_process_outputs(processing_class, output):
try:
Expand Down Expand Up @@ -998,7 +1014,19 @@ async def calc_reward_and_release_fn(name: str, tool: BaseTool):
tool_reward_scores = dict(tool_reward_scores)
all_rewards = {**tool_reward_scores, **{"user_turn_rewards": user_turn_rewards}}
_req.finalize(self.processing_class, all_rewards, finish_reason_type)

if self.config.calculate_log_probs:
# 把input_ids输入sglang内生成一遍,并设置max_new_tokens=0,以生成log_probs
debug_sampling_params = {**self.sampling_params}
debug_sampling_params["max_new_tokens"] = 0
output = await self._engine.async_generate(
prompt=None,
input_ids=_req.input_ids,
sampling_params=debug_sampling_params,
return_logprob=True,
logprob_start_len=0,
)
# len(input_token_logprobs) = len(input_tokens)-1,because logprob of 1st token is None
_req.output_token_ids, _req.rollout_log_probs = _extract_logprob_from_output(output)
return _req

async def _handle_engine_call(
Expand Down Expand Up @@ -1097,6 +1125,9 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro
reward_scores = []
multi_modal_inputs = []
request_ids = []
if self.config.calculate_log_probs:
output_logprobs = []
rollout_output_token_ids = []

for req in sorted_output_req_list:
assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f"Request {req.request_id} is not completed"
Expand Down Expand Up @@ -1137,6 +1168,10 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro
reward_scores.append(req.reward_scores)
multi_modal_inputs.append(req.multi_modal_inputs)
request_ids.append(req.request_id)
if self.config.calculate_log_probs:
# extract output log_probs
output_logprobs.append(req.rollout_log_probs[-len(req.response_ids) :])
rollout_output_token_ids.append(req.output_token_ids[-len(req.response_ids) :])

prompt_ids = pad_sequence(
prompt_ids,
Expand Down Expand Up @@ -1201,6 +1236,17 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro
response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0)
if response_loss_mask.shape[1] < self.config.response_length:
response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0)
if self.config.calculate_log_probs:
output_logprobs = pad_sequence(output_logprobs, padding_value=0.0, batch_first=True)
output_logprobs = pad_sequence_to_length(
output_logprobs, pad_token_id=0.0, max_seq_len=response_ids.shape[-1]
).to(tgt_device)
rollout_output_token_ids = pad_sequence(
rollout_output_token_ids, padding_value=self.pad_token_id, batch_first=True
)
rollout_output_token_ids = pad_sequence_to_length(
rollout_output_token_ids, pad_token_id=self.pad_token_id, max_seq_len=response_ids.shape[-1]
).to(tgt_device)

input_ids = torch.cat((prompt_ids, response_ids), dim=-1)
attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)
Expand All @@ -1218,6 +1264,9 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro
},
batch_size=len(sorted_output_req_list),
)
if self.config.calculate_log_probs:
batch["rollout_log_probs"] = output_logprobs
batch["rollout_output_token_ids"] = rollout_output_token_ids

# free cache engine
if self._engine is not None and self._tp_rank == 0:
Expand Down