Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…nto add_split_param
  • Loading branch information
DesmonDay committed Oct 16, 2024
2 parents 223e089 + 697a4cc commit ae9ddce
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 41 deletions.
8 changes: 4 additions & 4 deletions llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,15 @@ PaddleNLP 支持多个主流大模型的 SFT、LoRA、Prefix Tuning 等精调策
样例数据:

```text
{"src": "类型#裙*颜色#蓝色*风格#清新*图案#蝴蝶结", "tgt": "裙身处采用立体蝴蝶结装饰辅以蓝色条带点缀,令衣身造型饱满富有层次的同时为其注入一丝甜美气息。将女孩清新娇俏的一面衬托而出。"}
{"src": "Give three tips for staying healthy.", "tgt": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."}
...
```

为了方便测试,我们也提供了广告生成数据集可以直接使用
为了方便测试,我们也提供了[tatsu-lab/alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)demo 数据集可以直接使用

```shell
wget https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz
tar -zxvf AdvertiseGen.tar.gz
wget https://bj.bcebos.com/paddlenlp/datasets/examples/alpaca_demo.gz
tar -xvf alpaca_demo.gz
```

#### 2.2 全参精调:SFT
Expand Down
2 changes: 1 addition & 1 deletion llm/config/llama/lora_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"num_train_epochs": 1,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
Expand Down
2 changes: 1 addition & 1 deletion llm/config/llama/sft_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"gradient_accumulation_steps": 2,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"num_train_epochs": 1,
"learning_rate": 3e-05,
"warmup_steps": 30,
"logging_steps": 1,
Expand Down
54 changes: 45 additions & 9 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,21 +136,25 @@ def __init__(self, args):
self._process_master_weight = None
self._process_optimizer_weight = None
self._lock = None
self._shared_save_path = None
self._shared_save_model_flag = None
self._shared_save_master_weight_flag = None
self._shared_save_optimizer_flag = None

if "async_save" in self.args.unified_checkpoint_config:
self._lock = multiprocessing.Lock()
self._shared_save_model_path = multiprocessing.Array("c", 100000)
self._shared_save_model_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_master_weight_path = multiprocessing.Array("c", 100000)
self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_optimizer_path = multiprocessing.Array("c", 100000)
self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_model_flag = multiprocessing.Array("i", 1)
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1)
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)

def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_type="model_weight"):
def _file_save_async_or_sync(
self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight"
):
if is_sync:
for k in list(state_dict.keys()):
if isinstance(state_dict[k], paddle.Tensor):
Expand All @@ -165,6 +169,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
meta_dict = self._meta_dict_model
shared_save_flag = self._shared_save_model_flag
shared_save_path = self._shared_save_model_path
shared_save_signal_path = self._shared_save_model_signal_path
if self._process_model_weight is None:
self._process_model_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
Expand All @@ -173,12 +178,14 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
self._shm_model_weight.name,
self._shared_save_model_flag,
self._shared_save_model_path,
self._shared_save_model_signal_path,
self._lock,
state_dict_type,
self.global_rank,
),
)
self._process_model_weight.start()
process = self._process_model_weight
elif state_dict_type == "master_weight":
if self._shm_master_weight is None:
self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict)
Expand All @@ -187,6 +194,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
meta_dict = self._meta_dict_master_weight
shared_save_flag = self._shared_save_master_weight_flag
shared_save_path = self._shared_save_master_weight_path
shared_save_signal_path = self._shared_save_master_weight_signal_path
if self._process_master_weight is None:
self._process_master_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
Expand All @@ -195,6 +203,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
self._shm_master_weight.name,
self._shared_save_master_weight_flag,
self._shared_save_master_weight_path,
self._shared_save_master_weight_signal_path,
self._lock,
"model_weight"
if "skip_save_model_weight" in self.args.unified_checkpoint_config
Expand All @@ -203,6 +212,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
),
)
self._process_master_weight.start()
process = self._process_master_weight
elif state_dict_type == "optimizer_weight":
if self._shm_optimizer_weight is None:
self._meta_dict_optim, buffer_size = create_meta_dict(state_dict)
Expand All @@ -211,6 +221,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
meta_dict = self._meta_dict_optim
shared_save_flag = self._shared_save_optimizer_flag
shared_save_path = self._shared_save_optimizer_path
shared_save_signal_path = self._shared_save_optimizer_signal_path
if self._process_optimizer_weight is None:
self._process_optimizer_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
Expand All @@ -219,21 +230,26 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
self._shm_optimizer_weight.name,
self._shared_save_optimizer_flag,
self._shared_save_optimizer_path,
self._shared_save_optimizer_signal_path,
self._lock,
state_dict_type,
self.global_rank,
),
)
self._process_optimizer_weight.start()
process = self._process_optimizer_weight

while True: # wait until no process is saving.
flag_value = shared_save_flag[0]
if flag_value == 0:
break
if not process.is_alive():
raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.")
time.sleep(0.5)
logger.info(f"Wait for the previous save process to finish saving {state_dict_type}")
# only save model weight or save master weight, we enter this loop.
self._reset_and_update(shared_save_path, path)
self._reset_and_update(shared_save_signal_path, signal_path)
_traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf)
with self._lock:
shared_save_flag[0] = 1
Expand All @@ -244,6 +260,7 @@ def _save_file_async_in_process(
shm_name,
shared_save_flag,
shared_save_path,
shared_save_signal_path,
lock,
state_dict_type,
global_rank,
Expand All @@ -257,11 +274,12 @@ def _save_file_async_in_process(
continue
if flag_value == 1: # need to save
path = shared_save_path[:].decode("utf-8").rstrip("\x00")
signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00")
logger.info(f"Start to async save {path}")
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
safe_save_file(state_dict, path, {"format": "np"})
del state_dict
saved_signal_path = os.path.join(os.path.dirname(path), f".{state_dict_type}.done.{global_rank}")
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
paddle.save(global_rank, saved_signal_path)
with lock:
shared_save_flag[0] = 0
Expand All @@ -276,7 +294,7 @@ def _reset_and_update(self, shared_array, new_value):
encoded_value = new_value.encode("utf-8")
shared_array[: len(encoded_value)] = encoded_value

def save_unified_checkpoint(self, model, optimizer, output_dir):
def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None):
"""save unified checkpoint
Args:
Expand Down Expand Up @@ -313,6 +331,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):

save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)
if signal_dir is not None:
os.makedirs(signal_dir, exist_ok=True) # only for async save

# save model weights
if not skip_save_model_weight:
Expand All @@ -325,6 +345,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
self._file_save_async_or_sync(
state_dict,
path=os.path.join(save_directory, shard_file),
signal_path=signal_dir,
is_sync=is_sync_save,
state_dict_type="model_weight",
)
Expand Down Expand Up @@ -394,7 +415,7 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str)
if self.args.dataset_rank == 0 or self.args.use_expert_parallel:
load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True)

def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, output_dir):
def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, output_dir, signal_dir):
paddle.device.cuda.empty_cache()

# gather global master_weights status.
Expand Down Expand Up @@ -446,12 +467,14 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
self._file_save_async_or_sync(
optim_state_dict,
path=os.path.join(output_dir, optimizer_name),
signal_path=signal_dir,
is_sync=is_sync_save,
state_dict_type="optimizer_weight",
)
self._file_save_async_or_sync(
master_weights,
path=os.path.join(output_dir, master_weights_name),
signal_path=signal_dir,
is_sync=is_sync_save,
state_dict_type="master_weight",
)
Expand Down Expand Up @@ -501,17 +524,19 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):

return returned_optim_state_dict

def save_unified_optimizer(self, model, optimizer, output_dir):
def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
"""save unified optimizer
Args:
model (PretrainedModel): model used to get key mapping.
optimizer (Optimizer): optimizer to save
output_dir (str): Save directory.
signal_dir (str): Asynchronous saving signal directory.
"""

if paddle.distributed.get_world_size() <= 1:
save_single_card_optimizer(model, optimizer, output_dir)
save_single_card_optimizer(model, optimizer, output_dir) # no need to save signal

Check warning on line 539 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L538-L539

Added lines #L538 - L539 were not covered by tests
return

if (

Check warning on line 542 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L542

Added line #L542 was not covered by tests
Expand All @@ -530,7 +555,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir):
optim_state_dict.pop("LR_Scheduler")

Check warning on line 555 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L549-L555

Added lines #L549 - L555 were not covered by tests

if "ignore_merge_optimizer" in self.args.unified_checkpoint_config:
self.save_non_merge_optimizer(model, optim_state_dict, master_weights, output_dir)
self.save_non_merge_optimizer(model, optim_state_dict, master_weights, output_dir, signal_dir)

Check warning on line 558 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L557-L558

Added lines #L557 - L558 were not covered by tests
return

# Split into naive optimizer params and master weights.
Expand All @@ -547,20 +572,24 @@ def save_unified_optimizer(self, model, optimizer, output_dir):
paddle.device.cuda.empty_cache()
save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)
if signal_dir is not None:
os.makedirs(signal_dir, exist_ok=True)

is_sync_save = True
if "async_save" in self.args.unified_checkpoint_config:
is_sync_save = False
self._file_save_async_or_sync(
optim_state_dict,
path=os.path.join(save_directory, shard_optim_file),
signal_path=signal_dir,
is_sync=is_sync_save,
state_dict_type="optimizer_weight",
)
if master_weight_state_dict is not None:
self._file_save_async_or_sync(
master_weight_state_dict,
path=os.path.join(save_directory, shard_master_weight_file),
signal_path=signal_dir,
is_sync=is_sync_save,
state_dict_type="master_weight",
)
Expand Down Expand Up @@ -634,14 +663,20 @@ def unlink_shared_memory(self):

if self._shared_save_model_flag is not None:
while self._shared_save_model_flag[0] > 0: # async process is saving
if not self._process_model_weight.is_alive():
raise RuntimeError("The process that saves model_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_model_flag[0] = -1
if self._shared_save_master_weight_flag is not None:
while self._shared_save_master_weight_flag[0] > 0:
if not self._process_master_weight.is_alive():
raise RuntimeError("The process that saves master_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_master_weight_flag[0] = -1
if self._shared_save_optimizer_flag is not None:
while self._shared_save_optimizer_flag[0] > 0:
if not self._process_optimizer_weight.is_alive():
raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_optimizer_flag[0] = -1

Expand All @@ -658,7 +693,8 @@ def unlink_shared_memory(self):
self._shm_optimizer_weight.unlink()
self._shm_optimizer_weight = None

dist.barrier()
if paddle.distributed.get_world_size() > 1:
dist.barrier()


def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False):
Expand Down
Loading

0 comments on commit ae9ddce

Please sign in to comment.