-
Notifications
You must be signed in to change notification settings - Fork 295
fix: Replace decode-based prefix matching with EOS-boundary splicing #1337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
0565301
Replace decode-based prefix matching with EOS-boundary splicing to ro…
parthchadha f294d46
Add missing test file
parthchadha 6369b21
Fix failing unit test
parthchadha d053d7a
Use kwargs for _preprocess_chat
parthchadha 78811ba
Merge remote-tracking branch 'origin/main' into vllm-async-token-merg…
parthchadha 1882c84
Update Megatron-Bridge submodule to match main branch
parthchadha d44a558
Merge branch 'main' into vllm-async-token-merging-improve
parthchadha File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,6 @@ | |
| import torch | ||
| import uvicorn | ||
| from fastapi import FastAPI | ||
| from transformers.tokenization_utils_base import PreTrainedTokenizerBase | ||
|
|
||
| from nemo_rl.distributed.batched_data_dict import BatchedDataDict | ||
| from nemo_rl.distributed.virtual_cluster import _get_free_port_local, _get_node_ip_local | ||
|
|
@@ -36,88 +35,90 @@ | |
| from nemo_rl.models.generation.vllm.vllm_worker import BaseVllmGenerationWorker | ||
|
|
||
|
|
||
| def _maybe_correct_merged_tokens( | ||
| tokenizer: PreTrainedTokenizerBase, | ||
| reference_token_ids: list[int], | ||
| actual_token_ids: list[int], | ||
| def _replace_prefix_tokens( | ||
| tokenizer, | ||
| model_prefix_token_ids: list[int], | ||
| template_prefix_token_ids: list[int], | ||
| template_token_ids: list[int], | ||
| ) -> list[int]: | ||
| """This is a subroutine used inside the vLLM Chat Completion server. Some environments (namely Penguin) require an OpenAI compatible server endpoint rather than an inference engine handle. This is fine for the most part, but it may cause issues when the environment is used as a part of training. | ||
|
|
||
| RL training frameworks train models on token IDs, but the OpenAI compatible server communicates in what is basically de-tokenized text. When multiple model calls are made to the OpenAI compatible server in a single trajectory, model generations in previous model calls may be re-tokenized to something that is different than what was generated. This is not too big of an issue (that we know of) at inference time, but the log probs the model produces are different enough for the differently re-tokenized generation result that it causes the training to be off policy. Off policy isn't necessarily a bad thing in isolation, but this source of off-policyness may cause unexpected issues if not properly accounted for. It also mis-aligns the token ID sequences across model calls, which feels very strange during training. | ||
|
|
||
| Thus, in this function we attempt to correct any minor re-tokenization errors in an effort to stay on-policy as possible. We require the tokenizer, the ground truth reference token ids taken directly from previous model calls, and the re-tokenized actual token ids. | ||
|
|
||
| In other words, for the current model call: | ||
| - reference_token_ids = all_prefill_so_far + new_generation | ||
| - all_prefill_so_far: the last model call model engine input token ids. Literally what the model sees during the last generation call. | ||
| - new_generation: the last model call model engine generated token ids. Literally what the model generates during the last generation call. | ||
| - actual_token_ids = all_prefill_so_far_maybe_diff_tokenization + new_generation_maybe_diff_tokenization + tool_response_or_user + assistant_generation_prompt | ||
| - all_prefill_so_far_maybe_diff_tokenization: the re-tokenized version of all_prefill_so_far. Since the token IDs in all_prefill_so_far were de-tokenized and returned as OpenAI schema, they must be re-tokenized for the current model call, which means that it may differ from all_prefill_so_far | ||
| - new_generation_maybe_diff_tokenization: analogous version of all_prefill_so_far_maybe_diff_tokenization for new_generation | ||
| - tool_response_or_user: some returned user or tool message. It doesn't matter that this is tokenized here since it has never been tokenized before. However, at the next model call, this will become part of the all_prefill_so_far. | ||
| - assistant_generation_prompt: a common sequence of tokens to instruct the model to generate an assistant response. | ||
|
|
||
| The goal of this subroutine is to find the prefix in actual_token_ids that corresponds to the de-tokenized text of reference_token_ids. | ||
| The idea of this subroutine implementation is to just de-tokenize subsequences of actual_token_ids (called candidate_token_ids) until the de-tokenized text matches the de-tokenized text of reference_token_ids. | ||
|
|
||
| TODO When NeMo RL supports training image generation models, we want to revisit and possibly update this function. This issue occurs when the model generates tokens that are de-tokenized into text or images, and then re-tokenized into tokens. So if there is a situation like that with images and image tokenization is non-unique, then we will need to uppdate this function. | ||
| """This is a subroutine used inside the vLLM Chat Completion server. | ||
|
|
||
| This function is for fixing up the chat template-tokenized messages history | ||
| to match the model output tokenization up to the last assistant turn, | ||
| in order to preserve the monotonic tokens property for optimized multi-turn | ||
| training. | ||
|
|
||
| Some environments (namely Penguin) require an OpenAI compatible server | ||
| endpoint rather than an inference engine handle. This is fine for the most | ||
| part, but it may cause issues when the environment is used as a part of | ||
| training. | ||
|
|
||
| RL training frameworks train models on token IDs, but the OpenAI compatible | ||
| server communicates in what is basically de-tokenized text. When multiple | ||
| model calls are made to the OpenAI compatible server in a single trajectory, | ||
| model generations in previous model calls may be re-tokenized to something | ||
| that is different than what was generated. This is not too big of an issue | ||
| (that we know of) at inference time, but the log probs the model produces | ||
| are different enough for the differently re-tokenized generation result that | ||
| it causes the training to be off policy. Off policy isn't necessarily a bad | ||
| thing in isolation, but this source of off-policyness may cause unexpected | ||
| issues if not properly accounted for. It also mis-aligns the token ID | ||
| sequences across model calls, which feels very strange during training. | ||
|
|
||
| There are real cases where the model output string _does not match_ the chat | ||
| template tokenization of the parsed model output. A concrete example is | ||
| inconsistent whitespace tokens around tool call special tokens. | ||
|
|
||
| TODO When NeMo RL supports training image generation models, we want to | ||
| revisit and possibly update this function. This issue occurs when the model | ||
| generates tokens that are de-tokenized into text or images, and then | ||
| re-tokenized into tokens. So if there is a situation like that with images | ||
| and image tokenization is non-unique, then we will need to uppdate this | ||
| function. | ||
|
|
||
| Example (turn-by-turn, concise; eos_token_id = 2): | ||
| Turn 1: | ||
| - prefill_T1 (template prefill) = [11,12,13,40,41] | ||
| - model output = [220,17,2] # decodes to " 4" + EOS | ||
| - model_prefix_token_ids = prefill_T1 + model output | ||
| => [11,12,13,40,41,220,17,2] | ||
|
|
||
| Turn 2 (template retokenizes prior assistant text differently): | ||
| - template_prefix_token_ids = [11,12,13,40,41,1001,2] # 1001 decodes to " 4" | ||
| - template_token_ids = [11,12,13,40,41,1001,2,21,22,40,41] | ||
|
|
||
| _replace_prefix_tokens keeps the exact prior model tokens up to EOS and | ||
| resumes from the template after that EOS: | ||
| output => [11,12,13,40,41,220,17,2,21,22,40,41] | ||
| """ | ||
| if not reference_token_ids: | ||
| return actual_token_ids | ||
|
|
||
| # No re-tokenization errors | ||
| if reference_token_ids == actual_token_ids[: len(reference_token_ids)]: | ||
| return actual_token_ids | ||
|
|
||
| reference_str, actual_str = tokenizer.batch_decode( | ||
| [reference_token_ids, actual_token_ids] | ||
| if not model_prefix_token_ids: | ||
| return template_token_ids | ||
|
|
||
| eos_token_id = tokenizer.eos_token_id | ||
| assert eos_token_id is not None, "Your tokenizer must have an EOS token ID!" | ||
|
|
||
| model_cut_end = len(model_prefix_token_ids) | ||
| if model_prefix_token_ids: | ||
| # We are not always guaranteed that the model outputs an EOS token as the stop criteria of the previous model call e.g. when the model reaches max_tokens. | ||
| # And since chat templates will always add one for us, we just cut the model input to right before the EOS token ID (if applicable) | ||
| if model_prefix_token_ids[-1] == eos_token_id: | ||
| model_cut_end -= 1 | ||
|
|
||
| # We take everything starting with the EOS token ID. | ||
| template_cut_start = -1 | ||
| for pos in reversed(range(len(template_prefix_token_ids))): | ||
| if template_token_ids[pos] == eos_token_id: | ||
| template_cut_start = pos | ||
| break | ||
|
|
||
| # This should never be the case, but | ||
| assert template_cut_start >= 0, ( | ||
| "No EOS token ID found in the chat-templated messages!" | ||
| ) | ||
|
|
||
| # For now, if a trajectory is not monotonically increasing, we assert. | ||
| # Eventually when we support non-monotonic training, we need to update this logic | ||
| assert ( | ||
| reference_str == actual_str[: len(reference_str)] | ||
| ), f"""Found a non-monotonically increasing trajectory that is not caused by a token merge on re-tokenization! | ||
| Reference str: {reference_str} | ||
| Actual str: {actual_str} | ||
|
|
||
| Reference token ids: {reference_token_ids} | ||
| Actual token ids: {actual_token_ids}""" | ||
|
|
||
| # Now we want to try to find the subsequence of actual_token_ids that corresponds to reference_str | ||
| # Our first guess is just the prefix in actual_token_ids of length reference_token_ids. How good of a guess this is depends on the distribution of the number of re-tokenization errors. | ||
| # If there are a lot, this will be a poor guess. If there aren't that many this is a good guess. | ||
| candidate_token_ids = actual_token_ids[: len(reference_token_ids)] | ||
| candidate_str = tokenizer.decode(candidate_token_ids) | ||
|
|
||
| # If it's longer, we remove | ||
| if len(candidate_str) > len(reference_str): | ||
| while ( | ||
| candidate_str != reference_str | ||
| and len(candidate_str) > len(reference_str) | ||
| and candidate_token_ids | ||
| ): | ||
| candidate_token_ids.pop() | ||
| candidate_str = tokenizer.decode(candidate_token_ids) | ||
| # If it's shorter we append | ||
| elif len(candidate_str) < len(reference_str): | ||
| while ( | ||
| candidate_str != reference_str | ||
| and len(candidate_str) < len(reference_str) | ||
| and len(candidate_token_ids) < len(actual_token_ids) - 1 | ||
| ): | ||
| candidate_token_ids.append(actual_token_ids[len(candidate_token_ids)]) | ||
| candidate_str = tokenizer.decode(candidate_token_ids) | ||
| # If it's equal we should not need to do any modification. The assert below will directly error out. | ||
| else: | ||
| pass | ||
|
|
||
| # If we break above, it must be that we either found a correct match or that we didn't find a valid match | ||
| # e.g. in cases where there is some token merging that occurs at the very end of the reference_token_ids | ||
| # We scream loudly here. | ||
| assert candidate_str == reference_str | ||
|
|
||
| return reference_token_ids + actual_token_ids[len(candidate_token_ids) :] | ||
| return ( | ||
| model_prefix_token_ids[:model_cut_end] + template_token_ids[template_cut_start:] | ||
| ) | ||
|
|
||
|
|
||
| @ray.remote( | ||
|
|
@@ -151,6 +152,9 @@ async def report_dp_openai_server_base_url(self) -> Optional[str]: | |
| return self.base_url | ||
|
|
||
| def _setup_vllm_openai_api_server(self, app: FastAPI) -> FastAPI: | ||
| from copy import deepcopy | ||
| from logging import Filter as LoggingFilter | ||
| from logging import LogRecord | ||
| from typing import List, Optional, Union | ||
|
|
||
| from fastapi import Request | ||
|
|
@@ -169,6 +173,7 @@ def _setup_vllm_openai_api_server(self, app: FastAPI) -> FastAPI: | |
| TokenizeCompletionRequest, | ||
| TokenizeResponse, | ||
| ) | ||
| from vllm.v1.engine.async_llm import logger as vllm_async_llm_logger | ||
|
|
||
| engine_client = self.llm | ||
| model_config = self.llm_async_engine_args.create_model_config() | ||
|
|
@@ -214,6 +219,14 @@ async def _preprocess_chat( | |
| truncate_prompt_tokens=None, | ||
| add_special_tokens=False, | ||
| ): | ||
| # Materialize the message tool calls so we can deepcopy below. | ||
| for message in messages: | ||
| if message.get("tool_calls"): | ||
| message["tool_calls"] = list(message["tool_calls"]) | ||
|
|
||
| # Deepcopy messages here since _preprocess_chat may be destructive. | ||
| messages_for_replace_prefix_tokens = deepcopy(messages) | ||
|
|
||
| # res is conversation, [request_prompt], [engine_prompt] | ||
| res = await super()._preprocess_chat( | ||
| request, | ||
|
|
@@ -234,14 +247,50 @@ async def _preprocess_chat( | |
| if request.required_prefix_token_ids is None: | ||
| return res | ||
|
|
||
| # Find the last assistant message | ||
| last_assistant_message_idx = None | ||
| for i in reversed(range(len(messages_for_replace_prefix_tokens))): | ||
| if messages_for_replace_prefix_tokens[i]["role"] == "assistant": | ||
| last_assistant_message_idx = i | ||
| break | ||
|
|
||
| # If there's no assistant message, we don't have any issues. | ||
| if last_assistant_message_idx is None: | ||
| return res | ||
|
|
||
| # Include the last assistant message itself. | ||
| messages_to_last_assistant_message = messages_for_replace_prefix_tokens[ | ||
| : last_assistant_message_idx + 1 | ||
| ] | ||
| # Call the actual preprocess chat subroutine so we don't miss anything. Whatever they do is whatever we do since we literally do what they do. | ||
| corresponding_res = await super()._preprocess_chat( | ||
| request, | ||
| tokenizer, | ||
| messages_to_last_assistant_message, | ||
| chat_template, | ||
| chat_template_content_format, | ||
| add_generation_prompt=False, | ||
| continue_final_message=False, | ||
| tool_dicts=tool_dicts, | ||
| documents=documents, | ||
| chat_template_kwargs=chat_template_kwargs, | ||
| tool_parser=tool_parser, | ||
| truncate_prompt_tokens=truncate_prompt_tokens, | ||
| add_special_tokens=add_special_tokens, | ||
| ) | ||
| actual_corresponding_token_ids = corresponding_res[2][0][ | ||
| "prompt_token_ids" | ||
| ] | ||
|
|
||
| engine_prompt = res[2][ | ||
| 0 | ||
| ] # We need to modify engine_prompt.prompt_token_ids | ||
|
|
||
| final_prompt_token_ids = _maybe_correct_merged_tokens( | ||
| final_prompt_token_ids = _replace_prefix_tokens( | ||
| tokenizer=tokenizer, | ||
| reference_token_ids=request.required_prefix_token_ids, | ||
| actual_token_ids=engine_prompt["prompt_token_ids"], | ||
| model_prefix_token_ids=request.required_prefix_token_ids, | ||
| template_prefix_token_ids=request.required_prefix_token_ids, | ||
| template_token_ids=engine_prompt["prompt_token_ids"], | ||
| ) | ||
|
|
||
| engine_prompt["prompt_token_ids"] = final_prompt_token_ids | ||
bxyu-nvidia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
@@ -330,7 +379,38 @@ class NeMoRLTokenizeChatRequest( | |
| class NeMoRLOpenAIServingTokenization( | ||
| NeMoRLOpenAIServingMixin, OpenAIServingTokenization | ||
| ): | ||
| pass | ||
| async def create_tokenize(self, request, raw_request): | ||
| """Override to handle required_prefix_token_ids for tokenization.""" | ||
| # Call parent's create_tokenize first | ||
| result = await super().create_tokenize(request, raw_request) | ||
|
|
||
| # If there's an error or no required_prefix_token_ids, return as-is | ||
| if isinstance(result, ErrorResponse): | ||
| return result | ||
|
|
||
| # Only process chat requests (not completion requests) | ||
| if not hasattr(request, "messages"): | ||
| return result | ||
|
|
||
| # Get the template-tokenized tokens from the result | ||
| template_token_ids = result.tokens | ||
|
|
||
| # Get the tokenizer from the engine client | ||
| tokenizer = await self.engine_client.get_tokenizer() | ||
|
|
||
| # Apply _replace_prefix_tokens to fix up the tokenization | ||
| final_token_ids = _replace_prefix_tokens( | ||
| tokenizer=tokenizer, | ||
| model_prefix_token_ids=request.required_prefix_token_ids, | ||
| template_prefix_token_ids=request.required_prefix_token_ids, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should actually be |
||
| template_token_ids=template_token_ids, | ||
| ) | ||
|
|
||
| # Update the result with the corrected tokens | ||
| result.tokens = final_token_ids | ||
| result.count = len(final_token_ids) | ||
|
|
||
| return result | ||
|
|
||
| openai_serving_tokenization = NeMoRLOpenAIServingTokenization( | ||
| engine_client, | ||
|
|
@@ -356,6 +436,20 @@ async def tokenize(request: NeMoRLTokenizeRequest, raw_request: Request): | |
| elif isinstance(generator, TokenizeResponse): | ||
| return JSONResponse(content=generator.model_dump()) | ||
|
|
||
| ######################################## | ||
| # Logging | ||
| ######################################## | ||
| print( | ||
| "Adding a vLLM logging filter so that the logs aren't spammed with `Added request ...` messages. This is to help errors pop up better and filter out noise." | ||
| ) | ||
|
|
||
| class NoAddedRequestFilter(LoggingFilter): | ||
| def filter(self, record: LogRecord) -> bool: | ||
| msg = record.getMessage() | ||
| return "Added request" not in msg | ||
|
|
||
| vllm_async_llm_logger.addFilter(NoAddedRequestFilter()) | ||
|
|
||
| return app | ||
|
|
||
| def _setup_vllm_server(self) -> "tuple[threading.Thread, str, uvicorn.Server]": | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.