diff --git a/conf/base.yaml b/conf/base.yaml index e3122f5a..5043aa56 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -47,7 +47,7 @@ llm: temperature: 1.0 test_llm: parameters: - max_tokens: 16000 + max_tokens: 8192 temperature: 1.0 top_p: 0.95 top_k: 50 @@ -67,6 +67,7 @@ vllm_config: tensor-parallel-size: 1 pipeline-parallel-size: 1 generation-config: vllm + max_model_len: 10000 world: replicas: 1 @@ -75,7 +76,8 @@ world: preprocessor_fraction: 0 finetune_fraction: 4 - env_replicas: 2 + # Number of environment servers per actor VLLM server + env_replicas_per_actor: 1 actor_group_port: 9000 environment_start_port: 7777 diff --git a/conf/miniwob.yaml b/conf/miniwob.yaml index 07017d76..a8dc3868 100644 --- a/conf/miniwob.yaml +++ b/conf/miniwob.yaml @@ -1,34 +1,32 @@ defaults: - base + - override streams: redis + - override finetune: ppo + - _self_ world: - actor_fraction: 1 - preprocessor_fraction: 1 - finetune_fraction: 4 + actor_fraction: 2 + preprocessor_fraction: 0 + finetune_fraction: 6 # debug: # mode: actor save_tapes: False -output_dir: results/miniwob_debug/${now:%Y-%m-%d}/${now:%H-%M-%S} +output_dir: results/miniwob/${now:%Y-%m-%d}/${now:%H-%M-%S} model_path: meta-llama/Llama-3.1-8B-Instruct finetune: - save_checkpoint_steps: 10 - seq_length: 4096 + seq_length: 16384 # input + output tokens + max_train_steps: 1000 # 1000 optim steps = 1000 * bs samples train_batch_size: 1 gradient_accumulation_passes: 1024 - learning_rate: 1e-6 - optim: adamw_torch - rl: - kl_coef: 0.01 # GRPO beta coefficient - reward_minus_kl_coef: 0.0 # RLOO beta coefficient - use_advantages: true - algo: grpo + +eval_every_n_versions: 10240 # 1024 effective bs * 10 "optim steps" llm: parameters: - max_tokens: 3072 + max_tokens: 4096 # output tokens temperature: 1.0 test_llm: parameters: @@ -39,24 +37,37 @@ test_llm: vllm_config: vllm_kwargs: - enable-auto-tool-choice: "" - tool-call-parser: llama3_json # use hermes for qwen - chat_template: pipelinerl/domains/miniwob/tool_chat_template_llama3.1_json.jinja # copy pasted from https://github.com/vllm-project/vllm/blob/main/examples/tool_chat_template_llama3.1_json.jinja - enforce-eager: "" # speed the actor llm startup a bit + max_model_len: 16384 # input + output tokens actor: rollout_policy: pipelinerl.domains.miniwob.rollouts.generate_miniwob_rollout shared_memory_entry_size: 100000000 + llm_max_rollouts: 32 preprocess: - shared_memory_entry_size: 1000000000 + n_workers: 32 # Increase from 8 + chunk_n_groups: 8 # Increase from 2 for better throughput + # queue for loaded raw groups + raw_queue_size: 32 # Increase from 8 + # queue for processed chunks of multiple groups + input_queue_size: 64 # Increase from 32 + # queue for ready chunks for multiple groups + output_queue_size: 64 # Increase from 32 + # ring buffer to replace old samples with new ones when training is slow + ring_buffer_size: 1024 # Increase from 128 + # "virtual" sample queue per lead trainer + max_ready_samples_per_lead: 256 # Increase from 64 + shared_memory_entry_size: 1000000000 # Increase from 100M # AGENT CONFIGURATION agent_max_loops: 10 # max number of agent - environment interactions for each task +agent_attempts: 3 # number of attempts to run the agent (retry on errors) +rollout_timeout: 600 # overall timeout for entire rollout in seconds (10 minutes) +reward_computation: nico agent: _target_: tapeagents.agent.Agent name : web_agent - max_iterations: 4 # max number of iterations (make_prompt + llm? + generate_steps) for each loop + max_iterations: 4 # max number of iterations (make_prompt + llm + generate_steps) for each loop store_llm_calls: true templates: system_prompt: | @@ -65,50 +76,64 @@ agent: Keep your replies concise and direct. Prioritize clarity and avoid over-elaboration. You will be provided with the content of the current page and a task from the user. Do not express your emotions or opinions about the user question. - allowed_tools: | - You have access to the following tools: - {tools_description} - thought_format: | - Important! Respond with the plain text, do not include any JSON or code. - Do not output anything besides what I asked in this message. + allowed_steps: | + You are allowed to produce ONLY steps with the following json schemas: + {allowed_steps} + Do not reproduce schema when producing the steps, use it as a reference. + json_format: | + Important! Respond with very simple parsable JSON! + Do not use any special characters or code. Do not use new lines, tabs, or any other formatting inside the JSON. + Do not output anything besides one simple JSON object. nodes: - _target_: pipelinerl.domains.miniwob.agent.WebNode name: set_goal system_prompt: ${agent.templates.system_prompt} guidance: | - Produce the thought that describes the intended solution to the task. In the reasoning lines: + Produce the reasoning_thought step that describes the intended solution to the task. In the reasoning lines: - review the instructions from the user and the content of the page. - outline the main task to be accomplished and the steps to be taken to achieve it. - produce definiton of done, that will be checked later to verify if the task was completed. - ${agent.templates.thought_format} - steps_prompt: ${agent.templates.allowed_tools} + Produce only one reasoning_thought step! + ${agent.templates.json_format} + steps_prompt: ${agent.templates.allowed_steps} + steps: + - tapeagents.steps.ReasoningThought trim_obs_except_last_n: 3 # keep the last 3 observations from the tape in prompt messages max_chars_page_observation: 3000 # keep up to 3000 chars in PageObservation steps - _target_: pipelinerl.domains.miniwob.agent.WebNode name: reflect system_prompt: ${agent.templates.system_prompt} guidance: | - Review the current state of the page and previous steps to find the best possible next action to accomplish the task. - Produce the reflection_thought to describe the current page state, reflect on your last action, describe what is left to do, and what will be the immediate next action. - Produce only one reflection_thought step! - ${agent.templates.thought_format} - steps_prompt: ${agent.templates.allowed_tools} + Produce the reasoning_thought step that describes the current state of the page, the previous actions, and what should be the next best action to accomplish the task. In the reasoning lines: + - think about which information could be relevant to the given task, note relevant BIDs and coordinates. + - describe the last action taken, what were its expected effects on the page, versus the actual effects you can observe. Are they the same or not? if not, what could have gone wrong? + - check if you are stuck with repeating the same action over and over again, if so, try something else and change the action. + - check if you think the task is done, if not give a detailed list of actions to do next to accomplish the task. + - finally, if the task is not done, describe the immediate next action to be performed and its expected effect on the page. + Produce only one reasoning_thought step! Be brief and to the point. You can skip some details if they are not relevant for this step. + ${agent.templates.json_format} + steps_prompt: ${agent.templates.allowed_steps} + steps: + - tapeagents.steps.ReasoningThought trim_obs_except_last_n: 3 # keep the last 3 observations from the tape in prompt messages max_chars_page_observation: 3000 # keep up to 3000 chars in PageObservation steps - _target_: pipelinerl.domains.miniwob.agent.WebNode name: act system_prompt: ${agent.templates.system_prompt} guidance: | - Produce the single next tool call to be performed with the current page. - If you think that the task is solved, call the FinalAnswer. + Produce the next action to be performed with the current page. + If you think that the task is solved, produce the final_answer_action. You can interact with the page elements using their BIDs or coordinates as arguments for actions. HINTS: - You can use the BIDs of the elements or the mouse position in x, y coordinates to interact with them. - - To select value in a dropdown or combobox, ALWAYS use SelectOption tool. + - To select value in a dropdown or combobox, ALWAYS use select_action. - To click on a checkbox or radio button, ALWAYS use BID (or coordinates) of the corresponding Text and not the BID (or coordinates) of the element itself. - Press enter key to submit the search query. + - Always produce only one step at a time. + - Step kind is always lowercase and underscore separated. + ${agent.templates.json_format} + steps_prompt: ${agent.templates.allowed_steps} use_known_actions: true - use_function_calls: true steps: - pipelinerl.domains.miniwob.steps.FinalAnswerAction trim_obs_except_last_n: 3 # keep the last 3 observations from the tape in prompt messages @@ -120,12 +145,12 @@ agent: start_attempts: 3 # number of attempts to start each task environment: _target_: pipelinerl.domains.miniwob.environment_server.WebEnvironmentServer - miniwob_url: file:///home/toolkit/miniwob-plusplus/miniwob/html/miniwob/ - n_envs: 64 + miniwob_url: ??? + n_envs: 32 host: "0.0.0.0" - max_session_inactivity_secs: 300 - web_env_target: pipelinerl.domains.miniwob.environment.WebEnvironment - exp_path: ${output_dir}/env_server + env_call_timeout: 60 # timeout for each environment call (e.g. start_task, act, etc.) + web_env_target: examples.rl_webagent.environment.WebEnvironment + exp_path: null headless: true observation_format: html diff --git a/conf/miniwob_grpo.yaml b/conf/miniwob_grpo.yaml new file mode 100644 index 00000000..f6cfeed3 --- /dev/null +++ b/conf/miniwob_grpo.yaml @@ -0,0 +1,10 @@ +defaults: + - miniwob + - override finetune: grpo + - _self_ + +finetune: + seq_length: 16384 # input + output tokens + max_train_steps: 1000 # 1000 optim steps = 1000 * bs samples + train_batch_size: 1 + gradient_accumulation_passes: 1024 diff --git a/conf/miniwob_uic_grpo.yaml b/conf/miniwob_uic_grpo.yaml new file mode 100644 index 00000000..7e7746c6 --- /dev/null +++ b/conf/miniwob_uic_grpo.yaml @@ -0,0 +1,16 @@ +defaults: + - miniwob_grpo + - _self_ + +train_dataset_names: + - uic_train +test_dataset_names: + - uic_train_train_heldout_goals + - uic_test + +reward_computation: uic + +finetune: + gradient_accumulation_passes: 512 + +eval_every_n_versions: 5120 # 512 effective bs * 10 "optim steps" diff --git a/conf/miniwob_uic_ppo.yaml b/conf/miniwob_uic_ppo.yaml new file mode 100644 index 00000000..db0f5792 --- /dev/null +++ b/conf/miniwob_uic_ppo.yaml @@ -0,0 +1,16 @@ +defaults: + - miniwob + - _self_ + +train_dataset_names: + - uic_train +test_dataset_names: + - uic_train_train_heldout_goals + - uic_test + +reward_computation: uic + +finetune: + gradient_accumulation_passes: 512 + +eval_every_n_versions: 5120 # 512 effective bs * 10 "optim steps" diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 1c238ff9..6d3317af 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -216,6 +216,7 @@ async def rollout_and_maybe_produce_result( f"groups in progress: {len(group_rollouts)}, " f"rollouts started so far: {started_rollouts}, " f"rollouts finished so far: {finished_rollouts}, " + f"groups started so far: {group_id}, " f"max group size in bytes: {result_queue.max_actual_entry_size()}, " ) last_logged = time.time() @@ -489,6 +490,9 @@ def run(self, dataset: list[tuple[str, dict]]): assert isinstance(rollout_results, list) assert isinstance(rollout_results[0], RolloutResult) + assert len(rollout_results) == attempts, ( + f"Expected {attempts} rollouts, got {len(rollout_results)}" + ) group_samples = sum(len(r.training_texts) for r in rollout_results) published_samples += group_samples @@ -505,7 +509,6 @@ def run(self, dataset: list[tuple[str, dict]]): f" {in_progress} groups in progress" ) - self.update_stats(rollout_results=rollout_results) finished_groups += 1 diff --git a/pipelinerl/domains/miniwob/README.md b/pipelinerl/domains/miniwob/README.md new file mode 100644 index 00000000..e9af1b42 --- /dev/null +++ b/pipelinerl/domains/miniwob/README.md @@ -0,0 +1,34 @@ +# Miniwob example + +## Prerequesites + +### TapeAgents + +Clone [TapeAgents](https://github.com/ServiceNow/TapeAgents/) in your parent folder and install it. +```bash +cd .. +git clone git@github.com:ServiceNow/TapeAgents.git +cd TapeAgents +pip install -e . +pip install 'tapeagents[finetune,converters]=0.1.12' +cd ../PipelineRL +``` + +Make sure to add the TapeAgent folder to your python path. +```bash +export PYTHONPATH="/path/to/TapeAgents:$PYTHONPATH" +``` + +### Miniwob + +see setup here: https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/miniwob/README.md + +### Playwright + +The environment server will need to have playwright installed. + +`playwright install` + +## Launch Command + +`python -m pipelinerl.launch --config-name miniwob environment.miniwob_url=file:///PATH/TO/miniwob-plusplus/miniwob/html/miniwob/` diff --git a/pipelinerl/domains/miniwob/environment_server.py b/pipelinerl/domains/miniwob/environment_server.py index 1934e09d..b30f9ef7 100644 --- a/pipelinerl/domains/miniwob/environment_server.py +++ b/pipelinerl/domains/miniwob/environment_server.py @@ -13,12 +13,14 @@ def __init__(self, exp_path: str, headless: bool = True, observation_format: str = "html", - max_session_inactivity_secs: int = 600, + env_call_timeout: int = 60, ): os.environ["MINIWOB_URL"] = miniwob_url + # Remote environment server configuration self.n_envs = n_envs self.host = host - self.max_session_inactivity_secs = max_session_inactivity_secs + self.env_call_timeout = env_call_timeout + # Individual web environment configuration self.web_env_target = web_env_target self.exp_path = exp_path self.headless = headless @@ -29,7 +31,7 @@ def launch(self, port: int): """ Serve the web environment in TapeAgent. """ - env_server = EnvironmentServer(n_envs=self.n_envs, host=self.host, port=port) + env_server = EnvironmentServer(n_envs=self.n_envs, host=self.host, port=port, env_call_timeout=self.env_call_timeout) env_server.launch(OmegaConf.create({ "_target_": self.web_env_target, "exp_path": self.exp_path, diff --git a/pipelinerl/domains/miniwob/load_tasks.py b/pipelinerl/domains/miniwob/load_tasks.py index e5056c80..a0efc261 100644 --- a/pipelinerl/domains/miniwob/load_tasks.py +++ b/pipelinerl/domains/miniwob/load_tasks.py @@ -1,4 +1,5 @@ import random + from browsergym.miniwob import ALL_MINIWOB_TASKS DEBUG_SPLIT = [ @@ -34,6 +35,132 @@ "miniwob.tic-tac-toe", "miniwob.use-autocomplete-nodelay" ] +UIC_TRAIN_SPLIT = [ + "miniwob.ascending-numbers", + "miniwob.bisect-angle", + "miniwob.book-flight", + "miniwob.choose-date", + "miniwob.choose-date-easy", + "miniwob.choose-date-medium", + "miniwob.choose-date-nodelay", + "miniwob.choose-list", + "miniwob.circle-center", + "miniwob.click-button-sequence", + "miniwob.click-checkboxes-soft", + "miniwob.click-checkboxes-transfer", + "miniwob.click-collapsible-2", + "miniwob.click-collapsible-2-nodelay", + "miniwob.click-collapsible-nodelay", + "miniwob.click-color", + "miniwob.click-dialog", + "miniwob.click-dialog-2", + "miniwob.click-link", + "miniwob.click-menu", + "miniwob.click-menu-2", + "miniwob.click-scroll-list", + "miniwob.click-shape", + "miniwob.click-tab", + "miniwob.click-tab-2", + "miniwob.click-tab-2-hard", + "miniwob.click-tab-2-medium", + "miniwob.click-test", + "miniwob.click-test-2", + "miniwob.click-test-transfer", + "miniwob.click-widget", + "miniwob.copy-paste", + "miniwob.copy-paste-2", + "miniwob.count-shape", + "miniwob.count-sides", + "miniwob.daily-calendar", + "miniwob.drag-box", + "miniwob.drag-circle", + "miniwob.drag-cube", + "miniwob.drag-items", + "miniwob.drag-items-grid", + "miniwob.drag-shapes", + "miniwob.drag-shapes-2", + "miniwob.drag-sort-numbers", + "miniwob.draw-circle", + "miniwob.draw-line", + "miniwob.email-inbox", + "miniwob.email-inbox-delete", + "miniwob.email-inbox-forward", + "miniwob.email-inbox-forward-nl", + "miniwob.email-inbox-forward-nl-turk", + "miniwob.email-inbox-important", + "miniwob.email-inbox-noscroll", + "miniwob.email-inbox-reply", + "miniwob.email-inbox-star-reply", + "miniwob.enter-date", + "miniwob.enter-text", + "miniwob.enter-text-dynamic", + "miniwob.enter-time", + "miniwob.find-greatest", + "miniwob.find-word", + "miniwob.focus-text-2", + "miniwob.form-sequence", + "miniwob.form-sequence-2", + "miniwob.generate-number", + "miniwob.grid-coordinate", + "miniwob.guess-number", + "miniwob.highlight-text", + "miniwob.hot-cold", + "miniwob.identify-shape", + "miniwob.login-user", + "miniwob.login-user-popup", + "miniwob.multi-layouts", + "miniwob.multi-orderings", + "miniwob.navigate-tree", + "miniwob.odd-or-even", + "miniwob.order-food", + "miniwob.phone-book", + "miniwob.read-table", + "miniwob.read-table-2", + "miniwob.resize-textarea", + "miniwob.right-angle", + "miniwob.scroll-text", + "miniwob.scroll-text-2", + "miniwob.search-engine", + "miniwob.sign-agreement", + "miniwob.simple-algebra", + "miniwob.social-media", + "miniwob.social-media-all", + "miniwob.social-media-some", + "miniwob.text-editor", + "miniwob.text-transform", + "miniwob.tic-tac-toe", + "miniwob.use-autocomplete", + "miniwob.use-autocomplete-nodelay", + "miniwob.use-colorwheel", + "miniwob.use-colorwheel-2", + "miniwob.use-spinner", + "miniwob.visual-addition", +] +UIC_TEST_SPLIT = [ + "miniwob.buy-ticket", + "miniwob.click-button", + "miniwob.click-option", + "miniwob.click-pie-nodelay", + "miniwob.drag-single-shape", + "miniwob.email-inbox-nl-turk", + "miniwob.enter-text-2", + "miniwob.find-midpoint", + "miniwob.focus-text", + "miniwob.simple-arithmetic", + "miniwob.stock-market", + "miniwob.use-slider-2", + "miniwob.click-checkboxes", + "miniwob.click-checkboxes-large", + "miniwob.click-collapsible", + "miniwob.click-pie", + "miniwob.click-shades", + "miniwob.click-tab-2-easy", + "miniwob.enter-password", + "miniwob.form-sequence-3", + "miniwob.highlight-text-2", + "miniwob.unicode-test", + "miniwob.use-slider", +] TRAIN_SPLIT = None TEST_SPLIT = None @@ -72,5 +199,20 @@ def load_tasks(dataset_names: list[str], train_split: float = 0.6, seeds: list[i {"dataset": "miniwob.test", "task": task, "seed": seed} for task in TEST_SPLIT for seed in seeds ]) + elif name == "uic_train": + tasks.extend([ + {"dataset": "miniwob.uic_train", "task": task, "seed": seed} + for task in UIC_TRAIN_SPLIT for seed in range(3,10) # seeds 0-2 are used for held out goals in Mass setup + ]) + elif name == "uic_train_heldout_goals": + tasks.extend([ + {"dataset": "miniwob.uic_train_heldout_goals", "task": task, "seed": seed} + for task in UIC_TRAIN_SPLIT for seed in range(3) # seeds 0-2 are used for held out goals in Mass setup + ]) + elif name == "uic_test": + tasks.extend([ + {"dataset": "miniwob.uic_test", "task": task, "seed": seed} + for task in UIC_TEST_SPLIT for seed in range(10) + ]) return tasks diff --git a/pipelinerl/domains/miniwob/rollouts.py b/pipelinerl/domains/miniwob/rollouts.py index da429562..ea850814 100644 --- a/pipelinerl/domains/miniwob/rollouts.py +++ b/pipelinerl/domains/miniwob/rollouts.py @@ -1,23 +1,26 @@ - import asyncio +import json import logging import os import random import time +import traceback import aiohttp +from examples.rl_webagent.steps import WebTape from hydra.utils import instantiate from omegaconf import DictConfig from tapeagents.agent import DEFAULT, Agent -from tapeagents.core import LLMOutputParsingFailureAction, Observation +from tapeagents.core import LLMCall, LLMOutputParsingFailureAction, Observation from tapeagents.io import save_json_tape +from tapeagents.llms.trainable import TrainableLLM from tapeagents.orchestrator import async_execute_agent from tapeagents.remote_environment import AsyncRemoteEnvironment from tapeagents.tools.simple_browser import PageObservation from pipelinerl.async_llm import make_training_text -from pipelinerl.llm import TrainableLLM, LLMCall -from pipelinerl.rollouts import RolloutResult +from pipelinerl.llm import LLMCall, TrainableLLM +from pipelinerl.rollouts import BaseMetrics, RolloutResult from pipelinerl.world import Job from .steps import WebTape @@ -25,6 +28,23 @@ logger = logging.getLogger(__name__) +class MiniwobMetrics(BaseMetrics): + reward: float + success: bool + no_error: bool + no_answer: bool + overflow: bool + n_llm_calls: int + n_step_errors: int + n_page_observations: int + n_steps: int + total_execution_time: float + agent_execution_time: float + environment_execution_time: float + env_step_time: float + agent_step_time: float + + def tape_contains_an_error(tape: WebTape) -> bool: """ Returns true if the tape ends with an error, ie if one of the following is true: @@ -33,12 +53,35 @@ def tape_contains_an_error(tape: WebTape) -> bool: - the last step is a PageObservation with an error """ return ( - isinstance(tape.steps[-1], LLMOutputParsingFailureAction) + len(tape.steps) == 0 + or isinstance(tape.steps[-1], LLMOutputParsingFailureAction) or tape.metadata.result.get("error") is not None or (isinstance(tape.steps[-1], PageObservation) and tape.steps[-1].error) ) +async def check_env_server_health(env_job: Job, session: aiohttp.ClientSession) -> dict: + """Check environment server health via HTTP API.""" + try: + url = f"http://{env_job.hostname}:{env_job.port}/health" + async with session.get(url, timeout=5) as response: + if response.status == 200: + health_data = await response.json() + return { + "healthy": True, + "health_data": health_data, + "last_check": time.time() + } + else: + error_text = await response.text() + return {"healthy": False, "error_message": f"HTTP {response.status}: {error_text}", "last_check": time.time()} + except Exception as e: + exception_type = type(e).__name__ + exception_message = str(e) if str(e) else "No message available" + logger.exception(f"Error checking environment server health: {exception_type}: {exception_message}", stack_info=True) + return {"healthy": False, "error_message": f"Exception: {exception_type}: {exception_message}", "last_check": time.time(), "error_stacktrace": traceback.format_exc()} + + async def generate_miniwob_rollout( cfg: DictConfig, llm: TrainableLLM, @@ -54,61 +97,169 @@ async def generate_miniwob_rollout( # get training text from llm calls start_time = time.time() + + # Overall timeout for the entire rollout to prevent hanging + rollout_timeout = getattr(cfg, 'rollout_timeout', 600) # 10 minutes default - # (1) Choose a random environment server env_jobs = [Job(**job) for job in cfg.jobs if job["kind"] == "environment"] - # choose the env job randomly - env_job = random.choice(env_jobs) - assert env_job.port is not None - env_job_url = f"http://{env_job.hostname}:{env_job.port}" + env_jobs_url_tried = [] + + # Try each environment server with health checks until one of them returns a rollout result + for _ in range(len(env_jobs)): + # Choose the next environment server to try randomly from the ones that have not been tried yet + env_job = random.choice([job for job in env_jobs if f"http://{job.hostname}:{job.port}" not in env_jobs_url_tried]) + env_job_url = f"http://{env_job.hostname}:{env_job.port}" + env_jobs_url_tried.append(env_job_url) + # Check server health before using + health = await check_env_server_health(env_job, session) + if not health["healthy"]: + logger.warning(f"Environment server {env_job_url} is unhealthy: {health}") + logger.warning(f"Get health error stacktrace: {health['error_stacktrace']}") + continue + # Log health status for monitoring + if health["healthy"]: + logger.info(f"Using healthy environment server {env_job_url}: {health}") + + try: + # Execute the entire rollout with a timeout + return await asyncio.wait_for( + _execute_rollout_with_timeout(cfg, llm, problem, session, start_time, env_job_url), + timeout=rollout_timeout + ) + except asyncio.TimeoutError: + health = await check_env_server_health(env_job, session) + if stack_trace := health.get("error_stacktrace"): + logger.warning(f"Get health error stacktrace: {stack_trace}") + logger.warning(f"Rollout timeout error stacktrace: {traceback.format_exc()}") + logger.warning(f"Rollout timed out after {rollout_timeout} seconds for task {problem['dataset']}/{problem['task']}/{problem['seed']} on environment {env_job_url}. Health: {health}. Trying next server.") + continue + except Exception as e: + health = await check_env_server_health(env_job, session) + if stack_trace := health.get("error_stacktrace"): + logger.warning(f"Get health error stacktrace: {stack_trace}") + logger.warning(f"Rollout failed error stacktrace: {traceback.format_exc()}") + logger.warning(f"Rollout failed for task {problem['dataset']}/{problem['task']}/{problem['seed']} on environment {env_job_url}. Health: {health}. Trying next server.") + continue + # If all servers failed + logger.error(f"All environment servers failed for task {problem['dataset']}/{problem['task']}/{problem['seed']}. Returning a failed rollout result.") + return _create_failed_rollout_result(problem, start_time, "all environment servers failed") + + +async def _execute_rollout_with_timeout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict, + session: aiohttp.ClientSession, + start_time: float, + env_job_url: str, +) -> RolloutResult: # (2) Generate environment, TapeAgent, and run them to get a Tape + no_error = True # track if there was an error in the tape environment = AsyncRemoteEnvironment(server_url=env_job_url) # type: ignore async with environment.acontext(session, wait_for_env=True) as env: start_attempts = cfg.start_attempts t = time.perf_counter() - while True: + while start_attempts > 0: try: - tape_dict, _ = await env.start_task(problem) + tape_dict, info = await env.start_task(problem) + if info.get("error"): + raise ValueError(info['error']) break except Exception as e: start_attempts -= 1 + logger.warning(f"Failed to start task {problem['dataset']}/{problem['task']}/{problem['seed']}. {start_attempts} attempts remaining. Error: {e}") if start_attempts <= 0: - raise e - logger.warning(f"Failed to start task, retry after 5 seconds: {e}") - await asyncio.sleep(5) - logger.info(f"Task {problem['dataset']}/{problem['task']}/{problem['seed']} started in {time.perf_counter() - t:.2f} seconds") + logger.error(f"Failed to start task after all retry attempts: {e}") + no_error = False + tape_dict = {} + break + else: + logger.warning("Retry start task after 5 seconds.") + await asyncio.sleep(5) + logger.info( + f"Task {problem['dataset']}/{problem['task']}/{problem['seed']} started in {time.perf_counter() - t:.2f} seconds. Worker ID: {env.worker_id}. Tape dict: {tape_dict}" + ) tape: WebTape = WebTape(**tape_dict) # convert http response dict to WebTape object t = time.perf_counter() - try: - actions = await env.a_actions() - tools_description = await env.a_tools_description() - logger.debug(f"Available tools: {tools_description}") - agent: Agent = instantiate(cfg.agent, known_actions=actions, tools_description=tools_description) - agent.llms = {DEFAULT: llm} - tape = await async_execute_agent(agent, tape, env, session, max_loops=cfg.agent_max_loops) - except Exception as e: - logger.error(f"Error occurred while running agent: {e}") - tape.metadata.result = {"execution_time": time.perf_counter() - t} + if no_error: # only run the agent if the task started successfully + logger.info(f"Running agent for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}") + agent_attempts = cfg.agent_attempts + while agent_attempts > 0: + # check if the worker is alive. + try: + # this will either raise RuntimeError if worker is not alive anymore, or return a dictionary with the worker status + worker_status = await env.check_worker_alive() + if worker_status.get("status") == "starting": + logger.warning(f"Worker {env.worker_id} for task {problem['dataset']}/{problem['task']}/{problem['seed']} and tape ID {tape.metadata.id} is starting, waiting 5 seconds for it to be fully started.") + await asyncio.sleep(5) + continue + except Exception as e: + # if worker is dead, no need to retry + logger.exception(f"Worker {env.worker_id} for task {problem['dataset']}/{problem['task']}/{problem['seed']} and tape ID {tape.metadata.id} is dead. Error: {e}", stack_info=True) + no_error = False + break + # if worker is alive, run the agent + try: + actions = await env.a_actions() + tools_description = await env.a_tools_description() + agent: Agent = instantiate(cfg.agent, known_actions=actions, tools_description=tools_description) + agent.llms = {DEFAULT: llm} + tape = await async_execute_agent(agent, tape, env, session, max_loops=cfg.agent_max_loops) + # Check if the tape has an error from the orchestrator (e.g., SocketTimeoutError, RuntimeError: Worker is not alive, etc.) + if tape.metadata.error: + logger.error(f"Agent execution for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id} returned a tape with error: {tape.metadata.error}") + raise ValueError(tape.metadata.error) + else: + # Success - break out of retry loop + logger.info(f"Agent execution for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id} finished successfully") + break + except Exception as e: + agent_attempts -= 1 + logger.warning(f"Error occurred while running agent for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}. {agent_attempts} attempts remaining. Error: {e}") + if agent_attempts <= 0: + logger.error(f"Agent execution failed after all retry attempts for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}: {e}") + no_error = False + break + else: + logger.warning(f"Retry agent execution after 5 seconds for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}.") + await asyncio.sleep(5) + logger.info( + f"Agent finished task {problem['dataset']}/{problem['task']}/{problem['seed']} in {time.perf_counter() - t:.2f} seconds with worker ID: {env.worker_id} and tape ID {tape.metadata.id}" + ) + tape.metadata.result.update({"total_execution_time": time.perf_counter() - t}) # save the tape as we go if cfg.save_tapes: save_json_tape(tape, os.path.join(cfg.output_dir, "tapes"), tape.metadata.id) # (3) Compute rewards - last_obs = [step for step in tape if isinstance(step, Observation)][-1] - # in Miniwob, the observation "reward" is defined as RAW_REWARD_GLOBAL > 0 - # see here: https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/miniwob/src/browsergym/miniwob/base.py#L183 - # Let's take directly the RAW_REWARD_GLOBAL from the metadata - # raw_reward = last_obs.metadata.other.get("reward", 0.0) - raw_reward = last_obs.metadata.other.get("info", {}).get("task_info", {}).get("REWARD_GLOBAL", -1.0) - no_error = not tape_contains_an_error(tape) + obs_steps = [step for step in tape if isinstance(step, Observation)] + if obs_steps: + last_obs = obs_steps[-1] + # in Miniwob, the observation "reward" is defined as RAW_REWARD_GLOBAL > 0 + # see here: https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/miniwob/src/browsergym/miniwob/base.py#L188 + # Let's take directly the RAW_REWARD_GLOBAL from the metadata + # raw_reward = last_obs.metadata.other.get("reward", 0.0) + raw_reward = last_obs.metadata.other.get("info", {}).get("task_info", {}).get("REWARD_GLOBAL", -1.0) + else: + raw_reward = -1.0 + + no_error = no_error and not tape_contains_an_error(tape) # get the number of LLMOutputParsingFailureAction in the tape n_step_errors = len([step for step in tape.steps if isinstance(step, LLMOutputParsingFailureAction)]) # get the number of PageObservation steps in the tape n_page_observations = len([step for step in tape.steps if isinstance(step, PageObservation)]) - reward = raw_reward * 0.99**n_step_errors if no_error and raw_reward >= 0 else -1.0 + if cfg.reward_computation == "nico": + reward = raw_reward * 0.99**n_step_errors if no_error and raw_reward >= 0 else -1.0 + elif cfg.reward_computation == "uic": + reward = float(raw_reward>0) + if reward == 0.0: + reward = -1.0 + reward *= 0.98 ** n_page_observations + else: + raise ValueError(f"Invalid reward configuration: {cfg.reward_computation}") # (3) Get LLM calls from Tape llm_calls = [step for step in tape.steps if step.metadata.other.get("llm_call") is not None] @@ -120,7 +271,7 @@ async def generate_miniwob_rollout( ] # (4) # For each LLM interaction in the tape, make a training example. - all_finished = 0 + all_finished = 1 prompt_tokens = [llm_call.prompt_length_tokens for llm_call in llm_calls] output_tokens = [llm_call.output_length_tokens for llm_call in llm_calls] training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] @@ -129,19 +280,26 @@ async def generate_miniwob_rollout( all_finished &= 1 if text.input_ids[-1] == llm.tokenizer.eos_token_id else 0 latency = time.time() - start_time - - metrics = { - "reward": reward, - "success": 1 if reward > 0.5 else 0, - "no_error": no_error, - "no_answer": 1 if reward < 0 else 0, - "overflow": 0 if all_finished else 1, - "n_llm_calls": n_llm_calls, - "n_step_errors": n_step_errors, - "n_page_observations": n_page_observations, - "n_steps": len(tape.steps), - } - logger.info(f"Created {len(training_texts)} training texts, reward: {reward}, has error: {not no_error}") + agent_time = tape.metadata.result.get("agent_execution_time", -1.0) + env_time = tape.metadata.result.get("environment_execution_time", -1.0) + n_observations = len([s for s in tape.steps if isinstance(s, Observation)]) # TODO: is this not the same n_page_observations?? + n_other_steps = len(tape.steps) - n_observations + metrics = MiniwobMetrics( + reward=reward, + success=reward > 0.5, + no_error=no_error, + no_answer=reward < 0, + overflow=not all_finished, + n_llm_calls=n_llm_calls, + n_step_errors=n_step_errors, + n_page_observations=n_page_observations, + n_steps=len(tape.steps), + total_execution_time=tape.metadata.result.get("total_execution_time", -1.0), + agent_execution_time=agent_time, + environment_execution_time=env_time, + env_step_time=env_time / n_observations if env_time > 0 and n_observations > 0 else -1.0, + agent_step_time=agent_time / n_other_steps if agent_time > 0 and n_other_steps > 0 else -1.0, + ) return RolloutResult( training_texts=training_texts, @@ -152,3 +310,34 @@ async def generate_miniwob_rollout( output_tokens=output_tokens, ) + +def _create_failed_rollout_result(problem: dict, start_time: float, error_type: str) -> RolloutResult: + """Create a failed rollout result for timeout or other errors.""" + latency = time.time() - start_time + + # Create empty training texts and metrics for failed rollout + metrics = MiniwobMetrics( + reward=-1.0, + success=False, + no_error=False, + no_answer=True, + overflow=False, + n_llm_calls=0, + n_step_errors=0, + n_page_observations=0, + n_steps=0, + total_execution_time=latency, + agent_execution_time=-1.0, + environment_execution_time=-1.0, + env_step_time=-1.0, + agent_step_time=-1.0, + ) + + return RolloutResult( + training_texts=[], + metrics=metrics, + latency=latency, + dataset_name=problem["dataset"], + prompt_tokens=[], + output_tokens=[], + ) diff --git a/pipelinerl/domains/miniwob/tool_chat_template_llama3.1_json.jinja b/pipelinerl/domains/miniwob/tool_chat_template_llama3.1_json.jinja deleted file mode 100644 index a3bc9f02..00000000 --- a/pipelinerl/domains/miniwob/tool_chat_template_llama3.1_json.jinja +++ /dev/null @@ -1,120 +0,0 @@ -{{- bos_token }} -{%- if custom_tools is defined %} - {%- set tools = custom_tools %} -{%- endif %} -{%- if not tools_in_user_message is defined %} - {#- Llama 3.1 doesn't pass all tests if the tools are in the system prompt #} - {%- set tools_in_user_message = true %} -{%- endif %} -{%- if not date_string is defined %} - {%- if strftime_now is defined %} - {%- set date_string = strftime_now("%d %b %Y") %} - {%- else %} - {%- set date_string = "26 Jul 2024" %} - {%- endif %} -{%- endif %} -{%- if not tools is defined %} - {%- set tools = none %} -{%- endif %} - -{#- This block extracts the system message, so we can slot it into the right place. #} -{%- if messages[0]['role'] == 'system' %} - {%- if messages[0]['content'] is string %} - {%- set system_message = messages[0]['content']|trim %} - {%- else %} - {%- set system_message = messages[0]['content'][0]['text']|trim %} - {%- endif %} - {%- set messages = messages[1:] %} -{%- else %} - {%- if tools is not none %} - {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} - {%- else %} - {%- set system_message = "" %} - {%- endif %} -{%- endif %} - -{#- System message #} -{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} -{%- if tools is not none %} - {{- "Environment: ipython\n" }} -{%- endif %} -{{- "Cutting Knowledge Date: December 2023\n" }} -{{- "Today Date: " + date_string + "\n\n" }} -{%- if tools is not none and not tools_in_user_message %} - {{- "You have access to the following functions. To call a function, please respond with JSON for a function call. " }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} - {{- "Do not use variables.\n\n" }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} -{%- endif %} -{{- system_message }} -{{- "<|eot_id|>" }} - -{#- Custom tools are passed in a user message with some extra guidance #} -{%- if tools_in_user_message and not tools is none %} - {#- Extract the first user message so we can plug it in here #} - {%- if messages | length != 0 %} - {%- if messages[0]['content'] is string %} - {%- set first_user_message = messages[0]['content']|trim %} - {%- else %} - {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} - {%- endif %} - {%- set messages = messages[1:] %} - {%- else %} - {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} - {%- endif %} - {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} - {{- "Given the following functions, please respond with a JSON for a function call " }} - {{- "with its proper arguments that best answers the given prompt.\n\n" }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} - {{- "Do not use variables.\n\n" }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} - {{- first_user_message + "<|eot_id|>"}} -{%- endif %} - -{%- for message in messages %} - {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} - {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} - {%- if message['content'] is string %} - {{- message['content'] | trim}} - {%- else %} - {%- for content in message['content'] %} - {%- if content['type'] == 'text' %} - {{- content['text'] | trim }} - {%- endif %} - {%- endfor %} - {%- endif %} - {{- '<|eot_id|>' }} - {%- elif 'tool_calls' in message %} - {%- if not message.tool_calls|length == 1 %} - {{- raise_exception("This model only supports single tool-calls at once!") }} - {%- endif %} - {%- set tool_call = message.tool_calls[0].function %} - {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} - {{- '{"name": "' + tool_call.name + '", ' }} - {{- '"parameters": ' }} - {{- tool_call.arguments | tojson }} - {{- "}" }} - {{- "<|eot_id|>" }} - {%- elif message.role == "tool" or message.role == "ipython" %} - {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} - {%- if message.content is string %} - {{- { "output": message.content } | tojson }} - {%- else %} - {%- for content in message['content'] %} - {%- if content['type'] == 'text' %} - {{- { "output": content['text'] } | tojson }} - {%- endif %} - {%- endfor %} - {%- endif %} - {{- "<|eot_id|>" }} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} -{%- endif %} \ No newline at end of file diff --git a/pipelinerl/launch.py b/pipelinerl/launch.py index 114be006..8d2c33d8 100644 --- a/pipelinerl/launch.py +++ b/pipelinerl/launch.py @@ -78,6 +78,13 @@ def validate_config(cfg: DictConfig): if not hasattr(cfg.finetune.rl, "value_loss_coef") or cfg.finetune.rl.value_loss_coef <= 0.0: raise ValueError("value_loss_coef must be greater than 0 when using causal-language-modeling-with-value-head") + # Check that model being tuned to the max length accepted by inference + if cfg.finetune.seq_length < cfg.vllm_config.vllm_kwargs.max_model_len: + raise ValueError( + f"seq_length {cfg.finetune.seq_length} must be greater than or equal to " + f"vllm_kwargs.max_model_len {cfg.vllm_config.vllm_kwargs.max_model_len}" + ) + # Check for asymmetric PPO clipping if cfg.finetune.rl.policy_loss == "ppo" and cfg.finetune.rl.epsilon_low != cfg.finetune.rl.epsilon_high: if cfg.finetune.model_class == "causal-language-modeling-with-value-head": diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 65fcee47..0a6015e4 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -157,7 +157,19 @@ def preprocess_dataset( entry["step_index"] = entry["metadata"]["step_index"] if not isinstance(tokenizer.eos_token_id, int): raise ValueError(f"Tokenizer {tokenizer} does not have an eos_token_id") - dataset = populate_rl_data(dataset=dataset, eos_token_id=tokenizer.eos_token_id, config=rl_config) + try: + dataset = populate_rl_data(dataset=dataset, eos_token_id=tokenizer.eos_token_id, config=rl_config) + except Exception as e: + logger.error(f"Error in populate_rl_data: {e}", extra={ + "data": data, + "dataset": dataset, + "tokenizer": tokenizer, + "eos_token_id": tokenizer.eos_token_id, + "rl_config": rl_config, + "llm": llm, + "seq_length": seq_length, + }) + raise return dataset diff --git a/pipelinerl/utils.py b/pipelinerl/utils.py index fbcd9926..bd7fe5b5 100644 --- a/pipelinerl/utils.py +++ b/pipelinerl/utils.py @@ -237,6 +237,9 @@ def calculate_stats(stats: List | Dict[Any, Any]) -> Dict[str, float]: if not isinstance(stats, list): raise TypeError(f"Expected stats to be a list, got {type(stats)}") + if len(stats) == 0: + return {} + aggregated_stats = { "max": float(max(stats)), "min": float(min(stats)), diff --git a/pipelinerl/world.py b/pipelinerl/world.py index f41714e4..992a7c4d 100644 --- a/pipelinerl/world.py +++ b/pipelinerl/world.py @@ -1,9 +1,9 @@ import logging import os -from typing import Literal -from pydantic import BaseModel -from omegaconf import DictConfig + import torch +from omegaconf import DictConfig +from pydantic import BaseModel logger = logging.getLogger(__name__)