From 71a4d6a0a9af4e360cf84a54ef12552610867f27 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 19 Nov 2025 12:45:14 +0100 Subject: [PATCH 1/4] Update OpenEnv examples --- docs/source/openenv.md | 2 +- examples/scripts/openenv/catch.py | 2 +- examples/scripts/openenv/echo.py | 26 +++++++++++++++++++------- examples/scripts/openenv/wordle.py | 2 +- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/docs/source/openenv.md b/docs/source/openenv.md index 6829627dcfd..f9fe26c0c6f 100644 --- a/docs/source/openenv.md +++ b/docs/source/openenv.md @@ -89,7 +89,7 @@ args = GRPOConfig( args = GRPOConfig( use_vllm=True, vllm_mode="server", - vllm_server_url="http://localhost:8000", + vllm_server_base_url="http://localhost:8000", # ... other args ) diff --git a/examples/scripts/openenv/catch.py b/examples/scripts/openenv/catch.py index 1e93b7dc0a9..ab109d818b7 100644 --- a/examples/scripts/openenv/catch.py +++ b/examples/scripts/openenv/catch.py @@ -201,7 +201,7 @@ def main(): output_dir=f"{args.model.split('/')[-1]}-GRPO-Catch", use_vllm=True, vllm_mode=args.vllm_mode, - vllm_server_url=args.vllm_server_url if args.vllm_mode == "server" else None, + vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None, logging_steps=1, report_to="trackio", num_train_epochs=1, diff --git a/examples/scripts/openenv/echo.py b/examples/scripts/openenv/echo.py index e890579c527..a38c6e5ce90 100644 --- a/examples/scripts/openenv/echo.py +++ b/examples/scripts/openenv/echo.py @@ -71,9 +71,9 @@ def parse_args(): parser.add_argument("--env-port", type=int, default=8001, help="Port for the Echo environment.") parser.add_argument( "--env-mode", - choices=["local", "docker", "space"], - default="docker", - help="Where to run the Echo environment: 'local' to launch it, 'docker' if already running, or 'space' to use a remote Space URL.", + choices=["local", "docker-local", "docker-image", "docker-hub", "space"], + default="docker-image", + help="Where to run the Echo environment: 'local' to launch it, 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.", ) parser.add_argument( "--model", @@ -87,6 +87,9 @@ def parse_args(): default="trl-lib/ultrafeedback-prompt", help="Dataset to use for training.", ) + parser.add_argument( + "--env-image", type=str, default="echo-env:latest", help="Docker image for the Echo environment." + ) parser.add_argument( "--vllm-mode", choices=["colocate", "server"], @@ -146,25 +149,34 @@ def main(): if args.env_mode == "local": env_url = f"http://{args.env_host}:{args.env_port}" server_process = start_env_server(args.env_host, args.env_port) - elif args.env_mode == "docker": + elif args.env_mode == "docker-local": env_url = f"http://{args.env_host}:{args.env_port}" server_process = None print(f"🌍 Using existing Echo Environment (Docker) at: {env_url}") + elif args.env_mode == "docker-image": + client = EchoEnv.from_docker_image(args.env_image) + server_process = None + print("🌍 Using Echo Environment (Docker) from local Image") + elif args.env_mode == "docker-hub": + client = EchoEnv.from_hub(args.env_image) + server_process = None + print("🌍 Using existing Echo Environment (Docker) from Hub Image") elif args.env_mode == "space": env_url = args.env_host server_process = None - print(f"🚀 Using Hugging Face Space environment at: {env_url}") + print(f"🌍 Using Hugging Face Space environment at: {env_url}") else: raise ValueError(f"Unknown environment mode: {args.env_mode}") - client = EchoEnv(base_url=env_url) + if args.env_mode != "docker-hub" and args.env_mode != "docker-image": + client = EchoEnv(base_url=env_url) dataset = load_dataset(args.dataset, split="train[:1000]") training_args = GRPOConfig( output_dir=f"{args.model.split('/')[-1]}-GRPO-Rollout", use_vllm=True, vllm_mode=args.vllm_mode, - vllm_server_url=args.vllm_server_url if args.vllm_mode == "server" else None, + vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None, logging_steps=1, report_to="trackio", num_train_epochs=1, diff --git a/examples/scripts/openenv/wordle.py b/examples/scripts/openenv/wordle.py index e4ab82196d0..bf53ec2099c 100644 --- a/examples/scripts/openenv/wordle.py +++ b/examples/scripts/openenv/wordle.py @@ -429,7 +429,7 @@ def main() -> None: grpo_config = GRPOConfig( use_vllm=True, vllm_mode=cli_args.vllm_mode, - vllm_server_url=cli_args.vllm_server_url if cli_args.vllm_mode == "server" else None, + vllm_server_base_url=cli_args.vllm_server_url if cli_args.vllm_mode == "server" else None, output_dir=str(output_dir), num_train_epochs=cli_args.num_epochs, learning_rate=cli_args.learning_rate, From 145c87e5d1f4581da1cfbd5e7da96a20c095bb31 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 19 Nov 2025 13:03:27 +0100 Subject: [PATCH 2/4] Update catch --- examples/scripts/openenv/catch.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/examples/scripts/openenv/catch.py b/examples/scripts/openenv/catch.py index ab109d818b7..2578f39041d 100644 --- a/examples/scripts/openenv/catch.py +++ b/examples/scripts/openenv/catch.py @@ -73,9 +73,9 @@ def parse_args(): parser.add_argument("--env-port", type=int, default=8001, help="Port for the environment server.") parser.add_argument( "--env-mode", - choices=["local", "docker", "space"], - default="docker", - help="Where to run the environment: 'local', 'docker', or 'space'.", + choices=["local", "docker-local", "docker-image", "docker-hub", "space"], + default="docker-image", + help="Where to run the environment: 'local' to launch it, 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.", ) # --- Generation and model config --- parser.add_argument( @@ -90,6 +90,9 @@ def parse_args(): default=1000, help="Number of prompts to use for training dataset.", ) + parser.add_argument( + "--env-image", type=str, default="openspiel-env:latest", help="Docker image for the OpenSpiel environment." + ) parser.add_argument( "--vllm-mode", choices=["colocate", "server"], @@ -183,18 +186,27 @@ def main(): if args.env_mode == "local": env_url = f"http://{args.env_host}:{args.env_port}" server_process = start_env_server(args.env_host, args.env_port) - elif args.env_mode == "docker": + elif args.env_mode == "docker-local": env_url = f"http://{args.env_host}:{args.env_port}" server_process = None - print(f"🌍 Using existing Docker environment at {env_url}") + print(f"🌍 Using existing OpenSpiel Environment (Docker) at: {env_url}") + elif args.env_mode == "docker-image": + client = OpenSpielEnv.from_docker_image(args.env_image) + server_process = None + print("🌍 Using OpenSpiel Environment (Docker) from local Image") + elif args.env_mode == "docker-hub": + client = OpenSpielEnv.from_hub(args.env_image) + server_process = None + print("🌍 Using existing OpenSpiel Environment (Docker) from Hub Image") elif args.env_mode == "space": env_url = args.env_host server_process = None - print(f"🚀 Using Hugging Face Space environment at {env_url}") + print(f"🌍 Using Hugging Face Space environment at: {env_url}") else: - raise ValueError(f"Unknown env mode: {args.env_mode}") + raise ValueError(f"Unknown environment mode: {args.env_mode}") - client = OpenSpielEnv(base_url=env_url) + if args.env_mode != "docker-hub" and args.env_mode != "docker-image": + client = OpenSpielEnv(base_url=env_url) dataset = Dataset.from_dict({"prompt": [BASE_PROMPT] * args.dataset_size}) training_args = GRPOConfig( From 548526b1ee435a2bec44ef8c3d2379c11faa65cf Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 19 Nov 2025 14:20:46 +0100 Subject: [PATCH 3/4] Update wordle script --- examples/scripts/openenv/wordle.py | 91 +++++++++++++++++++----------- 1 file changed, 57 insertions(+), 34 deletions(-) diff --git a/examples/scripts/openenv/wordle.py b/examples/scripts/openenv/wordle.py index bf53ec2099c..69cb3dd50b6 100644 --- a/examples/scripts/openenv/wordle.py +++ b/examples/scripts/openenv/wordle.py @@ -85,11 +85,15 @@ def parse_args() -> argparse.Namespace: default="Qwen/Qwen3-1.7B", help="Model identifier passed to GRPOTrainer for fine-tuning.", ) + parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the environment server.") + parser.add_argument("--env-port", type=int, default=8001, help="Port for the environment server.") parser.add_argument( - "--env-url", - default="https://burtenshaw-textarena.hf.space", - help="URL for the TextArena Wordle environment.", + "--env-mode", + choices=["docker-local", "docker-image", "docker-hub", "space"], + default="docker-image", + help="Where to run the environment: 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.", ) + parser.add_argument("--env-image", type=str, default="textarena-env:latest", help="Docker image for the TextArena environment.") parser.add_argument( "--system-prompt-path", default="wordle_prompt.txt", @@ -411,46 +415,65 @@ def reward_repetition(completions: list[str], **kwargs) -> list[float]: def main() -> None: - cli_args = parse_args() + args = parse_args() - tokenizer = AutoTokenizer.from_pretrained(cli_args.tokenizer_id) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id) tokenizer.pad_token = tokenizer.eos_token - env = TextArenaEnv(base_url=cli_args.env_url) + # Select environment mode + if args.env_mode == "docker-local": + env_url = f"http://{args.env_host}:{args.env_port}" + print(f"🌍 Using existing TextArena Environment (Docker) at: {env_url}") + elif args.env_mode == "docker-image": + client = TextArenaEnv.from_docker_image(args.env_image) + print(f"🌍 Using TextArena Environment (Docker) from local Image") + elif args.env_mode == "docker-hub": + client = TextArenaEnv.from_hub(args.env_image) + print(f"🌍 Using existing TextArena Environment (Docker) from Hub Image") + elif args.env_mode == "space": + env_url = args.env_host + print(f"🌍 Using Hugging Face Space environment at: {env_url}") + else: + raise ValueError(f"Unknown environment mode: {args.env_mode}") - system_prompt = resolve_system_prompt(cli_args.system_prompt_path) + if args.env_mode != "docker-hub" and args.env_mode != "docker-image": + client = TextArenaEnv(base_url=env_url) - dataset = Dataset.from_dict({"prompt": [cli_args.dataset_prompt] * cli_args.dataset_size}) + #env = TextArenaEnv(base_url=args.env_url) + + system_prompt = resolve_system_prompt(args.system_prompt_path) + + dataset = Dataset.from_dict({"prompt": [args.dataset_prompt] * args.dataset_size}) timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - default_output_dir = Path("outputs") / f"wordle-grpo-{sanitize_name(cli_args.model_id)}-{timestamp}" - output_dir = Path(cli_args.output_dir or default_output_dir) + default_output_dir = Path("outputs") / f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}" + output_dir = Path(args.output_dir or default_output_dir) grpo_config = GRPOConfig( use_vllm=True, - vllm_mode=cli_args.vllm_mode, - vllm_server_base_url=cli_args.vllm_server_url if cli_args.vllm_mode == "server" else None, + vllm_mode=args.vllm_mode, + vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None, output_dir=str(output_dir), - num_train_epochs=cli_args.num_epochs, - learning_rate=cli_args.learning_rate, - weight_decay=cli_args.weight_decay, - gradient_accumulation_steps=cli_args.gradient_accumulation_steps, - per_device_train_batch_size=cli_args.per_device_batch_size, - warmup_steps=cli_args.warmup_steps, - num_generations=cli_args.num_generations, - max_completion_length=cli_args.max_new_tokens, - logging_steps=cli_args.logging_steps, + num_train_epochs=args.num_epochs, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + gradient_accumulation_steps=args.gradient_accumulation_steps, + per_device_train_batch_size=args.per_device_batch_size, + warmup_steps=args.warmup_steps, + num_generations=args.num_generations, + max_completion_length=args.max_new_tokens, + logging_steps=args.logging_steps, save_strategy="steps", - save_steps=cli_args.save_interval, - save_total_limit=cli_args.save_total_limit, - temperature=cli_args.temperature, - top_k=cli_args.top_k, - top_p=cli_args.top_p, + save_steps=args.save_interval, + save_total_limit=args.save_total_limit, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, ) - grpo_config.run_name = cli_args.run_name or f"run-{timestamp}" - grpo_config.project = cli_args.project or f"group-{sanitize_name(cli_args.model_id)}" - grpo_config.trackio_space_id = cli_args.trackio_space_id + grpo_config.run_name = args.run_name or f"run-{timestamp}" + grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}" + grpo_config.trackio_space_id = args.trackio_space_id def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: episode_prompt_ids: list[list[int]] = [] @@ -464,11 +487,11 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: for prompt_text in prompts: episode = rollout_once( trainer=trainer, - env=env, + env=client, tokenizer=tokenizer, dataset_prompt=prompt_text, system_prompt=system_prompt, - max_turns=cli_args.max_turns, + max_turns=args.max_turns, ) episode_prompt_ids.append(episode["prompt_ids"]) episode_completion_ids.append(episode["completion_ids"]) @@ -489,7 +512,7 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: } trainer = GRPOTrainer( - model=cli_args.model_id, + model=args.model_id, processing_class=tokenizer, reward_funcs=[ reward_correct, @@ -503,12 +526,12 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: ) print("Starting GRPO training with Wordle environment...") - print(f"Using {cli_args.num_generations} rollouts per dataset prompt") + print(f"Using {args.num_generations} rollouts per dataset prompt") try: trainer.train() finally: - env.close() + client.close() if __name__ == "__main__": From 8abaf5f6295e5a533da37e7a906beaa1136fe265 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 19 Nov 2025 14:35:41 +0100 Subject: [PATCH 4/4] Small nit --- examples/scripts/openenv/catch.py | 2 +- examples/scripts/openenv/echo.py | 2 +- examples/scripts/openenv/wordle.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/scripts/openenv/catch.py b/examples/scripts/openenv/catch.py index 2578f39041d..5bac3b1cfbe 100644 --- a/examples/scripts/openenv/catch.py +++ b/examples/scripts/openenv/catch.py @@ -24,7 +24,7 @@ Usage: -# Start the docker container for the Catch environment (recommended). Alternatively, you can run it locally or directly from a HF Space. +# Start the environment only if using --env-mode docker-local; In other modes, the env is automatically managed by the script. ```sh docker run -d -p 8001:8001 registry.hf.space/openenv-openspiel-env:latest ``` diff --git a/examples/scripts/openenv/echo.py b/examples/scripts/openenv/echo.py index a38c6e5ce90..91a960c3c7e 100644 --- a/examples/scripts/openenv/echo.py +++ b/examples/scripts/openenv/echo.py @@ -24,7 +24,7 @@ Usage: -# Start the docker container for the Echo environment (recommended). Alternatively, you can run it locally or directly from a HF Space. +# Start the environment only if using --env-mode docker-local; In other modes, the env is automatically managed by the script. ```sh docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest ``` diff --git a/examples/scripts/openenv/wordle.py b/examples/scripts/openenv/wordle.py index 69cb3dd50b6..d37ee9eefe2 100644 --- a/examples/scripts/openenv/wordle.py +++ b/examples/scripts/openenv/wordle.py @@ -23,7 +23,7 @@ Usage: -# Start the docker container for the Wordle environment (recommended). Alternatively, you can run it locally or directly from a HF Space. +# Start the environment only if using --env-mode docker-local; In other modes, the env is automatically managed by the script. ```sh docker run -d -p 8001:8001 registry.hf.space/burtenshaw-textarena:latest # or TEXTARENA_ENV_ID=Wordle-v0 TEXTARENA_NUM_PLAYERS=1 python -m src.envs.textarena_env.server.app