Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…nto fix_sp_tokens
  • Loading branch information
DesmonDay committed Feb 11, 2025
2 parents 2f9774d + 86286e0 commit 37040fe
Show file tree
Hide file tree
Showing 19 changed files with 308 additions and 43 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
| [ChatGLM3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm3-6b |
| [DeepSeekV2](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V2, deepseek-ai/DeepSeek-V2-Chat, deepseek-ai/DeepSeek-V2-Lite, deepseek-ai/DeepSeek-V2-Lite-Chat, deepseek-ai/DeepSeek-Coder-V2-Base, deepseek-ai/DeepSeek-Coder-V2-Instruct, deepseek-ai/DeepSeek-Coder-V2-Lite-Base, deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct |
| [DeepSeekV3](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-V3-Base |
| [DeepSeek-R1](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-R1, deepseek-ai/DeepSeek-R1-Zero, deepseek-ai/DeepSeek-R1-Distill-Llama-70B, deepseek-ai/DeepSeek-R1-Distill-Llama-8B, deepseek-ai/DeepSeek-R1-Distill-Qwen-14B, deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, deepseek-ai/DeepSeek-R1-Distill-Qwen-32B, deepseek-ai/DeepSeek-R1-Distill-Qwen-7B |
| [Gemma](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/gemma) | google/gemma-7b, google/gemma-7b-it, google/gemma-2b, google/gemma-2b-it |
| [Mistral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mistral) | mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-7B-v0.1 |
| [Mixtral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mixtral) | mistralai/Mixtral-8x7B-Instruct-v0.1 |
Expand Down
9 changes: 5 additions & 4 deletions llm/alignment/ppo/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
speed_metrics,
)
from paddlenlp.transformers import PretrainedModel, PretrainedTokenizer
from paddlenlp.utils import empty_device_cache


class StepTrainer(Trainer):
Expand Down Expand Up @@ -1032,7 +1033,7 @@ def gen_epoch_data():
ptx_batches = [None for _ in range(len(rl_batches))]
self.timers and self.timers("ptx-batch").stop()

paddle.device.cuda.empty_cache()
empty_device_cache()

self.set_train()
for _ in range(self.args.update_iters):
Expand Down Expand Up @@ -1152,7 +1153,7 @@ def train(

# ##### model and optimizer related setting #####
policy_model, value_model = self.init_train_model_opt(max_steps, resume_from_checkpoint)
paddle.device.cuda.empty_cache()
empty_device_cache()

# ##### traing statistic logging #####
# Number of trainable parameters only account for policy_model
Expand Down Expand Up @@ -1208,7 +1209,7 @@ def train(
# with self.enable(self.value_trainer.optimizer):
with self.enable(): # put value optimizer guard in rl_step
rl_info = self.rl_step(rl_batch)
paddle.device.cuda.empty_cache()
empty_device_cache()
self.timers and self.timers("rl_step").stop()

if self.use_ptx:
Expand All @@ -1224,7 +1225,7 @@ def train(
ptx_info = self.ptx_step(ptx_batch)
rl_info.update(ptx_info)
self.timers and self.timers("ptx_step").stop()
paddle.device.cuda.empty_cache()
empty_device_cache()

self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
Expand Down
2 changes: 1 addition & 1 deletion llm/docs/predict/mixtral.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

|Model|
|:-|
|mistralai/Mixtral-8x7B-v0.1-Instruct|
|mistralai/Mixtral-8x7B-Instruct-v0.1|


## 模型推理
Expand Down
4 changes: 4 additions & 0 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,15 @@ def _preprocess(self, source):
source = [source] if isinstance(source, str) else source
source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source]

return_attention_mask = False
if len(source) > 1:
return_attention_mask = True
tokenized_source = self.tokenizer(
source,
max_length=self.config.src_length,
truncation=True,
return_position_ids=True if not isinstance(self.tokenizer, ChatGLMTokenizer) else False,
return_attention_mask=return_attention_mask,
truncation_side="left",
return_tensors=self.return_tensors,
padding=True,
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/quantization/quantization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle.nn.quant import weight_quantize

from ..utils.log import logger
from ..utils.memory_utils import empty_device_cache
from .quantization_linear import (
ColumnParallelQuantizationLinear,
QuantizationLinear,
Expand Down Expand Up @@ -150,7 +151,7 @@ def convert_to_quantize_state_dict_without_check(state_dict, quantization_linear
state_dict.update(qlora_state_dict)
del target_weight
gc.collect()
paddle.device.cuda.empty_cache()
empty_device_cache()
return state_dict


Expand Down
39 changes: 38 additions & 1 deletion paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,8 +713,13 @@ def _save_checkpoint(self, model, metrics=None):
for key, value in model.state_dict("opt").items()
if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS)
}
model_state_dict = model.state_dict("param")
if self.args.should_save_model_with_tensor_fusion:
model_state_dict = self._convert_state_dict_for_saving_tensor_fusion_ckpt(model_state_dict)
opt_state_dict = self._convert_state_dict_for_saving_tensor_fusion_ckpt(opt_state_dict)

state_dict = {
MODEL_NAME: model.state_dict("param"),
MODEL_NAME: model_state_dict,
OPTIMIZER_NAME: opt_state_dict,
}
else:
Expand Down Expand Up @@ -854,6 +859,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
for key, value in self.model_wrapped.state_dict("opt").items()
if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS)
}
if self.args.should_load_model_with_tensor_fusion:
model_state_dict = self._convert_state_dict_for_loading_tensor_fusion_ckpt(model_state_dict)
optim_state_dict = self._convert_state_dict_for_loading_tensor_fusion_ckpt(optim_state_dict)
else:
model_state_dict = self.model_wrapped.state_dict()
optim_state_dict = self.optimizer.state_dict()
Expand Down Expand Up @@ -888,7 +896,36 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
self._load_ckpt_func(state_dict, ckpt_path)

if self.args.to_static:
if self.args.should_load_model_with_tensor_fusion:
model_state_dict = self._convert_state_dict_for_loading_model_with_tensor_fusion(model_state_dict)
optim_state_dict = self._convert_state_dict_for_loading_model_with_tensor_fusion(optim_state_dict)

self.model_wrapped.set_state_dict(model_state_dict)
self.model_wrapped.set_state_dict(optim_state_dict)
# release memory
del state_dict

def _convert_state_dict_for_loading_tensor_fusion_ckpt(self, state_dict):
if self.args.load_model_with_sharding_tensor_fusion:
logger.info("load sharding tensor fusion unbalanced model")
state_dict = self.model_wrapped._convert_state_dict_with_rank_unique_name(state_dict)
else:
logger.info("load sharding tensor fusion balanced model")
state_dict = self.model_wrapped._convert_state_dict_without_tensor_fusion_param(state_dict)
return state_dict

def _convert_state_dict_for_loading_model_with_tensor_fusion(self, state_dict):
if self.args.load_model_with_sharding_tensor_fusion:
state_dict = self.model_wrapped._convert_state_dict_with_origin_name(state_dict)
else:
state_dict = self.model_wrapped._convert_state_dict_with_tensor_fusion_param(state_dict)
return state_dict

def _convert_state_dict_for_saving_tensor_fusion_ckpt(self, state_dict):
if self.args.save_model_with_sharding_tensor_fusion:
logger.info("save sharding tensor fusion unbalanced model")
state_dict = self.model_wrapped._convert_state_dict_with_rank_unique_name(state_dict)
else:
logger.info("save sharding tensor fusion balanced model")
state_dict = self.model_wrapped._convert_state_dict_without_tensor_fusion_param(state_dict)
return state_dict
35 changes: 34 additions & 1 deletion paddlenlp/trainer/auto_training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json
from dataclasses import dataclass, field

from .trainer_utils import split_parallel_config
from .trainer_utils import ShardingOption, split_parallel_config
from .training_args import TrainingArguments
from .utils import add_start_docstrings

Expand Down Expand Up @@ -52,6 +52,29 @@ class AutoTrainingArguments(TrainingArguments):
metadata={"help": "Weather to use auto_parallel intermediate api"},
)
refined_ops_patterns: str = field(default=None, metadata={"help": "The pattern of refined recompute."})
load_model_with_sharding_tensor_fusion: bool = field(
default=False,
metadata={
"help": (
"When using sharding stage1, enabling tensor fusion, and setting `load_model_with_sharding_tensor_fusion` to `True`, "
"the model is loaded with unbalanced weights, meaning that the model weights are stored in an unbalanced format to avoid "
"additional memory overhead. If set to `False`, the model will be loaded with balanced weights, which may increase memory "
"consumption. This setting is only available in auto parallel to_static mode."
)
},
)
save_model_with_sharding_tensor_fusion: bool = field(
default=False,
metadata={
"help": (
"When using sharding stage1 and enabling tensor fusion, setting `save_model_with_sharding_tensor_fusion` to `True` "
"saves the model with unbalanced weights, which helps avoid additional memory consumption. Setting it to `False` "
"saves the model with balanced weights, which may increase memory usage but ensures uniform parameter distribution. "
"This option allows flexibility in choosing the save format based on memory requirements. "
"This setting is only available in auto parallel to_static mode."
)
},
)

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -89,3 +112,13 @@ def __post_init__(self):
recompute.refined_ops_patterns = (
self.refined_ops_patterns if self.refined_ops_patterns is not None else []
)

@property
def should_load_model_with_tensor_fusion(self):
return (
self.enable_auto_parallel
and self.to_static
and ShardingOption.SHARD_OP in self.sharding
and self.sharding_parallel_degree > 1
and "enable_tensor_fusion" in self.sharding_parallel_config
)
18 changes: 12 additions & 6 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ class TrainingArguments:
)
},
)

tensor_parallel_degree: int = field(
default=-1,
metadata={
Expand Down Expand Up @@ -740,7 +741,6 @@ class TrainingArguments:
"enable_stage2_overlap, overlap stage2 NCCL communication with computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap and no other sync could be called during the training for broadcast overlap\n"
"enable_stage1_broadcast_overlap, overlap stage1 V1 broadcast with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap forward compute and no other sync could be called during the training for broadcast overlap.\n"
"enable_stage1_allgather_overlap, overlap stage1 V2 allgather with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for allgather overlap forward compute and no other sync could be called during the training for allgather overlap.\n"
"enable_tensor_fusion_blanced_save_load, convert unbalanced optimizer state to balanced state when using tensor fusion strategy, which may increase the memory occupation."
)
},
)
Expand Down Expand Up @@ -1671,7 +1671,6 @@ def is_segment_parallel_supported():
"enable_tensor_fusion",
"enable_overlap",
"enable_release_grads",
"enable_tensor_fusion_blanced_save_load",
]:
if x in ["enable_stage1_overlap", "enable_stage2_overlap"]:
raise ValueError(
Expand All @@ -1686,7 +1685,7 @@ def is_segment_parallel_supported():
raise ValueError(
f"Found unknown sharding mode config {x}, "
f"accpet config is enable_tensor_fusion, "
"enable_overlap, enable_release_grads, enable_tensor_fusion_blanced_save_load."
"enable_overlap, enable_release_grads."
)

if "enable_overlap" in sharding_parallel_config:
Expand All @@ -1696,9 +1695,6 @@ def is_segment_parallel_supported():
sharding.grad_bucket_size_numel = 210355872
sharding.enable_tensor_fusion = True

if "enable_tensor_fusion_blanced_save_load" in sharding_parallel_config:
sharding.save_unbalanced_param = False

if "enable_release_grads" in sharding_parallel_config:
sharding.release_gradients = True

Expand Down Expand Up @@ -2273,3 +2269,13 @@ def print_config(self, args=None, key=""):
logger.debug("{:30}: {}".format(a, v))

logger.debug("")

@property
def should_save_model_with_tensor_fusion(self):
return (
self.enable_auto_parallel
and self.to_static
and ShardingOption.SHARD_OP in self.sharding
and self.sharding_parallel_degree > 1
and "enable_tensor_fusion" in self.sharding_parallel_config
)
20 changes: 10 additions & 10 deletions paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
unwrap_model,
)
from paddlenlp.transformers.utils import dtype_byte_size
from paddlenlp.utils import infohub
from paddlenlp.utils import empty_device_cache, infohub
from paddlenlp.utils.env import (
LORA_WEIGHTS_NAME,
MAX_QUANTIZATION_TIMES,
Expand Down Expand Up @@ -158,7 +158,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None)
if self.args.should_save:
save_model_config(model_to_save, save_directory)

paddle.device.cuda.empty_cache()
empty_device_cache()

if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and self.args.should_save:
world_size = paddle.distributed.get_world_size()
Expand Down Expand Up @@ -195,7 +195,7 @@ def load_unified_checkpoint(self, model, resume_from_checkpoint: str):
load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True)

def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, output_dir, signal_dir):
paddle.device.cuda.empty_cache()
empty_device_cache()

# gather global master_weights status.
global_master_weights = reduce_master_weights_status(master_weights is not None)
Expand Down Expand Up @@ -375,7 +375,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
optim_state_dict, shard_optim_file, sharded_optim_index = results[0]
master_weight_state_dict, shard_master_weight_file, sharded_master_weight_index = results[1]

paddle.device.cuda.empty_cache()
empty_device_cache()
save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)
if signal_dir is not None:
Expand Down Expand Up @@ -508,7 +508,7 @@ def unified_checkpoint_into_shards(
Returns:
tuple: state_dict, config, shard_file: file name, sharded_index: map for weight to file name.
"""
paddle.device.cuda.empty_cache()
empty_device_cache()
assert hasattr(model_to_save, "config")

state_dict = get_expected_state_dict(model_to_save, concat_additional_adapter=True)
Expand Down Expand Up @@ -560,7 +560,7 @@ def unified_checkpoint_into_shards(
elif isinstance(model_to_save, PrefixModelForCausalLM):
sharded_index["type"] = "ptuning"

paddle.device.cuda.empty_cache()
empty_device_cache()

return state_dict, shard_file, sharded_index

Expand All @@ -578,7 +578,7 @@ def unified_optimizer_into_shards(
optimizer (Optimizer): optimizer to save.
safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False.
"""
paddle.device.cuda.empty_cache()
empty_device_cache()

# gather global master_weights status.
global_master_weights = reduce_master_weights_status(master_weights is not None)
Expand Down Expand Up @@ -645,7 +645,7 @@ def unified_optimizer_into_shards(
filter_optim_keys,
state_dict if args.use_expert_parallel else None,
)
paddle.device.cuda.empty_cache()
empty_device_cache()

if master_weights is not None:
logger.info("Unified master weight tensor parallel in shards")
Expand All @@ -655,7 +655,7 @@ def unified_optimizer_into_shards(
filter_master_keys,
state_dict if args.use_expert_parallel else None,
)
paddle.device.cuda.empty_cache()
empty_device_cache()

# build index json file
index_optimizer_file, index_master_weight_file = {}, {}
Expand Down Expand Up @@ -706,7 +706,7 @@ def unified_optimizer_into_shards(
else:
sharded_optim_index["master_weights"] = False

paddle.device.cuda.empty_cache()
empty_device_cache()
if master_weights is None:
return [(optim_state_dict, shard_optimizer_file, sharded_optim_index)]
else:
Expand Down
Loading

0 comments on commit 37040fe

Please sign in to comment.