-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[New Features]LoRA-GA supports pipeline parallel #9819
Open
greycooker
wants to merge
2
commits into
PaddlePaddle:develop
Choose a base branch
from
greycooker:loraga_support_pp
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,8 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, Dict, Union | ||
|
||
import paddle | ||
import paddle.distributed as dist | ||
from paddle.distributed import fleet | ||
|
@@ -23,6 +25,9 @@ | |
except: | ||
pass | ||
|
||
import paddle.nn as nn | ||
from paddle import framework | ||
|
||
from paddlenlp.peft import LoRAModel | ||
from paddlenlp.peft.lora.lora_layers import ( | ||
ColumnParallelLoRALinear, | ||
|
@@ -53,6 +58,7 @@ | |
logger.info(f"Initialization iterations for LoraGA: {loraga_init_iters}") | ||
self.loraga_init_iters = loraga_init_iters | ||
self.gradient_offload = gradient_offload | ||
self.args.gradient_accumulation_steps = 1 | ||
|
||
def estimate_gradient(self, model: PretrainedModel): | ||
""" | ||
|
@@ -80,7 +86,6 @@ | |
): | ||
for batch in dataloader: | ||
iters += 1 | ||
# Pipeline parallel not supported currently | ||
self.training_step(model, batch) | ||
|
||
if iters == self.loraga_init_iters: | ||
|
@@ -101,9 +106,6 @@ | |
in_sep_parallel_mode = self.args.sep_parallel_degree > 1 | ||
in_cp_parallel_mode = self.args.context_parallel_degree > 1 | ||
|
||
if in_pipeline_parallel_mode: | ||
raise ValueError("LoRA-GA do not supported pipeline parallel currently.") | ||
|
||
# Multi-gpu training | ||
if self.args.world_size > 1 and (not self.args.use_hybrid_parallel): | ||
# MOE use DDP to broadcaset parameters. | ||
|
@@ -120,8 +122,18 @@ | |
ddp_kwargs["find_unused_parameters"] = True | ||
model = paddle.DataParallel(model, **ddp_kwargs) | ||
|
||
# sharding | ||
if in_sharding_parallel_mode: | ||
# Pipeline mode | ||
if in_pipeline_parallel_mode: | ||
prepare_pipeline_inputs_func = ( | ||
model._prepare_pipeline_inputs_func if hasattr(model, "_prepare_pipeline_inputs_func") else None | ||
) | ||
if isinstance(model, LoRAModel): | ||
model = model.model | ||
model = fleet.distributed_model(model) | ||
model._prepare_pipeline_inputs_func = prepare_pipeline_inputs_func | ||
|
||
# No pipeline mode, sharding only | ||
if not in_pipeline_parallel_mode and in_sharding_parallel_mode: | ||
# Sharded DDP! | ||
if self.args.tensor_parallel_degree > 1: | ||
hcg = fleet.get_hybrid_communicate_group() | ||
|
@@ -132,11 +144,80 @@ | |
if ShardingOption.SHARD_OP in self.args.sharding: | ||
model = fleet.distributed_model(model) | ||
|
||
if not in_sharding_parallel_mode and (in_tensor_parallel_mode or in_sep_parallel_mode or in_cp_parallel_mode): | ||
# pure tesnor parallel mode, no pipeline_parallel, no sharding. | ||
if ( | ||
not in_pipeline_parallel_mode | ||
and not in_sharding_parallel_mode | ||
and (in_tensor_parallel_mode or in_sep_parallel_mode or in_cp_parallel_mode) | ||
): | ||
model = fleet.distributed_model(model) | ||
|
||
return model | ||
|
||
def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: | ||
""" | ||
Perform a training step on a batch of inputs. | ||
|
||
Subclass and override to inject custom behavior. | ||
|
||
Args: | ||
model (`nn.Layer`): | ||
The model to train. | ||
inputs (`Dict[str, Union[paddle.Tensor, Any]]`): | ||
The inputs and targets of the model. | ||
|
||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | ||
argument `labels`. Check your model's documentation for all accepted arguments. | ||
|
||
Return: | ||
`paddle.Tensor`: The tensor with training loss on this batch. | ||
""" | ||
# accumulation data | ||
if not hasattr(self, "_pp_data_buffer"): | ||
self._pp_data_buffer = [] | ||
self._pp_data_buffer.append(inputs) | ||
if len(self._pp_data_buffer) != self.args.gradient_accumulation_steps: | ||
return paddle.zeros([]) | ||
|
||
# for v in self._pp_data_buffer[0].values(): | ||
# assert isinstance(v, paddle.Tensor), f"Only support tensor as pipeline mode input, got type {type(v)}" | ||
|
||
inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer) | ||
self._pp_data_buffer = [] | ||
|
||
model.train() | ||
# hack pipeline-layers | ||
# since the pipeline layer will check input is valid every iter. | ||
# in same case, for example, batch size warmup, we need dynamic change gradient_accumulation_steps to implement. | ||
config_backup = model.micro_batch_size, model.accumulate_steps | ||
model.micro_batch_size = self.args.per_device_train_batch_size | ||
model.accumulate_steps = self.args.gradient_accumulation_steps | ||
|
||
if model._dp_comm_overlap or model._sharding_comm_overlap: | ||
for _, buffers in model._chunk_2_comm_buffers.items(): | ||
for buffer in buffers: | ||
buffer._acc_steps = self.args.gradient_accumulation_steps | ||
|
||
# reset the virtual pp rank for each run | ||
model.set_virtual_pipeline_rank(0) | ||
assert framework._dygraph_tracer()._has_grad, "Please enable the generation of gradients." | ||
|
||
if model.is_pipeline_first_stage(ignore_virtual=True) or model.is_pipeline_last_stage(ignore_virtual=True): | ||
assert inputs is not None, "For the first and the last stage, the data must be set." | ||
else: | ||
inputs = None | ||
model._layers.train() | ||
|
||
model.optimizer = None # we do not use `PipelineParallel` to handler optimizer step | ||
model.lr_scheduler = None | ||
|
||
with self.autocast_smart_context_manager(): | ||
loss = model.forward_backward_pipeline(inputs, self.scaler if self.do_grad_scaling else None) | ||
|
||
model.micro_batch_size, model.accumulate_steps = config_backup | ||
|
||
return loss.detach() | ||
|
||
|
||
def get_module_gradient( | ||
grad_name, | ||
|
@@ -147,6 +228,7 @@ | |
sharding_degree, | ||
dp_degree, | ||
local_rank, | ||
pp_to_single_mapping, | ||
): | ||
""" | ||
Gather modules gradient in tensor parallel mode. | ||
|
@@ -180,7 +262,12 @@ | |
|
||
if tp_degree > 1: | ||
# remove prefix and suffix in name | ||
model_split_key = local_grad_name.split(base_model_prefix)[-1].rsplit(rank_suffix, 1)[0] | ||
if pp_to_single_mapping is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pp + dp/sharding的情况下,不需要pp_to_single_mapping吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pp_to_single_mapping是为了找到和lora_split_mapping和base_model_split_mapping对应的名称,不开tp的时候不需要split。 |
||
model_split_key = pp_to_single_mapping[".".join(grad_name.split(".")[1:]) + ".weight"] | ||
model_split_key = model_split_key.split(base_model_prefix)[-1] | ||
else: | ||
model_split_key = local_grad_name.split(base_model_prefix)[-1].rsplit(rank_suffix, 1)[0] | ||
|
||
if model_split_key in base_model_split_mappings: | ||
merge_func = base_model_split_mappings[model_split_key] | ||
output_tensors = [] | ||
|
@@ -197,12 +284,12 @@ | |
|
||
# dp | ||
if dp_degree > 1: | ||
if data_parallel_group.nranks > 1: | ||
if is_fleet_init: | ||
dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=data_parallel_group) | ||
else: | ||
dist.all_reduce(gradient, op=dist.ReduceOp.SUM) | ||
if is_fleet_init: | ||
dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=data_parallel_group) | ||
gradient /= data_parallel_group.nranks | ||
else: | ||
dist.all_reduce(gradient, op=dist.ReduceOp.SUM) | ||
gradient /= dist.get_world_size() | ||
return gradient | ||
|
||
|
||
|
@@ -221,17 +308,20 @@ | |
Returns: | ||
None: Updates the model's weights and LoRA adapter weights in place. | ||
""" | ||
tensor_parallel_degree = training_args.tensor_parallel_degree | ||
in_tensor_parallel_mode = tensor_parallel_degree > 1 | ||
lora_split_mapping = None | ||
base_model_split_mappings = None | ||
if in_tensor_parallel_mode: | ||
base_model_split_mappings = model.model._get_tensor_parallel_mappings(config=model.config, is_split=False) | ||
in_tensor_parallel_mode = training_args.tensor_parallel_degree > 1 | ||
in_pipleline_parallel_mode = training_args.pipeline_parallel_degree > 1 | ||
|
||
lora_split_mapping, base_model_split_mappings, pp_to_single_mapping = None, None, None | ||
|
||
base_model_prefix = unwrap_model(model).base_model_prefix + "." | ||
if in_tensor_parallel_mode: | ||
base_model_split_mappings = model.model._get_tensor_parallel_mappings(config=model.config, is_split=False) | ||
lora_split_mapping = model._get_tensor_parallel_mappings(model.config) | ||
|
||
if in_pipleline_parallel_mode: | ||
pp_to_single_mapping = model._pp_to_single_mapping | ||
|
||
loraga_init_dict = {} | ||
base_model_prefix = unwrap_model(model).base_model_prefix + "." | ||
for name, module in model.named_sublayers(): | ||
if isinstance( | ||
module, | ||
|
@@ -253,6 +343,7 @@ | |
training_args.sharding_parallel_degree, | ||
training_args.data_parallel_degree, | ||
training_args.local_rank, | ||
pp_to_single_mapping, | ||
) | ||
# perform SVD to reinit base model weight and lora adapter weight | ||
loraga_svd_module( | ||
|
@@ -262,6 +353,7 @@ | |
stable_gamma, | ||
loraga_init_dict, | ||
in_tensor_parallel_mode, | ||
pp_to_single_mapping, | ||
lora_split_mapping, | ||
**kwargs, | ||
) | ||
|
@@ -276,14 +368,19 @@ | |
stable_gamma, | ||
loraga_init_dict, | ||
in_tensor_parallel_mode=False, | ||
pp_to_single_mapping=None, | ||
lora_split_mapping=None, | ||
**kwargs | ||
): | ||
with paddle.no_grad(): | ||
lora_r = module.r | ||
|
||
loraA_name = ".".join(name.split(".")[1:]) + ".lora_A" | ||
loraB_name = ".".join(name.split(".")[1:]) + ".lora_B" | ||
if pp_to_single_mapping is not None: | ||
name = pp_to_single_mapping[".".join(name.split(".")[1:]) + ".weight"] | ||
loraA_name = name.replace("weight", "lora_A") | ||
loraB_name = name.replace("weight", "lora_B") | ||
else: | ||
loraA_name = ".".join(name.split(".")[1:]) + ".lora_A" | ||
loraB_name = ".".join(name.split(".")[1:]) + ".lora_B" | ||
# Perform SVD to gradients | ||
U, S, V = paddle.linalg.svd_lowrank(grads.astype("float32"), q=4 * lora_r, niter=4) | ||
|
||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块重写的目的是?会影响到LoRA吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LoRAGATrainer不需要优化器,所以我去掉了Trainer里的training_pipeline_step优化器相关的部分。LoRA不受影响,只有LoRA-GA在初始化的时候会走到LoRAGATrainer。