Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 18 additions & 39 deletions examples/scripts/openenv/wordle.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,15 @@
Setup:

```sh
# uv pip install git+https://github.com/meta-pytorch/OpenEnv.git
# Hotfix: https://github.com/huggingface/trl/pull/4740
uv pip install git+https://github.com/meta-pytorch/OpenEnv.git@bf5e968286e0d49cdc03fd904d48faff4b15a437 openenv_core==0.1.1
uv pip install git+https://huggingface.co/spaces/burtenshaw/wordle
```

Usage:

# Start the environment only if using --env-mode docker-local; In other modes, the env is automatically managed by the script.
```sh
docker run -d -p 8001:8001 registry.hf.space/burtenshaw-textarena:latest
# or TEXTARENA_ENV_ID=Wordle-v0 TEXTARENA_NUM_PLAYERS=1 python -m src.envs.textarena_env.server.app
```

# Option 1: Colocated vLLM (1 GPU required)
# Option 1: HF Spaces + Colocated vLLM (1 GPU required)
```sh
python examples/scripts/openenv/wordle.py --vllm-mode colocate
```

# Option 2: Separate vLLM server (2 GPUs required)
# Option 2: HF Spaces + Separate vLLM server (2 GPUs required)

# Spin up vLLM server (Terminal 1)
```sh
Expand All @@ -59,6 +49,19 @@
```sh
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py --vllm-mode server --vllm-server-url http://localhost:8000
```

# Option 3: Local + Colocated vLLM (1 GPU required)

Usage:

# Start the environment only if using --env-mode docker-local; In other modes, the env is automatically managed by the script.
```sh
docker run -d -p 8001:8001 registry.hf.space/burtenshaw-wordle:latest
```

```sh
python examples/scripts/openenv/wordle.py --vllm-mode colocate
```
"""

from __future__ import annotations
Expand Down Expand Up @@ -99,16 +102,8 @@ def parse_args() -> argparse.Namespace:
default="Qwen/Qwen3-1.7B",
help="Model identifier passed to GRPOTrainer for fine-tuning.",
)
parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the environment server.")
parser.add_argument("--env-port", type=int, default=8001, help="Port for the environment server.")
parser.add_argument(
"--env-mode",
choices=["docker-local", "docker-image", "docker-hub", "space"],
default="docker-image",
help="Where to run the environment: 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.",
)
parser.add_argument(
"--env-image", type=str, default="textarena-env:latest", help="Docker image for the TextArena environment."
"--env-url", type=str, default="https://burtenshaw-wordle.hf.space", help="URL for the environment server."
)
parser.add_argument(
"--system-prompt-path",
Expand Down Expand Up @@ -436,23 +431,7 @@ def main() -> None:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
tokenizer.pad_token = tokenizer.eos_token

# Select environment mode
if args.env_mode == "docker-local":
env_url = f"http://{args.env_host}:{args.env_port}"
client = TextArenaEnv(base_url=env_url)
print(f"🌍 Using existing TextArena Environment (Docker) at: {env_url}")
elif args.env_mode == "docker-image":
client = TextArenaEnv.from_docker_image(args.env_image)
print("🌍 Using TextArena Environment (Docker) from local Image")
elif args.env_mode == "docker-hub":
client = TextArenaEnv.from_hub(args.env_image)
print("🌍 Using existing TextArena Environment (Docker) from Hub Image")
elif args.env_mode == "space":
env_url = args.env_host
client = TextArenaEnv(base_url=env_url)
print(f"🌍 Using Hugging Face Space environment at: {env_url}")
else:
raise ValueError(f"Unknown environment mode: {args.env_mode}")
client = TextArenaEnv(base_url=args.env_url)

system_prompt = resolve_system_prompt(args.system_prompt_path)

Expand Down
8 changes: 4 additions & 4 deletions trl/experimental/openenv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

if is_vllm_available():
from vllm import SamplingParams
from vllm.sampling_params import StructuredOutputsParams
from vllm.sampling_params import GuidedDecodingParams
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have to revert this, see #4117

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah damn!



def _build_colocate_sampling_params(
Expand All @@ -33,17 +33,17 @@ def _build_colocate_sampling_params(
logprobs: bool = True,
) -> SamplingParams:
if trainer.structured_outputs_regex:
structured_outputs = StructuredOutputsParams(regex=trainer.structured_outputs_regex)
guided_decoding = GuidedDecodingParams(regex=trainer.structured_outputs_regex)
else:
structured_outputs = None
guided_decoding = None

generation_kwargs: dict[str, Any] = {
"n": 1,
"temperature": trainer.temperature,
"top_k": trainer.top_k,
"min_p": 0.0 if trainer.min_p is None else trainer.min_p,
"max_tokens": trainer.max_completion_length,
"structured_outputs": structured_outputs,
"guided_decoding": guided_decoding,
}
if trainer.repetition_penalty is not None:
generation_kwargs["repetition_penalty"] = trainer.repetition_penalty
Expand Down
Loading