diff --git a/docs/content/docs/harbor/index.mdx b/docs/content/docs/harbor/index.mdx index 9a3eaab4e4..a3e55d524c 100644 --- a/docs/content/docs/harbor/index.mdx +++ b/docs/content/docs/harbor/index.mdx @@ -212,7 +212,7 @@ agent: override_timeout_sec: 1200 # Time (seconds) given for a single Trial to run kwargs: max_turns: 32 # Max agent iterations per trial - store_all_messages: true # Required for SkyRL to extract training data + collect_rollout_details: true # Required for SkyRL to extract training data temperature: 1.0 # Sampling temperature (higher = more exploration) enable_summarize: false # Context summarization when nearing token limits model_info: @@ -221,7 +221,7 @@ agent: ``` -`store_all_messages: true` is **required** for training. Without it, SkyRL cannot extract the chat history needed to compute loss masks and train the model. +`collect_rollout_details: true` is **required** for training, where fields like `prompt_token_ids`, `completion_token_ids`, and `logprobs` are crucial for model training. ### Key Knobs for RL Training diff --git a/examples/train_integrations/harbor/entrypoints/main_harbor_fully_async.py b/examples/train_integrations/harbor/entrypoints/main_harbor_fully_async.py new file mode 100644 index 0000000000..76150d98ed --- /dev/null +++ b/examples/train_integrations/harbor/entrypoints/main_harbor_fully_async.py @@ -0,0 +1,74 @@ +""" +Fully-async entrypoint for training on Harbor tasks. + +Reuses HarborExp's generator/dataset overrides and swaps in +``FullyAsyncRayPPOTrainer``. This is the moral equivalent of +``examples/train/fully_async/main_fully_async.py`` for harbor. +""" + +import asyncio +import sys + +import ray +import yaml + +from skyrl.train.fully_async_trainer import FullyAsyncRayPPOTrainer +from skyrl.train.utils import validate_cfg +from skyrl.train.utils.utils import initialize_ray + +from .main_harbor import HARBOR_DEFAULT_CONFIG, HarborExp, HarborSkyRLConfig, _deep_merge + + +class HarborFullyAsyncExp(HarborExp): + def get_trainer( + self, + cfg, + tracker, + tokenizer, + train_dataset, + eval_dataset, + inference_engine_client, + generator, + colocate_pg, + ): + return FullyAsyncRayPPOTrainer( + cfg=cfg, + tracker=tracker, + tokenizer=tokenizer, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + inference_engine_client=inference_engine_client, + generator=generator, + colocate_pg=colocate_pg, + ) + + def run(self): + trainer = self._setup_trainer() + asyncio.run(trainer.train()) + + +@ray.remote(num_cpus=1) +def skyrl_entrypoint(cfg): + exp = HarborFullyAsyncExp(cfg) + exp.run() + + +def main() -> None: + cfg = HarborSkyRLConfig.from_cli_overrides(sys.argv[1:]) + + with open(HARBOR_DEFAULT_CONFIG) as f: + defaults = yaml.safe_load(f) + cfg.harbor_trial_config = _deep_merge(defaults, cfg.harbor_trial_config) + + validate_cfg(cfg) + if cfg.trainer.algorithm.max_seq_len is None: + raise ValueError( + "trainer.algorithm.max_seq_len must be explicitly set for Harbor training; " + "it is required to truncate responses to the maximum allowed length." + ) + initialize_ray(cfg) + ray.get(skyrl_entrypoint.remote(cfg)) + + +if __name__ == "__main__": + main() diff --git a/examples/train_integrations/harbor/harbor_generator.py b/examples/train_integrations/harbor/harbor_generator.py index 2232f11817..4841846bde 100644 --- a/examples/train_integrations/harbor/harbor_generator.py +++ b/examples/train_integrations/harbor/harbor_generator.py @@ -5,14 +5,15 @@ from loguru import logger from uuid import uuid4 from skyrl.train.generators.base import GeneratorInterface, GeneratorInput, GeneratorOutput, TrajectoryID -from skyrl.train.generators.utils import get_rollout_metrics, get_response_ids_and_loss_mask_from_messages +from skyrl.train.generators.utils import get_rollout_metrics from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient from skyrl.backends.skyrl_train.inference_engines.base import ConversationType from skyrl.train.utils.rate_limiter import create_rate_limiter from tqdm import tqdm -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig from harbor.trial.trial import Trial from harbor.models.trial.config import TrialConfig +from harbor.models.agent.rollout_detail import RolloutDetail # Suppress LiteLLM verbose logging @@ -28,141 +29,159 @@ MAX_NUM_RETRIES_PER_TRIAL = 2 -class ChatHistoryExtractor: - """Extracts a (chat_history, summarization_count, num_turns) tuple from Harbor trial results. +@dataclass +class HarborTrajectoryOutput: + """One trajectory's raw output from Harbor. - Supports two extraction strategies, tried in order: - 1. all_messages agents (terminus-2, terminus-1, terminus): metadata["all_messages"] - 2. Trajectory-based agents (mini-swe-agent, swe-agent, openhands): - trajectory.json converted to user/assistant messages + Holds the entire ``rollout_details`` from ``agent_result``. Per-step interpretation + (loss-mask / reward broadcast / overlong filtering) is done downstream in + `build_step_wise_generator_output`. """ - # Agents that write trajectory.json (ATIF format) instead of metadata["all_messages"]. - # OpenHands uses condensation (off-policy) - use reject_summarization=false to allow. - TRAJECTORY_BASED_AGENTS = frozenset( - {"mini-swe-agent", "swe-agent", "openhands", "openhands-host"}) - - @classmethod - def extract(cls, results) -> Optional[tuple]: - """Return (chat_history, summarization_count, num_turns) or None on failure.""" - agent_result = results.agent_result - if agent_result is None: - return None - - metadata = agent_result.metadata or {} - chat_history = metadata.get("all_messages") - if chat_history is not None: - return chat_history, metadata.get("summarization_count", 0), metadata.get("n_episodes", 0) - - # Fallback: load from trajectory.json or completions for trajectory-based agents - agent_name = (getattr(results.config.agent, - "name", None) or "").lower() - if agent_name not in cls.TRAJECTORY_BASED_AGENTS: - return None - - trial_path = cls._trial_path_from_uri( - getattr(results, "trial_uri", None) or "") - if trial_path is None: - return None - - trajectory_path = trial_path / "agent" / "trajectory.json" - chat_history = cls._from_atif_trajectory(trajectory_path) - if chat_history is None: - return None - - # Trajectory-based agents don't track summarization; use 0 for strictly appending - return chat_history, 0, cls._count_turns(chat_history) - - # ------------------------------------------------------------------ - # Private helpers - # ------------------------------------------------------------------ - - @staticmethod - def _count_turns(messages: List[dict]) -> int: - return sum(1 for m in messages if m["role"] == "assistant") - - @staticmethod - def _trial_path_from_uri(trial_uri: str) -> Optional[Path]: - """Extract local filesystem path from trial_uri (e.g. file:///path/to/trial).""" - if not trial_uri: - return None - try: - parsed = urlparse(trial_uri) - if parsed.scheme == "file" and parsed.path: - return Path(parsed.path) - except Exception: - pass - return None - - @classmethod - def _from_atif_trajectory(cls, trajectory_path: Path) -> Optional[List[dict]]: - """Convert ATIF trajectory JSON to user/assistant chat messages for SkyRL training. - - Handles system steps (prepended to first user message), agent observations - (converted to user messages for alternating user/assistant pattern), and - tool_calls (serialized into assistant content). - """ - if not trajectory_path.exists(): - return None - try: - with open(trajectory_path) as f: - data = json.load(f) - except Exception as e: - logger.warning( - f"Failed to load trajectory from {trajectory_path}: {e}") - return None - - messages: List[dict] = [] - pending_system: List[str] = [] - - for step in data.get("steps", []): - source = step.get("source", "") - message = step.get("message", "") - observation = step.get("observation") - - if source == "system": - if message: - pending_system.append(message) - continue - - if source == "user": - content = message or "" - if pending_system: - content = "\n\n".join(pending_system) + "\n\n" + content - pending_system = [] - messages.append({"role": "user", "content": content}) - - elif source == "agent": - content = message or "" - if step.get("tool_calls"): - content = content + "\n" + \ - json.dumps({"tool_calls": step["tool_calls"]}) - if not content: - continue - messages.append({"role": "assistant", "content": content}) - - # Observations represent environment feedback; emit as user message - # to maintain the alternating user/assistant pattern required for RL. - if observation and observation.get("results"): - obs_parts = [r.get("content", "") - for r in observation["results"] if r.get("content")] - if obs_parts: - messages.append( - {"role": "user", "content": "\n".join(obs_parts)}) - - return messages if messages else None - - -@dataclass -class HarborAgentOutput: - response_ids: List[int] - reward: float - stop_reason: str - loss_mask: List[int] - prompt_ids: List[int] trajectory_id: TrajectoryID - summarization_count: Optional[int] = None - num_turns: Optional[int] = None + # Entire rollout_details list as returned by harbor's agent_result. None for failed trajectories + # (agent_timeout / error) that we will mask in `build_step_wise_generator_output`. + rollout_details: Optional[List[RolloutDetail]] = None + reward: float = 0.0 + num_turns: int = 0 + # One of: "complete", "context_length", "agent_timeout", "error". Used by + # `build_step_wise_generator_output` to decide whether to skip the entire prompt group. + stop_reason: str = "complete" + + +def build_step_wise_generator_output( + trajectory_outputs: List[HarborTrajectoryOutput], overlong_filtering: bool +) -> GeneratorOutput: + """Flatten per-trajectory rollout details into one entry per LLM turn. + + Steps for one trajectory are emitted contiguously and the last step has + ``is_last_step=True``. Failures (timeout / unknown error / empty rollout + details) are batched per ``instance_id``: if any rollout for prompt P + failed, all rollouts for P are replaced with single zeroed-out + placeholder steps. + """ + # 1. Identify failed instances. If any rollout for prompt P failed, mask all rollouts for P (conservative). + timeout_instance_ids = set() + error_instance_ids = set() + all_instance_ids = set() + num_timeout_trajectories = 0 + num_error_trajectories = 0 + for traj in trajectory_outputs: + instance_id = traj.trajectory_id.instance_id + all_instance_ids.add(instance_id) + if traj.stop_reason == "agent_timeout": + num_timeout_trajectories += 1 + timeout_instance_ids.add(instance_id) + elif traj.stop_reason == "error" or traj.rollout_details is None: + num_error_trajectories += 1 + error_instance_ids.add(instance_id) + masked_instance_ids = timeout_instance_ids | error_instance_ids + + # 2. Walk trajectories and emit one entry of GeneratorOutput per step. + prompt_token_ids: List[List[int]] = [] + response_ids: List[List[int]] = [] + rewards: List[float] = [] + loss_masks: List[List[int]] = [] + stop_reasons: List[str] = [] + is_last_step_list: List[bool] = [] + out_trajectory_ids: List[TrajectoryID] = [] + rollout_logprobs_list: List[List[float]] = [] + + successful_trajectories: List[HarborTrajectoryOutput] = [] + response_ids_for_metrics: List[List[int]] = [] + rewards_for_metrics: List[float] = [] + for traj in trajectory_outputs: + tid = traj.trajectory_id + + # 2.1. For failed trajectories, set loss mask to [0] and stop reason to "error". + if tid.instance_id in masked_instance_ids: + prompt_token_ids.append([0]) + response_ids.append([0]) + rewards.append(0.0) + loss_masks.append([0]) + stop_reasons.append("error") + is_last_step_list.append(True) + out_trajectory_ids.append(tid) + rollout_logprobs_list.append([0.0]) + continue + + # 2.2. For successful trajectories, emit one entry per step. + successful_trajectories.append(traj) + + # 2.3. Check rollout_details expected format. + # Expect no summarization; rollout_details is a single linear chat segment from the main agent. + # TODO(Charlie): Support summarization. + assert len(traj.rollout_details) == 1, f"Expected exactly one rollout segment, got {len(traj.rollout_details)}." + rollout_detail = traj.rollout_details[0] + prompt_token_ids_per_turn = rollout_detail["prompt_token_ids"] + completion_token_ids_per_turn = rollout_detail["completion_token_ids"] + logprobs_per_turn = rollout_detail["logprobs"] + n_turns = len(completion_token_ids_per_turn) + assert len(prompt_token_ids_per_turn) == n_turns and len(logprobs_per_turn) == n_turns, ( + f"Malformed rollout_details (prompts={len(prompt_token_ids_per_turn)}, completions={n_turns}, " + f"logprobs={len(logprobs_per_turn)})." + ) + + # 2.4. Emit one entry per step, following SkyRL's step-wise convention. + for t in range(n_turns): + comp_ids = completion_token_ids_per_turn[t] + p_ids = prompt_token_ids_per_turn[t] + lp = logprobs_per_turn[t] + assert len(lp) == len(comp_ids), "logprobs and completion token ids must have the same length." + + # Record actual reward in last turn, and zeros for all other turns. + is_last = t == n_turns - 1 + reward = traj.reward if is_last else 0.0 + + # Loss mask. + step_loss_mask = [1] * len(comp_ids) + step_stop_reason = "complete" + if traj.stop_reason == "context_length": + step_stop_reason = "context_length" + if overlong_filtering: + step_loss_mask = [0] * len(comp_ids) + + prompt_token_ids.append(p_ids) + response_ids.append(comp_ids) + rewards.append(reward) + loss_masks.append(step_loss_mask) + stop_reasons.append(step_stop_reason) + is_last_step_list.append(is_last) + out_trajectory_ids.append(tid) + rollout_logprobs_list.append(lp) + + # 2.5. For trajectory-level metrics, record the last turn's prompt IDs and response IDs which + # contains the entire trajectory. + response_ids_for_metrics.append(prompt_token_ids_per_turn[-1] + completion_token_ids_per_turn[-1]) + rewards_for_metrics.append(traj.reward) + + # 3. Aggregate trajectory-level metrics for logging. + if successful_trajectories: + rollout_metrics = get_rollout_metrics(response_ids_for_metrics, rewards_for_metrics) + rollout_metrics["generate/trajectories_context_length_exceeded"] = sum( + 1 for t in successful_trajectories if t.stop_reason == "context_length" + ) + rollout_metrics["generate/avg_num_turns"] = sum(t.num_turns for t in successful_trajectories) / len( + successful_trajectories + ) + else: + rollout_metrics = {} + + rollout_metrics["generate/num_timeout_trajectories"] = num_timeout_trajectories + rollout_metrics["generate/num_error_trajectories"] = num_error_trajectories + rollout_metrics["generate/num_masked_instances"] = len(masked_instance_ids) + + return GeneratorOutput( + prompt_token_ids=prompt_token_ids, + response_ids=response_ids, + rewards=rewards, + loss_masks=loss_masks, + stop_reasons=stop_reasons, + rollout_metrics=rollout_metrics, + rollout_logprobs=rollout_logprobs_list, + trajectory_ids=out_trajectory_ids, + is_last_step=is_last_step_list, + ) class HarborGenerator(GeneratorInterface): @@ -188,8 +207,16 @@ def __init__( self.tokenizer = tokenizer self.max_seq_len = max_seq_len - # Harbor config template - users can specify any Harbor TrialConfig options in YAML or command line. - # SkyRL injects: model_name and api_base (once at init), task.path and session_id (per trial) + if not getattr(generator_cfg, "step_wise_trajectories", False): + raise ValueError( + "HarborGenerator only supports step-wise training. " "Set generator.step_wise_trajectories=true." + ) + if not getattr(generator_cfg, "merge_stepwise_output", False): + logger.warning( + "merge_stepwise_output=true is not set; will not merge step-wise outputs. This " + "may result in much slower training." + ) + self._harbor_trial_config_template = deepcopy(harbor_cfg) # Set model_name and api_base once (constant across all trials) @@ -202,26 +229,29 @@ def __init__( ] = f"hosted_vllm/{ie_cfg.served_model_name}" self._harbor_trial_config_template["agent"].setdefault("kwargs", {})["api_base"] = f"{self.base_url}/v1" + # Step-wise needs per-turn token IDs and logprobs from vLLM via Harbor. + agent_kwargs = self._harbor_trial_config_template["agent"]["kwargs"] + if not agent_kwargs.get("collect_rollout_details", False): + logger.warning("step_wise_trajectories=true requires collect_rollout_details=true; enabling automatically.") + agent_kwargs["collect_rollout_details"] = True + + # Can support summarization in future. + if agent_kwargs.get("enable_summarize", False): + raise ValueError( + "step_wise_trajectories=true is incompatible with enable_summarize=true. " + "Set harbor_trial_config.agent.kwargs.enable_summarize=false." + ) + logger.info( f"HarborGenerator initialized with Harbor config. " f"Agent: {self._harbor_trial_config_template.get('agent', {}).get('name')}, " f"Trials dir: {self._harbor_trial_config_template.get('trials_dir', 'trials')}" ) - # Read custom chat template - custom_chat_template_path = ie_cfg.engine_init_kwargs.get("chat_template", None) - if custom_chat_template_path: - with open(custom_chat_template_path, "r") as f: - self.custom_chat_template_content = f.read() - logger.info(f"HarborGenerator initialized with custom chat template read from: {custom_chat_template_path}") - else: - self.custom_chat_template_content = None - - # Initialize rate limiter from generator config (not part of Harbor TrialConfig) rate_limit_config = getattr(generator_cfg, "rate_limit", None) self._rate_limiter = create_rate_limiter(rate_limit_config) - async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput: + async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False) -> GeneratorOutput: prompts = input_batch["prompts"] trajectory_ids = input_batch["trajectory_ids"] @@ -229,11 +259,12 @@ async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput: raise ValueError("`trajectory_ids` is required in the input batch") if len(prompts) != len(trajectory_ids): raise ValueError( - f"Prompt count ({len(prompts)}) doesn't match " f"trajectory_ids count ({len(trajectory_ids)})" + f"Prompt count ({len(prompts)}) doesn't match trajectory_ids count ({len(trajectory_ids)})" ) - all_outputs: List[HarborAgentOutput] = [None] * len(prompts) # type: ignore[list-item] + all_outputs: List[HarborTrajectoryOutput] = [None] * len(prompts) # type: ignore[list-item] progress = tqdm( + disable=disable_tqdm, # disable for fully async training total=len(prompts), desc="Generating Trajectories", miniters=max(1, len(prompts) // 10), @@ -241,7 +272,7 @@ async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput: ) async def _worker(idx, prompt, trajectory_id): - result = await self.harbor_agent_loop(prompt=prompt, trajectory_id=trajectory_id) + result = await self._harbor_agent_loop(prompt=prompt, trajectory_id=trajectory_id) all_outputs[idx] = result progress.update(1) @@ -251,110 +282,27 @@ async def _worker(idx, prompt, trajectory_id): tg.create_task(_worker(idx, prompt, trajectory_id)) finally: progress.close() - all_outputs, rollout_metrics = self._mask_failed_instances_and_compute_metrics(all_outputs) - - generator_output: GeneratorOutput = { - "prompt_token_ids": [output.prompt_ids for output in all_outputs], - "response_ids": [output.response_ids for output in all_outputs], - "rewards": [output.reward for output in all_outputs], - "loss_masks": [output.loss_mask for output in all_outputs], - "stop_reasons": [output.stop_reason for output in all_outputs], - "rollout_metrics": rollout_metrics, - "rollout_logprobs": None, - } - - return generator_output - - @staticmethod - def _mask_failed_instances_and_compute_metrics( - all_outputs: List[HarborAgentOutput], - ) -> tuple[List[HarborAgentOutput], dict]: - """Mutates all_outputs in-place: zeros out every output belonging to a failed instance. - - For a group of trajectories (n_samples_per_prompt for the same prompt), - if one trajectory fails we skip training the entire group. - - Returns: - all_outputs: The same list, with failed-instance outputs zeroed out. - rollout_metrics: Dict of rollout metrics for logging. - """ - # Count failures by type before grouping overwrites stop_reason. - num_timeout_trajectories = 0 - num_error_trajectories = 0 - timeout_instance_ids = set() - error_instance_ids = set() - all_instance_ids = set() - for output in all_outputs: - cur_instance_id = output.trajectory_id.instance_id - all_instance_ids.add(cur_instance_id) - if output.stop_reason == "agent_timeout": - num_timeout_trajectories += 1 - timeout_instance_ids.add(cur_instance_id) - elif output.stop_reason == "error": - num_error_trajectories += 1 - error_instance_ids.add(cur_instance_id) - - masked_instance_ids = timeout_instance_ids | error_instance_ids - - # Zero-out all outputs belonging to any timeout or error instance so we skip training on them. - successful_outputs: List[HarborAgentOutput] = [] - for output in all_outputs: - if output.trajectory_id.instance_id in masked_instance_ids: - output.response_ids = [0] - output.stop_reason = "error" - output.loss_mask = [0] - output.prompt_ids = [0] - output.reward = 0 - else: - successful_outputs.append(output) - - # Rollout metrics for successful outputs. - if len(successful_outputs) > 0: - rollout_metrics = get_rollout_metrics( - [output.response_ids for output in successful_outputs], - [output.reward for output in successful_outputs], - ) - rollout_metrics["generate/trajectories_summarized"] = sum( - 1 for output in successful_outputs if output.summarization_count > 0 - ) - rollout_metrics["generate/trajectories_context_length_exceeded"] = sum( - 1 for output in successful_outputs if output.stop_reason == "context_length" - ) - rollout_metrics["generate/avg_num_turns"] = sum(output.num_turns for output in successful_outputs) / len( - successful_outputs - ) - else: - rollout_metrics = {} - - # Failure metrics: timeout vs unknown error trajectories, and masked instances. - rollout_metrics["generate/num_timeout_trajectories"] = num_timeout_trajectories - rollout_metrics["generate/num_error_trajectories"] = num_error_trajectories - rollout_metrics["generate/num_masked_instances"] = len(masked_instance_ids) - logger.info( - f"\n# of masked instances: {len(masked_instance_ids)} / {len(all_instance_ids)}\n" - f"# of timeout trajectories: {num_timeout_trajectories}\n" - f"# of error trajectories: {num_error_trajectories}" + return build_step_wise_generator_output( + all_outputs, overlong_filtering=self.generator_cfg.apply_overlong_filtering ) - return all_outputs, rollout_metrics - - async def harbor_agent_loop( + async def _harbor_agent_loop( self, prompt: ConversationType, trajectory_id: TrajectoryID, - ) -> HarborAgentOutput: + ) -> HarborTrajectoryOutput: + """Run a single Harbor trial and return the rollout details plus a trajectory-level reward. + Retries on unknown errors; context length errors train with reward=0; agent timeouts mask the trajectory. """ - Run a single harbor agent. - """ - # Run the trial to get `reward`, `chat_history`, `summarization_count`, and `num_turns` reward = None - chat_history = None - summarization_count = None + results = None + rollout_details = None num_turns = None successful = False is_context_length_error = False is_agent_timeout_error = False + for i in range(MAX_NUM_RETRIES_PER_TRIAL): prefix = f"Trajectory {trajectory_id} attempt {i+1}/{MAX_NUM_RETRIES_PER_TRIAL}" results = None @@ -374,101 +322,52 @@ async def harbor_agent_loop( is_context_length_error = exc_type == "ContextLengthExceededError" is_agent_timeout_error = exc_type == "AgentTimeoutError" - # --- Determine reward --- + # Determine reward. if is_agent_timeout_error: # AgentTimeoutError: not successful, no retry, loss-masked logger.debug(f"{prefix} hit AgentTimeoutError (no retry). Results: {results}") break elif is_context_length_error: # ContextLengthExceededError: always train with reward=0. - logger.debug( - f"{prefix} hit ContextLengthExceededError, will train with reward=0. Results: {results}" - ) - reward = 0 + logger.debug(f"{prefix} hit ContextLengthExceededError, setting reward=0. Results: {results}") + reward = 0.0 elif not results.verifier_result: # Does not have a verifier result, so it's not successful, will retry logger.warning(f"{prefix} failed: Exception info: {results.exception_info}. Results: {results}") continue else: - reward = results.verifier_result.rewards["reward"] + reward = float(results.verifier_result.rewards["reward"]) - # --- Extract chat history and check for success --- - chat_history = results.agent_result.metadata["all_messages"] - summarization_count = results.agent_result.metadata["summarization_count"] + # Extract rollout details and check for success + rollout_details = results.agent_result.rollout_details num_turns = results.agent_result.metadata["n_episodes"] - if len(chat_history) > 1 and chat_history[0]["role"] == "user": + + if ( + rollout_details + and len(rollout_details) >= 1 + and len(rollout_details[0].get("completion_token_ids", [])) > 0 + ): successful = True - logger.debug(f"{prefix} successful: reward={reward}. Results: {results}") + logger.debug(f"{prefix} successful: reward={reward}.") break else: - logger.warning( - f"{prefix} failed: Did not return a chat history with a user message. chat_history: {chat_history}\nResults: {results}" - ) + logger.warning(f"{prefix} failed: empty/missing rollout_details. Results: {results}") except Exception as e: logger.warning(f"{prefix} failed: Error running trial: {e}. Results: {results}") continue if not successful: - # We make loss mask 0 so it does not contribute to model updates stop_reason = "agent_timeout" if is_agent_timeout_error else "error" error_message = f"Trajectory {trajectory_id} failed (stop_reason={stop_reason}), will set loss mask to [0]." if stop_reason == "error": error_message += f" Results: {results}" logger.warning(error_message) - return HarborAgentOutput( - response_ids=[0], - reward=0, - stop_reason=stop_reason, - loss_mask=[0], - prompt_ids=[0], + return HarborTrajectoryOutput(trajectory_id=trajectory_id, rollout_details=None, stop_reason=stop_reason) + else: + return HarborTrajectoryOutput( trajectory_id=trajectory_id, + rollout_details=rollout_details, + reward=reward, + num_turns=num_turns, + stop_reason="context_length" if is_context_length_error else "complete", ) - - # Use the first message as the prompt. We assume to be no systems messages. - assert chat_history[0]["role"] == "user", "The first message should be a user message" - prompt = [chat_history[0]] - prompt_ids = self.tokenizer.apply_chat_template( - prompt, - add_generation_prompt=False, # the message below will add it themselves - return_dict=False, - tokenize=True, - chat_template=self.custom_chat_template_content, - ) - initial_prompt_length = len(prompt_ids) - - # Process response messages (everything after the first message) - response_messages = chat_history[1:] - assistant_logprobs = getattr(results.agent_result, "output_logprobs", None) - response_ids, loss_mask, rollout_logprobs = get_response_ids_and_loss_mask_from_messages( - response_messages, self.tokenizer, assistant_logprobs, chat_template=self.custom_chat_template_content - ) - - # Determine stop reason - max_response_tokens = max(0, self.max_seq_len - initial_prompt_length) - if is_context_length_error or len(response_ids) > max_response_tokens: - stop_reason = "context_length" - else: - stop_reason = "complete" - - # Apply overlong filtering. - # TODO(Charlie): should this also apply when the end reason is max_turns in Harbor? - # Revisit. We would like to reuse `utils.py`'s implementation for overlong filtering. - if self.generator_cfg.apply_overlong_filtering and stop_reason == "context_length": - loss_mask = [0] * len(loss_mask) - - # Truncate to maximum allowed length. - # NOTE(Charlie): though it shouldn't happen since it'd reach `ContextLengthExceededError` - # from Harbor first. We do it anyway to be safe. - response_ids = response_ids[:max_response_tokens] - loss_mask = loss_mask[:max_response_tokens] - - return HarborAgentOutput( - response_ids=response_ids, - reward=reward, - stop_reason=stop_reason, - loss_mask=loss_mask, - prompt_ids=prompt_ids, - trajectory_id=trajectory_id, - summarization_count=summarization_count, - num_turns=num_turns, - ) diff --git a/examples/train_integrations/harbor/harbor_trial_config/default.yaml b/examples/train_integrations/harbor/harbor_trial_config/default.yaml index e38d069294..9d2a920a0c 100644 --- a/examples/train_integrations/harbor/harbor_trial_config/default.yaml +++ b/examples/train_integrations/harbor/harbor_trial_config/default.yaml @@ -44,8 +44,8 @@ agent: # Whether to enable context summarization when approaching token limits enable_summarize: false - # Store all messages in the trial output (required for SkyRL training) - store_all_messages: true + # Collect per-turn rollout details (required for step-wise training) + collect_rollout_details: true # The only sampling param that directly gets passed to Terminus temperature: 1.0 diff --git a/examples/train_integrations/harbor/harbor_trial_config/openhands.yaml b/examples/train_integrations/harbor/harbor_trial_config/openhands.yaml deleted file mode 100644 index c84e6b73e9..0000000000 --- a/examples/train_integrations/harbor/harbor_trial_config/openhands.yaml +++ /dev/null @@ -1,54 +0,0 @@ -# @package harbor_trial_config -# -# OpenHands agent configuration for SkyRL RL training - -reject_summarization: true - -# Harbor TrialConfig fields below -# -------------------------------- - -trials_dir: ~/trials -timeout_multiplier: 1.0 - -agent: - name: openhands - override_timeout_sec: 1800 - - kwargs: - max_turns: 32 - suppress_max_turns_warning: true - enable_plan_mode: false - - # Text-based tool invocation: model generates , etc. in raw text. - # Required for RL training (preserves raw LLM output for proper tokenization). - disable_tool_calls: false - - # Preserve raw LLM responses in trajectory for accurate RL training. - trajectory_config: - raw_content: false - - # Disable reasoning effort to avoid thinking tokens. - #reasoning_effort: null - - temperature: 1.0 - - # Model info for token budgeting. - # NOTE: max_input_tokens should match +generator.engine_init_kwargs.max_model_len - # NOTE: max_output_tokens must be < max_input_tokens to leave room for the prompt, - # otherwise every LLM call triggers ContextWindowExceededError. - model_info: - max_input_tokens: 32768 - max_output_tokens: 4096 - input_cost_per_token: 0.0 - output_cost_per_token: 0.0 - -environment: - type: docker - - # OpenHands needs more resources than terminus-2 (runs its own venv, tools, etc.) - override_cpus: 2 - override_memory_mb: 4096 - suppress_override_warnings: true - -verifier: - disable: false diff --git a/examples/train_integrations/harbor/run_codecontest.sh b/examples/train_integrations/harbor/run_codecontest.sh index eac1fc77ce..38ab9db573 100644 --- a/examples/train_integrations/harbor/run_codecontest.sh +++ b/examples/train_integrations/harbor/run_codecontest.sh @@ -22,30 +22,41 @@ EVAL_DATA="['$DATA_DIR/OpenThoughts-TB-dev']" # Directory setup #----------------------- RUN_NAME="codecontest" -TRIALS_DIR="$HOME/$RUN_NAME/trials_run" -CKPTS_DIR="$HOME/$RUN_NAME/ckpts" -EXPORTS_DIR="$HOME/$RUN_NAME/exports" -LOG_DIR="/tmp/skyrl-logs/$RUN_NAME" +STORAGE_ROOT="/mnt/local_storage/$RUN_NAME" +TRIALS_DIR="$STORAGE_ROOT/trials_run" +CKPTS_DIR="$STORAGE_ROOT/ckpts" +EXPORTS_DIR="$STORAGE_ROOT/exports" +LOG_DIR="$STORAGE_ROOT/logs" #----------------------- # Training setup #----------------------- +N_SAMPLES_PER_PROMPT=8 MINI_BATCH_SIZE=32 MAX_MODEL_LEN=32768 -APPLY_OVERLONG_FILTERING=true -# Dr. GRPO parameters -LOSS_REDUCTION="seq_mean_token_sum_norm" +# Algorithmic parameters +LOSS_REDUCTION="token_mean" # with step-wise training, we have to use token_mean to be prefix-merge-invariant GRPO_NORM_BY_STD=false USE_KL_LOSS=false +APPLY_OVERLONG_FILTERING=true -# Essentially achieves interleaved thinking and hence on-policy training without step-wise training. +# Essentially achieves interleaved thinking (does not strip thinking tokens). Allows our step-wise +# training to be able to merge more step-wise outputs and hence speed up training. +# If you change the model you train, please change it accordingly, and decide if you need to make +# modifications. CHAT_TEMPLATE_PATH="$(dirname "$0")/../../../skyrl/train/utils/templates/qwen3_acc_thinking.jinja2" +# TIS corrections +TIS_TYPE=token +TIS_IMP_RATIO_CAP=2.0 + #---------------- # Infrastructure setup #---------------- -NUM_GPUS=8 +NUM_POLICY_GPUS=8 +NUM_INFERENCE_ENGINES=4 +TP_SIZE=2 ENABLE_RATE_LIMITING=true # Enable rate/concurrency limiting for trajectory submissions TRAJECTORIES_PER_SECOND=5 # Maximum trajectories per second (must be >= 1.0, fractional values like 1.5 are supported). null or omit to disable rate limiting MAX_CONCURRENCY=512 # Maximum concurrent trial.run() calls allowed (must be >= 1). null or omit to disable concurrency limiting @@ -64,14 +75,16 @@ uv run --isolated --extra fsdp --extra harbor -m examples.train_integrations.har trainer.algorithm.loss_reduction=$LOSS_REDUCTION \ trainer.algorithm.grpo_norm_by_std=$GRPO_NORM_BY_STD \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ trainer.placement.policy_num_nodes=1 \ trainer.placement.ref_num_nodes=1 \ - trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ - trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ - generator.inference_engine.num_engines=$NUM_GPUS \ - generator.inference_engine.tensor_parallel_size=1 \ + trainer.placement.policy_num_gpus_per_node=$NUM_POLICY_GPUS \ + trainer.placement.ref_num_gpus_per_node=$NUM_POLICY_GPUS \ + generator.inference_engine.num_engines=$NUM_INFERENCE_ENGINES \ + generator.inference_engine.tensor_parallel_size=$TP_SIZE \ generator.inference_engine.engine_init_kwargs.chat_template=$CHAT_TEMPLATE_PATH \ generator.inference_engine.engine_init_kwargs.max_model_len=$MAX_MODEL_LEN \ generator.inference_engine.engine_init_kwargs.enable_log_requests=false \ @@ -85,11 +98,14 @@ uv run --isolated --extra fsdp --extra harbor -m examples.train_integrations.har trainer.micro_forward_batch_size_per_gpu=1 \ trainer.micro_train_batch_size_per_gpu=1 \ trainer.ckpt_interval=5 \ + trainer.max_ckpts_to_keep=5 \ trainer.hf_save_interval=5 \ trainer.algorithm.max_seq_len=$MAX_MODEL_LEN \ trainer.policy.optimizer_config.lr=1.0e-6 \ - generator.n_samples_per_prompt=8 \ - generator.eval_n_samples_per_prompt=4 \ + generator.step_wise_trajectories=true \ + generator.merge_stepwise_output=true \ + generator.n_samples_per_prompt=$N_SAMPLES_PER_PROMPT \ + generator.eval_n_samples_per_prompt=2 \ generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \ generator.inference_engine.gpu_memory_utilization=0.8 \ trainer.logger=wandb \ diff --git a/examples/train_integrations/harbor/run_codecontest_fully_async.sh b/examples/train_integrations/harbor/run_codecontest_fully_async.sh new file mode 100644 index 0000000000..1c7a564613 --- /dev/null +++ b/examples/train_integrations/harbor/run_codecontest_fully_async.sh @@ -0,0 +1,139 @@ +set -ex + +# wandb api key. +# export WANDB_API_KEY=YOUR_KEY_HERE + +# Pick the sandbox provider and provide the credentials. +# export DAYTONA_API_KEY=YOUR_KEY_HERE +# export MODAL_TOKEN_ID=YOUR_KEY_HERE +# export MODAL_TOKEN_SECRET=YOUR_KEY_HERE + +#----------------------- +# Dataset setup +#----------------------- +# Prepare datasets first (downloads from HuggingFace and extracts tasks): +# uv run examples/train_integrations/harbor/prepare_harbor_dataset.py --dataset open-thoughts/CodeContests +# uv run examples/train_integrations/harbor/prepare_harbor_dataset.py --dataset open-thoughts/OpenThoughts-TB-dev +DATA_DIR="$HOME/data/harbor" +TRAIN_DATA="['$DATA_DIR/CodeContests']" +EVAL_DATA="['$DATA_DIR/OpenThoughts-TB-dev']" + +#----------------------- +# Directory setup +#----------------------- +RUN_NAME="codecontest-fullyasync" +STORAGE_ROOT="/mnt/local_storage/$RUN_NAME" +TRIALS_DIR="$STORAGE_ROOT/trials_run" +CKPTS_DIR="$STORAGE_ROOT/ckpts" +EXPORTS_DIR="$STORAGE_ROOT/exports" +LOG_DIR="$STORAGE_ROOT/logs" + +#----------------------- +# Training setup +#----------------------- +N_SAMPLES_PER_PROMPT=8 +MINI_BATCH_SIZE=16 +MAX_MODEL_LEN=32768 + +# Algorithmic parameters +LOSS_REDUCTION="token_mean" # with step-wise training, we have to use token_mean to be prefix-merge-invariant +GRPO_NORM_BY_STD=false +USE_KL_LOSS=false +APPLY_OVERLONG_FILTERING=true + +# Essentially achieves interleaved thinking (does not strip thinking tokens). Allows our step-wise +# training to be able to merge more step-wise outputs and hence speed up training. +# If you change the model you train, please change it accordingly, and decide if you need to make +# modifications. +CHAT_TEMPLATE_PATH="$(dirname "$0")/../../../skyrl/train/utils/templates/qwen3_acc_thinking.jinja2" + +# TIS corrections +TIS_TYPE=token +TIS_IMP_RATIO_CAP=2.0 + +# ------------------------- +# Fully-async knobs. +# All knobs are tuned for 1x8xH100 node for Qwen3-8B, please adjust accordingly. +# Constraint: mini_batch_size <= num_parallel_generation_workers <= mini_batch_size * (max_staleness_steps + 1) +# Can increase num_parallel_generation_workers based on your hardware resources (e.g. KV cache size). +# ------------------------- +MAX_STALENESS_STEPS=4 +NUM_PARALLEL_GENERATION_WORKERS=$(( MINI_BATCH_SIZE * 2 )) + +#---------------- +# Infrastructure setup. +# All knobs are tuned for 1x8xH100 node for Qwen3-8B, please adjust accordingly. +#---------------- +NUM_INFERENCE_ENGINES=2 +TP_SIZE=2 +NUM_POLICY_GPUS=4 +ENABLE_RATE_LIMITING=true # Enable rate/concurrency limiting for trajectory submissions +TRAJECTORIES_PER_SECOND=5 # Maximum trajectories per second (must be >= 1.0, fractional values like 1.5 are supported). null or omit to disable rate limiting +MAX_CONCURRENCY=128 # Maximum concurrent trial.run() calls allowed (must be >= 1). null or omit to disable concurrency limiting + +# Run SkyRL command +uv run --isolated --extra fsdp --extra harbor -m examples.train_integrations.harbor.entrypoints.main_harbor_fully_async \ + data.train_data=$TRAIN_DATA \ + data.val_data=$EVAL_DATA \ + trainer.policy.model.path=Qwen/Qwen3-8B \ + generator.inference_engine.served_model_name=Qwen3-8B \ + harbor_trial_config.trials_dir=$TRIALS_DIR \ + trainer.export_path=$EXPORTS_DIR \ + trainer.ckpt_path=$CKPTS_DIR \ + trainer.log_path=$LOG_DIR \ + trainer.fully_async.max_staleness_steps=$MAX_STALENESS_STEPS \ + trainer.fully_async.num_parallel_generation_workers=$NUM_PARALLEL_GENERATION_WORKERS \ + trainer.algorithm.advantage_estimator=grpo \ + trainer.algorithm.loss_reduction=$LOSS_REDUCTION \ + trainer.algorithm.grpo_norm_by_std=$GRPO_NORM_BY_STD \ + trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ + trainer.placement.colocate_all=false \ + trainer.strategy=fsdp2 \ + trainer.placement.policy_num_nodes=1 \ + trainer.placement.ref_num_nodes=1 \ + trainer.placement.policy_num_gpus_per_node=$NUM_POLICY_GPUS \ + trainer.placement.ref_num_gpus_per_node=$NUM_POLICY_GPUS \ + generator.inference_engine.num_engines=$NUM_INFERENCE_ENGINES \ + generator.inference_engine.tensor_parallel_size=$TP_SIZE \ + generator.inference_engine.engine_init_kwargs.chat_template=$CHAT_TEMPLATE_PATH \ + generator.inference_engine.engine_init_kwargs.max_model_len=$MAX_MODEL_LEN \ + generator.inference_engine.engine_init_kwargs.enable_log_requests=false \ + trainer.epochs=3 \ + trainer.eval_batch_size=128 \ + trainer.eval_before_train=false \ + trainer.eval_interval=100 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=$MINI_BATCH_SIZE \ + trainer.policy_mini_batch_size=$MINI_BATCH_SIZE \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.ckpt_interval=5 \ + trainer.max_ckpts_to_keep=5 \ + trainer.hf_save_interval=5 \ + trainer.algorithm.max_seq_len=$MAX_MODEL_LEN \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + generator.step_wise_trajectories=true \ + generator.merge_stepwise_output=true \ + generator.n_samples_per_prompt=$N_SAMPLES_PER_PROMPT \ + generator.eval_n_samples_per_prompt=2 \ + generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \ + generator.inference_engine.gpu_memory_utilization=0.9 \ + trainer.logger=wandb \ + trainer.project_name=harbor \ + trainer.run_name=$RUN_NAME \ + trainer.resume_mode=latest \ + generator.inference_engine.backend=vllm \ + generator.inference_engine.run_engines_locally=true \ + generator.inference_engine.weight_sync_backend=nccl \ + generator.inference_engine.async_engine=true \ + generator.batched=false \ + generator.inference_engine.enforce_eager=false \ + generator.inference_engine.enable_http_endpoint=true \ + generator.inference_engine.http_endpoint_host=127.0.0.1 \ + generator.inference_engine.http_endpoint_port=8000 \ + generator.rate_limit.enabled=$ENABLE_RATE_LIMITING \ + generator.rate_limit.trajectories_per_second=$TRAJECTORIES_PER_SECOND \ + generator.rate_limit.max_concurrency=$MAX_CONCURRENCY \ + "$@" diff --git a/examples/train_integrations/harbor/run_codecontest_openhands.sh b/examples/train_integrations/harbor/run_codecontest_openhands.sh deleted file mode 100644 index bed0f51d27..0000000000 --- a/examples/train_integrations/harbor/run_codecontest_openhands.sh +++ /dev/null @@ -1,135 +0,0 @@ -set -ex - -# wandb api key. -# export WANDB_API_KEY=YOUR_KEY_HERE - -# Pick the sandbox provider and provide the credentials. -# export DAYTONA_API_KEY=YOUR_KEY_HERE -# export MODAL_TOKEN_ID=YOUR_KEY_HERE -# export MODAL_TOKEN_SECRET=YOUR_KEY_HERE - -# ---- OpenHands-specific env vars ---- -# Disable condensation to ensure strictly-appending chat history for RL. -# The Harbor OpenHands agent forwards OPENHANDS_* env vars (stripping prefix). -export OPENHANDS_ENABLE_DEFAULT_CONDENSER=false -# Disable history truncation to prevent infinite condensation loops when context -# is exceeded. With this off, ContextWindowExceededError is raised cleanly instead -# of looping through condenser requests that can never reduce essential events. -export OPENHANDS_AGENT_ENABLE_HISTORY_TRUNCATION=false - -#----------------------- -# vLLM endpoint for Docker containers -#----------------------- -# OpenHands runs inside Docker containers (not on the host). The containers reach -# the host's vLLM server via the Docker bridge gateway (172.17.0.1 on Linux). -# Override VLLM_API_BASE if your Docker bridge uses a different gateway IP. -VLLM_PORT=8000 -VLLM_API_BASE="${VLLM_API_BASE:-http://172.17.0.1:${VLLM_PORT}/v1}" -echo "vLLM API base for Docker containers: $VLLM_API_BASE" - -#----------------------- -# Dataset setup -#----------------------- -# Prepare datasets first (downloads from HuggingFace and extracts tasks): -# uv run examples/train_integrations/harbor/prepare_harbor_dataset.py --dataset open-thoughts/CodeContests -# uv run examples/train_integrations/harbor/prepare_harbor_dataset.py --dataset open-thoughts/OpenThoughts-TB-dev -DATA_DIR="$HOME/data/harbor" -TRAIN_DATA="['$DATA_DIR/CodeContests']" -EVAL_DATA="['$DATA_DIR/OpenThoughts-TB-dev']" - -#----------------------- -# Directory setup -#----------------------- -RUN_NAME="codecontest-openhands" -TRIALS_DIR="$HOME/$RUN_NAME/trials_run" -CKPTS_DIR="$HOME/$RUN_NAME/ckpts" -EXPORTS_DIR="$HOME/$RUN_NAME/exports" -# Logs (trainer + tee) go under my_logs/ in the repo root when run from SkyRL-main. -LOG_DIR="my_logs/$RUN_NAME" -mkdir -p "$LOG_DIR" -# To save the full run log when you interrupt: ... 2>&1 | stdbuf -oL tee "$LOG_DIR/training.log" - -#----------------------- -# Training setup -#----------------------- -MINI_BATCH_SIZE=2 -MAX_MODEL_LEN=16384 -APPLY_OVERLONG_FILTERING=true - -# Dr. GRPO parameters -LOSS_REDUCTION="seq_mean_token_sum_norm" -GRPO_NORM_BY_STD=false -USE_KL_LOSS=false - -CHAT_TEMPLATE_PATH="$(dirname "$0")/../../../skyrl/train/utils/templates/qwen3_acc_thinking.jinja2" - -#---------------- -# Infrastructure setup -#---------------- -NUM_GPUS=1 -ENABLE_RATE_LIMITING=true -# OpenHands trials are heavier than terminus-2 but Docker runs locally. -TRAJECTORIES_PER_SECOND=2 -MAX_CONCURRENCY=4 - -# Run SkyRL command with OpenHands agent -uv run --isolated --extra fsdp --extra harbor -m examples.train_integrations.harbor.entrypoints.main_harbor \ - data.train_data=$TRAIN_DATA \ - data.val_data=$EVAL_DATA \ - trainer.policy.model.path=Qwen/Qwen3-1.7B \ - generator.served_model_name=Qwen3-1.7B \ - hydra.searchpath=['file://examples/train_integrations/harbor'] \ - +harbor_trial_config=openhands \ - ++harbor_trial_config.trials_dir=$TRIALS_DIR \ - trainer.export_path=$EXPORTS_DIR \ - trainer.ckpt_path=$CKPTS_DIR \ - trainer.log_path=$LOG_DIR \ - trainer.algorithm.advantage_estimator=grpo \ - trainer.algorithm.loss_reduction=$LOSS_REDUCTION \ - trainer.algorithm.grpo_norm_by_std=$GRPO_NORM_BY_STD \ - trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ - trainer.placement.colocate_all=true \ - trainer.strategy=fsdp2 \ - trainer.placement.policy_num_nodes=1 \ - trainer.placement.ref_num_nodes=1 \ - trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ - trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ - generator.num_inference_engines=$NUM_GPUS \ - generator.inference_engine_tensor_parallel_size=1 \ - +generator.engine_init_kwargs.chat_template=$CHAT_TEMPLATE_PATH \ - +generator.engine_init_kwargs.max_model_len=$MAX_MODEL_LEN \ - +generator.engine_init_kwargs.enable_log_requests=false \ - trainer.epochs=1 \ - trainer.eval_batch_size=128 \ - trainer.eval_before_train=false \ - trainer.eval_interval=20 \ - trainer.update_epochs_per_batch=1 \ - trainer.train_batch_size=$MINI_BATCH_SIZE \ - trainer.policy_mini_batch_size=$MINI_BATCH_SIZE \ - trainer.micro_forward_batch_size_per_gpu=1 \ - trainer.micro_train_batch_size_per_gpu=1 \ - trainer.ckpt_interval=5 \ - trainer.hf_save_interval=5 \ - trainer.algorithm.max_seq_len=$MAX_MODEL_LEN \ - trainer.policy.optimizer_config.lr=1.0e-6 \ - generator.n_samples_per_prompt=8 \ - generator.eval_n_samples_per_prompt=4 \ - generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \ - generator.gpu_memory_utilization=0.5 \ - trainer.logger=wandb \ - trainer.project_name=harbor \ - trainer.run_name=$RUN_NAME \ - trainer.resume_mode=latest \ - generator.backend=vllm \ - generator.run_engines_locally=true \ - generator.weight_sync_backend=nccl \ - generator.async_engine=true \ - generator.batched=false \ - generator.enforce_eager=false \ - generator.enable_http_endpoint=true \ - generator.http_endpoint_host=0.0.0.0 \ - generator.http_endpoint_port=8000 \ - ++harbor_trial_config.agent.kwargs.api_base="${VLLM_API_BASE}" \ - +generator.rate_limit.enabled=$ENABLE_RATE_LIMITING \ - +generator.rate_limit.trajectories_per_second=$TRAJECTORIES_PER_SECOND \ - +generator.rate_limit.max_concurrency=$MAX_CONCURRENCY diff --git a/skyrl/train/fully_async_trainer.py b/skyrl/train/fully_async_trainer.py index 63284da242..a720565290 100644 --- a/skyrl/train/fully_async_trainer.py +++ b/skyrl/train/fully_async_trainer.py @@ -649,6 +649,7 @@ def convert_generation_group_mini_batch_to_training_input( ) assert generator_output["rollout_metrics"] is not None, "Rollout metrics should be non-null." self.all_metrics.update(generator_output["rollout_metrics"]) + generator_output.pop("rollout_metrics", None) # Log staleness statistics for this step self.all_metrics.update( diff --git a/skyrl/train/generators/utils.py b/skyrl/train/generators/utils.py index ea908cbe9a..9bd9f4bf8b 100644 --- a/skyrl/train/generators/utils.py +++ b/skyrl/train/generators/utils.py @@ -274,6 +274,25 @@ def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput], step # Re-aggregate rollout metrics rollout_metrics = get_rollout_metrics(result["response_ids"], result["rewards"]) + + # Preserve generator-specific metrics from per-group rollout_metrics. get_rollout_metrics only + # computes basic stats (response length, reward); generators may add custom keys, which we + # aggregate by inferring from the key name. TODO(Charlie): hacky, to be removed soon. + extra_keys: dict = {} + for go in generator_outputs: + per_group = go.get("rollout_metrics") or {} + for k, v in per_group.items(): + if k not in rollout_metrics and isinstance(v, (int, float)): + extra_keys.setdefault(k, []).append(v) + for k, values in extra_keys.items(): + if "avg" in k or "mean" in k: + rollout_metrics[k] = sum(values) / len(values) + elif "min" in k: + rollout_metrics[k] = min(values) + elif "max" in k: + rollout_metrics[k] = max(values) + else: + rollout_metrics[k] = sum(values) result["rollout_metrics"] = rollout_metrics # Validate the generator output using the number of prompts @@ -607,7 +626,8 @@ def _slice_generator_output(generator_output: GeneratorOutput, indices: List[int sliced: GeneratorOutput = {} for key, value in generator_output.items(): if key == "rollout_metrics": - sliced[key] = value + # Skip since metrics are already recorded before calling `merge_stepwise_output()`. + continue elif value is None: sliced[key] = None else: @@ -716,7 +736,6 @@ def flush(): "rewards": out_rewards, "loss_masks": out_loss_masks, "stop_reasons": out_stop_reasons, - "rollout_metrics": gen_out.get("rollout_metrics", None), "rollout_logprobs": out_logprobs, "trajectory_ids": out_trajectory_ids, "rollout_expert_indices": None, @@ -740,6 +759,9 @@ def merge_stepwise_output(generator_output: GeneratorOutput) -> GeneratorOutput: When the prefix condition fails between two consecutive turns, the current merge group is flushed and a new group starts (greedy merging). + The returned GeneratorOutput's rollout_metrics should be ignored. We already recorded it before + calling this function. + Args: generator_output: Step-wise GeneratorOutput with trajectory_ids and is_last_step. @@ -765,5 +787,5 @@ def merge_stepwise_output(generator_output: GeneratorOutput) -> GeneratorOutput: start = i + 1 merged_slices = [_merge_single_trajectory(s) for s in trajectory_slices] - # concatenate_generator_outputs re-aggregates rollout_metrics and validates + return concatenate_generator_outputs(merged_slices, step_wise=True) diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index af1b65ed11..727e078cce 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -743,6 +743,7 @@ async def generate( # add rollout metrics to self.all_metrics if generator_output["rollout_metrics"] is not None: self.all_metrics.update(generator_output["rollout_metrics"]) + generator_output.pop("rollout_metrics", None) validate_generator_output( len(input_batch["prompts"]),