@@ -62,10 +62,8 @@ def __init__(self, config, device_mesh: DeviceMesh):
62
62
self .device_mesh = device_mesh
63
63
# build tokenizer first
64
64
local_model_path = copy_local_path_from_hdfs (src = self .config .model .partial_pretrain , verbose = True )
65
- self .tokenizer = AutoTokenizer .from_pretrained (local_model_path ,
66
- trust_remote_code = self .config .model .trust_remote_code )
67
- from verl .utils import set_pad_token_id
68
- set_pad_token_id (self .tokenizer )
65
+ from verl .utils import hf_tokenizer
66
+ self .tokenizer = hf_tokenizer (local_model_path , trust_remote_code = self .config .model .trust_remote_code )
69
67
if self .config .data .chat_template is not None :
70
68
raise ValueError ('Apply Chat template from config is not supported yet.' )
71
69
@@ -77,6 +75,8 @@ def __init__(self, config, device_mesh: DeviceMesh):
77
75
self ._build_model_optimizer ()
78
76
79
77
# TODO: add checkpoint manager
78
+ if self .device_mesh .get_rank () == 0 :
79
+ print (self .config )
80
80
81
81
def _normalize_config_bsz (self ):
82
82
dp_size = self .device_mesh .size ()
@@ -95,13 +95,17 @@ def _build_dataloader(self):
95
95
self .train_dataset = SFTDataset (parquet_files = config .data .train_files ,
96
96
tokenizer = self .tokenizer ,
97
97
prompt_key = config .data .prompt_key ,
98
+ prompt_dict_keys = config .data .get ('prompt_dict_keys' , None ),
98
99
response_key = config .data .response_key ,
100
+ response_dict_keys = config .data .get ('response_dict_keys' , None ),
99
101
max_length = config .data .max_length ,
100
102
truncation = config .data .truncation )
101
103
self .val_dataset = SFTDataset (parquet_files = config .data .val_files ,
102
104
tokenizer = self .tokenizer ,
103
105
prompt_key = config .data .prompt_key ,
106
+ prompt_dict_keys = config .data .get ('prompt_dict_keys' , None ),
104
107
response_key = config .data .response_key ,
108
+ response_dict_keys = config .data .get ('response_dict_keys' , None ),
105
109
max_length = config .data .max_length ,
106
110
truncation = config .data .truncation )
107
111
@@ -292,10 +296,11 @@ def save_checkpoint(self, step):
292
296
# save huggingface model
293
297
if self .device_mesh .get_rank () == 0 :
294
298
os .makedirs (path , exist_ok = True )
295
- hdfs_io .makedirs (self .config .trainer .default_hdfs_dir )
296
299
self .model .save_pretrained (path , state_dict = state_dict )
297
300
self .tokenizer .save_pretrained (path )
298
- hdfs_io .copy (src = path , dst = self .config .trainer .default_hdfs_dir )
301
+ if self .config .trainer .default_hdfs_dir :
302
+ hdfs_io .makedirs (self .config .trainer .default_hdfs_dir , exist_ok = True )
303
+ hdfs_io .copy (src = path , dst = self .config .trainer .default_hdfs_dir , dirs_exist_ok = True )
299
304
torch .distributed .barrier ()
300
305
301
306
def fit (self ):
@@ -349,7 +354,6 @@ def main(config):
349
354
local_rank , rank , world_size = initialize_global_process_group ()
350
355
351
356
device_mesh = init_device_mesh (device_type = 'cuda' , mesh_shape = (world_size ,), mesh_dim_names = ('dp' ,))
352
-
353
357
trainer = FSDPSFTTrainer (config = config , device_mesh = device_mesh )
354
358
trainer .fit ()
355
359
0 commit comments