Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceM
if self.config.data.chat_template is not None:
raise ValueError('Apply Chat template from config is not supported yet.')

if self.tokenizer.chat_template is None:
logger.log(msg="Empty chat template!", level=logging.WARNING)

# normalize dp size
self._normalize_config_bsz()

Expand Down
21 changes: 14 additions & 7 deletions verl/utils/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Each parquet file contains
"""

from typing import List, Union
from typing import List, Union, Literal, Iterable

import pandas as pd

Expand All @@ -43,8 +43,8 @@ def __init__(self,
prompt_dict_keys=None,
response_key='response',
response_dict_keys=None,
max_length=1024,
truncation='error'):
max_length: int = 1024,
truncation: Literal['error', 'left', 'right'] = 'error'):
assert truncation in ['error', 'left', 'right']
self.truncation = truncation

Expand Down Expand Up @@ -94,15 +94,15 @@ def series_to_item(ls):
except Exception:
print(f'self.prompts={self.prompts}')
raise
self.prompts = self.prompts.tolist()
self.prompts = self.prompts.squeeze().tolist()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using pandas==2.2.2, this line raises an error complaining DataFrame does not have .tolist() method. So I added .squeeze() method.

self.responses = self.dataframe[self.response_key]
for key in self.response_dict_keys:
try:
self.responses = self.responses.apply(lambda x: series_to_item(x)[key], axis=1)
except Exception:
print(f'self.responses={self.responses}')
raise
self.responses = self.responses.tolist()
self.responses = self.responses.squeeze().tolist()

def __len__(self):
return len(self.prompts)
Expand All @@ -113,8 +113,15 @@ def __getitem__(self, item):
prompt = self.prompts[item]
response = self.responses[item]

assert isinstance(response, str), f"invalid response type: {type(response)}"

# apply chat template
prompt_chat = [{'role': 'user', 'content': prompt}]
if isinstance(prompt, str):
prompt_chat = [{'role': 'user', 'content': prompt}]
elif isinstance(prompt, Iterable):
prompt_chat = prompt
else:
raise TypeError(f"invalid prompt type: {type(prompt)}")

# string
prompt_chat_str = tokenizer.apply_chat_template(prompt_chat, add_generation_prompt=True, tokenize=False)
Expand Down Expand Up @@ -164,7 +171,7 @@ def __getitem__(self, item):
# mask out prompt for SFT.
loss_mask[:min(prompt_length, loss_mask.size(0)) - 1] = 0
# mask out the last token in response
loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0
# loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea why eos_token masking is needed so I erased the line. I'll revert this if the masking is needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vermouth1992 could u help take a look?


return {
'input_ids': input_ids,
Expand Down