diff --git a/examples/notebooks/openenv_wordle_grpo.ipynb b/examples/notebooks/openenv_wordle_grpo.ipynb index 67f0c4e9376..6711024fb68 100644 --- a/examples/notebooks/openenv_wordle_grpo.ipynb +++ b/examples/notebooks/openenv_wordle_grpo.ipynb @@ -44,7 +44,7 @@ "## Install dependencies\n", "\n", "We'll start by installing **TRL**, which automatically includes the main dependencies like **Transformers**. \n", - "We'll also install the **OpenEnv** framework via the remote deployent env at [burtenshaw/wordle](https://huggingface.co/spaces/burtenshaw/wordle) (for the environment), **trackio** (for logging and monitoring training runs), and **vLLM** (for efficient generation)." + "We'll also install the **OpenEnv** framework via the remote deployent env at [sergiopaniego/wordle](https://huggingface.co/spaces/sergiopaniego/wordle) (for the environment), **trackio** (for logging and monitoring training runs), and **vLLM** (for efficient generation)." ] }, { @@ -56,7 +56,7 @@ }, "outputs": [], "source": [ - "!pip install -Uq trl[vllm] git+https://huggingface.co/spaces/burtenshaw/wordle trackio bitsandbytes" + "!pip install -Uq trl[vllm] git+https://huggingface.co/spaces/sergiopaniego/wordle trackio" ] }, { @@ -97,7 +97,7 @@ "Let's begin by setting up the environment that will be used during training. \n", "For this task, we'll rely on the **TextArena** environment from **OpenEnv**, which exposes a familiar Gymnasium-style API (`reset()`, `step()`, etc.) to simplify interaction.\n", "\n", - "In this example, we'll connect to the hosted environment at [burtenshaw/textarena](https://huggingface.co/spaces/burtenshaw/textarena). \n", + "In this example, we'll connect to the hosted environment at [sergiopaniego/wordle](https://huggingface.co/spaces/sergiopaniego/wordle). \n", "For production use or custom configurations, we **strongly recommend** running the environment locally via Docker. The hosted versions on the Hub currently have limited concurrency support, so duplicating the Space to your own account is the preferred approach in those cases.\n", "\n", "For more information, refer to the [TRL-OpenEnv documentation](https://huggingface.co/docs/trl/main/en/openenv).\n" @@ -108,16 +108,26 @@ "execution_count": null, "id": "rZimqp1UTIV_", "metadata": { - "id": "rZimqp1UTIV_" + "id": "rZimqp1UTIV_", + "outputId": "e53c277c-6050-4380-84e1-983857f0b325" }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n" + ] + } + ], "source": [ "from textarena_env import TextArenaEnv\n", "\n", - "textarena_url = \"https://burtenshaw-wordle.hf.space\" # Duplicate the Space and update this!\n", - "env = TextArenaEnv(base_url=textarena_url)\n", - "# textarena_url = \"burtenshaw/wordle\"\n", - "# env = TextArenaEnv.from_hub(repo_id=textarena_url)" + "wordle_url = \"https://sergiopaniego-wordle.hf.space\" # Duplicate the Space and update this!\n", + "env = TextArenaEnv(base_url=wordle_url)\n", + "# wordle_url = \"sergiopaniego/wordle\"\n", + "# env = TextArenaEnv.from_hub(repo_id=wordle_url)" ] }, { @@ -142,9 +152,24 @@ "execution_count": null, "id": "lR7usp2Dd-JK", "metadata": { - "id": "lR7usp2Dd-JK" + "id": "lR7usp2Dd-JK", + "outputId": "b8a60feb-e0c0-47c9-839e-2743a502341f" }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:104: UserWarning: \n", + "Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.\n", + "You are not authenticated with the Hugging Face Hub in this notebook.\n", + "If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).\n", + " warnings.warn(\n" + ] + } + ], "source": [ "from transformers import AutoTokenizer\n", "\n", @@ -169,6 +194,7 @@ "- The function is called automatically by the **GRPOTrainer** during each training step. \n", "- It uses the trainer's built-in `generate_rollout_completions()` method for efficient generation with vLLM in colocate mode.\n", "- Each rollout represents a full interaction loop. The model guesses, receives feedback from Wordle, and updates based on reward signals.\n", + "- The **`env_mask`** tracks which tokens are model-generated vs environment-generated, ensuring only model tokens contribute to the training loss.\n", "\n", "The rewards track different aspects of the agent's performance. Helper functions (like `rollout_once`) handle one episode of interaction, keeping the main `rollout_func` clean and modular.\n", "\n", @@ -210,6 +236,7 @@ "[guess]\n", "```\n", "\n", + "\n", "## STRATEGIC APPROACH\n", "\n", "Do not repeat the same guess twice.\n", @@ -222,7 +249,7 @@ "### Mid-Game Strategy\n", "- Use confirmed GREEN letters in their correct positions\n", "- Place YELLOW letters in different positions than where they appeared\n", - "- Eliminate GRAY letters from consideration\n", + "- Eliminate GRAY letters entirely from consideration\n", "- If multiple letters are unknown, prioritize common letter combinations (TH, CH, ST, ER, etc.)\n", "- Consider letter frequency: E is most common, followed by A, R, I, O, T, N, S\n", "\n", @@ -235,12 +262,12 @@ "- Use \"sacrificial\" guesses to test multiple new letters if you have attempts to spare\n", "- Avoid repeating letter patterns unless you're certain (e.g., SPEED has two E's)\n", "- Think about word endings: -ER, -LY, -ED, -ING are common but may not fit the 5-letter constraint\n", - "- Consider less common letters (Q, X, Z, J) only when you've eliminated the most common options\n", + "- Consider less common letters (Q, X, Z, J) only when you've eliminated most common options\n", "\n", "### Common Pitfalls to Avoid\n", - "- Don't reuse letters marked GRAY (eliminated letters)\n", - "- Don't place YELLOW letters in the same position they appeared\n", - "- Don't ignore confirmed GREEN letters in future guesses\n", + "- Don't reuse X letters\n", + "- Don't place Y letters in the same position they appeared\n", + "- Don't ignore confirmed G letters\n", "- Don't guess words that contradict known information\n", "\n", "## EXAMPLES\n", @@ -260,7 +287,7 @@ "Previous guesses: CRANE (C=gray, R=yellow, A=green, N=gray, E=yellow), SPARE (S=gray, P=gray, A=green, R=green, E=green)\n", "Feedback summary: _ARE_ with R in position 4, A in position 2, E in position 5\n", "\n", - "\"I have _AR E_ confirmed. Positions 1 and 3 are unknown. Common letters to try: T, L, D, B, F, G. Testing with TARED.\"\n", + "\"I have _AR E_ confirmed. Position 1 and 3 are unknown. Common letters to try: T, L, D, B, F, G. Testing with TARED.\"\n", "[tared]\n", "\n", "### Example 4: Final Deduction\n", @@ -316,27 +343,30 @@ }, "outputs": [], "source": [ - "def rollout_func(prompts, trainer=None):\n", + "max_new_tokens = 8\n", + "max_turns = 6\n", + "\n", + "def rollout_func(prompts, trainer):\n", " \"\"\"\n", " Rollout function for GRPO training with environment interaction.\n", "\n", " This function is called by GRPOTrainer to generate completions and compute rewards.\n", - " In colocate mode, it uses trainer.generate_rollout_completions() for inference.\n", + " It uses trainer.generate_rollout_completions() for inference.\n", "\n", " Args:\n", " prompts: List of prompts to generate from\n", " trainer: GRPOTrainer instance containing context and configuration\n", "\n", " Returns:\n", - " Dictionary with prompt_ids, completion_ids, logprobs, and reward signals\n", + " Dictionary with prompt_ids, completion_ids, logprobs, env_mask, and reward signals\n", " \"\"\"\n", " episode_prompt_ids = []\n", " episode_completion_ids = []\n", " episode_logprobs = []\n", + " episode_env_masks = []\n", " correctness_rewards = []\n", - " green_rewards = []\n", - " yellow_rewards = []\n", - " repetition_rewards = []\n", + " position_rewards = []\n", + " format_rewards = []\n", "\n", " for prompt_text in prompts:\n", " episode = rollout_once(\n", @@ -345,24 +375,25 @@ " tokenizer=tokenizer,\n", " dataset_prompt=prompt_text,\n", " system_prompt=system_prompt,\n", - " max_turns=6,\n", + " max_turns=max_turns,\n", + " max_new_tokens=max_new_tokens,\n", " )\n", " episode_prompt_ids.append(episode[\"prompt_ids\"])\n", " episode_completion_ids.append(episode[\"completion_ids\"])\n", " episode_logprobs.append(episode[\"logprobs\"])\n", + " episode_env_masks.append(episode[\"env_mask\"])\n", " correctness_rewards.append(episode[\"correct_reward\"])\n", - " green_rewards.append(episode[\"green_reward\"])\n", - " yellow_rewards.append(episode[\"yellow_reward\"])\n", - " repetition_rewards.append(episode[\"repetition_reward\"])\n", + " position_rewards.append(episode[\"position_reward\"])\n", + " format_rewards.append(compute_format_reward(episode[\"model_outputs\"]))\n", "\n", " return {\n", " \"prompt_ids\": episode_prompt_ids,\n", " \"completion_ids\": episode_completion_ids,\n", " \"logprobs\": episode_logprobs,\n", + " \"env_mask\": episode_env_masks,\n", " \"correct_reward\": correctness_rewards,\n", - " \"green_reward\": green_rewards,\n", - " \"yellow_reward\": yellow_rewards,\n", - " \"repetition_reward\": repetition_rewards,\n", + " \"position_reward\": position_rewards,\n", + " \"format_reward\": format_rewards,\n", " }" ] }, @@ -385,7 +416,16 @@ "3. **Generation:** Use `trl.experimental.openenv.generate_rollout_completions()` to produce the model's guess efficiently. \n", "4. **Feedback extraction:** Parse the environment's response using helpers like `extract_guess()` and `extract_wordle_feedback()`. \n", "5. **Reward calculation:** Compute rewards based on correctness, green/yellow feedback, and repetition penalty.\n", - "6. **Return structured rollout data:** Includes prompt/completion IDs, logprobs, and all computed reward components.\n", + "6. **Return structured rollout data:** Includes prompt/completion IDs, logprobs, `env_mask`, and all computed reward components.\n", + "\n", + "**Important: The `env_mask` mechanism**\n", + "\n", + "In multi-turn environments like Wordle, the completion includes both:\n", + "- **Model-generated tokens** (the guesses): These should contribute to the loss during training.\n", + "- **Environment feedback tokens** (game responses): These should NOT contribute to the loss.\n", + "\n", + "The `env_mask` is a list of 1s and 0s that marks which tokens are model-generated (`1`) vs environment-generated (`0`). \n", + "The GRPOTrainer uses this mask to exclude environment tokens from the loss calculation, ensuring the model only learns from its own outputs.\n", "\n", "This modular design ensures that each episode can be processed independently while still providing rich feedback for the **GRPO training loop**." ] @@ -395,58 +435,60 @@ "execution_count": null, "id": "5c585602-5352-4e57-8d35-e5b95e05f6c5", "metadata": { - "id": "5c585602-5352-4e57-8d35-e5b95e05f6c5", - "outputId": "8a88e037-043c-4fb5-ffb2-1a5c6dd87924" + "id": "5c585602-5352-4e57-8d35-e5b95e05f6c5" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipython-input-1881463685.py:4: UserWarning: You are importing from 'trl.experimental'. APIs here are unstable and may change or be removed without notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.\n", - " from trl.experimental.openenv import generate_rollout_completions\n" - ] - } - ], + "outputs": [], "source": [ - "from collections import defaultdict\n", + "import re\n", "from textarena_env import TextArenaAction\n", "from textarena_env.rewards import extract_feedback_counts, extract_guess, extract_wordle_feedback\n", "from trl.experimental.openenv import generate_rollout_completions\n", "\n", - "\n", - "def rollout_once(trainer, env, tokenizer, dataset_prompt, system_prompt, max_turns):\n", - " \"\"\"\n", - " Execute one full Wordle episode with the model.\n", - "\n", - " This function uses generate_rollout_completions() instead of manual vLLM handling,\n", - " making the code cleaner and more maintainable.\n", - " \"\"\"\n", + "def rollout_once(trainer, env, tokenizer, dataset_prompt, system_prompt, max_turns, max_new_tokens):\n", " result = env.reset()\n", " observation = result.observation\n", "\n", " prompt_ids = []\n", " completion_ids = []\n", " logprobs = []\n", + " env_mask = [] # 1 for model-generated tokens, 0 for environment tokens\n", + " model_outputs = []\n", " raw_rewards = []\n", - " green_scores = []\n", - " yellow_scores = []\n", - " repetition_scores = []\n", + " position_scores = []\n", " correct_scores = []\n", - " guess_counts = defaultdict(int)\n", + " prev_env_output_len = 0 # Track length to only add NEW portion each turn\n", + "\n", + " accumulated_messages: list[dict[str, str]] = [{\"role\": \"system\", \"content\": system_prompt}]\n", + " # Build initial prompt (only once, at the start)\n", + " # The initial env messages are included in the prompt, not completion\n", + " base_prompt = observation.prompt or dataset_prompt\n", + " initial_user_prompt = make_user_prompt(base_prompt, observation.messages)\n", + " # Track initial env output length so we don't add it again\n", + " initial_env_output = format_history(observation.messages) if observation.messages else \"\"\n", + " prev_env_output_len = len(initial_env_output)\n", + " initial_messages = accumulated_messages + [{\"role\": \"user\", \"content\": initial_user_prompt}]\n", + " initial_prompt_text = tokenizer.apply_chat_template(\n", + " initial_messages,\n", + " add_generation_prompt=True,\n", + " tokenize=False,\n", + " enable_thinking=False,\n", + " )\n", + " # Tokenize initial prompt once - this is the base prompt for the entire episode.\n", + " # GRPO expects one prompt-completion pair per episode, where:\n", + " # - prompt_ids = the initial/base prompt (what the model sees at episode start)\n", + " # - completion_ids = all model responses + env feedback from all turns concatenated\n", + " # Note: The actual prompts used for generation in each turn are longer (include conversation history),\n", + " # but we only count the initial prompt tokens here.\n", + " initial_prompt_ids = tokenizer.encode(initial_prompt_text, add_special_tokens=False)\n", + " prompt_ids.extend(initial_prompt_ids)\n", "\n", " for _turn in range(max_turns):\n", - " # when the game is over the environment will return a done=True\n", " if result.done:\n", " break\n", "\n", - " # set up the prompt for the model\n", " base_prompt = observation.prompt or dataset_prompt\n", " user_prompt = make_user_prompt(base_prompt, observation.messages)\n", - " messages = [\n", - " {\"role\": \"system\", \"content\": system_prompt},\n", - " {\"role\": \"user\", \"content\": user_prompt},\n", - " ]\n", + " messages = accumulated_messages + [{\"role\": \"user\", \"content\": user_prompt}]\n", " prompt_text = tokenizer.apply_chat_template(\n", " messages,\n", " add_generation_prompt=True,\n", @@ -454,55 +496,80 @@ " enable_thinking=False,\n", " )\n", "\n", - " # Generate using trainer's built-in method (much cleaner!)\n", - " rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]\n", - " prompt_ids.extend(rollout_outputs[\"prompt_ids\"])\n", + " rollout_outputs = generate_rollout_completions(\n", + " trainer, [prompt_text], generation_overrides={\"max_tokens\": max_new_tokens}\n", + " )[0]\n", + " # Add model-generated completion tokens and logprobs with newlines for readability\n", + " newline_tokens = tokenizer.encode(\"\\n\", add_special_tokens=False)\n", + " completion_ids.extend(newline_tokens) # newline before guess\n", + " logprobs.extend([0.0] * len(newline_tokens))\n", + " env_mask.extend([1] * len(newline_tokens)) # newlines are part of model output format\n", + "\n", " completion_ids.extend(rollout_outputs[\"completion_ids\"])\n", " logprobs.extend(rollout_outputs[\"logprobs\"])\n", + " env_mask.extend([1] * len(rollout_outputs[\"completion_ids\"])) # model-generated tokens\n", + "\n", + " completion_ids.extend(newline_tokens) # newline after guess\n", + " logprobs.extend([0.0] * len(newline_tokens))\n", + " env_mask.extend([1] * len(newline_tokens)) # newlines are part of model output format\n", " completion_text = rollout_outputs.get(\"text\") or tokenizer.decode(\n", " rollout_outputs[\"completion_ids\"], skip_special_tokens=True\n", " )\n", - "\n", - " # extract the guess from the completion\n", " guess = extract_guess(completion_text)\n", + " model_outputs.append(completion_text.strip()) # Store raw model output for format reward\n", "\n", - " # step the environment with the guess\n", " result = env.step(TextArenaAction(message=guess))\n", + "\n", " raw_rewards.append(float(result.reward or 0.0))\n", " observation = result.observation\n", " correct_score = float(result.reward or 0.0)\n", " feedback = extract_wordle_feedback(observation)\n", "\n", - " # Update guess counts\n", - " previous_occurrences = guess_counts[guess]\n", - " repetition_score = scale_repetition_score(previous_occurrences, len(guess_counts))\n", - " guess_counts[guess] += 1\n", + " full_env_output = format_history(observation.messages) if observation.messages else \"\"\n", + " new_env_output = full_env_output[prev_env_output_len:].lstrip(\"\\n\")\n", + " prev_env_output_len = len(full_env_output)\n", + "\n", + " if new_env_output:\n", + " env_output_tokens = tokenizer.encode(new_env_output, add_special_tokens=False)\n", + " completion_ids.extend(env_output_tokens) # Add to completion_ids\n", + " logprobs.extend([0.0] * len(env_output_tokens)) # Placeholder (ignored via env_mask=0)\n", + " env_mask.extend([0] * len(env_output_tokens)) # Environment tokens - mask out from loss\n", + " completion_with_env = completion_text + \"\\n\" + new_env_output\n", + " else:\n", + " completion_with_env = completion_text\n", + "\n", + " accumulated_messages.append({\"role\": \"user\", \"content\": user_prompt})\n", + " accumulated_messages.append({\"role\": \"assistant\", \"content\": completion_with_env})\n", "\n", - " # calculate custom reward signals from the feedback\n", " if not feedback:\n", - " green_score = 0.0\n", - " yellow_score = 0.0\n", + " position_score = 0.0\n", " else:\n", " green_count, yellow_count = extract_feedback_counts(feedback)\n", - " green_score = green_count / 5.0\n", - " yellow_score = yellow_count / 5.0\n", + " position_score = (green_count + 0.5 * yellow_count) / 5.0\n", "\n", - " repetition_scores.append(repetition_score)\n", - " green_scores.append(green_score)\n", - " yellow_scores.append(yellow_score)\n", + " position_scores.append(position_score)\n", " correct_scores.append(correct_score)\n", "\n", + " # Use the final correct reward (win/lose is binary at end)\n", " correct_reward_value = correct_scores[-1] if correct_scores else (raw_rewards[-1] if raw_rewards else 0.0)\n", "\n", + " # Position reward as shaping signal:\n", + " # - If model WINS: position_reward = 1.0 (no penalty for winning fast)\n", + " # - If model LOSES: position_reward = last attempt (where it ended up)\n", + " if correct_reward_value >= 1.0:\n", + " final_position_reward = 1.0\n", + " else:\n", + " final_position_reward = position_scores[-1] if position_scores else 0.0\n", + "\n", " return {\n", " \"prompt_ids\": prompt_ids,\n", " \"completion_ids\": completion_ids,\n", " \"logprobs\": logprobs,\n", + " \"env_mask\": env_mask,\n", " \"raw_rewards\": raw_rewards,\n", " \"correct_reward\": correct_reward_value,\n", - " \"green_reward\": green_scores[-1] if green_scores else 0.0,\n", - " \"yellow_reward\": yellow_scores[-1] if yellow_scores else 0.0,\n", - " \"repetition_reward\": repetition_scores[-1] if repetition_scores else 0.0,\n", + " \"position_reward\": final_position_reward,\n", + " \"model_outputs\": model_outputs,\n", " }" ] }, @@ -517,9 +584,8 @@ "\n", "Supporting utilities used in `rollout_once`:\n", "\n", - "- **`make_user_prompt`**: builds the user prompt combining the base text and previous game messages.\n", - "- **`format_history`**: formats the conversation log for consistent context.\n", - "- **`scale_repetition_score`**: applies a penalty when guesses are repeated to encourage exploration." + "- **`make_user_prompt`**: builds the user prompt combining the conversation history.\n", + "- **`format_history`**: formats the conversation log for consistent context." ] }, { @@ -532,19 +598,7 @@ "outputs": [], "source": [ "# @title Helpers definition (click to expand)\n", - "def make_user_prompt(prompt_text, messages):\n", - " \"\"\"Builds a structured user prompt combining the task description and message history\"\"\"\n", - " history = format_history(messages)\n", - " prompt_section = prompt_text.strip() if prompt_text.strip() else \"Wordle-v0\"\n", - " history_section = history if history else \"[PROMPT] Awaiting first feedback.\"\n", - " return (\n", - " f\"Game prompt:\\n{prompt_section}\\n\\n\"\n", - " f\"Conversation so far:\\n{history_section}\\n\\n\"\n", - " \"Reply with your next guess enclosed in square brackets.\"\n", - " )\n", - "\n", - "def format_history(messages):\n", - " \"\"\"Formats the message history with tags for clear conversational context\"\"\"\n", + "def format_history(messages) -> str:\n", " lines = []\n", " for message in messages:\n", " tag = message.category or \"MESSAGE\"\n", @@ -554,11 +608,12 @@ " lines.append(f\"[{tag}] {content}\")\n", " return \"\\n\".join(lines)\n", "\n", - "def scale_repetition_score(previous_occurrences, max_occurrences):\n", - " \"\"\"Scale the repetition score based on the number of previous occurrences from 0 to 1\"\"\"\n", - " if max_occurrences == 0:\n", - " return 0.0\n", - " return (max_occurrences - previous_occurrences) / max_occurrences" + "\n", + "def make_user_prompt(prompt_text, messages) -> str:\n", + " history = format_history(messages)\n", + " # Only use messages for conversation history - the prompt is already included as the first message\n", + " history_section = history if history else \"[PROMPT] Awaiting first feedback.\"\n", + " return f\"Conversation so far:\\n{history_section}\\n\\nReply with your next guess enclosed in square brackets.\"" ] }, { @@ -573,13 +628,12 @@ "To guide the agent's learning process, we define simple reward functions that map the feedback from the environment into numeric signals. \n", "Each function corresponds to a specific aspect of the **Wordle** game:\n", "\n", - "- ✅ **`reward_correct`**: rewards the model when it guesses the correct word. \n", - "- 🟩 **`reward_greens`**: rewards letters correctly placed (green feedback). \n", - "- 🟨 **`reward_yellows`**: rewards letters that are present but misplaced (yellow feedback). \n", - "- 🔁 **`reward_repetition`**: rewards diverse guessing by scoring based on guess uniqueness.\n", + "- ✅ **`reward_correct`**: rewards the model when it guesses the correct word (binary: 0 or 1). \n", + "- 🎯 **`reward_position`**: rewards progress based on letter feedback. Green letters worth 1.0, yellow worth 0.5, normalized by 5. If the model wins, this is set to 1.0.\n", + "- 📝 **`reward_format_strict`**: rewards correct output format `[xxxxx]`. Returns proportion of correctly formatted outputs across all turns.\n", "\n", "These functions return lists of float values that the **GRPOTrainer** uses during optimization. \n", - "By combining them, the model learns to balance correctness, information gathering, and exploration in its guessing strategy." + "By combining them, the model learns to balance correctness, information gathering, and proper formatting in its guessing strategy." ] }, { @@ -592,28 +646,39 @@ "outputs": [], "source": [ "def reward_correct(completions, **kwargs):\n", + " \"\"\"Reward from environment (correct answer).\"\"\"\n", " rewards = kwargs.get(\"correct_reward\") if kwargs else None\n", " if rewards is None:\n", " return [0.0 for _ in completions]\n", " return [float(r) for r in rewards]\n", "\n", "\n", - "def reward_greens(completions, **kwargs):\n", - " rewards = kwargs.get(\"green_reward\") if kwargs else None\n", + "def reward_position(completions, **kwargs):\n", + " \"\"\"Position reward: green worth 1.0, yellow worth 0.5, normalized by 5.\"\"\"\n", + " rewards = kwargs.get(\"position_reward\") if kwargs else None\n", " if rewards is None:\n", " return [0.0 for _ in completions]\n", " return [float(r) for r in rewards]\n", "\n", "\n", - "def reward_yellows(completions, **kwargs):\n", - " rewards = kwargs.get(\"yellow_reward\") if kwargs else None\n", - " if rewards is None:\n", - " return [0.0 for _ in completions]\n", - " return [float(r) for r in rewards]\n", + "def compute_format_reward(model_outputs):\n", + " \"\"\"Compute format reward from a list of model outputs (one per turn).\n", + "\n", + " Each output should be exactly [5 letters] with optional whitespace.\n", + " Returns proportion of correctly formatted outputs.\n", + " \"\"\"\n", + " if not model_outputs:\n", + " return 0.0\n", "\n", + " exact_pattern = re.compile(r\"^\\s*\\[[A-Za-z]{5}\\]\\s*$\")\n", + " correct_count = sum(1 for output in model_outputs if exact_pattern.match(output))\n", "\n", - "def reward_repetition(completions, **kwargs):\n", - " rewards = kwargs.get(\"repetition_reward\") if kwargs else None\n", + " return correct_count / len(model_outputs)\n", + "\n", + "\n", + "def reward_format_strict(completions, **kwargs):\n", + " \"\"\"Format reward - pre-computed in rollout_func.\"\"\"\n", + " rewards = kwargs.get(\"format_reward\") if kwargs else None\n", " if rewards is None:\n", " return [0.0 for _ in completions]\n", " return [float(r) for r in rewards]" @@ -643,7 +708,7 @@ "source": [ "from datasets import Dataset\n", "\n", - "dataset_size = 1000\n", + "dataset_size = 3000\n", "dataset_prompt = \"Play Wordle like an expert.\"\n", "\n", "dataset = Dataset.from_dict({\"prompt\": [dataset_prompt] * dataset_size})" @@ -673,37 +738,43 @@ "source": [ "from trl import GRPOConfig\n", "\n", - "output_dir = \"wordle-grpo-Qwen3-1.7B\"\n", + "output_dir = \"wordle-grpo-Qwen3-1.7B-test\"\n", "\n", "grpo_config = GRPOConfig(\n", " # Training schedule / optimization\n", - " num_train_epochs = 1, # Number of full dataset passes\n", - " learning_rate = 5e-6, # Learning rate for the optimizer\n", - " gradient_accumulation_steps = 64, # Accumulate gradients over multiple steps\n", - " per_device_train_batch_size = 1, # Batch size per GPU (number of prompts processed together)\n", - " warmup_steps = 20, # Steps for learning rate warmup\n", + " num_train_epochs = 1, # Number of full dataset passes\n", + " learning_rate = 1e-6, # Learning rate for the optimizer\n", + " gradient_accumulation_steps = 64, # Accumulate gradients over multiple steps\n", + " per_device_train_batch_size = 1, # Batch size per GPU (number of prompts processed together)\n", + " warmup_steps = 10, # Steps for learning rate warmup\n", + " optim=\"adamw_torch\", # Optimizer\n", + " max_grad_norm=1.0, # Clip gradients to prevent explosion\n", "\n", " # GRPO configuration\n", - " num_generations = 2, # Number of rollout episodes per prompt (for variance reduction)\n", - " max_completion_length = 8, # Maximum tokens generated per model response\n", + " num_generations = 2, # Number of rollout episodes per prompt (for variance reduction)\n", + " max_completion_length=1024, # Full episode length, not per-turn\n", + " log_completions = False, # Log completions for debugging\n", "\n", " # vLLM configuration\n", - " use_vllm = True, # Enable vLLM for faster inference during rollouts\n", - " vllm_mode = \"colocate\", # Run vLLM in colocate mode (same process as training)\n", - " vllm_gpu_memory_utilization = 0.1, # Fraction of GPU memory reserved for vLLM inference\n", + " use_vllm = True, # Enable vLLM for faster inference during rollouts\n", + " vllm_mode = \"colocate\", # Run vLLM in colocate mode (same process as training)\n", + " vllm_gpu_memory_utilization = 0.15, # Fraction of GPU memory reserved for vLLM inference\n", + " vllm_max_model_length=3072, # Maximum context length for vLLM\n", + " vllm_importance_sampling_correction=False,\n", "\n", " # Logging / reporting\n", - " output_dir = output_dir, # Directory for checkpoints and logs\n", - " report_to=\"trackio\", # Experiment tracking tool (integrates with HF Spaces)\n", - " trackio_space_id = output_dir, # HF Space where experiment tracking will be saved\n", - " logging_steps = 1, # Log metrics every N steps\n", - " save_steps = 10, # Interval for saving checkpoints\n", + " output_dir = output_dir, # Directory for checkpoints and logs\n", + " report_to=\"trackio\", # Experiment tracking tool (integrates with HF Spaces)\n", + " trackio_space_id = output_dir, # HF Space where experiment tracking will be saved\n", + " logging_steps = 1, # Log metrics every N steps\n", + " save_steps = 10, # Interval for saving checkpoints\n", + " save_total_limit=1, # Max number of checkpoints to save\n", "\n", " # Memory optimization\n", - " gradient_checkpointing = True, # Enable activation recomputation to save memory\n", + " gradient_checkpointing = True, # Enable activation recomputation to save memory\n", "\n", " # Hub integration\n", - " push_to_hub = True, # Set True to automatically push model to Hugging Face Hub\n", + " push_to_hub = True, # Set True to automatically push model to Hugging Face Hub\n", ")" ] }, @@ -727,11 +798,90 @@ { "cell_type": "code", "execution_count": null, - "id": "1f7aceb9-fe9e-49ba-b976-a39c1e29d4e5", + "id": "FeBMCppH7rAc", "metadata": { - "id": "1f7aceb9-fe9e-49ba-b976-a39c1e29d4e5" + "id": "FeBMCppH7rAc" }, "outputs": [], + "source": [ + "import sys\n", + "sys.stdout.fileno = lambda: 1\n", + "sys.stderr.fileno = lambda: 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f7aceb9-fe9e-49ba-b976-a39c1e29d4e5", + "metadata": { + "colab": { + "referenced_widgets": [ + "f44d7bb668064bdb80e3904ff92da5ea", + "efa028ffbd704a489729c83af0647d68" + ] + }, + "id": "1f7aceb9-fe9e-49ba-b976-a39c1e29d4e5", + "outputId": "aa6f81a6-662c-4215-f091-bcf422f43f9c" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f44d7bb668064bdb80e3904ff92da5ea", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipython-input-3695224185.py:3: UserWarning: You are importing from 'rollout_func', which is an experimental feature. This API may change or be removed at any time without prior notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.\n", + " trainer = GRPOTrainer(\n", + "The model is already on multiple devices. Skipping the move to device specified in `args`.\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "efa028ffbd704a489729c83af0647d68", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00, ?it/s]\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 0%| | 0/19 [00:00, ?it/s]/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 19/19 [00:00<00:00, 21.62it/s]\n", + "Capturing CUDA graphs (decode, FULL): 100%|██████████| 11/11 [00:00<00:00, 21.27it/s]\n" + ] + } + ], "source": [ "from trl import GRPOTrainer\n", "\n", @@ -740,9 +890,8 @@ " processing_class=tokenizer,\n", " reward_funcs=[\n", " reward_correct,\n", - " reward_greens,\n", - " reward_yellows,\n", - " reward_repetition,\n", + " reward_position,\n", + " reward_format_strict,\n", " ],\n", " train_dataset=dataset,\n", " args=grpo_config,\n", @@ -766,7 +915,7 @@ "id": "hxr5Rv0wVu_P", "metadata": { "id": "hxr5Rv0wVu_P", - "outputId": "4bbbea53-f54b-41e7-83e0-b11b6b6961ee" + "outputId": "8f638c93-ec50-487c-9dd4-dfd863b6b0ed" }, "outputs": [ { @@ -774,7 +923,7 @@ "output_type": "stream", "text": [ "GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.557 GB.\n", - "10.516 GB of memory reserved.\n" + "12.484 GB of memory reserved.\n" ] } ], @@ -801,10 +950,10 @@ { "cell_type": "code", "execution_count": null, - "id": "55c52596-4082-405b-b626-4b0401c2ce9f", + "id": "5DS2TVNTGifL", "metadata": { - "id": "55c52596-4082-405b-b626-4b0401c2ce9f", - "outputId": "da209e0c-36b7-4c9b-a61e-682ba22fa094" + "id": "5DS2TVNTGifL", + "outputId": "5de14ea9-0fbd-4167-d5ad-3d3187b8d489" }, "outputs": [ { @@ -819,15 +968,15 @@ "output_type": "stream", "text": [ "* Trackio project initialized: huggingface\n", - "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/wordle-grpo-Qwen3-1.7B-dataset\n", - "* Creating new space: https://huggingface.co/spaces/sergiopaniego/wordle-grpo-Qwen3-1.7B\n", - "* View dashboard by going to: https://sergiopaniego-wordle-grpo-Qwen3-1.7B.hf.space/\n" + "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/wordle-grpo-Qwen3-1.7B-test-dataset\n", + "* Creating new space: https://huggingface.co/spaces/sergiopaniego/wordle-grpo-Qwen3-1.7B-test\n", + "* View dashboard by going to: https://sergiopaniego-wordle-grpo-Qwen3-1.7B-test.hf.space/\n" ] }, { "data": { "text/html": [ - "
" + "" ], "text/plain": [ "| 1 | \n", - "0.008300 | \n", + "0.009800 | \n", "
| 2 | \n", - "0.001900 | \n", + "0.016400 | \n", "
| 3 | \n", - "0.015100 | \n", + "0.005600 | \n", "
| 4 | \n", - "0.008700 | \n", + "0.014700 | \n", "
| 5 | \n", - "0.009800 | \n", + "0.019500 | \n", "
| 6 | \n", - "0.006700 | \n", + "0.002300 | \n", "
| 7 | \n", - "0.006100 | \n", + "0.005300 | \n", "
| 8 | \n", - "0.004400 | \n", + "0.025100 | \n", "
| 9 | \n", - "-0.002100 | \n", + "0.004500 | \n", "
| 10 | \n", - "0.007500 | \n", + "0.004200 | \n", "
| 11 | \n", - "0.008400 | \n", + "0.009600 | \n", "
| 12 | \n", - "0.008000 | \n", + "0.014900 | \n", "
| 13 | \n", - "0.007800 | \n", + "0.024500 | \n", "
| 14 | \n", - "-0.002400 | \n", + "0.012200 | \n", "
| 15 | \n", - "-0.003200 | \n", + "0.015500 | \n", "
| 16 | \n", - "-0.006000 | \n", + "0.007400 | \n", "
| 17 | \n", - "-0.008300 | \n", + "0.017500 | \n", "
| 18 | \n", - "-0.011000 | \n", + "0.014900 | \n", "
| 19 | \n", - "-0.004200 | \n", + "0.035600 | \n", "
| 20 | \n", - "-0.001700 | \n", + "0.014900 | \n", "
| 21 | \n", - "-0.004100 | \n", + "0.030000 | \n", "
| 22 | \n", - "-0.011600 | \n", + "0.014300 | \n", "
| 23 | \n", - "-0.006400 | \n", + "0.018000 | \n", "
| 24 | \n", - "-0.009100 | \n", + "0.014000 | \n", "
| 25 | \n", - "0.003200 | \n", + "0.016600 | \n", "
| 26 | \n", - "0.005100 | \n", + "0.015600 | \n", "
| 27 | \n", - "-0.002800 | \n", + "0.021300 | \n", "
| 28 | \n", - "0.001400 | \n", + "0.021000 | \n", "
| 29 | \n", - "0.011500 | \n", + "0.036900 | \n", "
| 30 | \n", - "-0.010500 | \n", + "0.006400 | \n", "
| 31 | \n", - "-0.006400 | \n", + "0.044800 | \n", + "
| 32 | \n", + "0.026400 | \n", + "|
| 33 | \n", + "0.038700 | \n", + "|
| 34 | \n", + "0.022000 | \n", + "|
| 35 | \n", + "0.013400 | \n", + "|
| 36 | \n", + "0.025000 | \n", + "|
| 37 | \n", + "0.042900 | \n", + "|
| 38 | \n", + "0.072700 | \n", + "|
| 39 | \n", + "0.070100 | \n", + "|
| 40 | \n", + "0.019900 | \n", + "|
| 41 | \n", + "0.058700 | \n", + "|
| 42 | \n", + "0.060100 | \n", + "|
| 43 | \n", + "-0.026700 | \n", + "|
| 44 | \n", + "0.038900 | \n", + "|
| 45 | \n", + "0.042400 | \n", + "|
| 46 | \n", + "-0.009100 | \n", + "|
| 47 | \n", + "0.001300 | \n", + "|
| 48 | \n", + "0.020200 | \n", + "|
| 49 | \n", + "0.078700 | \n", + "|
| 50 | \n", + "0.026300 | \n", + "|
| 51 | \n", + "0.045700 | \n", + "|
| 52 | \n", + "0.035300 | \n", + "|
| 53 | \n", + "-0.006700 | \n", + "|
| 54 | \n", + "0.025300 | \n", + "|
| 55 | \n", + "0.069500 | \n", + "|
| 56 | \n", + "0.092800 | \n", + "|
| 57 | \n", + "0.067900 | \n", + "|
| 58 | \n", + "0.035000 | \n", + "|
| 59 | \n", + "0.061300 | \n", + "|
| 60 | \n", + "0.048800 | \n", + "|
| 61 | \n", + "0.000600 | \n", + "|
| 62 | \n", + "0.028400 | \n", + "|
| 63 | \n", + "0.016200 | \n", + "|
| 64 | \n", + "0.010700 | \n", + "|
| 65 | \n", + "0.020200 | \n", + "|
| 66 | \n", + "0.041800 | \n", + "|
| 67 | \n", + "0.006800 | \n", + "|
| 68 | \n", + "0.014800 | \n", + "|
| 69 | \n", + "0.025100 | \n", + "|
| 70 | \n", + "-0.006600 | \n", + "|
| 71 | \n", + "0.041000 | \n", + "|
| 72 | \n", + "0.008300 | \n", + "|
| 73 | \n", + "0.045300 | \n", + "|
| 74 | \n", + "0.062800 | \n", + "|
| 75 | \n", + "0.048200 | \n", + "|
| 76 | \n", + "0.032800 | \n", + "|
| 77 | \n", + "0.053000 | \n", + "|
| 78 | \n", + "0.023100 | \n", + "|
| 79 | \n", + "0.014900 | \n", + "|
| 80 | \n", + "0.078200 | \n", + "|
| 81 | \n", + "-0.000700 | \n", + "|
| 82 | \n", + "0.013400 | \n", + "|
| 83 | \n", + "0.030200 | \n", + "|
| 84 | \n", + "-0.003600 | \n", + "|
| 85 | \n", + "0.051700 | \n", + "|
| 86 | \n", + "0.033500 | \n", + "|
| 87 | \n", + "0.021800 | \n", + "|
| 88 | \n", + "-0.003400 | \n", + "|
| 89 | \n", + "0.023200 | \n", + "|
| 90 | \n", + "-0.002900 | \n", + "|
| 91 | \n", + "0.030900 | \n", + "|
| 92 | \n", + "0.029200 | \n", + "|
| 93 | \n", + "0.002500 | \n", "
" @@ -995,42 +1392,46 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "INFO 11-21 12:16:45 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:19:33 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:22:23 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:25:11 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:27:59 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:30:47 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:33:36 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:36:24 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:39:12 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:42:38 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:45:41 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:48:28 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:51:17 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:54:05 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:56:52 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 12:59:08 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:01:36 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:04:24 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:06:43 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:10:09 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:12:22 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:14:22 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:17:12 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:19:13 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:22:01 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:24:52 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:27:41 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:30:32 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:33:22 [block_pool.py:292] Successfully reset prefix cache\n", - "INFO 11-21 13:37:30 [block_pool.py:292] Successfully reset prefix cache\n", "* Run finished. Uploading logs to Trackio (please wait...)\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n" + ] } ], "source": [ @@ -1053,19 +1454,19 @@ "id": "zuHTwuxAVp8p", "metadata": { "id": "zuHTwuxAVp8p", - "outputId": "205f9e0e-e996-4ba7-fade-3b78977315e4" + "outputId": "fce9bdc8-d734-4382-bb26-7e03dbffa7a0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5231.7046 seconds used for training.\n", - "87.2 minutes used for training.\n", - "Peak reserved memory = 36.68 GB.\n", - "Peak reserved memory for training = 26.164 GB.\n", - "Peak reserved memory % of max memory = 92.727 %.\n", - "Peak reserved memory for training % of max memory = 66.143 %.\n" + "12065.8973 seconds used for training.\n", + "201.1 minutes used for training.\n", + "Peak reserved memory = 38.139 GB.\n", + "Peak reserved memory for training = 25.655 GB.\n", + "Peak reserved memory % of max memory = 96.415 %.\n", + "Peak reserved memory for training % of max memory = 64.856 %.\n" ] } ], @@ -1090,30 +1491,28 @@ "metadata": { "colab": { "referenced_widgets": [ - "2e8053fa4a524b03842a23f987f0b09b", - "5b42c51a14384ea0b222a665ae0f35dd", - "df766eead3e94399918561f47e4c94c2", - "685bd447dd6d4e4b883e22311bca1982", - "af0d3d7a321c493c9a67eb7e7f9167d6", - "2ce0cdfe83624a458a2205925f88f4f8", - "4a3170d4bcbb4e40b97921cf57fb9d3d", - "7c9234157def4dddb7c08a49d9c83d4d", - "1e49dbc0f9a741a28b53930ece8de736", - "ff14a1ce9dcf4250add61cfb9ae262f5", - "a7b733c562d7432a8a18fee70f6a0248", - "0539d5ee34234fb08cb93996fa7a26ed", - "ed426ec7315e405e976465fdf34f0eb2", - "4a5073ad35954e4a96c80f3fedf91bc9" + "decd9f00c4da42bf92b72c327bd28278", + "2d924050f7bf4e7f88316c8fc202a763", + "d589783221084eb7833ae6cd742d277c", + "0e135c821b5744b287b4de7eeb15d419", + "a1839712ff344a409e6f7f48a1467fd5", + "e9ae0fcd43e34d7e916fe1bda0a38a49", + "75776d6523ef42df930ddfd7048b384e", + "e2e07a449d914bd39653b7cbbc5903e3", + "0eafc3f9bac14807866233f924793380", + "b64c487a9dff4108a66da9eee4e4ed66", + "17a3ba38cf7349269ea54df84faf30b7", + "7382295b99ee4db28de43e1451dd0d17" ] }, "id": "13e9fd4e-e7a5-468d-a25a-3f7d2794201f", - "outputId": "1c0fce8a-8039-4d1e-b63f-60e114d9169c" + "outputId": "7f703ed8-7874-4da1-8490-48222755ae11" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2e8053fa4a524b03842a23f987f0b09b", + "model_id": "decd9f00c4da42bf92b72c327bd28278", "version_major": 2, "version_minor": 0 }, @@ -1127,7 +1526,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5b42c51a14384ea0b222a665ae0f35dd", + "model_id": "2d924050f7bf4e7f88316c8fc202a763", "version_major": 2, "version_minor": 0 }, @@ -1141,12 +1540,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "df766eead3e94399918561f47e4c94c2", + "model_id": "d589783221084eb7833ae6cd742d277c", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " ...n3-1.7B/training_args.bin: 100%|##########| 7.31kB / 7.31kB " + " ...7B-test/training_args.bin: 100%|##########| 7.70kB / 7.70kB " ] }, "metadata": {}, @@ -1155,12 +1554,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "685bd447dd6d4e4b883e22311bca1982", + "model_id": "0e135c821b5744b287b4de7eeb15d419", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " ...Qwen3-1.7B/tokenizer.json: 100%|##########| 11.4MB / 11.4MB " + " ...-1.7B-test/tokenizer.json: 100%|##########| 11.4MB / 11.4MB " ] }, "metadata": {}, @@ -1169,26 +1568,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "af0d3d7a321c493c9a67eb7e7f9167d6", + "model_id": "a1839712ff344a409e6f7f48a1467fd5", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " ...adapter_model.safetensors: 100%|##########| 25.7MB / 25.7MB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2ce0cdfe83624a458a2205925f88f4f8", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...0002-of-00002.safetensors: 2%|2 | 41.9MB / 1.91GB " + " ...0002-of-00002.safetensors: 2%|1 | 33.5MB / 1.91GB " ] }, "metadata": {}, @@ -1197,7 +1582,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4a3170d4bcbb4e40b97921cf57fb9d3d", + "model_id": "e9ae0fcd43e34d7e916fe1bda0a38a49", "version_major": 2, "version_minor": 0 }, @@ -1219,7 +1604,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "7c9234157def4dddb7c08a49d9c83d4d", + "model_id": "75776d6523ef42df930ddfd7048b384e", "version_major": 2, "version_minor": 0 }, @@ -1233,7 +1618,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1e49dbc0f9a741a28b53930ece8de736", + "model_id": "e2e07a449d914bd39653b7cbbc5903e3", "version_major": 2, "version_minor": 0 }, @@ -1247,12 +1632,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ff14a1ce9dcf4250add61cfb9ae262f5", + "model_id": "0eafc3f9bac14807866233f924793380", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " ...n3-1.7B/training_args.bin: 100%|##########| 7.31kB / 7.31kB " + " ...7B-test/training_args.bin: 100%|##########| 7.70kB / 7.70kB " ] }, "metadata": {}, @@ -1261,12 +1646,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a7b733c562d7432a8a18fee70f6a0248", + "model_id": "b64c487a9dff4108a66da9eee4e4ed66", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " ...Qwen3-1.7B/tokenizer.json: 100%|##########| 11.4MB / 11.4MB " + " ...-1.7B-test/tokenizer.json: 100%|##########| 11.4MB / 11.4MB " ] }, "metadata": {}, @@ -1275,12 +1660,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0539d5ee34234fb08cb93996fa7a26ed", + "model_id": "17a3ba38cf7349269ea54df84faf30b7", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " ...0001-of-00002.safetensors: 1%| | 41.9MB / 4.97GB " + " ...0001-of-00002.safetensors: 1%| | 33.5MB / 4.97GB " ] }, "metadata": {}, @@ -1289,7 +1674,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ed426ec7315e405e976465fdf34f0eb2", + "model_id": "7382295b99ee4db28de43e1451dd0d17", "version_major": 2, "version_minor": 0 }, @@ -1300,20 +1685,6 @@ "metadata": {}, "output_type": "display_data" }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4a5073ad35954e4a96c80f3fedf91bc9", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...adapter_model.safetensors: 100%|##########| 25.7MB / 25.7MB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, { "name": "stderr", "output_type": "stream", @@ -1328,7 +1699,7 @@ "type": "string" }, "text/plain": [ - "CommitInfo(commit_url='https://huggingface.co/sergiopaniego/wordle-grpo-Qwen3-1.7B/commit/b81b548867ab35601d3bda845ed5e18147550e30', commit_message='End of training', commit_description='', oid='b81b548867ab35601d3bda845ed5e18147550e30', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sergiopaniego/wordle-grpo-Qwen3-1.7B', endpoint='https://huggingface.co', repo_type='model', repo_id='sergiopaniego/wordle-grpo-Qwen3-1.7B'), pr_revision=None, pr_num=None)" + "CommitInfo(commit_url='https://huggingface.co/sergiopaniego/wordle-grpo-Qwen3-1.7B-test/commit/2d7a27066ef244796a079cbf08fa6656af426145', commit_message='End of training', commit_description='', oid='2d7a27066ef244796a079cbf08fa6656af426145', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sergiopaniego/wordle-grpo-Qwen3-1.7B-test', endpoint='https://huggingface.co', repo_type='model', repo_id='sergiopaniego/wordle-grpo-Qwen3-1.7B-test'), pr_revision=None, pr_num=None)" ] }, "execution_count": 15, @@ -1367,11 +1738,11 @@ ] }, "id": "JcTeeSBXxWWF", - "outputId": "b5b06184-11f6-4d01-d6ce-e63ad04419d0" + "outputId": "86efafc3-1161-471b-86b1-14c43e95908f" }, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:104: UserWarning: \n", @@ -1507,7 +1878,7 @@ "id": "JjOzWexUXmfW", "metadata": { "id": "JjOzWexUXmfW", - "outputId": "d7f2fed8-7d22-493e-d530-be5a6aff19a2" + "outputId": "1c6130af-fe89-4930-e53a-7329e0483ef0" }, "outputs": [ { @@ -1651,15 +2022,15 @@ } ], "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [] - }, "language_info": { "name": "python" - } + }, + "colab": { + "provenance": [], + "gpuType": "A100" + }, + "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/examples/scripts/openenv/wordle.py b/examples/scripts/openenv/wordle.py index 9f199b5f1a5..7053e089254 100644 --- a/examples/scripts/openenv/wordle.py +++ b/examples/scripts/openenv/wordle.py @@ -16,9 +16,10 @@ # dependencies = [ # "trl[vllm]", # "peft", -# "trackio", +# "trackio>=0.13.0", # "kernels", -# "openenv-textarena @ git+https://huggingface.co/spaces/burtenshaw/wordle", +# "openenv @ git+https://github.com/meta-pytorch/OpenEnv.git", +# "openenv_core", # ] # /// @@ -49,25 +50,54 @@ CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py --vllm-mode server --vllm-server-url http://localhost:8000 ``` -# Option 3: Local + Colocated vLLM (1 GPU required) +# Option 3: Local Environment + Colocated vLLM (1 GPU required) -Usage: +To run the Wordle environment locally, you have several options: -# Start the environment only if using --env-mode docker-local; In other modes, the env is automatically managed by the script. +## Option 3a: Using Docker Image (Recommended) + +First, build the Docker image from the textarena_env directory: ```sh -docker run -d -p 8001:8001 registry.hf.space/burtenshaw-wordle:latest +cd 3rd_party/OpenEnv/envs/textarena_env +docker build -t textarena-env:latest -f server/Dockerfile . ``` +Then run the environment server: ```sh -python examples/scripts/openenv/wordle.py --vllm-mode colocate +docker run -d -p 8001:8001 textarena-env:latest ``` -""" -from __future__ import annotations +Finally, run training pointing to local server: +```sh +python examples/scripts/openenv/wordle.py --vllm-mode colocate --env-url http://localhost:8001 +``` + +## Option 3b: Running Server Directly + +From the textarena_env directory: +```sh +cd 3rd_party/OpenEnv/envs/textarena_env +uv venv && source .venv/bin/activate +uv pip install -e . +python -m uvicorn server.app:app --reload --port 8001 +``` + +Then in another terminal, run training: +```sh +python examples/scripts/openenv/wordle.py --vllm-mode colocate --env-url http://localhost:8001 +``` + +## Option 3c: Using Pre-built HF Space Image + +```sh +docker run -d -p 8001:8001 registry.hf.space/burtenshaw-wordle:latest +python examples/scripts/openenv/wordle.py --vllm-mode colocate --env-url http://localhost:8001 +``` +""" import argparse +import re import sys -from collections import defaultdict from collections.abc import Iterable from datetime import datetime from pathlib import Path @@ -102,7 +132,7 @@ def parse_args() -> argparse.Namespace: help="Model identifier passed to GRPOTrainer for fine-tuning.", ) parser.add_argument( - "--env-url", type=str, default="https://burtenshaw-wordle.hf.space", help="URL for the environment server." + "--env-url", type=str, default="https://sergiopaniego-wordle.hf.space", help="URL for the environment server." ) parser.add_argument( "--system-prompt-path", @@ -153,7 +183,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--learning-rate", type=float, - default=5e-6, + default=1e-6, help="Learning rate for GRPO training.", ) parser.add_argument( @@ -171,7 +201,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--warmup-steps", type=int, - default=20, + default=10, help="Warmup steps for the scheduler.", ) parser.add_argument( @@ -183,7 +213,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--num-generations", type=int, - default=2, + default=4, help="Number of rollout generations per dataset prompt.", ) parser.add_argument( @@ -242,12 +272,6 @@ def parse_args() -> argparse.Namespace: default=1, help="Frequency of logging steps for GRPO training.", ) - parser.add_argument( - "--debug", - action="store_true", - default=False, - help="Enable verbose debugging output during rollouts.", - ) return parser.parse_args() @@ -280,20 +304,9 @@ def format_history(messages: Iterable[TextArenaMessage]) -> str: def make_user_prompt(prompt_text: str, messages: Iterable[TextArenaMessage]) -> str: history = format_history(messages) - prompt_section = prompt_text.strip() if prompt_text.strip() else "Wordle-v0" + # Only use messages for conversation history - the prompt is already included as the first message history_section = history if history else "[PROMPT] Awaiting first feedback." - return ( - f"Game prompt:\n{prompt_section}\n\n" - f"Conversation so far:\n{history_section}\n\n" - "Reply with your next guess enclosed in square brackets." - ) - - -def scale_repetition_score(previous_occurrences: int, max_occurrences: int) -> float: - """Scale the repetition score based on the number of previous occurrences from 0 to 1""" - if max_occurrences == 0: - return 0.0 - return (max_occurrences - previous_occurrences) / max_occurrences + return f"Conversation so far:\n{history_section}\n\nReply with your next guess enclosed in square brackets." def rollout_once( @@ -303,6 +316,7 @@ def rollout_once( dataset_prompt: str, system_prompt: str, max_turns: int, + max_new_tokens: int = 16, ) -> dict[str, list]: result = env.reset() observation = result.observation @@ -310,25 +324,44 @@ def rollout_once( prompt_ids: list[int] = [] completion_ids: list[int] = [] logprobs: list[float] = [] + env_mask: list[int] = [] # 1 for model-generated tokens, 0 for environment tokens + model_outputs: list[str] = [] raw_rewards: list[float] = [] - green_scores: list[float] = [] - yellow_scores: list[float] = [] - repetition_scores: list[float] = [] + position_scores: list[float] = [] correct_scores: list[float] = [] - guess_counts: defaultdict[str, int] = defaultdict(int) + prev_env_output_len: int = 0 # Track length to only add NEW portion each turn + + accumulated_messages: list[dict[str, str]] = [{"role": "system", "content": system_prompt}] + # Build initial prompt (only once, at the start) + # The initial env messages are included in the prompt, not completion + base_prompt = observation.prompt or dataset_prompt + initial_user_prompt = make_user_prompt(base_prompt, observation.messages) + # Track initial env output length so we don't add it again + initial_env_output = format_history(observation.messages) if observation.messages else "" + prev_env_output_len = len(initial_env_output) + initial_messages = accumulated_messages + [{"role": "user", "content": initial_user_prompt}] + initial_prompt_text = tokenizer.apply_chat_template( + initial_messages, + add_generation_prompt=True, + tokenize=False, + enable_thinking=False, + ) + # Tokenize initial prompt once - this is the base prompt for the entire episode. + # GRPO expects one prompt-completion pair per episode, where: + # - prompt_ids = the initial/base prompt (what the model sees at episode start) + # - completion_ids = all model responses + env feedback from all turns concatenated + # Note: The actual prompts used for generation in each turn are longer (include conversation history), + # but we only count the initial prompt tokens here. + initial_prompt_ids = tokenizer.encode(initial_prompt_text, add_special_tokens=False) + prompt_ids.extend(initial_prompt_ids) for _turn in range(max_turns): - # when the game is over the environment will return a done=True if result.done: break - # set up the prompt for the model base_prompt = observation.prompt or dataset_prompt user_prompt = make_user_prompt(base_prompt, observation.messages) - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] + messages = accumulated_messages + [{"role": "user", "content": user_prompt}] prompt_text = tokenizer.apply_chat_template( messages, add_generation_prompt=True, @@ -336,53 +369,80 @@ def rollout_once( enable_thinking=False, ) - rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0] - prompt_ids.extend(rollout_outputs["prompt_ids"]) + rollout_outputs = generate_rollout_completions( + trainer, [prompt_text], generation_overrides={"max_tokens": max_new_tokens} + )[0] + # Add model-generated completion tokens and logprobs with newlines for readability + newline_tokens = tokenizer.encode("\n", add_special_tokens=False) + completion_ids.extend(newline_tokens) # newline before guess + logprobs.extend([0.0] * len(newline_tokens)) + env_mask.extend([1] * len(newline_tokens)) # newlines are part of model output format + completion_ids.extend(rollout_outputs["completion_ids"]) logprobs.extend(rollout_outputs["logprobs"]) + env_mask.extend([1] * len(rollout_outputs["completion_ids"])) # model-generated tokens + + completion_ids.extend(newline_tokens) # newline after guess + logprobs.extend([0.0] * len(newline_tokens)) + env_mask.extend([1] * len(newline_tokens)) # newlines are part of model output format completion_text = rollout_outputs.get("text") or tokenizer.decode( rollout_outputs["completion_ids"], skip_special_tokens=True ) - # extract the guess from the completion guess = extract_guess(completion_text) + model_outputs.append(completion_text.strip()) # Store raw model output for format reward - # step the environment with the guess result = env.step(TextArenaAction(message=guess)) + raw_rewards.append(float(result.reward or 0.0)) observation = result.observation correct_score = float(result.reward or 0.0) feedback = extract_wordle_feedback(observation) - # Update guess counts - previous_occurrences = guess_counts[guess] - repetition_score = scale_repetition_score(previous_occurrences, len(guess_counts)) - guess_counts[guess] += 1 + full_env_output = format_history(observation.messages) if observation.messages else "" + new_env_output = full_env_output[prev_env_output_len:].lstrip("\n") + prev_env_output_len = len(full_env_output) + + if new_env_output: + env_output_tokens = tokenizer.encode(new_env_output, add_special_tokens=False) + completion_ids.extend(env_output_tokens) # Add to completion_ids + logprobs.extend([0.0] * len(env_output_tokens)) # Placeholder (ignored via env_mask=0) + env_mask.extend([0] * len(env_output_tokens)) # Environment tokens - mask out from loss + completion_with_env = completion_text + "\n" + new_env_output + else: + completion_with_env = completion_text + + accumulated_messages.append({"role": "user", "content": user_prompt}) + accumulated_messages.append({"role": "assistant", "content": completion_with_env}) - # calculate custom reward signals from the feedback if not feedback: - green_score = 0.0 - yellow_score = 0.0 + position_score = 0.0 else: green_count, yellow_count = extract_feedback_counts(feedback) - green_score = green_count / 5.0 - yellow_score = yellow_count / 5.0 + position_score = (green_count + 0.5 * yellow_count) / 5.0 - repetition_scores.append(repetition_score) - green_scores.append(green_score) - yellow_scores.append(yellow_score) + position_scores.append(position_score) correct_scores.append(correct_score) + # Use the final correct reward (win/lose is binary at end) correct_reward_value = correct_scores[-1] if correct_scores else (raw_rewards[-1] if raw_rewards else 0.0) + # Position reward as shaping signal: + # - If model WINS: position_reward = 1.0 (no penalty for winning fast) + # - If model LOSES: position_reward = last attempt (where it ended up) + if correct_reward_value >= 1.0: + final_position_reward = 1.0 + else: + final_position_reward = position_scores[-1] if position_scores else 0.0 + return { "prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs, + "env_mask": env_mask, "raw_rewards": raw_rewards, "correct_reward": correct_reward_value, - "green_reward": green_scores[-1] if green_scores else 0.0, - "yellow_reward": yellow_scores[-1] if yellow_scores else 0.0, - "repetition_reward": repetition_scores[-1] if repetition_scores else 0.0, + "position_reward": final_position_reward, + "model_outputs": model_outputs, } @@ -392,28 +452,39 @@ def rollout_once( def reward_correct(completions: list[str], **kwargs) -> list[float]: + """Reward from environment (correct answer).""" rewards = kwargs.get("correct_reward") if kwargs else None if rewards is None: return [0.0 for _ in completions] return [float(r) for r in rewards] -def reward_greens(completions: list[str], **kwargs) -> list[float]: - rewards = kwargs.get("green_reward") if kwargs else None +def reward_position(completions: list[str], **kwargs) -> list[float]: + """Position reward: green worth 1.0, yellow worth 0.5, normalized by 5.""" + rewards = kwargs.get("position_reward") if kwargs else None if rewards is None: return [0.0 for _ in completions] return [float(r) for r in rewards] -def reward_yellows(completions: list[str], **kwargs) -> list[float]: - rewards = kwargs.get("yellow_reward") if kwargs else None - if rewards is None: - return [0.0 for _ in completions] - return [float(r) for r in rewards] +def compute_format_reward(model_outputs: list[str]) -> float: + """Compute format reward from a list of model outputs (one per turn). + + Each output should be exactly [5 letters] with optional whitespace. + Returns proportion of correctly formatted outputs. + """ + if not model_outputs: + return 0.0 + + exact_pattern = re.compile(r"^\s*\[[A-Za-z]{5}\]\s*$") + correct_count = sum(1 for output in model_outputs if exact_pattern.match(output)) + + return correct_count / len(model_outputs) -def reward_repetition(completions: list[str], **kwargs) -> list[float]: - rewards = kwargs.get("repetition_reward") if kwargs else None +def reward_format_strict(completions: list[str], **kwargs) -> list[float]: + """Format reward - pre-computed in rollout_func.""" + rewards = kwargs.get("format_reward") if kwargs else None if rewards is None: return [0.0 for _ in completions] return [float(r) for r in rewards] @@ -452,8 +523,9 @@ def main() -> None: per_device_train_batch_size=args.per_device_batch_size, warmup_steps=args.warmup_steps, num_generations=args.num_generations, - max_completion_length=args.max_new_tokens, + max_completion_length=1024, # Full episode length, not per-turn logging_steps=args.logging_steps, + log_completions=True, report_to="trackio", trackio_space_id=f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}", save_strategy="steps", @@ -462,20 +534,25 @@ def main() -> None: temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, + vllm_gpu_memory_utilization=0.25, + vllm_max_model_length=8192, + vllm_importance_sampling_correction=False, + optim="adamw_torch", + max_grad_norm=1.0, # Clip gradients to prevent explosion ) grpo_config.run_name = args.run_name or f"run-{timestamp}" - grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}" + grpo_config.project = args.project or f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}" grpo_config.trackio_space_id = args.trackio_space_id def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: episode_prompt_ids: list[list[int]] = [] episode_completion_ids: list[list[int]] = [] episode_logprobs: list[list[float]] = [] + episode_env_masks: list[list[int]] = [] correctness_rewards: list[float] = [] - green_rewards: list[float] = [] - yellow_rewards: list[float] = [] - repetition_rewards: list[float] = [] + position_rewards: list[float] = [] + format_rewards: list[float] = [] for prompt_text in prompts: episode = rollout_once( @@ -485,23 +562,24 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: dataset_prompt=prompt_text, system_prompt=system_prompt, max_turns=args.max_turns, + max_new_tokens=args.max_new_tokens, ) episode_prompt_ids.append(episode["prompt_ids"]) episode_completion_ids.append(episode["completion_ids"]) episode_logprobs.append(episode["logprobs"]) + episode_env_masks.append(episode["env_mask"]) correctness_rewards.append(episode["correct_reward"]) - green_rewards.append(episode["green_reward"]) - yellow_rewards.append(episode["yellow_reward"]) - repetition_rewards.append(episode["repetition_reward"]) + position_rewards.append(episode["position_reward"]) + format_rewards.append(compute_format_reward(episode["model_outputs"])) return { "prompt_ids": episode_prompt_ids, "completion_ids": episode_completion_ids, "logprobs": episode_logprobs, + "env_mask": episode_env_masks, "correct_reward": correctness_rewards, - "green_reward": green_rewards, - "yellow_reward": yellow_rewards, - "repetition_reward": repetition_rewards, + "position_reward": position_rewards, + "format_reward": format_rewards, } trainer = GRPOTrainer( @@ -509,9 +587,8 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: processing_class=tokenizer, reward_funcs=[ reward_correct, - reward_greens, - reward_yellows, - reward_repetition, + reward_position, + reward_format_strict, ], train_dataset=dataset, args=grpo_config, diff --git a/trl/experimental/openenv/utils.py b/trl/experimental/openenv/utils.py index 5e7053da0c1..5c4c710132b 100644 --- a/trl/experimental/openenv/utils.py +++ b/trl/experimental/openenv/utils.py @@ -26,31 +26,22 @@ from vllm.sampling_params import StructuredOutputsParams -def _build_colocate_sampling_params( +def _build_base_generation_kwargs( trainer, overrides: dict[str, Any] | None = None, - *, - logprobs: bool = True, -) -> "SamplingParams": - if trainer.structured_outputs_regex: - structured_outputs = StructuredOutputsParams(regex=trainer.structured_outputs_regex) - else: - structured_outputs = None - +) -> dict[str, Any]: + """Build base generation kwargs common to both colocate and server modes.""" generation_kwargs: dict[str, Any] = { "n": 1, "temperature": trainer.temperature, "top_k": trainer.top_k, "min_p": 0.0 if trainer.min_p is None else trainer.min_p, "max_tokens": trainer.max_completion_length, - "structured_outputs": structured_outputs, } if trainer.repetition_penalty is not None: generation_kwargs["repetition_penalty"] = trainer.repetition_penalty if trainer.top_p is not None: generation_kwargs["top_p"] = trainer.top_p - if logprobs: - generation_kwargs["logprobs"] = 0 if trainer.args.generation_kwargs is not None: generation_kwargs.update(trainer.args.generation_kwargs) @@ -60,10 +51,38 @@ def _build_colocate_sampling_params( generation_kwargs = {key: value for key, value in generation_kwargs.items() if value is not None} - sampling_params = SamplingParams(**generation_kwargs) - if sampling_params.n != 1: - raise ValueError("generate_rollout_completions expects n=1 when using colocated vLLM.") - return sampling_params + if generation_kwargs.get("n", 1) != 1: + raise ValueError("generate_rollout_completions expects n=1.") + + return generation_kwargs + + +def _build_colocate_sampling_params( + trainer, + overrides: dict[str, Any] | None = None, + *, + logprobs: bool = True, +) -> "SamplingParams": + """Build SamplingParams for colocate mode.""" + generation_kwargs = _build_base_generation_kwargs(trainer, overrides) + + # Add colocate-specific parameters + if trainer.vllm_generation.structured_outputs_regex: + generation_kwargs["structured_outputs"] = StructuredOutputsParams( + regex=trainer.vllm_generation.structured_outputs_regex + ) + if logprobs: + generation_kwargs["logprobs"] = 0 + + return SamplingParams(**generation_kwargs) + + +def _build_server_generation_kwargs( + trainer, + overrides: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Build generation kwargs for server mode.""" + return _build_base_generation_kwargs(trainer, overrides) def generate_rollout_completions( @@ -74,7 +93,7 @@ def generate_rollout_completions( as_chat: bool | None = None, ) -> list[dict[str, Any]]: """ - Generate completions for custom rollouts when vLLM is running in colocate mode. + Generate completions for custom rollouts when vLLM is running in colocate or server mode. Returns one result per prompt, containing prompt and completion token ids along with per-token log probabilities and the generated text. @@ -83,31 +102,83 @@ def generate_rollout_completions( if not prompts: return [] - if not trainer.use_vllm or trainer.vllm_mode != "colocate": - raise RuntimeError("Custom rollouts require vLLM in colocate mode to call generate_rollout_completions.") + if not trainer.use_vllm: + raise RuntimeError("Custom rollouts require vLLM to call generate_rollout_completions.") + + if trainer.vllm_mode == "server": + return _generate_rollout_completions_server(trainer, prompts, generation_overrides, as_chat) + elif trainer.vllm_mode == "colocate": + return _generate_rollout_completions_colocate(trainer, prompts, generation_overrides, as_chat) + else: + raise ValueError(f"vllm_mode must be 'server' or 'colocate', got '{trainer.vllm_mode}'") + + +def _generate_rollout_completions_server( + trainer, + prompts: list[str], + generation_overrides: dict[str, Any] | None = None, + as_chat: bool | None = None, +) -> list[dict[str, Any]]: + """Generate completions using vLLM server mode.""" + generation_kwargs = _build_server_generation_kwargs(trainer, generation_overrides) + + if as_chat is None: + as_chat = prompts and is_conversational({"prompt": prompts[0]}) + + with profiling_context(trainer, "vLLM.generate_rollout_server"): + if as_chat: + # For chat mode, we need to pass messages format + # Since prompts are already formatted strings, we use generate instead + output = trainer.vllm_generation.vllm_client.generate(prompts=prompts, **generation_kwargs) + else: + output = trainer.vllm_generation.vllm_client.generate(prompts=prompts, **generation_kwargs) + + # Format results to match colocate output format + results: list[dict[str, Any]] = [] + for i in range(len(prompts)): + results.append( + { + "prompt_ids": output["prompt_ids"][i], + "completion_ids": list(output["completion_ids"][i]), + "logprobs": list(output["logprobs"][i]), + "text": trainer.processing_class.decode(output["completion_ids"][i], skip_special_tokens=True), + } + ) + + return results + +def _generate_rollout_completions_colocate( + trainer, + prompts: list[str], + generation_overrides: dict[str, Any] | None = None, + as_chat: bool | None = None, +) -> list[dict[str, Any]]: + """Generate completions using vLLM colocate mode.""" sampling_params = _build_colocate_sampling_params(trainer, generation_overrides) prompts_for_generation = prompts original_size = len(prompts) if trainer.vllm_tensor_parallel_size > 1: gathered_prompts = [None for _ in range(trainer.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts, group=trainer.tp_group) + torch.distributed.all_gather_object(gathered_prompts, prompts, group=trainer.vllm_generation.tp_group) prompts_for_generation = [prompt for group_prompts in gathered_prompts for prompt in group_prompts] if as_chat is None: as_chat = prompts_for_generation and is_conversational({"prompt": prompts_for_generation[0]}) if trainer.args.vllm_enable_sleep_mode: - trainer.llm.wake_up(tags=["kv_cache"]) + trainer.vllm_generation.llm.wake_up(tags=["kv_cache"]) # Work around for https://github.com/vllm-project/vllm/issues/29341 - trainer.llm.collective_rpc("reload_weights") + trainer.vllm_generation.llm.collective_rpc("reload_weights") with profiling_context(trainer, "vLLM.generate_rollout"): if as_chat: - vllm_outputs = trainer.llm.chat(prompts_for_generation, sampling_params=sampling_params, use_tqdm=False) + vllm_outputs = trainer.vllm_generation.llm.chat( + prompts_for_generation, sampling_params=sampling_params, use_tqdm=False + ) else: - vllm_outputs = trainer.llm.generate( + vllm_outputs = trainer.vllm_generation.llm.generate( prompts_for_generation, sampling_params=sampling_params, use_tqdm=False ) @@ -128,11 +199,11 @@ def generate_rollout_completions( ) if trainer.vllm_tensor_parallel_size > 1: - local_rank_in_group = torch.distributed.get_rank(group=trainer.tp_group) + local_rank_in_group = torch.distributed.get_rank(group=trainer.vllm_generation.tp_group) tp_slice = slice(local_rank_in_group * original_size, (local_rank_in_group + 1) * original_size) results = results[tp_slice] if trainer.args.vllm_enable_sleep_mode: - trainer.llm.sleep(level=2) + trainer.vllm_generation.llm.sleep(level=2) return results diff --git a/trl/extras/profiling.py b/trl/extras/profiling.py index 4a81e47255c..185ac72932c 100644 --- a/trl/extras/profiling.py +++ b/trl/extras/profiling.py @@ -169,6 +169,7 @@ def profiling_decorator(func: Callable) -> Callable: Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`]. This decorator works with methods that have access to a trainer instance (typically as `self`). + For non-Trainer objects that have an `accelerator` attribute, it will use that for logging configuration. Args: func (`Callable`): @@ -195,7 +196,22 @@ def some_method(self): @functools.wraps(func) def wrapper(self, *args, **kwargs): - with profiling_context(self, func.__name__): + # Check if self is a Trainer-like object with required attributes + if hasattr(self, "state") and hasattr(self, "args"): + with profiling_context(self, func.__name__): + return func(self, *args, **kwargs) + # For non-Trainer objects (e.g., VLLMGeneration), use ProfilingContext directly + elif hasattr(self, "accelerator"): + context_name = f"{self.__class__.__name__}.{func.__name__}" + with ProfilingContext( + name=context_name, + report_to=[], # No reporting for non-Trainer objects without args + is_main_process=self.accelerator.is_main_process, + step=None, + ): + return func(self, *args, **kwargs) + else: + # No profiling available, just run the function return func(self, *args, **kwargs) return wrapper diff --git a/trl/generation/vllm_generation.py b/trl/generation/vllm_generation.py index 64f08d1d56f..489e817eb0a 100644 --- a/trl/generation/vllm_generation.py +++ b/trl/generation/vllm_generation.py @@ -540,7 +540,8 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte } with profiler: # TODO: profiling_context(trainer, "vLLM.generate"): if rollout_func is not None: - rollout_prompts = ordered_set_of_prompts + # Pass all prompts (with duplicates) to rollout_func for consistency with colocate mode + rollout_prompts = all_prompts if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}): rollout_prompts = [ apply_chat_template({"prompt": p}, processing_class, **chat_template_kwargs)["prompt"] @@ -570,8 +571,12 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte broadcast_object_list(obj_list, from_process=0) all_prompt_ids, all_completion_ids, all_logprobs, all_extra_fields = obj_list[0] - # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times - all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)] + # When using rollout_func, it handles its own generation logic and returns one result per prompt. + # When NOT using rollout_func, vllm_client.generate(n=num_generations) returns num_generations + # completions per prompt, so we need to duplicate prompt_ids to match. + if self.rollout_func is None: + # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times + all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)] process_slice = slice( accelerator.process_index * len(prompts), diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index be795a666ac..702138ca42f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1464,11 +1464,13 @@ def _generate(self, prompts: list): tool_failure_count, ) = self._tool_call_loop(prompts, prompt_ids, completion_ids, completions, logprobs) else: - tool_mask = None + # Support custom env_mask from rollout_func (e.g., for environment feedback masking) + # Internally treated as tool_mask - marks model tokens (1) vs external tokens (0) + tool_mask = extra_fields.pop("env_mask", None) # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) - if tool_mask is not None: # count only non-tool tokens (tool_mask=1) + if tool_mask is not None: # count only model-generated tokens (tool_mask=1) completion_lengths = torch.tensor([sum(mask) for mask in tool_mask], device=device) else: completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) @@ -1570,9 +1572,11 @@ def _generate_and_score_completions( sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") else: sampling_per_token_logps = None - if self.tools: + if tool_mask_list is not None: tool_mask = [torch.tensor(mask, device=device) for mask in tool_mask_list] - tool_mask = pad(tool_mask, padding_value=1, padding_side="right") # 0 for tool result tokens, 1 elsewhere + tool_mask = pad(tool_mask, padding_value=1, padding_side="right") + else: + tool_mask = None # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: @@ -1639,7 +1643,7 @@ def _generate_and_score_completions( # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch if self.use_vllm and self.vllm_importance_sampling_correction: - mask = completion_mask if not self.tools else completion_mask * tool_mask + mask = completion_mask if tool_mask is None else completion_mask * tool_mask per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] @@ -1793,7 +1797,7 @@ def _generate_and_score_completions( if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) - mask = completion_mask.bool() if not self.tools else (completion_mask * tool_mask).bool() + mask = completion_mask.bool() if tool_mask is None else (completion_mask * tool_mask).bool() delta = delta[mask] mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) @@ -1856,7 +1860,7 @@ def _generate_and_score_completions( output["token_type_ids"] = forward_kwargs["token_type_ids"] if images is not None: output["num_images"] = num_images - if self.tools: + if tool_mask is not None: output["tool_mask"] = tool_mask return output @@ -1950,7 +1954,7 @@ def _compute_loss(self, model, inputs): input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - mask = completion_mask if not self.tools else completion_mask * inputs["tool_mask"] + mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"] # Compute the per_token_logps and the entropy at each position in the completion per_token_logps, entropies = self._get_per_token_logps_and_entropies(