Skip to content

Commit d60f843

Browse files
[sft] feat: fix sft dataset with latest preprocess code (Jiayi-Pan#49)
* api: rename tracking logger to wandb logger type * [sft] feat: add tests for sft dataset * refresh dataset * force refresh * use ds model for tokenizer * add option for trainer.val_only * fix path * fix lint * add sft test for cot and raw q&a * add hf_tokenizer api to patch gemma tokenizer * fix test
1 parent c7534db commit d60f843

File tree

18 files changed

+192
-101
lines changed

18 files changed

+192
-101
lines changed

examples/data_preprocess/gsm8k.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,17 @@ def extract_solution(solution_str):
5252
def make_map_fn(split):
5353

5454
def process_fn(example, idx):
55-
question = example.pop('question')
55+
question_raw = example.pop('question')
5656

57-
question = question + ' ' + instruction_following
57+
question = question_raw + ' ' + instruction_following
5858

59-
answer = example.pop('answer')
60-
solution = extract_solution(answer)
59+
answer_raw = example.pop('answer')
60+
solution = extract_solution(answer_raw)
6161
data = {
6262
"data_source": data_source,
6363
"prompt": [{
6464
"role": "user",
65-
"content": question
65+
"content": question,
6666
}],
6767
"ability": "math",
6868
"reward_model": {
@@ -71,7 +71,9 @@ def process_fn(example, idx):
7171
},
7272
"extra_info": {
7373
'split': split,
74-
'index': idx
74+
'index': idx,
75+
'answer': answer_raw,
76+
"question": question_raw,
7577
}
7678
}
7779
return data

examples/sft/gsm8k/run_gemma_2b.sh

+11-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Tested in 4 GPUs
1+
# Tested with 2 & 4 GPUs
22

33
set -x
44

@@ -8,7 +8,7 @@ if [ "$#" -lt 2 ]; then
88
fi
99

1010
nproc_per_node=$1
11-
hdfs_path=$2
11+
save_path=$2
1212

1313
# Shift the arguments so $@ refers to the rest
1414
shift 2
@@ -17,12 +17,15 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \
1717
-m verl.trainer.fsdp_sft_trainer \
1818
data.train_files=$HOME/data/gsm8k/train.parquet \
1919
data.val_files=$HOME/data/gsm8k/test.parquet \
20-
data.prompt_key=prompt \
21-
data.response_key=answer \
22-
data.micro_batch_size=32 \
20+
data.prompt_key=extra_info \
21+
data.response_key=extra_info \
22+
+data.prompt_dict_keys=['question'] \
23+
+data.response_dict_keys=['answer'] \
24+
data.micro_batch_size=8 \
2325
model.partial_pretrain=google/gemma-2b-it \
24-
trainer.default_hdfs_dir=$hdfs_path \
26+
trainer.default_local_dir=$save_path \
2527
trainer.project_name=gsm8k-sft \
2628
trainer.experiment_name=gsm8k-sft-gemma-2b-it \
27-
trainer.total_epochs=3 \
28-
trainer.logger=['console','wandb'] $@
29+
trainer.total_epochs=2 \
30+
trainer.logger=['console','wandb'] \
31+
trainer.default_hdfs_dir=null $@

examples/split_placement/main_ppo_split.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,8 @@ def main_task(config):
113113
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
114114

115115
# instantiate tokenizer
116-
tokenizer = AutoTokenizer.from_pretrained(local_path)
117-
from verl.utils import set_pad_token_id
118-
set_pad_token_id(tokenizer)
116+
from verl.utils import hf_tokenizer
117+
tokenizer = hf_tokenizer(local_path)
119118

120119
# define worker classes
121120
if config.actor_rollout_ref.actor.strategy == 'fsdp':

tests/verl/utils/dataset/test_rl_dataset.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,13 @@ def get_gsm8k_data():
2323
local_folder = os.path.expanduser('~/verl-data/gsm8k/')
2424
local_path = os.path.join(local_folder, 'train.parquet')
2525
os.makedirs(local_folder, exist_ok=True)
26-
# import fsspec
27-
# with fsspec.open(url, mode='rb') as fin, fsspec.open(local_path, mode='wb') as fout:
28-
# content = fin.read()
29-
# fout.write(content)
3026
return local_path
3127

3228

3329
def test_rl_dataset():
3430
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
35-
tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/deepseek-coder-1.3b-instruct')
36-
from verl.utils import set_pad_token_id
37-
set_pad_token_id(tokenizer)
31+
from verl.utils import hf_tokenizer
32+
tokenizer = hf_tokenizer('deepseek-ai/deepseek-coder-1.3b-instruct')
3833
local_path = get_gsm8k_data()
3934
dataset = RLHFDataset(parquet_files=local_path, tokenizer=tokenizer, prompt_key='prompt', max_prompt_length=256)
4035

tests/verl/utils/dataset/test_rm_dataset.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import os
1515

1616
from transformers import AutoTokenizer
17-
from verl.utils import set_pad_token_id
17+
from verl.utils import hf_tokenizer
1818
from verl.utils.dataset.rm_dataset import RMDataset
1919

2020

@@ -24,16 +24,11 @@ def get_rm_data():
2424
local_folder = os.path.expanduser('~/verl-data/full_hh_rlhf/rm/')
2525
local_path = os.path.join(local_folder, 'test.parquet')
2626
os.makedirs(local_folder, exist_ok=True)
27-
# import fsspec
28-
# with fsspec.open(url, mode='rb') as fin, fsspec.open(local_path, mode='wb') as fout:
29-
# content = fin.read()
30-
# fout.write(content)
3127
return local_path
3228

3329

3430
def test_rm_dataset():
35-
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")
36-
set_pad_token_id(tokenizer)
31+
tokenizer = hf_tokenizer("facebook/opt-1.3b")
3732
local_path = get_rm_data()
3833
dataset = RMDataset(parquet_files=local_path, tokenizer=tokenizer, max_length=512)
3934
data = dataset[0]['input_ids']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
from transformers import AutoTokenizer
17+
from verl.utils import hf_tokenizer
18+
from verl.utils.dataset.sft_dataset import SFTDataset
19+
20+
21+
def get_gsm8k_data():
22+
# prepare test dataset
23+
url = "https://github.com/eric-haibin-lin/verl-data/raw/refs/heads/main/gsm8k/train.parquet"
24+
local_folder = os.path.expanduser('~/verl-data/gsm8k/')
25+
local_path = os.path.join(local_folder, 'train.parquet')
26+
return local_path
27+
28+
29+
def test_sft_cot_dataset():
30+
tokenizer = hf_tokenizer('deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct')
31+
local_path = get_gsm8k_data()
32+
dataset = SFTDataset(parquet_files=local_path,
33+
tokenizer=tokenizer,
34+
prompt_key='prompt',
35+
prompt_dict_keys=['content'],
36+
response_key='extra_info',
37+
response_dict_keys=['answer'],
38+
max_length=512)
39+
40+
data = dataset[0]['input_ids']
41+
output = tokenizer.batch_decode([data])[0]
42+
assert len(output) > 1
43+
assert type(output) == str
44+
45+
46+
def test_sft_dataset():
47+
tokenizer = hf_tokenizer('deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct')
48+
local_path = get_gsm8k_data()
49+
dataset = SFTDataset(parquet_files=local_path,
50+
tokenizer=tokenizer,
51+
prompt_key='extra_info',
52+
prompt_dict_keys=['question'],
53+
response_key='extra_info',
54+
response_dict_keys=['answer'],
55+
max_length=512)
56+
57+
data = dataset[0]['input_ids']
58+
output = tokenizer.batch_decode([data])[0]
59+
assert len(output) > 1
60+
assert type(output) == str

verl/trainer/fsdp_sft_trainer.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,8 @@ def __init__(self, config, device_mesh: DeviceMesh):
6262
self.device_mesh = device_mesh
6363
# build tokenizer first
6464
local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True)
65-
self.tokenizer = AutoTokenizer.from_pretrained(local_model_path,
66-
trust_remote_code=self.config.model.trust_remote_code)
67-
from verl.utils import set_pad_token_id
68-
set_pad_token_id(self.tokenizer)
65+
from verl.utils import hf_tokenizer
66+
self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code)
6967
if self.config.data.chat_template is not None:
7068
raise ValueError('Apply Chat template from config is not supported yet.')
7169

@@ -77,6 +75,8 @@ def __init__(self, config, device_mesh: DeviceMesh):
7775
self._build_model_optimizer()
7876

7977
# TODO: add checkpoint manager
78+
if self.device_mesh.get_rank() == 0:
79+
print(self.config)
8080

8181
def _normalize_config_bsz(self):
8282
dp_size = self.device_mesh.size()
@@ -95,13 +95,17 @@ def _build_dataloader(self):
9595
self.train_dataset = SFTDataset(parquet_files=config.data.train_files,
9696
tokenizer=self.tokenizer,
9797
prompt_key=config.data.prompt_key,
98+
prompt_dict_keys=config.data.get('prompt_dict_keys', None),
9899
response_key=config.data.response_key,
100+
response_dict_keys=config.data.get('response_dict_keys', None),
99101
max_length=config.data.max_length,
100102
truncation=config.data.truncation)
101103
self.val_dataset = SFTDataset(parquet_files=config.data.val_files,
102104
tokenizer=self.tokenizer,
103105
prompt_key=config.data.prompt_key,
106+
prompt_dict_keys=config.data.get('prompt_dict_keys', None),
104107
response_key=config.data.response_key,
108+
response_dict_keys=config.data.get('response_dict_keys', None),
105109
max_length=config.data.max_length,
106110
truncation=config.data.truncation)
107111

@@ -292,10 +296,11 @@ def save_checkpoint(self, step):
292296
# save huggingface model
293297
if self.device_mesh.get_rank() == 0:
294298
os.makedirs(path, exist_ok=True)
295-
hdfs_io.makedirs(self.config.trainer.default_hdfs_dir)
296299
self.model.save_pretrained(path, state_dict=state_dict)
297300
self.tokenizer.save_pretrained(path)
298-
hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir)
301+
if self.config.trainer.default_hdfs_dir:
302+
hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)
303+
hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)
299304
torch.distributed.barrier()
300305

301306
def fit(self):
@@ -349,7 +354,6 @@ def main(config):
349354
local_rank, rank, world_size = initialize_global_process_group()
350355

351356
device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('dp',))
352-
353357
trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh)
354358
trainer.fit()
355359

verl/trainer/main_generation.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ def main(config):
4343
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
4444
OmegaConf.resolve(config)
4545
local_path = copy_local_path_from_hdfs(config.model.path)
46-
tokenizer = AutoTokenizer.from_pretrained(local_path)
47-
from verl.utils import set_pad_token_id
48-
set_pad_token_id(tokenizer)
46+
from verl.utils import hf_tokenizer
47+
tokenizer = hf_tokenizer(local_path)
4948

5049
if config.rollout.temperature == 0.:
5150
assert config.data.n_samples == 1, 'When temperature=0, n_samples must be 1.'

verl/trainer/main_ppo.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def _select_rm_score_fn(data_source):
3131

3232

3333
class RewardManager():
34+
"""The reward manager.
35+
"""
3436

3537
def __init__(self, tokenizer, num_examine) -> None:
3638
self.tokenizer = tokenizer
@@ -112,9 +114,8 @@ def main_task(config):
112114
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
113115

114116
# instantiate tokenizer
115-
tokenizer = AutoTokenizer.from_pretrained(local_path)
116-
from verl.utils import set_pad_token_id
117-
set_pad_token_id(tokenizer)
117+
from verl.utils import hf_tokenizer
118+
tokenizer = hf_tokenizer(local_path)
118119

119120
# define worker classes
120121
if config.actor_rollout_ref.actor.strategy == 'fsdp':

verl/trainer/ppo/ray_trainer.py

+4
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,9 @@ def fit(self):
420420
if self.val_reward_fn is not None:
421421
val_metrics = self._validate()
422422
pprint(f'Initial validation metrics: {val_metrics}')
423+
logger.log(data=val_metrics, step=global_steps)
424+
if self.config.trainer.get('val_only', False):
425+
return
423426

424427
for epoch in range(self.config.trainer.total_epochs):
425428
for batch_dict in self.train_dataloader:
@@ -527,3 +530,4 @@ def fit(self):
527530
if self.val_reward_fn is not None:
528531
val_metrics = self._validate()
529532
pprint(f'Final validation metrics: {val_metrics}')
533+
logger.log(data=val_metrics, step=global_steps)

verl/trainer/ppo/workers/fsdp_workers.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from verl.utils.import_utils import import_external_libs
3636
from verl.utils.debug import log_gpu_memory_usage
3737
import verl.utils.hdfs_io as hdfs_io
38-
from verl.utils import set_pad_token_id
38+
from verl.utils import hf_tokenizer
3939

4040
logger = logging.getLogger(__file__)
4141
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
@@ -107,8 +107,7 @@ def _build_model_optimizer(self,
107107

108108
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
109109
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
110-
self.tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=trust_remote_code)
111-
set_pad_token_id(self.tokenizer)
110+
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
112111

113112
torch_dtype = fsdp_config.get('model_dtype', None)
114113
if torch_dtype is None:
@@ -467,9 +466,7 @@ def _build_critic_model_optimizer(self, config):
467466
from transformers import AutoTokenizer
468467

469468
tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path)
470-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
471-
trust_remote_code=config.model.get('trust_remote_code', False))
472-
set_pad_token_id(self.tokenizer)
469+
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False))
473470

474471
from omegaconf import OmegaConf
475472
override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
@@ -673,14 +670,9 @@ def _build_model(self, config):
673670
else:
674671
self._do_switch_chat_template = True
675672
input_tokenizer_local_path = copy_local_path_from_hdfs(config.model.input_tokenizer)
676-
self.input_tokenizer = AutoTokenizer.from_pretrained(input_tokenizer_local_path,
677-
trust_remote_code=config.model.get(
678-
'trust_remote_code', False))
679-
self.tokenizer = AutoTokenizer.from_pretrained(local_path,
680-
trust_remote_code=config.model.get(
681-
'trust_remote_code', False))
682-
set_pad_token_id(self.tokenizer)
683-
set_pad_token_id(self.input_tokenizer)
673+
self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path,
674+
trust_remote_code=config.model.get('trust_remote_code', False))
675+
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get('trust_remote_code', False))
684676

685677
trust_remote_code = config.model.get('trust_remote_code', False)
686678
model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)

0 commit comments

Comments
 (0)