diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 4e03892d17a8d1..45d9a72ade3de4 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -92,6 +92,7 @@ message PpConfig { optional bool enable_offload_queue = 11 [ default = false ]; optional bool enable_dynamic_shape = 12 [ default = false ]; optional bool use_dualpipev = 13 [ default = false ]; + optional bool forward_backward_overlap_scheduler = 14 [ default = false ]; } message DygraphShardingConfig { diff --git a/python/paddle/distributed/fleet/meta_parallel/dualpipev.py b/python/paddle/distributed/fleet/meta_parallel/dualpipev.py index 0e039a40763893..b1276429da1fbd 100644 --- a/python/paddle/distributed/fleet/meta_parallel/dualpipev.py +++ b/python/paddle/distributed/fleet/meta_parallel/dualpipev.py @@ -150,9 +150,7 @@ def _forward_compute(self, phase: int, micro_datasets=None) -> None: inputs = self._get_forward_inputs(micro_datasets, phase, acc_id) if self.overlapped_forward_backward: - schedule_chunk = self._layers.forward( - inputs, chunk_id=phase, overlap_schedule_mode=True - ) + schedule_chunk = self._layers.get_schedule_chunk(chunk_id=phase) outputs = schedule_chunk.forward(inputs) else: schedule_chunk = None @@ -296,9 +294,7 @@ def _forward_backward_compute( ) # forward & backward - forward_chunk = self._layers.forward( - None, chunk_id=forward_phase, overlap_schedule_mode=True - ) + forward_chunk = self._layers.get_schedule_chunk(chunk_id=forward_phase) backward_chunk = self.schedule_chunks[backward_phase][backward_acc_id] forward_outputs, forward_loss, backward_input_grads = ( self._layers.overlapped_forward_backward( diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 6a0044904509be..978dee98fb8105 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -1017,7 +1017,7 @@ def execute_func(*x): return execute_func - def forward(self, input, chunk_id=None, overlap_schedule_mode=False): + def update_run_function(self, chunk_id): if chunk_id is not None: assert isinstance(chunk_id, int), "chunk_id should be an int" assert ( @@ -1035,9 +1035,15 @@ def forward(self, input, chunk_id=None, overlap_schedule_mode=False): # But for interleave, self.run_function will keep updating to the target functions at every run. self.run_function = model_chunk.get_run_function() - if overlap_schedule_mode: - assert self._recompute_interval == 0 - return self.build_schedule_nodes(0, len(self.run_function)) + def get_schedule_chunk(self, chunk_id): + self.update_run_function(chunk_id) + + assert self._recompute_interval == 0 + return self.build_schedule_nodes(0, len(self.run_function)) + + def forward(self, input, chunk_id=None): + self.update_run_function(chunk_id) + if self._recompute_interval == 0: input = self.forward_function(0, len(self.run_function))(input) else: diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 0837fe329fb937..b63b5ebbce6d51 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -18,7 +18,9 @@ import time import warnings from collections import defaultdict, deque +from dataclasses import dataclass from enum import Enum +from functools import partial from typing import Callable import paddle @@ -754,7 +756,7 @@ def forward_backward_pipeline( ) self._record_stamp("F", step_id, '"B"', self._forward_color) - output_tensor = self._forward_step( + output_tensor, _, _ = self._forward_step( input_tensor, micro_dataset, step_id=step_id ) self._record_stamp("F", step_id, '"E"', self._forward_color) @@ -788,7 +790,7 @@ def forward_backward_pipeline( self._record_stamp( "F", startup_steps + i, '"B"', self._forward_color ) - output_tensor = self._forward_step( + output_tensor, _, _ = self._forward_step( input_tensor, micro_dataset, step_id=startup_steps + i ) self._record_stamp( @@ -1018,7 +1020,7 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0): batch_p2p_comm=self._use_batch_p2p_comm, ) - output_tensor = self._forward_step( + output_tensor, _, _ = self._forward_step( input_tensor, micro_dataset, step_id=None ) self._p2p_helper.send_forward( @@ -1040,7 +1042,7 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0): for i in range(steady_steps): last_iter = i == (steady_steps - 1) - output_tensor = self._forward_step( + output_tensor, _, _ = self._forward_step( input_tensor, micro_dataset, step_id=None ) self._p2p_helper.send_forward( @@ -1066,8 +1068,67 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0): return self.train_loss + def _maybe_loss_compute( + self, output_tensor, micro_dataset, overlap_schedule_mode=False + ): + backward_loss_tensor = None + backward_loss_fn_node = None + loss_fn_node = None + + if self.is_pipeline_last_stage(): + # train calculate loss for train + if self._compute_loss: + assert ( + self._layers._loss_fn[self.loss_fn_idx] is not None + ), "loss function should exist to compute loss" + labels = next(micro_dataset)[1] + self._check_micro_batch_data_valid(labels) + for idx, loss_fn in enumerate(self._layers._loss_fn): + if overlap_schedule_mode: + loss_fn_node = loss_fn.build_schedule_node() + loss_fn_node.labels = labels + if ( + self.accumulate_steps > 1 + and not self._delay_scale_loss + ): + loss_fn_node.scale_loss_factor = ( + self.accumulate_steps + ) + loss_tensor = loss_fn_node.forward(output_tensor) + else: + loss_tensor = loss_fn(output_tensor, labels) + assert isinstance( + loss_tensor, paddle.Tensor + ), "Currently, loss_fn should obtain Paddle.Tensor dtype" + + with paddle.amp.auto_cast(enable=False): + if ( + self.accumulate_steps > 1 + and not self._delay_scale_loss + ): + loss_tensor = ( + loss_tensor / self.accumulate_steps + ) + + if self.total_loss is None: + self.total_loss = [] + # when self.total_loss length is less than idx, append a new tensor + if len(self.total_loss) <= idx: + self.total_loss.append([]) + self.total_loss[idx].append(loss_tensor.detach()) + + if idx == self.loss_fn_idx: + backward_loss_tensor = loss_tensor + backward_loss_fn_node = loss_fn_node + return backward_loss_tensor, backward_loss_fn_node + def _forward_step( - self, input_tensor, micro_dataset, chunk_id=None, step_id=None + self, + input_tensor, + micro_dataset, + chunk_id=None, + step_id=None, + overlap_schedule_mode=False, ): if self.user_hooks_enabled: self.forward_hooks.run_hook() @@ -1086,7 +1147,16 @@ def _forward_step( input_tensor=input_tensor, step_id=step_id, ) - output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id) + + schedule_chunk = None + if overlap_schedule_mode: + schedule_chunk = self._layers.get_schedule_chunk(chunk_id=chunk_id) + output_tensor = schedule_chunk.forward(input_tensor) + else: + output_tensor = self._layers.forward( + input_tensor, chunk_id=chunk_id + ) + self.callbacks.on_location( PipelineParallelMicroStepLocations.FORWARD_END, input_tensor=input_tensor, @@ -1094,36 +1164,10 @@ def _forward_step( step_id=step_id, ) - if self.is_pipeline_last_stage(): - # train calculate loss for train - if self._compute_loss: - assert ( - self._layers._loss_fn[self.loss_fn_idx] is not None - ), "loss function should exist to compute loss" - labels = next(micro_dataset)[1] - self._check_micro_batch_data_valid(labels) - for idx, loss_fn in enumerate(self._layers._loss_fn): - loss_tensor = loss_fn(output_tensor, labels) - assert isinstance( - loss_tensor, paddle.Tensor - ), "Currently, loss_fn should obtain Paddle.Tensor dtype" - - with paddle.amp.auto_cast(enable=False): - if ( - self.accumulate_steps > 1 - and not self._delay_scale_loss - ): - loss_tensor = loss_tensor / self.accumulate_steps - - if self.total_loss is None: - self.total_loss = [] - # when self.total_loss length is less than idx, append a new tensor - if len(self.total_loss) <= idx: - self.total_loss.append([]) - self.total_loss[idx].append(loss_tensor.detach()) + backward_loss_tensor, backward_loss_fn_node = self._maybe_loss_compute( + output_tensor, micro_dataset, overlap_schedule_mode + ) - if idx == self.loss_fn_idx: - backward_loss_tensor = loss_tensor if self.is_pipeline_first_stage() or self.is_pipeline_last_stage(): # Only increase micro batch id at virtual first/last pp stage. # The micro batch id is used to load data, therefore, only increase it when load data. @@ -1133,11 +1177,18 @@ def _forward_step( if self.processed_steps < g_profile_pipeline_details_steps: get_sync_logger().info("After forward_step") if self.is_pipeline_last_stage() and self._compute_loss: - return backward_loss_tensor - return output_tensor + return backward_loss_tensor, schedule_chunk, backward_loss_fn_node + return output_tensor, schedule_chunk, backward_loss_fn_node def _backward_step( - self, input_tensor, output_tensor, output_tensor_grad, step_id=None + self, + input_tensor, + output_tensor, + output_tensor_grad, + step_id=None, + overlap_schedule_mode=False, + schedule_chunk=None, + loss_fn_node=None, ): if self.user_hooks_enabled: self.backward_hooks.run_hook() @@ -1155,35 +1206,61 @@ def _backward_step( ) if self.is_pipeline_last_stage(): assert output_tensor_grad is None - # In align mode, we scale the grad directly after forward - if paddle.distributed.in_auto_parallel_align_mode(): - output_tensor = output_tensor / _get_align_mode_scale() - if self.scaler: - paddle.autograd.backward(self.scaler.scale(output_tensor)) + if overlap_schedule_mode: + assert ( + loss_fn_node is not None and schedule_chunk is not None + ), "loss_fn_node and schedule_chunk should not be None in overlap_schedule_mode" + input_tensor_grad = loss_fn_node.backward( + scaler=self.scaler + ) + input_tensor_grad = schedule_chunk.backward( + input_tensor_grad + ) else: - paddle.autograd.backward(output_tensor) + # In align mode, we scale the grad directly after forward + if paddle.distributed.in_auto_parallel_align_mode(): + output_tensor = output_tensor / _get_align_mode_scale() + if self.scaler: + paddle.autograd.backward( + self.scaler.scale(output_tensor) + ) + else: + paddle.autograd.backward(output_tensor) else: if isinstance(output_tensor, tuple): outputs = [t for t in output_tensor if not t.stop_gradient] assert len(outputs) == len(output_tensor_grad) - paddle.autograd.backward( - tensors=outputs, - grad_tensors=list(output_tensor_grad), - ) + grad_tensors = list(output_tensor_grad) + else: + outputs = [output_tensor] + grad_tensors = [output_tensor_grad] + + if overlap_schedule_mode: + assert ( + schedule_chunk is not None + ), "schedule_chunk should not be None in overlap_schedule_mode" + input_tensor_grad = schedule_chunk.backward(grad_tensors) else: paddle.autograd.backward( - tensors=[output_tensor], - grad_tensors=[output_tensor_grad], + tensors=outputs, + grad_tensors=grad_tensors, ) - input_tensor_grad = None - if input_tensor is not None: - if isinstance(input_tensor, tuple): - input_tensor_grad = tuple( - [t.grad for t in input_tensor if not t.stop_gradient] - ) - else: - input_tensor_grad = input_tensor.grad + if not overlap_schedule_mode: + # Extract input_tensor_grad from the input tensor. In overlap_schedule_mode, + # the input_tensor_grad is extracted inside the schedule_chunk. + input_tensor_grad = None + if input_tensor is not None: + if isinstance(input_tensor, tuple): + input_tensor_grad = tuple( + [ + t.grad + for t in input_tensor + if not t.stop_gradient + ] + ) + else: + input_tensor_grad = input_tensor.grad if self._enable_timer: self.timers("backward_step").stop() self.callbacks.on_location( @@ -1320,11 +1397,63 @@ def get_static_scheduler(self): return self.forward_backward_pipeline(data=None, static_scheduler=True) +@dataclass +class P2PAsyncHandle: + # funcs + forward_handle_wait_fn: Callable + forward_async_comm_fn: Callable + backward_handle_wait_fn: Callable + backward_async_comm_fn: Callable + + # outputs + next_forward_virtual_pp_rank = None + input_tensor = None + out_fwd_wait_handles = None + next_backward_virtual_pp_rank = None + output_tensor_grad = None + recv_next = None + out_bwd_wait_handles = None + + def forward_handle_wait(self): + self.forward_handle_wait_fn() + + def forward_async_comm(self, output_tensor): + ( + self.next_forward_virtual_pp_rank, + self.input_tensor, + self.out_fwd_wait_handles, + ) = self.forward_async_comm_fn(output_tensor=output_tensor) + + def backward_handle_wait(self): + self.backward_handle_wait_fn() + + def backward_async_comm(self, input_tensor_grad): + ( + self.next_backward_virtual_pp_rank, + self.output_tensor_grad, + self.recv_next, + self.out_bwd_wait_handles, + ) = self.backward_async_comm_fn(input_tensor_grad=input_tensor_grad) + + class PipelineParallelWithInterleave(PipelineParallel): # pipeline parallel with interleave scheduler def __init__(self, layers, hcg, strategy): super().__init__(layers=layers, hcg=hcg, strategy=strategy) + self.overlap_schedule_mode = ( + hasattr(type(self._layers), "overlapped_forward_backward") + and self._strategy.hybrid_configs[ + "pp_configs" + ].forward_backward_overlap_scheduler + ) + + if self.overlap_schedule_mode: + assert ( + not self._profiling + ), "Profiling is not compatible with overlap_schedule_mode." + logger.info(f"Using {self._get_scheduler_name()}") + self._record_format = ( '"name": "{}{}_VP{}", "cat": "virtual pipeline timeline", "ph": {}, "pid": 0, "tid": ' + str(self.stage_id + 1) @@ -1368,6 +1497,9 @@ def __init__(self, layers, hcg, strategy): # reinit user hook since now we have virtual stages self._init_user_hooks() + def _get_scheduler_name(self): + return f"PipelineParallelWithInterleave with overlapping forward backward={self.overlap_schedule_mode}, overlap p2p comm={self._overlap_p2p_comm}" + def _init_user_bubble_hooks(self): # initialize bubble hooks self.bubble_hooks = PipelineHook() @@ -1509,10 +1641,7 @@ def _get_virtual_pp_rank(self, micro_step, forward): return virtual_pp_stage - def _forward_step_helper(self, micro_dataset, micro_step): - virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True) - self.set_virtual_pipeline_rank(virtual_pp_rank) - + def _get_forward_input(self, virtual_pp_rank): # some checkers assert hasattr(self, 'input_tensors') assert hasattr(self, 'output_tensors') @@ -1522,16 +1651,45 @@ def _forward_step_helper(self, micro_dataset, micro_step): len(self.output_tensors[virtual_pp_rank]) + 1 ) input_tensor = self.input_tensors[virtual_pp_rank][-1] - output_tensor = self._forward_step( - input_tensor, micro_dataset, virtual_pp_rank, step_id=micro_step - ) + return input_tensor + + def _store_forward_outputs( + self, + virtual_pp_rank, + output_tensor, + schedule_chunk=None, + loss_fn_node=None, + ): self.output_tensors[virtual_pp_rank].append(output_tensor) + # If overlap_schedule_mode eq False, the schedule chunk is a None + self.schedule_chunks[virtual_pp_rank].append(schedule_chunk) + if self.is_pipeline_last_stage(): + self.loss_fn_chunks.append(loss_fn_node) if self._forward_only: # no need to store tensor for backward self.input_tensors[virtual_pp_rank].pop() self.output_tensors[virtual_pp_rank].pop() + def _forward_step_helper( + self, micro_dataset, micro_step, overlap_schedule_mode=False + ): + virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True) + self.set_virtual_pipeline_rank(virtual_pp_rank) + + input_tensor = self._get_forward_input(virtual_pp_rank) + + output_tensor, schedule_chunk, loss_fn_node = self._forward_step( + input_tensor, + micro_dataset, + virtual_pp_rank, + step_id=micro_step, + overlap_schedule_mode=overlap_schedule_mode, + ) + + self._store_forward_outputs( + virtual_pp_rank, output_tensor, schedule_chunk, loss_fn_node + ) return output_tensor def _overlap_comm_grads(self): @@ -1567,10 +1725,7 @@ def _sync_overlap_grads(self): for buffer in buffers: buffer.scale_grads() - def _backward_step_helper(self, micro_step): - virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False) - self.set_virtual_pipeline_rank(virtual_pp_rank) - + def _get_backward_input(self, virtual_pp_rank): # some checkers assert hasattr(self, 'input_tensors') assert hasattr(self, 'output_tensors') @@ -1586,14 +1741,190 @@ def _backward_step_helper(self, micro_step): input_tensor = self.input_tensors[virtual_pp_rank].pop(0) output_tensor = self.output_tensors[virtual_pp_rank].pop(0) output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0) + schedule_chunk = self.schedule_chunks[virtual_pp_rank].pop(0) + if self.is_pipeline_last_stage(): + loss_fn_node = self.loss_fn_chunks.pop(0) + else: + loss_fn_node = None + + return ( + input_tensor, + output_tensor, + output_tensor_grad, + schedule_chunk, + loss_fn_node, + ) + + def _backward_step_helper(self, micro_step, overlap_schedule_mode=False): + virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False) + self.set_virtual_pipeline_rank(virtual_pp_rank) + + ( + input_tensor, + output_tensor, + output_tensor_grad, + schedule_chunk, + loss_fn_node, + ) = self._get_backward_input(virtual_pp_rank) + input_tensor_grad = self._backward_step( - input_tensor, output_tensor, output_tensor_grad, step_id=micro_step + input_tensor, + output_tensor, + output_tensor_grad, + step_id=micro_step, + overlap_schedule_mode=overlap_schedule_mode, + schedule_chunk=schedule_chunk, + loss_fn_node=loss_fn_node, ) self._overlap_comm_grads() return input_tensor_grad + def _forward_backward_helper( + self, + micro_dataset, + forward_micro_step_id, + backward_micro_step_id, + p2p_async_handle=None, + ): + if not self.overlap_schedule_mode: + if p2p_async_handle is not None: + p2p_async_handle.forward_handle_wait() + + self._record_stamp("F", forward_micro_step_id, '"B"', forward=True) + output_tensor = self._forward_step_helper( + micro_dataset, + forward_micro_step_id, + ) + self._record_stamp("F", forward_micro_step_id, '"E"', forward=True) + + if p2p_async_handle is not None: + p2p_async_handle.forward_async_comm(output_tensor) + p2p_async_handle.backward_handle_wait() + + # backward + self._record_stamp( + "B", backward_micro_step_id, '"B"', forward=False + ) + input_tensor_grad = self._backward_step_helper( + backward_micro_step_id, + ) + self._record_stamp( + "B", backward_micro_step_id, '"E"', forward=False + ) + + if p2p_async_handle is not None: + p2p_async_handle.backward_async_comm(input_tensor_grad) + return + else: + return output_tensor, input_tensor_grad + else: + # 1. prepare forward inputs + forward_virtual_pp_rank = self._get_virtual_pp_rank( + forward_micro_step_id, forward=True + ) + self.set_virtual_pipeline_rank(forward_virtual_pp_rank) + + if self.user_hooks_enabled: + self.forward_hooks.run_hook() + + forward_inputs = self._get_forward_input(forward_virtual_pp_rank) + if self.is_pipeline_first_stage(): + forward_inputs = next(micro_dataset)[0] + self._check_micro_batch_data_valid(forward_inputs) + if self.is_pipeline_last_stage(): + labels = next(micro_dataset)[1] + + # 2. get forward chunks + forward_chunk = self._layers.get_schedule_chunk( + chunk_id=forward_virtual_pp_rank + ) + + if self.is_pipeline_last_stage(): + assert len(self._layers._loss_fn) == 1 + forward_loss_fn_node = self._layers._loss_fn[ + 0 + ].build_schedule_node() + forward_loss_fn_node.labels = labels + if self.accumulate_steps > 1 and not self._delay_scale_loss: + forward_loss_fn_node.scale_loss_factor = ( + self.accumulate_steps + ) + else: + forward_loss_fn_node = None + + # 3. prepare backward inputs & get backward chunks + backward_virtual_pp_rank = self._get_virtual_pp_rank( + backward_micro_step_id, forward=False + ) + self.set_virtual_pipeline_rank(backward_virtual_pp_rank) + + if self.user_hooks_enabled: + self.backward_hooks.run_hook() + + ( + _, + _, + backward_grads, + backward_chunk, + backward_loss_fn_node, + ) = self._get_backward_input(backward_virtual_pp_rank) + + # 4. forward & backward + if self.processed_steps < g_profile_pipeline_details_steps: + get_sync_logger().info("Before forward_backward_step") + if self._enable_timer: + self.timers("forward_backward_step").start() + output_tensor, forward_loss, input_tensor_grad = ( + self._layers.overlapped_forward_backward( + forward_chunk, + forward_inputs, + forward_loss_fn_node, + backward_chunk, + backward_loss_fn_node, + backward_grads, + self.scaler, + p2p_async_handle=p2p_async_handle, + ) + ) + if self.processed_steps < g_profile_pipeline_details_steps: + get_sync_logger().info("After forward_backward_step") + if self._enable_timer: + self.timers("forward_backward_step").stop() + + # 5. process forward outputs + forward_virtual_pp_rank = self._get_virtual_pp_rank( + forward_micro_step_id, forward=True + ) + self.set_virtual_pipeline_rank(forward_virtual_pp_rank) + self._store_forward_outputs( + forward_virtual_pp_rank, + output_tensor, + forward_chunk, + forward_loss_fn_node, + ) + + if self.is_pipeline_first_stage() or self.is_pipeline_last_stage(): + # Only increase micro batch id at virtual first/last pp stage. + # The micro batch id is used to load data, therefore, only increase it when load data. + self.micro_batch_id += 1 + + if self.is_pipeline_last_stage(): + # In overlap mode, only one loss_fn is supported. + if self.total_loss is None: + self.total_loss = [[]] + self.total_loss[0].append(forward_loss.detach()) + + # 6. process backward outputs + backward_virtual_pp_rank = self._get_virtual_pp_rank( + backward_micro_step_id, forward=False + ) + self.set_virtual_pipeline_rank(backward_virtual_pp_rank) + self._overlap_comm_grads() + + return output_tensor, input_tensor_grad + def bw_hook_func(self, buffer, param): # For pipeline with interleave, we need to add grad to buffer without communication. # Use communication where appropriate to avoid dp communication and pp scheduling conflicts. @@ -1609,6 +1940,14 @@ def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp): model, comm_group, acc_steps, dp, group_size=sys.maxsize ) + def _init_buffers(self): + # init some data buffers for interleave scheduler + self.input_tensors = [[] for _ in range(self.num_model_chunks)] + self.output_tensors = [[] for _ in range(self.num_model_chunks)] + self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)] + self.schedule_chunks = [[] for _ in range(self.num_model_chunks)] + self.loss_fn_chunks = [] + def forward_backward_pipeline( self, data, @@ -1649,6 +1988,15 @@ def forward_backward_pipeline( self._using_cache ), "cache should be enabled for pipeline with interleave" + self.overlap_schedule_mode = ( + hasattr(type(self._layers), "overlapped_forward_backward") + and self._strategy.hybrid_configs[ + "pp_configs" + ].forward_backward_overlap_scheduler + ) + if forward_only: + self.overlap_schedule_mode = False + # init some attributes for this batch run self.scaler = scaler self.total_loss = None @@ -1733,10 +2081,7 @@ def _process_bwd_buffer(step_id, tensor): * self.num_model_chunks ) - # init some data buffers for interleave scheduler - self.input_tensors = [[] for _ in range(self.num_model_chunks)] - self.output_tensors = [[] for _ in range(self.num_model_chunks)] - self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)] + self._init_buffers() micro_dataset = self._wrap_data(data) @@ -1753,6 +2098,9 @@ def _process_bwd_buffer(step_id, tensor): startup_steps += (self.num_model_chunks - 1) * first_chunk_acc startup_steps = min(startup_steps, num_steps) + # An additional micro step is needed for overplapping schedule + if self.overlap_schedule_mode: + startup_steps += 1 steady_steps = num_steps - startup_steps for location in range(self.stage_id): @@ -1795,7 +2143,11 @@ def _process_bwd_buffer(step_id, tensor): continue self._record_stamp("F", micro_step, '"B"', forward=True) - output_tensor = self._forward_step_helper(micro_dataset, micro_step) + output_tensor = self._forward_step_helper( + micro_dataset, + micro_step, + overlap_schedule_mode=self.overlap_schedule_mode, + ) self._record_stamp("F", micro_step, '"E"', forward=True) if micro_step >= startup_steps - rest_bubble_times: @@ -1932,100 +2284,119 @@ def _process_bwd_buffer(step_id, tensor): forward_micro_step_id = micro_step + startup_steps if self._overlap_p2p_comm: - if fwd_wait_handles is not None: - for req in fwd_wait_handles: - req.wait() - - self._release_output(output_tensor) - output_tensor = self._forward_step_helper( - micro_dataset, forward_micro_step_id - ) + backward_micro_step_id = micro_step - forward_virtual_pp_rank = self._get_virtual_pp_rank( - forward_micro_step_id, forward=True - ) - self.set_virtual_pipeline_rank(forward_virtual_pp_rank) - if self.is_pipeline_last_stage(ignore_virtual=True): - output_tensor = _process_fwd_buffer( - forward_micro_step_id, output_tensor - ) + def forward_handle_wait(fwd_wait_handles, output_tensor): + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + self._release_output(output_tensor) - # determine whether to recv input tensor from upstream - recv_prev = True - if self.is_pipeline_first_stage(ignore_virtual=True): - next_forward_virtual_pp_rank = self._get_virtual_pp_rank( - forward_micro_step_id + 1, forward=True - ) - if next_forward_virtual_pp_rank == 0: - # next chunk is the first chunk, not need to pre recv an input tensor - recv_prev = False - else: - next_forward_virtual_pp_rank = self._get_virtual_pp_rank( - forward_micro_step_id + 1, forward=True + def forward_async_comm(forward_micro_step_id, output_tensor): + forward_virtual_pp_rank = self._get_virtual_pp_rank( + forward_micro_step_id, forward=True ) + self.set_virtual_pipeline_rank(forward_virtual_pp_rank) - # last iteration doesn't need recv from upstream - if micro_step == (steady_steps - 1): - recv_prev = False + # determine whether to recv input tensor from upstream + recv_prev = True + if self.is_pipeline_first_stage(ignore_virtual=True): + next_forward_virtual_pp_rank = ( + self._get_virtual_pp_rank( + forward_micro_step_id + 1, forward=True + ) + ) + if next_forward_virtual_pp_rank == 0: + # next chunk is the first chunk, not need to pre recv an input tensor + recv_prev = False + else: + next_forward_virtual_pp_rank = ( + self._get_virtual_pp_rank( + forward_micro_step_id + 1, forward=True + ) + ) - # Send activation tensor to the next stage and receive activation tensor from the - # previous stage - ( - input_tensor, - fwd_wait_handles, - ) = self._p2p_helper.send_forward_recv_forward( - output_tensor, - recv_prev=recv_prev, - batch_p2p_comm=self._use_batch_p2p_comm, - overlap_p2p_comm=True, - skip_check_meta=not self.training, - ) + # last iteration doesn't need recv from upstream + if micro_step == (steady_steps - 1): + recv_prev = False - if bwd_wait_handles is not None: - for req in bwd_wait_handles: - req.wait() + if self.is_pipeline_last_stage(ignore_virtual=True): + output_tensor = _process_fwd_buffer( + forward_micro_step_id, output_tensor + ) + # Send activation tensor to the next stage and receive activation tensor from the + # previous stage + ( + input_tensor, + fwd_wait_handles, + ) = self._p2p_helper.send_forward_recv_forward( + output_tensor, + recv_prev=recv_prev, + batch_p2p_comm=self._use_batch_p2p_comm, + overlap_p2p_comm=True, + skip_check_meta=not self.training, + ) + return ( + next_forward_virtual_pp_rank, + input_tensor, + fwd_wait_handles, + ) - # backward pass - backward_micro_step_id = micro_step - input_tensor_grad = self._backward_step_helper( - backward_micro_step_id - ) + def backward_handle_wait(bwd_wait_handles): + if bwd_wait_handles is not None: + for req in bwd_wait_handles: + req.wait() - if ( - self._best_unbalanced_scheduler - and self.is_pipeline_last_stage(ignore_virtual=True) + def backward_async_comm( + backward_micro_step_id, input_tensor_grad ): - cur_pp_rank = self._get_virtual_pp_rank( - backward_micro_step_id, forward=False - ) - if cur_pp_rank != 0: - last_stage_recv_queue.append( - (backward_micro_step_id, cur_pp_rank) + if ( + self._best_unbalanced_scheduler + and self.is_pipeline_last_stage(ignore_virtual=True) + ): + cur_pp_rank = self._get_virtual_pp_rank( + backward_micro_step_id, forward=False ) + if cur_pp_rank != 0: + last_stage_recv_queue.append( + (backward_micro_step_id, cur_pp_rank) + ) - # first stage doesn't send grad to upstream - backward_virtual_pp_rank = self._get_virtual_pp_rank( - backward_micro_step_id, forward=False - ) - self.set_virtual_pipeline_rank(backward_virtual_pp_rank) - if self.is_pipeline_first_stage(ignore_virtual=True): - input_tensor_grad = _process_bwd_buffer( - backward_micro_step_id, input_tensor_grad + # first stage doesn't send grad to upstream + backward_virtual_pp_rank = self._get_virtual_pp_rank( + backward_micro_step_id, forward=False ) + self.set_virtual_pipeline_rank(backward_virtual_pp_rank) + if self.is_pipeline_first_stage(ignore_virtual=True): + input_tensor_grad = _process_bwd_buffer( + backward_micro_step_id, input_tensor_grad + ) - recv_next = True - if self.is_pipeline_last_stage(ignore_virtual=True): - if self._best_unbalanced_scheduler: - next_backward_virtual_pp_rank = ( - self._get_virtual_pp_rank( - backward_micro_step_id + 1, - forward=False, + recv_next = True + if self.is_pipeline_last_stage(ignore_virtual=True): + if self._best_unbalanced_scheduler: + next_backward_virtual_pp_rank = ( + self._get_virtual_pp_rank( + backward_micro_step_id + 1, + forward=False, + ) ) - ) - if self.is_pipeline_last_stage(ignore_virtual=True): - recv_next = _last_stage_need_recv_next( - backward_micro_step_id + 1 + if self.is_pipeline_last_stage(ignore_virtual=True): + recv_next = _last_stage_need_recv_next( + backward_micro_step_id + 1 + ) + else: + next_backward_virtual_pp_rank = ( + self._get_virtual_pp_rank( + backward_micro_step_id + 1, + forward=False, + ) ) + if next_backward_virtual_pp_rank == ( + self.num_model_chunks - 1 + ): + # next chunk is the last chunk, not need to pre recv an output tensor grad + recv_next = False else: next_backward_virtual_pp_rank = ( self._get_virtual_pp_rank( @@ -2033,48 +2404,73 @@ def _process_bwd_buffer(step_id, tensor): forward=False, ) ) - if next_backward_virtual_pp_rank == ( - self.num_model_chunks - 1 - ): - # next chunk is the last chunk, not need to pre recv an output tensor grad - recv_next = False - else: - next_backward_virtual_pp_rank = self._get_virtual_pp_rank( - backward_micro_step_id + 1, - forward=False, + + ( + output_tensor_grad, + bwd_wait_handles, + ) = self._p2p_helper.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + batch_p2p_comm=self._use_batch_p2p_comm, + overlap_p2p_comm=True, + ) + return ( + next_backward_virtual_pp_rank, + output_tensor_grad, + recv_next, + bwd_wait_handles, ) - ( - output_tensor_grad, - bwd_wait_handles, - ) = self._p2p_helper.send_backward_recv_backward( - input_tensor_grad, - recv_next=recv_next, - batch_p2p_comm=self._use_batch_p2p_comm, - overlap_p2p_comm=True, + # Package some closure functions and parameters into `P2PAsyncHandle` + # structure to simplify function parameter passing + p2p_async_handle = P2PAsyncHandle( + partial( + forward_handle_wait, + fwd_wait_handles=fwd_wait_handles, + output_tensor=output_tensor, + ), + partial( + forward_async_comm, + forward_micro_step_id=forward_micro_step_id, + ), + partial( + backward_handle_wait, bwd_wait_handles=bwd_wait_handles + ), + partial( + backward_async_comm, + backward_micro_step_id=backward_micro_step_id, + ), ) - else: - self._record_stamp( - "F", forward_micro_step_id, '"B"', forward=True - ) - output_tensor = self._forward_step_helper( - micro_dataset, forward_micro_step_id - ) - self._record_stamp( - "F", forward_micro_step_id, '"E"', forward=True + + self._forward_backward_helper( + micro_dataset, + forward_micro_step_id, + backward_micro_step_id, + p2p_async_handle, ) - # backward - backward_micro_step_id = micro_step - self._record_stamp( - "B", backward_micro_step_id, '"B"', forward=False + # Information that needs to be updated + next_forward_virtual_pp_rank = ( + p2p_async_handle.next_forward_virtual_pp_rank ) - input_tensor_grad = self._backward_step_helper( - backward_micro_step_id + input_tensor = p2p_async_handle.input_tensor + fwd_wait_handles = p2p_async_handle.out_fwd_wait_handles + next_backward_virtual_pp_rank = ( + p2p_async_handle.next_backward_virtual_pp_rank ) - self._record_stamp( - "B", backward_micro_step_id, '"E"', forward=False + output_tensor_grad = p2p_async_handle.output_tensor_grad + recv_next = p2p_async_handle.recv_next + bwd_wait_handles = p2p_async_handle.out_bwd_wait_handles + else: + backward_micro_step_id = micro_step + output_tensor, input_tensor_grad = ( + self._forward_backward_helper( + micro_dataset, + forward_micro_step_id, + backward_micro_step_id, + ) ) + if ( self._best_unbalanced_scheduler and self.is_pipeline_last_stage(ignore_virtual=True) @@ -2245,7 +2641,9 @@ def _process_bwd_buffer(step_id, tensor): # cooldown loop self._record_stamp("B", micro_step, '"B"', forward=False) - input_tensor_grad = self._backward_step_helper(micro_step) + input_tensor_grad = self._backward_step_helper( + micro_step, overlap_schedule_mode=self.overlap_schedule_mode + ) self._record_stamp("B", micro_step, '"E"', forward=False) next_backward_virtual_pp_rank = self._get_virtual_pp_rank( micro_step + 1, @@ -2421,6 +2819,10 @@ def get_static_scheduler(self): class PipelineParallelWithInterleaveFthenB(PipelineParallelWithInterleave): def __init__(self, layers, hcg, strategy): super().__init__(layers=layers, hcg=hcg, strategy=strategy) + self.overlap_schedule_mode = False + + def _get_scheduler_name(self): + return "PipelineParallelWithInterleaveFthenB" def _init_user_bubble_hooks(self): # (TODO:gexiao) support bubble hooks if needed @@ -2521,10 +2923,7 @@ def forward_backward_pipeline( skip_steps = self.accumulate_steps - self.num_stages send_recv_buffer_queue = queue.Queue() - # init some data buffers for interleave scheduler - self.input_tensors = [[] for _ in range(self.num_model_chunks)] - self.output_tensors = [[] for _ in range(self.num_model_chunks)] - self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)] + self._init_buffers() micro_dataset = self._wrap_data(data) num_steps = self.accumulate_steps * self.num_model_chunks @@ -2708,7 +3107,10 @@ def get(self, *args, **kwargs): class VPPFhenBInBalancedMemory(PipelineParallelWithInterleaveFthenB): def __init__(self, layers, hcg, strategy): super().__init__(layers=layers, hcg=hcg, strategy=strategy) - logger.info("Using VPPFhenBInBalancedMemory") + self.overlap_schedule_mode = False + + def _get_scheduler_name(self): + return "VPPFhenBInBalancedMemory" def _init_user_bubble_hooks(self): # (TODO:gexiao) support bubble hooks if needed @@ -2751,10 +3153,8 @@ def forward_backward_pipeline( self.micro_batch_id = 0 self._forward_only = forward_only - # init some data buffers for interleave scheduler - self.input_tensors = [[] for _ in range(self.num_model_chunks)] - self.output_tensors = [[] for _ in range(self.num_model_chunks)] - self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)] + self._init_buffers() + backward_send_recv_buffer_queue = OffloadQueue( offload=self._enable_offload_queue ) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/forward_backward_overlap_utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/forward_backward_overlap_utils.py index c97a25cd4c68bd..6a3a1da1279435 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/forward_backward_overlap_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/forward_backward_overlap_utils.py @@ -38,16 +38,42 @@ def _check_nodes_valid(self): def detach_and_requires_grad(inputs): - ret = [] - for input in inputs: - if isinstance(input, (tuple, list)): - ret.append(detach_and_requires_grad(input)) - else: - tmp = input.detach() if input is not None else None - if tmp is not None: - tmp.stop_gradient = input.stop_gradient - ret.append(tmp) - return ret + if isinstance(inputs, (tuple, list)): + is_tuple = isinstance(inputs, tuple) + ret = [] + for input in inputs: + if isinstance(input, (tuple, list)): + ret.append(detach_and_requires_grad(input)) + else: + tmp = input.detach() if input is not None else None + if tmp is not None: + tmp.stop_gradient = input.stop_gradient + ret.append(tmp) + if is_tuple: + ret = tuple(ret) + return ret + else: + tmp = inputs.detach() + tmp.stop_gradient = inputs.stop_gradient + return tmp + + +def clone_and_clear_dataptr(outputs, clear_dataptr=False): + if isinstance(outputs, (tuple, list)): + is_tuple = isinstance(outputs, tuple) + ret = [FakeClone.apply(o) for o in outputs if o is not None] + + if clear_dataptr: + for o in ret: + o._clear_dataptr() + if is_tuple: + ret = tuple(ret) + return ret + else: + ret = FakeClone.apply(outputs) + if clear_dataptr: + ret._clear_dataptr() + return ret class FakeClone(paddle.autograd.PyLayer): @@ -76,8 +102,6 @@ def __init__(self, fwd_func, name=""): self.scale_loss_factor = None def forward(self, inputs=()): - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) detached_inputs = detach_and_requires_grad(inputs) self.inputs = detached_inputs if self.labels is not None: @@ -86,13 +110,10 @@ def forward(self, inputs=()): outputs = self.fwd_func(self.inputs) if self.scale_loss_factor is not None: outputs /= self.scale_loss_factor - if not isinstance(outputs, (tuple, list)): - outputs = (outputs,) - self.outputs = [FakeClone.apply(o) for o in outputs if o is not None] - if self.labels is None: - # Do not release the loss tensor. - for o in self.outputs: - o._clear_dataptr() + + # Do not release the loss tensor. + clear_dataptr = self.labels is None + self.outputs = clone_and_clear_dataptr(outputs, clear_dataptr) return outputs def backward(self, output_grad=None, scaler=None): @@ -100,6 +121,8 @@ def backward(self, output_grad=None, scaler=None): if isinstance(self.outputs, (tuple, list)): assert len(self.outputs) == 1 outputs = self.outputs[0] + else: + outputs = self.outputs assert isinstance(outputs, paddle.Tensor) if scaler is not None: paddle.autograd.backward(scaler.scale(outputs)) @@ -117,7 +140,12 @@ def backward(self, output_grad=None, scaler=None): ), f"{len(outputs)} of {type(outputs[0])} vs {len(output_grad)} of {type(output_grad[0])}" paddle.autograd.backward(outputs, output_grad) - grad = tuple([e.grad if e is not None else None for e in self.inputs]) + + if not isinstance(self.inputs, (tuple, list)): + inputs = (self.inputs,) + else: + inputs = self.inputs + grad = tuple([e.grad if e is not None else None for e in inputs]) self._reset_states() if len(grad) == 1: