-
Notifications
You must be signed in to change notification settings - Fork 315
support trajectory-based agents and OpenHands (#1184) #1548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4383b24
91d1022
7e1b60b
fbfb962
ae31182
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,131 @@ | |
| MAX_NUM_RETRIES_PER_TRIAL = 2 | ||
|
|
||
|
|
||
| class ChatHistoryExtractor: | ||
| """Extracts a (chat_history, summarization_count, num_turns) tuple from Harbor trial results. | ||
|
|
||
| Supports two extraction strategies, tried in order: | ||
| 1. all_messages agents (terminus-2, terminus-1, terminus): metadata["all_messages"] | ||
| 2. Trajectory-based agents (mini-swe-agent, swe-agent, openhands): | ||
| trajectory.json converted to user/assistant messages | ||
| """ | ||
|
|
||
| # Agents that write trajectory.json (ATIF format) instead of metadata["all_messages"]. | ||
| # OpenHands uses condensation (off-policy) - use reject_summarization=false to allow. | ||
| TRAJECTORY_BASED_AGENTS = frozenset( | ||
| {"mini-swe-agent", "swe-agent", "openhands", "openhands-host"}) | ||
|
|
||
| @classmethod | ||
| def extract(cls, results) -> Optional[tuple]: | ||
|
arteemg marked this conversation as resolved.
|
||
| """Return (chat_history, summarization_count, num_turns) or None on failure.""" | ||
| agent_result = results.agent_result | ||
| if agent_result is None: | ||
| return None | ||
|
|
||
| metadata = agent_result.metadata or {} | ||
| chat_history = metadata.get("all_messages") | ||
| if chat_history is not None: | ||
| return chat_history, metadata.get("summarization_count", 0), metadata.get("n_episodes", 0) | ||
|
|
||
| # Fallback: load from trajectory.json or completions for trajectory-based agents | ||
| agent_name = (getattr(results.config.agent, | ||
| "name", None) or "").lower() | ||
| if agent_name not in cls.TRAJECTORY_BASED_AGENTS: | ||
| return None | ||
|
|
||
| trial_path = cls._trial_path_from_uri( | ||
| getattr(results, "trial_uri", None) or "") | ||
| if trial_path is None: | ||
| return None | ||
|
|
||
| trajectory_path = trial_path / "agent" / "trajectory.json" | ||
| chat_history = cls._from_atif_trajectory(trajectory_path) | ||
| if chat_history is None: | ||
| return None | ||
|
|
||
| # Trajectory-based agents don't track summarization; use 0 for strictly appending | ||
| return chat_history, 0, cls._count_turns(chat_history) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Private helpers | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| @staticmethod | ||
| def _count_turns(messages: List[dict]) -> int: | ||
| return sum(1 for m in messages if m["role"] == "assistant") | ||
|
|
||
| @staticmethod | ||
| def _trial_path_from_uri(trial_uri: str) -> Optional[Path]: | ||
| """Extract local filesystem path from trial_uri (e.g. file:///path/to/trial).""" | ||
| if not trial_uri: | ||
| return None | ||
| try: | ||
| parsed = urlparse(trial_uri) | ||
| if parsed.scheme == "file" and parsed.path: | ||
| return Path(parsed.path) | ||
| except Exception: | ||
| pass | ||
| return None | ||
|
|
||
| @classmethod | ||
| def _from_atif_trajectory(cls, trajectory_path: Path) -> Optional[List[dict]]: | ||
| """Convert ATIF trajectory JSON to user/assistant chat messages for SkyRL training. | ||
|
|
||
| Handles system steps (prepended to first user message), agent observations | ||
| (converted to user messages for alternating user/assistant pattern), and | ||
| tool_calls (serialized into assistant content). | ||
| """ | ||
| if not trajectory_path.exists(): | ||
| return None | ||
| try: | ||
| with open(trajectory_path) as f: | ||
| data = json.load(f) | ||
| except Exception as e: | ||
| logger.warning( | ||
| f"Failed to load trajectory from {trajectory_path}: {e}") | ||
| return None | ||
|
|
||
| messages: List[dict] = [] | ||
| pending_system: List[str] = [] | ||
|
|
||
| for step in data.get("steps", []): | ||
| source = step.get("source", "") | ||
| message = step.get("message", "") | ||
| observation = step.get("observation") | ||
|
|
||
| if source == "system": | ||
| if message: | ||
| pending_system.append(message) | ||
| continue | ||
|
|
||
| if source == "user": | ||
| content = message or "" | ||
| if pending_system: | ||
| content = "\n\n".join(pending_system) + "\n\n" + content | ||
| pending_system = [] | ||
| messages.append({"role": "user", "content": content}) | ||
|
|
||
| elif source == "agent": | ||
| content = message or "" | ||
| if step.get("tool_calls"): | ||
| content = content + "\n" + \ | ||
| json.dumps({"tool_calls": step["tool_calls"]}) | ||
| if not content: | ||
| continue | ||
| messages.append({"role": "assistant", "content": content}) | ||
|
|
||
| # Observations represent environment feedback; emit as user message | ||
| # to maintain the alternating user/assistant pattern required for RL. | ||
| if observation and observation.get("results"): | ||
| obs_parts = [r.get("content", "") | ||
| for r in observation["results"] if r.get("content")] | ||
| if obs_parts: | ||
| messages.append( | ||
| {"role": "user", "content": "\n".join(obs_parts)}) | ||
|
|
||
| return messages if messages else None | ||
|
Comment on lines
+31
to
+153
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 ChatHistoryExtractor is defined but never called, making OpenHands integration non-functional The PR adds Prompt for agentsWas this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
|
|
||
| @dataclass | ||
| class HarborAgentOutput: | ||
| response_ids: List[int] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| # @package harbor_trial_config | ||
| # | ||
| # OpenHands agent configuration for SkyRL RL training | ||
|
|
||
| reject_summarization: true | ||
|
|
||
| # Harbor TrialConfig fields below | ||
| # -------------------------------- | ||
|
|
||
| trials_dir: ~/trials | ||
| timeout_multiplier: 1.0 | ||
|
|
||
| agent: | ||
| name: openhands | ||
| override_timeout_sec: 1800 | ||
|
|
||
| kwargs: | ||
| max_turns: 32 | ||
| suppress_max_turns_warning: true | ||
| enable_plan_mode: false | ||
|
|
||
| # Text-based tool invocation: model generates <execute_bash>, etc. in raw text. | ||
| # Required for RL training (preserves raw LLM output for proper tokenization). | ||
| disable_tool_calls: false | ||
|
|
||
| # Preserve raw LLM responses in trajectory for accurate RL training. | ||
| trajectory_config: | ||
| raw_content: false | ||
|
|
||
| # Disable reasoning effort to avoid thinking tokens. | ||
| #reasoning_effort: null | ||
|
|
||
| temperature: 1.0 | ||
|
|
||
| # Model info for token budgeting. | ||
| # NOTE: max_input_tokens should match +generator.engine_init_kwargs.max_model_len | ||
| # NOTE: max_output_tokens must be < max_input_tokens to leave room for the prompt, | ||
| # otherwise every LLM call triggers ContextWindowExceededError. | ||
| model_info: | ||
| max_input_tokens: 32768 | ||
| max_output_tokens: 4096 | ||
| input_cost_per_token: 0.0 | ||
| output_cost_per_token: 0.0 | ||
|
|
||
| environment: | ||
| type: docker | ||
|
|
||
| # OpenHands needs more resources than terminus-2 (runs its own venv, tools, etc.) | ||
| override_cpus: 2 | ||
| override_memory_mb: 4096 | ||
| suppress_override_warnings: true | ||
|
|
||
| verifier: | ||
| disable: false |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| set -ex | ||
|
|
||
| # wandb api key. | ||
| # export WANDB_API_KEY=YOUR_KEY_HERE | ||
|
|
||
| # Pick the sandbox provider and provide the credentials. | ||
| # export DAYTONA_API_KEY=YOUR_KEY_HERE | ||
| # export MODAL_TOKEN_ID=YOUR_KEY_HERE | ||
| # export MODAL_TOKEN_SECRET=YOUR_KEY_HERE | ||
|
|
||
| # ---- OpenHands-specific env vars ---- | ||
| # Disable condensation to ensure strictly-appending chat history for RL. | ||
| # The Harbor OpenHands agent forwards OPENHANDS_* env vars (stripping prefix). | ||
| export OPENHANDS_ENABLE_DEFAULT_CONDENSER=false | ||
| # Disable history truncation to prevent infinite condensation loops when context | ||
| # is exceeded. With this off, ContextWindowExceededError is raised cleanly instead | ||
| # of looping through condenser requests that can never reduce essential events. | ||
| export OPENHANDS_AGENT_ENABLE_HISTORY_TRUNCATION=false | ||
|
|
||
| #----------------------- | ||
| # vLLM endpoint for Docker containers | ||
| #----------------------- | ||
| # OpenHands runs inside Docker containers (not on the host). The containers reach | ||
| # the host's vLLM server via the Docker bridge gateway (172.17.0.1 on Linux). | ||
| # Override VLLM_API_BASE if your Docker bridge uses a different gateway IP. | ||
| VLLM_PORT=8000 | ||
| VLLM_API_BASE="${VLLM_API_BASE:-http://172.17.0.1:${VLLM_PORT}/v1}" | ||
| echo "vLLM API base for Docker containers: $VLLM_API_BASE" | ||
|
|
||
| #----------------------- | ||
| # Dataset setup | ||
| #----------------------- | ||
| # Prepare datasets first (downloads from HuggingFace and extracts tasks): | ||
| # uv run examples/train_integrations/harbor/prepare_harbor_dataset.py --dataset open-thoughts/CodeContests | ||
| # uv run examples/train_integrations/harbor/prepare_harbor_dataset.py --dataset open-thoughts/OpenThoughts-TB-dev | ||
| DATA_DIR="$HOME/data/harbor" | ||
| TRAIN_DATA="['$DATA_DIR/CodeContests']" | ||
| EVAL_DATA="['$DATA_DIR/OpenThoughts-TB-dev']" | ||
|
|
||
| #----------------------- | ||
| # Directory setup | ||
| #----------------------- | ||
| RUN_NAME="codecontest-openhands" | ||
| TRIALS_DIR="$HOME/$RUN_NAME/trials_run" | ||
| CKPTS_DIR="$HOME/$RUN_NAME/ckpts" | ||
| EXPORTS_DIR="$HOME/$RUN_NAME/exports" | ||
| # Logs (trainer + tee) go under my_logs/ in the repo root when run from SkyRL-main. | ||
| LOG_DIR="my_logs/$RUN_NAME" | ||
| mkdir -p "$LOG_DIR" | ||
| # To save the full run log when you interrupt: ... 2>&1 | stdbuf -oL tee "$LOG_DIR/training.log" | ||
|
|
||
| #----------------------- | ||
| # Training setup | ||
| #----------------------- | ||
| MINI_BATCH_SIZE=2 | ||
| MAX_MODEL_LEN=16384 | ||
| APPLY_OVERLONG_FILTERING=true | ||
|
|
||
| # Dr. GRPO parameters | ||
| LOSS_REDUCTION="seq_mean_token_sum_norm" | ||
| GRPO_NORM_BY_STD=false | ||
| USE_KL_LOSS=false | ||
|
|
||
| CHAT_TEMPLATE_PATH="$(dirname "$0")/../../../skyrl/train/utils/templates/qwen3_acc_thinking.jinja2" | ||
|
|
||
| #---------------- | ||
| # Infrastructure setup | ||
| #---------------- | ||
| NUM_GPUS=1 | ||
| ENABLE_RATE_LIMITING=true | ||
| # OpenHands trials are heavier than terminus-2 but Docker runs locally. | ||
| TRAJECTORIES_PER_SECOND=2 | ||
| MAX_CONCURRENCY=4 | ||
|
|
||
| # Run SkyRL command with OpenHands agent | ||
| uv run --isolated --extra fsdp --extra harbor -m examples.train_integrations.harbor.entrypoints.main_harbor \ | ||
| data.train_data=$TRAIN_DATA \ | ||
| data.val_data=$EVAL_DATA \ | ||
| trainer.policy.model.path=Qwen/Qwen3-1.7B \ | ||
| generator.served_model_name=Qwen3-1.7B \ | ||
| hydra.searchpath=['file://examples/train_integrations/harbor'] \ | ||
| +harbor_trial_config=openhands \ | ||
| ++harbor_trial_config.trials_dir=$TRIALS_DIR \ | ||
| trainer.export_path=$EXPORTS_DIR \ | ||
| trainer.ckpt_path=$CKPTS_DIR \ | ||
| trainer.log_path=$LOG_DIR \ | ||
| trainer.algorithm.advantage_estimator=grpo \ | ||
| trainer.algorithm.loss_reduction=$LOSS_REDUCTION \ | ||
| trainer.algorithm.grpo_norm_by_std=$GRPO_NORM_BY_STD \ | ||
| trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ | ||
| trainer.placement.colocate_all=true \ | ||
| trainer.strategy=fsdp2 \ | ||
| trainer.placement.policy_num_nodes=1 \ | ||
| trainer.placement.ref_num_nodes=1 \ | ||
| trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ | ||
| trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ | ||
| generator.num_inference_engines=$NUM_GPUS \ | ||
| generator.inference_engine_tensor_parallel_size=1 \ | ||
| +generator.engine_init_kwargs.chat_template=$CHAT_TEMPLATE_PATH \ | ||
| +generator.engine_init_kwargs.max_model_len=$MAX_MODEL_LEN \ | ||
| +generator.engine_init_kwargs.enable_log_requests=false \ | ||
| trainer.epochs=1 \ | ||
| trainer.eval_batch_size=128 \ | ||
| trainer.eval_before_train=false \ | ||
| trainer.eval_interval=20 \ | ||
| trainer.update_epochs_per_batch=1 \ | ||
| trainer.train_batch_size=$MINI_BATCH_SIZE \ | ||
| trainer.policy_mini_batch_size=$MINI_BATCH_SIZE \ | ||
| trainer.micro_forward_batch_size_per_gpu=1 \ | ||
| trainer.micro_train_batch_size_per_gpu=1 \ | ||
| trainer.ckpt_interval=5 \ | ||
| trainer.hf_save_interval=5 \ | ||
| trainer.algorithm.max_seq_len=$MAX_MODEL_LEN \ | ||
| trainer.policy.optimizer_config.lr=1.0e-6 \ | ||
| generator.n_samples_per_prompt=8 \ | ||
| generator.eval_n_samples_per_prompt=4 \ | ||
| generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \ | ||
| generator.gpu_memory_utilization=0.5 \ | ||
| trainer.logger=wandb \ | ||
| trainer.project_name=harbor \ | ||
| trainer.run_name=$RUN_NAME \ | ||
| trainer.resume_mode=latest \ | ||
| generator.backend=vllm \ | ||
| generator.run_engines_locally=true \ | ||
| generator.weight_sync_backend=nccl \ | ||
| generator.async_engine=true \ | ||
| generator.batched=false \ | ||
| generator.enforce_eager=false \ | ||
| generator.enable_http_endpoint=true \ | ||
| generator.http_endpoint_host=0.0.0.0 \ | ||
| generator.http_endpoint_port=8000 \ | ||
| ++harbor_trial_config.agent.kwargs.api_base="${VLLM_API_BASE}" \ | ||
| +generator.rate_limit.enabled=$ENABLE_RATE_LIMITING \ | ||
| +generator.rate_limit.trajectories_per_second=$TRAJECTORIES_PER_SECOND \ | ||
| +generator.rate_limit.max_concurrency=$MAX_CONCURRENCY |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Several required modules (
json,pathlib.Path, andurllib.parse.urlparse) are used within the newChatHistoryExtractorclass but have not been imported in this file. This will result in aNameErrorat runtime when these methods are called.