Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Features]LoRA-GA supports pipeline parallel #9819

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 120 additions & 23 deletions paddlenlp/peft/lora/loraga_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Union

Check warning on line 15 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L15

Added line #L15 was not covered by tests

import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
Expand All @@ -23,6 +25,9 @@
except:
pass

import paddle.nn as nn
from paddle import framework

Check warning on line 29 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L28-L29

Added lines #L28 - L29 were not covered by tests

from paddlenlp.peft import LoRAModel
from paddlenlp.peft.lora.lora_layers import (
ColumnParallelLoRALinear,
Expand Down Expand Up @@ -53,6 +58,7 @@
logger.info(f"Initialization iterations for LoraGA: {loraga_init_iters}")
self.loraga_init_iters = loraga_init_iters
self.gradient_offload = gradient_offload
self.args.gradient_accumulation_steps = 1

Check warning on line 61 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L61

Added line #L61 was not covered by tests

def estimate_gradient(self, model: PretrainedModel):
"""
Expand Down Expand Up @@ -80,7 +86,6 @@
):
for batch in dataloader:
iters += 1
# Pipeline parallel not supported currently
self.training_step(model, batch)

if iters == self.loraga_init_iters:
Expand All @@ -101,9 +106,6 @@
in_sep_parallel_mode = self.args.sep_parallel_degree > 1
in_cp_parallel_mode = self.args.context_parallel_degree > 1

if in_pipeline_parallel_mode:
raise ValueError("LoRA-GA do not supported pipeline parallel currently.")

# Multi-gpu training
if self.args.world_size > 1 and (not self.args.use_hybrid_parallel):
# MOE use DDP to broadcaset parameters.
Expand All @@ -120,8 +122,18 @@
ddp_kwargs["find_unused_parameters"] = True
model = paddle.DataParallel(model, **ddp_kwargs)

# sharding
if in_sharding_parallel_mode:
# Pipeline mode
if in_pipeline_parallel_mode:
prepare_pipeline_inputs_func = (

Check warning on line 127 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L126-L127

Added lines #L126 - L127 were not covered by tests
model._prepare_pipeline_inputs_func if hasattr(model, "_prepare_pipeline_inputs_func") else None
)
if isinstance(model, LoRAModel):
model = model.model
model = fleet.distributed_model(model)
model._prepare_pipeline_inputs_func = prepare_pipeline_inputs_func

Check warning on line 133 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L130-L133

Added lines #L130 - L133 were not covered by tests

# No pipeline mode, sharding only
if not in_pipeline_parallel_mode and in_sharding_parallel_mode:

Check warning on line 136 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L136

Added line #L136 was not covered by tests
# Sharded DDP!
if self.args.tensor_parallel_degree > 1:
hcg = fleet.get_hybrid_communicate_group()
Expand All @@ -132,11 +144,80 @@
if ShardingOption.SHARD_OP in self.args.sharding:
model = fleet.distributed_model(model)

if not in_sharding_parallel_mode and (in_tensor_parallel_mode or in_sep_parallel_mode or in_cp_parallel_mode):
# pure tesnor parallel mode, no pipeline_parallel, no sharding.
if (

Check warning on line 148 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L148

Added line #L148 was not covered by tests
not in_pipeline_parallel_mode
and not in_sharding_parallel_mode
and (in_tensor_parallel_mode or in_sep_parallel_mode or in_cp_parallel_mode)
):
model = fleet.distributed_model(model)

return model

def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:

Check warning on line 157 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L157

Added line #L157 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块重写的目的是?会影响到LoRA吗?

Copy link
Contributor Author

@greycooker greycooker Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoRAGATrainer不需要优化器,所以我去掉了Trainer里的training_pipeline_step优化器相关的部分。LoRA不受影响,只有LoRA-GA在初始化的时候会走到LoRAGATrainer。

"""
Perform a training step on a batch of inputs.

Subclass and override to inject custom behavior.

Args:
model (`nn.Layer`):
The model to train.
inputs (`Dict[str, Union[paddle.Tensor, Any]]`):
The inputs and targets of the model.

The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.

Return:
`paddle.Tensor`: The tensor with training loss on this batch.
"""
# accumulation data
if not hasattr(self, "_pp_data_buffer"):
self._pp_data_buffer = []
self._pp_data_buffer.append(inputs)
if len(self._pp_data_buffer) != self.args.gradient_accumulation_steps:
return paddle.zeros([])

Check warning on line 180 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L176-L180

Added lines #L176 - L180 were not covered by tests

# for v in self._pp_data_buffer[0].values():
# assert isinstance(v, paddle.Tensor), f"Only support tensor as pipeline mode input, got type {type(v)}"

inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer)
self._pp_data_buffer = []

Check warning on line 186 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L185-L186

Added lines #L185 - L186 were not covered by tests

model.train()

Check warning on line 188 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L188

Added line #L188 was not covered by tests
# hack pipeline-layers
# since the pipeline layer will check input is valid every iter.
# in same case, for example, batch size warmup, we need dynamic change gradient_accumulation_steps to implement.
config_backup = model.micro_batch_size, model.accumulate_steps
model.micro_batch_size = self.args.per_device_train_batch_size
model.accumulate_steps = self.args.gradient_accumulation_steps

Check warning on line 194 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L192-L194

Added lines #L192 - L194 were not covered by tests

if model._dp_comm_overlap or model._sharding_comm_overlap:
for _, buffers in model._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer._acc_steps = self.args.gradient_accumulation_steps

Check warning on line 199 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L196-L199

Added lines #L196 - L199 were not covered by tests

# reset the virtual pp rank for each run
model.set_virtual_pipeline_rank(0)
assert framework._dygraph_tracer()._has_grad, "Please enable the generation of gradients."

Check warning on line 203 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L202-L203

Added lines #L202 - L203 were not covered by tests

if model.is_pipeline_first_stage(ignore_virtual=True) or model.is_pipeline_last_stage(ignore_virtual=True):
assert inputs is not None, "For the first and the last stage, the data must be set."

Check warning on line 206 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L205-L206

Added lines #L205 - L206 were not covered by tests
else:
inputs = None
model._layers.train()

Check warning on line 209 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L208-L209

Added lines #L208 - L209 were not covered by tests

model.optimizer = None # we do not use `PipelineParallel` to handler optimizer step
model.lr_scheduler = None

Check warning on line 212 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L211-L212

Added lines #L211 - L212 were not covered by tests

with self.autocast_smart_context_manager():
loss = model.forward_backward_pipeline(inputs, self.scaler if self.do_grad_scaling else None)

Check warning on line 215 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L214-L215

Added lines #L214 - L215 were not covered by tests

model.micro_batch_size, model.accumulate_steps = config_backup

Check warning on line 217 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L217

Added line #L217 was not covered by tests

return loss.detach()

Check warning on line 219 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L219

Added line #L219 was not covered by tests


def get_module_gradient(
grad_name,
Expand All @@ -147,6 +228,7 @@
sharding_degree,
dp_degree,
local_rank,
pp_to_single_mapping,
):
"""
Gather modules gradient in tensor parallel mode.
Expand Down Expand Up @@ -180,7 +262,12 @@

if tp_degree > 1:
# remove prefix and suffix in name
model_split_key = local_grad_name.split(base_model_prefix)[-1].rsplit(rank_suffix, 1)[0]
if pp_to_single_mapping is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pp + dp/sharding的情况下,不需要pp_to_single_mapping吗?

Copy link
Contributor Author

@greycooker greycooker Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pp_to_single_mapping是为了找到和lora_split_mapping和base_model_split_mapping对应的名称,不开tp的时候不需要split。

model_split_key = pp_to_single_mapping[".".join(grad_name.split(".")[1:]) + ".weight"]
model_split_key = model_split_key.split(base_model_prefix)[-1]

Check warning on line 267 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L265-L267

Added lines #L265 - L267 were not covered by tests
else:
model_split_key = local_grad_name.split(base_model_prefix)[-1].rsplit(rank_suffix, 1)[0]

Check warning on line 269 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L269

Added line #L269 was not covered by tests

if model_split_key in base_model_split_mappings:
merge_func = base_model_split_mappings[model_split_key]
output_tensors = []
Expand All @@ -197,12 +284,12 @@

# dp
if dp_degree > 1:
if data_parallel_group.nranks > 1:
if is_fleet_init:
dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=data_parallel_group)
else:
dist.all_reduce(gradient, op=dist.ReduceOp.SUM)
if is_fleet_init:
dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=data_parallel_group)

Check warning on line 288 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L287-L288

Added lines #L287 - L288 were not covered by tests
gradient /= data_parallel_group.nranks
else:
dist.all_reduce(gradient, op=dist.ReduceOp.SUM)
gradient /= dist.get_world_size()

Check warning on line 292 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L291-L292

Added lines #L291 - L292 were not covered by tests
return gradient


Expand All @@ -221,17 +308,20 @@
Returns:
None: Updates the model's weights and LoRA adapter weights in place.
"""
tensor_parallel_degree = training_args.tensor_parallel_degree
in_tensor_parallel_mode = tensor_parallel_degree > 1
lora_split_mapping = None
base_model_split_mappings = None
if in_tensor_parallel_mode:
base_model_split_mappings = model.model._get_tensor_parallel_mappings(config=model.config, is_split=False)
in_tensor_parallel_mode = training_args.tensor_parallel_degree > 1
in_pipleline_parallel_mode = training_args.pipeline_parallel_degree > 1

Check warning on line 312 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L311-L312

Added lines #L311 - L312 were not covered by tests

lora_split_mapping, base_model_split_mappings, pp_to_single_mapping = None, None, None

Check warning on line 314 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L314

Added line #L314 was not covered by tests

base_model_prefix = unwrap_model(model).base_model_prefix + "."
if in_tensor_parallel_mode:
base_model_split_mappings = model.model._get_tensor_parallel_mappings(config=model.config, is_split=False)

Check warning on line 317 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L317

Added line #L317 was not covered by tests
lora_split_mapping = model._get_tensor_parallel_mappings(model.config)

if in_pipleline_parallel_mode:
pp_to_single_mapping = model._pp_to_single_mapping

Check warning on line 321 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L320-L321

Added lines #L320 - L321 were not covered by tests

loraga_init_dict = {}
base_model_prefix = unwrap_model(model).base_model_prefix + "."

Check warning on line 324 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L324

Added line #L324 was not covered by tests
for name, module in model.named_sublayers():
if isinstance(
module,
Expand All @@ -253,6 +343,7 @@
training_args.sharding_parallel_degree,
training_args.data_parallel_degree,
training_args.local_rank,
pp_to_single_mapping,
)
# perform SVD to reinit base model weight and lora adapter weight
loraga_svd_module(
Expand All @@ -262,6 +353,7 @@
stable_gamma,
loraga_init_dict,
in_tensor_parallel_mode,
pp_to_single_mapping,
lora_split_mapping,
**kwargs,
)
Expand All @@ -276,14 +368,19 @@
stable_gamma,
loraga_init_dict,
in_tensor_parallel_mode=False,
pp_to_single_mapping=None,
lora_split_mapping=None,
**kwargs
):
with paddle.no_grad():
lora_r = module.r

loraA_name = ".".join(name.split(".")[1:]) + ".lora_A"
loraB_name = ".".join(name.split(".")[1:]) + ".lora_B"
if pp_to_single_mapping is not None:
name = pp_to_single_mapping[".".join(name.split(".")[1:]) + ".weight"]
loraA_name = name.replace("weight", "lora_A")
loraB_name = name.replace("weight", "lora_B")

Check warning on line 380 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L377-L380

Added lines #L377 - L380 were not covered by tests
else:
loraA_name = ".".join(name.split(".")[1:]) + ".lora_A"
loraB_name = ".".join(name.split(".")[1:]) + ".lora_B"

Check warning on line 383 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L382-L383

Added lines #L382 - L383 were not covered by tests
# Perform SVD to gradients
U, S, V = paddle.linalg.svd_lowrank(grads.astype("float32"), q=4 * lora_r, niter=4)

Expand Down