Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/openenv.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ args = GRPOConfig(
args = GRPOConfig(
use_vllm=True,
vllm_mode="server",
vllm_server_url="http://localhost:8000",
vllm_server_base_url="http://localhost:8000",
# ... other args
)

Expand Down
32 changes: 22 additions & 10 deletions examples/scripts/openenv/catch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

Usage:

# Start the docker container for the Catch environment (recommended). Alternatively, you can run it locally or directly from a HF Space.
# Start the environment only if using --env-mode docker-local; In other modes, the env is automatically managed by the script.
```sh
docker run -d -p 8001:8001 registry.hf.space/openenv-openspiel-env:latest
```
Expand Down Expand Up @@ -73,9 +73,9 @@ def parse_args():
parser.add_argument("--env-port", type=int, default=8001, help="Port for the environment server.")
parser.add_argument(
"--env-mode",
choices=["local", "docker", "space"],
default="docker",
help="Where to run the environment: 'local', 'docker', or 'space'.",
choices=["local", "docker-local", "docker-image", "docker-hub", "space"],
default="docker-image",
help="Where to run the environment: 'local' to launch it, 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.",
)
# --- Generation and model config ---
parser.add_argument(
Expand All @@ -90,6 +90,9 @@ def parse_args():
default=1000,
help="Number of prompts to use for training dataset.",
)
parser.add_argument(
"--env-image", type=str, default="openspiel-env:latest", help="Docker image for the OpenSpiel environment."
)
parser.add_argument(
"--vllm-mode",
choices=["colocate", "server"],
Expand Down Expand Up @@ -183,25 +186,34 @@ def main():
if args.env_mode == "local":
env_url = f"http://{args.env_host}:{args.env_port}"
server_process = start_env_server(args.env_host, args.env_port)
elif args.env_mode == "docker":
elif args.env_mode == "docker-local":
env_url = f"http://{args.env_host}:{args.env_port}"
server_process = None
print(f"🌍 Using existing Docker environment at {env_url}")
print(f"🌍 Using existing OpenSpiel Environment (Docker) at: {env_url}")
elif args.env_mode == "docker-image":
client = OpenSpielEnv.from_docker_image(args.env_image)
server_process = None
print("🌍 Using OpenSpiel Environment (Docker) from local Image")
elif args.env_mode == "docker-hub":
client = OpenSpielEnv.from_hub(args.env_image)
server_process = None
print("🌍 Using existing OpenSpiel Environment (Docker) from Hub Image")
elif args.env_mode == "space":
env_url = args.env_host
server_process = None
print(f"🚀 Using Hugging Face Space environment at {env_url}")
print(f"🌍 Using Hugging Face Space environment at: {env_url}")
else:
raise ValueError(f"Unknown env mode: {args.env_mode}")
raise ValueError(f"Unknown environment mode: {args.env_mode}")

client = OpenSpielEnv(base_url=env_url)
if args.env_mode != "docker-hub" and args.env_mode != "docker-image":
client = OpenSpielEnv(base_url=env_url)
dataset = Dataset.from_dict({"prompt": [BASE_PROMPT] * args.dataset_size})

training_args = GRPOConfig(
output_dir=f"{args.model.split('/')[-1]}-GRPO-Catch",
use_vllm=True,
vllm_mode=args.vllm_mode,
vllm_server_url=args.vllm_server_url if args.vllm_mode == "server" else None,
vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
logging_steps=1,
report_to="trackio",
num_train_epochs=1,
Expand Down
28 changes: 20 additions & 8 deletions examples/scripts/openenv/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

Usage:

# Start the docker container for the Echo environment (recommended). Alternatively, you can run it locally or directly from a HF Space.
# Start the environment only if using --env-mode docker-local; In other modes, the env is automatically managed by the script.
```sh
docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest
```
Expand Down Expand Up @@ -71,9 +71,9 @@ def parse_args():
parser.add_argument("--env-port", type=int, default=8001, help="Port for the Echo environment.")
parser.add_argument(
"--env-mode",
choices=["local", "docker", "space"],
default="docker",
help="Where to run the Echo environment: 'local' to launch it, 'docker' if already running, or 'space' to use a remote Space URL.",
choices=["local", "docker-local", "docker-image", "docker-hub", "space"],
default="docker-image",
help="Where to run the Echo environment: 'local' to launch it, 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.",
)
parser.add_argument(
"--model",
Expand All @@ -87,6 +87,9 @@ def parse_args():
default="trl-lib/ultrafeedback-prompt",
help="Dataset to use for training.",
)
parser.add_argument(
"--env-image", type=str, default="echo-env:latest", help="Docker image for the Echo environment."
)
parser.add_argument(
"--vllm-mode",
choices=["colocate", "server"],
Expand Down Expand Up @@ -146,25 +149,34 @@ def main():
if args.env_mode == "local":
env_url = f"http://{args.env_host}:{args.env_port}"
server_process = start_env_server(args.env_host, args.env_port)
elif args.env_mode == "docker":
elif args.env_mode == "docker-local":
env_url = f"http://{args.env_host}:{args.env_port}"
server_process = None
print(f"🌍 Using existing Echo Environment (Docker) at: {env_url}")
elif args.env_mode == "docker-image":
client = EchoEnv.from_docker_image(args.env_image)
server_process = None
print("🌍 Using Echo Environment (Docker) from local Image")
elif args.env_mode == "docker-hub":
client = EchoEnv.from_hub(args.env_image)
server_process = None
print("🌍 Using existing Echo Environment (Docker) from Hub Image")
elif args.env_mode == "space":
env_url = args.env_host
server_process = None
print(f"🚀 Using Hugging Face Space environment at: {env_url}")
print(f"🌍 Using Hugging Face Space environment at: {env_url}")
else:
raise ValueError(f"Unknown environment mode: {args.env_mode}")

client = EchoEnv(base_url=env_url)
if args.env_mode != "docker-hub" and args.env_mode != "docker-image":
client = EchoEnv(base_url=env_url)
dataset = load_dataset(args.dataset, split="train[:1000]")

training_args = GRPOConfig(
output_dir=f"{args.model.split('/')[-1]}-GRPO-Rollout",
use_vllm=True,
vllm_mode=args.vllm_mode,
vllm_server_url=args.vllm_server_url if args.vllm_mode == "server" else None,
vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
logging_steps=1,
report_to="trackio",
num_train_epochs=1,
Expand Down
93 changes: 58 additions & 35 deletions examples/scripts/openenv/wordle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

Usage:

# Start the docker container for the Wordle environment (recommended). Alternatively, you can run it locally or directly from a HF Space.
# Start the environment only if using --env-mode docker-local; In other modes, the env is automatically managed by the script.
```sh
docker run -d -p 8001:8001 registry.hf.space/burtenshaw-textarena:latest
# or TEXTARENA_ENV_ID=Wordle-v0 TEXTARENA_NUM_PLAYERS=1 python -m src.envs.textarena_env.server.app
Expand Down Expand Up @@ -85,11 +85,15 @@ def parse_args() -> argparse.Namespace:
default="Qwen/Qwen3-1.7B",
help="Model identifier passed to GRPOTrainer for fine-tuning.",
)
parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the environment server.")
parser.add_argument("--env-port", type=int, default=8001, help="Port for the environment server.")
parser.add_argument(
"--env-url",
default="https://burtenshaw-textarena.hf.space",
help="URL for the TextArena Wordle environment.",
"--env-mode",
choices=["docker-local", "docker-image", "docker-hub", "space"],
default="docker-image",
help="Where to run the environment: 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.",
)
parser.add_argument("--env-image", type=str, default="textarena-env:latest", help="Docker image for the TextArena environment.")
parser.add_argument(
"--system-prompt-path",
default="wordle_prompt.txt",
Expand Down Expand Up @@ -411,46 +415,65 @@ def reward_repetition(completions: list[str], **kwargs) -> list[float]:


def main() -> None:
cli_args = parse_args()
args = parse_args()

tokenizer = AutoTokenizer.from_pretrained(cli_args.tokenizer_id)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
tokenizer.pad_token = tokenizer.eos_token

env = TextArenaEnv(base_url=cli_args.env_url)
# Select environment mode
if args.env_mode == "docker-local":
env_url = f"http://{args.env_host}:{args.env_port}"
print(f"🌍 Using existing TextArena Environment (Docker) at: {env_url}")
elif args.env_mode == "docker-image":
client = TextArenaEnv.from_docker_image(args.env_image)
print(f"🌍 Using TextArena Environment (Docker) from local Image")
elif args.env_mode == "docker-hub":
client = TextArenaEnv.from_hub(args.env_image)
print(f"🌍 Using existing TextArena Environment (Docker) from Hub Image")
elif args.env_mode == "space":
env_url = args.env_host
print(f"🌍 Using Hugging Face Space environment at: {env_url}")
else:
raise ValueError(f"Unknown environment mode: {args.env_mode}")

system_prompt = resolve_system_prompt(cli_args.system_prompt_path)
if args.env_mode != "docker-hub" and args.env_mode != "docker-image":
client = TextArenaEnv(base_url=env_url)

dataset = Dataset.from_dict({"prompt": [cli_args.dataset_prompt] * cli_args.dataset_size})
#env = TextArenaEnv(base_url=args.env_url)

system_prompt = resolve_system_prompt(args.system_prompt_path)

dataset = Dataset.from_dict({"prompt": [args.dataset_prompt] * args.dataset_size})

timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
default_output_dir = Path("outputs") / f"wordle-grpo-{sanitize_name(cli_args.model_id)}-{timestamp}"
output_dir = Path(cli_args.output_dir or default_output_dir)
default_output_dir = Path("outputs") / f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}"
output_dir = Path(args.output_dir or default_output_dir)

grpo_config = GRPOConfig(
use_vllm=True,
vllm_mode=cli_args.vllm_mode,
vllm_server_url=cli_args.vllm_server_url if cli_args.vllm_mode == "server" else None,
vllm_mode=args.vllm_mode,
vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
output_dir=str(output_dir),
num_train_epochs=cli_args.num_epochs,
learning_rate=cli_args.learning_rate,
weight_decay=cli_args.weight_decay,
gradient_accumulation_steps=cli_args.gradient_accumulation_steps,
per_device_train_batch_size=cli_args.per_device_batch_size,
warmup_steps=cli_args.warmup_steps,
num_generations=cli_args.num_generations,
max_completion_length=cli_args.max_new_tokens,
logging_steps=cli_args.logging_steps,
num_train_epochs=args.num_epochs,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
gradient_accumulation_steps=args.gradient_accumulation_steps,
per_device_train_batch_size=args.per_device_batch_size,
warmup_steps=args.warmup_steps,
num_generations=args.num_generations,
max_completion_length=args.max_new_tokens,
logging_steps=args.logging_steps,
save_strategy="steps",
save_steps=cli_args.save_interval,
save_total_limit=cli_args.save_total_limit,
temperature=cli_args.temperature,
top_k=cli_args.top_k,
top_p=cli_args.top_p,
save_steps=args.save_interval,
save_total_limit=args.save_total_limit,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)

grpo_config.run_name = cli_args.run_name or f"run-{timestamp}"
grpo_config.project = cli_args.project or f"group-{sanitize_name(cli_args.model_id)}"
grpo_config.trackio_space_id = cli_args.trackio_space_id
grpo_config.run_name = args.run_name or f"run-{timestamp}"
grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}"
grpo_config.trackio_space_id = args.trackio_space_id

def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
episode_prompt_ids: list[list[int]] = []
Expand All @@ -464,11 +487,11 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
for prompt_text in prompts:
episode = rollout_once(
trainer=trainer,
env=env,
env=client,
tokenizer=tokenizer,
dataset_prompt=prompt_text,
system_prompt=system_prompt,
max_turns=cli_args.max_turns,
max_turns=args.max_turns,
)
episode_prompt_ids.append(episode["prompt_ids"])
episode_completion_ids.append(episode["completion_ids"])
Expand All @@ -489,7 +512,7 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
}

trainer = GRPOTrainer(
model=cli_args.model_id,
model=args.model_id,
processing_class=tokenizer,
reward_funcs=[
reward_correct,
Expand All @@ -503,12 +526,12 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
)

print("Starting GRPO training with Wordle environment...")
print(f"Using {cli_args.num_generations} rollouts per dataset prompt")
print(f"Using {args.num_generations} rollouts per dataset prompt")

try:
trainer.train()
finally:
env.close()
client.close()


if __name__ == "__main__":
Expand Down
Loading