Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pre-commit
ray[default]
tensordict>=0.8.0,<=0.10.0,!=0.9.0
torchdata
transformers<5.0.0
transformers
# vllm==0.8.4
wandb
packaging>=20.0
Expand Down
4 changes: 3 additions & 1 deletion scripts/legacy_model_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
AutoConfig,
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
GenerationConfig,
PretrainedConfig,
)
Expand All @@ -69,6 +68,9 @@
from tqdm import tqdm

from verl.utils import hf_processor, hf_tokenizer
from verl.utils.transformers_compat import get_auto_model_for_vision2seq

AutoModelForVision2Seq = get_auto_model_for_vision2seq()


@dataclass
Expand Down
5 changes: 3 additions & 2 deletions tests/experimental/reward_loop/test_reward_model_disrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from verl.protocol import DataProto
from verl.utils import hf_tokenizer
from verl.utils.model import compute_position_id_with_mask
from verl.utils.tokenizer import normalize_token_ids


def create_data_samples(tokenizer) -> DataProto:
Expand Down Expand Up @@ -63,8 +64,8 @@ def create_data_samples(tokenizer) -> DataProto:
pad_token_id = tokenizer.pad_token_id
prompts, responses, input_ids, attention_masks = [], [], [], []
for conv in convs:
prompt_tokens = tokenizer.apply_chat_template(conv[:1], tokenize=True)
response_tokens = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt_tokens) :]
prompt_tokens = normalize_token_ids(tokenizer.apply_chat_template(conv[:1], tokenize=True))
response_tokens = normalize_token_ids(tokenizer.apply_chat_template(conv, tokenize=True))[len(prompt_tokens) :]

padded_prompt = [pad_token_id] * (prompt_length - len(prompt_tokens)) + prompt_tokens
padded_response = response_tokens + [pad_token_id] * (response_length - len(response_tokens))
Expand Down
5 changes: 3 additions & 2 deletions tests/experimental/reward_loop/test_reward_model_genrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from verl.protocol import DataProto
from verl.utils import hf_tokenizer
from verl.utils.model import compute_position_id_with_mask
from verl.utils.tokenizer import normalize_token_ids


def create_data_samples(tokenizer) -> DataProto:
Expand Down Expand Up @@ -64,8 +65,8 @@ def create_data_samples(tokenizer) -> DataProto:
pad_token_id = tokenizer.pad_token_id
prompts, responses, input_ids, attention_masks = [], [], [], []
for conv in convs:
prompt_tokens = tokenizer.apply_chat_template(conv[:1], tokenize=True)
response_tokens = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt_tokens) :]
prompt_tokens = normalize_token_ids(tokenizer.apply_chat_template(conv[:1], tokenize=True))
response_tokens = normalize_token_ids(tokenizer.apply_chat_template(conv, tokenize=True))[len(prompt_tokens) :]

padded_prompt = [pad_token_id] * (prompt_length - len(prompt_tokens)) + prompt_tokens
padded_response = response_tokens + [pad_token_id] * (response_length - len(response_tokens))
Expand Down
68 changes: 68 additions & 0 deletions tests/utils/test_tokenizer_normalize_on_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

from verl.utils.tokenizer import normalize_token_ids


class DummyBatchEncoding:
def __init__(self, input_ids):
self.input_ids = input_ids


class DummyToList:
def __init__(self, data):
self._data = data

def tolist(self):
return self._data


@pytest.mark.parametrize(
("tokenized_output", "expected"),
[
# transformers v4-style direct token ids
([1, 2, 3], [1, 2, 3]),
((1, 2, 3), [1, 2, 3]),
# common list-like outputs with tolist()/ndarray paths
(DummyToList([1, 2, 3]), [1, 2, 3]),
(np.array([1, 2, 3], dtype=np.int64), [1, 2, 3]),
# transformers v5-like mapping / BatchEncoding-style outputs
({"input_ids": [1, 2, 3]}, [1, 2, 3]),
({"input_ids": DummyToList([1, 2, 3])}, [1, 2, 3]),
({"input_ids": [[1, 2, 3]]}, [1, 2, 3]),
(DummyBatchEncoding([1, 2, 3]), [1, 2, 3]),
(DummyBatchEncoding(DummyToList([[1, 2, 3]])), [1, 2, 3]),
# scalar item() support
([np.int64(1), np.int32(2), np.int16(3)], [1, 2, 3]),
],
)
def test_normalize_token_ids_valid_outputs(tokenized_output, expected):
assert normalize_token_ids(tokenized_output) == expected


@pytest.mark.parametrize(
"tokenized_output",
[
"not-token-ids",
{"attention_mask": [1, 1, 1]},
[[1, 2], [3, 4]], # ambiguous batched ids should fail fast
[1, object(), 3],
],
)
def test_normalize_token_ids_invalid_outputs(tokenized_output):
with pytest.raises(TypeError):
normalize_token_ids(tokenized_output)
6 changes: 5 additions & 1 deletion tests/workers/rollout/rollout_vllm/test_vllm_abort.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def test_vllm_abort():
print("\n[2] Creating config...")
from hydra import compose, initialize_config_dir

from verl.utils.tokenizer import normalize_token_ids

config_dir = os.path.abspath("verl/verl/trainer/config")
if not os.path.exists(config_dir):
config_dir = os.path.abspath("verl/trainer/config")
Expand Down Expand Up @@ -121,7 +123,9 @@ def test_vllm_abort():
all_prompt_ids = []
for prompt in prompts[:NUM_PROMPTS]:
messages = [{"role": "user", "content": prompt}]
prompt_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
prompt_ids = normalize_token_ids(
tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
)
all_prompt_ids.append(prompt_ids)
print(f"Prepared {NUM_PROMPTS} prompts")

Expand Down
6 changes: 4 additions & 2 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
rollout_trace_attr,
rollout_trace_op,
)
from verl.utils.tokenizer import normalize_token_ids
from verl.workers.config import HFModelConfig, RolloutConfig
from verl.workers.rollout.replica import TokenOutput, get_rollout_replica_class

Expand Down Expand Up @@ -297,9 +298,9 @@ async def apply_chat_template(
return_tensors="pt",
do_sample_frames=False,
)
prompt_ids = model_inputs.pop("input_ids").squeeze(0).tolist()
prompt_ids = normalize_token_ids(model_inputs.pop("input_ids"))
else:
prompt_ids = await self.loop.run_in_executor(
tokenized_prompt = await self.loop.run_in_executor(
None,
lambda: self.tokenizer.apply_chat_template(
messages,
Expand All @@ -309,6 +310,7 @@ async def apply_chat_template(
**self.apply_chat_template_kwargs,
),
)
prompt_ids = normalize_token_ids(tokenized_prompt)

if remove_system_prompt:
prompt_ids = prompt_ids[len(self.system_prompt) :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from verl.experimental.agent_loop import AgentLoopBase
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, register
from verl.utils.profiler import simple_timer
from verl.utils.tokenizer import normalize_token_ids

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
Expand Down Expand Up @@ -75,12 +76,13 @@ def get_prompt_ids():
videos=videos,
)
else:
prompt_ids = await self.loop.run_in_executor(
tokenized_prompt = await self.loop.run_in_executor(
None,
lambda: self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, **self.apply_chat_template_kwargs
),
)
prompt_ids = normalize_token_ids(tokenized_prompt)
else:
if output.extra_fields.get("is_cancel", False):
# Resume the paused sample,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.inputs import TokensPrompt
from vllm.outputs import RequestOutput

from verl.utils.tokenizer import normalize_token_ids
from verl.workers.config import HFModelConfig, RolloutConfig
from verl.workers.rollout.replica import RolloutMode
from verl.workers.rollout.vllm_rollout.vllm_async_server import (
Expand Down Expand Up @@ -72,6 +73,7 @@ async def _generate_step(
image_data: Optional[list[Any]] = None,
video_data: Optional[list[Any]] = None,
):
prompt_ids = normalize_token_ids(prompt_ids)
max_tokens = self.config.max_model_len - len(prompt_ids)
sampling_params["logprobs"] = 1
sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0))
Expand Down
5 changes: 4 additions & 1 deletion verl/experimental/vla/models/register_vla_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

"""Utility helpers to register custom VLA models with Hugging Face Auto classes."""

from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
from transformers import AutoConfig, AutoImageProcessor, AutoProcessor

from verl.utils.transformers_compat import get_auto_model_for_vision2seq

from .openvla_oft.configuration_prismatic import OpenVLAConfig
from .openvla_oft.modeling_prismatic import OpenVLAForActionPrediction
Expand All @@ -26,6 +28,7 @@
"openvla_oft": False,
"pi0_torch": False,
}
AutoModelForVision2Seq = get_auto_model_for_vision2seq()


def register_openvla_oft() -> None:
Expand Down
20 changes: 6 additions & 14 deletions verl/model_merger/base_model_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
)

from verl.utils import hf_processor, hf_tokenizer
from verl.utils.transformers_compat import get_auto_model_for_vision2seq

AutoModelForVision2Seq = get_auto_model_for_vision2seq()


def parse_args():
Expand Down Expand Up @@ -201,20 +204,9 @@ def get_transformers_auto_model_class(self):
case "AutoModelForTokenClassification":
return AutoModelForTokenClassification
case "AutoModelForVision2Seq":
# Handle different transformers versions for Vision2Seq models
import transformers
from packaging import version

if version.parse(transformers.__version__) >= version.parse("4.54.0"):
# transformers >= 4.54.0 uses AutoModelForImageTextToText
from transformers import AutoModelForImageTextToText

return AutoModelForImageTextToText
else:
# transformers < 4.54.0 uses AutoModelForVision2Seq
from transformers import AutoModelForVision2Seq

return AutoModelForVision2Seq
return AutoModelForVision2Seq
case "AutoModelForImageTextToText":
return AutoModelForVision2Seq
case _:
raise NotImplementedError(f"Unknown auto class {auto_class}")
else:
Expand Down
4 changes: 2 additions & 2 deletions verl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from . import config, tokenizer
from .config import omega_conf_to_dataclass, validate_config
from .groupwise import as_torch_index, group_mean_std
from .tokenizer import hf_processor, hf_tokenizer
from .tokenizer import hf_processor, hf_tokenizer, normalize_token_ids

__all__ = (
tokenizer.__all__
+ config.__all__
+ ["hf_processor", "hf_tokenizer", "omega_conf_to_dataclass", "validate_config"]
+ ["hf_processor", "hf_tokenizer", "normalize_token_ids", "omega_conf_to_dataclass", "validate_config"]
+ ["as_torch_index", "group_mean_std"]
)
22 changes: 13 additions & 9 deletions verl/utils/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
import os

from verl.utils.tokenizer import normalize_token_ids

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

Expand All @@ -17,28 +19,30 @@ def initialize_system_prompt(tokenizer, **apply_chat_template_kwargs) -> list[in
Returns:
List of token IDs for the system prompt, or empty list if not supported
"""
token1 = tokenizer.apply_chat_template(
[{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True
token1 = normalize_token_ids(
tokenizer.apply_chat_template([{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True)
)
token2 = tokenizer.apply_chat_template(
[{"role": "user", "content": ""}] * 2, add_generation_prompt=False, tokenize=True
token2 = normalize_token_ids(
tokenizer.apply_chat_template([{"role": "user", "content": ""}] * 2, add_generation_prompt=False, tokenize=True)
)
# get system prompt tokens
system_prompt = token1[: -(len(token2) - len(token1))]
return system_prompt


def extract_system_prompt_and_generation(tokenizer):
token1 = tokenizer.apply_chat_template(
[{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True
token1 = normalize_token_ids(
tokenizer.apply_chat_template([{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True)
)
token2 = tokenizer.apply_chat_template(
[{"role": "user", "content": ""}] * 2, add_generation_prompt=False, tokenize=True
token2 = normalize_token_ids(
tokenizer.apply_chat_template([{"role": "user", "content": ""}] * 2, add_generation_prompt=False, tokenize=True)
)
# get system prompt tokens
system_prompt = token1[: -(len(token2) - len(token1))]
# get generate prompt tokens
token3 = tokenizer.apply_chat_template([{"role": "user", "content": ""}], add_generation_prompt=True, tokenize=True)
token3 = normalize_token_ids(
tokenizer.apply_chat_template([{"role": "user", "content": ""}], add_generation_prompt=True, tokenize=True)
)
generate_prompt = token3[len(token1) :]

return system_prompt, generate_prompt
16 changes: 2 additions & 14 deletions verl/utils/checkpoint/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from verl.utils.fs import copy_to_local, is_non_local, local_mkdir_safe
from verl.utils.fsdp_utils import fsdp_version, get_fsdp_full_state_dict, get_fsdp_state_ctx
from verl.utils.logger import log_with_rank
from verl.utils.transformers_compat import get_auto_model_for_vision2seq

from .checkpoint_manager import BaseCheckpointManager

Expand Down Expand Up @@ -318,20 +319,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i

auto_model_cls = AutoModelForCausalLM
elif "ForConditionalGeneration" in model_config.architectures[0]:
# Handle different transformers versions for Vision2Seq models
import transformers
from packaging import version

if version.parse(transformers.__version__) >= version.parse("4.54.0"):
# transformers >= 4.54.0 uses AutoModelForImageTextToText
from transformers import AutoModelForImageTextToText

auto_model_cls = AutoModelForImageTextToText
else:
# transformers < 4.54.0 uses AutoModelForVision2Seq
from transformers import AutoModelForVision2Seq

auto_model_cls = AutoModelForVision2Seq
auto_model_cls = get_auto_model_for_vision2seq()
else:
raise NotImplementedError(f"Unknown architecture {model_config['architectures']}")

Expand Down
11 changes: 9 additions & 2 deletions verl/utils/dataset/rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from transformers import PreTrainedTokenizer, ProcessorMixin

from verl.utils.import_utils import load_extern_object
from verl.utils.tokenizer import normalize_token_ids

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -243,9 +244,15 @@ def doc2len(doc) -> int:
if self.tool_schemas is not None:
apply_kwargs["tools"] = self.tool_schemas

return len(
tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True, **apply_kwargs)
# Keep explicit tokenization to avoid transformers version default changes.
apply_kwargs.pop("tokenize", None)
apply_kwargs.pop("return_dict", None)
apply_kwargs.pop("return_tensors", None)

tokenized_prompt = tokenizer.apply_chat_template(
doc[prompt_key], add_generation_prompt=True, tokenize=True, **apply_kwargs
)
return len(normalize_token_ids(tokenized_prompt))
except Exception:
print("Error processing one of the samples, skipping...")
traceback.print_exc()
Expand Down
Loading