Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/openenv.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ args = GRPOConfig(
args = GRPOConfig(
use_vllm=True,
vllm_mode="server",
vllm_server_url="http://localhost:8000",
vllm_server_base_url="http://localhost:8000",
# ... other args
)

Expand Down
30 changes: 21 additions & 9 deletions examples/scripts/openenv/catch.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def parse_args():
parser.add_argument("--env-port", type=int, default=8001, help="Port for the environment server.")
parser.add_argument(
"--env-mode",
choices=["local", "docker", "space"],
default="docker",
help="Where to run the environment: 'local', 'docker', or 'space'.",
choices=["local", "docker-local", "docker-image", "docker-hub", "space"],
default="docker-image",
help="Where to run the environment: 'local' to launch it, 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.",
)
# --- Generation and model config ---
parser.add_argument(
Expand All @@ -90,6 +90,9 @@ def parse_args():
default=1000,
help="Number of prompts to use for training dataset.",
)
parser.add_argument(
"--env-image", type=str, default="openspiel-env:latest", help="Docker image for the OpenSpiel environment."
)
parser.add_argument(
"--vllm-mode",
choices=["colocate", "server"],
Expand Down Expand Up @@ -183,25 +186,34 @@ def main():
if args.env_mode == "local":
env_url = f"http://{args.env_host}:{args.env_port}"
server_process = start_env_server(args.env_host, args.env_port)
elif args.env_mode == "docker":
elif args.env_mode == "docker-local":
env_url = f"http://{args.env_host}:{args.env_port}"
server_process = None
print(f"🌍 Using existing Docker environment at {env_url}")
print(f"🌍 Using existing OpenSpiel Environment (Docker) at: {env_url}")
elif args.env_mode == "docker-image":
client = OpenSpielEnv.from_docker_image(args.env_image)
server_process = None
print("🌍 Using OpenSpiel Environment (Docker) from local Image")
elif args.env_mode == "docker-hub":
client = OpenSpielEnv.from_hub(args.env_image)
server_process = None
print("🌍 Using existing OpenSpiel Environment (Docker) from Hub Image")
elif args.env_mode == "space":
env_url = args.env_host
server_process = None
print(f"🚀 Using Hugging Face Space environment at {env_url}")
print(f"🌍 Using Hugging Face Space environment at: {env_url}")
else:
raise ValueError(f"Unknown env mode: {args.env_mode}")
raise ValueError(f"Unknown environment mode: {args.env_mode}")

client = OpenSpielEnv(base_url=env_url)
if args.env_mode != "docker-hub" and args.env_mode != "docker-image":
client = OpenSpielEnv(base_url=env_url)
dataset = Dataset.from_dict({"prompt": [BASE_PROMPT] * args.dataset_size})

training_args = GRPOConfig(
output_dir=f"{args.model.split('/')[-1]}-GRPO-Catch",
use_vllm=True,
vllm_mode=args.vllm_mode,
vllm_server_url=args.vllm_server_url if args.vllm_mode == "server" else None,
vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
logging_steps=1,
report_to="trackio",
num_train_epochs=1,
Expand Down
26 changes: 19 additions & 7 deletions examples/scripts/openenv/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def parse_args():
parser.add_argument("--env-port", type=int, default=8001, help="Port for the Echo environment.")
parser.add_argument(
"--env-mode",
choices=["local", "docker", "space"],
default="docker",
help="Where to run the Echo environment: 'local' to launch it, 'docker' if already running, or 'space' to use a remote Space URL.",
choices=["local", "docker-local", "docker-image", "docker-hub", "space"],
default="docker-image",
help="Where to run the Echo environment: 'local' to launch it, 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.",
)
parser.add_argument(
"--model",
Expand All @@ -87,6 +87,9 @@ def parse_args():
default="trl-lib/ultrafeedback-prompt",
help="Dataset to use for training.",
)
parser.add_argument(
"--env-image", type=str, default="echo-env:latest", help="Docker image for the Echo environment."
)
parser.add_argument(
"--vllm-mode",
choices=["colocate", "server"],
Expand Down Expand Up @@ -146,25 +149,34 @@ def main():
if args.env_mode == "local":
env_url = f"http://{args.env_host}:{args.env_port}"
server_process = start_env_server(args.env_host, args.env_port)
elif args.env_mode == "docker":
elif args.env_mode == "docker-local":
env_url = f"http://{args.env_host}:{args.env_port}"
server_process = None
print(f"🌍 Using existing Echo Environment (Docker) at: {env_url}")
elif args.env_mode == "docker-image":
client = EchoEnv.from_docker_image(args.env_image)
server_process = None
print("🌍 Using Echo Environment (Docker) from local Image")
elif args.env_mode == "docker-hub":
client = EchoEnv.from_hub(args.env_image)
server_process = None
print("🌍 Using existing Echo Environment (Docker) from Hub Image")
elif args.env_mode == "space":
env_url = args.env_host
server_process = None
print(f"🚀 Using Hugging Face Space environment at: {env_url}")
print(f"🌍 Using Hugging Face Space environment at: {env_url}")
else:
raise ValueError(f"Unknown environment mode: {args.env_mode}")

client = EchoEnv(base_url=env_url)
if args.env_mode != "docker-hub" and args.env_mode != "docker-image":
client = EchoEnv(base_url=env_url)
dataset = load_dataset(args.dataset, split="train[:1000]")

training_args = GRPOConfig(
output_dir=f"{args.model.split('/')[-1]}-GRPO-Rollout",
use_vllm=True,
vllm_mode=args.vllm_mode,
vllm_server_url=args.vllm_server_url if args.vllm_mode == "server" else None,
vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
logging_steps=1,
report_to="trackio",
num_train_epochs=1,
Expand Down
91 changes: 57 additions & 34 deletions examples/scripts/openenv/wordle.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,15 @@ def parse_args() -> argparse.Namespace:
default="Qwen/Qwen3-1.7B",
help="Model identifier passed to GRPOTrainer for fine-tuning.",
)
parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the environment server.")
parser.add_argument("--env-port", type=int, default=8001, help="Port for the environment server.")
parser.add_argument(
"--env-url",
default="https://burtenshaw-textarena.hf.space",
help="URL for the TextArena Wordle environment.",
"--env-mode",
choices=["docker-local", "docker-image", "docker-hub", "space"],
default="docker-image",
help="Where to run the environment: 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.",
)
parser.add_argument("--env-image", type=str, default="textarena-env:latest", help="Docker image for the TextArena environment.")
parser.add_argument(
"--system-prompt-path",
default="wordle_prompt.txt",
Expand Down Expand Up @@ -411,46 +415,65 @@ def reward_repetition(completions: list[str], **kwargs) -> list[float]:


def main() -> None:
cli_args = parse_args()
args = parse_args()

tokenizer = AutoTokenizer.from_pretrained(cli_args.tokenizer_id)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
tokenizer.pad_token = tokenizer.eos_token

env = TextArenaEnv(base_url=cli_args.env_url)
# Select environment mode
if args.env_mode == "docker-local":
env_url = f"http://{args.env_host}:{args.env_port}"
print(f"🌍 Using existing TextArena Environment (Docker) at: {env_url}")
elif args.env_mode == "docker-image":
client = TextArenaEnv.from_docker_image(args.env_image)
print(f"🌍 Using TextArena Environment (Docker) from local Image")
elif args.env_mode == "docker-hub":
client = TextArenaEnv.from_hub(args.env_image)
print(f"🌍 Using existing TextArena Environment (Docker) from Hub Image")
elif args.env_mode == "space":
env_url = args.env_host
print(f"🌍 Using Hugging Face Space environment at: {env_url}")
else:
raise ValueError(f"Unknown environment mode: {args.env_mode}")

system_prompt = resolve_system_prompt(cli_args.system_prompt_path)
if args.env_mode != "docker-hub" and args.env_mode != "docker-image":
client = TextArenaEnv(base_url=env_url)

dataset = Dataset.from_dict({"prompt": [cli_args.dataset_prompt] * cli_args.dataset_size})
#env = TextArenaEnv(base_url=args.env_url)

system_prompt = resolve_system_prompt(args.system_prompt_path)

dataset = Dataset.from_dict({"prompt": [args.dataset_prompt] * args.dataset_size})

timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
default_output_dir = Path("outputs") / f"wordle-grpo-{sanitize_name(cli_args.model_id)}-{timestamp}"
output_dir = Path(cli_args.output_dir or default_output_dir)
default_output_dir = Path("outputs") / f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}"
output_dir = Path(args.output_dir or default_output_dir)

grpo_config = GRPOConfig(
use_vllm=True,
vllm_mode=cli_args.vllm_mode,
vllm_server_url=cli_args.vllm_server_url if cli_args.vllm_mode == "server" else None,
vllm_mode=args.vllm_mode,
vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
output_dir=str(output_dir),
num_train_epochs=cli_args.num_epochs,
learning_rate=cli_args.learning_rate,
weight_decay=cli_args.weight_decay,
gradient_accumulation_steps=cli_args.gradient_accumulation_steps,
per_device_train_batch_size=cli_args.per_device_batch_size,
warmup_steps=cli_args.warmup_steps,
num_generations=cli_args.num_generations,
max_completion_length=cli_args.max_new_tokens,
logging_steps=cli_args.logging_steps,
num_train_epochs=args.num_epochs,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
gradient_accumulation_steps=args.gradient_accumulation_steps,
per_device_train_batch_size=args.per_device_batch_size,
warmup_steps=args.warmup_steps,
num_generations=args.num_generations,
max_completion_length=args.max_new_tokens,
logging_steps=args.logging_steps,
save_strategy="steps",
save_steps=cli_args.save_interval,
save_total_limit=cli_args.save_total_limit,
temperature=cli_args.temperature,
top_k=cli_args.top_k,
top_p=cli_args.top_p,
save_steps=args.save_interval,
save_total_limit=args.save_total_limit,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)

grpo_config.run_name = cli_args.run_name or f"run-{timestamp}"
grpo_config.project = cli_args.project or f"group-{sanitize_name(cli_args.model_id)}"
grpo_config.trackio_space_id = cli_args.trackio_space_id
grpo_config.run_name = args.run_name or f"run-{timestamp}"
grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}"
grpo_config.trackio_space_id = args.trackio_space_id

def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
episode_prompt_ids: list[list[int]] = []
Expand All @@ -464,11 +487,11 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
for prompt_text in prompts:
episode = rollout_once(
trainer=trainer,
env=env,
env=client,
tokenizer=tokenizer,
dataset_prompt=prompt_text,
system_prompt=system_prompt,
max_turns=cli_args.max_turns,
max_turns=args.max_turns,
)
episode_prompt_ids.append(episode["prompt_ids"])
episode_completion_ids.append(episode["completion_ids"])
Expand All @@ -489,7 +512,7 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
}

trainer = GRPOTrainer(
model=cli_args.model_id,
model=args.model_id,
processing_class=tokenizer,
reward_funcs=[
reward_correct,
Expand All @@ -503,12 +526,12 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
)

print("Starting GRPO training with Wordle environment...")
print(f"Using {cli_args.num_generations} rollouts per dataset prompt")
print(f"Using {args.num_generations} rollouts per dataset prompt")

try:
trainer.train()
finally:
env.close()
client.close()


if __name__ == "__main__":
Expand Down
Loading