diff --git a/openmanus_rl/llm_agent/openmanus.py b/openmanus_rl/llm_agent/openmanus.py
index 9259b8bf..7f181852 100644
--- a/openmanus_rl/llm_agent/openmanus.py
+++ b/openmanus_rl/llm_agent/openmanus.py
@@ -30,7 +30,7 @@ class AgentConfig:
num_gpus: Number of GPUs to use
react_format: Whether to use ReAct format
env_name: Name of the environment (e.g., "webshop")
- env_port: Port number for environment server
+ env_ports: List of ports for parallel servers
env_server_base: Base URL for environment server
env_data_len: Number of data samples in the environment (used for client init)
rollout_strategy: Strategy to use for rollout (StandardReAct/ToT/MCTS)
@@ -48,7 +48,7 @@ class AgentConfig:
# Environment configuration (Now passed from trainer)
env_name: str
- env_port: int
+ env_ports: List[int] # List of ports for parallel servers
env_server_base: str
env_data_len: int = 200 # Default, might need adjustment
rollout_strategy: str = "StandardReAct" # Strategy is now internal logic
@@ -117,41 +117,39 @@ def __init__(
max_start_length=config.max_start_length
))
- # Initialize the environment client directly
- self.client = self._init_env_client()
- # Initialize thread pool for parallel rollouts
- self.executor = ThreadPoolExecutor(max_workers=self.config.max_workers)
+ # Initialize multiple environment clients
+ self.clients = self._init_env_clients() # Changed method name
- def _init_env_client(self):
+ # Adjust thread pool size based on number of clients, up to max_workers
+ num_clients = len(self.clients)
+ actual_workers = min(num_clients, self.config.max_workers)
+ if actual_workers < num_clients:
+ print(f"[Warning] Number of clients ({num_clients}) exceeds max_workers ({self.config.max_workers}). Using {actual_workers} workers.")
+ print(f"[Info] Initializing ThreadPoolExecutor with {actual_workers} workers for {num_clients} clients.")
+ self.executor = ThreadPoolExecutor(max_workers=actual_workers)
+
+ def _init_env_clients(self) -> List[Any]: # Renamed and return type changed
"""
- Initialize and return the specific AgentGym environment client based on config.
+ Initialize and return a list of specific AgentGym environment clients
+ based on the ports provided in the config.
"""
+ clients = []
+ env_name_lower = self.config.env_name.lower()
+
# Mapping from env_name (lowercase) to Task class name
- # We need the Task class to potentially get the right client or initial setup
ENV_TO_TASK_CLASS = {
- "academia": "AcademiaTask",
- "alfworld": "AlfWorldTask",
- "babyai": "BabyAITask",
- "maze": "MazeTask",
- "wordle": "WordleTask",
- "movie": "MovieTask",
- "sciworld": "SciworldTask",
- "sheet": "SheetTask",
- "sqlgym": "SqlGymTask",
- "textcraft": "TextCraftTask",
- "todo": "TodoTask",
- "weather": "WeatherTask",
- "webarena": "WebarenaTask",
- "webshop": "WebshopTask",
+ "academia": "AcademiaTask", "alfworld": "AlfWorldTask", "babyai": "BabyAITask",
+ "maze": "MazeTask", "wordle": "WordleTask", "movie": "MovieTask",
+ "sciworld": "SciworldTask", "sheet": "SheetTask", "sqlgym": "SqlGymTask",
+ "textcraft": "TextCraftTask", "todo": "TodoTask", "weather": "WeatherTask",
+ "webarena": "WebarenaTask", "webshop": "WebshopTask",
}
-
- env_name_lower = self.config.env_name.lower()
+
if env_name_lower not in ENV_TO_TASK_CLASS:
raise ValueError(f"Unsupported environment name: {self.config.env_name}. Supported: {list(ENV_TO_TASK_CLASS.keys())}")
task_class_name = ENV_TO_TASK_CLASS[env_name_lower]
- print(f"Initializing Env Client for: {self.config.env_name} (via Task: {task_class_name})")
- print(f"Connecting to AgentGym server at: {self.config.env_server_base}:{self.config.env_port}")
+ print(f"[Info] Initializing {len(self.config.env_ports)} Env Client(s) for: {self.config.env_name} (via Task: {task_class_name})")
# Dynamically import the Task class
try:
@@ -160,28 +158,41 @@ def _init_env_client(self):
except (ImportError, AttributeError) as e:
raise ImportError(f"Could not import Task class {task_class_name} from agentenv.envs: {e}")
- client_args={
- "env_server_base": f"{self.config.env_server_base}:{self.config.env_port}",
- "data_len": self.config.env_data_len,
- "timeout": 300,
- }
-
- # Instantiate the task to get the client.
- # Assuming Task object creates and holds the client(s) in a list `clients`.
- # This might need adjustment based on actual Task implementation.
- try:
- # We only need one client instance per agent worker typically.
- task_instance = TaskClass(client_args=client_args, n_clients=1)
- if hasattr(task_instance, 'clients') and task_instance.clients:
- client = task_instance.clients[0]
- print(f"Successfully obtained client: {type(client)}")
- return client
- else:
- raise ValueError(f"Task class {task_class_name} did not provide a client in 'clients' attribute.")
- except Exception as e:
- print(f"Error initializing Task or getting client for {task_class_name}: {e}")
- print(traceback.format_exc()) # Print detailed traceback
- raise
+ for i, port in enumerate(self.config.env_ports):
+ server_url = f"{self.config.env_server_base}:{port}"
+ print(f" - Client {i+1}: Connecting to {server_url}")
+
+ client_args={
+ "env_server_base": server_url,
+ "data_len": self.config.env_data_len,
+ "timeout": 300,
+ }
+
+ try:
+ # Instantiate the task to get the client.
+ # We need one client per specified port.
+ # Assuming TaskClass handles client creation correctly when n_clients=1.
+ # If TaskClass itself manages multiple internal clients, this might need adjustment.
+ task_instance = TaskClass(client_args=client_args, n_clients=1)
+ if hasattr(task_instance, 'clients') and task_instance.clients:
+ client = task_instance.clients[0]
+ print(f" - Client {i+1}: Successfully obtained client: {type(client)}")
+ clients.append(client)
+ else:
+ print(f" - Client {i+1}: Error - Task class {task_class_name} did not provide a client for port {port}.")
+ # Decide how to handle failure: raise error or skip this client? Skipping for now.
+ # raise ValueError(f"Task class {task_class_name} did not provide a client for port {port}.")
+ except Exception as e:
+ print(f" - Client {i+1}: Error initializing Task or getting client for port {port}: {e}")
+ print(traceback.format_exc()) # Print detailed traceback
+ # Decide how to handle failure: raise error or skip? Skipping for now.
+ # raise
+
+ if not clients:
+ raise RuntimeError("Failed to initialize any environment clients.")
+
+ print(f"[Info] Successfully initialized {len(clients)} environment clients.")
+ return clients
def _batch_tokenize(self, responses: List[str]) -> torch.Tensor:
"""Tokenize a batch of responses."""
@@ -398,13 +409,14 @@ def _save_trajectory(self, trajectory: List[Dict],
for filename in filenames:
self.logger.logger['wandb'].save(filename)
- def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int) -> Dict[str, Any]:
+ def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int, client: Any) -> Dict[str, Any]:
"""
- Runs the interaction loop for a single environment instance.
+ Runs the interaction loop for a single environment instance using the provided client.
Args:
initial_prompt_ids: Token IDs for the initial prompt/observation.
task_idx: The index for resetting the environment.
+ client: The specific environment client instance to use for this rollout.
Returns:
A dictionary containing the trajectory, step rewards, final reward, turns,
@@ -419,13 +431,13 @@ def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int) -
current_input_ids = None
try:
- # Reset environment
- self.client.reset(task_idx)
- initial_obs_text = self.client.observe()
+ # Reset environment using the provided client
+ client.reset(task_idx)
+ initial_obs_text = client.observe()
# Handle initial observation
if not initial_obs_text:
- print(f"[Agent._run_single_rollout][{task_idx}] Warning: Received empty initial observation. Using initial prompt from batch.")
+ print(f"[Agent._run_single_rollout][{task_idx} @ {client.env_server_base}] Warning: Received empty initial observation. Using initial prompt from batch.")
initial_prompt_text = self.tokenizer.decode(initial_prompt_ids[0], skip_special_tokens=True)
trajectory.append({"from": "human", "value": initial_prompt_text})
current_input_ids = initial_prompt_ids
@@ -441,15 +453,17 @@ def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int) -
# Handle input that exceeds max length
if current_input_ids.shape[1] > self.config.max_prompt_length:
current_input_ids = current_input_ids[:, -self.config.max_prompt_length:]
- print(f"[Agent._run_single_rollout][{task_idx}] Warning: Truncating input {current_input_ids.shape} > {self.config.max_prompt_length}.")
+ print(f"[Agent._run_single_rollout][{task_idx} @ {client.env_server_base}] Warning: Truncating input {current_input_ids.shape} > {self.config.max_prompt_length}.")
# Prepare input
current_attention_mask = self.tensor_fn.create_attention_mask(current_input_ids)
current_position_ids = self.tensor_fn.create_position_ids(current_attention_mask)
+ # Ensure input tensors are on the correct device for the actor model
+ device = next(self.actor_rollout_wg.actor_model.parameters()).device # Get model's device
gen_input_proto = DataProto.from_dict({
- 'input_ids': current_input_ids.to(self.actor_rollout_wg.device),
- 'attention_mask': current_attention_mask.to(self.actor_rollout_wg.device),
- 'position_ids': current_position_ids.to(self.actor_rollout_wg.device)
+ 'input_ids': current_input_ids.to(device),
+ 'attention_mask': current_attention_mask.to(device),
+ 'position_ids': current_position_ids.to(device)
})
# Generate response
@@ -460,6 +474,7 @@ def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int) -
temperature=1.0,
do_sample=True
)
+ # Generation happens on the actor worker group's device
gen_output_proto = self.actor_rollout_wg.generate_sequences(gen_input_proto, generation_config=generation_config)
response_ids = gen_output_proto.batch['response_ids']
response_text = self.tokenizer.decode(response_ids[0], skip_special_tokens=True)
@@ -469,29 +484,30 @@ def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int) -
action_types, action_contents = self.postprocess_predictions([response_text])
action_text = action_contents[0]
- # Execute environment step
+ # Execute environment step using the provided client
if action_text is None: action_text = ""
- next_obs_text, reward, done, info = self.client.step(action_text)
+ next_obs_text, reward, done, info = client.step(action_text)
# Record rewards
step_rewards.append(reward)
final_reward = reward
- final_env_score = info.get('score', 0.0)
+ final_env_score = info.get('score', 0.0) # Use .get for safety
# Process next observation
if not done:
trajectory.append({"from": "human", "value": next_obs_text})
next_obs_ids = self.tokenizer(next_obs_text, return_tensors='pt', add_special_tokens=False)['input_ids']
+ # Ensure tensors are concatenated on the same device (e.g., CPU or model's device if needed later)
current_input_ids = torch.cat([
- current_input_ids,
- response_ids.to(current_input_ids.device),
- next_obs_ids.to(current_input_ids.device)
+ current_input_ids.to(response_ids.device), # Move to same device as response_ids
+ response_ids,
+ next_obs_ids.to(response_ids.device) # Move to same device
], dim=1)
else:
break
except Exception as e:
- print(f"[Agent._run_single_rollout][{task_idx}] Error during rollout: {e}")
+ print(f"[Agent._run_single_rollout][{task_idx} @ {getattr(client, 'env_server_base', 'unknown_client')}] Error during rollout: {e}")
print(traceback.format_exc())
step_rewards = []
final_reward = 0.0
@@ -509,25 +525,28 @@ def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int) -
def run_llm_loop(self, gen_batch: DataProto, output_dir: str = None, global_steps: int = 0) -> DataProto:
"""
- Run the LLM interaction loop for a batch of initial prompts.
- Updated to include trajectory visualization similar to generation.py.
-
+ Run the LLM interaction loop for a batch of initial prompts using multiple clients.
+
Args:
gen_batch: DataProto containing initial prompts
output_dir: Directory to save visualizations
global_steps: Current training step
-
+
Returns:
DataProto containing processed results
"""
initial_prompts_ids = gen_batch.batch['input_ids']
batch_size = initial_prompts_ids.shape[0]
- print(f"[Agent.run_llm_loop] Starting rollout for batch size: {batch_size}")
+ num_clients = len(self.clients)
+ if num_clients == 0:
+ raise RuntimeError("No environment clients available for rollout.")
+
+ print(f"[Agent.run_llm_loop] Starting rollout for batch size: {batch_size} using {num_clients} clients.")
# --- Setup Visualization ---
trajectory = self._setup_visualization()
-
- # --- Extract Task Indices ---
+
+ # --- Extract Task Indices ---
if 'idx' in gen_batch.meta_info:
task_idxs = gen_batch.meta_info['idx']
if isinstance(task_idxs, torch.Tensor):
@@ -539,74 +558,76 @@ def run_llm_loop(self, gen_batch: DataProto, output_dir: str = None, global_step
print("[Agent.run_llm_loop] Warning: 'idx' not found in gen_batch.meta_info. Using range(batch_size)." )
task_idxs = list(range(batch_size))
- # Create active_mask to track active tasks
- active_mask = torch.ones(batch_size, dtype=torch.bool)
- active_num_list = [active_mask.sum().item()]
-
- # --- Parallel Rollout Execution ---
+ # --- Parallel Rollout Execution ---
futures = {}
rollout_results_list = [None] * batch_size # Preallocate list to store results in order
-
- # Create task list
- envs = [] # Store environment instances
+ # Submit tasks to the thread pool, distributing across clients
for i in range(batch_size):
task_idx = task_idxs[i]
initial_prompt = initial_prompts_ids[i:i+1] # Keep batch dim
- env = self.client # Assume client is the environment instance
- envs.append(env)
- future = self.executor.submit(self._run_single_rollout, initial_prompt, task_idx)
- futures[future] = i # Store original index
+
+ # Select a client for this task (round-robin)
+ client_index = i % num_clients
+ selected_client = self.clients[client_index]
+
+ # Submit the rollout task with the selected client
+ future = self.executor.submit(self._run_single_rollout, initial_prompt, task_idx, selected_client)
+ futures[future] = i # Store original batch index
+
+ print(f"[Agent.run_llm_loop] Submitted {batch_size} rollout tasks to {self.executor._max_workers} workers.")
# Collect results
+ completed_count = 0
for future in as_completed(futures):
original_index = futures[future]
try:
result_dict = future.result()
rollout_results_list[original_index] = result_dict
-
+ completed_count += 1
+ # print(f"Completed task {original_index + 1}/{batch_size}") # Optional progress logging
+
# If visualization is enabled, update trajectory
- if trajectory and original_index < len(trajectory):
- for turn in result_dict.get('trajectory', []):
- if turn.get('from') == 'gpt':
- # Create active_mask with only this index active
- current_active_mask = torch.zeros(batch_size, dtype=torch.bool)
- current_active_mask[original_index] = True
- # Update trajectory
- self._update_trajectory(
- trajectory,
- envs,
- [turn.get('value', '')],
- current_active_mask
- )
+ # Note: Visualization logic might need adjustment if envs are not easily accessible
+ # or if you want per-client visualization. Current logic assumes a single env list.
+ # Consider passing necessary env state back from _run_single_rollout if needed.
+ # if trajectory and original_index < len(trajectory):
+ # # This part might be tricky with multiple clients unless you manage env state carefully
+ # pass # Placeholder for potential visualization update logic
+
except Exception as e:
- print(f"[Agent.run_llm_loop] Error collecting result for index {original_index}: {e}")
- # Store a placeholder or error indicator if needed
+ print(f"[Agent.run_llm_loop] Error collecting result for batch index {original_index} (task_idx {task_idxs[original_index]}): {e}")
+ print(traceback.format_exc())
+ # Store a placeholder or error indicator
rollout_results_list[original_index] = {
- 'trajectory': [], 'step_rewards': [], 'reward': 0.0,
- 'turns': 0, 'env_score': 0.0, 'task_idx': task_idxs[original_index],
+ 'trajectory': [], 'step_rewards': [], 'reward': 0.0,
+ 'turns': 0, 'env_score': 0.0, 'task_idx': task_idxs[original_index],
'error': str(e)
}
- # Update active_mask to mark current task as inactive
- active_mask[original_index] = False
- active_num_list.append(active_mask.sum().item())
- print(f"[Agent.run_llm_loop] Collected results from {len(futures)} rollouts. Active trajectory nums: {active_num_list}")
-
- # Save trajectory visualizations
- if output_dir and trajectory:
- self._save_trajectory(trajectory, output_dir, global_steps)
-
- # Filter out potential None entries if some tasks failed critically before returning dict
+ print(f"[Agent.run_llm_loop] Collected results from {completed_count}/{batch_size} rollouts.")
+
+ # Save trajectory visualizations (if implemented and needed)
+ # if output_dir and trajectory:
+ # self._save_trajectory(trajectory, output_dir, global_steps)
+
+ # Filter out potential None entries if some tasks failed critically
valid_results = [res for res in rollout_results_list if res is not None]
-
+
if not valid_results:
print("[Agent.run_llm_loop] Error: No valid rollout results collected.")
- return DataProto.from_dict({}) # Return empty DataProto
-
- # --- Format Results into DataProto ---
+ # Return empty DataProto but with correct structure if possible
+ return DataProto.from_dict({
+ "input_ids": torch.empty((0,0), dtype=torch.long),
+ "attention_mask": torch.empty((0,0), dtype=torch.long),
+ "position_ids": torch.empty((0,0), dtype=torch.long),
+ "info_mask": torch.empty((0,0), dtype=torch.long),
+ "token_level_rewards": torch.empty((0,0), dtype=torch.float)
+ })
+
+ # --- Format Results into DataProto ---
processed_data = self._convert_rollout_results_to_dataproto(valid_results, gen_batch)
-
+
print(f"[Agent.run_llm_loop] Finished processing rollout results.")
return processed_data
diff --git a/test/test_openmanus.py b/test/test_openmanus.py
new file mode 100644
index 00000000..43a0f0ce
--- /dev/null
+++ b/test/test_openmanus.py
@@ -0,0 +1,85 @@
+import pytest
+import os
+import sys
+from pathlib import Path
+
+# Import target functions/classes
+from openmanus_rl.llm_agent.openmanus import OpenManusAgent, create_react_prompt
+
+# 将 OpenManus-RL 路径加入 sys.path,确保能够找到 openmanus_rl 包
+PROJECT_ROOT = Path(__file__).resolve().parent.parent / "OpenManus-RL"
+if PROJECT_ROOT.exists() and str(PROJECT_ROOT) not in sys.path:
+ sys.path.insert(0, str(PROJECT_ROOT))
+
+
+class DummyToolManager:
+ """简易 ToolManager,仅返回固定的提示指令。"""
+
+ def get_prompt_instructions(self):
+ return "These are tool instructions."
+
+
+class MinimalAgent(OpenManusAgent):
+ """覆盖 __init__ 以绕过复杂依赖,仅用于测试无状态方法。"""
+
+ def __init__(self): # pylint: disable=super-init-not-called
+ pass
+
+
+# ------------------------------
+# postprocess_predictions 测试
+# ------------------------------
+
+def test_postprocess_action():
+ agent = MinimalAgent()
+ predictions = ["CLICK_BUTTON"]
+ actions, contents = agent.postprocess_predictions(predictions)
+ assert actions == ["action"]
+ assert contents == ["CLICK_BUTTON"]
+
+
+def test_postprocess_response():
+ agent = MinimalAgent()
+ predictions = ["Hello World"]
+ actions, contents = agent.postprocess_predictions(predictions)
+ assert actions == ["response"]
+ assert contents == ["Hello World"]
+
+
+def test_postprocess_no_tag():
+ agent = MinimalAgent()
+ predictions = ["Hello there"]
+ actions, contents = agent.postprocess_predictions(predictions)
+ assert actions == [None]
+ assert contents == [""]
+
+
+def test_postprocess_invalid_type():
+ """非字符串输入应抛出 ValueError。"""
+ agent = MinimalAgent()
+ with pytest.raises(ValueError):
+ agent.postprocess_predictions([123])
+
+
+def test_postprocess_multiple_tags():
+ """确保仅解析第一个 标签内容。"""
+ agent = MinimalAgent()
+ predictions = ["foobar"]
+ actions, contents = agent.postprocess_predictions(predictions)
+ assert actions == ["action"]
+ assert contents == ["foo"]
+
+
+# ------------------------------
+# create_react_prompt 测试
+# ------------------------------
+
+def test_create_react_prompt_contains_sections():
+ task_description = "Navigate to the red door."
+ tool_manager = DummyToolManager()
+ prompt = create_react_prompt(task_description, tool_manager)
+
+ # Prompt 应包含任务描述、工具指令和固定结尾
+ assert task_description in prompt
+ assert tool_manager.get_prompt_instructions() in prompt
+ assert "Let's solve this step by step." in prompt
\ No newline at end of file
diff --git a/train_ppo.sh b/train_ppo.sh
index addbc7f3..c574e580 100644
--- a/train_ppo.sh
+++ b/train_ppo.sh
@@ -4,19 +4,22 @@
export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}
WAND_PROJECT=${WAND_PROJECT:-'OpenManus-rl'}
export BASE_MODEL=${BASE_MODEL:-'meta-llama/Llama-3.2-3B'}
-AGENTGYM_HOST=${AGENTGYM_HOST:-'0.0.0.0'}
+AGENTGYM_HOST=${AGENTGYM_HOST:-'0.0.0.0'} # Default to 0.0.0.0 for external access
AGENTGYM_SQL_BIRD_PATH=${AGENTGYM_SQL_BIRD_PATH:-} # Used only for sqlgym
# --- Argument Parsing ---
usage() {
- echo "Usage: $0 --env_name [--port ] [--data_dir ] [--exp_name_suffix ]"
+ echo "Usage: $0 --env_name [--num_servers ] [--base_port ] [--data_dir ] [--exp_name_suffix ]"
echo "Supported env_names: webshop, webarena, maze, wordle, alfworld, sciworld, babyai, textcraft, weather, movie, academia, todo, sheet, sqlgym"
+ echo " --num_servers: Number of parallel AgentGym servers to launch (default: 1)."
+ echo " --base_port: Starting port number for servers (default varies by env)."
echo "Assumes dedicated conda environments like 'agentenv-webshop' are already created and set up."
exit 1
}
AGENTGYM_ENV_NAME="webshop" # Default environment
-AGENTGYM_PORT_OVERRIDE=""
+NUM_SERVERS=1 # Default number of servers
+BASE_PORT_OVERRIDE=""
DATA_DIR_OVERRIDE=""
EXP_NAME_SUFFIX=""
@@ -25,8 +28,10 @@ while [[ $# -gt 0 ]]; do
case $key in
--env_name)
AGENTGYM_ENV_NAME="$2"; shift; shift;;
- --port)
- AGENTGYM_PORT_OVERRIDE="$2"; shift; shift;;
+ --num_servers)
+ NUM_SERVERS="$2"; shift; shift;;
+ --base_port) # Changed from --port to --base_port
+ BASE_PORT_OVERRIDE="$2"; shift; shift;;
--data_dir)
DATA_DIR_OVERRIDE="$2"; shift; shift;;
--exp_name_suffix)
@@ -36,6 +41,11 @@ while [[ $# -gt 0 ]]; do
esac
done
+if ! [[ "$NUM_SERVERS" =~ ^[1-9][0-9]*$ ]]; then
+ echo "Error: --num_servers must be a positive integer."
+ usage
+fi
+
if [ -z "$AGENTGYM_ENV_NAME" ]; then
echo "Error: --env_name is required."; usage
fi
@@ -45,122 +55,149 @@ BASE_CONDA_ENV=${CONDA_DEFAULT_ENV:-openmanus-rl}
echo "[Info] Detected base conda environment: $BASE_CONDA_ENV"
echo "[Info] Verl trainer will run in this environment."
-# --- Environment Specific Setup (Determine LAUNCH_CMD, DEFAULT_PORT, URL_PATH) ---
+
+# --- Environment Specific Setup (Determine LAUNCH_CMD, DEFAULT_BASE_PORT, URL_PATH) ---
+
LAUNCH_CMD=""
-DEFAULT_PORT=""
+DEFAULT_BASE_PORT="" # Renamed from DEFAULT_PORT
URL_PATH=""
-# 不再使用Python -m 模块方式
# MODULE_LAUNCH_NAME=""
-# 设置默认主机为0.0.0.0以允许外部访问
AGENTGYM_HOST=${AGENTGYM_HOST:-'0.0.0.0'}
case $AGENTGYM_ENV_NAME in
webshop)
LAUNCH_CMD="webshop --host $AGENTGYM_HOST --port \$AGENTGYM_PORT"
- DEFAULT_PORT=36001;;
+ DEFAULT_BASE_PORT=36001;;
webarena)
LAUNCH_CMD="webarena --host $AGENTGYM_HOST --port \$AGENTGYM_PORT"
- DEFAULT_PORT=8000;;
+ DEFAULT_BASE_PORT=8000;;
maze)
LAUNCH_CMD="lmrlgym --host $AGENTGYM_HOST --port \$AGENTGYM_PORT"
- DEFAULT_PORT=36001; URL_PATH="/maze/";;
+ DEFAULT_BASE_PORT=36001; URL_PATH="/maze/";;
wordle)
LAUNCH_CMD="lmrlgym --host $AGENTGYM_HOST --port \$AGENTGYM_PORT"
- DEFAULT_PORT=36001; URL_PATH="/wordle/";;
+ DEFAULT_BASE_PORT=36001; URL_PATH="/wordle/";;
alfworld)
LAUNCH_CMD="alfworld --host $AGENTGYM_HOST --port \$AGENTGYM_PORT"
- DEFAULT_PORT=36001;;
+ DEFAULT_BASE_PORT=36001;;
sciworld)
LAUNCH_CMD="sciworld --host $AGENTGYM_HOST --port \$AGENTGYM_PORT"
- DEFAULT_PORT=36001;;
+ DEFAULT_BASE_PORT=36001;;
babyai)
LAUNCH_CMD="babyai --host $AGENTGYM_HOST --port \$AGENTGYM_PORT"
- DEFAULT_PORT=36001;;
+ DEFAULT_BASE_PORT=36001;;
textcraft)
LAUNCH_CMD="textcraft --host $AGENTGYM_HOST --port \$AGENTGYM_PORT"
- DEFAULT_PORT=36001;;
+ DEFAULT_BASE_PORT=36001;;
weather|movie|academia|todo|sheet)
- LAUNCH_CMD="\$AGENTGYM_ENV_NAME --host $AGENTGYM_HOST --port \$AGENTGYM_PORT"
- DEFAULT_PORT=8000;;
+ LAUNCH_CMD="\\\$AGENTGYM_ENV_NAME --host $AGENTGYM_HOST --port \\\$AGENTGYM_PORT" # Escaped env name var
+ DEFAULT_BASE_PORT=8000;;
sqlgym)
if [ -z "$AGENTGYM_SQL_BIRD_PATH" ]; then echo "Error: AGENTGYM_SQL_BIRD_PATH must be set for sqlgym."; exit 1; fi
- LAUNCH_CMD="AGENTENV_SQLGYM_BIRD_PATH=$AGENTGYM_SQL_BIRD_PATH sqlgym --host $AGENTGYM_HOST --port \$AGENTGYM_PORT"
- DEFAULT_PORT=36002;;
+ LAUNCH_CMD="AGENTENV_SQLGYM_BIRD_PATH=$AGENTGYM_SQL_BIRD_PATH sqlgym --host $AGENTGYM_HOST --port \\\$AGENTGYM_PORT"
+ DEFAULT_BASE_PORT=36002;;
*)
echo "Error: Unsupported environment name '$AGENTGYM_ENV_NAME'"; usage;;
esac
-# --- Start AgentGym Server in its Dedicated Environment ---
+# --- Start AgentGym Servers in Dedicated Environment ---
TARGET_ENV_NAME="agentenv-${AGENTGYM_ENV_NAME}"
-AGENTGYM_PID=""
+AGENTGYM_PIDS=() # Array to store PIDs
+AGENTGYM_PORTS=() # Array to store ports
# Check if target env exists
-if ! conda env list | grep -Eq "^${TARGET_ENV_NAME}\s"; then
+if ! conda env list | grep -Eq "^${TARGET_ENV_NAME}\\s"; then
echo "[Error] Dedicated environment '$TARGET_ENV_NAME' not found. Please create it first."
exit 1
fi
-# Prepare Launch Command (Prefer python -m style if defined)
-export AGENTGYM_PORT=${AGENTGYM_PORT_OVERRIDE:-$DEFAULT_PORT}
-
-# 直接使用命令行工具方式启动
-FINAL_LAUNCH_CMD=$(eval echo $LAUNCH_CMD) # Substitute $AGENTGYM_PORT
+# Determine base port
+AGENTGYM_BASE_PORT=${BASE_PORT_OVERRIDE:-$DEFAULT_BASE_PORT}
-echo -e "\\n[Server] Starting AgentGym server for ${AGENTGYM_ENV_NAME} in env '$TARGET_ENV_NAME'..."
-echo "[Server] Host: ${AGENTGYM_HOST}, Port: ${AGENTGYM_PORT}"
-echo "[Server] Launch command: $FINAL_LAUNCH_CMD"
+echo -e "\\n[Server] Starting $NUM_SERVERS AgentGym server(s) for ${AGENTGYM_ENV_NAME} in env '$TARGET_ENV_NAME'..."
+echo "[Server] Base Port: ${AGENTGYM_BASE_PORT}"
# Create logs directory
mkdir -p logs
-# Run server in background using conda run in the target environment
-LOG_FILE="logs/${TARGET_ENV_NAME}_server.log"
-echo "[Server] Logging to $LOG_FILE"
-
-# 简化启动命令,直接使用conda run启动服务器
-conda run -n "$TARGET_ENV_NAME" $FINAL_LAUNCH_CMD > "$LOG_FILE" 2>&1 &
-AGENTGYM_PID=$!
-
-# Check if PID was obtained
-if [ -z "$AGENTGYM_PID" ]; then
- echo "[Error] Failed to get PID for AgentGym server launch command."
- exit 1
-fi
-echo "[Server] AgentGym server launched in '$TARGET_ENV_NAME' (PID: $AGENTGYM_PID)."
+for (( i=0; i<$NUM_SERVERS; i++ )); do
+ # Calculate port for this server instance
+ export AGENTGYM_PORT=$((AGENTGYM_BASE_PORT + i))
+ AGENTGYM_PORTS+=($AGENTGYM_PORT) # Store port
+
+ # Prepare the specific launch command for this instance
+ CURRENT_LAUNCH_CMD=$(eval echo $LAUNCH_CMD) # Substitute $AGENTGYM_PORT
+
+ echo "[Server $(($i+1))/$NUM_SERVERS] Launching on ${AGENTGYM_HOST}:${AGENTGYM_PORT}..."
+ echo "[Server $(($i+1))/$NUM_SERVERS] Command: $CURRENT_LAUNCH_CMD"
+
+ # Run server in background using conda run
+ LOG_FILE="logs/${TARGET_ENV_NAME}_server_${AGENTGYM_PORT}.log"
+ echo "[Server $(($i+1))/$NUM_SERVERS] Logging to $LOG_FILE"
+
+ # Use bash -c to handle potential env vars in launch cmd (like for sqlgym)
+ conda run --no-capture-output -n "$TARGET_ENV_NAME" bash -c "$CURRENT_LAUNCH_CMD" > "$LOG_FILE" 2>&1 &
+ PID=$!
+
+ # Check if PID was obtained
+ if [ -z "$PID" ]; then
+ echo "[Error] Failed to get PID for AgentGym server instance $i on port $AGENTGYM_PORT."
+ # Attempt to kill already launched servers before exiting
+ for p in "${AGENTGYM_PIDS[@]}"; do kill $p 2>/dev/null; done
+ exit 1
+ fi
+ AGENTGYM_PIDS+=($PID) # Store PID
+ echo "[Server $(($i+1))/$NUM_SERVERS] Launched (PID: $PID)."
+ sleep 2 # Small delay between starting servers
+done
-# --- Wait and Check Server ---
-echo "[Server] Waiting for AgentGym server (PID: $AGENTGYM_PID) to initialize..."
+# --- Wait and Check Servers ---
+echo "[Server] Waiting for AgentGym servers (${AGENTGYM_PIDS[*]}) to initialize..."
sleep 15 # Adjust sleep time if needed
-# Check if the process with the captured PID is still running
-if ! kill -0 $AGENTGYM_PID > /dev/null 2>&1; then
- echo "[Error] AgentGym server (PID: $AGENTGYM_PID) failed to start or exited prematurely."
- echo "[Error] Check server log: $LOG_FILE"
+# Check if all server processes are still running
+ALL_SERVERS_RUNNING=true
+for PID in "${AGENTGYM_PIDS[@]}"; do
+ if ! kill -0 $PID > /dev/null 2>&1; then
+ echo "[Error] AgentGym server (PID: $PID) failed to start or exited prematurely."
+ # Attempt to find the corresponding log file (this is a bit heuristic)
+ PORT=$(grep -oP -- "--port\\s+\\K\\d+" "logs/"*"${PID}"* 2>/dev/null || echo "unknown")
+ echo "[Error] Check server log potentially named logs/${TARGET_ENV_NAME}_server_${PORT}.log or similar."
+ ALL_SERVERS_RUNNING=false
+ fi
+done
+
+if [ "$ALL_SERVERS_RUNNING" = false ]; then
+ echo "[Error] Not all servers started successfully. Exiting."
+ # Kill remaining servers
+ for p in "${AGENTGYM_PIDS[@]}"; do kill $p 2>/dev/null; done
exit 1
fi
-echo "[Server] AgentGym server appears to be running (PID: $AGENTGYM_PID)."
+echo "[Server] All AgentGym servers appear to be running."
-# Setup trap to kill the server process on script exit/interrupt
-trap "echo '[Cleanup] Stopping AgentGym server (PID: $AGENTGYM_PID)...'; kill $AGENTGYM_PID 2>/dev/null || echo '[Cleanup] Server already stopped.'" EXIT
+# Setup trap to kill all server processes on script exit/interrupt
+trap "echo '[Cleanup] Stopping AgentGym servers (PIDs: ${AGENTGYM_PIDS[*]})...'; kill ${AGENTGYM_PIDS[*]} 2>/dev/null || echo '[Cleanup] Servers already stopped.'; wait ${AGENTGYM_PIDS[*]} 2>/dev/null" EXIT
# --- Data and Experiment Naming ---
export DATA_DIR=${DATA_DIR_OVERRIDE:-"data/$AGENTGYM_ENV_NAME"} # Default data dir based on env name
export EXPERIMENT_NAME="OpenManus-rl-ppo-${BASE_MODEL##*/}-${AGENTGYM_ENV_NAME}${EXP_NAME_SUFFIX}"
+
# --- Run PPO Training in Base Environment ---
echo -e "\\n[Trainer] Running PPO training in base environment '$BASE_CONDA_ENV'..."
export VLLM_ATTENTION_BACKEND=${VLLM_ATTENTION_BACKEND:-XFORMERS}
# Construct server base URL, adding path if needed
-AGENTGYM_SERVER_BASE="http://$AGENTGYM_HOST"
-if [ -n "$URL_PATH" ]; then
- AGENTGYM_SERVER_BASE="$AGENTGYM_SERVER_BASE$URL_PATH"
-fi
+AGENTGYM_SERVER_BASE="http://$AGENTGYM_HOST" # Base URL without port
+# Construct the list of ports as a comma-separated string for OmegaConf
+AGENTGYM_PORTS_STR=$(IFS=,; echo "${AGENTGYM_PORTS[*]}")
echo "[Trainer] Using Data Directory: $DATA_DIR"
echo "[Trainer] Experiment Name: $EXPERIMENT_NAME"
-echo "[Trainer] AgentGym Base URL: $AGENTGYM_SERVER_BASE:$AGENTGYM_PORT"
+
+echo "[Trainer] AgentGym Base URL: $AGENTGYM_SERVER_BASE"
+echo "[Trainer] AgentGym Ports: $AGENTGYM_PORTS_STR" # Pass list of ports
# Check if train/test files exist
TRAIN_FILE="$DATA_DIR/train.parquet"
@@ -173,87 +210,88 @@ if [ ! -f "$TEST_FILE" ]; then
echo "[Warning] Test file not found at $TEST_FILE. Ensure data generation script was run for $AGENTGYM_ENV_NAME."
fi
-TRAINER_LOG_FILE="logs/${EXPERIMENT_NAME}.log"
-echo "[Trainer] Logging trainer output to $TRAINER_LOG_FILE"
-
-# 正确方式激活conda环境(使用source)
-echo "[Trainer] Activating base environment for training..."
-# 获取conda base路径并确保conda.sh可用
+# Ensure base environment is activated correctly for trainer
+echo "[Trainer] Ensuring base environment '$BASE_CONDA_ENV' is active..."
CONDA_BASE=$(conda info --base)
source "${CONDA_BASE}/etc/profile.d/conda.sh"
-conda activate openmanus-rl
-
-# 检查并安装必要的依赖
-echo "[Trainer] Checking and installing required dependencies..."
-if ! python -c "import tensordict" &>/dev/null; then
- echo "[Trainer] Installing missing dependency: tensordict"
- pip install tensordict
-fi
+conda activate "$BASE_CONDA_ENV" || { echo "Error: Failed to activate base env '$BASE_CONDA_ENV'"; exit 1; }
+
+# Check and install dependencies within the base environment
+echo "[Trainer] Checking and installing required dependencies in '$BASE_CONDA_ENV'..."
+for pkg in tensordict codetiming ray wandb transformers; do
+ if ! python -c "import $pkg" &>/dev/null; then
+ echo "[Trainer] Installing missing dependency: $pkg"
+ pip install $pkg
+ fi
+done
+TRAINER_LOG_FILE="logs/${EXPERIMENT_NAME}.log"
+echo "[Trainer] Logging trainer output to $TRAINER_LOG_FILE"
echo "[Trainer] Starting PPO training..."
-PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
- data.train_files=$TRAIN_FILE \
- data.val_files=$TEST_FILE \
- data.env_name=$AGENTGYM_ENV_NAME \
- data.env_server_base=$AGENTGYM_SERVER_BASE \
- data.env_port=$AGENTGYM_PORT \
- data.train_data_num=null \
- data.val_data_num=null \
- data.train_batch_size=512 \
- data.val_batch_size=256 \
- data.max_prompt_length=4096 \
- data.max_response_length=500 \
- data.max_start_length=2048 \
- data.max_obs_length=500 \
- data.shuffle_train_dataloader=True \
- algorithm.adv_estimator=gae \
- actor_rollout_ref.model.path=$BASE_MODEL \
- actor_rollout_ref.actor.optim.lr=1e-6 \
- actor_rollout_ref.model.enable_gradient_checkpointing=true \
- actor_rollout_ref.model.use_remove_padding=True \
- actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.95 \
- actor_rollout_ref.actor.ppo_mini_batch_size=256 \
- actor_rollout_ref.actor.ppo_micro_batch_size=64 \
- actor_rollout_ref.actor.fsdp_config.param_offload=true \
- actor_rollout_ref.actor.fsdp_config.grad_offload=true \
- actor_rollout_ref.actor.fsdp_config.optimizer_offload=true \
- actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \
- actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
- actor_rollout_ref.rollout.name=vllm \
- actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
- actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
- actor_rollout_ref.ref.fsdp_config.param_offload=True \
- actor_rollout_ref.rollout.n_agent=1 \
- actor_rollout_ref.rollout.temperature=1 \
- actor_rollout_ref.actor.state_masking=true \
- critic.optim.lr=1e-5 \
- critic.model.use_remove_padding=True \
- critic.optim.lr_warmup_steps_ratio=0.05 \
- critic.model.path=$BASE_MODEL \
- critic.model.enable_gradient_checkpointing=true \
- critic.ppo_micro_batch_size=8 \
- critic.model.fsdp_config.param_offload=true \
- critic.model.fsdp_config.grad_offload=true \
- critic.model.fsdp_config.optimizer_offload=true \
- algorithm.kl_ctrl.kl_coef=0.001 \
- algorithm.no_think_rl=false \
- algorithm.reward_score_fn=agentgym \
- trainer.critic_warmup=0 \
- trainer.logger=['wandb'] \
- +trainer.val_only=false \
- +trainer.val_before_train=true \
- trainer.default_hdfs_dir=null \
- trainer.n_gpus_per_node=8 \
- trainer.nnodes=1 \
- trainer.save_freq=100 \
- trainer.test_freq=50 \
- trainer.project_name=$WAND_PROJECT \
- trainer.experiment_name=$EXPERIMENT_NAME \
- trainer.total_epochs=15 \
- trainer.total_training_steps=305 \
- trainer.default_hdfs_dir=null \
- trainer.default_local_dir=verl_checkpoints/$EXPERIMENT_NAME \
- max_turns=2 \
+
+PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\
+ data.train_files=$TRAIN_FILE \\
+ data.val_files=$TEST_FILE \\
+ data.env_name=$AGENTGYM_ENV_NAME \\
+ data.env_server_base=$AGENTGYM_SERVER_BASE \\
+ data.env_ports=[${AGENTGYM_PORTS_STR}] \\ // Pass ports as a list
+ data.train_data_num=null \\
+ data.val_data_num=null \\
+ data.train_batch_size=512 \\
+ data.val_batch_size=256 \\
+ data.max_prompt_length=4096 \\
+ data.max_response_length=500 \\
+ data.max_start_length=2048 \\
+ data.max_obs_length=500 \\
+ data.shuffle_train_dataloader=True \\
+ algorithm.adv_estimator=gae \\
+ actor_rollout_ref.model.path=$BASE_MODEL \\
+ actor_rollout_ref.actor.optim.lr=1e-6 \\
+ actor_rollout_ref.model.enable_gradient_checkpointing=true \\
+ actor_rollout_ref.model.use_remove_padding=True \\
+ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.95 \\
+ actor_rollout_ref.actor.ppo_mini_batch_size=256 \\
+ actor_rollout_ref.actor.ppo_micro_batch_size=64 \\
+ actor_rollout_ref.actor.fsdp_config.param_offload=true \\
+ actor_rollout_ref.actor.fsdp_config.grad_offload=true \\
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=true \\
+ actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \\
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\
+ actor_rollout_ref.rollout.name=vllm \\
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\
+ actor_rollout_ref.ref.log_prob_micro_batch_size=128 \\
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \\
+ actor_rollout_ref.rollout.n_agent=1 \\
+ actor_rollout_ref.rollout.temperature=1 \\
+ actor_rollout_ref.actor.state_masking=true \\
+ critic.optim.lr=1e-5 \\
+ critic.model.use_remove_padding=True \\
+ critic.optim.lr_warmup_steps_ratio=0.05 \\
+ critic.model.path=$BASE_MODEL \\
+ critic.model.enable_gradient_checkpointing=true \\
+ critic.ppo_micro_batch_size=8 \\
+ critic.model.fsdp_config.param_offload=true \\
+ critic.model.fsdp_config.grad_offload=true \\
+ critic.model.fsdp_config.optimizer_offload=true \\
+ algorithm.kl_ctrl.kl_coef=0.001 \\
+ algorithm.no_think_rl=false \\
+ algorithm.reward_score_fn=agentgym \\
+ trainer.critic_warmup=0 \\
+ trainer.logger=['wandb'] \\
+ +trainer.val_only=false \\
+ +trainer.val_before_train=true \\
+ trainer.default_hdfs_dir=null \\
+ trainer.n_gpus_per_node=8 \\
+ trainer.nnodes=1 \\
+ trainer.save_freq=100 \\
+ trainer.test_freq=50 \\
+ trainer.project_name=$WAND_PROJECT \\
+ trainer.experiment_name=$EXPERIMENT_NAME \\
+ trainer.total_epochs=15 \\
+ trainer.total_training_steps=305 \\
+ trainer.default_hdfs_dir=null \\
+ trainer.default_local_dir=verl_checkpoints/$EXPERIMENT_NAME \\
+ max_turns=2 \\
2>&1 | tee "$TRAINER_LOG_FILE" # Log trainer output
TRAINER_EXIT_CODE=$?
@@ -262,4 +300,5 @@ echo "PPO training finished with exit code $TRAINER_EXIT_CODE."
# Cleanup is handled by the trap
+
exit $TRAINER_EXIT_CODE
\ No newline at end of file