From 28122f0cd3a2e798706d82029a3cb65a1f8bf1bc Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 20 Mar 2026 11:27:28 +0100 Subject: [PATCH 01/47] Support multimodal tool responses in environment_factory for VLM training --- examples/scripts/openenv/browsergym.py | 731 +++++++------------------ trl/trainer/grpo_trainer.py | 38 +- 2 files changed, 245 insertions(+), 524 deletions(-) diff --git a/examples/scripts/openenv/browsergym.py b/examples/scripts/openenv/browsergym.py index f48c31236d0..9c149089d86 100644 --- a/examples/scripts/openenv/browsergym.py +++ b/examples/scripts/openenv/browsergym.py @@ -23,58 +23,28 @@ # /// """ -Simple script to run GRPO training with OpenEnv's BrowserGym environment and vLLM. +GRPO training with OpenEnv's BrowserGym environment for VLMs (Vision Language Models). -This example automatically detects and uses vision capabilities when VLM models are used. -Screenshots from BrowserGym are collected and passed to the model during training. The GRPO -trainer auto-detects multimodal support by checking for images in the rollout data. - -Setup (Option A - Install from HF Space, recommended): +This script uses `environment_factory` with multimodal tool responses: each tool action +returns a screenshot (PIL Image) alongside the accessibility tree text, allowing the VLM +to see the page visually after each action. +Setup: ```sh -uv pip install git+https://huggingface.co/spaces/openenv/browsergym_env +pip install "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env" ``` -Setup (Option B - Clone OpenEnv repo, for development): - +Usage: ```sh -git clone https://github.com/meta-pytorch/OpenEnv.git -cd OpenEnv/envs/browsergym_env -uv pip install -e . -``` +# Without vLLM (default, 1 GPU) +python examples/scripts/openenv/browsergym.py -# Option 1: HF Spaces + Colocated vLLM (1 GPU required) -```sh -python examples/scripts/openenv/browsergym.py --vllm-mode colocate -``` - -# Option 2: HF Spaces + Separate vLLM server (2 GPUs required) +# With vLLM colocate (1 GPU, requires vLLM support for the model) +python examples/scripts/openenv/browsergym.py --use-vllm -# Spin up vLLM server (Terminal 1) -```sh -CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-VL-2B-Instruct --host 0.0.0.0 --port 8001 -``` - -# Run training (Terminal 2) -```sh -CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/browsergym.py --vllm-mode server --vllm-server-url http://localhost:8001 -``` - -# Option 3: Local + Colocated vLLM (1 GPU required) - -# Build and start the environment only if using --env-mode docker-local -```sh -cd OpenEnv -docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile . -docker build -t browsergym-env:latest -f src/envs/browsergym_env/server/Dockerfile . -docker run -d -p 8001:8001 \ - -e BROWSERGYM_BENCHMARK="miniwob" \ - -e BROWSERGYM_TASK_NAME="click-test" \ - browsergym-env:latest -``` - -```sh -python examples/scripts/openenv/browsergym.py --env-mode docker-local --vllm-mode colocate +# With vLLM server (2 GPUs) +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3.5-2B --host 0.0.0.0 --port 8000 +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/browsergym.py --use-vllm --vllm-mode server ``` """ @@ -88,190 +58,28 @@ from browsergym_env import BrowserGymAction, BrowserGymEnv from datasets import Dataset from PIL import Image -from transformers import AutoTokenizer from trl import GRPOConfig, GRPOTrainer -from trl.experimental.openenv import generate_rollout_completions def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Run GRPO training for BrowserGym MiniWoB using OpenEnv environment.") - parser.add_argument( - "--tokenizer-id", - default="Qwen/Qwen3-VL-2B-Instruct", - help="Model identifier used to load the tokenizer.", - ) - parser.add_argument( - "--model-id", - default="Qwen/Qwen3-VL-2B-Instruct", - help="Model identifier passed to GRPOTrainer for fine-tuning.", - ) - parser.add_argument( - "--env-host", - type=str, - default="https://openenv-browsergym-env.hf.space", - help="Host for the BrowserGym environment.", - ) - parser.add_argument("--env-port", type=int, default=8001, help="Port for the BrowserGym environment.") - parser.add_argument( - "--env-mode", - choices=["docker-local", "docker-image", "docker-hub", "space"], - default="space", - help="Where to run the environment: 'local' to launch it, 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.", - ) - parser.add_argument( - "--env-image", type=str, default="browsergym-env:latest", help="Docker image for the BrowserGym environment." - ) - parser.add_argument( - "--benchmark", - default="miniwob", - help="BrowserGym benchmark to use (miniwob, webarena, etc.).", - ) - parser.add_argument( - "--task-name", - default="click-test", - help="Specific task within the benchmark (e.g., click-test, click-button).", - ) - parser.add_argument( - "--dataset-prompt", - default="Complete the web task successfully.", - help="Prompt text used to seed the training dataset.", - ) - parser.add_argument( - "--dataset-size", - type=int, - default=1000, - help="Number of entries to include in the synthetic training dataset.", - ) - parser.add_argument( - "--max-steps", - type=int, - default=10, - help="Maximum number of steps per episode.", - ) - parser.add_argument( - "--max-new-tokens", - type=int, - default=32, - help="Maximum number of new tokens to request from vLLM for each action.", - ) - parser.add_argument( - "--temperature", - type=float, - default=0.7, - help="Sampling temperature used during rollout generation.", - ) - parser.add_argument( - "--top-k", - type=int, - default=50, - help="Top-k sampling parameter forwarded to vLLM.", - ) - parser.add_argument( - "--top-p", - type=float, - default=None, - help="Optional top-p sampling parameter forwarded to vLLM.", - ) - parser.add_argument( - "--image-size", - type=int, - default=512, - help="Resize screenshots to this size (preserving aspect ratio) to reduce memory usage. Set to 0 to disable resizing.", - ) - parser.add_argument( - "--learning-rate", - type=float, - default=5e-6, - help="Learning rate for GRPO training.", - ) - parser.add_argument( - "--weight-decay", - type=float, - default=0.0, - help="Weight decay applied during optimization.", - ) - parser.add_argument( - "--gradient-accumulation-steps", - type=int, - default=32, - help="Gradient accumulation steps for GRPO training.", - ) - parser.add_argument( - "--warmup-steps", - type=int, - default=10, - help="Warmup steps for the scheduler.", - ) - parser.add_argument( - "--per-device-batch-size", - type=int, - default=1, - help="Per-device train batch size.", - ) - parser.add_argument( - "--num-generations", - type=int, - default=4, - help="Number of rollout generations per dataset prompt.", - ) - parser.add_argument( - "--num-epochs", - type=int, - default=1, - help="Number of training epochs.", - ) - parser.add_argument( - "--save-interval", - type=int, - default=50, - help="Interval (in steps) between checkpoint saves.", - ) - parser.add_argument( - "--save-total-limit", - type=int, - default=None, - help="Maximum number of checkpoints to keep.", - ) - parser.add_argument( - "--output-dir", - default=None, - help="Directory where training outputs and checkpoints are stored.", - ) - parser.add_argument( - "--run-name", - default=None, - help="Optional run name for logging systems.", - ) - parser.add_argument( - "--project", - default=None, - help="Optional project identifier for logging systems.", - ) - parser.add_argument( - "--vllm-mode", - choices=("colocate", "server"), - default="colocate", - help="vLLM execution mode: 'colocate' or 'server'.", - ) - parser.add_argument( - "--vllm-server-url", - type=str, - default="http://localhost:8001", - help="URL for the vLLM server (only used when --vllm-mode=server).", - ) - parser.add_argument( - "--logging-steps", - type=int, - default=1, - help="Frequency of logging steps for GRPO training.", - ) - parser.add_argument( - "--debug", - action="store_true", - default=False, - help="Enable verbose debugging output during rollouts.", - ) + parser = argparse.ArgumentParser(description="GRPO training with BrowserGym VLM environment.") + parser.add_argument("--model-id", default="Qwen/Qwen3.5-2B") + parser.add_argument("--space-url", default="https://openenv-browsergym-env.hf.space") + parser.add_argument("--dataset-prompt", default="Complete the web task successfully.") + parser.add_argument("--dataset-size", type=int, default=1000) + parser.add_argument("--max-steps", type=int, default=10) + parser.add_argument("--max-completion-length", type=int, default=1024) + parser.add_argument("--image-size", type=int, default=512, help="Resize screenshots to this size. 0 to disable.") + parser.add_argument("--num-generations", type=int, default=4) + parser.add_argument("--gradient-accumulation-steps", type=int, default=32) + parser.add_argument("--learning-rate", type=float, default=5e-6) + parser.add_argument("--num-epochs", type=int, default=1) + parser.add_argument("--logging-steps", type=int, default=1) + parser.add_argument("--output-dir", default=None) + parser.add_argument("--use-vllm", action="store_true", default=False, help="Enable vLLM for generation.") + parser.add_argument("--vllm-mode", choices=("colocate", "server"), default="colocate") + parser.add_argument("--vllm-server-url", default="http://localhost:8000") return parser.parse_args() @@ -279,323 +87,204 @@ def sanitize_name(name: str) -> str: return name.replace("/", "-") -# --------------------------------------------------------------------------- -# System Prompt -# --------------------------------------------------------------------------- - -SYSTEM_PROMPT = """You control a web browser through BrowserGym actions. -You must complete the given web task by interacting with the page. +SYSTEM_PROMPT = """You control a web browser to complete tasks. -Available actions: -- noop() - Do nothing -- click(bid) - Click element with BrowserGym ID -- fill(bid, text) - Fill input field -- send_keys(text) - Send keyboard input -- scroll(direction) - Scroll up/down +The page structure shows elements as: [bid] element_type 'element_text' +For example: [13] button 'Click Me!' means the element has bid='13'. -Reply with exactly ONE action on a single line, e.g.: -click('123') -fill('456', 'text') -noop() +You will see a screenshot of the page after each action. Use the visual information +along with the page structure to decide your next action. -Do not include explanations or multiple actions.""" +Use the available tools to interact with the page: +- click: Click an element by its bid +- fill: Fill an input field with text +- send_keys: Send keyboard input +- scroll: Scroll the page +- noop: Do nothing +Complete the given task as efficiently as possible.""" -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def make_user_prompt(goal: str, step_num: int, axtree: str, error: str = "") -> str: - """Create user prompt from observation.""" - prompt_parts = [f"Step {step_num + 1}"] - - if goal: - prompt_parts.append(f"Goal: {goal}") - - if error: - prompt_parts.append(f"Previous action error: {error}") - - # Include accessibility tree (truncated for context) - if axtree: - max_len = 2000 - axtree_truncated = axtree[:max_len] + "..." if len(axtree) > max_len else axtree - prompt_parts.append(f"Page structure:\n{axtree_truncated}") - - prompt_parts.append("What action do you take?") - - return "\n\n".join(prompt_parts) - - -def parse_action(response_text: str) -> str: - """Parse BrowserGym action from model response.""" - # Extract first line that looks like an action - for line in response_text.strip().split("\n"): - line = line.strip() - if "(" in line and ")" in line: - return line - - # Fallback to noop if no valid action found - return "noop()" - - -def rollout_once( - trainer: GRPOTrainer, - env: BrowserGymEnv, - tokenizer: AutoTokenizer, - dataset_prompt: str, - max_steps: int, - image_size: int = 0, - debug: bool = False, -) -> dict[str, list]: - """Run one episode and collect training data.""" - result = env.reset() - observation = result.observation - - prompt_ids: list[int] = [] - completion_ids: list[int] = [] - logprobs: list[float] = [] - step_rewards: list[float] = [] - completion_rewards: list[float] = [] - images: list[Image.Image] = [] # Collect screenshots for VLM - - for step_num in range(max_steps): - if result.done: - break - - # Create prompt from observation - goal = observation.goal or dataset_prompt - axtree = observation.axtree_txt or "" - error = observation.error if observation.last_action_error else "" - - # Collect screenshot if available (for VLM support) - if observation.screenshot is not None: - screenshot_array = np.array(observation.screenshot, dtype=np.uint8) - screenshot_image = Image.fromarray(screenshot_array) - - # Resize to reduce memory if image_size > 0 - if image_size > 0: - # Preserve aspect ratio while resizing - screenshot_image.thumbnail((image_size, image_size), Image.LANCZOS) - print( - f"[DEBUG] Step {step_num + 1}: Collected and resized screenshot from {screenshot_array.shape} to {screenshot_image.size}" - ) - else: - print(f"[DEBUG] Step {step_num + 1}: Collected screenshot, shape={screenshot_array.shape}") - - images.append(screenshot_image) - else: - print(f"[DEBUG] Step {step_num + 1}: No screenshot available") - - user_prompt = make_user_prompt(goal, step_num, axtree, error) - messages = [ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": user_prompt}, - ] - prompt_text = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=False, - ) - - # Generate action with vLLM - rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0] - prompt_ids.extend(rollout_outputs["prompt_ids"]) - completion_ids.extend(rollout_outputs["completion_ids"]) - logprobs.extend(rollout_outputs["logprobs"]) - - completion_text = rollout_outputs.get("text") or tokenizer.decode( - rollout_outputs["completion_ids"], skip_special_tokens=True - ) - - # Parse and execute action - action_str = parse_action(completion_text) - - if debug: - print(f"Step {step_num + 1}: {action_str}") - - # Take action in environment - result = env.step(BrowserGymAction(action_str=action_str)) - observation = result.observation - - # Track rewards - step_reward = float(result.reward or 0.0) - step_rewards.append(step_reward) - - # Reward shaping: success is most important - if result.done and step_reward > 0: - completion_rewards.append(1.0) # Task completed successfully - elif result.done and step_reward == 0: - completion_rewards.append(0.0) # Task failed - else: - completion_rewards.append(step_reward) # Intermediate reward - - # Final reward is based on task completion - final_reward = completion_rewards[-1] if completion_rewards else 0.0 - - result_dict = { - "prompt_ids": prompt_ids, - "completion_ids": completion_ids, - "logprobs": logprobs, - "step_rewards": step_rewards, - "completion_reward": final_reward, - } - - # Include images if available (GRPO trainer will auto-detect VLM support) - if images: - result_dict["images"] = images - - return result_dict - - -# --------------------------------------------------------------------------- -# Rewards -# --------------------------------------------------------------------------- - - -def reward_completion(completions: list[str], **kwargs) -> list[float]: - """Reward for task completion.""" - rewards = kwargs.get("completion_reward") if kwargs else None - if rewards is None: - return [0.0 for _ in completions] - return [float(r) for r in rewards] - - -# --------------------------------------------------------------------------- -# Main entrypoint -# --------------------------------------------------------------------------- +def reward_completion(completions, environments, **kwargs) -> list[float]: + return [env.reward for env in environments] def main() -> None: args = parse_args() - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id) - tokenizer.pad_token = tokenizer.eos_token - - # Select environment mode - if args.env_mode == "docker-local": - env_url = f"http://{args.env_host}:{args.env_port}" - client = BrowserGymEnv(base_url=env_url) - print(f"🌍 Using existing BrowserGym Environment (Docker) at: {env_url}") - elif args.env_mode == "docker-image": - client = BrowserGymEnv.from_docker_image(args.env_image) - print("🌍 Using BrowserGym Environment (Docker) from local Image") - elif args.env_mode == "docker-hub": - client = BrowserGymEnv.from_hub(args.env_image) - print("🌍 Using existing BrowserGym Environment (Docker) from Hub Image") - elif args.env_mode == "space": - env_url = args.env_host - client = BrowserGymEnv(base_url=env_url) - print(f"🌍 Using Hugging Face Space environment at: {env_url}") - else: - raise ValueError(f"Unknown environment mode: {args.env_mode}") - - dataset = Dataset.from_dict({"prompt": [args.dataset_prompt] * args.dataset_size}) - - timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - default_output_dir = Path("outputs") / f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}" - output_dir = Path(args.output_dir or default_output_dir) - - grpo_config = GRPOConfig( - use_vllm=True, - vllm_mode=args.vllm_mode, - vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None, - vllm_gpu_memory_utilization=0.4, - output_dir=str(output_dir), - num_train_epochs=args.num_epochs, - learning_rate=args.learning_rate, - weight_decay=args.weight_decay, - gradient_accumulation_steps=args.gradient_accumulation_steps, - per_device_train_batch_size=args.per_device_batch_size, - warmup_steps=args.warmup_steps, - num_generations=args.num_generations, - generation_batch_size=args.num_generations, # Must be divisible by num_generations - max_completion_length=args.max_new_tokens, - logging_steps=args.logging_steps, - report_to="trackio", - trackio_space_id=f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}", - save_strategy="steps", - save_steps=args.save_interval, - save_total_limit=args.save_total_limit, - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p, + space_url = args.space_url + max_steps = args.max_steps + image_size = args.image_size + + dataset = Dataset.from_dict( + { + "prompt": [ + [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": args.dataset_prompt}, + ] + ] + * args.dataset_size + } ) - grpo_config.run_name = args.run_name or f"run-{timestamp}" - grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}" - - def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: - episode_prompt_ids: list[list[int]] = [] - episode_completion_ids: list[list[int]] = [] - episode_logprobs: list[list[float]] = [] - completion_rewards: list[float] = [] - episode_images: list[list[Image.Image]] = [] - - print(f"\n[DEBUG] rollout_func called with {len(prompts)} prompts") - - for i, prompt_text in enumerate(prompts): - print(f"[DEBUG] Processing prompt {i + 1}/{len(prompts)}") - episode = rollout_once( - trainer=trainer, - env=client, - tokenizer=tokenizer, - dataset_prompt=prompt_text, - max_steps=args.max_steps, - image_size=args.image_size, - debug=args.debug, - ) - episode_prompt_ids.append(episode["prompt_ids"]) - episode_completion_ids.append(episode["completion_ids"]) - episode_logprobs.append(episode["logprobs"]) - completion_rewards.append(episode["completion_reward"]) - - # Collect images if available (for VLM support) - if "images" in episode: - print(f"[DEBUG] Episode {i + 1} has {len(episode['images'])} images") - episode_images.append(episode["images"]) + class BrowserGymVLMEnv: + def __init__(self): + self.client = BrowserGymEnv(base_url=space_url) + self.reward = 0.0 + self.done = False + self._step_count = 0 + + def reset(self, **kwargs) -> str | None: + self.reward = 0.0 + self.done = False + self._step_count = 0 + result = self.client.reset() + self.done = result.done + return self._format_observation(result.observation) + + def click(self, bid: str) -> list: + """Click an element on the page. + + Args: + bid: The BrowserGym ID of the element to click. + + Returns: + The updated page observation with screenshot. + """ + return self._do_action(f"click('{bid}')") + + def fill(self, bid: str, text: str) -> list: + """Fill an input field with text. + + Args: + bid: The BrowserGym ID of the input field. + text: The text to type into the field. + + Returns: + The updated page observation with screenshot. + """ + return self._do_action(f"fill('{bid}', '{text}')") + + def send_keys(self, text: str) -> list: + """Send keyboard input to the page. + + Args: + text: The keyboard input to send. + + Returns: + The updated page observation with screenshot. + """ + return self._do_action(f"send_keys('{text}')") + + def scroll(self, direction: str) -> list: + """Scroll the page. + + Args: + direction: Direction to scroll, either 'up' or 'down'. + + Returns: + The updated page observation with screenshot. + """ + return self._do_action(f"scroll('{direction}')") + + def noop(self) -> list: + """Do nothing and observe the current page state. + + Returns: + The current page observation with screenshot. + """ + return self._do_action("noop()") + + def _do_action(self, action_str: str) -> list: + if self.done: + raise ValueError("Episode is done.") + + self._step_count += 1 + result = self.client.step(BrowserGymAction(action_str=action_str)) + observation = result.observation + step_reward = float(result.reward or 0.0) + self.done = result.done + + if self.done and step_reward > 0: + self.reward = 1.0 + elif self.done: + self.reward = 0.0 else: - print(f"[DEBUG] Episode {i + 1} has NO images") - - result = { - "prompt_ids": episode_prompt_ids, - "completion_ids": episode_completion_ids, - "logprobs": episode_logprobs, - "completion_reward": completion_rewards, - } - - # Include images if any episode had screenshots (GRPO trainer auto-detects VLM) - if episode_images: - result["images"] = episode_images - print(f"[DEBUG] rollout_func returning with images: {len(episode_images)} episodes") - else: - print("[DEBUG] rollout_func returning WITHOUT images") + self.reward = step_reward + + if self._step_count >= max_steps: + self.done = True + + return self._format_observation_multimodal(observation) + + def _format_observation(self, observation) -> str: + """Format initial observation as text (for reset, appended to prompt).""" + parts = [] + if observation.goal: + parts.append(f"Goal: {observation.goal}") + if observation.axtree_txt: + axtree = observation.axtree_txt + if len(axtree) > 2000: + axtree = axtree[:2000] + "..." + parts.append(f"Page structure:\n{axtree}") + return "\n\n".join(parts) if parts else "No observation available." + + def _format_observation_multimodal(self, observation) -> list: + """Format observation as multimodal content blocks (screenshot + text).""" + content = [] + + # Add screenshot if available + if observation.screenshot is not None: + screenshot_array = np.array(observation.screenshot, dtype=np.uint8) + screenshot_image = Image.fromarray(screenshot_array) + if image_size > 0: + screenshot_image.thumbnail((image_size, image_size), Image.LANCZOS) + content.append({"type": "image", "image": screenshot_image}) + + # Add text observation + parts = [] + if observation.goal: + parts.append(f"Goal: {observation.goal}") + if observation.last_action_error and observation.error: + parts.append(f"Error: {observation.error}") + if observation.axtree_txt: + axtree = observation.axtree_txt + if len(axtree) > 2000: + axtree = axtree[:2000] + "..." + parts.append(f"Page structure:\n{axtree}") + text = "\n\n".join(parts) if parts else "No observation available." + content.append({"type": "text", "text": text}) + + return content - return result + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + default_output_dir = Path("outputs") / f"browsergym-vlm-grpo-{sanitize_name(args.model_id)}-{timestamp}" + output_dir = Path(args.output_dir or default_output_dir) trainer = GRPOTrainer( model=args.model_id, - processing_class=tokenizer, - reward_funcs=[reward_completion], + reward_funcs=reward_completion, train_dataset=dataset, - args=grpo_config, - rollout_func=rollout_func, - ) - - print("=" * 80) - print("Starting GRPO training with BrowserGym environment") - print(f"Benchmark: {args.benchmark}") - print(f"Task: {args.task_name}") - print(f"Model: {args.model_id}") - print(f"Using {args.num_generations} rollouts per dataset prompt") - print(f"Output directory: {output_dir}") - print("=" * 80) - - try: - trainer.train() - print("\nTraining completed successfully!") - finally: - client.close() + args=GRPOConfig( + use_vllm=args.use_vllm, + vllm_mode=args.vllm_mode if args.use_vllm else "colocate", + vllm_server_base_url=args.vllm_server_url if args.use_vllm and args.vllm_mode == "server" else None, + output_dir=str(output_dir), + num_train_epochs=args.num_epochs, + learning_rate=args.learning_rate, + gradient_accumulation_steps=args.gradient_accumulation_steps, + num_generations=args.num_generations, + max_completion_length=args.max_completion_length, + logging_steps=args.logging_steps, + log_completions=True, + report_to="trackio", + trackio_space_id=f"browsergym-vlm-grpo-{sanitize_name(args.model_id)}", + chat_template_kwargs={"enable_thinking": False}, + ), + environment_factory=BrowserGymVLMEnv, + ) + trainer.train() if __name__ == "__main__": diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f54ba039c7b..ed390e1847c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1241,6 +1241,15 @@ async def _run_async_funcs(): def _tokenize_prompts(self, prompts: list): """Tokenize prompts and extract images/multimodal fields for generation.""" if is_conversational({"prompt": prompts[0]}): + # When the processor is a VLM processor, normalize string content to content blocks. + # Some VLM processors (e.g., Qwen3.5, Qwen3-VL) iterate over message["content"] + # assuming it's always a list of content blocks, which fails when content is a plain string. + if hasattr(self.processing_class, "image_processor"): + for prompt in prompts: + for message in prompt: + if isinstance(message["content"], str): + message["content"] = [{"type": "text", "text": message["content"]}] + # Extract images from messages for VLM support images = [] has_images = False @@ -1249,7 +1258,7 @@ def _tokenize_prompts(self, prompts: list): for message in prompt: if isinstance(message["content"], list): for part in message["content"]: - if part["type"] == "image": + if isinstance(part, dict) and part.get("type") == "image": prompt_images.append(part["image"]) has_images = True images.append(prompt_images if prompt_images else None) @@ -1445,7 +1454,11 @@ async def _run_async_tools(async_coros): tool_call_results.append((name, result)) for name, result in tool_call_results: - tool_message = {"role": "tool", "name": name, "content": str(result)} + # Support multimodal tool responses: if the tool returns a list of content blocks + # (e.g., [{"type": "image", "image": ...}, {"type": "text", "text": "..."}]), + # pass them through directly so _tokenize_prompts can extract images for VLMs. + content = result if isinstance(result, list) else str(result) + tool_message = {"role": "tool", "name": name, "content": content} prompt_completion_tool.append(tool_message) completions[idx_with_tool].append(tool_message) @@ -1775,7 +1788,26 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - num_images = [len(img_list) for img_list in images] if images is not None else None + # For VLMs with tools, re-extract all images from the full conversation. After the tool loop, + # prompts contain both dataset images (from prepare_multimodal_messages) and tool response + # images (from multimodal tool returns in _tool_call_loop). We need all of them for the + # forward pass to compute pixel_values correctly. + if self.tools and hasattr(self.processing_class, "image_processor"): + all_images = [] + has_images = False + for prompt in prompts: + sample_images = [] + for message in prompt: + if isinstance(message.get("content"), list): + for part in message["content"]: + if isinstance(part, dict) and part.get("type") == "image": + sample_images.append(part["image"]) + has_images = True + all_images.append(sample_images if sample_images else None) + if has_images: + images = all_images + + num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None # Get forward_kwargs for models with multimodal inputs if images is not None: From 09ef5d4cb9129fc9a08654b0467929595ffab9f5 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 20 Mar 2026 18:55:31 +0100 Subject: [PATCH 02/47] Fix VLM processor support for tool calling and add CARLA VLM example - Fix _get_tool_suffix_ids: add tokenize=True and unbatch for VLM processors - Fix parse_response: delegate to inner tokenizer for processors (2 places) - Propagate response_schema from processor to tokenizer for tool call parsing - Fix max_position_embeddings fallback for models with composite config (e.g., Qwen3.5) - Add carla_vlm.py: CARLA VLM training with multimodal tool responses (camera images) Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/scripts/openenv/carla_vlm.py | 229 ++++++++++++++++++++++++++ trl/trainer/grpo_trainer.py | 35 +++- 2 files changed, 257 insertions(+), 7 deletions(-) create mode 100644 examples/scripts/openenv/carla_vlm.py diff --git a/examples/scripts/openenv/carla_vlm.py b/examples/scripts/openenv/carla_vlm.py new file mode 100644 index 00000000000..ce5e58fe8bc --- /dev/null +++ b/examples/scripts/openenv/carla_vlm.py @@ -0,0 +1,229 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env", +# ] +# /// + + +""" +GRPO training with OpenEnv's CARLA environment for VLMs (Vision Language Models). + +This script uses `environment_factory` with multimodal tool responses: each tool action +returns a camera image from the vehicle alongside the text scene description, allowing the +VLM to see the driving scene visually after each action. + +The CARLA environment simulates an emergency driving scenario where pedestrians are ahead +and the model must learn to observe the scene and take the correct action (e.g., swerve +to an empty lane) to minimize casualties. + +Setup: +```sh +pip install "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env" +``` + +Usage (requires at least 2 CARLA Spaces, each supports only 1 concurrent connection): +```sh +python examples/scripts/openenv/carla_vlm.py \ + --env-urls https://server1.hf.space https://server2.hf.space +``` +""" + +import argparse +import base64 +from io import BytesIO + +from carla_env import CarlaAction, CarlaEnv +from datasets import Dataset +from PIL import Image + +from trl import GRPOConfig, GRPOTrainer + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run GRPO VLM training with CARLA environment.") + parser.add_argument("--model", type=str, default="Qwen/Qwen3.5-2B") + parser.add_argument( + "--env-urls", + type=str, + nargs="+", + required=True, + help="URLs for CARLA environment servers. At least 2 required (1 Space = 1 connection).", + ) + parser.add_argument("--dataset-size", type=int, default=1000) + parser.add_argument("--max-completion-length", type=int, default=1024) + parser.add_argument("--gradient-accumulation-steps", type=int, default=16) + parser.add_argument("--max-steps", type=int, default=50) + parser.add_argument("--image-size", type=int, default=512, help="Resize camera images to this size. 0 to disable.") + parser.add_argument("--trackio-space-id", type=str, default="carla-vlm-grpo") + parser.add_argument("--hub-model-id", type=str, default=None) + parser.add_argument("--run-name", type=str, default=None) + parser.add_argument("--report-to", type=str, default="wandb", help="Logging backend: wandb, trackio, none.") + return parser.parse_args() + + +SIM_TICKS = 10 + + +def reward_func(completions, environments, **kwargs): + return [env.reward for env in environments] + + +def main(): + args = parse_args() + env_url_iter = iter(args.env_urls) + image_size = args.image_size + + prompt = """You control an autonomous vehicle in an emergency. There are pedestrians ahead and you must \ +decide what to do immediately. + +You will see a camera image from the vehicle after each action. Use the visual information +along with the scene description to decide your next action. + +You have the following tools available: +- `observe`: Advance time and get a new observation of the scene with a camera image. +- `emergency_stop`: Apply maximum braking to stop the vehicle. +- `lane_change(direction)`: Change lane to the left or right. Direction must be "left" or "right". + +Observe the scene first, then decide the best course of action to minimize harm.""" + + dataset = Dataset.from_dict( + {"prompt": [[{"role": "user", "content": prompt}] for _ in range(args.dataset_size)]} + ) + + class CarlaVLMEnv: + def __init__(self): + url = next(env_url_iter) + self.client = CarlaEnv(base_url=url, connect_timeout_s=30, message_timeout_s=120) + self.reward = 0.0 + + @staticmethod + def _describe(obs) -> str: + parts = [] + parts.append(f"Speed: {obs.speed_kmh:.1f} km/h.") + if obs.nearby_actors: + for actor in obs.nearby_actors: + parts.append(f"- {actor.get('type', 'actor')} at {actor.get('distance', '?')}m") + else: + parts.append("No nearby actors detected.") + if obs.collision_detected: + parts.append(f"COLLISION detected with {obs.collided_with or 'unknown'}!") + return "\n".join(parts) + + @staticmethod + def _decode_image(camera_image_b64, target_size): + """Decode base64 JPEG image and optionally resize.""" + img_bytes = base64.b64decode(camera_image_b64) + img = Image.open(BytesIO(img_bytes)) + if target_size > 0: + img.thumbnail((target_size, target_size), Image.LANCZOS) + return img + + def _format_multimodal(self, obs) -> list: + """Format observation as multimodal content blocks (camera image + text).""" + content = [] + if obs.camera_image is not None: + img = self._decode_image(obs.camera_image, image_size) + content.append({"type": "image", "image": img}) + content.append({"type": "text", "text": self._describe(obs)}) + return content + + def _advance(self, ticks: int = SIM_TICKS): + result = None + for _ in range(ticks): + result = self.client.step(CarlaAction(action_type="observe")) + if result.done: + break + return result + + def _advance_and_capture(self, ticks: int = SIM_TICKS): + """Advance the simulation, then capture an image of the current state.""" + result = self._advance(ticks) + capture_result = self.client.step(CarlaAction(action_type="capture_image")) + result.observation.camera_image = capture_result.observation.camera_image + return result + + def reset(self, **kwargs) -> str | None: + result = self.client.reset(scenario_name="trolley_micro_escape_exists") + self.reward = 0.0 + return self._describe(result.observation) + + def observe(self) -> list: + """ + Get the current scene with a camera image and description. + + Returns: + The camera image and scene description with vehicle state and nearby actors. + """ + result = self._advance_and_capture() + self.reward = result.observation.rubric_reward or 0.0 + return self._format_multimodal(result.observation) + + def emergency_stop(self) -> list: + """ + Apply maximum braking to stop the vehicle. + + Returns: + The camera image and scene description after braking. + """ + self.client.step(CarlaAction(action_type="emergency_stop")) + result = self._advance_and_capture() + self.reward = result.observation.rubric_reward or 0.0 + return self._format_multimodal(result.observation) + + def lane_change(self, direction: str) -> list: + """ + Change lane to avoid obstacles. + + Args: + direction: Direction to change lane, either "left" or "right". + + Returns: + The camera image and scene description after changing lane. + """ + self.client.step(CarlaAction(action_type="lane_change", lane_direction=direction)) + result = self._advance_and_capture() + self.reward = result.observation.rubric_reward or 0.0 + return self._format_multimodal(result.observation) + + trainer = GRPOTrainer( + model=args.model, + train_dataset=dataset, + reward_funcs=reward_func, + args=GRPOConfig( + chat_template_kwargs={"enable_thinking": False}, + log_completions=True, + logging_steps=2, + num_completions_to_print=1, + max_completion_length=args.max_completion_length, + per_device_train_batch_size=len(args.env_urls), + steps_per_generation=1, + num_generations=len(args.env_urls), + gradient_accumulation_steps=args.gradient_accumulation_steps, + max_steps=args.max_steps, + push_to_hub=args.hub_model_id is not None, + hub_model_id=args.hub_model_id, + run_name=args.run_name, + report_to=args.report_to, + ), + environment_factory=CarlaVLMEnv, + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c3e21140814..f52f0e4c117 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1392,6 +1392,7 @@ def _get_tool_suffix_ids(self, tool_messages): prefix_ids = self.processing_class.apply_chat_template( dummy_messages, add_generation_prompt=False, + tokenize=True, chat_template=self.chat_template, return_dict=False, **self.chat_template_kwargs, @@ -1399,10 +1400,16 @@ def _get_tool_suffix_ids(self, tool_messages): full_ids = self.processing_class.apply_chat_template( dummy_messages + tool_messages, add_generation_prompt=True, + tokenize=True, chat_template=self.chat_template, return_dict=False, **self.chat_template_kwargs, ) + # VLM processors return batched output (list of lists), unbatch for single conversation + if isinstance(prefix_ids, list) and len(prefix_ids) == 1 and isinstance(prefix_ids[0], list): + prefix_ids = prefix_ids[0] + if isinstance(full_ids, list) and len(full_ids) == 1 and isinstance(full_ids[0], list): + full_ids = full_ids[0] if not full_ids[: len(prefix_ids)] == prefix_ids: raise ValueError("Unexpected tokenization: the prefix IDs are not a prefix of the full IDs.") return full_ids[len(prefix_ids) :] @@ -1499,7 +1506,11 @@ async def _run_async_tools(async_coros): if self.use_vllm and self.vllm_mode == "colocate": max_model_len = self.vllm_generation.llm.llm_engine.model_config.max_model_len elif not self.use_vllm: - max_model_len = self.model.config.max_position_embeddings + max_model_len = getattr(self.model.config, "max_position_embeddings", None) + if max_model_len is None: + # Some models (e.g., Qwen3.5) store max length in text_config or use a different attribute + text_config = getattr(self.model.config, "text_config", self.model.config) + max_model_len = getattr(text_config, "max_position_embeddings", 32768) else: raise NotImplementedError( f"Unsupported mode detected: use_vllm={self.use_vllm}, vllm_mode={self.vllm_mode}" @@ -1569,9 +1580,12 @@ async def _run_async_tools(async_coros): pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] - # Decode post-tool completions + # Decode post-tool completions. Use tokenizer for parsing (processors don't have parse_response). + parsing_class = self.processing_class + if not isinstance(parsing_class, PreTrainedTokenizerBase) and hasattr(parsing_class, "tokenizer"): + parsing_class = parsing_class.tokenizer post_tool_completions = [ - parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids + parse_response(parsing_class, ids) if ids else {} for ids in post_tool_ids ] # Add post-tool completions to the existing completions @@ -1618,14 +1632,21 @@ def _generate(self, prompts: list): extra_fields = {} # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. + # For VLM processors, delegate to the inner tokenizer for parsing (parse_response lives on the tokenizer). if is_conversational({"prompt": prompts[0]}): + parsing_class = self.processing_class + if not isinstance(parsing_class, PreTrainedTokenizerBase) and hasattr(parsing_class, "tokenizer"): + # Propagate response_schema from processor to tokenizer if needed + if getattr(self.processing_class, "response_schema", None) and not getattr(parsing_class.tokenizer, "response_schema", None): + parsing_class.tokenizer.response_schema = self.processing_class.response_schema + parsing_class = parsing_class.tokenizer if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 - and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors - and hasattr(self.processing_class, "response_schema") # attribute not set by default for now - and self.processing_class.response_schema is not None # only works if the tokenizer has a schema + and isinstance(parsing_class, PreTrainedTokenizerBase) + and hasattr(parsing_class, "response_schema") # attribute not set by default for now + and parsing_class.response_schema is not None # only works if the tokenizer has a schema ): - completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] + completions = [[parse_response(parsing_class, ids)] for ids in completion_ids] else: contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) completions = [[{"role": "assistant", "content": content}] for content in contents] From b03cc2b8bdc580b7b49503cdcb16b4d1791ea920 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 20 Mar 2026 19:58:02 +0100 Subject: [PATCH 03/47] Expand image tokens in tool suffix IDs and collect tool images for forward pass --- trl/trainer/grpo_trainer.py | 85 ++++++++++++++++++++++++------------- 1 file changed, 55 insertions(+), 30 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f52f0e4c117..a439e41461c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1397,19 +1397,46 @@ def _get_tool_suffix_ids(self, tool_messages): return_dict=False, **self.chat_template_kwargs, ) - full_ids = self.processing_class.apply_chat_template( - dummy_messages + tool_messages, - add_generation_prompt=True, - tokenize=True, - chat_template=self.chat_template, - return_dict=False, - **self.chat_template_kwargs, - ) # VLM processors return batched output (list of lists), unbatch for single conversation if isinstance(prefix_ids, list) and len(prefix_ids) == 1 and isinstance(prefix_ids[0], list): prefix_ids = prefix_ids[0] - if isinstance(full_ids, list) and len(full_ids) == 1 and isinstance(full_ids[0], list): - full_ids = full_ids[0] + + # Check if tool messages contain images (multimodal tool responses) + tool_images = [] + for msg in tool_messages: + if isinstance(msg.get("content"), list): + for part in msg["content"]: + if isinstance(part, dict) and part.get("type") == "image": + tool_images.append(part["image"]) + + if tool_images and hasattr(self.processing_class, "image_processor"): + # For VLMs with images: use processor.__call__ to get correctly expanded image tokens. + # apply_chat_template only inserts a single <|image_pad|> placeholder per image, + # but the model needs N tokens per image (based on resolution). The processor's + # __call__ handles this expansion. + full_text = self.processing_class.apply_chat_template( + dummy_messages + tool_messages, + add_generation_prompt=True, + tokenize=False, + chat_template=self.chat_template, + **self.chat_template_kwargs, + ) + full_result = self.processing_class( + text=full_text, images=tool_images, return_tensors="pt" + ) + full_ids = full_result["input_ids"][0].tolist() + else: + full_ids = self.processing_class.apply_chat_template( + dummy_messages + tool_messages, + add_generation_prompt=True, + tokenize=True, + chat_template=self.chat_template, + return_dict=False, + **self.chat_template_kwargs, + ) + if isinstance(full_ids, list) and len(full_ids) == 1 and isinstance(full_ids[0], list): + full_ids = full_ids[0] + if not full_ids[: len(prefix_ids)] == prefix_ids: raise ValueError("Unexpected tokenization: the prefix IDs are not a prefix of the full IDs.") return full_ids[len(prefix_ids) :] @@ -1420,6 +1447,8 @@ def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logp idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] tool_calls = [tool_calls[idx] for idx in idxs_with_tool] tool_mask = [[1] * len(ids) for ids in completion_ids] # 0 for tool result tokens, 1 elsewhere + # Collect images from multimodal tool responses for the forward pass + tool_images = [[] for _ in completion_ids] tool_call_count = 0 tool_failure_count = 0 iteration_num = 0 @@ -1482,6 +1511,11 @@ async def _run_async_tools(async_coros): # pass them through directly so _tokenize_prompts can extract images for VLMs. content = result if isinstance(result, list) else str(result) tool_message = {"role": "tool", "name": name, "content": content} + # Collect images from multimodal tool responses + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "image": + tool_images[idx_with_tool].append(part["image"]) prompt_completion_tool.append(tool_message) completions[idx_with_tool].append(tool_message) @@ -1599,7 +1633,7 @@ async def _run_async_tools(async_coros): idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call] tool_calls = [tool_call for tool_call in tool_calls if tool_call] iteration_num += 1 - return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count + return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count, tool_images def _generate(self, prompts: list): device = self.accelerator.device @@ -1662,9 +1696,19 @@ def _generate(self, prompts: list): logprobs, tool_call_count, tool_failure_count, + tool_images, ) = self._tool_call_loop( prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields ) + # Merge tool response images into the images list for the forward pass + has_tool_images = any(imgs for imgs in tool_images) + if has_tool_images: + if images is None: + images = [imgs if imgs else None for imgs in tool_images] + else: + images = [ + (existing or []) + new for existing, new in zip(images, tool_images, strict=True) + ] else: # Support custom env_mask from rollout_func (e.g., for environment feedback masking) # Internally treated as tool_mask - marks model tokens (1) vs external tokens (0) @@ -1832,25 +1876,6 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - # For VLMs with tools, re-extract all images from the full conversation. After the tool loop, - # prompts contain both dataset images (from prepare_multimodal_messages) and tool response - # images (from multimodal tool returns in _tool_call_loop). We need all of them for the - # forward pass to compute pixel_values correctly. - if self.tools and hasattr(self.processing_class, "image_processor"): - all_images = [] - has_images = False - for prompt in prompts: - sample_images = [] - for message in prompt: - if isinstance(message.get("content"), list): - for part in message["content"]: - if isinstance(part, dict) and part.get("type") == "image": - sample_images.append(part["image"]) - has_images = True - all_images.append(sample_images if sample_images else None) - if has_images: - images = all_images - num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None # Get forward_kwargs for models with multimodal inputs From 6ea243076bbf4669fb7cf9380297e6e4af30a63e Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 20 Mar 2026 19:58:02 +0100 Subject: [PATCH 04/47] Expand image tokens in tool suffix IDs and collect tool images for forward pass --- examples/scripts/openenv/carla_vlm.py | 2 +- trl/trainer/grpo_trainer.py | 85 +++++++++++++++++---------- 2 files changed, 56 insertions(+), 31 deletions(-) diff --git a/examples/scripts/openenv/carla_vlm.py b/examples/scripts/openenv/carla_vlm.py index ce5e58fe8bc..f3c08b2a557 100644 --- a/examples/scripts/openenv/carla_vlm.py +++ b/examples/scripts/openenv/carla_vlm.py @@ -65,7 +65,7 @@ def parse_args(): help="URLs for CARLA environment servers. At least 2 required (1 Space = 1 connection).", ) parser.add_argument("--dataset-size", type=int, default=1000) - parser.add_argument("--max-completion-length", type=int, default=1024) + parser.add_argument("--max-completion-length", type=int, default=2048) parser.add_argument("--gradient-accumulation-steps", type=int, default=16) parser.add_argument("--max-steps", type=int, default=50) parser.add_argument("--image-size", type=int, default=512, help="Resize camera images to this size. 0 to disable.") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f52f0e4c117..a439e41461c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1397,19 +1397,46 @@ def _get_tool_suffix_ids(self, tool_messages): return_dict=False, **self.chat_template_kwargs, ) - full_ids = self.processing_class.apply_chat_template( - dummy_messages + tool_messages, - add_generation_prompt=True, - tokenize=True, - chat_template=self.chat_template, - return_dict=False, - **self.chat_template_kwargs, - ) # VLM processors return batched output (list of lists), unbatch for single conversation if isinstance(prefix_ids, list) and len(prefix_ids) == 1 and isinstance(prefix_ids[0], list): prefix_ids = prefix_ids[0] - if isinstance(full_ids, list) and len(full_ids) == 1 and isinstance(full_ids[0], list): - full_ids = full_ids[0] + + # Check if tool messages contain images (multimodal tool responses) + tool_images = [] + for msg in tool_messages: + if isinstance(msg.get("content"), list): + for part in msg["content"]: + if isinstance(part, dict) and part.get("type") == "image": + tool_images.append(part["image"]) + + if tool_images and hasattr(self.processing_class, "image_processor"): + # For VLMs with images: use processor.__call__ to get correctly expanded image tokens. + # apply_chat_template only inserts a single <|image_pad|> placeholder per image, + # but the model needs N tokens per image (based on resolution). The processor's + # __call__ handles this expansion. + full_text = self.processing_class.apply_chat_template( + dummy_messages + tool_messages, + add_generation_prompt=True, + tokenize=False, + chat_template=self.chat_template, + **self.chat_template_kwargs, + ) + full_result = self.processing_class( + text=full_text, images=tool_images, return_tensors="pt" + ) + full_ids = full_result["input_ids"][0].tolist() + else: + full_ids = self.processing_class.apply_chat_template( + dummy_messages + tool_messages, + add_generation_prompt=True, + tokenize=True, + chat_template=self.chat_template, + return_dict=False, + **self.chat_template_kwargs, + ) + if isinstance(full_ids, list) and len(full_ids) == 1 and isinstance(full_ids[0], list): + full_ids = full_ids[0] + if not full_ids[: len(prefix_ids)] == prefix_ids: raise ValueError("Unexpected tokenization: the prefix IDs are not a prefix of the full IDs.") return full_ids[len(prefix_ids) :] @@ -1420,6 +1447,8 @@ def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logp idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] tool_calls = [tool_calls[idx] for idx in idxs_with_tool] tool_mask = [[1] * len(ids) for ids in completion_ids] # 0 for tool result tokens, 1 elsewhere + # Collect images from multimodal tool responses for the forward pass + tool_images = [[] for _ in completion_ids] tool_call_count = 0 tool_failure_count = 0 iteration_num = 0 @@ -1482,6 +1511,11 @@ async def _run_async_tools(async_coros): # pass them through directly so _tokenize_prompts can extract images for VLMs. content = result if isinstance(result, list) else str(result) tool_message = {"role": "tool", "name": name, "content": content} + # Collect images from multimodal tool responses + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "image": + tool_images[idx_with_tool].append(part["image"]) prompt_completion_tool.append(tool_message) completions[idx_with_tool].append(tool_message) @@ -1599,7 +1633,7 @@ async def _run_async_tools(async_coros): idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call] tool_calls = [tool_call for tool_call in tool_calls if tool_call] iteration_num += 1 - return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count + return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count, tool_images def _generate(self, prompts: list): device = self.accelerator.device @@ -1662,9 +1696,19 @@ def _generate(self, prompts: list): logprobs, tool_call_count, tool_failure_count, + tool_images, ) = self._tool_call_loop( prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields ) + # Merge tool response images into the images list for the forward pass + has_tool_images = any(imgs for imgs in tool_images) + if has_tool_images: + if images is None: + images = [imgs if imgs else None for imgs in tool_images] + else: + images = [ + (existing or []) + new for existing, new in zip(images, tool_images, strict=True) + ] else: # Support custom env_mask from rollout_func (e.g., for environment feedback masking) # Internally treated as tool_mask - marks model tokens (1) vs external tokens (0) @@ -1832,25 +1876,6 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - # For VLMs with tools, re-extract all images from the full conversation. After the tool loop, - # prompts contain both dataset images (from prepare_multimodal_messages) and tool response - # images (from multimodal tool returns in _tool_call_loop). We need all of them for the - # forward pass to compute pixel_values correctly. - if self.tools and hasattr(self.processing_class, "image_processor"): - all_images = [] - has_images = False - for prompt in prompts: - sample_images = [] - for message in prompt: - if isinstance(message.get("content"), list): - for part in message["content"]: - if isinstance(part, dict) and part.get("type") == "image": - sample_images.append(part["image"]) - has_images = True - all_images.append(sample_images if sample_images else None) - if has_images: - images = all_images - num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None # Get forward_kwargs for models with multimodal inputs From cdf153ac7105e6f953be63d17db8c5b7c917740d Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Mon, 23 Mar 2026 11:09:38 +0100 Subject: [PATCH 05/47] Update docs --- docs/source/example_overview.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index 96b61c65d6e..52865ed3788 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -66,9 +66,10 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl | [`examples/scripts/nemo_gym/train_multi_environment.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_environment.py) | This script shows how to use the [`GRPOTrainer`] to train language models in NVIDIA NeMo-Gym environments. Supports multi-turn and tool calling environments, and multi-environment training. See the [NeMo-Gym Integration](nemo_gym) guide for setup and usage. | | [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a model. | | [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a a Vision Language Model. | -| [`examples/scripts/openenv/browsergym.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's BrowserGym environment and vLLM for VLMs | +| [`examples/scripts/openenv/browsergym.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym.py) | GRPO training with BrowserGym environment for VLMs with multimodal tool responses (screenshots). | | [`examples/scripts/openenv/browsergym_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym_llm.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's BrowserGym environment and vLLM for LLMs | | [`examples/scripts/openenv/carla.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/carla.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's CARLA environment for autonomous driving scenarios. | +| [`examples/scripts/openenv/carla_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/carla_vlm.py) | GRPO training with CARLA environment for VLMs with multimodal tool responses (camera images). | | [`examples/scripts/openenv/catch.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/catch.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Catch environment (OpenSpiel) and vLLM | | [`examples/scripts/openenv/echo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Echo environment and vLLM. | | [`examples/scripts/openenv/sudoku.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/sudoku.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Sudoku environment and vLLM. | From f49a63a3376d5b3e7c82b7a3eeb6190eee6492f2 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Mon, 23 Mar 2026 13:03:03 +0100 Subject: [PATCH 06/47] Add debug --- trl/trainer/grpo_trainer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 318f2dbf0c3..9bc9d735a8e 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1581,6 +1581,8 @@ async def _run_async_tools(async_coros): ) # Generate new completions after tool execution (using concatenated IDs, no re-tokenization) + loop_img_count = sum(len(x) for x in loop_images if x) if loop_images else 0 + print(f" [VLM DEBUG] Generation in tool loop: {loop_img_count} images passed to _generate_single_turn") post_tool_ids, post_tool_logprobs = self._generate_single_turn( prompt_completion_tool_ids, loop_images, loop_multimodal_fields ) @@ -1708,6 +1710,9 @@ def _generate(self, prompts: list): ) # Merge tool response images into the images list for the forward pass has_tool_images = any(imgs for imgs in tool_images) + # DEBUG: tool image collection + print(f" [VLM DEBUG] tool_images per sample: {[len(imgs) for imgs in tool_images]}") + print(f" [VLM DEBUG] completion_ids lengths: {[len(ids) for ids in completion_ids]}") if has_tool_images: if images is None: images = [imgs if imgs else None for imgs in tool_images] @@ -1715,6 +1720,7 @@ def _generate(self, prompts: list): images = [ (existing or []) + new for existing, new in zip(images, tool_images, strict=True) ] + print(f" [VLM DEBUG] images after merge: {[len(x) if x else 0 for x in images]}") else: # Support custom env_mask from rollout_func (e.g., for environment feedback masking) # Internally treated as tool_mask - marks model tokens (1) vs external tokens (0) @@ -1886,6 +1892,8 @@ def _generate_and_score_completions( # Get forward_kwargs for models with multimodal inputs if images is not None: + total_imgs = sum(len(x) for x in images if x) + print(f" [VLM DEBUG] Forward pass: {total_imgs} total images, num_images={num_images}") prompts_text = [ apply_chat_template( {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs @@ -1895,8 +1903,13 @@ def _generate_and_score_completions( prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") prompt_inputs = super()._prepare_inputs(prompt_inputs) forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + # DEBUG: confirm pixel_values are computed + for k, v in forward_kwargs.items(): + if hasattr(v, 'shape'): + print(f" [VLM DEBUG] forward_kwargs[{k}].shape = {v.shape}") else: forward_kwargs = {} + print(f" [VLM DEBUG] No images for forward pass") # If token_type_ids are used, extend them with zeros for the completion part if "token_type_ids" in forward_kwargs: From dd87fd1239d721ade5cb71e133cf05a6379376e8 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Mon, 23 Mar 2026 16:32:36 +0100 Subject: [PATCH 07/47] Fix VLM forward pass: return images from _generate, build mm_token_type_ids from input_ids, handle truncation - Return images from _generate to _generate_and_score_completions - Build mm_token_type_ids directly from prompt_completion_ids (detect image_pad tokens) - Place mm_token_type_ids construction after the extension block to avoid double-extension - Use image_processor directly for pixel_values (avoids re-processing mismatch) - Trim pixel_values/image_grid_thw when truncation cuts partial images - Normalize string content in tool messages for VLM processors in _get_tool_suffix_ids - Use content blocks format for dummy messages with VLM processors --- trl/trainer/grpo_trainer.py | 102 ++++++++++++++++++++++++++++++++---- 1 file changed, 92 insertions(+), 10 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d637d4603a6..07d1226ccf4 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1388,7 +1388,14 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" - dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] + # Use content blocks format for VLM processors that don't handle plain strings + if hasattr(self.processing_class, "image_processor"): + dummy_messages = [ + {"role": "user", "content": [{"type": "text", "text": "dummy"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "dummy"}]}, + ] + else: + dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] prefix_ids = self.processing_class.apply_chat_template( dummy_messages, add_generation_prompt=False, @@ -1409,6 +1416,7 @@ def _get_tool_suffix_ids(self, tool_messages): if isinstance(part, dict) and part.get("type") == "image": tool_images.append(part["image"]) + tool_multimodal_fields = {} if tool_images and hasattr(self.processing_class, "image_processor"): # For VLMs with images: use processor.__call__ to get correctly expanded image tokens. # apply_chat_template only inserts a single <|image_pad|> placeholder per image, @@ -1425,7 +1433,17 @@ def _get_tool_suffix_ids(self, tool_messages): text=full_text, images=tool_images, return_tensors="pt" ) full_ids = full_result["input_ids"][0].tolist() + # Save multimodal fields (pixel_values, image_grid_thw, etc.) to reuse in the forward pass, + # ensuring consistency between image tokens in suffix_ids and pixel features. + tool_multimodal_fields = { + k: v for k, v in full_result.items() if k not in ("input_ids", "attention_mask") + } else: + # Normalize string content in tool messages for VLM processors + if hasattr(self.processing_class, "image_processor"): + for msg in tool_messages: + if isinstance(msg.get("content"), str): + msg["content"] = [{"type": "text", "text": msg["content"]}] full_ids = self.processing_class.apply_chat_template( dummy_messages + tool_messages, add_generation_prompt=True, @@ -1445,7 +1463,7 @@ def _get_tool_suffix_ids(self, tool_messages): if full_ids[: len(prefix_ids)] != prefix_ids: raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") - return full_ids[len(prefix_ids) :] + return full_ids[len(prefix_ids) :], tool_multimodal_fields def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields): # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt @@ -1536,7 +1554,7 @@ async def _run_async_tools(async_coros): tool_messages.insert(0, message) else: break - suffix_ids = self._get_tool_suffix_ids(tool_messages) + suffix_ids, _ = self._get_tool_suffix_ids(tool_messages) prompt_completion_tool_ids.append( prompt_ids[idx_with_tool] + completion_ids[idx_with_tool] + suffix_ids ) @@ -1779,6 +1797,7 @@ def _generate(self, prompts: list): total_completion_tokens, logprobs, extra_fields, + images, ) def _generate_and_score_completions( @@ -1830,6 +1849,7 @@ def _generate_and_score_completions( num_items_in_batch, sampling_per_token_logps_list, extra_fields, + images, ) = self._generate(prompts) # Convert lists of token IDs to padded tensors @@ -1893,9 +1913,26 @@ def _generate_and_score_completions( num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None # Get forward_kwargs for models with multimodal inputs - if images is not None: - total_imgs = sum(len(x) for x in images if x) - print(f" [VLM DEBUG] Forward pass: {total_imgs} total images, num_images={num_images}") + if images is not None and hasattr(self.processing_class, "image_processor"): + # For VLMs with tool images: process images through image_processor and construct + # mm_token_type_ids from prompt_completion_ids by detecting image/video pad tokens. + flat_images = [img for img_list in images if img_list for img in img_list] + print(f" [VLM DEBUG] Forward pass: {len(flat_images)} images via image_processor + mm_token_type_ids") + + # Get pixel_values and image_grid_thw from image processor + image_inputs = self.processing_class.image_processor(images=flat_images, return_tensors="pt") + image_inputs = super()._prepare_inputs(image_inputs) + forward_kwargs = dict(image_inputs) + + # Build mm_token_type_ids from prompt_completion_ids AFTER the extension code below, + # since our mm_token_type_ids already covers the full prompt+completion sequence. + # Store the data needed to build it; it will be set after the extension block. + self._build_mm_token_type_ids = True + + for k, v in forward_kwargs.items(): + if hasattr(v, 'shape'): + print(f" [VLM DEBUG] forward_kwargs[{k}].shape = {v.shape}") + elif images is not None: prompts_text = [ apply_chat_template( {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs @@ -1905,10 +1942,6 @@ def _generate_and_score_completions( prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") prompt_inputs = super()._prepare_inputs(prompt_inputs) forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - # DEBUG: confirm pixel_values are computed - for k, v in forward_kwargs.items(): - if hasattr(v, 'shape'): - print(f" [VLM DEBUG] forward_kwargs[{k}].shape = {v.shape}") else: forward_kwargs = {} print(f" [VLM DEBUG] No images for forward pass") @@ -1941,6 +1974,55 @@ def _generate_and_score_completions( [mm_token_type_ids, mm_token_type_ids.new_zeros(completion_ids.shape)], dim=1 ) + # For VLM tool images: build mm_token_type_ids from the full prompt_completion_ids. + # This must happen AFTER the mm_token_type_ids extension block above, because our version + # already covers the full sequence (images are in the completion, not just the prompt). + if getattr(self, "_build_mm_token_type_ids", False): + image_pad_id = self.processing_class.tokenizer.convert_tokens_to_ids("<|image_pad|>") + video_pad_id = self.processing_class.tokenizer.convert_tokens_to_ids("<|video_pad|>") + mm_ids = torch.zeros_like(prompt_completion_ids) + if image_pad_id is not None: + mm_ids[prompt_completion_ids == image_pad_id] = 1 + if video_pad_id is not None: + mm_ids[prompt_completion_ids == video_pad_id] = 2 + forward_kwargs["mm_token_type_ids"] = mm_ids + self._build_mm_token_type_ids = False + + # Truncation safety: if max_completion_length cut some image tokens, pixel_values may + # have more features than there are image tokens in input_ids. Trim to keep only + # images whose tokens are fully present. + actual_image_tokens = (mm_ids == 1).sum().item() + image_grid_thw = forward_kwargs.get("image_grid_thw") + if image_grid_thw is not None: + merge_length = getattr(self.processing_class.image_processor, "merge_size", 2) ** 2 + # Count how many complete images fit in the actual token count + cumulative_tokens = 0 + keep_images = 0 + keep_pixels = 0 + for i in range(image_grid_thw.shape[0]): + img_tokens = image_grid_thw[i].prod().item() // merge_length + if cumulative_tokens + img_tokens <= actual_image_tokens: + cumulative_tokens += img_tokens + keep_images += 1 + keep_pixels += image_grid_thw[i].prod().item() + else: + break + if keep_images < image_grid_thw.shape[0]: + forward_kwargs["image_grid_thw"] = image_grid_thw[:keep_images] + forward_kwargs["pixel_values"] = forward_kwargs["pixel_values"][:keep_pixels] + # Zero out orphaned image tokens in mm_token_type_ids (from truncated images) + mm_ids[mm_ids == 1] = 0 # reset all + # Re-mark only the kept image tokens + count = 0 + for b in range(mm_ids.shape[0]): + for j in range(mm_ids.shape[1]): + if prompt_completion_ids[b, j] == image_pad_id and count < cumulative_tokens: + mm_ids[b, j] = 1 + count += 1 + forward_kwargs["mm_token_type_ids"] = mm_ids + + print(f" [VLM DEBUG] mm_token_type_ids: {mm_ids.shape}, image tokens: {actual_image_tokens}, image_grid_thw: {forward_kwargs.get('image_grid_thw', torch.tensor([])).shape}") + # When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a # torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True"). # Temporarily disable checkpointing to avoid this warning during inference. From df71f29ebd8cc0a958b96f896189e23f610cb381 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Mon, 23 Mar 2026 18:11:40 +0100 Subject: [PATCH 08/47] Fix image boundary truncation, tool_mask sync, and image logging with None entries --- trl/trainer/grpo_trainer.py | 139 ++++++++++++++++++++++++++---------- 1 file changed, 103 insertions(+), 36 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 07d1226ccf4..0a3676be0be 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1465,6 +1465,40 @@ def _get_tool_suffix_ids(self, tool_messages): raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") return full_ids[len(prefix_ids) :], tool_multimodal_fields + def _truncate_at_image_boundary(self, ids, max_length): + """Truncate token ID list to max_length, ensuring we don't cut in the middle of an image. + + If truncation would split an image token sequence (<|vision_start|>...<|vision_end|>), + backs up to the end of the last complete image. This prevents mismatches between + image placeholder tokens in input_ids and pixel_values in the forward pass. + """ + if len(ids) <= max_length: + return ids + + # Get special token IDs for image boundaries + if hasattr(self.processing_class, "image_processor") and hasattr(self.processing_class, "tokenizer"): + tok = self.processing_class.tokenizer + vision_end_id = tok.convert_tokens_to_ids("<|vision_end|>") + vision_start_id = tok.convert_tokens_to_ids("<|vision_start|>") + + truncated = ids[:max_length] + # Check if we're inside an image sequence: find the last vision_start and vision_end + last_start = -1 + last_end = -1 + for i in range(len(truncated) - 1, -1, -1): + if truncated[i] == vision_end_id and last_end == -1: + last_end = i + if truncated[i] == vision_start_id and last_start == -1: + last_start = i + if last_start != -1 and last_end != -1: + break + + # If last vision_start > last vision_end, we're inside an incomplete image + if last_start > last_end: + return ids[:last_start] # truncate before the incomplete image + return truncated + return ids[:max_length] + def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields): # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt tool_calls = [completion[0].get("tool_calls") for completion in completions] @@ -1580,7 +1614,9 @@ async def _run_async_tools(async_coros): idx_with_tool = idxs_with_tool[idx] if overlong[idx]: prompt_length = len(prompt_ids[idx_with_tool]) - ct = prompt_completion_tool_ids[idx][prompt_length : prompt_length + self.max_completion_length] + ct = self._truncate_at_image_boundary( + prompt_completion_tool_ids[idx][prompt_length:], self.max_completion_length + ) completion_ids[idx_with_tool] = ct tool_mask[idx_with_tool] += [1] * (len(ct) - len(tool_mask[idx_with_tool])) if logprobs is not None: @@ -1614,14 +1650,20 @@ async def _run_async_tools(async_coros): completion_tool_ids = prompt_completion_tool_ids[idx][prompt_len:] excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length if excess_length > 0: - # If exceeding max length, truncate post_tool_ids - post_tool_ids[idx] = post_tool_ids[idx][:-excess_length] + # If exceeding max length, truncate post_tool_ids (respecting image boundaries) + truncated_post = self._truncate_at_image_boundary( + post_tool_ids[idx], len(post_tool_ids[idx]) - excess_length + ) if logprobs is not None: - post_tool_logprobs[idx] = post_tool_logprobs[idx][:-excess_length] + post_tool_logprobs[idx] = post_tool_logprobs[idx][:len(truncated_post)] + post_tool_ids[idx] = truncated_post excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length if excess_length > 0: - # If still exceeding max length, truncate completion_tool_ids as well - prompt_completion_tool_ids[idx] = prompt_completion_tool_ids[idx][:-excess_length] + # If still exceeding, truncate completion_tool_ids (respecting image boundaries) + truncated_pct = self._truncate_at_image_boundary( + prompt_completion_tool_ids[idx], len(prompt_completion_tool_ids[idx]) - excess_length + ) + prompt_completion_tool_ids[idx] = truncated_pct # Update tool_mask: the tool result should be 0 and the post-tool 1 for idx in range(len(idxs_with_tool)): @@ -1661,6 +1703,31 @@ async def _run_async_tools(async_coros): idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call] tool_calls = [tool_call for tool_call in tool_calls if tool_call] iteration_num += 1 + + # Sync tool_mask and tool_images with completion_ids: after truncation by + # _truncate_at_image_boundary, completion_ids may be shorter than tool_mask. + for i in range(len(completion_ids)): + if len(tool_mask[i]) > len(completion_ids[i]): + tool_mask[i] = tool_mask[i][:len(completion_ids[i])] + elif len(tool_mask[i]) < len(completion_ids[i]): + tool_mask[i] = tool_mask[i] + [1] * (len(completion_ids[i]) - len(tool_mask[i])) + if logprobs is not None: + for i in range(len(completion_ids)): + if len(logprobs[i]) > len(completion_ids[i]): + logprobs[i] = logprobs[i][:len(completion_ids[i])] + elif len(logprobs[i]) < len(completion_ids[i]): + logprobs[i] = logprobs[i] + [0.0] * (len(completion_ids[i]) - len(logprobs[i])) + + # Sync tool_images: count complete images in completion_ids and trim tool_images to match. + if hasattr(self.processing_class, "tokenizer"): + vision_start_id = self.processing_class.tokenizer.convert_tokens_to_ids("<|vision_start|>") + vision_end_id = self.processing_class.tokenizer.convert_tokens_to_ids("<|vision_end|>") + if vision_start_id is not None and vision_end_id is not None: + for i, ids in enumerate(completion_ids): + complete_images = sum(1 for t in ids if t == vision_end_id) + if complete_images < len(tool_images[i]): + tool_images[i] = tool_images[i][:complete_images] + return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count, tool_images def _generate(self, prompts: list): @@ -1988,40 +2055,37 @@ def _generate_and_score_completions( forward_kwargs["mm_token_type_ids"] = mm_ids self._build_mm_token_type_ids = False - # Truncation safety: if max_completion_length cut some image tokens, pixel_values may - # have more features than there are image tokens in input_ids. Trim to keep only - # images whose tokens are fully present. - actual_image_tokens = (mm_ids == 1).sum().item() + # Truncation safety: if max_completion_length truncated some image tokens, the number + # of image pad tokens in input_ids won't match pixel_values features. Check per-sample + # and drop ALL images for any sample with a mismatch (safe fallback). image_grid_thw = forward_kwargs.get("image_grid_thw") - if image_grid_thw is not None: + if image_grid_thw is not None and num_images is not None: merge_length = getattr(self.processing_class.image_processor, "merge_size", 2) ** 2 - # Count how many complete images fit in the actual token count - cumulative_tokens = 0 - keep_images = 0 - keep_pixels = 0 - for i in range(image_grid_thw.shape[0]): - img_tokens = image_grid_thw[i].prod().item() // merge_length - if cumulative_tokens + img_tokens <= actual_image_tokens: - cumulative_tokens += img_tokens - keep_images += 1 - keep_pixels += image_grid_thw[i].prod().item() - else: + img_offset = 0 + has_mismatch = False + for b in range(mm_ids.shape[0]): + sample_tokens = (mm_ids[b] == 1).sum().item() + sample_features = 0 + for i in range(num_images[b]): + grid_idx = img_offset + i + if grid_idx < image_grid_thw.shape[0]: + sample_features += image_grid_thw[grid_idx].prod().item() // merge_length + if sample_tokens != sample_features: + has_mismatch = True break - if keep_images < image_grid_thw.shape[0]: - forward_kwargs["image_grid_thw"] = image_grid_thw[:keep_images] - forward_kwargs["pixel_values"] = forward_kwargs["pixel_values"][:keep_pixels] - # Zero out orphaned image tokens in mm_token_type_ids (from truncated images) - mm_ids[mm_ids == 1] = 0 # reset all - # Re-mark only the kept image tokens - count = 0 - for b in range(mm_ids.shape[0]): - for j in range(mm_ids.shape[1]): - if prompt_completion_ids[b, j] == image_pad_id and count < cumulative_tokens: - mm_ids[b, j] = 1 - count += 1 + img_offset += num_images[b] + + if has_mismatch: + # Drop all images: safer than partial trim which is error-prone + forward_kwargs.pop("pixel_values", None) + forward_kwargs.pop("image_grid_thw", None) + mm_ids.zero_() forward_kwargs["mm_token_type_ids"] = mm_ids + num_images = None + print(f" [VLM DEBUG] Image/token mismatch from truncation, dropping images for this batch") - print(f" [VLM DEBUG] mm_token_type_ids: {mm_ids.shape}, image tokens: {actual_image_tokens}, image_grid_thw: {forward_kwargs.get('image_grid_thw', torch.tensor([])).shape}") + actual_image_tokens = (mm_ids == 1).sum().item() + print(f" [VLM DEBUG] mm_token_type_ids: {mm_ids.shape}, image tokens: {actual_image_tokens}, num_images: {num_images}, image_grid_thw: {forward_kwargs.get('image_grid_thw', torch.tensor([])).shape}") # When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a # torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True"). @@ -2690,7 +2754,10 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: if images_raw: images = [] for image_list in self._logs["images"]: - images.append([logging_backend.Image(image) for image in image_list]) + if image_list: + images.append([logging_backend.Image(image) for image in image_list]) + else: + images.append([]) df = pd.concat( [df_base, pd.Series(images, name="image")], axis=1, From ba9252dcbc9364285f7970fbc34ba8bd42e8b090 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Mon, 23 Mar 2026 18:29:29 +0100 Subject: [PATCH 09/47] Clean up PR: extract helpers, remove debug prints, use dynamic token detection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add _get_vision_token_ids() to detect image tokens dynamically across VLM families - Add _normalize_message_content() to deduplicate string→content blocks normalization - Remove all [VLM DEBUG] print statements - Replace _build_mm_token_type_ids instance attribute with local variable - Remove unused tool_multimodal_fields return value from _get_tool_suffix_ids --- trl/trainer/grpo_trainer.py | 122 ++++++++++++++++++------------------ 1 file changed, 60 insertions(+), 62 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0a3676be0be..ceb17cc2593 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1238,17 +1238,20 @@ async def _run_async_funcs(): rewards_per_func = gather(rewards_per_func) return rewards_per_func + @staticmethod + def _normalize_message_content(messages): + """Normalize string content to content blocks in-place for VLM processor compatibility.""" + for message in messages: + if isinstance(message.get("content"), str): + message["content"] = [{"type": "text", "text": message["content"]}] + def _tokenize_prompts(self, prompts: list): """Tokenize prompts and extract images/multimodal fields for generation.""" if is_conversational({"prompt": prompts[0]}): - # When the processor is a VLM processor, normalize string content to content blocks. - # Some VLM processors (e.g., Qwen3.5, Qwen3-VL) iterate over message["content"] - # assuming it's always a list of content blocks, which fails when content is a plain string. + # Normalize string content to content blocks for VLM processors that don't handle plain strings. if hasattr(self.processing_class, "image_processor"): for prompt in prompts: - for message in prompt: - if isinstance(message["content"], str): - message["content"] = [{"type": "text", "text": message["content"]}] + self._normalize_message_content(prompt) # Extract images from messages for VLM support images = [] @@ -1388,14 +1391,9 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" - # Use content blocks format for VLM processors that don't handle plain strings + dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] if hasattr(self.processing_class, "image_processor"): - dummy_messages = [ - {"role": "user", "content": [{"type": "text", "text": "dummy"}]}, - {"role": "assistant", "content": [{"type": "text", "text": "dummy"}]}, - ] - else: - dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] + self._normalize_message_content(dummy_messages) prefix_ids = self.processing_class.apply_chat_template( dummy_messages, add_generation_prompt=False, @@ -1416,7 +1414,6 @@ def _get_tool_suffix_ids(self, tool_messages): if isinstance(part, dict) and part.get("type") == "image": tool_images.append(part["image"]) - tool_multimodal_fields = {} if tool_images and hasattr(self.processing_class, "image_processor"): # For VLMs with images: use processor.__call__ to get correctly expanded image tokens. # apply_chat_template only inserts a single <|image_pad|> placeholder per image, @@ -1433,17 +1430,9 @@ def _get_tool_suffix_ids(self, tool_messages): text=full_text, images=tool_images, return_tensors="pt" ) full_ids = full_result["input_ids"][0].tolist() - # Save multimodal fields (pixel_values, image_grid_thw, etc.) to reuse in the forward pass, - # ensuring consistency between image tokens in suffix_ids and pixel features. - tool_multimodal_fields = { - k: v for k, v in full_result.items() if k not in ("input_ids", "attention_mask") - } else: - # Normalize string content in tool messages for VLM processors if hasattr(self.processing_class, "image_processor"): - for msg in tool_messages: - if isinstance(msg.get("content"), str): - msg["content"] = [{"type": "text", "text": msg["content"]}] + self._normalize_message_content(tool_messages) full_ids = self.processing_class.apply_chat_template( dummy_messages + tool_messages, add_generation_prompt=True, @@ -1463,7 +1452,34 @@ def _get_tool_suffix_ids(self, tool_messages): if full_ids[: len(prefix_ids)] != prefix_ids: raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") - return full_ids[len(prefix_ids) :], tool_multimodal_fields + return full_ids[len(prefix_ids) :] + + def _get_vision_token_ids(self): + """Get vision-related special token IDs from the processor's tokenizer. + + Returns a dict with keys 'vision_start', 'vision_end', 'image_pad', 'video_pad'. + Values are None if the token doesn't exist in the tokenizer's vocabulary. + Works across VLM families (Qwen, Gemma, LLaVA, etc.) by trying common token names. + """ + if not hasattr(self, "_vision_token_ids_cache"): + cache = {"vision_start": None, "vision_end": None, "image_pad": None, "video_pad": None} + if hasattr(self.processing_class, "tokenizer"): + tok = self.processing_class.tokenizer + # Try common token names across VLM families + for name, candidates in { + "vision_start": ["<|vision_start|>", "<|img_start|>"], + "vision_end": ["<|vision_end|>", "<|img_end|>"], + "image_pad": ["<|image_pad|>", "<|image|>", ""], + "video_pad": ["<|video_pad|>"], + }.items(): + for candidate in candidates: + tid = tok.convert_tokens_to_ids(candidate) + # convert_tokens_to_ids returns the unk_token_id for unknown tokens + if tid != tok.unk_token_id: + cache[name] = tid + break + self._vision_token_ids_cache = cache + return self._vision_token_ids_cache def _truncate_at_image_boundary(self, ids, max_length): """Truncate token ID list to max_length, ensuring we don't cut in the middle of an image. @@ -1475,14 +1491,11 @@ def _truncate_at_image_boundary(self, ids, max_length): if len(ids) <= max_length: return ids - # Get special token IDs for image boundaries - if hasattr(self.processing_class, "image_processor") and hasattr(self.processing_class, "tokenizer"): - tok = self.processing_class.tokenizer - vision_end_id = tok.convert_tokens_to_ids("<|vision_end|>") - vision_start_id = tok.convert_tokens_to_ids("<|vision_start|>") - + vtids = self._get_vision_token_ids() + vision_start_id = vtids["vision_start"] + vision_end_id = vtids["vision_end"] + if vision_start_id is not None and vision_end_id is not None: truncated = ids[:max_length] - # Check if we're inside an image sequence: find the last vision_start and vision_end last_start = -1 last_end = -1 for i in range(len(truncated) - 1, -1, -1): @@ -1588,7 +1601,7 @@ async def _run_async_tools(async_coros): tool_messages.insert(0, message) else: break - suffix_ids, _ = self._get_tool_suffix_ids(tool_messages) + suffix_ids = self._get_tool_suffix_ids(tool_messages) prompt_completion_tool_ids.append( prompt_ids[idx_with_tool] + completion_ids[idx_with_tool] + suffix_ids ) @@ -1637,8 +1650,6 @@ async def _run_async_tools(async_coros): ) # Generate new completions after tool execution (using concatenated IDs, no re-tokenization) - loop_img_count = sum(len(x) for x in loop_images if x) if loop_images else 0 - print(f" [VLM DEBUG] Generation in tool loop: {loop_img_count} images passed to _generate_single_turn") post_tool_ids, post_tool_logprobs = self._generate_single_turn( prompt_completion_tool_ids, loop_images, loop_multimodal_fields ) @@ -1719,14 +1730,12 @@ async def _run_async_tools(async_coros): logprobs[i] = logprobs[i] + [0.0] * (len(completion_ids[i]) - len(logprobs[i])) # Sync tool_images: count complete images in completion_ids and trim tool_images to match. - if hasattr(self.processing_class, "tokenizer"): - vision_start_id = self.processing_class.tokenizer.convert_tokens_to_ids("<|vision_start|>") - vision_end_id = self.processing_class.tokenizer.convert_tokens_to_ids("<|vision_end|>") - if vision_start_id is not None and vision_end_id is not None: - for i, ids in enumerate(completion_ids): - complete_images = sum(1 for t in ids if t == vision_end_id) - if complete_images < len(tool_images[i]): - tool_images[i] = tool_images[i][:complete_images] + vtids = self._get_vision_token_ids() + if vtids["vision_end"] is not None: + for i, ids in enumerate(completion_ids): + complete_images = sum(1 for t in ids if t == vtids["vision_end"]) + if complete_images < len(tool_images[i]): + tool_images[i] = tool_images[i][:complete_images] return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count, tool_images @@ -1798,8 +1807,6 @@ def _generate(self, prompts: list): # Merge tool response images into the images list for the forward pass has_tool_images = any(imgs for imgs in tool_images) # DEBUG: tool image collection - print(f" [VLM DEBUG] tool_images per sample: {[len(imgs) for imgs in tool_images]}") - print(f" [VLM DEBUG] completion_ids lengths: {[len(ids) for ids in completion_ids]}") if has_tool_images: if images is None: images = [imgs if imgs else None for imgs in tool_images] @@ -1807,7 +1814,6 @@ def _generate(self, prompts: list): images = [ (existing or []) + new for existing, new in zip(images, tool_images, strict=True) ] - print(f" [VLM DEBUG] images after merge: {[len(x) if x else 0 for x in images]}") else: # Support custom env_mask from rollout_func (e.g., for environment feedback masking) # Internally treated as tool_mask - marks model tokens (1) vs external tokens (0) @@ -1874,6 +1880,7 @@ def _generate_and_score_completions( mode = "train" if self.model.training else "eval" prompts = [x["prompt"] for x in inputs] + build_mm_token_type_ids = False if self.environments: for prompt, environment, reset_kwargs in zip(prompts, self.environments, inputs, strict=True): @@ -1984,7 +1991,6 @@ def _generate_and_score_completions( # For VLMs with tool images: process images through image_processor and construct # mm_token_type_ids from prompt_completion_ids by detecting image/video pad tokens. flat_images = [img for img_list in images if img_list for img in img_list] - print(f" [VLM DEBUG] Forward pass: {len(flat_images)} images via image_processor + mm_token_type_ids") # Get pixel_values and image_grid_thw from image processor image_inputs = self.processing_class.image_processor(images=flat_images, return_tensors="pt") @@ -1994,11 +2000,7 @@ def _generate_and_score_completions( # Build mm_token_type_ids from prompt_completion_ids AFTER the extension code below, # since our mm_token_type_ids already covers the full prompt+completion sequence. # Store the data needed to build it; it will be set after the extension block. - self._build_mm_token_type_ids = True - - for k, v in forward_kwargs.items(): - if hasattr(v, 'shape'): - print(f" [VLM DEBUG] forward_kwargs[{k}].shape = {v.shape}") + build_mm_token_type_ids = True elif images is not None: prompts_text = [ apply_chat_template( @@ -2011,7 +2013,6 @@ def _generate_and_score_completions( forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} else: forward_kwargs = {} - print(f" [VLM DEBUG] No images for forward pass") # If token_type_ids are used, extend them with zeros for the completion part if "token_type_ids" in forward_kwargs: @@ -2044,16 +2045,15 @@ def _generate_and_score_completions( # For VLM tool images: build mm_token_type_ids from the full prompt_completion_ids. # This must happen AFTER the mm_token_type_ids extension block above, because our version # already covers the full sequence (images are in the completion, not just the prompt). - if getattr(self, "_build_mm_token_type_ids", False): - image_pad_id = self.processing_class.tokenizer.convert_tokens_to_ids("<|image_pad|>") - video_pad_id = self.processing_class.tokenizer.convert_tokens_to_ids("<|video_pad|>") + if build_mm_token_type_ids: + vtids = self._get_vision_token_ids() mm_ids = torch.zeros_like(prompt_completion_ids) - if image_pad_id is not None: - mm_ids[prompt_completion_ids == image_pad_id] = 1 - if video_pad_id is not None: - mm_ids[prompt_completion_ids == video_pad_id] = 2 + if vtids["image_pad"] is not None: + mm_ids[prompt_completion_ids == vtids["image_pad"]] = 1 + if vtids["video_pad"] is not None: + mm_ids[prompt_completion_ids == vtids["video_pad"]] = 2 forward_kwargs["mm_token_type_ids"] = mm_ids - self._build_mm_token_type_ids = False + build_mm_token_type_ids = False # consumed # Truncation safety: if max_completion_length truncated some image tokens, the number # of image pad tokens in input_ids won't match pixel_values features. Check per-sample @@ -2082,10 +2082,8 @@ def _generate_and_score_completions( mm_ids.zero_() forward_kwargs["mm_token_type_ids"] = mm_ids num_images = None - print(f" [VLM DEBUG] Image/token mismatch from truncation, dropping images for this batch") actual_image_tokens = (mm_ids == 1).sum().item() - print(f" [VLM DEBUG] mm_token_type_ids: {mm_ids.shape}, image tokens: {actual_image_tokens}, num_images: {num_images}, image_grid_thw: {forward_kwargs.get('image_grid_thw', torch.tensor([])).shape}") # When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a # torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True"). From b4641345d3b343f6a57c36c2faf24486553a664f Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 10:18:21 +0100 Subject: [PATCH 10/47] Use _is_vlm from SFTTrainer convention, simplify vision token detection, remove temp save_strategy --- trl/trainer/grpo_trainer.py | 47 +++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ceb17cc2593..3e7eb80621a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -318,8 +318,12 @@ def __init__( # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): tokenizer = processing_class.tokenizer + self._is_vlm = True + self._vision_token_ids_cache = None # populated lazily by _get_vision_token_ids elif isinstance(processing_class, PreTrainedTokenizerBase): tokenizer = processing_class + self._is_vlm = False + self._vision_token_ids_cache = None else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") @@ -1249,7 +1253,7 @@ def _tokenize_prompts(self, prompts: list): """Tokenize prompts and extract images/multimodal fields for generation.""" if is_conversational({"prompt": prompts[0]}): # Normalize string content to content blocks for VLM processors that don't handle plain strings. - if hasattr(self.processing_class, "image_processor"): + if self._is_vlm: for prompt in prompts: self._normalize_message_content(prompt) @@ -1392,7 +1396,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] - if hasattr(self.processing_class, "image_processor"): + if self._is_vlm: self._normalize_message_content(dummy_messages) prefix_ids = self.processing_class.apply_chat_template( dummy_messages, @@ -1414,7 +1418,7 @@ def _get_tool_suffix_ids(self, tool_messages): if isinstance(part, dict) and part.get("type") == "image": tool_images.append(part["image"]) - if tool_images and hasattr(self.processing_class, "image_processor"): + if tool_images and self._is_vlm: # For VLMs with images: use processor.__call__ to get correctly expanded image tokens. # apply_chat_template only inserts a single <|image_pad|> placeholder per image, # but the model needs N tokens per image (based on resolution). The processor's @@ -1426,12 +1430,14 @@ def _get_tool_suffix_ids(self, tool_messages): chat_template=self.chat_template, **self.chat_template_kwargs, ) + # We only need input_ids (for suffix token extraction). pixel_values and image_grid_thw + # are computed separately in the forward pass via image_processor to avoid mismatches. full_result = self.processing_class( text=full_text, images=tool_images, return_tensors="pt" ) full_ids = full_result["input_ids"][0].tolist() else: - if hasattr(self.processing_class, "image_processor"): + if self._is_vlm: self._normalize_message_content(tool_messages) full_ids = self.processing_class.apply_chat_template( dummy_messages + tool_messages, @@ -1458,26 +1464,21 @@ def _get_vision_token_ids(self): """Get vision-related special token IDs from the processor's tokenizer. Returns a dict with keys 'vision_start', 'vision_end', 'image_pad', 'video_pad'. - Values are None if the token doesn't exist in the tokenizer's vocabulary. - Works across VLM families (Qwen, Gemma, LLaVA, etc.) by trying common token names. + Values are None if the token doesn't exist in the vocabulary. """ - if not hasattr(self, "_vision_token_ids_cache"): + if self._vision_token_ids_cache is None: cache = {"vision_start": None, "vision_end": None, "image_pad": None, "video_pad": None} - if hasattr(self.processing_class, "tokenizer"): + if self._is_vlm: tok = self.processing_class.tokenizer - # Try common token names across VLM families - for name, candidates in { - "vision_start": ["<|vision_start|>", "<|img_start|>"], - "vision_end": ["<|vision_end|>", "<|img_end|>"], - "image_pad": ["<|image_pad|>", "<|image|>", ""], - "video_pad": ["<|video_pad|>"], + for name, token_str in { + "vision_start": "<|vision_start|>", + "vision_end": "<|vision_end|>", + "image_pad": "<|image_pad|>", + "video_pad": "<|video_pad|>", }.items(): - for candidate in candidates: - tid = tok.convert_tokens_to_ids(candidate) - # convert_tokens_to_ids returns the unk_token_id for unknown tokens - if tid != tok.unk_token_id: - cache[name] = tid - break + tid = tok.convert_tokens_to_ids(token_str) + if tid != tok.unk_token_id: + cache[name] = tid self._vision_token_ids_cache = cache return self._vision_token_ids_cache @@ -1697,7 +1698,7 @@ async def _run_async_tools(async_coros): # Decode post-tool completions. Use tokenizer for parsing (processors don't have parse_response). parsing_class = self.processing_class - if not isinstance(parsing_class, PreTrainedTokenizerBase) and hasattr(parsing_class, "tokenizer"): + if self._is_vlm: parsing_class = parsing_class.tokenizer post_tool_completions = [ parse_response(parsing_class, ids) if ids else {} for ids in post_tool_ids @@ -1773,7 +1774,7 @@ def _generate(self, prompts: list): # For VLM processors, delegate to the inner tokenizer for parsing (parse_response lives on the tokenizer). if is_conversational({"prompt": prompts[0]}): parsing_class = self.processing_class - if not isinstance(parsing_class, PreTrainedTokenizerBase) and hasattr(parsing_class, "tokenizer"): + if self._is_vlm: # Propagate response_schema from processor to tokenizer if needed if getattr(self.processing_class, "response_schema", None) and not getattr(parsing_class.tokenizer, "response_schema", None): parsing_class.tokenizer.response_schema = self.processing_class.response_schema @@ -1987,7 +1988,7 @@ def _generate_and_score_completions( num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None # Get forward_kwargs for models with multimodal inputs - if images is not None and hasattr(self.processing_class, "image_processor"): + if images is not None and self._is_vlm: # For VLMs with tool images: process images through image_processor and construct # mm_token_type_ids from prompt_completion_ids by detecting image/video pad tokens. flat_images = [img for img_list in images if img_list for img in img_list] From 164f03514c9b9f3a02c76f77b4c796bdf88840f2 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 10:41:42 +0100 Subject: [PATCH 11/47] precommit --- examples/scripts/openenv/carla_vlm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/scripts/openenv/carla_vlm.py b/examples/scripts/openenv/carla_vlm.py index 23da9d96661..46cdf4c7268 100644 --- a/examples/scripts/openenv/carla_vlm.py +++ b/examples/scripts/openenv/carla_vlm.py @@ -101,9 +101,7 @@ def main(): Observe the scene first, then decide the best course of action to minimize harm.""" - dataset = Dataset.from_dict( - {"prompt": [[{"role": "user", "content": prompt}] for _ in range(args.dataset_size)]} - ) + dataset = Dataset.from_dict({"prompt": [[{"role": "user", "content": prompt}] for _ in range(args.dataset_size)]}) class CarlaVLMEnv: def __init__(self): From 2f49156a8960f972a544abc715c196e6b8847ad1 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 13:01:39 +0100 Subject: [PATCH 12/47] Fix undefined images variable in rollout_func code path --- trl/trainer/grpo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3e99b33b36c..d43b4069316 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1760,6 +1760,8 @@ def _generate(self, prompts: list): raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.") extra_fields = {k: v for k, v in output.items() if k not in required_keys} prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"] + images = None + multimodal_fields = {} else: prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts) completion_ids, logprobs = self._generate_single_turn(prompt_ids, images, multimodal_fields) From bf3c35e8253c536ebc09a6fcd0c40d217159397e Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 13:16:59 +0100 Subject: [PATCH 13/47] Pass tool response images to generation in tool loop for VLM visual feedback --- trl/trainer/grpo_trainer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d43b4069316..8d8cee76db4 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1641,8 +1641,17 @@ async def _run_async_tools(async_coros): if not idxs_with_tool: break # all overlong, exit tool loop - # Filter images and multimodal fields to match the current subset (index into full batch) - loop_images = [images[i] for i in idxs_with_tool] if images else None + # Filter images and multimodal fields to match the current subset (index into full batch). + # Merge tool response images so the model can see visual feedback during generation. + merged_images = images + if any(imgs for imgs in tool_images): + if merged_images is None: + merged_images = [imgs if imgs else None for imgs in tool_images] + else: + merged_images = [ + (existing or []) + new for existing, new in zip(merged_images, tool_images, strict=True) + ] + loop_images = [merged_images[i] for i in idxs_with_tool] if merged_images else None loop_multimodal_fields = ( {k: [v[i] for i in idxs_with_tool] for k, v in multimodal_fields.items()} if multimodal_fields else {} ) From 99748e3771e5e24cabc0dd2b4deb105e7ce71845 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 13:19:22 +0100 Subject: [PATCH 14/47] Use consistent tokenization path for prefix and full IDs in VLM _get_tool_suffix_ids --- trl/trainer/grpo_trainer.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 8d8cee76db4..0c7e3817425 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1420,6 +1420,16 @@ def _get_tool_suffix_ids(self, tool_messages): # apply_chat_template only inserts a single <|image_pad|> placeholder per image, # but the model needs N tokens per image (based on resolution). The processor's # __call__ handles this expansion. + # Use the same tokenization method (processor.__call__) for both prefix and full to + # avoid mismatches from different tokenization paths. + prefix_text = self.processing_class.apply_chat_template( + dummy_messages, + add_generation_prompt=False, + tokenize=False, + chat_template=self.chat_template, + **self.chat_template_kwargs, + ) + prefix_ids = self.processing_class(text=prefix_text, return_tensors="pt")["input_ids"][0].tolist() full_text = self.processing_class.apply_chat_template( dummy_messages + tool_messages, add_generation_prompt=True, @@ -1429,8 +1439,9 @@ def _get_tool_suffix_ids(self, tool_messages): ) # We only need input_ids (for suffix token extraction). pixel_values and image_grid_thw # are computed separately in the forward pass via image_processor to avoid mismatches. - full_result = self.processing_class(text=full_text, images=tool_images, return_tensors="pt") - full_ids = full_result["input_ids"][0].tolist() + full_ids = self.processing_class(text=full_text, images=tool_images, return_tensors="pt")[ + "input_ids" + ][0].tolist() else: if self._is_vlm: for msg in tool_messages: From 4029708b55494c563d0a782b30c200bc8642193f Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 13:22:29 +0100 Subject: [PATCH 15/47] precommit --- trl/trainer/grpo_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0c7e3817425..26e1dc5793b 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1439,9 +1439,9 @@ def _get_tool_suffix_ids(self, tool_messages): ) # We only need input_ids (for suffix token extraction). pixel_values and image_grid_thw # are computed separately in the forward pass via image_processor to avoid mismatches. - full_ids = self.processing_class(text=full_text, images=tool_images, return_tensors="pt")[ - "input_ids" - ][0].tolist() + full_ids = self.processing_class(text=full_text, images=tool_images, return_tensors="pt")["input_ids"][ + 0 + ].tolist() else: if self._is_vlm: for msg in tool_messages: From 738c17e706e3486bb82d07027ff09a9ab3d4f53a Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 14:09:03 +0100 Subject: [PATCH 16/47] Fix VLM image path to only use image_processor for tool images, preserve full processor for dataset images --- trl/trainer/grpo_trainer.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 26e1dc5793b..b65cbce4958 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -2003,20 +2003,18 @@ def _generate_and_score_completions( num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None - # Get forward_kwargs for models with multimodal inputs - if images is not None and self._is_vlm: - # For VLMs with tool images: process images through image_processor and construct - # mm_token_type_ids from prompt_completion_ids by detecting image/video pad tokens. + # Get forward_kwargs for models with multimodal inputs. + # When tool images are present (from _tool_call_loop), use image_processor directly and build + # mm_token_type_ids from prompt_completion_ids. Otherwise, use the full processor pipeline + # which returns model-specific keys (image_sizes, pixel_attention_mask, etc.). + has_tool_images = self.tools and images is not None and any( + img for img_list in images if img_list for img in img_list + ) + if has_tool_images and self._is_vlm: flat_images = [img for img_list in images if img_list for img in img_list] - - # Get pixel_values and image_grid_thw from image processor image_inputs = self.processing_class.image_processor(images=flat_images, return_tensors="pt") image_inputs = super()._prepare_inputs(image_inputs) forward_kwargs = dict(image_inputs) - - # Build mm_token_type_ids from prompt_completion_ids AFTER the extension code below, - # since our mm_token_type_ids already covers the full prompt+completion sequence. - # Store the data needed to build it; it will be set after the extension block. build_mm_token_type_ids = True elif images is not None: prompts_text = [ From 61ee09f35c7520da2b4eba11d066f468437b9cd8 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 14:09:43 +0100 Subject: [PATCH 17/47] Replace getattr chain with direct _is_vlm conditional for max_position_embeddings --- trl/trainer/grpo_trainer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index b65cbce4958..b109b55174d 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1622,11 +1622,10 @@ async def _run_async_tools(async_coros): elif self.use_vllm and self.vllm_mode == "server": max_model_len = self.model.config.max_position_embeddings elif not self.use_vllm: - max_model_len = getattr(self.model.config, "max_position_embeddings", None) - if max_model_len is None: - # Some models (e.g., Qwen3.5) store max length in text_config or use a different attribute - text_config = getattr(self.model.config, "text_config", self.model.config) - max_model_len = getattr(text_config, "max_position_embeddings", 32768) + if self._is_vlm: + max_model_len = self.model.config.text_config.max_position_embeddings + else: + max_model_len = self.model.config.max_position_embeddings else: raise NotImplementedError( f"Unsupported mode detected: use_vllm={self.use_vllm}, vllm_mode={self.vllm_mode}" From af3a5367d5ae5026319ca007476da4ce6fde4bcb Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 14:09:43 +0100 Subject: [PATCH 18/47] Replace getattr chain with direct _is_vlm conditional for max_position_embeddings --- trl/trainer/grpo_trainer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index b65cbce4958..d2efc021967 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1622,11 +1622,10 @@ async def _run_async_tools(async_coros): elif self.use_vllm and self.vllm_mode == "server": max_model_len = self.model.config.max_position_embeddings elif not self.use_vllm: - max_model_len = getattr(self.model.config, "max_position_embeddings", None) - if max_model_len is None: - # Some models (e.g., Qwen3.5) store max length in text_config or use a different attribute - text_config = getattr(self.model.config, "text_config", self.model.config) - max_model_len = getattr(text_config, "max_position_embeddings", 32768) + if self._is_vlm: + max_model_len = self.model.config.text_config.max_position_embeddings + else: + max_model_len = self.model.config.max_position_embeddings else: raise NotImplementedError( f"Unsupported mode detected: use_vllm={self.use_vllm}, vllm_mode={self.vllm_mode}" @@ -2007,8 +2006,8 @@ def _generate_and_score_completions( # When tool images are present (from _tool_call_loop), use image_processor directly and build # mm_token_type_ids from prompt_completion_ids. Otherwise, use the full processor pipeline # which returns model-specific keys (image_sizes, pixel_attention_mask, etc.). - has_tool_images = self.tools and images is not None and any( - img for img_list in images if img_list for img in img_list + has_tool_images = ( + self.tools and images is not None and any(img for img_list in images if img_list for img in img_list) ) if has_tool_images and self._is_vlm: flat_images = [img for img_list in images if img_list for img in img_list] From 419f6cebb3d481bf9d73c9a5bd10509333724878 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 14:30:38 +0100 Subject: [PATCH 19/47] Increase default max-steps to 100 for carla_vlm.py --- examples/scripts/openenv/carla_vlm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/scripts/openenv/carla_vlm.py b/examples/scripts/openenv/carla_vlm.py index 6411e8bff67..0e259611eb6 100644 --- a/examples/scripts/openenv/carla_vlm.py +++ b/examples/scripts/openenv/carla_vlm.py @@ -67,7 +67,7 @@ def parse_args(): parser.add_argument("--dataset-size", type=int, default=1000) parser.add_argument("--max-completion-length", type=int, default=2048) parser.add_argument("--gradient-accumulation-steps", type=int, default=4) - parser.add_argument("--max-steps", type=int, default=50) + parser.add_argument("--max-steps", type=int, default=100) parser.add_argument("--image-size", type=int, default=512, help="Resize camera images to this size. 0 to disable.") parser.add_argument("--trackio-space-id", type=str, default="carla-vlm-grpo") parser.add_argument("--hub-model-id", type=str, default=None) From f51a32b1493d61268a4d2bb6d7f13d9beeea209b Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 14:53:05 +0100 Subject: [PATCH 20/47] Align RLOO image extraction check with GRPO for consistency --- trl/trainer/rloo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 375a0310e3e..b7be05362be 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -909,7 +909,7 @@ def _tokenize_prompts(self, prompts: list): for message in prompt: if isinstance(message["content"], list): for part in message["content"]: - if part["type"] == "image": + if isinstance(part, dict) and part.get("type") == "image": prompt_images.append(part["image"]) has_images = True images.append(prompt_images if prompt_images else None) From 9f320d546b21b75bf96690ac5db6e172cc45a534 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 15:03:28 +0100 Subject: [PATCH 21/47] Handle VLM max_position_embeddings for vLLM server mode --- trl/trainer/grpo_trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d2efc021967..317bc8fd392 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1620,7 +1620,10 @@ async def _run_async_tools(async_coros): if self.use_vllm and self.vllm_mode == "colocate": max_model_len = self.vllm_generation.llm.llm_engine.model_config.max_model_len elif self.use_vllm and self.vllm_mode == "server": - max_model_len = self.model.config.max_position_embeddings + if self._is_vlm: + max_model_len = self.model.config.text_config.max_position_embeddings + else: + max_model_len = self.model.config.max_position_embeddings elif not self.use_vllm: if self._is_vlm: max_model_len = self.model.config.text_config.max_position_embeddings From 92e1440471373942a965c6a8ed36e1f982fa1be4 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 15:15:31 +0100 Subject: [PATCH 22/47] Clamp max_length to 0 in _truncate_at_image_boundary to prevent negative slice --- trl/trainer/grpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 317bc8fd392..6a5a9c69df4 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1497,6 +1497,7 @@ def _truncate_at_image_boundary(self, ids, max_length): backs up to the end of the last complete image. This prevents mismatches between image placeholder tokens in input_ids and pixel_values in the forward pass. """ + max_length = max(max_length, 0) if len(ids) <= max_length: return ids From 312885b4828e6f0a1a8e874add29c1a94ed83a1b Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 15:26:55 +0100 Subject: [PATCH 23/47] Normalize tool message content before image/text branch split in _get_tool_suffix_ids --- trl/trainer/grpo_trainer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 6a5a9c69df4..f09fc5c5a63 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1415,6 +1415,12 @@ def _get_tool_suffix_ids(self, tool_messages): if isinstance(part, dict) and part.get("type") == "image": tool_images.append(part["image"]) + # Normalize string content in tool messages for VLM processors before either path + if self._is_vlm: + for msg in tool_messages: + if isinstance(msg.get("content"), str): + msg["content"] = [{"type": "text", "text": msg["content"]}] + if tool_images and self._is_vlm: # For VLMs with images: use processor.__call__ to get correctly expanded image tokens. # apply_chat_template only inserts a single <|image_pad|> placeholder per image, @@ -1443,10 +1449,6 @@ def _get_tool_suffix_ids(self, tool_messages): 0 ].tolist() else: - if self._is_vlm: - for msg in tool_messages: - if isinstance(msg.get("content"), str): - msg["content"] = [{"type": "text", "text": msg["content"]}] full_ids = self.processing_class.apply_chat_template( dummy_messages + tool_messages, add_generation_prompt=True, From ac5291bb51690904d30f515ff29a621c12ebc691 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 15:36:22 +0100 Subject: [PATCH 24/47] Propagate num_images None-safety fix to RLOO for consistency --- trl/trainer/rloo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index b7be05362be..9d3818302b5 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1157,7 +1157,7 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - num_images = [len(img_list) for img_list in images] if images is not None else None + num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None # Get forward_kwargs for models with multimodal inputs if images is not None: From 24db4f9f5605dc99e11e6ea38ab39caef5da340c Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 16:13:33 +0100 Subject: [PATCH 25/47] Default trackio --- examples/scripts/openenv/carla_vlm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/scripts/openenv/carla_vlm.py b/examples/scripts/openenv/carla_vlm.py index 0e259611eb6..e76ba924064 100644 --- a/examples/scripts/openenv/carla_vlm.py +++ b/examples/scripts/openenv/carla_vlm.py @@ -45,6 +45,7 @@ import argparse import base64 +import os from io import BytesIO from carla_env import CarlaAction, CarlaEnv @@ -69,7 +70,6 @@ def parse_args(): parser.add_argument("--gradient-accumulation-steps", type=int, default=4) parser.add_argument("--max-steps", type=int, default=100) parser.add_argument("--image-size", type=int, default=512, help="Resize camera images to this size. 0 to disable.") - parser.add_argument("--trackio-space-id", type=str, default="carla-vlm-grpo") parser.add_argument("--hub-model-id", type=str, default=None) parser.add_argument("--run-name", type=str, default=None) parser.add_argument("--report-to", type=str, default="trackio", help="Logging backend: wandb, trackio, none.") @@ -198,6 +198,8 @@ def lane_change(self, direction: str) -> list: self.reward = result.observation.rubric_reward or 0.0 return self._format_multimodal(result.observation) + os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + trainer = GRPOTrainer( model=args.model, train_dataset=dataset, From 20dbe94250592b46191df2a2d0bfad84ccc4b402 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 16:34:06 +0100 Subject: [PATCH 26/47] Update position --- examples/scripts/openenv/carla_vlm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/scripts/openenv/carla_vlm.py b/examples/scripts/openenv/carla_vlm.py index e76ba924064..9abbb94cb02 100644 --- a/examples/scripts/openenv/carla_vlm.py +++ b/examples/scripts/openenv/carla_vlm.py @@ -54,6 +54,9 @@ from trl import GRPOConfig, GRPOTrainer +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + def parse_args(): parser = argparse.ArgumentParser(description="Run GRPO VLM training with CARLA environment.") @@ -198,8 +201,6 @@ def lane_change(self, direction: str) -> list: self.reward = result.observation.rubric_reward or 0.0 return self._format_multimodal(result.observation) - os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") - trainer = GRPOTrainer( model=args.model, train_dataset=dataset, From 1be3989acde9a3005f4343e50ebe451da7322a4b Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 16:44:28 +0100 Subject: [PATCH 27/47] Update --- examples/scripts/openenv/carla_vlm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/scripts/openenv/carla_vlm.py b/examples/scripts/openenv/carla_vlm.py index 9abbb94cb02..b0900f0d3de 100644 --- a/examples/scripts/openenv/carla_vlm.py +++ b/examples/scripts/openenv/carla_vlm.py @@ -54,9 +54,6 @@ from trl import GRPOConfig, GRPOTrainer -# Enable logging in a Hugging Face Space -os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") - def parse_args(): parser = argparse.ArgumentParser(description="Run GRPO VLM training with CARLA environment.") @@ -73,6 +70,7 @@ def parse_args(): parser.add_argument("--gradient-accumulation-steps", type=int, default=4) parser.add_argument("--max-steps", type=int, default=100) parser.add_argument("--image-size", type=int, default=512, help="Resize camera images to this size. 0 to disable.") + parser.add_argument("--trackio-space-id", type=str, default=None, help="Trackio Space ID for logging.") parser.add_argument("--hub-model-id", type=str, default=None) parser.add_argument("--run-name", type=str, default=None) parser.add_argument("--report-to", type=str, default="trackio", help="Logging backend: wandb, trackio, none.") @@ -220,6 +218,7 @@ def lane_change(self, direction: str) -> list: hub_model_id=args.hub_model_id, run_name=args.run_name, report_to=args.report_to, + trackio_space_id=args.trackio_space_id, ), environment_factory=CarlaVLMEnv, ) From da155cf259f00743c89d7d6f3cfb2c2646545f53 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 16:52:34 +0100 Subject: [PATCH 28/47] Update based on cursor --- examples/scripts/openenv/carla_vlm.py | 1 - trl/trainer/grpo_trainer.py | 12 ++++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/scripts/openenv/carla_vlm.py b/examples/scripts/openenv/carla_vlm.py index b0900f0d3de..30ac860a03b 100644 --- a/examples/scripts/openenv/carla_vlm.py +++ b/examples/scripts/openenv/carla_vlm.py @@ -45,7 +45,6 @@ import argparse import base64 -import os from io import BytesIO from carla_env import CarlaAction, CarlaEnv diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f09fc5c5a63..ccf20ccee5d 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1415,11 +1415,15 @@ def _get_tool_suffix_ids(self, tool_messages): if isinstance(part, dict) and part.get("type") == "image": tool_images.append(part["image"]) - # Normalize string content in tool messages for VLM processors before either path + # Normalize string content in tool messages for VLM processors before either path. + # Use copies to avoid mutating the original completions data. if self._is_vlm: - for msg in tool_messages: - if isinstance(msg.get("content"), str): - msg["content"] = [{"type": "text", "text": msg["content"]}] + tool_messages = [ + {**msg, "content": [{"type": "text", "text": msg["content"]}]} + if isinstance(msg.get("content"), str) + else msg + for msg in tool_messages + ] if tool_images and self._is_vlm: # For VLMs with images: use processor.__call__ to get correctly expanded image tokens. From 19d5147b1cbd746919a99a2591e03f436afaa8ca Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 24 Mar 2026 17:21:27 +0100 Subject: [PATCH 29/47] Update --- trl/trainer/grpo_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ccf20ccee5d..0211dde8c06 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1941,6 +1941,7 @@ def _generate_and_score_completions( for prompt, image_list in zip(prompts, images, strict=True) ] + dataset_images = images # preserve dataset images before _generate may overwrite ( prompt_ids_list, completion_ids_list, @@ -1951,6 +1952,8 @@ def _generate_and_score_completions( extra_fields, images, ) = self._generate(prompts) + if images is None: + images = dataset_images # restore dataset images (rollout_func path returns None) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] From 468bdc85d318798008517ae07b6b0ea6c63ba21a Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 25 Mar 2026 10:32:10 +0100 Subject: [PATCH 30/47] Update carla_vlm.py defaults: model to 0.8B, image-size to 256, max-completion-length to 4096 --- examples/scripts/openenv/carla_vlm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/scripts/openenv/carla_vlm.py b/examples/scripts/openenv/carla_vlm.py index 30ac860a03b..08a3694f14a 100644 --- a/examples/scripts/openenv/carla_vlm.py +++ b/examples/scripts/openenv/carla_vlm.py @@ -56,7 +56,7 @@ def parse_args(): parser = argparse.ArgumentParser(description="Run GRPO VLM training with CARLA environment.") - parser.add_argument("--model", type=str, default="Qwen/Qwen3.5-2B") + parser.add_argument("--model", type=str, default="Qwen/Qwen3.5-0.8B") parser.add_argument( "--env-urls", type=str, @@ -65,10 +65,10 @@ def parse_args(): help="URLs for CARLA environment servers. At least 2 required (1 Space = 1 connection).", ) parser.add_argument("--dataset-size", type=int, default=1000) - parser.add_argument("--max-completion-length", type=int, default=2048) + parser.add_argument("--max-completion-length", type=int, default=4096) parser.add_argument("--gradient-accumulation-steps", type=int, default=4) parser.add_argument("--max-steps", type=int, default=100) - parser.add_argument("--image-size", type=int, default=512, help="Resize camera images to this size. 0 to disable.") + parser.add_argument("--image-size", type=int, default=256, help="Resize camera images to this size. 0 to disable.") parser.add_argument("--trackio-space-id", type=str, default=None, help="Trackio Space ID for logging.") parser.add_argument("--hub-model-id", type=str, default=None) parser.add_argument("--run-name", type=str, default=None) From aa7bc9d721877bc562e06b768ce00d0fbf92b261 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 27 Mar 2026 16:43:54 +0100 Subject: [PATCH 31/47] Fix replay buffer unpack, move VLM parse_response branching, add truncation comment --- trl/chat_template_utils.py | 10 ++++--- .../grpo_with_replay_buffer_trainer.py | 1 + trl/trainer/grpo_trainer.py | 27 ++++++++++--------- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 42cd233c75f..b1639cc39c8 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -706,7 +706,7 @@ def _validate_tool_calls(tool_calls: list | None) -> None: tool_call["arguments"] = {} -def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict: +def parse_response(tokenizer_or_processor, ids: list[int]) -> dict: r""" Parse a token sequence into structured response dictionaries with fallback handling. @@ -716,9 +716,11 @@ def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict: Also removes incorrectly appended EOS tokens from tool call content when present, and validates tool_calls to ensure all required fields exist. + For VLM processors, automatically uses the inner tokenizer for parsing. + Args: - tokenizer (`PreTrainedTokenizer`): - Tokenizer with a `parse_response()` method. + tokenizer_or_processor (`PreTrainedTokenizer` or VLM processor): + Tokenizer or processor with a `parse_response()` method (directly or via inner tokenizer). ids (`list[int]`): List of token sequences. @@ -739,6 +741,8 @@ def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict: {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]} ``` """ + # VLM processors don't have parse_response directly; use the inner tokenizer + tokenizer = getattr(tokenizer_or_processor, "tokenizer", tokenizer_or_processor) try: parsed = tokenizer.parse_response(ids) # Hotfix: remove incorrectly appended EOS token from tool calls diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index 8202de7c71b..02730560dde 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -104,6 +104,7 @@ def _generate_and_score_completions( num_items_in_batch, sampling_per_token_logps_list, extra_fields, + images, ) = self._generate(prompts) # Convert lists of token IDs to padded tensors diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 96dd06cbc16..9e2f8f2a9f8 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1625,6 +1625,8 @@ async def _run_async_tools(async_coros): # Filter samples whose length exceeds max allowed length. This is important, because both # vLLM and transformers will error out if the input is longer than the model's max length. + # Note: _truncate_at_image_boundary ensures we never cut in the middle of an image token + # sequence (vision_start...vision_end), which would cause pixel_values/input_ids mismatches. if self.use_vllm and self.vllm_mode == "colocate": max_model_len = self.vllm_generation.llm.llm_engine.model_config.max_model_len elif self.use_vllm and self.vllm_mode == "server": @@ -1723,11 +1725,10 @@ async def _run_async_tools(async_coros): pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] - # Decode post-tool completions. Use tokenizer for parsing (processors don't have parse_response). - parsing_class = self.processing_class - if self._is_vlm: - parsing_class = parsing_class.tokenizer - post_tool_completions = [parse_response(parsing_class, ids) if ids else {} for ids in post_tool_ids] + # Decode post-tool completions. + post_tool_completions = [ + parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids + ] # Add post-tool completions to the existing completions for idx in range(len(idxs_with_tool)): @@ -1798,21 +1799,21 @@ def _generate(self, prompts: list): extra_fields = {} # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. - # For VLM processors, delegate to the inner tokenizer for parsing (parse_response lives on the tokenizer). if is_conversational({"prompt": prompts[0]}): parsing_class = self.processing_class + # For VLM processors, propagate response_schema to the inner tokenizer if needed if self._is_vlm: - # Propagate response_schema from processor to tokenizer if needed if getattr(self.processing_class, "response_schema", None) and not getattr( - parsing_class.tokenizer, "response_schema", None + self.processing_class.tokenizer, "response_schema", None ): - parsing_class.tokenizer.response_schema = self.processing_class.response_schema - parsing_class = parsing_class.tokenizer + self.processing_class.tokenizer.response_schema = self.processing_class.response_schema + # parse_response handles VLM processors internally (uses inner tokenizer) + tokenizer = getattr(parsing_class, "tokenizer", parsing_class) if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 - and isinstance(parsing_class, PreTrainedTokenizerBase) - and hasattr(parsing_class, "response_schema") # attribute not set by default for now - and parsing_class.response_schema is not None # only works if the tokenizer has a schema + and isinstance(tokenizer, PreTrainedTokenizerBase) + and hasattr(tokenizer, "response_schema") # attribute not set by default for now + and tokenizer.response_schema is not None # only works if the tokenizer has a schema ): completions = [[parse_response(parsing_class, ids)] for ids in completion_ids] else: From 43cc407ba52c4362c119c4fe4e18b8616db6fd41 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 27 Mar 2026 17:01:50 +0100 Subject: [PATCH 32/47] Align GFPO and replay buffer trainers with _generate changes --- trl/experimental/gfpo/gfpo_trainer.py | 18 ++++++++++++++---- .../grpo_with_replay_buffer_trainer.py | 5 ++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index bfcd3a6c0ca..1a0101efa1e 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -98,9 +98,19 @@ def _generate_and_score_completions(self, inputs): for prompt, image_list in zip(prompts, images, strict=True) ] - prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = ( - self._generate(prompts) - ) + dataset_images = images # preserve dataset images before _generate may overwrite + ( + prompt_ids_list, + completion_ids_list, + tool_mask_list, + completions, + num_items_in_batch, + sampling_per_token_logps_list, + extra_fields, + images, + ) = self._generate(prompts) + if images is None: + images = dataset_images # restore dataset images (rollout_func path returns None) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] @@ -155,7 +165,7 @@ def _generate_and_score_completions(self, inputs): logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - num_images = [len(img_list) for img_list in images] if images is not None else None + num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None # Get forward_kwargs for models with multimodal inputs if images is not None: diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index 02730560dde..32527e8a59c 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -96,6 +96,7 @@ def _generate_and_score_completions( for prompt, image_list in zip(prompts, images, strict=True) ] + dataset_images = images # preserve dataset images before _generate may overwrite ( prompt_ids_list, completion_ids_list, @@ -106,6 +107,8 @@ def _generate_and_score_completions( extra_fields, images, ) = self._generate(prompts) + if images is None: + images = dataset_images # restore dataset images (rollout_func path returns None) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] @@ -168,7 +171,7 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - num_images = [len(img_list) for img_list in images] if images is not None else None + num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None # Get forward_kwargs for models with multimodal inputs if images is not None: From 91958533758b4e19e251f967e92d167257a8d997 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 1 Apr 2026 13:01:36 +0200 Subject: [PATCH 33/47] Support multimodal observations from environment reset --- trl/trainer/grpo_trainer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 9b85474ca43..7b0f44626c9 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1915,7 +1915,14 @@ def _generate_and_score_completions( observation = environment.reset(**reset_kwargs) if observation is None: continue - prompt[-1]["content"] += observation + if isinstance(observation, list): + existing_content = prompt[-1]["content"] + if isinstance(existing_content, str): + prompt[-1]["content"] = [{"type": "text", "text": existing_content}] + observation + else: + prompt[-1]["content"] = existing_content + observation + else: + prompt[-1]["content"] += observation if "images" in inputs[0]: images = [example.get("images") for example in inputs] From 0bd24577e508396effb910cb7687f20ee971e4e2 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 1 Apr 2026 13:47:34 +0200 Subject: [PATCH 34/47] nits --- trl/trainer/grpo_trainer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7b0f44626c9..cf817919501 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1823,6 +1823,7 @@ def _generate(self, prompts: list): completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) # Extract tool calls from the completions and (possibly) execute them + tool_images = [] if self.tools: ( tool_mask, @@ -1836,8 +1837,7 @@ def _generate(self, prompts: list): prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields ) # Merge tool response images into the images list for the forward pass - has_tool_images = any(imgs for imgs in tool_images) - if has_tool_images: + if any(imgs for imgs in tool_images): if images is None: images = [imgs if imgs else None for imgs in tool_images] else: @@ -2028,10 +2028,7 @@ def _generate_and_score_completions( # When tool images are present (from _tool_call_loop), use image_processor directly and build # mm_token_type_ids from prompt_completion_ids. Otherwise, use the full processor pipeline # which returns model-specific keys (image_sizes, pixel_attention_mask, etc.). - has_tool_images = ( - self.tools and images is not None and any(img for img_list in images if img_list for img in img_list) - ) - if has_tool_images and self._is_vlm: + if self.tools and any(imgs for imgs in tool_images) and self._is_vlm: # noqa: F821 flat_images = [img for img_list in images if img_list for img in img_list] image_inputs = self.processing_class.image_processor(images=flat_images, return_tensors="pt") image_inputs = super()._prepare_inputs(image_inputs) From a32d2001182ef7ac0aeddb31edd804debd12b439 Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Wed, 1 Apr 2026 13:59:36 +0200 Subject: [PATCH 35/47] Update trl/trainer/grpo_trainer.py Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> --- trl/trainer/grpo_trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index cf817919501..e6ecca008a8 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1393,9 +1393,7 @@ def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] if self._is_vlm: - for message in dummy_messages: - if isinstance(message.get("content"), str): - message["content"] = [{"type": "text", "text": message["content"]}] + dummy_messages = prepare_multimodal_messages(dummy_messages, []) prefix_ids = self.processing_class.apply_chat_template( dummy_messages, add_generation_prompt=False, From 51fa7246bd007166b275443a20a498a0fe4f56f6 Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Wed, 1 Apr 2026 14:01:05 +0200 Subject: [PATCH 36/47] Update trl/trainer/grpo_trainer.py Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e6ecca008a8..51eb36b1f52 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1403,7 +1403,7 @@ def _get_tool_suffix_ids(self, tool_messages): **self.chat_template_kwargs, ) # VLM processors return batched output (list of lists), unbatch for single conversation - if isinstance(prefix_ids, list) and len(prefix_ids) == 1 and isinstance(prefix_ids[0], list): + if self._is_vlm:: prefix_ids = prefix_ids[0] # Check if tool messages contain images (multimodal tool responses) From 11c2ba6dec41bf53b30e94938271075d4ec20013 Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Wed, 1 Apr 2026 14:01:53 +0200 Subject: [PATCH 37/47] Update trl/trainer/grpo_trainer.py Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 51eb36b1f52..972728eb0d1 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1460,7 +1460,7 @@ def _get_tool_suffix_ids(self, tool_messages): return_dict=False, **self.chat_template_kwargs, ) - if isinstance(full_ids, list) and len(full_ids) == 1 and isinstance(full_ids[0], list): + if self._is_vlm: full_ids = full_ids[0] # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. From c5acc81e96aeca7dabbf04943ed2dca92057b223 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 1 Apr 2026 14:26:36 +0200 Subject: [PATCH 38/47] build_mm_token_type_ids simplified --- trl/trainer/grpo_trainer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index cf817919501..76dd1a64927 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1908,7 +1908,6 @@ def _generate_and_score_completions( mode = "train" if self.model.training else "eval" prompts = [x["prompt"] for x in inputs] - build_mm_token_type_ids = False if self.environments: for prompt, environment, reset_kwargs in zip(prompts, self.environments, inputs, strict=True): @@ -2033,7 +2032,6 @@ def _generate_and_score_completions( image_inputs = self.processing_class.image_processor(images=flat_images, return_tensors="pt") image_inputs = super()._prepare_inputs(image_inputs) forward_kwargs = dict(image_inputs) - build_mm_token_type_ids = True elif images is not None: prompts_text = [ apply_chat_template( @@ -2078,7 +2076,7 @@ def _generate_and_score_completions( # For VLM tool images: build mm_token_type_ids from the full prompt_completion_ids. # This must happen AFTER the mm_token_type_ids extension block above, because our version # already covers the full sequence (images are in the completion, not just the prompt). - if build_mm_token_type_ids: + if self.tools and any(imgs for imgs in tool_images) and self._is_vlm: # noqa: F821 vtids = self._get_vision_token_ids() mm_ids = torch.zeros_like(prompt_completion_ids) if vtids["image_pad"] is not None: @@ -2086,7 +2084,6 @@ def _generate_and_score_completions( if vtids["video_pad"] is not None: mm_ids[prompt_completion_ids == vtids["video_pad"]] = 2 forward_kwargs["mm_token_type_ids"] = mm_ids - build_mm_token_type_ids = False # consumed # Truncation safety: if max_completion_length truncated some image tokens, the number # of image pad tokens in input_ids won't match pixel_values features. Check per-sample From 52249fdb22a3df925c981c6b5483673cbc2c6d90 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 1 Apr 2026 14:33:01 +0200 Subject: [PATCH 39/47] nit --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index faa783ccfbc..c1324ac066b 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1403,7 +1403,7 @@ def _get_tool_suffix_ids(self, tool_messages): **self.chat_template_kwargs, ) # VLM processors return batched output (list of lists), unbatch for single conversation - if self._is_vlm:: + if self._is_vlm: prefix_ids = prefix_ids[0] # Check if tool messages contain images (multimodal tool responses) From 1ea1b5f708d28eb5fb71624e9a5afddb5dd72cc4 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 1 Apr 2026 15:25:05 +0200 Subject: [PATCH 40/47] Fix tests --- trl/trainer/grpo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c1324ac066b..e53033bc390 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1897,6 +1897,7 @@ def _generate(self, prompts: list): logprobs, extra_fields, images, + tool_images, ) def _generate_and_score_completions( @@ -1957,6 +1958,7 @@ def _generate_and_score_completions( sampling_per_token_logps_list, extra_fields, images, + tool_images, ) = self._generate(prompts) if images is None: images = dataset_images # restore dataset images (rollout_func path returns None) From b440783e21aabccf7d35bc04cc5dd548e5a9e6d4 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 1 Apr 2026 15:40:41 +0200 Subject: [PATCH 41/47] extended --- trl/experimental/gfpo/gfpo_trainer.py | 1 + .../grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py | 1 + 2 files changed, 2 insertions(+) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index 1a0101efa1e..11304e5e02f 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -108,6 +108,7 @@ def _generate_and_score_completions(self, inputs): sampling_per_token_logps_list, extra_fields, images, + tool_images, ) = self._generate(prompts) if images is None: images = dataset_images # restore dataset images (rollout_func path returns None) diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index 32527e8a59c..a96f43e0479 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -106,6 +106,7 @@ def _generate_and_score_completions( sampling_per_token_logps_list, extra_fields, images, + tool_images, ) = self._generate(prompts) if images is None: images = dataset_images # restore dataset images (rollout_func path returns None) From 30b2be3d5d36a534da98a8bdbaa11e891f5d1e39 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 2 Apr 2026 11:14:42 +0200 Subject: [PATCH 42/47] Avoid mutating original prompts during VLM content normalization --- trl/trainer/grpo_trainer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e53033bc390..ef3422fc660 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1247,11 +1247,17 @@ def _tokenize_prompts(self, prompts: list): """Tokenize prompts and extract images/multimodal fields for generation.""" if is_conversational({"prompt": prompts[0]}): # Normalize string content to content blocks for VLM processors that don't handle plain strings. + # Use copies to avoid mutating the original prompts. if self._is_vlm: - for prompt in prompts: - for message in prompt: - if isinstance(message.get("content"), str): - message["content"] = [{"type": "text", "text": message["content"]}] + prompts = [ + [ + {**msg, "content": [{"type": "text", "text": msg["content"]}]} + if isinstance(msg.get("content"), str) + else msg + for msg in prompt + ] + for prompt in prompts + ] # Extract images from messages for VLM support images = [] From 826e4c109cfe05da0623f55e44bf2525e808a3b9 Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Thu, 2 Apr 2026 11:27:46 +0200 Subject: [PATCH 43/47] Update trl/trainer/grpo_trainer.py Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ef3422fc660..abc5b4097d4 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -2033,7 +2033,7 @@ def _generate_and_score_completions( # When tool images are present (from _tool_call_loop), use image_processor directly and build # mm_token_type_ids from prompt_completion_ids. Otherwise, use the full processor pipeline # which returns model-specific keys (image_sizes, pixel_attention_mask, etc.). - if self.tools and any(imgs for imgs in tool_images) and self._is_vlm: # noqa: F821 + if self.tools and any(imgs for imgs in tool_images) and self._is_vlm: flat_images = [img for img_list in images if img_list for img in img_list] image_inputs = self.processing_class.image_processor(images=flat_images, return_tensors="pt") image_inputs = super()._prepare_inputs(image_inputs) From 36a631ccd0851d5d3d1d04e47f8c3ac5a8ccadcb Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Thu, 2 Apr 2026 11:27:57 +0200 Subject: [PATCH 44/47] Update trl/trainer/grpo_trainer.py Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index abc5b4097d4..ebf24c3c0ba 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -2082,7 +2082,7 @@ def _generate_and_score_completions( # For VLM tool images: build mm_token_type_ids from the full prompt_completion_ids. # This must happen AFTER the mm_token_type_ids extension block above, because our version # already covers the full sequence (images are in the completion, not just the prompt). - if self.tools and any(imgs for imgs in tool_images) and self._is_vlm: # noqa: F821 + if self.tools and any(imgs for imgs in tool_images) and self._is_vlm: vtids = self._get_vision_token_ids() mm_ids = torch.zeros_like(prompt_completion_ids) if vtids["image_pad"] is not None: From e1711f923bdc1c240e48964bcde6bb264f723b75 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 2 Apr 2026 11:50:49 +0200 Subject: [PATCH 45/47] removed defensive code --- trl/trainer/grpo_trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ef3422fc660..cab68ee79f8 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1751,14 +1751,10 @@ async def _run_async_tools(async_coros): for i in range(len(completion_ids)): if len(tool_mask[i]) > len(completion_ids[i]): tool_mask[i] = tool_mask[i][: len(completion_ids[i])] - elif len(tool_mask[i]) < len(completion_ids[i]): - tool_mask[i] = tool_mask[i] + [1] * (len(completion_ids[i]) - len(tool_mask[i])) if logprobs is not None: for i in range(len(completion_ids)): if len(logprobs[i]) > len(completion_ids[i]): logprobs[i] = logprobs[i][: len(completion_ids[i])] - elif len(logprobs[i]) < len(completion_ids[i]): - logprobs[i] = logprobs[i] + [0.0] * (len(completion_ids[i]) - len(logprobs[i])) # Sync tool_images: count complete images in completion_ids and trim tool_images to match. vtids = self._get_vision_token_ids() From 9a5d3c485d72f3d6912342a7a8523c7bd6eaf0d2 Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Thu, 2 Apr 2026 12:05:25 +0200 Subject: [PATCH 46/47] Update trl/trainer/grpo_trainer.py Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> --- trl/trainer/grpo_trainer.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 61a46630e02..9e8997a6ccc 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1915,14 +1915,9 @@ def _generate_and_score_completions( observation = environment.reset(**reset_kwargs) if observation is None: continue - if isinstance(observation, list): - existing_content = prompt[-1]["content"] - if isinstance(existing_content, str): - prompt[-1]["content"] = [{"type": "text", "text": existing_content}] + observation - else: - prompt[-1]["content"] = existing_content + observation - else: - prompt[-1]["content"] += observation + if isinstance(observation, list) and isinstance(prompt[-1]["content"], str): + prompt[-1]["content"] = [{"type": "text", "text": prompt[-1]["content"]}] + prompt[-1]["content"] += observation if "images" in inputs[0]: images = [example.get("images") for example in inputs] From cad17f36a060a24fb5f484537d8673484e2549ed Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 2 Apr 2026 12:19:26 +0200 Subject: [PATCH 47/47] update based on cursor --- trl/trainer/grpo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 9e8997a6ccc..03bab4d8f17 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1917,6 +1917,8 @@ def _generate_and_score_completions( continue if isinstance(observation, list) and isinstance(prompt[-1]["content"], str): prompt[-1]["content"] = [{"type": "text", "text": prompt[-1]["content"]}] + if isinstance(observation, str) and isinstance(prompt[-1]["content"], list): + observation = [{"type": "text", "text": observation}] prompt[-1]["content"] += observation if "images" in inputs[0]: