diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index d12db42e1a0..c7c88df0c94 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -103,6 +103,7 @@ These scripts demonstrate how to train models with [OpenEnv](openenv) environmen | [`examples/scripts/openenv/browsergym.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym.py) | GRPO training with the BrowserGym environment for VLMs. | | [`examples/scripts/openenv/browsergym_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym_llm.py) | GRPO training with the BrowserGym environment for LLMs. | | [`examples/scripts/openenv/carla.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/carla.py) | GRPO training with the CARLA environment for autonomous driving. | +| [`examples/scripts/openenv/carla_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/carla_vlm.py) | GRPO training with CARLA for VLMs with multimodal tool responses (camera images). | ## Distributed Training (for scripts) diff --git a/examples/scripts/openenv/browsergym.py b/examples/scripts/openenv/browsergym.py index 655daaa3a46..49a0b4c712f 100644 --- a/examples/scripts/openenv/browsergym.py +++ b/examples/scripts/openenv/browsergym.py @@ -22,58 +22,28 @@ # /// """ -Simple script to run GRPO training with OpenEnv's BrowserGym environment and vLLM. +GRPO training with OpenEnv's BrowserGym environment for VLMs (Vision Language Models). -This example automatically detects and uses vision capabilities when VLM models are used. -Screenshots from BrowserGym are collected and passed to the model during training. The GRPO -trainer auto-detects multimodal support by checking for images in the rollout data. - -Setup (Option A - Install from HF Space, recommended): +This script uses `environment_factory` with multimodal tool responses: each tool action +returns a screenshot (PIL Image) alongside the accessibility tree text, allowing the VLM +to see the page visually after each action. +Setup: ```sh -uv pip install git+https://huggingface.co/spaces/openenv/browsergym_env +pip install "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env" ``` -Setup (Option B - Clone OpenEnv repo, for development): - +Usage: ```sh -git clone https://github.com/meta-pytorch/OpenEnv.git -cd OpenEnv/envs/browsergym_env -uv pip install -e . -``` +# Without vLLM (default, 1 GPU) +python examples/scripts/openenv/browsergym.py -# Option 1: HF Spaces + Colocated vLLM (1 GPU required) -```sh -python examples/scripts/openenv/browsergym.py --vllm-mode colocate -``` - -# Option 2: HF Spaces + Separate vLLM server (2 GPUs required) +# With vLLM colocate (1 GPU, requires vLLM support for the model) +python examples/scripts/openenv/browsergym.py --use-vllm -# Spin up vLLM server (Terminal 1) -```sh -CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-VL-2B-Instruct --host 0.0.0.0 --port 8001 -``` - -# Run training (Terminal 2) -```sh -CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/browsergym.py --vllm-mode server --vllm-server-url http://localhost:8001 -``` - -# Option 3: Local + Colocated vLLM (1 GPU required) - -# Build and start the environment only if using --env-mode docker-local -```sh -cd OpenEnv -docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile . -docker build -t browsergym-env:latest -f src/envs/browsergym_env/server/Dockerfile . -docker run -d -p 8001:8001 \ - -e BROWSERGYM_BENCHMARK="miniwob" \ - -e BROWSERGYM_TASK_NAME="click-test" \ - browsergym-env:latest -``` - -```sh -python examples/scripts/openenv/browsergym.py --env-mode docker-local --vllm-mode colocate +# With vLLM server (2 GPUs) +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3.5-2B --host 0.0.0.0 --port 8000 +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/browsergym.py --use-vllm --vllm-mode server ``` """ @@ -87,190 +57,28 @@ from browsergym_env import BrowserGymAction, BrowserGymEnv from datasets import Dataset from PIL import Image -from transformers import AutoTokenizer from trl import GRPOConfig, GRPOTrainer -from trl.experimental.openenv import generate_rollout_completions def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Run GRPO training for BrowserGym MiniWoB using OpenEnv environment.") - parser.add_argument( - "--tokenizer-id", - default="Qwen/Qwen3-VL-2B-Instruct", - help="Model identifier used to load the tokenizer.", - ) - parser.add_argument( - "--model-id", - default="Qwen/Qwen3-VL-2B-Instruct", - help="Model identifier passed to GRPOTrainer for fine-tuning.", - ) - parser.add_argument( - "--env-host", - type=str, - default="https://openenv-browsergym-env.hf.space", - help="Host for the BrowserGym environment.", - ) - parser.add_argument("--env-port", type=int, default=8001, help="Port for the BrowserGym environment.") - parser.add_argument( - "--env-mode", - choices=["docker-local", "docker-image", "docker-hub", "space"], - default="space", - help="Where to run the environment: 'local' to launch it, 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.", - ) - parser.add_argument( - "--env-image", type=str, default="browsergym-env:latest", help="Docker image for the BrowserGym environment." - ) - parser.add_argument( - "--benchmark", - default="miniwob", - help="BrowserGym benchmark to use (miniwob, webarena, etc.).", - ) - parser.add_argument( - "--task-name", - default="click-test", - help="Specific task within the benchmark (e.g., click-test, click-button).", - ) - parser.add_argument( - "--dataset-prompt", - default="Complete the web task successfully.", - help="Prompt text used to seed the training dataset.", - ) - parser.add_argument( - "--dataset-size", - type=int, - default=1000, - help="Number of entries to include in the synthetic training dataset.", - ) - parser.add_argument( - "--max-steps", - type=int, - default=10, - help="Maximum number of steps per episode.", - ) - parser.add_argument( - "--max-new-tokens", - type=int, - default=32, - help="Maximum number of new tokens to request from vLLM for each action.", - ) - parser.add_argument( - "--temperature", - type=float, - default=0.7, - help="Sampling temperature used during rollout generation.", - ) - parser.add_argument( - "--top-k", - type=int, - default=50, - help="Top-k sampling parameter forwarded to vLLM.", - ) - parser.add_argument( - "--top-p", - type=float, - default=None, - help="Optional top-p sampling parameter forwarded to vLLM.", - ) - parser.add_argument( - "--image-size", - type=int, - default=512, - help="Resize screenshots to this size (preserving aspect ratio) to reduce memory usage. Set to 0 to disable resizing.", - ) - parser.add_argument( - "--learning-rate", - type=float, - default=5e-6, - help="Learning rate for GRPO training.", - ) - parser.add_argument( - "--weight-decay", - type=float, - default=0.0, - help="Weight decay applied during optimization.", - ) - parser.add_argument( - "--gradient-accumulation-steps", - type=int, - default=32, - help="Gradient accumulation steps for GRPO training.", - ) - parser.add_argument( - "--warmup-steps", - type=int, - default=10, - help="Warmup steps for the scheduler.", - ) - parser.add_argument( - "--per-device-batch-size", - type=int, - default=1, - help="Per-device train batch size.", - ) - parser.add_argument( - "--num-generations", - type=int, - default=4, - help="Number of rollout generations per dataset prompt.", - ) - parser.add_argument( - "--num-epochs", - type=int, - default=1, - help="Number of training epochs.", - ) - parser.add_argument( - "--save-interval", - type=int, - default=50, - help="Interval (in steps) between checkpoint saves.", - ) - parser.add_argument( - "--save-total-limit", - type=int, - default=None, - help="Maximum number of checkpoints to keep.", - ) - parser.add_argument( - "--output-dir", - default=None, - help="Directory where training outputs and checkpoints are stored.", - ) - parser.add_argument( - "--run-name", - default=None, - help="Optional run name for logging systems.", - ) - parser.add_argument( - "--project", - default=None, - help="Optional project identifier for logging systems.", - ) - parser.add_argument( - "--vllm-mode", - choices=("colocate", "server"), - default="colocate", - help="vLLM execution mode: 'colocate' or 'server'.", - ) - parser.add_argument( - "--vllm-server-url", - type=str, - default="http://localhost:8001", - help="URL for the vLLM server (only used when --vllm-mode=server).", - ) - parser.add_argument( - "--logging-steps", - type=int, - default=1, - help="Frequency of logging steps for GRPO training.", - ) - parser.add_argument( - "--debug", - action="store_true", - default=False, - help="Enable verbose debugging output during rollouts.", - ) + parser = argparse.ArgumentParser(description="GRPO training with BrowserGym VLM environment.") + parser.add_argument("--model-id", default="Qwen/Qwen3.5-2B") + parser.add_argument("--space-url", default="https://openenv-browsergym-env.hf.space") + parser.add_argument("--dataset-prompt", default="Complete the web task successfully.") + parser.add_argument("--dataset-size", type=int, default=1000) + parser.add_argument("--max-steps", type=int, default=10) + parser.add_argument("--max-completion-length", type=int, default=1024) + parser.add_argument("--image-size", type=int, default=512, help="Resize screenshots to this size. 0 to disable.") + parser.add_argument("--num-generations", type=int, default=4) + parser.add_argument("--gradient-accumulation-steps", type=int, default=32) + parser.add_argument("--learning-rate", type=float, default=5e-6) + parser.add_argument("--num-epochs", type=int, default=1) + parser.add_argument("--logging-steps", type=int, default=1) + parser.add_argument("--output-dir", default=None) + parser.add_argument("--use-vllm", action="store_true", default=False, help="Enable vLLM for generation.") + parser.add_argument("--vllm-mode", choices=("colocate", "server"), default="colocate") + parser.add_argument("--vllm-server-url", default="http://localhost:8000") return parser.parse_args() @@ -278,323 +86,204 @@ def sanitize_name(name: str) -> str: return name.replace("/", "-") -# --------------------------------------------------------------------------- -# System Prompt -# --------------------------------------------------------------------------- - -SYSTEM_PROMPT = """You control a web browser through BrowserGym actions. -You must complete the given web task by interacting with the page. +SYSTEM_PROMPT = """You control a web browser to complete tasks. -Available actions: -- noop() - Do nothing -- click(bid) - Click element with BrowserGym ID -- fill(bid, text) - Fill input field -- send_keys(text) - Send keyboard input -- scroll(direction) - Scroll up/down +The page structure shows elements as: [bid] element_type 'element_text' +For example: [13] button 'Click Me!' means the element has bid='13'. -Reply with exactly ONE action on a single line, e.g.: -click('123') -fill('456', 'text') -noop() +You will see a screenshot of the page after each action. Use the visual information +along with the page structure to decide your next action. -Do not include explanations or multiple actions.""" +Use the available tools to interact with the page: +- click: Click an element by its bid +- fill: Fill an input field with text +- send_keys: Send keyboard input +- scroll: Scroll the page +- noop: Do nothing +Complete the given task as efficiently as possible.""" -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def make_user_prompt(goal: str, step_num: int, axtree: str, error: str = "") -> str: - """Create user prompt from observation.""" - prompt_parts = [f"Step {step_num + 1}"] - - if goal: - prompt_parts.append(f"Goal: {goal}") - - if error: - prompt_parts.append(f"Previous action error: {error}") - - # Include accessibility tree (truncated for context) - if axtree: - max_len = 2000 - axtree_truncated = axtree[:max_len] + "..." if len(axtree) > max_len else axtree - prompt_parts.append(f"Page structure:\n{axtree_truncated}") - - prompt_parts.append("What action do you take?") - - return "\n\n".join(prompt_parts) - - -def parse_action(response_text: str) -> str: - """Parse BrowserGym action from model response.""" - # Extract first line that looks like an action - for line in response_text.strip().split("\n"): - line = line.strip() - if "(" in line and ")" in line: - return line - - # Fallback to noop if no valid action found - return "noop()" - - -def rollout_once( - trainer: GRPOTrainer, - env: BrowserGymEnv, - tokenizer: AutoTokenizer, - dataset_prompt: str, - max_steps: int, - image_size: int = 0, - debug: bool = False, -) -> dict[str, list]: - """Run one episode and collect training data.""" - result = env.reset() - observation = result.observation - - prompt_ids: list[int] = [] - completion_ids: list[int] = [] - logprobs: list[float] = [] - step_rewards: list[float] = [] - completion_rewards: list[float] = [] - images: list[Image.Image] = [] # Collect screenshots for VLM - - for step_num in range(max_steps): - if result.done: - break - - # Create prompt from observation - goal = observation.goal or dataset_prompt - axtree = observation.axtree_txt or "" - error = observation.error if observation.last_action_error else "" - - # Collect screenshot if available (for VLM support) - if observation.screenshot is not None: - screenshot_array = np.array(observation.screenshot, dtype=np.uint8) - screenshot_image = Image.fromarray(screenshot_array) - - # Resize to reduce memory if image_size > 0 - if image_size > 0: - # Preserve aspect ratio while resizing - screenshot_image.thumbnail((image_size, image_size), Image.LANCZOS) - print( - f"[DEBUG] Step {step_num + 1}: Collected and resized screenshot from {screenshot_array.shape} to {screenshot_image.size}" - ) - else: - print(f"[DEBUG] Step {step_num + 1}: Collected screenshot, shape={screenshot_array.shape}") - - images.append(screenshot_image) - else: - print(f"[DEBUG] Step {step_num + 1}: No screenshot available") - - user_prompt = make_user_prompt(goal, step_num, axtree, error) - messages = [ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": user_prompt}, - ] - prompt_text = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=False, - ) - - # Generate action with vLLM - rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0] - prompt_ids.extend(rollout_outputs["prompt_ids"]) - completion_ids.extend(rollout_outputs["completion_ids"]) - logprobs.extend(rollout_outputs["logprobs"]) - - completion_text = rollout_outputs.get("text") or tokenizer.decode( - rollout_outputs["completion_ids"], skip_special_tokens=True - ) - - # Parse and execute action - action_str = parse_action(completion_text) - - if debug: - print(f"Step {step_num + 1}: {action_str}") - - # Take action in environment - result = env.step(BrowserGymAction(action_str=action_str)) - observation = result.observation - - # Track rewards - step_reward = float(result.reward or 0.0) - step_rewards.append(step_reward) - - # Reward shaping: success is most important - if result.done and step_reward > 0: - completion_rewards.append(1.0) # Task completed successfully - elif result.done and step_reward == 0: - completion_rewards.append(0.0) # Task failed - else: - completion_rewards.append(step_reward) # Intermediate reward - - # Final reward is based on task completion - final_reward = completion_rewards[-1] if completion_rewards else 0.0 - - result_dict = { - "prompt_ids": prompt_ids, - "completion_ids": completion_ids, - "logprobs": logprobs, - "step_rewards": step_rewards, - "completion_reward": final_reward, - } - - # Include images if available (GRPO trainer will auto-detect VLM support) - if images: - result_dict["images"] = images - - return result_dict - - -# --------------------------------------------------------------------------- -# Rewards -# --------------------------------------------------------------------------- - - -def reward_completion(completions: list[str], **kwargs) -> list[float]: - """Reward for task completion.""" - rewards = kwargs.get("completion_reward") if kwargs else None - if rewards is None: - return [0.0 for _ in completions] - return [float(r) for r in rewards] - - -# --------------------------------------------------------------------------- -# Main entrypoint -# --------------------------------------------------------------------------- +def reward_completion(completions, environments, **kwargs) -> list[float]: + return [env.reward for env in environments] def main() -> None: args = parse_args() - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id) - tokenizer.pad_token = tokenizer.eos_token - - # Select environment mode - if args.env_mode == "docker-local": - env_url = f"http://{args.env_host}:{args.env_port}" - client = BrowserGymEnv(base_url=env_url) - print(f"🌍 Using existing BrowserGym Environment (Docker) at: {env_url}") - elif args.env_mode == "docker-image": - client = BrowserGymEnv.from_docker_image(args.env_image) - print("🌍 Using BrowserGym Environment (Docker) from local Image") - elif args.env_mode == "docker-hub": - client = BrowserGymEnv.from_hub(args.env_image) - print("🌍 Using existing BrowserGym Environment (Docker) from Hub Image") - elif args.env_mode == "space": - env_url = args.env_host - client = BrowserGymEnv(base_url=env_url) - print(f"🌍 Using Hugging Face Space environment at: {env_url}") - else: - raise ValueError(f"Unknown environment mode: {args.env_mode}") - - dataset = Dataset.from_dict({"prompt": [args.dataset_prompt] * args.dataset_size}) - - timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - default_output_dir = Path("outputs") / f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}" - output_dir = Path(args.output_dir or default_output_dir) - - grpo_config = GRPOConfig( - use_vllm=True, - vllm_mode=args.vllm_mode, - vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None, - vllm_gpu_memory_utilization=0.4, - output_dir=str(output_dir), - num_train_epochs=args.num_epochs, - learning_rate=args.learning_rate, - weight_decay=args.weight_decay, - gradient_accumulation_steps=args.gradient_accumulation_steps, - per_device_train_batch_size=args.per_device_batch_size, - warmup_steps=args.warmup_steps, - num_generations=args.num_generations, - generation_batch_size=args.num_generations, # Must be divisible by num_generations - max_completion_length=args.max_new_tokens, - logging_steps=args.logging_steps, - report_to="trackio", - trackio_space_id=f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}", - save_strategy="steps", - save_steps=args.save_interval, - save_total_limit=args.save_total_limit, - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p, + space_url = args.space_url + max_steps = args.max_steps + image_size = args.image_size + + dataset = Dataset.from_dict( + { + "prompt": [ + [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": args.dataset_prompt}, + ] + ] + * args.dataset_size + } ) - grpo_config.run_name = args.run_name or f"run-{timestamp}" - grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}" - - def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: - episode_prompt_ids: list[list[int]] = [] - episode_completion_ids: list[list[int]] = [] - episode_logprobs: list[list[float]] = [] - completion_rewards: list[float] = [] - episode_images: list[list[Image.Image]] = [] - - print(f"\n[DEBUG] rollout_func called with {len(prompts)} prompts") - - for i, prompt_text in enumerate(prompts): - print(f"[DEBUG] Processing prompt {i + 1}/{len(prompts)}") - episode = rollout_once( - trainer=trainer, - env=client, - tokenizer=tokenizer, - dataset_prompt=prompt_text, - max_steps=args.max_steps, - image_size=args.image_size, - debug=args.debug, - ) - episode_prompt_ids.append(episode["prompt_ids"]) - episode_completion_ids.append(episode["completion_ids"]) - episode_logprobs.append(episode["logprobs"]) - completion_rewards.append(episode["completion_reward"]) - - # Collect images if available (for VLM support) - if "images" in episode: - print(f"[DEBUG] Episode {i + 1} has {len(episode['images'])} images") - episode_images.append(episode["images"]) + class BrowserGymVLMEnv: + def __init__(self): + self.client = BrowserGymEnv(base_url=space_url) + self.reward = 0.0 + self.done = False + self._step_count = 0 + + def reset(self, **kwargs) -> str | None: + self.reward = 0.0 + self.done = False + self._step_count = 0 + result = self.client.reset() + self.done = result.done + return self._format_observation(result.observation) + + def click(self, bid: str) -> list: + """Click an element on the page. + + Args: + bid: The BrowserGym ID of the element to click. + + Returns: + The updated page observation with screenshot. + """ + return self._do_action(f"click('{bid}')") + + def fill(self, bid: str, text: str) -> list: + """Fill an input field with text. + + Args: + bid: The BrowserGym ID of the input field. + text: The text to type into the field. + + Returns: + The updated page observation with screenshot. + """ + return self._do_action(f"fill('{bid}', '{text}')") + + def send_keys(self, text: str) -> list: + """Send keyboard input to the page. + + Args: + text: The keyboard input to send. + + Returns: + The updated page observation with screenshot. + """ + return self._do_action(f"send_keys('{text}')") + + def scroll(self, direction: str) -> list: + """Scroll the page. + + Args: + direction: Direction to scroll, either 'up' or 'down'. + + Returns: + The updated page observation with screenshot. + """ + return self._do_action(f"scroll('{direction}')") + + def noop(self) -> list: + """Do nothing and observe the current page state. + + Returns: + The current page observation with screenshot. + """ + return self._do_action("noop()") + + def _do_action(self, action_str: str) -> list: + if self.done: + raise ValueError("Episode is done.") + + self._step_count += 1 + result = self.client.step(BrowserGymAction(action_str=action_str)) + observation = result.observation + step_reward = float(result.reward or 0.0) + self.done = result.done + + if self.done and step_reward > 0: + self.reward = 1.0 + elif self.done: + self.reward = 0.0 else: - print(f"[DEBUG] Episode {i + 1} has NO images") - - result = { - "prompt_ids": episode_prompt_ids, - "completion_ids": episode_completion_ids, - "logprobs": episode_logprobs, - "completion_reward": completion_rewards, - } - - # Include images if any episode had screenshots (GRPO trainer auto-detects VLM) - if episode_images: - result["images"] = episode_images - print(f"[DEBUG] rollout_func returning with images: {len(episode_images)} episodes") - else: - print("[DEBUG] rollout_func returning WITHOUT images") + self.reward = step_reward + + if self._step_count >= max_steps: + self.done = True + + return self._format_observation_multimodal(observation) + + def _format_observation(self, observation) -> str: + """Format initial observation as text (for reset, appended to prompt).""" + parts = [] + if observation.goal: + parts.append(f"Goal: {observation.goal}") + if observation.axtree_txt: + axtree = observation.axtree_txt + if len(axtree) > 2000: + axtree = axtree[:2000] + "..." + parts.append(f"Page structure:\n{axtree}") + return "\n\n".join(parts) if parts else "No observation available." + + def _format_observation_multimodal(self, observation) -> list: + """Format observation as multimodal content blocks (screenshot + text).""" + content = [] + + # Add screenshot if available + if observation.screenshot is not None: + screenshot_array = np.array(observation.screenshot, dtype=np.uint8) + screenshot_image = Image.fromarray(screenshot_array) + if image_size > 0: + screenshot_image.thumbnail((image_size, image_size), Image.LANCZOS) + content.append({"type": "image", "image": screenshot_image}) + + # Add text observation + parts = [] + if observation.goal: + parts.append(f"Goal: {observation.goal}") + if observation.last_action_error and observation.error: + parts.append(f"Error: {observation.error}") + if observation.axtree_txt: + axtree = observation.axtree_txt + if len(axtree) > 2000: + axtree = axtree[:2000] + "..." + parts.append(f"Page structure:\n{axtree}") + text = "\n\n".join(parts) if parts else "No observation available." + content.append({"type": "text", "text": text}) + + return content - return result + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + default_output_dir = Path("outputs") / f"browsergym-vlm-grpo-{sanitize_name(args.model_id)}-{timestamp}" + output_dir = Path(args.output_dir or default_output_dir) trainer = GRPOTrainer( model=args.model_id, - processing_class=tokenizer, - reward_funcs=[reward_completion], + reward_funcs=reward_completion, train_dataset=dataset, - args=grpo_config, - rollout_func=rollout_func, - ) - - print("=" * 80) - print("Starting GRPO training with BrowserGym environment") - print(f"Benchmark: {args.benchmark}") - print(f"Task: {args.task_name}") - print(f"Model: {args.model_id}") - print(f"Using {args.num_generations} rollouts per dataset prompt") - print(f"Output directory: {output_dir}") - print("=" * 80) - - try: - trainer.train() - print("\nTraining completed successfully!") - finally: - client.close() + args=GRPOConfig( + use_vllm=args.use_vllm, + vllm_mode=args.vllm_mode if args.use_vllm else "colocate", + vllm_server_base_url=args.vllm_server_url if args.use_vllm and args.vllm_mode == "server" else None, + output_dir=str(output_dir), + num_train_epochs=args.num_epochs, + learning_rate=args.learning_rate, + gradient_accumulation_steps=args.gradient_accumulation_steps, + num_generations=args.num_generations, + max_completion_length=args.max_completion_length, + logging_steps=args.logging_steps, + log_completions=True, + report_to="trackio", + trackio_space_id=f"browsergym-vlm-grpo-{sanitize_name(args.model_id)}", + chat_template_kwargs={"enable_thinking": False}, + ), + environment_factory=BrowserGymVLMEnv, + ) + trainer.train() if __name__ == "__main__": diff --git a/examples/scripts/openenv/carla_vlm.py b/examples/scripts/openenv/carla_vlm.py new file mode 100644 index 00000000000..08a3694f14a --- /dev/null +++ b/examples/scripts/openenv/carla_vlm.py @@ -0,0 +1,228 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env", +# ] +# /// + + +""" +GRPO training with OpenEnv's CARLA environment for VLMs (Vision Language Models). + +This script uses `environment_factory` with multimodal tool responses: each tool action +returns a camera image from the vehicle alongside the text scene description, allowing the +VLM to see the driving scene visually after each action. + +The CARLA environment simulates an emergency driving scenario where pedestrians are ahead +and the model must learn to observe the scene and take the correct action (e.g., swerve +to an empty lane) to minimize casualties. + +Setup: +```sh +pip install "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env" +``` + +Usage (requires at least 2 CARLA Spaces, each supports only 1 concurrent connection): +```sh +python examples/scripts/openenv/carla_vlm.py \ + --env-urls https://server1.hf.space https://server2.hf.space +``` +""" + +import argparse +import base64 +from io import BytesIO + +from carla_env import CarlaAction, CarlaEnv +from datasets import Dataset +from PIL import Image + +from trl import GRPOConfig, GRPOTrainer + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run GRPO VLM training with CARLA environment.") + parser.add_argument("--model", type=str, default="Qwen/Qwen3.5-0.8B") + parser.add_argument( + "--env-urls", + type=str, + nargs="+", + required=True, + help="URLs for CARLA environment servers. At least 2 required (1 Space = 1 connection).", + ) + parser.add_argument("--dataset-size", type=int, default=1000) + parser.add_argument("--max-completion-length", type=int, default=4096) + parser.add_argument("--gradient-accumulation-steps", type=int, default=4) + parser.add_argument("--max-steps", type=int, default=100) + parser.add_argument("--image-size", type=int, default=256, help="Resize camera images to this size. 0 to disable.") + parser.add_argument("--trackio-space-id", type=str, default=None, help="Trackio Space ID for logging.") + parser.add_argument("--hub-model-id", type=str, default=None) + parser.add_argument("--run-name", type=str, default=None) + parser.add_argument("--report-to", type=str, default="trackio", help="Logging backend: wandb, trackio, none.") + return parser.parse_args() + + +SIM_TICKS = 10 + + +def reward_func(completions, environments, **kwargs): + return [env.reward for env in environments] + + +def main(): + args = parse_args() + env_url_iter = iter(args.env_urls) + image_size = args.image_size + + prompt = """You control an autonomous vehicle in an emergency. There are pedestrians ahead and you must \ +decide what to do immediately. + +You will see a camera image from the vehicle after each action. Use the visual information +along with the scene description to decide your next action. + +You have the following tools available: +- `observe`: Advance time and get a new observation of the scene with a camera image. +- `emergency_stop`: Apply maximum braking to stop the vehicle. +- `lane_change(direction)`: Change lane to the left or right. Direction must be "left" or "right". + +Observe the scene first, then decide the best course of action to minimize harm.""" + + dataset = Dataset.from_dict({"prompt": [[{"role": "user", "content": prompt}] for _ in range(args.dataset_size)]}) + + class CarlaVLMEnv: + def __init__(self): + url = next(env_url_iter) + self.client = CarlaEnv(base_url=url, connect_timeout_s=30, message_timeout_s=120) + self.reward = 0.0 + + @staticmethod + def _describe(obs) -> str: + parts = [] + parts.append(f"Speed: {obs.speed_kmh:.1f} km/h.") + if obs.nearby_actors: + for actor in obs.nearby_actors: + parts.append(f"- {actor.get('type', 'actor')} at {actor.get('distance', '?')}m") + else: + parts.append("No nearby actors detected.") + if obs.collision_detected: + parts.append(f"COLLISION detected with {obs.collided_with or 'unknown'}!") + return "\n".join(parts) + + @staticmethod + def _decode_image(camera_image_b64, target_size): + """Decode base64 JPEG image and optionally resize.""" + img_bytes = base64.b64decode(camera_image_b64) + img = Image.open(BytesIO(img_bytes)) + if target_size > 0: + img.thumbnail((target_size, target_size), Image.LANCZOS) + return img + + def _format_multimodal(self, obs) -> list: + """Format observation as multimodal content blocks (camera image + text).""" + content = [] + if obs.camera_image is not None: + img = self._decode_image(obs.camera_image, image_size) + content.append({"type": "image", "image": img}) + content.append({"type": "text", "text": self._describe(obs)}) + return content + + def _advance(self, ticks: int = SIM_TICKS): + result = None + for _ in range(ticks): + result = self.client.step(CarlaAction(action_type="observe")) + if result.done: + break + return result + + def _advance_and_capture(self, ticks: int = SIM_TICKS): + """Advance the simulation, then capture an image of the current state.""" + result = self._advance(ticks) + capture_result = self.client.step(CarlaAction(action_type="capture_image")) + result.observation.camera_image = capture_result.observation.camera_image + return result + + def reset(self, **kwargs) -> str | None: + result = self.client.reset(scenario_name="trolley_micro_escape_exists") + self.reward = 0.0 + return self._describe(result.observation) + + def observe(self) -> list: + """ + Get the current scene with a camera image and description. + + Returns: + The camera image and scene description with vehicle state and nearby actors. + """ + result = self._advance_and_capture() + self.reward = result.observation.rubric_reward or 0.0 + return self._format_multimodal(result.observation) + + def emergency_stop(self) -> list: + """ + Apply maximum braking to stop the vehicle. + + Returns: + The camera image and scene description after braking. + """ + self.client.step(CarlaAction(action_type="emergency_stop")) + result = self._advance_and_capture() + self.reward = result.observation.rubric_reward or 0.0 + return self._format_multimodal(result.observation) + + def lane_change(self, direction: str) -> list: + """ + Change lane to avoid obstacles. + + Args: + direction: Direction to change lane, either "left" or "right". + + Returns: + The camera image and scene description after changing lane. + """ + self.client.step(CarlaAction(action_type="lane_change", lane_direction=direction)) + result = self._advance_and_capture() + self.reward = result.observation.rubric_reward or 0.0 + return self._format_multimodal(result.observation) + + trainer = GRPOTrainer( + model=args.model, + train_dataset=dataset, + reward_funcs=reward_func, + args=GRPOConfig( + chat_template_kwargs={"enable_thinking": False}, + log_completions=True, + logging_steps=2, + num_completions_to_print=1, + max_completion_length=args.max_completion_length, + per_device_train_batch_size=len(args.env_urls), + steps_per_generation=1, + num_generations=len(args.env_urls), + gradient_accumulation_steps=args.gradient_accumulation_steps, + max_steps=args.max_steps, + push_to_hub=args.hub_model_id is not None, + hub_model_id=args.hub_model_id, + run_name=args.run_name, + report_to=args.report_to, + trackio_space_id=args.trackio_space_id, + ), + environment_factory=CarlaVLMEnv, + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index c5aef1071fe..46719396c4d 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -871,7 +871,7 @@ def _validate_tool_calls(tool_calls: list | None) -> None: tool_call["arguments"] = {} -def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict: +def parse_response(tokenizer_or_processor, ids: list[int]) -> dict: r""" Parse a token sequence into structured response dictionaries with fallback handling. @@ -881,9 +881,11 @@ def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict: Also removes incorrectly appended EOS tokens from tool call content when present, and validates tool_calls to ensure all required fields exist. + For VLM processors, automatically uses the inner tokenizer for parsing. + Args: - tokenizer (`PreTrainedTokenizer`): - Tokenizer with a `parse_response()` method. + tokenizer_or_processor (`PreTrainedTokenizer` or VLM processor): + Tokenizer or processor with a `parse_response()` method (directly or via inner tokenizer). ids (`list[int]`): List of token sequences. @@ -904,6 +906,8 @@ def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict: {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]} ``` """ + # VLM processors don't have parse_response directly; use the inner tokenizer + tokenizer = getattr(tokenizer_or_processor, "tokenizer", tokenizer_or_processor) try: parsed = tokenizer.parse_response(ids) # Hotfix: remove incorrectly appended EOS token from tool calls diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index bfcd3a6c0ca..11304e5e02f 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -98,9 +98,20 @@ def _generate_and_score_completions(self, inputs): for prompt, image_list in zip(prompts, images, strict=True) ] - prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = ( - self._generate(prompts) - ) + dataset_images = images # preserve dataset images before _generate may overwrite + ( + prompt_ids_list, + completion_ids_list, + tool_mask_list, + completions, + num_items_in_batch, + sampling_per_token_logps_list, + extra_fields, + images, + tool_images, + ) = self._generate(prompts) + if images is None: + images = dataset_images # restore dataset images (rollout_func path returns None) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] @@ -155,7 +166,7 @@ def _generate_and_score_completions(self, inputs): logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - num_images = [len(img_list) for img_list in images] if images is not None else None + num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None # Get forward_kwargs for models with multimodal inputs if images is not None: diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index 8202de7c71b..a96f43e0479 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -96,6 +96,7 @@ def _generate_and_score_completions( for prompt, image_list in zip(prompts, images, strict=True) ] + dataset_images = images # preserve dataset images before _generate may overwrite ( prompt_ids_list, completion_ids_list, @@ -104,7 +105,11 @@ def _generate_and_score_completions( num_items_in_batch, sampling_per_token_logps_list, extra_fields, + images, + tool_images, ) = self._generate(prompts) + if images is None: + images = dataset_images # restore dataset images (rollout_func path returns None) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] @@ -167,7 +172,7 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - num_images = [len(img_list) for img_list in images] if images is not None else None + num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None # Get forward_kwargs for models with multimodal inputs if images is not None: diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c5eed094192..03bab4d8f17 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -314,8 +314,12 @@ def __init__( # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): tokenizer = processing_class.tokenizer + self._is_vlm = True + self._vision_token_ids_cache = None # populated lazily by _get_vision_token_ids elif isinstance(processing_class, PreTrainedTokenizerBase): tokenizer = processing_class + self._is_vlm = False + self._vision_token_ids_cache = None else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") @@ -1242,6 +1246,19 @@ async def _run_async_funcs(): def _tokenize_prompts(self, prompts: list): """Tokenize prompts and extract images/multimodal fields for generation.""" if is_conversational({"prompt": prompts[0]}): + # Normalize string content to content blocks for VLM processors that don't handle plain strings. + # Use copies to avoid mutating the original prompts. + if self._is_vlm: + prompts = [ + [ + {**msg, "content": [{"type": "text", "text": msg["content"]}]} + if isinstance(msg.get("content"), str) + else msg + for msg in prompt + ] + for prompt in prompts + ] + # Extract images from messages for VLM support images = [] has_images = False @@ -1250,7 +1267,7 @@ def _tokenize_prompts(self, prompts: list): for message in prompt: if isinstance(message["content"], list): for part in message["content"]: - if part["type"] == "image": + if isinstance(part, dict) and part.get("type") == "image": prompt_images.append(part["image"]) has_images = True images.append(prompt_images if prompt_images else None) @@ -1381,20 +1398,76 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] + if self._is_vlm: + dummy_messages = prepare_multimodal_messages(dummy_messages, []) prefix_ids = self.processing_class.apply_chat_template( dummy_messages, add_generation_prompt=False, + tokenize=True, chat_template=self.chat_template, return_dict=False, **self.chat_template_kwargs, ) - full_ids = self.processing_class.apply_chat_template( - dummy_messages + tool_messages, - add_generation_prompt=True, - chat_template=self.chat_template, - return_dict=False, - **self.chat_template_kwargs, - ) + # VLM processors return batched output (list of lists), unbatch for single conversation + if self._is_vlm: + prefix_ids = prefix_ids[0] + + # Check if tool messages contain images (multimodal tool responses) + tool_images = [] + for msg in tool_messages: + if isinstance(msg.get("content"), list): + for part in msg["content"]: + if isinstance(part, dict) and part.get("type") == "image": + tool_images.append(part["image"]) + + # Normalize string content in tool messages for VLM processors before either path. + # Use copies to avoid mutating the original completions data. + if self._is_vlm: + tool_messages = [ + {**msg, "content": [{"type": "text", "text": msg["content"]}]} + if isinstance(msg.get("content"), str) + else msg + for msg in tool_messages + ] + + if tool_images and self._is_vlm: + # For VLMs with images: use processor.__call__ to get correctly expanded image tokens. + # apply_chat_template only inserts a single <|image_pad|> placeholder per image, + # but the model needs N tokens per image (based on resolution). The processor's + # __call__ handles this expansion. + # Use the same tokenization method (processor.__call__) for both prefix and full to + # avoid mismatches from different tokenization paths. + prefix_text = self.processing_class.apply_chat_template( + dummy_messages, + add_generation_prompt=False, + tokenize=False, + chat_template=self.chat_template, + **self.chat_template_kwargs, + ) + prefix_ids = self.processing_class(text=prefix_text, return_tensors="pt")["input_ids"][0].tolist() + full_text = self.processing_class.apply_chat_template( + dummy_messages + tool_messages, + add_generation_prompt=True, + tokenize=False, + chat_template=self.chat_template, + **self.chat_template_kwargs, + ) + # We only need input_ids (for suffix token extraction). pixel_values and image_grid_thw + # are computed separately in the forward pass via image_processor to avoid mismatches. + full_ids = self.processing_class(text=full_text, images=tool_images, return_tensors="pt")["input_ids"][ + 0 + ].tolist() + else: + full_ids = self.processing_class.apply_chat_template( + dummy_messages + tool_messages, + add_generation_prompt=True, + tokenize=True, + chat_template=self.chat_template, + return_dict=False, + **self.chat_template_kwargs, + ) + if self._is_vlm: + full_ids = full_ids[0] # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to @@ -1404,15 +1477,70 @@ def _get_tool_suffix_ids(self, tool_messages): if full_ids[: len(prefix_ids)] != prefix_ids: raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") - return full_ids[len(prefix_ids) :] + def _get_vision_token_ids(self): + """Get vision-related special token IDs from the processor's tokenizer. + + Returns a dict with keys 'vision_start', 'vision_end', 'image_pad', 'video_pad'. + Values are None if the token doesn't exist in the vocabulary. + """ + if self._vision_token_ids_cache is None: + cache = {"vision_start": None, "vision_end": None, "image_pad": None, "video_pad": None} + if self._is_vlm: + tok = self.processing_class.tokenizer + for name, token_str in { + "vision_start": "<|vision_start|>", + "vision_end": "<|vision_end|>", + "image_pad": "<|image_pad|>", + "video_pad": "<|video_pad|>", + }.items(): + tid = tok.convert_tokens_to_ids(token_str) + if tid != tok.unk_token_id: + cache[name] = tid + self._vision_token_ids_cache = cache + return self._vision_token_ids_cache + + def _truncate_at_image_boundary(self, ids, max_length): + """Truncate token ID list to max_length, ensuring we don't cut in the middle of an image. + + If truncation would split an image token sequence (<|vision_start|>...<|vision_end|>), + backs up to the end of the last complete image. This prevents mismatches between + image placeholder tokens in input_ids and pixel_values in the forward pass. + """ + max_length = max(max_length, 0) + if len(ids) <= max_length: + return ids + + vtids = self._get_vision_token_ids() + vision_start_id = vtids["vision_start"] + vision_end_id = vtids["vision_end"] + if vision_start_id is not None and vision_end_id is not None: + truncated = ids[:max_length] + last_start = -1 + last_end = -1 + for i in range(len(truncated) - 1, -1, -1): + if truncated[i] == vision_end_id and last_end == -1: + last_end = i + if truncated[i] == vision_start_id and last_start == -1: + last_start = i + if last_start != -1 and last_end != -1: + break + + # If last vision_start > last vision_end, we're inside an incomplete image + if last_start > last_end: + return ids[:last_start] # truncate before the incomplete image + return truncated + return ids[:max_length] + def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields): # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt tool_calls = [completion[0].get("tool_calls") for completion in completions] idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] tool_calls = [tool_calls[idx] for idx in idxs_with_tool] tool_mask = [[1] * len(ids) for ids in completion_ids] # 0 for tool result tokens, 1 elsewhere + # Collect images from multimodal tool responses for the forward pass + tool_images = [[] for _ in completion_ids] tool_call_count = 0 tool_failure_count = 0 iteration_num = 0 @@ -1470,7 +1598,16 @@ async def _run_async_tools(async_coros): tool_call_results.append((name, result)) for name, result in tool_call_results: - tool_message = {"role": "tool", "name": name, "content": str(result)} + # Support multimodal tool responses: if the tool returns a list of content blocks + # (e.g., [{"type": "image", "image": ...}, {"type": "text", "text": "..."}]), + # pass them through directly so _tokenize_prompts can extract images for VLMs. + content = result if isinstance(result, list) else str(result) + tool_message = {"role": "tool", "name": name, "content": content} + # Collect images from multimodal tool responses + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "image": + tool_images[idx_with_tool].append(part["image"]) prompt_completion_tool.append(tool_message) completions[idx_with_tool].append(tool_message) @@ -1492,12 +1629,20 @@ async def _run_async_tools(async_coros): # Filter samples whose length exceeds max allowed length. This is important, because both # vLLM and transformers will error out if the input is longer than the model's max length. + # Note: _truncate_at_image_boundary ensures we never cut in the middle of an image token + # sequence (vision_start...vision_end), which would cause pixel_values/input_ids mismatches. if self.use_vllm and self.vllm_mode == "colocate": max_model_len = self.vllm_generation.llm.llm_engine.model_config.max_model_len elif self.use_vllm and self.vllm_mode == "server": - max_model_len = self.model.config.max_position_embeddings + if self._is_vlm: + max_model_len = self.model.config.text_config.max_position_embeddings + else: + max_model_len = self.model.config.max_position_embeddings elif not self.use_vllm: - max_model_len = self.model.config.max_position_embeddings + if self._is_vlm: + max_model_len = self.model.config.text_config.max_position_embeddings + else: + max_model_len = self.model.config.max_position_embeddings else: raise NotImplementedError( f"Unsupported mode detected: use_vllm={self.use_vllm}, vllm_mode={self.vllm_mode}" @@ -1507,7 +1652,9 @@ async def _run_async_tools(async_coros): idx_with_tool = idxs_with_tool[idx] if overlong[idx]: prompt_length = len(prompt_ids[idx_with_tool]) - ct = prompt_completion_tool_ids[idx][prompt_length : prompt_length + self.max_completion_length] + ct = self._truncate_at_image_boundary( + prompt_completion_tool_ids[idx][prompt_length:], self.max_completion_length + ) completion_ids[idx_with_tool] = ct tool_mask[idx_with_tool] += [1] * (len(ct) - len(tool_mask[idx_with_tool])) if logprobs is not None: @@ -1521,8 +1668,17 @@ async def _run_async_tools(async_coros): if not idxs_with_tool: break # all overlong, exit tool loop - # Filter images and multimodal fields to match the current subset (index into full batch) - loop_images = [images[i] for i in idxs_with_tool] if images else None + # Filter images and multimodal fields to match the current subset (index into full batch). + # Merge tool response images so the model can see visual feedback during generation. + merged_images = images + if any(imgs for imgs in tool_images): + if merged_images is None: + merged_images = [imgs if imgs else None for imgs in tool_images] + else: + merged_images = [ + (existing or []) + new for existing, new in zip(merged_images, tool_images, strict=True) + ] + loop_images = [merged_images[i] for i in idxs_with_tool] if merged_images else None loop_multimodal_fields = ( {k: [v[i] for i in idxs_with_tool] for k, v in multimodal_fields.items()} if multimodal_fields else {} ) @@ -1539,14 +1695,20 @@ async def _run_async_tools(async_coros): completion_tool_ids = prompt_completion_tool_ids[idx][prompt_len:] excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length if excess_length > 0: - # If exceeding max length, truncate post_tool_ids - post_tool_ids[idx] = post_tool_ids[idx][:-excess_length] + # If exceeding max length, truncate post_tool_ids (respecting image boundaries) + truncated_post = self._truncate_at_image_boundary( + post_tool_ids[idx], len(post_tool_ids[idx]) - excess_length + ) if logprobs is not None: - post_tool_logprobs[idx] = post_tool_logprobs[idx][:-excess_length] + post_tool_logprobs[idx] = post_tool_logprobs[idx][: len(truncated_post)] + post_tool_ids[idx] = truncated_post excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length if excess_length > 0: - # If still exceeding max length, truncate completion_tool_ids as well - prompt_completion_tool_ids[idx] = prompt_completion_tool_ids[idx][:-excess_length] + # If still exceeding, truncate completion_tool_ids (respecting image boundaries) + truncated_pct = self._truncate_at_image_boundary( + prompt_completion_tool_ids[idx], len(prompt_completion_tool_ids[idx]) - excess_length + ) + prompt_completion_tool_ids[idx] = truncated_pct # Update tool_mask: the tool result should be 0 and the post-tool 1 for idx in range(len(idxs_with_tool)): @@ -1567,7 +1729,7 @@ async def _run_async_tools(async_coros): pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] - # Decode post-tool completions + # Decode post-tool completions. post_tool_completions = [ parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids ] @@ -1584,7 +1746,25 @@ async def _run_async_tools(async_coros): tool_calls = [tool_call for tool_call in tool_calls if tool_call] iteration_num += 1 - return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count + # Sync tool_mask and tool_images with completion_ids: after truncation by + # _truncate_at_image_boundary, completion_ids may be shorter than tool_mask. + for i in range(len(completion_ids)): + if len(tool_mask[i]) > len(completion_ids[i]): + tool_mask[i] = tool_mask[i][: len(completion_ids[i])] + if logprobs is not None: + for i in range(len(completion_ids)): + if len(logprobs[i]) > len(completion_ids[i]): + logprobs[i] = logprobs[i][: len(completion_ids[i])] + + # Sync tool_images: count complete images in completion_ids and trim tool_images to match. + vtids = self._get_vision_token_ids() + if vtids["vision_end"] is not None: + for i, ids in enumerate(completion_ids): + complete_images = sum(1 for t in ids if t == vtids["vision_end"]) + if complete_images < len(tool_images[i]): + tool_images[i] = tool_images[i][:complete_images] + + return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count, tool_images def _generate(self, prompts: list): device = self.accelerator.device @@ -1611,6 +1791,8 @@ def _generate(self, prompts: list): raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.") extra_fields = {k: v for k, v in output.items() if k not in required_keys} prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"] + images = None + multimodal_fields = {} else: prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts) completion_ids, logprobs = self._generate_single_turn(prompt_ids, images, multimodal_fields) @@ -1618,13 +1800,22 @@ def _generate(self, prompts: list): # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): + parsing_class = self.processing_class + # For VLM processors, propagate response_schema to the inner tokenizer if needed + if self._is_vlm: + if getattr(self.processing_class, "response_schema", None) and not getattr( + self.processing_class.tokenizer, "response_schema", None + ): + self.processing_class.tokenizer.response_schema = self.processing_class.response_schema + # parse_response handles VLM processors internally (uses inner tokenizer) + tokenizer = getattr(parsing_class, "tokenizer", parsing_class) if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 - and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors - and hasattr(self.processing_class, "response_schema") # attribute not set by default for now - and self.processing_class.response_schema is not None # only works if the tokenizer has a schema + and isinstance(tokenizer, PreTrainedTokenizerBase) + and hasattr(tokenizer, "response_schema") # attribute not set by default for now + and tokenizer.response_schema is not None # only works if the tokenizer has a schema ): - completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] + completions = [[parse_response(parsing_class, ids)] for ids in completion_ids] else: contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) completions = [[{"role": "assistant", "content": content}] for content in contents] @@ -1632,6 +1823,7 @@ def _generate(self, prompts: list): completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) # Extract tool calls from the completions and (possibly) execute them + tool_images = [] if self.tools: ( tool_mask, @@ -1640,9 +1832,16 @@ def _generate(self, prompts: list): logprobs, tool_call_count, tool_failure_count, + tool_images, ) = self._tool_call_loop( prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields ) + # Merge tool response images into the images list for the forward pass + if any(imgs for imgs in tool_images): + if images is None: + images = [imgs if imgs else None for imgs in tool_images] + else: + images = [(existing or []) + new for existing, new in zip(images, tool_images, strict=True)] else: # Support custom env_mask from rollout_func (e.g., for environment feedback masking) # Internally treated as tool_mask - marks model tokens (1) vs external tokens (0) @@ -1699,6 +1898,8 @@ def _generate(self, prompts: list): total_completion_tokens, logprobs, extra_fields, + images, + tool_images, ) def _generate_and_score_completions( @@ -1714,6 +1915,10 @@ def _generate_and_score_completions( observation = environment.reset(**reset_kwargs) if observation is None: continue + if isinstance(observation, list) and isinstance(prompt[-1]["content"], str): + prompt[-1]["content"] = [{"type": "text", "text": prompt[-1]["content"]}] + if isinstance(observation, str) and isinstance(prompt[-1]["content"], list): + observation = [{"type": "text", "text": observation}] prompt[-1]["content"] += observation if "images" in inputs[0]: @@ -1742,6 +1947,7 @@ def _generate_and_score_completions( for prompt, image_list in zip(prompts, images, strict=True) ] + dataset_images = images # preserve dataset images before _generate may overwrite ( prompt_ids_list, completion_ids_list, @@ -1750,7 +1956,11 @@ def _generate_and_score_completions( num_items_in_batch, sampling_per_token_logps_list, extra_fields, + images, + tool_images, ) = self._generate(prompts) + if images is None: + images = dataset_images # restore dataset images (rollout_func path returns None) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] @@ -1810,10 +2020,18 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - num_images = [len(img_list) for img_list in images] if images is not None else None - - # Get forward_kwargs for models with multimodal inputs - if images is not None: + num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None + + # Get forward_kwargs for models with multimodal inputs. + # When tool images are present (from _tool_call_loop), use image_processor directly and build + # mm_token_type_ids from prompt_completion_ids. Otherwise, use the full processor pipeline + # which returns model-specific keys (image_sizes, pixel_attention_mask, etc.). + if self.tools and any(imgs for imgs in tool_images) and self._is_vlm: + flat_images = [img for img_list in images if img_list for img in img_list] + image_inputs = self.processing_class.image_processor(images=flat_images, return_tensors="pt") + image_inputs = super()._prepare_inputs(image_inputs) + forward_kwargs = dict(image_inputs) + elif images is not None: prompts_text = [ apply_chat_template( {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs @@ -1854,6 +2072,46 @@ def _generate_and_score_completions( [mm_token_type_ids, mm_token_type_ids.new_zeros(completion_ids.shape)], dim=1 ) + # For VLM tool images: build mm_token_type_ids from the full prompt_completion_ids. + # This must happen AFTER the mm_token_type_ids extension block above, because our version + # already covers the full sequence (images are in the completion, not just the prompt). + if self.tools and any(imgs for imgs in tool_images) and self._is_vlm: + vtids = self._get_vision_token_ids() + mm_ids = torch.zeros_like(prompt_completion_ids) + if vtids["image_pad"] is not None: + mm_ids[prompt_completion_ids == vtids["image_pad"]] = 1 + if vtids["video_pad"] is not None: + mm_ids[prompt_completion_ids == vtids["video_pad"]] = 2 + forward_kwargs["mm_token_type_ids"] = mm_ids + + # Truncation safety: if max_completion_length truncated some image tokens, the number + # of image pad tokens in input_ids won't match pixel_values features. Check per-sample + # and drop ALL images for any sample with a mismatch (safe fallback). + image_grid_thw = forward_kwargs.get("image_grid_thw") + if image_grid_thw is not None and num_images is not None: + merge_length = getattr(self.processing_class.image_processor, "merge_size", 2) ** 2 + img_offset = 0 + has_mismatch = False + for b in range(mm_ids.shape[0]): + sample_tokens = (mm_ids[b] == 1).sum().item() + sample_features = 0 + for i in range(num_images[b]): + grid_idx = img_offset + i + if grid_idx < image_grid_thw.shape[0]: + sample_features += image_grid_thw[grid_idx].prod().item() // merge_length + if sample_tokens != sample_features: + has_mismatch = True + break + img_offset += num_images[b] + + if has_mismatch: + # Drop all images: safer than partial trim which is error-prone + forward_kwargs.pop("pixel_values", None) + forward_kwargs.pop("image_grid_thw", None) + mm_ids.zero_() + forward_kwargs["mm_token_type_ids"] = mm_ids + num_images = None + # When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a # torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True"). # Temporarily disable checkpointing to avoid this warning during inference. @@ -2525,7 +2783,10 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: if images_raw: images = [] for image_list in self._logs["images"]: - images.append([logging_backend.Image(image) for image in image_list]) + if image_list: + images.append([logging_backend.Image(image) for image in image_list]) + else: + images.append([]) df = pd.concat( [df_base, pd.Series(images, name="image")], axis=1, diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 4e62310e7c7..585b1f605eb 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -912,7 +912,7 @@ def _tokenize_prompts(self, prompts: list): for message in prompt: if isinstance(message["content"], list): for part in message["content"]: - if part["type"] == "image": + if isinstance(part, dict) and part.get("type") == "image": prompt_images.append(part["image"]) has_images = True images.append(prompt_images if prompt_images else None) @@ -1155,7 +1155,7 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - num_images = [len(img_list) for img_list in images] if images is not None else None + num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None # Get forward_kwargs for models with multimodal inputs if images is not None: