Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 177 additions & 83 deletions nemo_rl/models/generation/vllm/vllm_worker_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch
import uvicorn
from fastapi import FastAPI
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.virtual_cluster import _get_free_port_local, _get_node_ip_local
Expand All @@ -36,88 +35,90 @@
from nemo_rl.models.generation.vllm.vllm_worker import BaseVllmGenerationWorker


def _maybe_correct_merged_tokens(
tokenizer: PreTrainedTokenizerBase,
reference_token_ids: list[int],
actual_token_ids: list[int],
def _replace_prefix_tokens(
tokenizer,
model_prefix_token_ids: list[int],
template_prefix_token_ids: list[int],
template_token_ids: list[int],
) -> list[int]:
"""This is a subroutine used inside the vLLM Chat Completion server. Some environments (namely Penguin) require an OpenAI compatible server endpoint rather than an inference engine handle. This is fine for the most part, but it may cause issues when the environment is used as a part of training.

RL training frameworks train models on token IDs, but the OpenAI compatible server communicates in what is basically de-tokenized text. When multiple model calls are made to the OpenAI compatible server in a single trajectory, model generations in previous model calls may be re-tokenized to something that is different than what was generated. This is not too big of an issue (that we know of) at inference time, but the log probs the model produces are different enough for the differently re-tokenized generation result that it causes the training to be off policy. Off policy isn't necessarily a bad thing in isolation, but this source of off-policyness may cause unexpected issues if not properly accounted for. It also mis-aligns the token ID sequences across model calls, which feels very strange during training.

Thus, in this function we attempt to correct any minor re-tokenization errors in an effort to stay on-policy as possible. We require the tokenizer, the ground truth reference token ids taken directly from previous model calls, and the re-tokenized actual token ids.

In other words, for the current model call:
- reference_token_ids = all_prefill_so_far + new_generation
- all_prefill_so_far: the last model call model engine input token ids. Literally what the model sees during the last generation call.
- new_generation: the last model call model engine generated token ids. Literally what the model generates during the last generation call.
- actual_token_ids = all_prefill_so_far_maybe_diff_tokenization + new_generation_maybe_diff_tokenization + tool_response_or_user + assistant_generation_prompt
- all_prefill_so_far_maybe_diff_tokenization: the re-tokenized version of all_prefill_so_far. Since the token IDs in all_prefill_so_far were de-tokenized and returned as OpenAI schema, they must be re-tokenized for the current model call, which means that it may differ from all_prefill_so_far
- new_generation_maybe_diff_tokenization: analogous version of all_prefill_so_far_maybe_diff_tokenization for new_generation
- tool_response_or_user: some returned user or tool message. It doesn't matter that this is tokenized here since it has never been tokenized before. However, at the next model call, this will become part of the all_prefill_so_far.
- assistant_generation_prompt: a common sequence of tokens to instruct the model to generate an assistant response.

The goal of this subroutine is to find the prefix in actual_token_ids that corresponds to the de-tokenized text of reference_token_ids.
The idea of this subroutine implementation is to just de-tokenize subsequences of actual_token_ids (called candidate_token_ids) until the de-tokenized text matches the de-tokenized text of reference_token_ids.

TODO When NeMo RL supports training image generation models, we want to revisit and possibly update this function. This issue occurs when the model generates tokens that are de-tokenized into text or images, and then re-tokenized into tokens. So if there is a situation like that with images and image tokenization is non-unique, then we will need to uppdate this function.
"""This is a subroutine used inside the vLLM Chat Completion server.

This function is for fixing up the chat template-tokenized messages history
to match the model output tokenization up to the last assistant turn,
in order to preserve the monotonic tokens property for optimized multi-turn
training.

Some environments (namely Penguin) require an OpenAI compatible server
endpoint rather than an inference engine handle. This is fine for the most
part, but it may cause issues when the environment is used as a part of
training.

RL training frameworks train models on token IDs, but the OpenAI compatible
server communicates in what is basically de-tokenized text. When multiple
model calls are made to the OpenAI compatible server in a single trajectory,
model generations in previous model calls may be re-tokenized to something
that is different than what was generated. This is not too big of an issue
(that we know of) at inference time, but the log probs the model produces
are different enough for the differently re-tokenized generation result that
it causes the training to be off policy. Off policy isn't necessarily a bad
thing in isolation, but this source of off-policyness may cause unexpected
issues if not properly accounted for. It also mis-aligns the token ID
sequences across model calls, which feels very strange during training.

There are real cases where the model output string _does not match_ the chat
template tokenization of the parsed model output. A concrete example is
inconsistent whitespace tokens around tool call special tokens.

TODO When NeMo RL supports training image generation models, we want to
revisit and possibly update this function. This issue occurs when the model
generates tokens that are de-tokenized into text or images, and then
re-tokenized into tokens. So if there is a situation like that with images
and image tokenization is non-unique, then we will need to uppdate this
function.

Example (turn-by-turn, concise; eos_token_id = 2):
Turn 1:
- prefill_T1 (template prefill) = [11,12,13,40,41]
- model output = [220,17,2] # decodes to " 4" + EOS
- model_prefix_token_ids = prefill_T1 + model output
=> [11,12,13,40,41,220,17,2]

Turn 2 (template retokenizes prior assistant text differently):
- template_prefix_token_ids = [11,12,13,40,41,1001,2] # 1001 decodes to " 4"
- template_token_ids = [11,12,13,40,41,1001,2,21,22,40,41]

_replace_prefix_tokens keeps the exact prior model tokens up to EOS and
resumes from the template after that EOS:
output => [11,12,13,40,41,220,17,2,21,22,40,41]
"""
if not reference_token_ids:
return actual_token_ids

# No re-tokenization errors
if reference_token_ids == actual_token_ids[: len(reference_token_ids)]:
return actual_token_ids

reference_str, actual_str = tokenizer.batch_decode(
[reference_token_ids, actual_token_ids]
if not model_prefix_token_ids:
return template_token_ids

eos_token_id = tokenizer.eos_token_id
assert eos_token_id is not None, "Your tokenizer must have an EOS token ID!"

model_cut_end = len(model_prefix_token_ids)
if model_prefix_token_ids:
# We are not always guaranteed that the model outputs an EOS token as the stop criteria of the previous model call e.g. when the model reaches max_tokens.
# And since chat templates will always add one for us, we just cut the model input to right before the EOS token ID (if applicable)
if model_prefix_token_ids[-1] == eos_token_id:
model_cut_end -= 1

# We take everything starting with the EOS token ID.
template_cut_start = -1
for pos in reversed(range(len(template_prefix_token_ids))):
if template_token_ids[pos] == eos_token_id:
template_cut_start = pos
break

# This should never be the case, but
assert template_cut_start >= 0, (
"No EOS token ID found in the chat-templated messages!"
)

# For now, if a trajectory is not monotonically increasing, we assert.
# Eventually when we support non-monotonic training, we need to update this logic
assert (
reference_str == actual_str[: len(reference_str)]
), f"""Found a non-monotonically increasing trajectory that is not caused by a token merge on re-tokenization!
Reference str: {reference_str}
Actual str: {actual_str}

Reference token ids: {reference_token_ids}
Actual token ids: {actual_token_ids}"""

# Now we want to try to find the subsequence of actual_token_ids that corresponds to reference_str
# Our first guess is just the prefix in actual_token_ids of length reference_token_ids. How good of a guess this is depends on the distribution of the number of re-tokenization errors.
# If there are a lot, this will be a poor guess. If there aren't that many this is a good guess.
candidate_token_ids = actual_token_ids[: len(reference_token_ids)]
candidate_str = tokenizer.decode(candidate_token_ids)

# If it's longer, we remove
if len(candidate_str) > len(reference_str):
while (
candidate_str != reference_str
and len(candidate_str) > len(reference_str)
and candidate_token_ids
):
candidate_token_ids.pop()
candidate_str = tokenizer.decode(candidate_token_ids)
# If it's shorter we append
elif len(candidate_str) < len(reference_str):
while (
candidate_str != reference_str
and len(candidate_str) < len(reference_str)
and len(candidate_token_ids) < len(actual_token_ids) - 1
):
candidate_token_ids.append(actual_token_ids[len(candidate_token_ids)])
candidate_str = tokenizer.decode(candidate_token_ids)
# If it's equal we should not need to do any modification. The assert below will directly error out.
else:
pass

# If we break above, it must be that we either found a correct match or that we didn't find a valid match
# e.g. in cases where there is some token merging that occurs at the very end of the reference_token_ids
# We scream loudly here.
assert candidate_str == reference_str

return reference_token_ids + actual_token_ids[len(candidate_token_ids) :]
return (
model_prefix_token_ids[:model_cut_end] + template_token_ids[template_cut_start:]
)


@ray.remote(
Expand Down Expand Up @@ -151,6 +152,9 @@ async def report_dp_openai_server_base_url(self) -> Optional[str]:
return self.base_url

def _setup_vllm_openai_api_server(self, app: FastAPI) -> FastAPI:
from copy import deepcopy
from logging import Filter as LoggingFilter
from logging import LogRecord
from typing import List, Optional, Union

from fastapi import Request
Expand All @@ -169,6 +173,7 @@ def _setup_vllm_openai_api_server(self, app: FastAPI) -> FastAPI:
TokenizeCompletionRequest,
TokenizeResponse,
)
from vllm.v1.engine.async_llm import logger as vllm_async_llm_logger

engine_client = self.llm
model_config = self.llm_async_engine_args.create_model_config()
Expand Down Expand Up @@ -214,6 +219,14 @@ async def _preprocess_chat(
truncate_prompt_tokens=None,
add_special_tokens=False,
):
# Materialize the message tool calls so we can deepcopy below.
for message in messages:
if message.get("tool_calls"):
message["tool_calls"] = list(message["tool_calls"])

# Deepcopy messages here since _preprocess_chat may be destructive.
messages_for_replace_prefix_tokens = deepcopy(messages)

# res is conversation, [request_prompt], [engine_prompt]
res = await super()._preprocess_chat(
request,
Expand All @@ -234,14 +247,50 @@ async def _preprocess_chat(
if request.required_prefix_token_ids is None:
return res

# Find the last assistant message
last_assistant_message_idx = None
for i in reversed(range(len(messages_for_replace_prefix_tokens))):
if messages_for_replace_prefix_tokens[i]["role"] == "assistant":
last_assistant_message_idx = i
break

# If there's no assistant message, we don't have any issues.
if last_assistant_message_idx is None:
return res

# Include the last assistant message itself.
messages_to_last_assistant_message = messages_for_replace_prefix_tokens[
: last_assistant_message_idx + 1
]
# Call the actual preprocess chat subroutine so we don't miss anything. Whatever they do is whatever we do since we literally do what they do.
corresponding_res = await super()._preprocess_chat(
request,
tokenizer,
messages_to_last_assistant_message,
chat_template,
chat_template_content_format,
add_generation_prompt=False,
continue_final_message=False,
tool_dicts=tool_dicts,
documents=documents,
chat_template_kwargs=chat_template_kwargs,
tool_parser=tool_parser,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
actual_corresponding_token_ids = corresponding_res[2][0][
"prompt_token_ids"
]

engine_prompt = res[2][
0
] # We need to modify engine_prompt.prompt_token_ids

final_prompt_token_ids = _maybe_correct_merged_tokens(
final_prompt_token_ids = _replace_prefix_tokens(
tokenizer=tokenizer,
reference_token_ids=request.required_prefix_token_ids,
actual_token_ids=engine_prompt["prompt_token_ids"],
model_prefix_token_ids=request.required_prefix_token_ids,
template_prefix_token_ids=request.required_prefix_token_ids,
template_token_ids=engine_prompt["prompt_token_ids"],
)

engine_prompt["prompt_token_ids"] = final_prompt_token_ids
Expand Down Expand Up @@ -330,7 +379,38 @@ class NeMoRLTokenizeChatRequest(
class NeMoRLOpenAIServingTokenization(
NeMoRLOpenAIServingMixin, OpenAIServingTokenization
):
pass
async def create_tokenize(self, request, raw_request):
"""Override to handle required_prefix_token_ids for tokenization."""
# Call parent's create_tokenize first
result = await super().create_tokenize(request, raw_request)

# If there's an error or no required_prefix_token_ids, return as-is
if isinstance(result, ErrorResponse):
return result

# Only process chat requests (not completion requests)
if not hasattr(request, "messages"):
return result

# Get the template-tokenized tokens from the result
template_token_ids = result.tokens

# Get the tokenizer from the engine client
tokenizer = await self.engine_client.get_tokenizer()

# Apply _replace_prefix_tokens to fix up the tokenization
final_token_ids = _replace_prefix_tokens(
tokenizer=tokenizer,
model_prefix_token_ids=request.required_prefix_token_ids,
template_prefix_token_ids=request.required_prefix_token_ids,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should actually be actual_corresponding_token_ids

template_token_ids=template_token_ids,
)

# Update the result with the corrected tokens
result.tokens = final_token_ids
result.count = len(final_token_ids)

return result

openai_serving_tokenization = NeMoRLOpenAIServingTokenization(
engine_client,
Expand All @@ -356,6 +436,20 @@ async def tokenize(request: NeMoRLTokenizeRequest, raw_request: Request):
elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump())

########################################
# Logging
########################################
print(
"Adding a vLLM logging filter so that the logs aren't spammed with `Added request ...` messages. This is to help errors pop up better and filter out noise."
)

class NoAddedRequestFilter(LoggingFilter):
def filter(self, record: LogRecord) -> bool:
msg = record.getMessage()
return "Added request" not in msg

vllm_async_llm_logger.addFilter(NoAddedRequestFilter())

return app

def _setup_vllm_server(self) -> "tuple[threading.Thread, str, uvicorn.Server]":
Expand Down
Loading
Loading