diff --git a/examples/data_preprocess/multiturn.py b/examples/data_preprocess/multiturn.py new file mode 100644 index 00000000000..98407bec2b5 --- /dev/null +++ b/examples/data_preprocess/multiturn.py @@ -0,0 +1,131 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Create a simple multi-turn dataset for testing +""" + +import os +import pandas as pd +import argparse + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--local_dir', default='~/data/multiturn') + parser.add_argument('--hdfs_dir', default=None) + args = parser.parse_args() + + # Create example conversations + conversations = [] + + # Conversation 1 + conversations.append({ + "messages": [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "What is the capital of France?" + }, { + "role": "assistant", + "content": "The capital of France is Paris." + }, { + "role": "user", + "content": "And what about Germany?" + }, { + "role": "assistant", + "content": "The capital of Germany is Berlin." + }] + }) + + # Conversation 2 + conversations.append({ + "messages": [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Can you explain quantum computing?" + }, { + "role": + "assistant", + "content": + "Quantum computing is a type of computing that uses quantum-mechanical phenomena, such as superposition and entanglement, to perform operations on data." + }, { + "role": "user", + "content": "How is it different from classical computing?" + }, { + "role": + "assistant", + "content": + "Classical computing uses bits that are either 0 or 1, while quantum computing uses quantum bits or qubits that can exist in multiple states simultaneously due to superposition." + }] + }) + + # Conversation 3 + conversations.append({ + "messages": [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Write a simple Python function to calculate factorial." + }, { + "role": + "assistant", + "content": + "```python\ndef factorial(n):\n if n == 0 or n == 1:\n return 1\n else:\n return n * factorial(n-1)\n```\n\nThis is a recursive function to calculate the factorial of a number." + }, { + "role": "user", + "content": "Can you make it iterative instead?" + }, { + "role": + "assistant", + "content": + "```python\ndef factorial(n):\n result = 1\n for i in range(1, n+1):\n result *= i\n return result\n```\n\nThis is an iterative version of the factorial function." + }] + }) + + # Create train and test datasets + train_data = conversations[:2] # First 2 conversations for training + test_data = conversations[2:] # Last conversation for testing + + # Create output directory + local_dir = os.path.expanduser(args.local_dir) + os.makedirs(local_dir, exist_ok=True) + + # Save to parquet files + train_df = pd.DataFrame(train_data) + test_df = pd.DataFrame(test_data) + + train_df.to_parquet(os.path.join(local_dir, 'train.parquet')) + test_df.to_parquet(os.path.join(local_dir, 'test.parquet')) + + # Handle HDFS if specified + if args.hdfs_dir is not None: + try: + from verl.utils.hdfs_io import copy, makedirs + makedirs(args.hdfs_dir) + copy(src=local_dir, dst=args.hdfs_dir) + except ImportError: + print("Warning: HDFS support not available. Skipping HDFS copy.") + + # Print statistics + print(f"Train dataset size: {len(train_df)}") + print(f"Test dataset size: {len(test_df)}") + print(f"Data saved to {local_dir}") + + +if __name__ == '__main__': + main() diff --git a/examples/sft/multiturn/run_qwen_05_sp2.sh b/examples/sft/multiturn/run_qwen_05_sp2.sh new file mode 100755 index 00000000000..1da72070b96 --- /dev/null +++ b/examples/sft/multiturn/run_qwen_05_sp2.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen_05_sp2.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/multiturn/train.parquet \ + data.val_files=$HOME/data/multiturn/test.parquet \ + data.multiturn.enable=true \ + data.multiturn.messages_key=messages \ + data.micro_batch_size=4 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=multiturn-sft \ + trainer.experiment_name=multiturn-sft-qwen-2.5-0.5b-instruct-sp2 \ + trainer.logger=['console'] \ + trainer.total_training_steps=1 \ + trainer.default_hdfs_dir=null $@ \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true \ No newline at end of file diff --git a/tests/sft/run_sft_multiturn.sh b/tests/sft/run_sft_multiturn.sh new file mode 100755 index 00000000000..1da72070b96 --- /dev/null +++ b/tests/sft/run_sft_multiturn.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen_05_sp2.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/multiturn/train.parquet \ + data.val_files=$HOME/data/multiturn/test.parquet \ + data.multiturn.enable=true \ + data.multiturn.messages_key=messages \ + data.micro_batch_size=4 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=multiturn-sft \ + trainer.experiment_name=multiturn-sft-qwen-2.5-0.5b-instruct-sp2 \ + trainer.logger=['console'] \ + trainer.total_training_steps=1 \ + trainer.default_hdfs_dir=null $@ \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true \ No newline at end of file diff --git a/tests/verl/utils/dataset/test_multiturn_sft_dataset.py b/tests/verl/utils/dataset/test_multiturn_sft_dataset.py new file mode 100644 index 00000000000..64594c63338 --- /dev/null +++ b/tests/verl/utils/dataset/test_multiturn_sft_dataset.py @@ -0,0 +1,193 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test the MultiTurnSFTDataset implementation +""" +import os +import pandas as pd +import torch +from transformers import AutoTokenizer +from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset + + +def test_multiturn_sft_dataset(): + print("Starting test...") + # Create a temporary parquet file with test data + test_data = { + 'messages': [[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "What is 2+2?" + }, { + "role": "assistant", + "content": "2+2 equals 4." + }, { + "role": "user", + "content": "And what is 4+4?" + }, { + "role": "assistant", + "content": "4+4 equals 8." + }], + [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Tell me a joke." + }, { + "role": "assistant", + "content": "Why did the chicken cross the road?" + }, { + "role": "user", + "content": "Why?" + }, { + "role": "assistant", + "content": "To get to the other side!" + }]] + } + + # Create test directory if it doesn't exist + os.makedirs('test_data', exist_ok=True) + test_file = 'test_data/test.parquet' + + # Save test data to parquet + df = pd.DataFrame(test_data) + df.to_parquet(test_file) + + # Initialize tokenizer and dataset + tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-Coder-7B-Instruct') + config = {'max_length': 512, 'truncation': 'error', 'multiturn': {'messages_key': 'messages'}} + dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config) + + # Test 1: Dataset Length + assert len(dataset) == 2, f"Expected dataset length 2, got {len(dataset)}" + + # Get items for testing + item0 = dataset[0] # Math conversation + item1 = dataset[1] # Joke conversation + + # Test 2: Required Keys and Types + required_keys = ['input_ids', 'attention_mask', 'position_ids', 'loss_mask'] + for key in required_keys: + assert key in item0, f"Missing key {key} in dataset item" + assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key}" + assert item0[key].dtype == torch.long, f"Expected torch.long for {key}, got {item0[key].dtype}" + + # Test 3: Shape Consistency + assert item0['loss_mask'].shape == item0['input_ids'].shape, \ + "Loss mask shape doesn't match input_ids shape" + assert item0['attention_mask'].shape == item0['input_ids'].shape, \ + "Attention mask shape doesn't match input_ids shape" + assert item0['position_ids'].shape == item0['input_ids'].shape, \ + "Position IDs shape doesn't match input_ids shape" + + # Test 4: Loss Mask Pattern - Math Conversation + loss_mask0 = item0['loss_mask'] + input_ids0 = item0['input_ids'] + + # Find assistant response positions + assistant_positions0 = torch.where(loss_mask0 == 1)[0] + assert len(assistant_positions0) > 0, "No assistant positions found in loss mask" + + # Decode and verify assistant responses + assistant_text0 = tokenizer.decode(input_ids0[loss_mask0 == 1]) + print(f"Math conversation assistant text: {assistant_text0}") + assert "2+2 equals 4" in assistant_text0, "First assistant response not found" + assert "4+4 equals 8" in assistant_text0, "Second assistant response not found" + + # Test 5: Loss Mask Pattern - Joke Conversation + loss_mask1 = item1['loss_mask'] + input_ids1 = item1['input_ids'] + + # Find assistant response positions + assistant_positions1 = torch.where(loss_mask1 == 1)[0] + assert len(assistant_positions1) > 0, "No assistant positions found in loss mask" + + # Decode and verify assistant responses + assistant_text1 = tokenizer.decode(input_ids1[loss_mask1 == 1]) + print(f"Joke conversation assistant text: {assistant_text1}") + assert "chicken cross the road" in assistant_text1, "First assistant response not found" + assert "other side" in assistant_text1, "Second assistant response not found" + + # Test 6: Attention Mask Pattern + attention_mask0 = item0['attention_mask'] + sequence_length = torch.sum(attention_mask0) + assert sequence_length > 0, "No tokens marked as attended in attention mask" + assert torch.all(attention_mask0[:sequence_length] == 1), "Incorrect attention mask pattern" + if sequence_length < len(attention_mask0): + assert torch.all(attention_mask0[sequence_length:] == 0), "Padding not properly masked" + + # Test 7: Position IDs Pattern + position_ids0 = item0['position_ids'] + assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), \ + "Position IDs not sequential for non-padded tokens" + if sequence_length < len(position_ids0): + assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero" + + # Test 8: Verify loss mask for assistant responses + # Get the full conversation text + full_text = tokenizer.decode(input_ids0) + print(f"\nFull conversation text:\n{full_text}") + + # Get the assistant responses + assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 1]) + print(f"\nAssistant responses (from loss mask):\n{assistant_text}") + + # Verify that loss mask is set for all assistant responses + for msg in test_data['messages'][0]: # First conversation + if msg['role'] == 'assistant': + # The content should appear in the masked text + assert msg['content'] in assistant_text, \ + f"Assistant message '{msg['content']}' not found in masked text" + + # The content should NOT appear in the non-masked text + non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) + assert msg['content'] not in non_assistant_text, \ + f"Assistant message '{msg['content']}' found in non-assistant text" + + # Test 9: Verify non-assistant parts have loss_mask=0 + # Get non-assistant text + non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) + print(f"\nNon-assistant text (from loss mask):\n{non_assistant_text}") + + # Verify that system and user messages are in the non-assistant text + for msg in test_data['messages'][0]: # First conversation + if msg['role'] in ['system', 'user']: + assert msg['content'] in non_assistant_text, \ + f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text" + + # And verify they're NOT in the assistant text + assert msg['content'] not in assistant_text, \ + f"{msg['role'].title()} message '{msg['content']}' found in assistant text" + + # Test 10: Verify padding behavior + padding_config = {'max_length': 1024, 'truncation': 'error', 'multiturn': {'messages_key': 'messages'}} + small_dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=padding_config) + padded_item = small_dataset[0] + + # Get actual sequence length (before padding) + actual_length = torch.sum(padded_item['attention_mask']) + + # Verify padding tokens + assert torch.all(padded_item['input_ids'][actual_length:] == tokenizer.pad_token_id), \ + "Padding tokens not set correctly" + assert torch.all(padded_item['attention_mask'][actual_length:] == 0), \ + "Attention mask not set correctly for padding" + assert torch.all(padded_item['loss_mask'][actual_length:] == 0), \ + "Loss mask not set correctly for padding" + + print("All tests passed!") + print("Starting test...") diff --git a/verl/trainer/config/sft_trainer.yaml b/verl/trainer/config/sft_trainer.yaml index a18082456ac..4c50dda693c 100644 --- a/verl/trainer/config/sft_trainer.yaml +++ b/verl/trainer/config/sft_trainer.yaml @@ -4,8 +4,13 @@ data: micro_batch_size_per_gpu: 4 # this is also val batch size train_files: ~/data/gsm8k/train.parquet val_files: ~/data/gsm8k/test.parquet + # Single-turn settings prompt_key: question response_key: answer + # Multi-turn settings + multiturn: + enable: false # Set to true to use multi-turn dataset + messages_key: messages # Key for messages list in multi-turn mode max_length: 1024 truncation: error balance_dp_token: False diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index 54089202afd..2efdd69eae8 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -39,6 +39,7 @@ from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager from verl.utils.dataset import SFTDataset +from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset from verl.utils.fs import copy_to_local from verl.utils.tracking import Tracking from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group @@ -122,10 +123,18 @@ def _build_dataloader(self): config = self.config # build dataset from verl.utils.import_utils import load_extern_type + + # First check if a custom dataset class is specified if config.data.custom_cls.get("path", None): dataset_cls = load_extern_type(config.data.custom_cls.path, config.data.custom_cls.name) + # Then check if multi-turn dataset should be used + elif config.data.get('multiturn', {}).get('enable', False): + dataset_cls = MultiTurnSFTDataset + # Default to single-turn dataset else: dataset_cls = SFTDataset + + # Create datasets based on the selected class self.train_dataset = dataset_cls(parquet_files=config.data.train_files, tokenizer=self.tokenizer, config=config.data) diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py new file mode 100644 index 00000000000..8ff29cb0a5c --- /dev/null +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -0,0 +1,155 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Multi-turn SFT dataset that supports training on conversation data with multiple turns +""" + +from typing import List, Union + +import pandas as pd +import torch +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer + +from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.model import compute_position_id_with_mask +from verl.utils import hf_tokenizer + + +class MultiTurnSFTDataset(Dataset): + """ + Dataset for multi-turn conversations where each assistant response should be trained + """ + + def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config=None): + # Set defaults and extract parameters from config if provided + config = config or {} + self.truncation = config.get('truncation', 'error') + self.max_length = config.get('max_length', 1024) + # Get messages_key from the new multiturn config structure + multiturn_config = config.get('multiturn', {}) + self.messages_key = multiturn_config.get('messages_key', 'messages') + + assert self.truncation in ['error', 'left', 'right'] + + if not isinstance(parquet_files, List): + parquet_files = [parquet_files] + + self.parquet_files = parquet_files + if isinstance(tokenizer, str): + tokenizer = hf_tokenizer(tokenizer) + self.tokenizer: PreTrainedTokenizer = tokenizer + + self._download() + self._read_files_and_process() + + def _download(self): + for i, parquet_file in enumerate(self.parquet_files): + self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True) + + def _read_files_and_process(self): + + def series_to_item(ls): + import pandas, numpy + while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1: + ls = ls[0] + return ls + + dataframes = [] + for parquet_file in self.parquet_files: + dataframe = pd.read_parquet(parquet_file) + dataframes.append(dataframe) + self.dataframe = pd.concat(dataframes) + + # Extract messages list from dataframe + self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist() + + def __len__(self): + return len(self.messages) + + def __getitem__(self, item): + tokenizer = self.tokenizer + messages = self.messages[item] + + # First, get the full conversation tokens + full_tokens = tokenizer.apply_chat_template(messages, + tokenize=True, + return_tensors='pt', + add_generation_prompt=False) + input_ids = full_tokens[0] # The output is already a tensor + attention_mask = torch.ones_like(input_ids) + + # Create loss mask by identifying assistant responses + loss_mask = torch.zeros_like(input_ids, dtype=torch.long) + + # Process each message to find assistant responses + current_length = 0 + for i, msg in enumerate(messages): + # Get tokens for messages up to this point to find the start position + prefix_messages = messages[:i + 1] + prefix_tokens = tokenizer.apply_chat_template(prefix_messages, + tokenize=True, + return_tensors='pt', + add_generation_prompt=False) + + # Get tokens for messages up to previous point + prev_tokens = tokenizer.apply_chat_template( + messages[:i], tokenize=True, return_tensors='pt', add_generation_prompt=False) if i > 0 else None + + # Calculate start and end positions + start_pos = prev_tokens[0].shape[0] if prev_tokens is not None else 0 + end_pos = prefix_tokens[0].shape[0] + + # If this is an assistant message, set loss mask + if msg['role'] == 'assistant': + loss_mask[start_pos:end_pos] = 1 + + # Handle sequence length + sequence_length = input_ids.shape[0] + if sequence_length < self.max_length: + # Pad sequences + pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 + padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), + dtype=input_ids.dtype) * pad_token_id + padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype) + padded_loss_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=loss_mask.dtype) + + input_ids = torch.cat((input_ids, padded_input_ids)) + attention_mask = torch.cat((attention_mask, padded_attention_mask)) + loss_mask = torch.cat((loss_mask, padded_loss_mask)) + elif sequence_length > self.max_length: + if self.truncation == 'left': + input_ids = input_ids[-self.max_length:] + attention_mask = attention_mask[-self.max_length:] + loss_mask = loss_mask[-self.max_length:] + elif self.truncation == 'right': + input_ids = input_ids[:self.max_length] + attention_mask = attention_mask[:self.max_length] + loss_mask = loss_mask[:self.max_length] + elif self.truncation == 'error': + raise ValueError(f'{sequence_length=} is larger than {self.max_length=}') + else: + raise ValueError(f'Unknown truncation method {self.truncation}') + + # Create position IDs + position_ids = torch.arange(len(input_ids), dtype=torch.long) + # Zero out position IDs for padding + position_ids = position_ids * attention_mask + + return { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'position_ids': position_ids, + 'loss_mask': loss_mask + }