diff --git a/requirements.txt b/requirements.txt index 6a051b8e458..11960721e58 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ pre-commit ray[default] tensordict>=0.8.0,<=0.10.0,!=0.9.0 torchdata -transformers<5.0.0 +transformers # vllm==0.8.4 wandb packaging>=20.0 diff --git a/scripts/legacy_model_merger.py b/scripts/legacy_model_merger.py index a6da5072df0..f8290997d92 100644 --- a/scripts/legacy_model_merger.py +++ b/scripts/legacy_model_merger.py @@ -55,7 +55,6 @@ AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, - AutoModelForVision2Seq, GenerationConfig, PretrainedConfig, ) @@ -69,6 +68,9 @@ from tqdm import tqdm from verl.utils import hf_processor, hf_tokenizer +from verl.utils.transformers_compat import get_auto_model_for_vision2seq + +AutoModelForVision2Seq = get_auto_model_for_vision2seq() @dataclass diff --git a/tests/experimental/reward_loop/test_reward_model_disrm.py b/tests/experimental/reward_loop/test_reward_model_disrm.py index cc91dad10af..2e42be92ab1 100644 --- a/tests/experimental/reward_loop/test_reward_model_disrm.py +++ b/tests/experimental/reward_loop/test_reward_model_disrm.py @@ -21,6 +21,7 @@ from verl.protocol import DataProto from verl.utils import hf_tokenizer from verl.utils.model import compute_position_id_with_mask +from verl.utils.tokenizer import normalize_token_ids def create_data_samples(tokenizer) -> DataProto: @@ -63,8 +64,8 @@ def create_data_samples(tokenizer) -> DataProto: pad_token_id = tokenizer.pad_token_id prompts, responses, input_ids, attention_masks = [], [], [], [] for conv in convs: - prompt_tokens = tokenizer.apply_chat_template(conv[:1], tokenize=True) - response_tokens = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt_tokens) :] + prompt_tokens = normalize_token_ids(tokenizer.apply_chat_template(conv[:1], tokenize=True)) + response_tokens = normalize_token_ids(tokenizer.apply_chat_template(conv, tokenize=True))[len(prompt_tokens) :] padded_prompt = [pad_token_id] * (prompt_length - len(prompt_tokens)) + prompt_tokens padded_response = response_tokens + [pad_token_id] * (response_length - len(response_tokens)) diff --git a/tests/experimental/reward_loop/test_reward_model_genrm.py b/tests/experimental/reward_loop/test_reward_model_genrm.py index 6decc85b6d8..d86e044d129 100644 --- a/tests/experimental/reward_loop/test_reward_model_genrm.py +++ b/tests/experimental/reward_loop/test_reward_model_genrm.py @@ -22,6 +22,7 @@ from verl.protocol import DataProto from verl.utils import hf_tokenizer from verl.utils.model import compute_position_id_with_mask +from verl.utils.tokenizer import normalize_token_ids def create_data_samples(tokenizer) -> DataProto: @@ -64,8 +65,8 @@ def create_data_samples(tokenizer) -> DataProto: pad_token_id = tokenizer.pad_token_id prompts, responses, input_ids, attention_masks = [], [], [], [] for conv in convs: - prompt_tokens = tokenizer.apply_chat_template(conv[:1], tokenize=True) - response_tokens = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt_tokens) :] + prompt_tokens = normalize_token_ids(tokenizer.apply_chat_template(conv[:1], tokenize=True)) + response_tokens = normalize_token_ids(tokenizer.apply_chat_template(conv, tokenize=True))[len(prompt_tokens) :] padded_prompt = [pad_token_id] * (prompt_length - len(prompt_tokens)) + prompt_tokens padded_response = response_tokens + [pad_token_id] * (response_length - len(response_tokens)) diff --git a/tests/utils/test_tokenizer_normalize_on_cpu.py b/tests/utils/test_tokenizer_normalize_on_cpu.py new file mode 100644 index 00000000000..ae871db70e2 --- /dev/null +++ b/tests/utils/test_tokenizer_normalize_on_cpu.py @@ -0,0 +1,68 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from verl.utils.tokenizer import normalize_token_ids + + +class DummyBatchEncoding: + def __init__(self, input_ids): + self.input_ids = input_ids + + +class DummyToList: + def __init__(self, data): + self._data = data + + def tolist(self): + return self._data + + +@pytest.mark.parametrize( + ("tokenized_output", "expected"), + [ + # transformers v4-style direct token ids + ([1, 2, 3], [1, 2, 3]), + ((1, 2, 3), [1, 2, 3]), + # common list-like outputs with tolist()/ndarray paths + (DummyToList([1, 2, 3]), [1, 2, 3]), + (np.array([1, 2, 3], dtype=np.int64), [1, 2, 3]), + # transformers v5-like mapping / BatchEncoding-style outputs + ({"input_ids": [1, 2, 3]}, [1, 2, 3]), + ({"input_ids": DummyToList([1, 2, 3])}, [1, 2, 3]), + ({"input_ids": [[1, 2, 3]]}, [1, 2, 3]), + (DummyBatchEncoding([1, 2, 3]), [1, 2, 3]), + (DummyBatchEncoding(DummyToList([[1, 2, 3]])), [1, 2, 3]), + # scalar item() support + ([np.int64(1), np.int32(2), np.int16(3)], [1, 2, 3]), + ], +) +def test_normalize_token_ids_valid_outputs(tokenized_output, expected): + assert normalize_token_ids(tokenized_output) == expected + + +@pytest.mark.parametrize( + "tokenized_output", + [ + "not-token-ids", + {"attention_mask": [1, 1, 1]}, + [[1, 2], [3, 4]], # ambiguous batched ids should fail fast + [1, object(), 3], + ], +) +def test_normalize_token_ids_invalid_outputs(tokenized_output): + with pytest.raises(TypeError): + normalize_token_ids(tokenized_output) diff --git a/tests/workers/rollout/rollout_vllm/test_vllm_abort.py b/tests/workers/rollout/rollout_vllm/test_vllm_abort.py index 82034f1e905..cad7cbb5e83 100644 --- a/tests/workers/rollout/rollout_vllm/test_vllm_abort.py +++ b/tests/workers/rollout/rollout_vllm/test_vllm_abort.py @@ -63,6 +63,8 @@ def test_vllm_abort(): print("\n[2] Creating config...") from hydra import compose, initialize_config_dir + from verl.utils.tokenizer import normalize_token_ids + config_dir = os.path.abspath("verl/verl/trainer/config") if not os.path.exists(config_dir): config_dir = os.path.abspath("verl/trainer/config") @@ -121,7 +123,9 @@ def test_vllm_abort(): all_prompt_ids = [] for prompt in prompts[:NUM_PROMPTS]: messages = [{"role": "user", "content": prompt}] - prompt_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) + prompt_ids = normalize_token_ids( + tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) + ) all_prompt_ids.append(prompt_ids) print(f"Prepared {NUM_PROMPTS} prompts") diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 6dd3871c34e..f52ead64570 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -45,6 +45,7 @@ rollout_trace_attr, rollout_trace_op, ) +from verl.utils.tokenizer import normalize_token_ids from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import TokenOutput, get_rollout_replica_class @@ -297,9 +298,9 @@ async def apply_chat_template( return_tensors="pt", do_sample_frames=False, ) - prompt_ids = model_inputs.pop("input_ids").squeeze(0).tolist() + prompt_ids = normalize_token_ids(model_inputs.pop("input_ids")) else: - prompt_ids = await self.loop.run_in_executor( + tokenized_prompt = await self.loop.run_in_executor( None, lambda: self.tokenizer.apply_chat_template( messages, @@ -309,6 +310,7 @@ async def apply_chat_template( **self.apply_chat_template_kwargs, ), ) + prompt_ids = normalize_token_ids(tokenized_prompt) if remove_system_prompt: prompt_ids = prompt_ids[len(self.system_prompt) :] diff --git a/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py index 6982184f8f6..92ea23c6f2c 100644 --- a/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py +++ b/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py @@ -19,6 +19,7 @@ from verl.experimental.agent_loop import AgentLoopBase from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, register from verl.utils.profiler import simple_timer +from verl.utils.tokenizer import normalize_token_ids logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -75,12 +76,13 @@ def get_prompt_ids(): videos=videos, ) else: - prompt_ids = await self.loop.run_in_executor( + tokenized_prompt = await self.loop.run_in_executor( None, lambda: self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, **self.apply_chat_template_kwargs ), ) + prompt_ids = normalize_token_ids(tokenized_prompt) else: if output.extra_fields.get("is_cancel", False): # Resume the paused sample, diff --git a/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py b/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py index 46834874c41..d45dbec890f 100644 --- a/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py +++ b/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py @@ -21,6 +21,7 @@ from vllm.inputs import TokensPrompt from vllm.outputs import RequestOutput +from verl.utils.tokenizer import normalize_token_ids from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import RolloutMode from verl.workers.rollout.vllm_rollout.vllm_async_server import ( @@ -72,6 +73,7 @@ async def _generate_step( image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, ): + prompt_ids = normalize_token_ids(prompt_ids) max_tokens = self.config.max_model_len - len(prompt_ids) sampling_params["logprobs"] = 1 sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0)) diff --git a/verl/experimental/vla/models/register_vla_models.py b/verl/experimental/vla/models/register_vla_models.py index 95dcdd76b87..c3d97cf1952 100644 --- a/verl/experimental/vla/models/register_vla_models.py +++ b/verl/experimental/vla/models/register_vla_models.py @@ -15,7 +15,9 @@ """Utility helpers to register custom VLA models with Hugging Face Auto classes.""" -from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor +from transformers import AutoConfig, AutoImageProcessor, AutoProcessor + +from verl.utils.transformers_compat import get_auto_model_for_vision2seq from .openvla_oft.configuration_prismatic import OpenVLAConfig from .openvla_oft.modeling_prismatic import OpenVLAForActionPrediction @@ -26,6 +28,7 @@ "openvla_oft": False, "pi0_torch": False, } +AutoModelForVision2Seq = get_auto_model_for_vision2seq() def register_openvla_oft() -> None: diff --git a/verl/model_merger/base_model_merger.py b/verl/model_merger/base_model_merger.py index 1dc64042d1e..c86e8def33f 100644 --- a/verl/model_merger/base_model_merger.py +++ b/verl/model_merger/base_model_merger.py @@ -28,6 +28,9 @@ ) from verl.utils import hf_processor, hf_tokenizer +from verl.utils.transformers_compat import get_auto_model_for_vision2seq + +AutoModelForVision2Seq = get_auto_model_for_vision2seq() def parse_args(): @@ -201,20 +204,9 @@ def get_transformers_auto_model_class(self): case "AutoModelForTokenClassification": return AutoModelForTokenClassification case "AutoModelForVision2Seq": - # Handle different transformers versions for Vision2Seq models - import transformers - from packaging import version - - if version.parse(transformers.__version__) >= version.parse("4.54.0"): - # transformers >= 4.54.0 uses AutoModelForImageTextToText - from transformers import AutoModelForImageTextToText - - return AutoModelForImageTextToText - else: - # transformers < 4.54.0 uses AutoModelForVision2Seq - from transformers import AutoModelForVision2Seq - - return AutoModelForVision2Seq + return AutoModelForVision2Seq + case "AutoModelForImageTextToText": + return AutoModelForVision2Seq case _: raise NotImplementedError(f"Unknown auto class {auto_class}") else: diff --git a/verl/utils/__init__.py b/verl/utils/__init__.py index bc40ffb32e1..449c14764a7 100644 --- a/verl/utils/__init__.py +++ b/verl/utils/__init__.py @@ -15,11 +15,11 @@ from . import config, tokenizer from .config import omega_conf_to_dataclass, validate_config from .groupwise import as_torch_index, group_mean_std -from .tokenizer import hf_processor, hf_tokenizer +from .tokenizer import hf_processor, hf_tokenizer, normalize_token_ids __all__ = ( tokenizer.__all__ + config.__all__ - + ["hf_processor", "hf_tokenizer", "omega_conf_to_dataclass", "validate_config"] + + ["hf_processor", "hf_tokenizer", "normalize_token_ids", "omega_conf_to_dataclass", "validate_config"] + ["as_torch_index", "group_mean_std"] ) diff --git a/verl/utils/chat_template.py b/verl/utils/chat_template.py index 64300601c58..e5f8d3e9d1d 100644 --- a/verl/utils/chat_template.py +++ b/verl/utils/chat_template.py @@ -2,6 +2,8 @@ import logging import os +from verl.utils.tokenizer import normalize_token_ids + logger = logging.getLogger(__name__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -17,11 +19,11 @@ def initialize_system_prompt(tokenizer, **apply_chat_template_kwargs) -> list[in Returns: List of token IDs for the system prompt, or empty list if not supported """ - token1 = tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True + token1 = normalize_token_ids( + tokenizer.apply_chat_template([{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True) ) - token2 = tokenizer.apply_chat_template( - [{"role": "user", "content": ""}] * 2, add_generation_prompt=False, tokenize=True + token2 = normalize_token_ids( + tokenizer.apply_chat_template([{"role": "user", "content": ""}] * 2, add_generation_prompt=False, tokenize=True) ) # get system prompt tokens system_prompt = token1[: -(len(token2) - len(token1))] @@ -29,16 +31,18 @@ def initialize_system_prompt(tokenizer, **apply_chat_template_kwargs) -> list[in def extract_system_prompt_and_generation(tokenizer): - token1 = tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True + token1 = normalize_token_ids( + tokenizer.apply_chat_template([{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True) ) - token2 = tokenizer.apply_chat_template( - [{"role": "user", "content": ""}] * 2, add_generation_prompt=False, tokenize=True + token2 = normalize_token_ids( + tokenizer.apply_chat_template([{"role": "user", "content": ""}] * 2, add_generation_prompt=False, tokenize=True) ) # get system prompt tokens system_prompt = token1[: -(len(token2) - len(token1))] # get generate prompt tokens - token3 = tokenizer.apply_chat_template([{"role": "user", "content": ""}], add_generation_prompt=True, tokenize=True) + token3 = normalize_token_ids( + tokenizer.apply_chat_template([{"role": "user", "content": ""}], add_generation_prompt=True, tokenize=True) + ) generate_prompt = token3[len(token1) :] return system_prompt, generate_prompt diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index 1730bd84542..245a0ce5340 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -32,6 +32,7 @@ from verl.utils.fs import copy_to_local, is_non_local, local_mkdir_safe from verl.utils.fsdp_utils import fsdp_version, get_fsdp_full_state_dict, get_fsdp_state_ctx from verl.utils.logger import log_with_rank +from verl.utils.transformers_compat import get_auto_model_for_vision2seq from .checkpoint_manager import BaseCheckpointManager @@ -318,20 +319,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i auto_model_cls = AutoModelForCausalLM elif "ForConditionalGeneration" in model_config.architectures[0]: - # Handle different transformers versions for Vision2Seq models - import transformers - from packaging import version - - if version.parse(transformers.__version__) >= version.parse("4.54.0"): - # transformers >= 4.54.0 uses AutoModelForImageTextToText - from transformers import AutoModelForImageTextToText - - auto_model_cls = AutoModelForImageTextToText - else: - # transformers < 4.54.0 uses AutoModelForVision2Seq - from transformers import AutoModelForVision2Seq - - auto_model_cls = AutoModelForVision2Seq + auto_model_cls = get_auto_model_for_vision2seq() else: raise NotImplementedError(f"Unknown architecture {model_config['architectures']}") diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 5fc327961f0..fc2e987af50 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -32,6 +32,7 @@ from transformers import PreTrainedTokenizer, ProcessorMixin from verl.utils.import_utils import load_extern_object +from verl.utils.tokenizer import normalize_token_ids logger = logging.getLogger(__name__) @@ -243,9 +244,15 @@ def doc2len(doc) -> int: if self.tool_schemas is not None: apply_kwargs["tools"] = self.tool_schemas - return len( - tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True, **apply_kwargs) + # Keep explicit tokenization to avoid transformers version default changes. + apply_kwargs.pop("tokenize", None) + apply_kwargs.pop("return_dict", None) + apply_kwargs.pop("return_tensors", None) + + tokenized_prompt = tokenizer.apply_chat_template( + doc[prompt_key], add_generation_prompt=True, tokenize=True, **apply_kwargs ) + return len(normalize_token_ids(tokenized_prompt)) except Exception: print("Error processing one of the samples, skipping...") traceback.print_exc() diff --git a/verl/utils/model.py b/verl/utils/model.py index a59c4c32962..a67a9d8479c 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -33,7 +33,6 @@ AutoModelForImageTextToText, AutoModelForSequenceClassification, AutoModelForTokenClassification, - AutoModelForVision2Seq, GenerationConfig, MistralForSequenceClassification, PretrainedConfig, @@ -43,6 +42,9 @@ from verl.models.registry import ModelRegistry from verl.utils.import_utils import is_trl_available +from verl.utils.transformers_compat import get_auto_model_for_vision2seq + +AutoModelForVision2Seq = get_auto_model_for_vision2seq() class LambdaLayer(nn.Module): @@ -619,7 +621,7 @@ def can_generate(self): def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_code): - from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq + from transformers import AutoModelForCausalLM, AutoModelForTokenClassification try: model = AutoModelForTokenClassification.from_pretrained( diff --git a/verl/utils/tokenizer.py b/verl/utils/tokenizer.py index 861fd3a5d17..5a99f1e695c 100644 --- a/verl/utils/tokenizer.py +++ b/verl/utils/tokenizer.py @@ -16,7 +16,44 @@ import types import warnings -__all__ = ["hf_tokenizer", "hf_processor"] +__all__ = ["hf_tokenizer", "hf_processor", "normalize_token_ids"] + + +def normalize_token_ids(tokenized_output) -> list[int]: + """Normalize tokenizer outputs into a flat ``list[int]``. + + This handles Transformers 4/5 differences where ``apply_chat_template(tokenize=True)`` + may return either ``list[int]`` or a ``BatchEncoding``/mapping with ``input_ids``. + """ + + token_ids = tokenized_output + if isinstance(tokenized_output, dict): + if "input_ids" in tokenized_output: + token_ids = tokenized_output["input_ids"] + elif hasattr(tokenized_output, "input_ids"): + token_ids = tokenized_output.input_ids + + if hasattr(token_ids, "tolist"): + token_ids = token_ids.tolist() + + if isinstance(token_ids, tuple): + token_ids = list(token_ids) + + if isinstance(token_ids, list) and len(token_ids) == 1 and isinstance(token_ids[0], list | tuple): + token_ids = list(token_ids[0]) + + if not isinstance(token_ids, list): + raise TypeError(f"token_ids must be list-like token ids, got {type(token_ids).__name__}: {token_ids!r}") + + normalized_ids = [] + for idx, token_id in enumerate(token_ids): + if hasattr(token_id, "item"): + token_id = token_id.item() + try: + normalized_ids.append(int(token_id)) + except (TypeError, ValueError) as e: + raise TypeError(f"token_id must be int-convertible, got {type(token_id).__name__}: {token_id!r}") from e + return normalized_ids def set_pad_token_id(tokenizer): @@ -71,12 +108,20 @@ def hf_processor(name_or_path, **kwargs): name_or_path (str): The name of the processor. Returns: - transformers.ProcessorMixin: The pretrained processor. + Optional[transformers.ProcessorMixin]: The pretrained multimodal processor. + Returns ``None`` for text-only models (including AutoProcessor fallbacks to + tokenizer backends such as ``TokenizersBackend``). """ - from transformers import AutoConfig, AutoProcessor + from transformers import AutoConfig, AutoProcessor, PreTrainedTokenizerBase try: processor = AutoProcessor.from_pretrained(name_or_path, **kwargs) + # In newer transformers, AutoProcessor may legitimately fall back to a + # tokenizer backend (e.g. TokenizersBackend) for text-only models. + # Treat it as "no multimodal processor" and let callers use hf_tokenizer. + if isinstance(processor, PreTrainedTokenizerBase): + return None + config = AutoConfig.from_pretrained(name_or_path, **kwargs) # Bind vlm model's get_rope_index method to processor diff --git a/verl/utils/transformers_compat.py b/verl/utils/transformers_compat.py index cfcb9f4dda4..9a03c658512 100644 --- a/verl/utils/transformers_compat.py +++ b/verl/utils/transformers_compat.py @@ -55,3 +55,20 @@ def is_transformers_version_in_range(min_version: Optional[str] = None, max_vers upper_bound_check = transformers_version <= version.parse(max_version) return lower_bound_check and upper_bound_check + + +@lru_cache +def get_auto_model_for_vision2seq(): + """Return the available VL auto model class across transformers versions.""" + + try: + # Prefer the newer class when available. In transformers 4.x this class has + # a broader mapping than AutoModelForVision2Seq, and AutoModelForVision2Seq + # is deprecated for removal in v5. + from transformers import AutoModelForImageTextToText + except ImportError: + from transformers import AutoModelForVision2Seq + + return AutoModelForVision2Seq + + return AutoModelForImageTextToText diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 24b46e13c1b..b6fc052db3d 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -86,6 +86,7 @@ # QAT support from verl.utils.qat import apply_qat, enable_qat_fuse from verl.utils.ray_utils import get_event_loop +from verl.utils.transformers_compat import get_auto_model_for_vision2seq from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig from verl.workers.config.optimizer import build_optimizer from verl.workers.rollout import get_rollout_class @@ -348,12 +349,13 @@ def _build_model_optimizer( AutoModel, AutoModelForCausalLM, AutoModelForImageTextToText, - AutoModelForVision2Seq, ) from verl.utils.model import get_generation_config, print_model_size, update_model_config from verl.utils.torch_dtypes import PrecisionType + AutoModelForVision2Seq = get_auto_model_for_vision2seq() + assert role in ["actor", "ref"] # TiledMLP requires FSDP2 for correct gradient computation diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 3fe3aef3216..9ac052a773e 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -39,6 +39,7 @@ from verl.utils.device import get_resource_name, get_visible_devices_keyword from verl.utils.net_utils import get_free_port, is_valid_ipv6_address from verl.utils.profiler import DistProfiler, build_vllm_profiler_args +from verl.utils.tokenizer import normalize_token_ids from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput @@ -510,6 +511,8 @@ async def generate( priority: int = 0, ) -> TokenOutput: """Generate sequence with token-in-token-out.""" + prompt_ids = normalize_token_ids(prompt_ids) + # Calculate the maximum possible new tokens based on available context space # This serves as a safety upper bound max_possible_tokens = self.config.max_model_len - len(prompt_ids)