diff --git a/openmanus_rl/llm_agent/openmanus.py b/openmanus_rl/llm_agent/openmanus.py index 9259b8bf..7f181852 100644 --- a/openmanus_rl/llm_agent/openmanus.py +++ b/openmanus_rl/llm_agent/openmanus.py @@ -30,7 +30,7 @@ class AgentConfig: num_gpus: Number of GPUs to use react_format: Whether to use ReAct format env_name: Name of the environment (e.g., "webshop") - env_port: Port number for environment server + env_ports: List of ports for parallel servers env_server_base: Base URL for environment server env_data_len: Number of data samples in the environment (used for client init) rollout_strategy: Strategy to use for rollout (StandardReAct/ToT/MCTS) @@ -48,7 +48,7 @@ class AgentConfig: # Environment configuration (Now passed from trainer) env_name: str - env_port: int + env_ports: List[int] # List of ports for parallel servers env_server_base: str env_data_len: int = 200 # Default, might need adjustment rollout_strategy: str = "StandardReAct" # Strategy is now internal logic @@ -117,41 +117,39 @@ def __init__( max_start_length=config.max_start_length )) - # Initialize the environment client directly - self.client = self._init_env_client() - # Initialize thread pool for parallel rollouts - self.executor = ThreadPoolExecutor(max_workers=self.config.max_workers) + # Initialize multiple environment clients + self.clients = self._init_env_clients() # Changed method name - def _init_env_client(self): + # Adjust thread pool size based on number of clients, up to max_workers + num_clients = len(self.clients) + actual_workers = min(num_clients, self.config.max_workers) + if actual_workers < num_clients: + print(f"[Warning] Number of clients ({num_clients}) exceeds max_workers ({self.config.max_workers}). Using {actual_workers} workers.") + print(f"[Info] Initializing ThreadPoolExecutor with {actual_workers} workers for {num_clients} clients.") + self.executor = ThreadPoolExecutor(max_workers=actual_workers) + + def _init_env_clients(self) -> List[Any]: # Renamed and return type changed """ - Initialize and return the specific AgentGym environment client based on config. + Initialize and return a list of specific AgentGym environment clients + based on the ports provided in the config. """ + clients = [] + env_name_lower = self.config.env_name.lower() + # Mapping from env_name (lowercase) to Task class name - # We need the Task class to potentially get the right client or initial setup ENV_TO_TASK_CLASS = { - "academia": "AcademiaTask", - "alfworld": "AlfWorldTask", - "babyai": "BabyAITask", - "maze": "MazeTask", - "wordle": "WordleTask", - "movie": "MovieTask", - "sciworld": "SciworldTask", - "sheet": "SheetTask", - "sqlgym": "SqlGymTask", - "textcraft": "TextCraftTask", - "todo": "TodoTask", - "weather": "WeatherTask", - "webarena": "WebarenaTask", - "webshop": "WebshopTask", + "academia": "AcademiaTask", "alfworld": "AlfWorldTask", "babyai": "BabyAITask", + "maze": "MazeTask", "wordle": "WordleTask", "movie": "MovieTask", + "sciworld": "SciworldTask", "sheet": "SheetTask", "sqlgym": "SqlGymTask", + "textcraft": "TextCraftTask", "todo": "TodoTask", "weather": "WeatherTask", + "webarena": "WebarenaTask", "webshop": "WebshopTask", } - - env_name_lower = self.config.env_name.lower() + if env_name_lower not in ENV_TO_TASK_CLASS: raise ValueError(f"Unsupported environment name: {self.config.env_name}. Supported: {list(ENV_TO_TASK_CLASS.keys())}") task_class_name = ENV_TO_TASK_CLASS[env_name_lower] - print(f"Initializing Env Client for: {self.config.env_name} (via Task: {task_class_name})") - print(f"Connecting to AgentGym server at: {self.config.env_server_base}:{self.config.env_port}") + print(f"[Info] Initializing {len(self.config.env_ports)} Env Client(s) for: {self.config.env_name} (via Task: {task_class_name})") # Dynamically import the Task class try: @@ -160,28 +158,41 @@ def _init_env_client(self): except (ImportError, AttributeError) as e: raise ImportError(f"Could not import Task class {task_class_name} from agentenv.envs: {e}") - client_args={ - "env_server_base": f"{self.config.env_server_base}:{self.config.env_port}", - "data_len": self.config.env_data_len, - "timeout": 300, - } - - # Instantiate the task to get the client. - # Assuming Task object creates and holds the client(s) in a list `clients`. - # This might need adjustment based on actual Task implementation. - try: - # We only need one client instance per agent worker typically. - task_instance = TaskClass(client_args=client_args, n_clients=1) - if hasattr(task_instance, 'clients') and task_instance.clients: - client = task_instance.clients[0] - print(f"Successfully obtained client: {type(client)}") - return client - else: - raise ValueError(f"Task class {task_class_name} did not provide a client in 'clients' attribute.") - except Exception as e: - print(f"Error initializing Task or getting client for {task_class_name}: {e}") - print(traceback.format_exc()) # Print detailed traceback - raise + for i, port in enumerate(self.config.env_ports): + server_url = f"{self.config.env_server_base}:{port}" + print(f" - Client {i+1}: Connecting to {server_url}") + + client_args={ + "env_server_base": server_url, + "data_len": self.config.env_data_len, + "timeout": 300, + } + + try: + # Instantiate the task to get the client. + # We need one client per specified port. + # Assuming TaskClass handles client creation correctly when n_clients=1. + # If TaskClass itself manages multiple internal clients, this might need adjustment. + task_instance = TaskClass(client_args=client_args, n_clients=1) + if hasattr(task_instance, 'clients') and task_instance.clients: + client = task_instance.clients[0] + print(f" - Client {i+1}: Successfully obtained client: {type(client)}") + clients.append(client) + else: + print(f" - Client {i+1}: Error - Task class {task_class_name} did not provide a client for port {port}.") + # Decide how to handle failure: raise error or skip this client? Skipping for now. + # raise ValueError(f"Task class {task_class_name} did not provide a client for port {port}.") + except Exception as e: + print(f" - Client {i+1}: Error initializing Task or getting client for port {port}: {e}") + print(traceback.format_exc()) # Print detailed traceback + # Decide how to handle failure: raise error or skip? Skipping for now. + # raise + + if not clients: + raise RuntimeError("Failed to initialize any environment clients.") + + print(f"[Info] Successfully initialized {len(clients)} environment clients.") + return clients def _batch_tokenize(self, responses: List[str]) -> torch.Tensor: """Tokenize a batch of responses.""" @@ -398,13 +409,14 @@ def _save_trajectory(self, trajectory: List[Dict], for filename in filenames: self.logger.logger['wandb'].save(filename) - def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int) -> Dict[str, Any]: + def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int, client: Any) -> Dict[str, Any]: """ - Runs the interaction loop for a single environment instance. + Runs the interaction loop for a single environment instance using the provided client. Args: initial_prompt_ids: Token IDs for the initial prompt/observation. task_idx: The index for resetting the environment. + client: The specific environment client instance to use for this rollout. Returns: A dictionary containing the trajectory, step rewards, final reward, turns, @@ -419,13 +431,13 @@ def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int) - current_input_ids = None try: - # Reset environment - self.client.reset(task_idx) - initial_obs_text = self.client.observe() + # Reset environment using the provided client + client.reset(task_idx) + initial_obs_text = client.observe() # Handle initial observation if not initial_obs_text: - print(f"[Agent._run_single_rollout][{task_idx}] Warning: Received empty initial observation. Using initial prompt from batch.") + print(f"[Agent._run_single_rollout][{task_idx} @ {client.env_server_base}] Warning: Received empty initial observation. Using initial prompt from batch.") initial_prompt_text = self.tokenizer.decode(initial_prompt_ids[0], skip_special_tokens=True) trajectory.append({"from": "human", "value": initial_prompt_text}) current_input_ids = initial_prompt_ids @@ -441,15 +453,17 @@ def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int) - # Handle input that exceeds max length if current_input_ids.shape[1] > self.config.max_prompt_length: current_input_ids = current_input_ids[:, -self.config.max_prompt_length:] - print(f"[Agent._run_single_rollout][{task_idx}] Warning: Truncating input {current_input_ids.shape} > {self.config.max_prompt_length}.") + print(f"[Agent._run_single_rollout][{task_idx} @ {client.env_server_base}] Warning: Truncating input {current_input_ids.shape} > {self.config.max_prompt_length}.") # Prepare input current_attention_mask = self.tensor_fn.create_attention_mask(current_input_ids) current_position_ids = self.tensor_fn.create_position_ids(current_attention_mask) + # Ensure input tensors are on the correct device for the actor model + device = next(self.actor_rollout_wg.actor_model.parameters()).device # Get model's device gen_input_proto = DataProto.from_dict({ - 'input_ids': current_input_ids.to(self.actor_rollout_wg.device), - 'attention_mask': current_attention_mask.to(self.actor_rollout_wg.device), - 'position_ids': current_position_ids.to(self.actor_rollout_wg.device) + 'input_ids': current_input_ids.to(device), + 'attention_mask': current_attention_mask.to(device), + 'position_ids': current_position_ids.to(device) }) # Generate response @@ -460,6 +474,7 @@ def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int) - temperature=1.0, do_sample=True ) + # Generation happens on the actor worker group's device gen_output_proto = self.actor_rollout_wg.generate_sequences(gen_input_proto, generation_config=generation_config) response_ids = gen_output_proto.batch['response_ids'] response_text = self.tokenizer.decode(response_ids[0], skip_special_tokens=True) @@ -469,29 +484,30 @@ def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int) - action_types, action_contents = self.postprocess_predictions([response_text]) action_text = action_contents[0] - # Execute environment step + # Execute environment step using the provided client if action_text is None: action_text = "" - next_obs_text, reward, done, info = self.client.step(action_text) + next_obs_text, reward, done, info = client.step(action_text) # Record rewards step_rewards.append(reward) final_reward = reward - final_env_score = info.get('score', 0.0) + final_env_score = info.get('score', 0.0) # Use .get for safety # Process next observation if not done: trajectory.append({"from": "human", "value": next_obs_text}) next_obs_ids = self.tokenizer(next_obs_text, return_tensors='pt', add_special_tokens=False)['input_ids'] + # Ensure tensors are concatenated on the same device (e.g., CPU or model's device if needed later) current_input_ids = torch.cat([ - current_input_ids, - response_ids.to(current_input_ids.device), - next_obs_ids.to(current_input_ids.device) + current_input_ids.to(response_ids.device), # Move to same device as response_ids + response_ids, + next_obs_ids.to(response_ids.device) # Move to same device ], dim=1) else: break except Exception as e: - print(f"[Agent._run_single_rollout][{task_idx}] Error during rollout: {e}") + print(f"[Agent._run_single_rollout][{task_idx} @ {getattr(client, 'env_server_base', 'unknown_client')}] Error during rollout: {e}") print(traceback.format_exc()) step_rewards = [] final_reward = 0.0 @@ -509,25 +525,28 @@ def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int) - def run_llm_loop(self, gen_batch: DataProto, output_dir: str = None, global_steps: int = 0) -> DataProto: """ - Run the LLM interaction loop for a batch of initial prompts. - Updated to include trajectory visualization similar to generation.py. - + Run the LLM interaction loop for a batch of initial prompts using multiple clients. + Args: gen_batch: DataProto containing initial prompts output_dir: Directory to save visualizations global_steps: Current training step - + Returns: DataProto containing processed results """ initial_prompts_ids = gen_batch.batch['input_ids'] batch_size = initial_prompts_ids.shape[0] - print(f"[Agent.run_llm_loop] Starting rollout for batch size: {batch_size}") + num_clients = len(self.clients) + if num_clients == 0: + raise RuntimeError("No environment clients available for rollout.") + + print(f"[Agent.run_llm_loop] Starting rollout for batch size: {batch_size} using {num_clients} clients.") # --- Setup Visualization --- trajectory = self._setup_visualization() - - # --- Extract Task Indices --- + + # --- Extract Task Indices --- if 'idx' in gen_batch.meta_info: task_idxs = gen_batch.meta_info['idx'] if isinstance(task_idxs, torch.Tensor): @@ -539,74 +558,76 @@ def run_llm_loop(self, gen_batch: DataProto, output_dir: str = None, global_step print("[Agent.run_llm_loop] Warning: 'idx' not found in gen_batch.meta_info. Using range(batch_size)." ) task_idxs = list(range(batch_size)) - # Create active_mask to track active tasks - active_mask = torch.ones(batch_size, dtype=torch.bool) - active_num_list = [active_mask.sum().item()] - - # --- Parallel Rollout Execution --- + # --- Parallel Rollout Execution --- futures = {} rollout_results_list = [None] * batch_size # Preallocate list to store results in order - - # Create task list - envs = [] # Store environment instances + # Submit tasks to the thread pool, distributing across clients for i in range(batch_size): task_idx = task_idxs[i] initial_prompt = initial_prompts_ids[i:i+1] # Keep batch dim - env = self.client # Assume client is the environment instance - envs.append(env) - future = self.executor.submit(self._run_single_rollout, initial_prompt, task_idx) - futures[future] = i # Store original index + + # Select a client for this task (round-robin) + client_index = i % num_clients + selected_client = self.clients[client_index] + + # Submit the rollout task with the selected client + future = self.executor.submit(self._run_single_rollout, initial_prompt, task_idx, selected_client) + futures[future] = i # Store original batch index + + print(f"[Agent.run_llm_loop] Submitted {batch_size} rollout tasks to {self.executor._max_workers} workers.") # Collect results + completed_count = 0 for future in as_completed(futures): original_index = futures[future] try: result_dict = future.result() rollout_results_list[original_index] = result_dict - + completed_count += 1 + # print(f"Completed task {original_index + 1}/{batch_size}") # Optional progress logging + # If visualization is enabled, update trajectory - if trajectory and original_index < len(trajectory): - for turn in result_dict.get('trajectory', []): - if turn.get('from') == 'gpt': - # Create active_mask with only this index active - current_active_mask = torch.zeros(batch_size, dtype=torch.bool) - current_active_mask[original_index] = True - # Update trajectory - self._update_trajectory( - trajectory, - envs, - [turn.get('value', '')], - current_active_mask - ) + # Note: Visualization logic might need adjustment if envs are not easily accessible + # or if you want per-client visualization. Current logic assumes a single env list. + # Consider passing necessary env state back from _run_single_rollout if needed. + # if trajectory and original_index < len(trajectory): + # # This part might be tricky with multiple clients unless you manage env state carefully + # pass # Placeholder for potential visualization update logic + except Exception as e: - print(f"[Agent.run_llm_loop] Error collecting result for index {original_index}: {e}") - # Store a placeholder or error indicator if needed + print(f"[Agent.run_llm_loop] Error collecting result for batch index {original_index} (task_idx {task_idxs[original_index]}): {e}") + print(traceback.format_exc()) + # Store a placeholder or error indicator rollout_results_list[original_index] = { - 'trajectory': [], 'step_rewards': [], 'reward': 0.0, - 'turns': 0, 'env_score': 0.0, 'task_idx': task_idxs[original_index], + 'trajectory': [], 'step_rewards': [], 'reward': 0.0, + 'turns': 0, 'env_score': 0.0, 'task_idx': task_idxs[original_index], 'error': str(e) } - # Update active_mask to mark current task as inactive - active_mask[original_index] = False - active_num_list.append(active_mask.sum().item()) - print(f"[Agent.run_llm_loop] Collected results from {len(futures)} rollouts. Active trajectory nums: {active_num_list}") - - # Save trajectory visualizations - if output_dir and trajectory: - self._save_trajectory(trajectory, output_dir, global_steps) - - # Filter out potential None entries if some tasks failed critically before returning dict + print(f"[Agent.run_llm_loop] Collected results from {completed_count}/{batch_size} rollouts.") + + # Save trajectory visualizations (if implemented and needed) + # if output_dir and trajectory: + # self._save_trajectory(trajectory, output_dir, global_steps) + + # Filter out potential None entries if some tasks failed critically valid_results = [res for res in rollout_results_list if res is not None] - + if not valid_results: print("[Agent.run_llm_loop] Error: No valid rollout results collected.") - return DataProto.from_dict({}) # Return empty DataProto - - # --- Format Results into DataProto --- + # Return empty DataProto but with correct structure if possible + return DataProto.from_dict({ + "input_ids": torch.empty((0,0), dtype=torch.long), + "attention_mask": torch.empty((0,0), dtype=torch.long), + "position_ids": torch.empty((0,0), dtype=torch.long), + "info_mask": torch.empty((0,0), dtype=torch.long), + "token_level_rewards": torch.empty((0,0), dtype=torch.float) + }) + + # --- Format Results into DataProto --- processed_data = self._convert_rollout_results_to_dataproto(valid_results, gen_batch) - + print(f"[Agent.run_llm_loop] Finished processing rollout results.") return processed_data diff --git a/test/test_openmanus.py b/test/test_openmanus.py new file mode 100644 index 00000000..43a0f0ce --- /dev/null +++ b/test/test_openmanus.py @@ -0,0 +1,85 @@ +import pytest +import os +import sys +from pathlib import Path + +# Import target functions/classes +from openmanus_rl.llm_agent.openmanus import OpenManusAgent, create_react_prompt + +# 将 OpenManus-RL 路径加入 sys.path,确保能够找到 openmanus_rl 包 +PROJECT_ROOT = Path(__file__).resolve().parent.parent / "OpenManus-RL" +if PROJECT_ROOT.exists() and str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + + +class DummyToolManager: + """简易 ToolManager,仅返回固定的提示指令。""" + + def get_prompt_instructions(self): + return "These are tool instructions." + + +class MinimalAgent(OpenManusAgent): + """覆盖 __init__ 以绕过复杂依赖,仅用于测试无状态方法。""" + + def __init__(self): # pylint: disable=super-init-not-called + pass + + +# ------------------------------ +# postprocess_predictions 测试 +# ------------------------------ + +def test_postprocess_action(): + agent = MinimalAgent() + predictions = ["CLICK_BUTTON"] + actions, contents = agent.postprocess_predictions(predictions) + assert actions == ["action"] + assert contents == ["CLICK_BUTTON"] + + +def test_postprocess_response(): + agent = MinimalAgent() + predictions = ["Hello World"] + actions, contents = agent.postprocess_predictions(predictions) + assert actions == ["response"] + assert contents == ["Hello World"] + + +def test_postprocess_no_tag(): + agent = MinimalAgent() + predictions = ["Hello there"] + actions, contents = agent.postprocess_predictions(predictions) + assert actions == [None] + assert contents == [""] + + +def test_postprocess_invalid_type(): + """非字符串输入应抛出 ValueError。""" + agent = MinimalAgent() + with pytest.raises(ValueError): + agent.postprocess_predictions([123]) + + +def test_postprocess_multiple_tags(): + """确保仅解析第一个 标签内容。""" + agent = MinimalAgent() + predictions = ["foobar"] + actions, contents = agent.postprocess_predictions(predictions) + assert actions == ["action"] + assert contents == ["foo"] + + +# ------------------------------ +# create_react_prompt 测试 +# ------------------------------ + +def test_create_react_prompt_contains_sections(): + task_description = "Navigate to the red door." + tool_manager = DummyToolManager() + prompt = create_react_prompt(task_description, tool_manager) + + # Prompt 应包含任务描述、工具指令和固定结尾 + assert task_description in prompt + assert tool_manager.get_prompt_instructions() in prompt + assert "Let's solve this step by step." in prompt \ No newline at end of file diff --git a/train_ppo.sh b/train_ppo.sh index addbc7f3..c574e580 100644 --- a/train_ppo.sh +++ b/train_ppo.sh @@ -4,19 +4,22 @@ export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7} WAND_PROJECT=${WAND_PROJECT:-'OpenManus-rl'} export BASE_MODEL=${BASE_MODEL:-'meta-llama/Llama-3.2-3B'} -AGENTGYM_HOST=${AGENTGYM_HOST:-'0.0.0.0'} +AGENTGYM_HOST=${AGENTGYM_HOST:-'0.0.0.0'} # Default to 0.0.0.0 for external access AGENTGYM_SQL_BIRD_PATH=${AGENTGYM_SQL_BIRD_PATH:-} # Used only for sqlgym # --- Argument Parsing --- usage() { - echo "Usage: $0 --env_name [--port ] [--data_dir ] [--exp_name_suffix ]" + echo "Usage: $0 --env_name [--num_servers ] [--base_port ] [--data_dir ] [--exp_name_suffix ]" echo "Supported env_names: webshop, webarena, maze, wordle, alfworld, sciworld, babyai, textcraft, weather, movie, academia, todo, sheet, sqlgym" + echo " --num_servers: Number of parallel AgentGym servers to launch (default: 1)." + echo " --base_port: Starting port number for servers (default varies by env)." echo "Assumes dedicated conda environments like 'agentenv-webshop' are already created and set up." exit 1 } AGENTGYM_ENV_NAME="webshop" # Default environment -AGENTGYM_PORT_OVERRIDE="" +NUM_SERVERS=1 # Default number of servers +BASE_PORT_OVERRIDE="" DATA_DIR_OVERRIDE="" EXP_NAME_SUFFIX="" @@ -25,8 +28,10 @@ while [[ $# -gt 0 ]]; do case $key in --env_name) AGENTGYM_ENV_NAME="$2"; shift; shift;; - --port) - AGENTGYM_PORT_OVERRIDE="$2"; shift; shift;; + --num_servers) + NUM_SERVERS="$2"; shift; shift;; + --base_port) # Changed from --port to --base_port + BASE_PORT_OVERRIDE="$2"; shift; shift;; --data_dir) DATA_DIR_OVERRIDE="$2"; shift; shift;; --exp_name_suffix) @@ -36,6 +41,11 @@ while [[ $# -gt 0 ]]; do esac done +if ! [[ "$NUM_SERVERS" =~ ^[1-9][0-9]*$ ]]; then + echo "Error: --num_servers must be a positive integer." + usage +fi + if [ -z "$AGENTGYM_ENV_NAME" ]; then echo "Error: --env_name is required."; usage fi @@ -45,122 +55,149 @@ BASE_CONDA_ENV=${CONDA_DEFAULT_ENV:-openmanus-rl} echo "[Info] Detected base conda environment: $BASE_CONDA_ENV" echo "[Info] Verl trainer will run in this environment." -# --- Environment Specific Setup (Determine LAUNCH_CMD, DEFAULT_PORT, URL_PATH) --- + +# --- Environment Specific Setup (Determine LAUNCH_CMD, DEFAULT_BASE_PORT, URL_PATH) --- + LAUNCH_CMD="" -DEFAULT_PORT="" +DEFAULT_BASE_PORT="" # Renamed from DEFAULT_PORT URL_PATH="" -# 不再使用Python -m 模块方式 # MODULE_LAUNCH_NAME="" -# 设置默认主机为0.0.0.0以允许外部访问 AGENTGYM_HOST=${AGENTGYM_HOST:-'0.0.0.0'} case $AGENTGYM_ENV_NAME in webshop) LAUNCH_CMD="webshop --host $AGENTGYM_HOST --port \$AGENTGYM_PORT" - DEFAULT_PORT=36001;; + DEFAULT_BASE_PORT=36001;; webarena) LAUNCH_CMD="webarena --host $AGENTGYM_HOST --port \$AGENTGYM_PORT" - DEFAULT_PORT=8000;; + DEFAULT_BASE_PORT=8000;; maze) LAUNCH_CMD="lmrlgym --host $AGENTGYM_HOST --port \$AGENTGYM_PORT" - DEFAULT_PORT=36001; URL_PATH="/maze/";; + DEFAULT_BASE_PORT=36001; URL_PATH="/maze/";; wordle) LAUNCH_CMD="lmrlgym --host $AGENTGYM_HOST --port \$AGENTGYM_PORT" - DEFAULT_PORT=36001; URL_PATH="/wordle/";; + DEFAULT_BASE_PORT=36001; URL_PATH="/wordle/";; alfworld) LAUNCH_CMD="alfworld --host $AGENTGYM_HOST --port \$AGENTGYM_PORT" - DEFAULT_PORT=36001;; + DEFAULT_BASE_PORT=36001;; sciworld) LAUNCH_CMD="sciworld --host $AGENTGYM_HOST --port \$AGENTGYM_PORT" - DEFAULT_PORT=36001;; + DEFAULT_BASE_PORT=36001;; babyai) LAUNCH_CMD="babyai --host $AGENTGYM_HOST --port \$AGENTGYM_PORT" - DEFAULT_PORT=36001;; + DEFAULT_BASE_PORT=36001;; textcraft) LAUNCH_CMD="textcraft --host $AGENTGYM_HOST --port \$AGENTGYM_PORT" - DEFAULT_PORT=36001;; + DEFAULT_BASE_PORT=36001;; weather|movie|academia|todo|sheet) - LAUNCH_CMD="\$AGENTGYM_ENV_NAME --host $AGENTGYM_HOST --port \$AGENTGYM_PORT" - DEFAULT_PORT=8000;; + LAUNCH_CMD="\\\$AGENTGYM_ENV_NAME --host $AGENTGYM_HOST --port \\\$AGENTGYM_PORT" # Escaped env name var + DEFAULT_BASE_PORT=8000;; sqlgym) if [ -z "$AGENTGYM_SQL_BIRD_PATH" ]; then echo "Error: AGENTGYM_SQL_BIRD_PATH must be set for sqlgym."; exit 1; fi - LAUNCH_CMD="AGENTENV_SQLGYM_BIRD_PATH=$AGENTGYM_SQL_BIRD_PATH sqlgym --host $AGENTGYM_HOST --port \$AGENTGYM_PORT" - DEFAULT_PORT=36002;; + LAUNCH_CMD="AGENTENV_SQLGYM_BIRD_PATH=$AGENTGYM_SQL_BIRD_PATH sqlgym --host $AGENTGYM_HOST --port \\\$AGENTGYM_PORT" + DEFAULT_BASE_PORT=36002;; *) echo "Error: Unsupported environment name '$AGENTGYM_ENV_NAME'"; usage;; esac -# --- Start AgentGym Server in its Dedicated Environment --- +# --- Start AgentGym Servers in Dedicated Environment --- TARGET_ENV_NAME="agentenv-${AGENTGYM_ENV_NAME}" -AGENTGYM_PID="" +AGENTGYM_PIDS=() # Array to store PIDs +AGENTGYM_PORTS=() # Array to store ports # Check if target env exists -if ! conda env list | grep -Eq "^${TARGET_ENV_NAME}\s"; then +if ! conda env list | grep -Eq "^${TARGET_ENV_NAME}\\s"; then echo "[Error] Dedicated environment '$TARGET_ENV_NAME' not found. Please create it first." exit 1 fi -# Prepare Launch Command (Prefer python -m style if defined) -export AGENTGYM_PORT=${AGENTGYM_PORT_OVERRIDE:-$DEFAULT_PORT} - -# 直接使用命令行工具方式启动 -FINAL_LAUNCH_CMD=$(eval echo $LAUNCH_CMD) # Substitute $AGENTGYM_PORT +# Determine base port +AGENTGYM_BASE_PORT=${BASE_PORT_OVERRIDE:-$DEFAULT_BASE_PORT} -echo -e "\\n[Server] Starting AgentGym server for ${AGENTGYM_ENV_NAME} in env '$TARGET_ENV_NAME'..." -echo "[Server] Host: ${AGENTGYM_HOST}, Port: ${AGENTGYM_PORT}" -echo "[Server] Launch command: $FINAL_LAUNCH_CMD" +echo -e "\\n[Server] Starting $NUM_SERVERS AgentGym server(s) for ${AGENTGYM_ENV_NAME} in env '$TARGET_ENV_NAME'..." +echo "[Server] Base Port: ${AGENTGYM_BASE_PORT}" # Create logs directory mkdir -p logs -# Run server in background using conda run in the target environment -LOG_FILE="logs/${TARGET_ENV_NAME}_server.log" -echo "[Server] Logging to $LOG_FILE" - -# 简化启动命令,直接使用conda run启动服务器 -conda run -n "$TARGET_ENV_NAME" $FINAL_LAUNCH_CMD > "$LOG_FILE" 2>&1 & -AGENTGYM_PID=$! - -# Check if PID was obtained -if [ -z "$AGENTGYM_PID" ]; then - echo "[Error] Failed to get PID for AgentGym server launch command." - exit 1 -fi -echo "[Server] AgentGym server launched in '$TARGET_ENV_NAME' (PID: $AGENTGYM_PID)." +for (( i=0; i<$NUM_SERVERS; i++ )); do + # Calculate port for this server instance + export AGENTGYM_PORT=$((AGENTGYM_BASE_PORT + i)) + AGENTGYM_PORTS+=($AGENTGYM_PORT) # Store port + + # Prepare the specific launch command for this instance + CURRENT_LAUNCH_CMD=$(eval echo $LAUNCH_CMD) # Substitute $AGENTGYM_PORT + + echo "[Server $(($i+1))/$NUM_SERVERS] Launching on ${AGENTGYM_HOST}:${AGENTGYM_PORT}..." + echo "[Server $(($i+1))/$NUM_SERVERS] Command: $CURRENT_LAUNCH_CMD" + + # Run server in background using conda run + LOG_FILE="logs/${TARGET_ENV_NAME}_server_${AGENTGYM_PORT}.log" + echo "[Server $(($i+1))/$NUM_SERVERS] Logging to $LOG_FILE" + + # Use bash -c to handle potential env vars in launch cmd (like for sqlgym) + conda run --no-capture-output -n "$TARGET_ENV_NAME" bash -c "$CURRENT_LAUNCH_CMD" > "$LOG_FILE" 2>&1 & + PID=$! + + # Check if PID was obtained + if [ -z "$PID" ]; then + echo "[Error] Failed to get PID for AgentGym server instance $i on port $AGENTGYM_PORT." + # Attempt to kill already launched servers before exiting + for p in "${AGENTGYM_PIDS[@]}"; do kill $p 2>/dev/null; done + exit 1 + fi + AGENTGYM_PIDS+=($PID) # Store PID + echo "[Server $(($i+1))/$NUM_SERVERS] Launched (PID: $PID)." + sleep 2 # Small delay between starting servers +done -# --- Wait and Check Server --- -echo "[Server] Waiting for AgentGym server (PID: $AGENTGYM_PID) to initialize..." +# --- Wait and Check Servers --- +echo "[Server] Waiting for AgentGym servers (${AGENTGYM_PIDS[*]}) to initialize..." sleep 15 # Adjust sleep time if needed -# Check if the process with the captured PID is still running -if ! kill -0 $AGENTGYM_PID > /dev/null 2>&1; then - echo "[Error] AgentGym server (PID: $AGENTGYM_PID) failed to start or exited prematurely." - echo "[Error] Check server log: $LOG_FILE" +# Check if all server processes are still running +ALL_SERVERS_RUNNING=true +for PID in "${AGENTGYM_PIDS[@]}"; do + if ! kill -0 $PID > /dev/null 2>&1; then + echo "[Error] AgentGym server (PID: $PID) failed to start or exited prematurely." + # Attempt to find the corresponding log file (this is a bit heuristic) + PORT=$(grep -oP -- "--port\\s+\\K\\d+" "logs/"*"${PID}"* 2>/dev/null || echo "unknown") + echo "[Error] Check server log potentially named logs/${TARGET_ENV_NAME}_server_${PORT}.log or similar." + ALL_SERVERS_RUNNING=false + fi +done + +if [ "$ALL_SERVERS_RUNNING" = false ]; then + echo "[Error] Not all servers started successfully. Exiting." + # Kill remaining servers + for p in "${AGENTGYM_PIDS[@]}"; do kill $p 2>/dev/null; done exit 1 fi -echo "[Server] AgentGym server appears to be running (PID: $AGENTGYM_PID)." +echo "[Server] All AgentGym servers appear to be running." -# Setup trap to kill the server process on script exit/interrupt -trap "echo '[Cleanup] Stopping AgentGym server (PID: $AGENTGYM_PID)...'; kill $AGENTGYM_PID 2>/dev/null || echo '[Cleanup] Server already stopped.'" EXIT +# Setup trap to kill all server processes on script exit/interrupt +trap "echo '[Cleanup] Stopping AgentGym servers (PIDs: ${AGENTGYM_PIDS[*]})...'; kill ${AGENTGYM_PIDS[*]} 2>/dev/null || echo '[Cleanup] Servers already stopped.'; wait ${AGENTGYM_PIDS[*]} 2>/dev/null" EXIT # --- Data and Experiment Naming --- export DATA_DIR=${DATA_DIR_OVERRIDE:-"data/$AGENTGYM_ENV_NAME"} # Default data dir based on env name export EXPERIMENT_NAME="OpenManus-rl-ppo-${BASE_MODEL##*/}-${AGENTGYM_ENV_NAME}${EXP_NAME_SUFFIX}" + # --- Run PPO Training in Base Environment --- echo -e "\\n[Trainer] Running PPO training in base environment '$BASE_CONDA_ENV'..." export VLLM_ATTENTION_BACKEND=${VLLM_ATTENTION_BACKEND:-XFORMERS} # Construct server base URL, adding path if needed -AGENTGYM_SERVER_BASE="http://$AGENTGYM_HOST" -if [ -n "$URL_PATH" ]; then - AGENTGYM_SERVER_BASE="$AGENTGYM_SERVER_BASE$URL_PATH" -fi +AGENTGYM_SERVER_BASE="http://$AGENTGYM_HOST" # Base URL without port +# Construct the list of ports as a comma-separated string for OmegaConf +AGENTGYM_PORTS_STR=$(IFS=,; echo "${AGENTGYM_PORTS[*]}") echo "[Trainer] Using Data Directory: $DATA_DIR" echo "[Trainer] Experiment Name: $EXPERIMENT_NAME" -echo "[Trainer] AgentGym Base URL: $AGENTGYM_SERVER_BASE:$AGENTGYM_PORT" + +echo "[Trainer] AgentGym Base URL: $AGENTGYM_SERVER_BASE" +echo "[Trainer] AgentGym Ports: $AGENTGYM_PORTS_STR" # Pass list of ports # Check if train/test files exist TRAIN_FILE="$DATA_DIR/train.parquet" @@ -173,87 +210,88 @@ if [ ! -f "$TEST_FILE" ]; then echo "[Warning] Test file not found at $TEST_FILE. Ensure data generation script was run for $AGENTGYM_ENV_NAME." fi -TRAINER_LOG_FILE="logs/${EXPERIMENT_NAME}.log" -echo "[Trainer] Logging trainer output to $TRAINER_LOG_FILE" - -# 正确方式激活conda环境(使用source) -echo "[Trainer] Activating base environment for training..." -# 获取conda base路径并确保conda.sh可用 +# Ensure base environment is activated correctly for trainer +echo "[Trainer] Ensuring base environment '$BASE_CONDA_ENV' is active..." CONDA_BASE=$(conda info --base) source "${CONDA_BASE}/etc/profile.d/conda.sh" -conda activate openmanus-rl - -# 检查并安装必要的依赖 -echo "[Trainer] Checking and installing required dependencies..." -if ! python -c "import tensordict" &>/dev/null; then - echo "[Trainer] Installing missing dependency: tensordict" - pip install tensordict -fi +conda activate "$BASE_CONDA_ENV" || { echo "Error: Failed to activate base env '$BASE_CONDA_ENV'"; exit 1; } + +# Check and install dependencies within the base environment +echo "[Trainer] Checking and installing required dependencies in '$BASE_CONDA_ENV'..." +for pkg in tensordict codetiming ray wandb transformers; do + if ! python -c "import $pkg" &>/dev/null; then + echo "[Trainer] Installing missing dependency: $pkg" + pip install $pkg + fi +done +TRAINER_LOG_FILE="logs/${EXPERIMENT_NAME}.log" +echo "[Trainer] Logging trainer output to $TRAINER_LOG_FILE" echo "[Trainer] Starting PPO training..." -PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ - data.train_files=$TRAIN_FILE \ - data.val_files=$TEST_FILE \ - data.env_name=$AGENTGYM_ENV_NAME \ - data.env_server_base=$AGENTGYM_SERVER_BASE \ - data.env_port=$AGENTGYM_PORT \ - data.train_data_num=null \ - data.val_data_num=null \ - data.train_batch_size=512 \ - data.val_batch_size=256 \ - data.max_prompt_length=4096 \ - data.max_response_length=500 \ - data.max_start_length=2048 \ - data.max_obs_length=500 \ - data.shuffle_train_dataloader=True \ - algorithm.adv_estimator=gae \ - actor_rollout_ref.model.path=$BASE_MODEL \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.enable_gradient_checkpointing=true \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.95 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=64 \ - actor_rollout_ref.actor.fsdp_config.param_offload=true \ - actor_rollout_ref.actor.fsdp_config.grad_offload=true \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=128 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.rollout.n_agent=1 \ - actor_rollout_ref.rollout.temperature=1 \ - actor_rollout_ref.actor.state_masking=true \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.optim.lr_warmup_steps_ratio=0.05 \ - critic.model.path=$BASE_MODEL \ - critic.model.enable_gradient_checkpointing=true \ - critic.ppo_micro_batch_size=8 \ - critic.model.fsdp_config.param_offload=true \ - critic.model.fsdp_config.grad_offload=true \ - critic.model.fsdp_config.optimizer_offload=true \ - algorithm.kl_ctrl.kl_coef=0.001 \ - algorithm.no_think_rl=false \ - algorithm.reward_score_fn=agentgym \ - trainer.critic_warmup=0 \ - trainer.logger=['wandb'] \ - +trainer.val_only=false \ - +trainer.val_before_train=true \ - trainer.default_hdfs_dir=null \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=100 \ - trainer.test_freq=50 \ - trainer.project_name=$WAND_PROJECT \ - trainer.experiment_name=$EXPERIMENT_NAME \ - trainer.total_epochs=15 \ - trainer.total_training_steps=305 \ - trainer.default_hdfs_dir=null \ - trainer.default_local_dir=verl_checkpoints/$EXPERIMENT_NAME \ - max_turns=2 \ + +PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\ + data.train_files=$TRAIN_FILE \\ + data.val_files=$TEST_FILE \\ + data.env_name=$AGENTGYM_ENV_NAME \\ + data.env_server_base=$AGENTGYM_SERVER_BASE \\ + data.env_ports=[${AGENTGYM_PORTS_STR}] \\ // Pass ports as a list + data.train_data_num=null \\ + data.val_data_num=null \\ + data.train_batch_size=512 \\ + data.val_batch_size=256 \\ + data.max_prompt_length=4096 \\ + data.max_response_length=500 \\ + data.max_start_length=2048 \\ + data.max_obs_length=500 \\ + data.shuffle_train_dataloader=True \\ + algorithm.adv_estimator=gae \\ + actor_rollout_ref.model.path=$BASE_MODEL \\ + actor_rollout_ref.actor.optim.lr=1e-6 \\ + actor_rollout_ref.model.enable_gradient_checkpointing=true \\ + actor_rollout_ref.model.use_remove_padding=True \\ + actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.95 \\ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \\ + actor_rollout_ref.actor.ppo_micro_batch_size=64 \\ + actor_rollout_ref.actor.fsdp_config.param_offload=true \\ + actor_rollout_ref.actor.fsdp_config.grad_offload=true \\ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=true \\ + actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \\ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\ + actor_rollout_ref.rollout.name=vllm \\ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\ + actor_rollout_ref.ref.log_prob_micro_batch_size=128 \\ + actor_rollout_ref.ref.fsdp_config.param_offload=True \\ + actor_rollout_ref.rollout.n_agent=1 \\ + actor_rollout_ref.rollout.temperature=1 \\ + actor_rollout_ref.actor.state_masking=true \\ + critic.optim.lr=1e-5 \\ + critic.model.use_remove_padding=True \\ + critic.optim.lr_warmup_steps_ratio=0.05 \\ + critic.model.path=$BASE_MODEL \\ + critic.model.enable_gradient_checkpointing=true \\ + critic.ppo_micro_batch_size=8 \\ + critic.model.fsdp_config.param_offload=true \\ + critic.model.fsdp_config.grad_offload=true \\ + critic.model.fsdp_config.optimizer_offload=true \\ + algorithm.kl_ctrl.kl_coef=0.001 \\ + algorithm.no_think_rl=false \\ + algorithm.reward_score_fn=agentgym \\ + trainer.critic_warmup=0 \\ + trainer.logger=['wandb'] \\ + +trainer.val_only=false \\ + +trainer.val_before_train=true \\ + trainer.default_hdfs_dir=null \\ + trainer.n_gpus_per_node=8 \\ + trainer.nnodes=1 \\ + trainer.save_freq=100 \\ + trainer.test_freq=50 \\ + trainer.project_name=$WAND_PROJECT \\ + trainer.experiment_name=$EXPERIMENT_NAME \\ + trainer.total_epochs=15 \\ + trainer.total_training_steps=305 \\ + trainer.default_hdfs_dir=null \\ + trainer.default_local_dir=verl_checkpoints/$EXPERIMENT_NAME \\ + max_turns=2 \\ 2>&1 | tee "$TRAINER_LOG_FILE" # Log trainer output TRAINER_EXIT_CODE=$? @@ -262,4 +300,5 @@ echo "PPO training finished with exit code $TRAINER_EXIT_CODE." # Cleanup is handled by the trap + exit $TRAINER_EXIT_CODE \ No newline at end of file