diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index 7fe299dd4c4..77858d6f60a 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -89,6 +89,9 @@ def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceM if self.config.data.chat_template is not None: raise ValueError('Apply Chat template from config is not supported yet.') + if self.tokenizer.chat_template is None: + logger.log(msg="Empty chat template!", level=logging.WARNING) + # normalize dp size self._normalize_config_bsz() diff --git a/verl/utils/dataset/sft_dataset.py b/verl/utils/dataset/sft_dataset.py index 1b8cec3057d..8117a0350c7 100644 --- a/verl/utils/dataset/sft_dataset.py +++ b/verl/utils/dataset/sft_dataset.py @@ -18,7 +18,7 @@ Each parquet file contains """ -from typing import List, Union +from typing import List, Union, Literal, Iterable import pandas as pd @@ -43,8 +43,8 @@ def __init__(self, prompt_dict_keys=None, response_key='response', response_dict_keys=None, - max_length=1024, - truncation='error'): + max_length: int = 1024, + truncation: Literal['error', 'left', 'right'] = 'error'): assert truncation in ['error', 'left', 'right'] self.truncation = truncation @@ -94,7 +94,7 @@ def series_to_item(ls): except Exception: print(f'self.prompts={self.prompts}') raise - self.prompts = self.prompts.tolist() + self.prompts = self.prompts.squeeze().tolist() self.responses = self.dataframe[self.response_key] for key in self.response_dict_keys: try: @@ -102,7 +102,7 @@ def series_to_item(ls): except Exception: print(f'self.responses={self.responses}') raise - self.responses = self.responses.tolist() + self.responses = self.responses.squeeze().tolist() def __len__(self): return len(self.prompts) @@ -113,8 +113,15 @@ def __getitem__(self, item): prompt = self.prompts[item] response = self.responses[item] + assert isinstance(response, str), f"invalid response type: {type(response)}" + # apply chat template - prompt_chat = [{'role': 'user', 'content': prompt}] + if isinstance(prompt, str): + prompt_chat = [{'role': 'user', 'content': prompt}] + elif isinstance(prompt, Iterable): + prompt_chat = prompt + else: + raise TypeError(f"invalid prompt type: {type(prompt)}") # string prompt_chat_str = tokenizer.apply_chat_template(prompt_chat, add_generation_prompt=True, tokenize=False) @@ -164,7 +171,7 @@ def __getitem__(self, item): # mask out prompt for SFT. loss_mask[:min(prompt_length, loss_mask.size(0)) - 1] = 0 # mask out the last token in response - loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0 + # loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0 return { 'input_ids': input_ids,