diff --git a/data/webshop/webshop.py b/data/webshop/webshop.py new file mode 100644 index 00000000..a1adb741 --- /dev/null +++ b/data/webshop/webshop.py @@ -0,0 +1,18 @@ +import json +from datasets import load_dataset + +# Load the full AgentEval dataset +ds = load_dataset("AgentGym/AgentEval", split="test") + +# Filter only the entries with item_id starting with "webshop_" +webshop_ds = ds.filter(lambda x: x["item_id"].startswith("webshop_")) + +# Preview the result +print(webshop_ds) + +output_file = "webshop_inference.json" + +data = [{"item_id": x["item_id"], "conversations": []} for x in webshop_ds] + +with open(output_file, "w") as f: + json.dump(data, f, indent=2) diff --git a/openmanus_rl/agentgym/agentenv/examples/basic/base_eval_webshop.sh b/openmanus_rl/agentgym/agentenv/examples/basic/base_eval_webshop.sh new file mode 100644 index 00000000..9c707ccc --- /dev/null +++ b/openmanus_rl/agentgym/agentenv/examples/basic/base_eval_webshop.sh @@ -0,0 +1,19 @@ +# Evaluation args +model_path="/data1/models/openmanus_rl/Qwen/Qwen3-3b-sft/global_step_1" +inference_file="/home/user/muxin/OpenManus-RL/data/webshop/webshop_inference.json" +output_file="/data1/models/openmanus_rl/Qwen/Qwen3-3b-sft/output/qwen2.5-3b-webshop.log" +task_name="webshop" +seed="42" + +# environment parameters +max_round="6" +env_server_base="http://127.0.0.1:36001" + +python -u base_eval_template.py \ + --model_path "${model_path}" \ + --inference_file "${inference_file}" \ + --output_file "${output_file}" \ + --task_name "${task_name}" \ + --seed "${seed}" \ + --max_round "${max_round}" \ + --env_server_base "${env_server_base}" diff --git a/openmanus_rl/agentgym/agentenv/examples/distributed_eval_scripts/distributed_eval_webshop.sh b/openmanus_rl/agentgym/agentenv/examples/distributed_eval_scripts/distributed_eval_webshop.sh new file mode 100755 index 00000000..1c2e564b --- /dev/null +++ b/openmanus_rl/agentgym/agentenv/examples/distributed_eval_scripts/distributed_eval_webshop.sh @@ -0,0 +1,46 @@ +exp_name="eval_webshop" +inference_file='/home/user/muxin/OpenManus-RL/data/webshop/webshop_inference.json' # Path to the trainset file which contains idxs for the task. + +num_processes='8' +main_process_port='8877' +weight_decay="0" + +### Default variables +task_name="webshop" # change this to evaluate on a different task +output_dir="/data1/models/openmanus_rl/Qwen/Qwen3-3b-sft/output" + +# agent model +#model_path="/data1/models/openmanus_rl/Qwen/Qwen3-3b-sft/global_step_1" +model_path="/data1/models/Qwen/Qwen2.5-3B" +eval_batch_size="1" +num_workers="8" +seed="42" +do_sample="False" +temperature="1.0" + +max_round="6" +env_server_base="http://127.0.0.1:36001" # Set this to the base url of the EnvServer. +timeout="2400" + + +######### +mkdir -p "${output_dir}" +export PYTHONPATH=/home/user/muxin/OpenManus-RL/openmanus_rl/agentgym/agentenv:$PYTHONPATH # You need to modify this as your agentgym/agentenv absolute path + +accelerate launch \ + --num_processes=${num_processes} \ + --main_process_port=${main_process_port} \ + ../../utils/distributed_eval_task.py \ + --model_path "${model_path}" \ + --output_file "${output_dir}/inference.jsonl" \ + --inference_file "${inference_file}" \ + --task_name "${task_name}" \ + --eval_batch_size "${eval_batch_size}" \ + --num_workers "${num_workers}" \ + --seed "${seed}" \ + --do_sample "${do_sample}" \ + --temperature "${temperature}" \ + --max_round "${max_round}" \ + --env_server_base "${env_server_base}" \ + --data_len 200 \ + --timeout "${timeout}" diff --git a/requirements.txt b/requirements.txt index 2bf1d121..8f153662 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,4 @@ vllm<=0.6.3 wandb IPython matplotlib -omegaconf \ No newline at end of file +omegaconf diff --git a/scripts/offline_rollout.sh b/scripts/offline_rollout.sh new file mode 100644 index 00000000..d9022b3c --- /dev/null +++ b/scripts/offline_rollout.sh @@ -0,0 +1,96 @@ +CONFIG_FILE="" # fulfill the config yaml file here +MODEL_PATH="" +OUTPUT_DIR="" +TASK_NAMES="" +DATA_LEN=200 +TIMEOUT=2400 +DO_SAMPLE="False" +TEMPERATURE=1.0 +SEED=42 +DEBUG=false + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --config) + CONFIG_FILE="$2" + shift 2 + ;; + --model_path) + MODEL_PATH="$2" + shift 2 + ;; + --output_dir) + OUTPUT_DIR="$2" + shift 2 + ;; + --task_names) + TASK_NAMES="$2" + shift 2 + ;; + --data_len) + DATA_LEN="$2" + shift 2 + ;; + --timeout) + TIMEOUT="$2" + shift 2 + ;; + --do_sample) + DO_SAMPLE="$2" + shift 2 + ;; + --temperature) + TEMPERATURE="$2" + shift 2 + ;; + --seed) + SEED="$2" + shift 2 + ;; + --debug) + DEBUG=true + shift + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Build command +CMD="python traj_generation/rollout_eval.py --config $CONFIG_FILE" + +if [ ! -z "$MODEL_PATH" ]; then + CMD="$CMD --model_path $MODEL_PATH" +fi + +if [ ! -z "$OUTPUT_DIR" ]; then + CMD="$CMD --output_dir $OUTPUT_DIR" +fi + +if [ ! -z "$TASK_NAMES" ]; then + CMD="$CMD --task_names $TASK_NAMES" +fi + +CMD="$CMD --data_len $DATA_LEN --timeout $TIMEOUT --do_sample $DO_SAMPLE --temperature $TEMPERATURE --seed $SEED" + +if [ "$DEBUG" = true ]; then + CMD="$CMD --debug" +fi + +# Create log directory +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") +LOG_DIR="./logs" +mkdir -p $LOG_DIR +LOG_FILE="$LOG_DIR/offline_rollout_$TIMESTAMP.log" + +# Print the command +echo "Running: $CMD" +echo "Logging to: $LOG_FILE" + +# Execute with logging +eval "$CMD | tee $LOG_FILE" + +echo "Evaluation complete! Results saved to the output directory." \ No newline at end of file diff --git a/scripts/run_sft.sh b/scripts/run_sft.sh new file mode 100644 index 00000000..9d4e3010 --- /dev/null +++ b/scripts/run_sft.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_sft.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=OpenManus-RL/data/train_split.parquet \ + data.val_files=OpenManus-RL/data/test_split.parquet \ + data.multiturn.enable=true \ + data.multiturn.messages_key=prompt \ + data.micro_batch_size=4 \ + model.partial_pretrain=/data1/models/Qwen/Qwen3-4B \ + trainer.default_local_dir=$save_path \ + trainer.project_name=multiturn-sft \ + trainer.experiment_name=multiturn-sft-qwen-3-4b \ + trainer.logger=['console'] \ + trainer.total_training_steps=1 \ + trainer.default_hdfs_dir=null $@ \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true \ No newline at end of file diff --git a/traj_generation/rollout_eval.py b/traj_generation/rollout_eval.py new file mode 100644 index 00000000..32ec3149 --- /dev/null +++ b/traj_generation/rollout_eval.py @@ -0,0 +1,466 @@ +import os +import sys +import json +import jsonlines +import time +import argparse +import yaml +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional, Tuple +import torch +from tqdm import tqdm +import logging + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout), + ] +) +logger = logging.getLogger("offline_rollout") + +# Import OpenManus and AgentGym components +try: + from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + from agentenv.controller import Agent, Evaluator + from agentenv.envs import ( + AlfWorldTask, + BabyAITask, + SciworldTask, + TextCraftTask, + WebarenaTask, + WebshopTask, + SqlGymTask, + MazeTask, + WordleTask, + WeatherTask, + TodoTask, + MovieTask, + SheetTask, + AcademiaTask + ) +except ImportError as e: + logger.error(f"Failed to import required modules: {e}") + logger.error("Please ensure AgentGym and its dependencies are installed") + sys.exit(1) + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate offline rollout trajectories for AgentGym environments") + parser.add_argument("--config", type=str, default=None, help="Path to configuration YAML file") + parser.add_argument("--model_path", type=str, required="--config" not in sys.argv, + help="Path to the model") + parser.add_argument("--output_dir", type=str, default="./offline_rollout_results", + help="Directory to save results") + parser.add_argument("--task_names", nargs="+", default=["webshop"], + help="Space-separated list of task names to evaluate") + parser.add_argument("--inference_files", nargs="+", default=None, + help="Space-separated list of inference files for each task") + parser.add_argument("--max_rounds", nargs="+", type=int, default=None, + help="Maximum interaction rounds for each task") + parser.add_argument("--env_server_bases", nargs="+", default=None, + help="Environment server base URLs for each task") + parser.add_argument("--data_len", type=int, default=200, + help="Number of data samples in environment") + parser.add_argument("--timeout", type=int, default=2400, + help="Timeout value for environment interactions") + parser.add_argument("--do_sample", type=str, default="False", + help="Whether to use sampling for generation") + parser.add_argument("--temperature", type=float, default=1.0, + help="Temperature for sampling") + parser.add_argument("--seed", type=int, default=42, + help="Random seed") + parser.add_argument("--debug", action="store_true", default=False, + help="Enable debug mode") + return parser.parse_args() + +def read_config(config_path: str) -> Dict[str, Any]: + """Read and process configuration file""" + try: + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + + # Validate required fields + required_fields = ["model_path", "output_dir", "tasks"] + for field in required_fields: + if field not in config: + raise ValueError(f"Config file missing required field: {field}") + + # Process tasks configuration + if not isinstance(config["tasks"], dict) or not config["tasks"]: + raise ValueError("Config must contain at least one task under 'tasks' key") + + return config + except Exception as e: + logger.error(f"Error reading config file: {e}") + raise + +def get_task_class(task_name: str): + """Get task class based on task name""" + task_classes = { + "webshop": WebshopTask, + "alfworld": AlfWorldTask, + "babyai": BabyAITask, + "sciworld": SciworldTask, + "textcraft": TextCraftTask, + "webarena": WebarenaTask, + "sqlgym": SqlGymTask, + 'maze': MazeTask, + 'wordle': WordleTask, + "weather": WeatherTask, + "todo": TodoTask, + "movie": MovieTask, + "sheet": SheetTask, + "academia": AcademiaTask + } + + task_class = task_classes.get(task_name.lower(), None) + if task_class is None: + raise ValueError(f"Unsupported task name: {task_name}") + + return task_class + +def load_data(data_path: str) -> List[Dict[str, Any]]: + """Load data from the specified file path""" + logger.info(f"Loading data from {data_path}") + + try: + if data_path.endswith('.json'): + with open(data_path, 'r') as f: + data = json.load(f) + elif data_path.endswith('.jsonl'): + data = [] + with jsonlines.open(data_path) as reader: + for item in reader: + data.append(item) + else: + raise ValueError(f"Unsupported file format: {data_path}") + + logger.info(f"Loaded {len(data)} samples") + return data + except Exception as e: + logger.error(f"Error loading data: {e}") + raise + +def save_summary(output_dir: str, task_name: str, success_rate: float, score: float, + total_samples: int, success_count: int, process_time: float): + """Save evaluation summary to file""" + summary_path = os.path.join(output_dir, f"{task_name}_summary.json") + + summary = { + "task_name": task_name, + "success_rate": success_rate, + "score": score, + "total_samples": total_samples, + "success_count": success_count, + "process_time": process_time, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S") + } + + with open(summary_path, 'w') as f: + json.dump(summary, f, indent=2) + + logger.info(f"Summary saved to {summary_path}") + +def extract_data_idxs(data: List[Dict[str, Any]], task_name: str) -> List[List[int]]: + """Extract data indices from the data""" + data_idxs = [] + + for item in data: + # Handle different data formats + if "item_id" in item: + # Format: "task_123" or just "123" + item_id = item["item_id"] + if isinstance(item_id, str): + if task_name in item_id: + # Extract number after task_name_ + idx = int(item_id.split("_")[-1]) + else: + # If no task prefix, just use the ID directly + try: + idx = int(item_id) + except ValueError: + idx = int(item_id.split("_")[-1]) + else: + idx = item_id + elif "id" in item: + idx = item["id"] + elif "index" in item: + idx = item["index"] + else: + # Default to using the position in the list + idx = data.index(item) + + data_idxs.append([idx]) + + return data_idxs + +def generate_trajectories( + model_path: str, + task_name: str, + inference_file: str, + output_file: str, + max_rounds: int, + env_server_base: str, + data_len: int = 200, + timeout: int = 2400, + do_sample: bool = False, + temperature: float = 1.0, + seed: int = 42 +) -> Dict[str, Any]: + """Generate trajectories for a specific task""" + logger.info(f"Generating trajectories for {task_name}") + logger.info(f"Using model: {model_path}") + logger.info(f"Environment server: {env_server_base}") + + # Load model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, + device_map="auto", + trust_remote_code=True + ).eval() + + # Get task class + task_class = get_task_class(task_name) + + # Set environment parameters + env_args = { + "env_server_base": env_server_base, + "data_len": data_len, + "timeout": timeout, + } + + # Set up evaluator + evaluator = Evaluator( + Agent(model, tokenizer), + [task_class(client_args=env_args, n_clients=1)], + ) + + # Load data + test_data = load_data(inference_file) + data_idxs = extract_data_idxs(test_data, task_name) + + # Set generation config + gen_config = GenerationConfig( + max_length=4096, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.unk_token_id, + do_sample=do_sample, + temperature=temperature if do_sample else 1.0, + ) + + # Initialize output file + os.makedirs(os.path.dirname(output_file), exist_ok=True) + if os.path.exists(output_file): + os.remove(output_file) + + # Run evaluation + total_score = 0.0 + total_success = 0.0 + start_time = time.time() + + for data_idx in tqdm(data_idxs, total=len(data_idxs), desc=f"[Evaluating {task_name}]"): + try: + exps = evaluator.eval( + generation_config=gen_config, + max_rounds=max_rounds, + idxs=data_idx + ) + + total_score += exps.score + total_success += exps.success + + cur_experiences = exps.experiences + + # Write results to file + with jsonlines.open(output_file, mode="a") as f: + for exp in cur_experiences: + conversation = exp.conversation + cur_reward = exp.reward + cur_success = 1 if exp.reward == 1 else 0 + item_id = f"{task_name}_{data_idx[0]}" + + f.write({ + "conversations": conversation, + "item_id": item_id, + "reward": cur_reward, + "success": cur_success, + }) + + except Exception as e: + logger.error(f"Error evaluating sample {data_idx}: {e}") + # Write error result to file + with jsonlines.open(output_file, mode="a") as f: + f.write({ + "conversations": [{"error": str(e)}], + "item_id": f"{task_name}_{data_idx[0]}", + "reward": 0.0, + "success": 0, + "error": str(e) + }) + + process_time = time.time() - start_time + + score = total_score / len(data_idxs) if data_idxs else 0 + success_rate = total_success / len(data_idxs) if data_idxs else 0 + success_count = int(total_success) + + logger.info(f"Task: {task_name}") + logger.info(f"Score: {score:.4f}") + logger.info(f"Success rate: {success_rate:.4f} ({success_count}/{len(data_idxs)})") + logger.info(f"Time: {process_time:.2f} seconds") + + # Return results + return { + "task_name": task_name, + "score": score, + "success_rate": success_rate, + "total_samples": len(data_idxs), + "success_count": success_count, + "process_time": process_time + } + +def create_output_filepaths(output_dir: str, task_names: List[str]) -> Dict[str, str]: + """Create output file paths for each task""" + output_files = {} + + for task_name in task_names: + task_dir = os.path.join(output_dir, task_name) + os.makedirs(task_dir, exist_ok=True) + output_files[task_name] = os.path.join(task_dir, f"{task_name}_trajectories.jsonl") + + return output_files + +def main(): + args = parse_args() + + # Set logging level + if args.debug: + logging.getLogger().setLevel(logging.DEBUG) + + # Process configuration + if args.config: + config = read_config(args.config) + model_path = args.model_path or config["model_path"] + output_dir = args.output_dir or config["output_dir"] + task_names = args.task_names or list(config["tasks"].keys()) + + # Get task-specific configurations + inference_files = {} + max_rounds = {} + env_server_bases = {} + + for task_name in task_names: + if task_name in config["tasks"]: + task_config = config["tasks"][task_name] + inference_files[task_name] = task_config.get("inference_file") + max_rounds[task_name] = task_config.get("max_rounds") + env_server_bases[task_name] = task_config.get("env_server_base") + else: + model_path = args.model_path + output_dir = args.output_dir + task_names = args.task_names + + # Process lists from command line + inference_files = {} + max_rounds = {} + env_server_bases = {} + + if args.inference_files: + for i, task_name in enumerate(task_names): + if i < len(args.inference_files): + inference_files[task_name] = args.inference_files[i] + + if args.max_rounds: + for i, task_name in enumerate(task_names): + if i < len(args.max_rounds): + max_rounds[task_name] = args.max_rounds[i] + + if args.env_server_bases: + for i, task_name in enumerate(task_names): + if i < len(args.env_server_bases): + env_server_bases[task_name] = args.env_server_bases[i] + + # Set default server bases if not specified + for task_name in task_names: + if task_name not in env_server_bases or not env_server_bases[task_name]: + env_server_bases[task_name] = "http://localhost:8000" + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # Create output files for each task + output_files = create_output_filepaths(output_dir, task_names) + + # Run evaluation for each task + all_results = {} + + for task_name in task_names: + if task_name not in inference_files or not inference_files[task_name]: + logger.error(f"No inference file specified for {task_name}, skipping") + continue + + try: + logger.info(f"Processing task: {task_name}") + + # Set default max_rounds if not specified + task_max_rounds = max_rounds.get(task_name, 10) + + # Generate trajectories + task_results = generate_trajectories( + model_path=model_path, + task_name=task_name, + inference_file=inference_files[task_name], + output_file=output_files[task_name], + max_rounds=task_max_rounds, + env_server_base=env_server_bases[task_name], + data_len=args.data_len, + timeout=args.timeout, + do_sample=args.do_sample.lower() == "true", + temperature=args.temperature, + seed=args.seed + ) + + all_results[task_name] = task_results + + # Save summary + save_summary( + output_dir=os.path.join(output_dir, task_name), + task_name=task_name, + success_rate=task_results["success_rate"], + score=task_results["score"], + total_samples=task_results["total_samples"], + success_count=task_results["success_count"], + process_time=task_results["process_time"] + ) + + except Exception as e: + logger.error(f"Error processing task {task_name}: {e}") + import traceback + logger.error(traceback.format_exc()) + + # Save overall results + overall_results_path = os.path.join(output_dir, "overall_results.json") + with open(overall_results_path, 'w') as f: + json.dump({ + "model_path": model_path, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "results": all_results + }, f, indent=2) + + # Print summary table + logger.info("\n" + "="*60) + logger.info("EVALUATION SUMMARY") + logger.info("="*60) + logger.info(f"{'Task':<15} | {'Success Rate':<15} | {'Score':<10} | {'Success/Total':<15}") + logger.info("-"*60) + + for task_name, results in all_results.items(): + logger.info(f"{task_name:<15} | {results['success_rate']:.2%:<15} | {results['score']:.4f} | {results['success_count']}/{results['total_samples']:<15}") + + logger.info("="*60) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/verl/utils/dataset/__init__.py b/verl/utils/dataset/__init__.py index a7f9b71c..02160771 100644 --- a/verl/utils/dataset/__init__.py +++ b/verl/utils/dataset/__init__.py @@ -13,4 +13,7 @@ # limitations under the License. from .rl_dataset import RLHFDataset -from .rm_dataset import RMDataset \ No newline at end of file +from .rm_dataset import RMDataset +from .sft_dataset import SFTDataset + +__all__ = ["RLHFDataset", "RMDataset", "SFTDataset"] \ No newline at end of file diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py new file mode 100644 index 00000000..89be513b --- /dev/null +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -0,0 +1,146 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Multi-turn SFT dataset that supports training on conversation data with multiple turns +""" + +from typing import List, Union + +import pandas as pd +import torch +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer + +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_local_path_from_hdfs + + +class MultiTurnSFTDataset(Dataset): + """ + Dataset for multi-turn conversations where each assistant response should be trained + """ + + def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config=None): + # Set defaults and extract parameters from config if provided + config = config or {} + self.truncation = config.get("truncation", "error") + self.max_length = config.get("max_length", 1024) + # Get messages_key from the new multiturn config structure + multiturn_config = config.get("multiturn", {}) + self.messages_key = multiturn_config.get("messages_key", "messages") + + assert self.truncation in ["error", "left", "right"] + + if not isinstance(parquet_files, List): + parquet_files = [parquet_files] + + self.parquet_files = parquet_files + if isinstance(tokenizer, str): + tokenizer = hf_tokenizer(tokenizer) + self.tokenizer: PreTrainedTokenizer = tokenizer + + self._download() + self._read_files_and_process() + + def _download(self): + for i, parquet_file in enumerate(self.parquet_files): + self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True) + + def _read_files_and_process(self): + def series_to_item(ls): + import numpy + import pandas + + while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1: + ls = ls[0] + return ls + + dataframes = [] + for parquet_file in self.parquet_files: + dataframe = pd.read_parquet(parquet_file) + dataframes.append(dataframe) + self.dataframe = pd.concat(dataframes) + + # Extract messages list from dataframe + self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist() + + def __len__(self): + return len(self.messages) + + def __getitem__(self, item): + tokenizer = self.tokenizer + messages = self.messages[item] + + # First, get the full conversation tokens + full_tokens = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=False) + input_ids = full_tokens[0] # The output is already a tensor + attention_mask = torch.ones_like(input_ids) + + # Create loss mask by identifying assistant responses + loss_mask = torch.zeros_like(input_ids, dtype=torch.long) + + # Process each message to find assistant responses + for i, msg in enumerate(messages): + # Get tokens for messages up to this point to find the start position + prefix_messages = messages[: i + 1] + prefix_tokens = tokenizer.apply_chat_template(prefix_messages, tokenize=True, return_tensors="pt", add_generation_prompt=False) + + # Get tokens for messages up to previous point + prev_tokens = tokenizer.apply_chat_template(messages[:i], tokenize=True, return_tensors="pt", add_generation_prompt=False) if i > 0 else None + + # Calculate start and end positions + start_pos = prev_tokens[0].shape[0] if prev_tokens is not None else 0 + end_pos = prefix_tokens[0].shape[0] + + # If this is an assistant message, set loss mask + if msg["role"] == "assistant": + loss_mask[start_pos:end_pos] = 1 + + # Handle sequence length + sequence_length = input_ids.shape[0] + if sequence_length < self.max_length: + # Pad sequences + pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 + padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype) * pad_token_id + padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype) + padded_loss_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=loss_mask.dtype) + + input_ids = torch.cat((input_ids, padded_input_ids)) + attention_mask = torch.cat((attention_mask, padded_attention_mask)) + loss_mask = torch.cat((loss_mask, padded_loss_mask)) + elif sequence_length > self.max_length: + if self.truncation == "left": + input_ids = input_ids[-self.max_length :] + attention_mask = attention_mask[-self.max_length :] + loss_mask = loss_mask[-self.max_length :] + elif self.truncation == "right": + input_ids = input_ids[: self.max_length] + attention_mask = attention_mask[: self.max_length] + loss_mask = loss_mask[: self.max_length] + elif self.truncation == "error": + raise ValueError(f"{sequence_length=} is larger than {self.max_length=}") + else: + raise ValueError(f"Unknown truncation method {self.truncation}") + + # Create position IDs + position_ids = torch.arange(len(input_ids), dtype=torch.long) + # Zero out position IDs for padding + position_ids = position_ids * attention_mask + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "loss_mask": loss_mask, + } \ No newline at end of file diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 6b5f65f4..7da194a6 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -1,4 +1,6 @@ # Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,35 +14,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -from omegaconf import ListConfig +import copy +import logging import os -from typing import List, Union +import re +from collections import defaultdict +from typing import List, Optional, Union -import pandas as pd - -import torch +import datasets import numpy as np -from torch.utils.data import Dataset, DataLoader -from transformers import AutoTokenizer, PreTrainedTokenizer -from verl.utils.fs import copy_local_path_from_hdfs +import torch +from omegaconf import DictConfig, ListConfig +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer, ProcessorMixin -from verl.utils.model import compute_position_id_with_mask import verl.utils.torch_functional as verl_F +from verl.utils.model import compute_position_id_with_mask + +logger = logging.getLogger(__name__) def collate_fn(data_list: list[dict]) -> dict: - tensors = {} - non_tensors = {} + """Collate a batch of data.""" + tensors = defaultdict(list) + non_tensors = defaultdict(list) for data in data_list: for key, val in data.items(): if isinstance(val, torch.Tensor): - if key not in tensors: - tensors[key] = [] tensors[key].append(val) else: - if key not in non_tensors: - non_tensors[key] = [] non_tensors[key].append(val) for key, val in tensors.items(): @@ -49,10 +52,7 @@ def collate_fn(data_list: list[dict]) -> dict: for key, val in non_tensors.items(): non_tensors[key] = np.array(val, dtype=object) - output = {} - output.update(tensors) - output.update(non_tensors) - return output + return {**tensors, **non_tensors} class RLHFDataset(Dataset): @@ -60,96 +60,205 @@ class RLHFDataset(Dataset): We assume the dataset contains a column that contains prompts and other information """ - def __init__(self, - parquet_files: Union[str, List[str]], - tokenizer: PreTrainedTokenizer, - prompt_key='prompt', - max_prompt_length=1024, - filter_prompts=True, - cache_dir='~/.cache/verl/rlhf', - chat_template_func=None, - return_raw_chat=False, - truncation='error'): - if not isinstance(parquet_files, (List, ListConfig)): - parquet_files = [parquet_files] - - self.parquet_files = parquet_files - self.cache_dir = os.path.expanduser(cache_dir) + def __init__( + self, + data_files: Union[str, List[str]], + tokenizer: PreTrainedTokenizer, + config: DictConfig, + processor: Optional[ProcessorMixin] = None, + ): + if not isinstance(data_files, (List, ListConfig)): + data_files = [data_files] + + self.data_files = copy.deepcopy(data_files) + self.original_data_files = copy.deepcopy(data_files) # use for resume self.tokenizer = tokenizer - - self.prompt_key = prompt_key - self.max_prompt_length = max_prompt_length - self.filter_prompts = filter_prompts - - self.return_raw_chat = return_raw_chat - self.chat_template_func = chat_template_func - self.truncation = truncation - + self.processor = processor + self.config = config + + self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) + self.prompt_key = config.get("prompt_key", "prompt") + self.image_key = config.get("image_key", "images") + self.video_key = config.get("video_key", "videos") + self.max_prompt_length = config.get("max_prompt_length", 1024) + self.return_raw_chat = config.get("return_raw_chat", False) + self.truncation = config.get("truncation", "error") + self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) + + self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) + self.num_workers = min(self.num_workers, os.cpu_count()) + self.chat_template_func = config.get("chat_template_func", None) + self.need_tools_kwargs = config.get("need_tools_kwargs", False) + self.filter_prompts = config.get("filter_prompts", True) + self.serialize_dataset = False self._download() self._read_files_and_tokenize() - def _download(self): - from verl.utils.fs import copy_local_path_from_hdfs - for i, parquet_file in enumerate(self.parquet_files): - self.parquet_files[i] = copy_local_path_from_hdfs(src=parquet_file, cache_dir=self.cache_dir) + def _download(self, use_origin_parquet=False): + from verl.utils.fs import copy_to_local + + data_files = self.data_files if not use_origin_parquet else self.original_data_files + for i, parquet_file in enumerate(data_files): + self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir) def _read_files_and_tokenize(self): dataframes = [] - for parquet_file in self.parquet_files: + for parquet_file in self.data_files: # read parquet files and cache - dataframe = pd.read_parquet(parquet_file) + dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] dataframes.append(dataframe) - self.dataframe = pd.concat(dataframes) + self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) - print(f'original dataset len: {len(self.dataframe)}') + print(f"dataset len: {len(self.dataframe)}") # filter out too long prompts - tokenizer = self.tokenizer - prompt_key = self.prompt_key - - # nvm if prompt is too long - # self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len( - # tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length, - # axis=1)] - - print(f'filter dataset len: {len(self.dataframe)}') + if self.filter_overlong_prompts: + tokenizer = self.tokenizer + prompt_key = self.prompt_key + self.dataframe = self.dataframe.filter( + lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length, + num_proc=self.num_workers, + desc=f"Filtering prompts longer than {self.max_prompt_length} tokens", + ) + + print(f"filter dataset len: {len(self.dataframe)}") + + def resume_dataset_state(self): + self.serialize_dataset = not hasattr(self, "original_data_files") + # resume dataframe if not it's serialized in data.pt + if not self.serialize_dataset: + self._download(use_origin_parquet=True) # download and resume from original parquet files + self._read_files_and_tokenize() + else: + print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance") def __len__(self): return len(self.dataframe) + def _build_messages(self, example: dict): + messages: list = example.pop(self.prompt_key) + + if self.image_key in example or self.video_key in example: + for message in messages: + content = message["content"] + content_list = [] + for segment in re.split("(|