diff --git a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h index 85762a03aebf93..6f36a772b0be16 100644 --- a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h +++ b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h @@ -82,11 +82,6 @@ class CUDAMallocAsyncAllocator : public Allocator { void SetDefaultStream(gpuStream_t stream); void ClearFreeStream(bool sync = false); - ~CUDAMallocAsyncAllocator() { - VLOG(0) << "Async allocator is freed " << (this) - << " tid = " << std::this_thread::get_id(); - } - protected: void FreeImpl(phi::Allocation* allocation) override; phi::Allocation* AllocateImpl(size_t size) override; diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.cc b/paddle/phi/backends/gpu/cuda/cuda_graph.cc index 71181263c26a51..6c0c373dc8a6cf 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.cc +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/backends/gpu/cuda/cuda_graph.h" +#include "glog/logging.h" #include "paddle/common/flags.h" #ifdef PADDLE_WITH_CUDA diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.h b/paddle/phi/backends/gpu/cuda/cuda_graph.h index f794cf0fa65366..538b4a52d03ac4 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.h +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.h @@ -27,8 +27,6 @@ #include #include -#include "glog/logging.h" - #include "paddle/common/errors.h" #include "paddle/common/macros.h" #include "paddle/phi/backends/context_pool.h" @@ -65,7 +63,6 @@ class CUDAGraphContextManager { DeviceContext *Get(int64_t pool_id, const Place &place, int stream_priority) { std::lock_guard lk(ctx_mtx_); - VLOG(6) << "Get cuda graph device context for " << place; DeviceContextMap &ctxs = cuda_graph_ctx_pool_[pool_id]; if (ctxs.find(place) == ctxs.end()) { diff --git a/python/paddle/device/cuda/cuda_graphed_layer.py b/python/paddle/device/cuda/cuda_graphed_layer.py index d2dee9b3422a78..069d8a6c9476c4 100644 --- a/python/paddle/device/cuda/cuda_graphed_layer.py +++ b/python/paddle/device/cuda/cuda_graphed_layer.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import os from collections import deque from enum import Enum @@ -21,11 +22,59 @@ from .graphs import CUDAGraph +# CUDAGraphedLayer Debug tools +enable_debug_print = bool( + int(os.getenv('PADDLE_DEBUG_ENABLE_CUDAGRAPH_LAYER_LOGGING', '0')) +) +debug_cudagraphedlayer_fallback_to_default = bool( + int(os.getenv('PADDLE_DEBUG_CUDAGRAPHEDLAYER_FALLBACK_TO_DEFAULT', '0')) +) + logger = log_helper.get_logger( __name__, logging.INFO, fmt='[%(levelname)s] %(message)s' ) +def debug_print(x): + if not enable_debug_print: + return + logger.info(x) + + +def print_tensor( + t, + name="Unnamed", + print_meta=True, + print_ptr=False, + print_hash=True, + hash=None, +): + output = [] + if name: + output.append(name) + if hash is None: + hash = lambda t: float((t.astype('float32') * 1000).sum()) + + if t is None: + debug_print(f"{name} is None") + elif isinstance(t, paddle.Tensor): + if print_meta: + output.append(f"shape = {t.shape}") + output.append(f"place = {t.place}") + if print_ptr: + output.append(f"ptr = {hex(t.data_ptr())}") + if print_hash: + output.append(f"hash = {hash(t)}") + debug_print(" | ".join(output)) + + +def printer(x, banner="printer"): + if not enable_debug_print: + return + debug_print(banner.center(100, "-")) + recursive_apply(print_tensor, x) + + # We need this function, for any kind of inputs with iterables # we recursively apply the function to the leave nodes def recursive_apply(function, input_var): @@ -58,8 +107,13 @@ def recursive_flatten(target): def append(arg): if isinstance(arg, paddle.Tensor): - if not arg.stop_gradient: - ret.append(arg) + # [NOTE] sometimes unnecessary tensors, such as the constant `mask` tensor in the PP layer, is passed into subsequent layers. + # When a tensor is marked with `stop_gradient=True`, it indicates that it does not contribute to gradient calculations, + # suggesting it's unrelated to the main computational process. + # Therefore, I try to eliminate the copying of such tensors in the to optimize performance. + # if not arg.stop_gradient: + # [NOTE] However, `stop_gradient=True` propagation rules within the framework appear to be flawed, so directly eliminate stop_gradient may cause bug + ret.append(arg) recursive_apply(append, target) return ret @@ -94,6 +148,7 @@ def __init__(self, num_warmup_steps): self.has_recorded = False + self.has_preserved_inputs = False self.args_static = None self.kwargs_static = None @@ -102,7 +157,21 @@ def __init__(self, num_warmup_steps): self.outputs_static = None def preserve_or_copy(self, args, kwargs): - if self.args_static is None: + """ + For the CUDA Graph, it is crucial that the buffer remains address-stable, + meaning that the buffer addresses for any inputs to the CUDA Graph should not change. + One solution to achieve this is to preserve all input tensors. + + This function attempts to recursively flatten the input arguments and keyword arguments + to identify all tensors passed to the layer (though it may still miss some due to other implicit + ways inputs can be passed to a layer). It then preserves references to these input tensors + as `self.inputs_static` so that the buffer pointers can be reused later. + + When this method is called subsequently, it copies the values back to the preserved input tensors + to ensure the buffers are reused. + """ + if not self.has_preserved_inputs: + self.has_preserved_inputs = True self.args_static = args self.kwargs_static = kwargs self.inputs_static = recursive_flatten_args_kwargs( @@ -119,6 +188,9 @@ def record(self, f, *args, **kwargs): self.graph.capture_begin() self.outputs_static = f(*self.args_static, **self.kwargs_static) self.graph.capture_end() + debug_print( + "[CUDAGraph] Record-Replay Start (Graph is replayed for the first time)" + ) self.graph.replay() self.has_recorded = True @@ -134,6 +206,7 @@ def replay(self, *args, **kwargs): self.preserve_or_copy(args, kwargs) + debug_print("[CUDAGraph] Replay Start") self.graph.replay() return self.outputs_static @@ -278,8 +351,12 @@ def forward(ctx, context, arg_tuple, *grad_inputs): detached_grad_inputs = recursive_flatten_args_kwargs(args, kwargs) inputs = (grad_inputs, detached_grad_inputs) - if context.is_warmup_step(): - logger.debug("[CUDAGraph] Forward Step (Default)") + printer(detached_grad_inputs, "Forward input") + if ( + context.is_warmup_step() + or debug_cudagraphedlayer_fallback_to_default + ): + debug_print("[CUDAGraph] Forward Step (Default)") with paddle.enable_grad(): y = context.layer(*args, **kwargs) @@ -289,7 +366,7 @@ def forward(ctx, context, arg_tuple, *grad_inputs): graph = context.get_graph() if graph.is_record_step(): # In record step, record the forward pass in CUDA graph - logger.info("[CUDAGraph] Forward Step (Record)") + debug_print(f"[CUDAGraph] Forward Step (Record) id {id(graph)}") def forward(*args, **kwargs): with paddle.enable_grad(): @@ -301,14 +378,17 @@ def forward(*args, **kwargs): (CUDAGraphLayerStatus.RECORD, graph, inputs, y) ) else: - logger.debug(f"[CUDAGraph] Forward Step (Graph - {id(graph)})") + debug_print(f"[CUDAGraph] Forward Step (Graph) id {id(graph)}") y = graph.forward_graph.replay(*args, **kwargs) context.push_data( (CUDAGraphLayerStatus.CUDAGRAPH, graph, None, y) ) + debug_print("[CUDAGraph] Forward Step End") + ctx.save_for_backward(context) + printer(y, "Forward output") return detach(y) @staticmethod @@ -322,8 +402,10 @@ def backward(ctx, *dys): (status, graph, inputs, ys) = context.pop_data() y, dy = select_y_with_grad(ys, dys) + printer((y, dy), "Backward input") + if status == CUDAGraphLayerStatus.WARMUP: - logger.debug("[CUDAGraph] Backward Step (Default)") + debug_print("[CUDAGraph] Backward Step (Default)") # In warmup step, perform standard backward operation y.backward(dy) @@ -331,7 +413,7 @@ def backward(ctx, *dys): context.warmup_step() elif status == CUDAGraphLayerStatus.RECORD: - logger.info("[CUDAGraph] Backward Step (Record)") + debug_print(f"[CUDAGraph] Backward Step (Record) id {id(graph)}") # In record step, record the backward pass in CUDA graph def backward(y, dy): @@ -347,7 +429,7 @@ def backward(y, dy): context.reuse_graph(graph) elif status == CUDAGraphLayerStatus.CUDAGRAPH: - logger.debug(f"[CUDAGraph] Backward Step (Graph) - {id(graph)}") + debug_print(f"[CUDAGraph] Backward Step (Graph) id {id(graph)}") # In CUDA graph step, replay the recorded graph for backward pass args_grad = graph.backward_graph.replay(y, dy) @@ -355,6 +437,9 @@ def backward(y, dy): else: raise RuntimeError("Unknown cuda graph status") + debug_print("[CUDAGraph] Backward Step End") + + printer(args_grad, "Backward output") return args_grad diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index c2bd3b811428ed..801c997ab558b4 100755 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -1744,7 +1744,9 @@ def start_pod_worker(self, args, pod): if args.log_dir is not None: os.makedirs(args.log_dir, exist_ok=True) - fn = open("%s/workerlog.%d" % (args.log_dir, idx), "w") + fn = open( + "%s/workerlog.%d" % (args.log_dir, cur_worker.rank), "w" + ) self.log_fns["worker"].append(fn) proc = subprocess.Popen( cmd, env=current_env, stdout=fn, stderr=fn diff --git a/python/paddle/distributed/fleet/meta_parallel/__init__.py b/python/paddle/distributed/fleet/meta_parallel/__init__.py index bdf76262157d40..f7b9a5c37b2e66 100644 --- a/python/paddle/distributed/fleet/meta_parallel/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/__init__.py @@ -26,8 +26,10 @@ ) from .pipeline_parallel import ( # noqa: F401 PipelineParallel, + PipelineParallelMicroStepLocations, PipelineParallelWithInterleave, PipelineParallelWithInterleaveFthenB, + register_global_pipeline_parallel_hook, ) from .segment_parallel import SegmentParallel # noqa: F401 from .sharding_parallel import ShardingParallel # noqa: F401 diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index e6eab31cfbf446..f7723988f2f830 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -16,6 +16,8 @@ import time import warnings from collections import defaultdict +from enum import Enum +from typing import Callable, Dict, List import paddle from paddle import framework @@ -146,6 +148,84 @@ def _check_data_valid(self, data): ) +# Enum for specifying the pipeline parallel micro-step locations. +class PipelineParallelMicroStepLocations(Enum): + FORWARD_BEGIN = 'forward_begin' + FORWARD_END = 'forward_end' + BACKWARD_BEGIN = 'backward_begin' + BACKWARD_END = 'backward_end' + + +# A callback class for managing hooks at different stages of a pipeline parallel process. +class PipelineParallelMicroStepCallback: + def __init__(self): + # Initializes a dictionary to store hooks for each micro-step location in the pipeline. + self.hooks: Dict[PipelineParallelMicroStepLocations, List[Callable]] = { + PipelineParallelMicroStepLocations.FORWARD_BEGIN: [], + PipelineParallelMicroStepLocations.FORWARD_END: [], + PipelineParallelMicroStepLocations.BACKWARD_BEGIN: [], + PipelineParallelMicroStepLocations.BACKWARD_END: [], + } + + def register_hook( + self, location: PipelineParallelMicroStepLocations, hook: Callable + ): + """ + Registers a hook function to be called at a specified pipeline parallel micro-step location. + + Args: + location (PipelineParallelMicroStepLocations): The micro-step location where the hook should be registered. + hook (Callable): The hook function to be registered. The function should accept the following optional keyword arguments: + - input_tensor (paddle.Tensor): The input tensor to the current micro-step. + - output_tensor (paddle.Tensor): The output tensor from the current micro-step. + - input_tensor_grad (paddle.Tensor): The gradient of the input tensor. + - output_tensor_grad (paddle.Tensor): The gradient of the output tensor. + - step_id (paddle.Tensor): An identifier for the current step in the pipeline. + + Raises: + AssertionError: If the specified location is not a valid micro-step location. + """ + assert ( + location in self.hooks + ), f"Invalid location '{location}'. Valid locations are 'forward_begin', 'forward_end', 'backward_begin', or 'backward_end'." + self.hooks[location].append(hook) + + def on_location( + self, location: PipelineParallelMicroStepLocations, **kwargs + ): + """ + Triggers all registered hooks at a specified pipeline parallel micro-step location. + + Args: + location (PipelineParallelMicroStepLocations): The micro-step location where the hooks should be triggered. + kwargs: Additional keyword arguments to be passed to the hook functions. + + Raises: + AssertionError: If the specified location is not a valid micro-step location. + """ + assert ( + location in self.hooks + ), f"Invalid location '{location}'. Valid locations are 'forward_begin', 'forward_end', 'backward_begin', or 'backward_end'." + for hook in self.hooks[location]: + hook(**kwargs) + + +pipeline_parallel_callbacks_ = PipelineParallelMicroStepCallback() + + +# It is typically very difficult for us to directly access the PipelineParallel object. +# Users may use fleet.distributed_model to wrap a model into a pipeline parallel model (PP model). +# We may not have access to the wrapped model when we want to register hooks, for example, when using PaddleNLP trainer to wrap around the PP model. +# Additionally, we usually have only one `PipelineParallel` model, so the callbacks are registered globally. +def register_global_pipeline_parallel_hook( + location: PipelineParallelMicroStepLocations, hook: Callable +): + """ + Registering global hooks for pipeline parallelism. + """ + pipeline_parallel_callbacks_.register_hook(location, hook) + + class PipelineParallel(MetaParallelBase): def __init__(self, layers, hcg, strategy): if not isinstance(layers, PipelineLayer): @@ -299,6 +379,7 @@ def __init__(self, layers, hcg, strategy): self.loss_fn_idx = 0 self._compute_loss = True + self.callbacks = pipeline_parallel_callbacks_ logger.info( f"Pipeline Info -- num_stages: {self.num_stages}, stage_id: {self.stage_id}" @@ -325,6 +406,11 @@ def __init__(self, layers, hcg, strategy): self._layers, self.dp_group, self.accumulate_steps, True ) + def register_hook( + self, location: PipelineParallelMicroStepLocations, hook: Callable + ): + self.callbacks.register_hook(location, hook) + def is_pipeline_first_stage(self, ignore_virtual=False): if not ignore_virtual: if self._virtual_pp_world_size is not None: @@ -508,7 +594,9 @@ def forward_backward_pipeline( ) self._record_stamp("F", step_id, '"B"', self._forward_color) - output_tensor = self._forward_step(input_tensor, micro_dataset) + output_tensor = self._forward_step( + input_tensor, micro_dataset, step_id=step_id + ) self._record_stamp("F", step_id, '"E"', self._forward_color) self._p2p_helper.send_forward( output_tensor, @@ -540,7 +628,9 @@ def forward_backward_pipeline( self._record_stamp( "F", startup_steps + i, '"B"', self._forward_color ) - output_tensor = self._forward_step(input_tensor, micro_dataset) + output_tensor = self._forward_step( + input_tensor, micro_dataset, step_id=startup_steps + i + ) self._record_stamp( "F", startup_steps + i, '"E"', self._forward_color ) @@ -563,7 +653,7 @@ def forward_backward_pipeline( self._record_stamp("B", i, '"B"', self._backward_color) input_tensor_grad = self._backward_step( - input_tensor, output_tensor, output_tensor_grad + input_tensor, output_tensor, output_tensor_grad, step_id=i ) self._record_stamp("B", i, '"E"', self._backward_color) @@ -598,7 +688,10 @@ def forward_backward_pipeline( "B", steady_steps + i, '"B"', self._backward_color ) input_tensor_grad = self._backward_step( - input_tensor, output_tensor, output_tensor_grad + input_tensor, + output_tensor, + output_tensor_grad, + step_id=steady_steps + i, ) self._record_stamp( "B", steady_steps + i, '"E"', self._backward_color @@ -758,7 +851,9 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0): self.is_pipeline_first_stage() ) - output_tensor = self._forward_step(input_tensor, micro_dataset) + output_tensor = self._forward_step( + input_tensor, micro_dataset, step_id=None + ) self._p2p_helper.send_forward( output_tensor, self.is_pipeline_last_stage(), @@ -776,7 +871,9 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0): for i in range(steady_steps): last_iter = i == (steady_steps - 1) - output_tensor = self._forward_step(input_tensor, micro_dataset) + output_tensor = self._forward_step( + input_tensor, micro_dataset, step_id=None + ) self._p2p_helper.send_forward( output_tensor, self.is_pipeline_last_stage(), @@ -798,7 +895,9 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0): return self.train_loss - def _forward_step(self, input_tensor, micro_dataset, chunk_id=None): + def _forward_step( + self, input_tensor, micro_dataset, chunk_id=None, step_id=None + ): if self._enable_timer: self.timers("forward_step").start() if self.is_pipeline_first_stage(): @@ -807,7 +906,18 @@ def _forward_step(self, input_tensor, micro_dataset, chunk_id=None): assert chunk_id is None or isinstance(chunk_id, int) + self.callbacks.on_location( + PipelineParallelMicroStepLocations.FORWARD_BEGIN, + input_tensor=input_tensor, + step_id=step_id, + ) output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id) + self.callbacks.on_location( + PipelineParallelMicroStepLocations.FORWARD_END, + input_tensor=input_tensor, + output_tensor=output_tensor, + step_id=step_id, + ) if self.is_pipeline_last_stage(): # train calculate loss for train @@ -849,10 +959,19 @@ def _forward_step(self, input_tensor, micro_dataset, chunk_id=None): return backward_loss_tensor return output_tensor - def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): + def _backward_step( + self, input_tensor, output_tensor, output_tensor_grad, step_id=None + ): if self._enable_timer: self.timers("backward_step").start() with paddle.amp.auto_cast(enable=False): + self.callbacks.on_location( + PipelineParallelMicroStepLocations.BACKWARD_BEGIN, + input_tensor=input_tensor, + output_tensor=output_tensor, + output_tensor_grad=output_tensor_grad, + step_id=step_id, + ) if self.is_pipeline_last_stage(): assert output_tensor_grad is None if self.scaler: @@ -883,6 +1002,14 @@ def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): input_tensor_grad = input_tensor.grad if self._enable_timer: self.timers("backward_step").stop() + self.callbacks.on_location( + PipelineParallelMicroStepLocations.BACKWARD_END, + input_tensor=input_tensor, + output_tensor=output_tensor, + input_tensor_grad=input_tensor_grad, + output_tensor_grad=output_tensor_grad, + step_id=step_id, + ) return input_tensor_grad def _check_micro_batch_data_valid(self, micro_batch_data): @@ -1131,7 +1258,7 @@ def _forward_step_helper(self, micro_dataset, micro_step): ) input_tensor = self.input_tensors[virtual_pp_rank][-1] output_tensor = self._forward_step( - input_tensor, micro_dataset, virtual_pp_rank + input_tensor, micro_dataset, virtual_pp_rank, step_id=micro_step ) self.output_tensors[virtual_pp_rank].append(output_tensor) @@ -1195,7 +1322,7 @@ def _backward_step_helper(self, micro_step): output_tensor = self.output_tensors[virtual_pp_rank].pop(0) output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0) input_tensor_grad = self._backward_step( - input_tensor, output_tensor, output_tensor_grad + input_tensor, output_tensor, output_tensor_grad, step_id=micro_step ) self._overlap_comm_grads() diff --git a/python/paddle/distributed/launch/controllers/collective.py b/python/paddle/distributed/launch/controllers/collective.py index d76c3349806c18..ddfa8a9e82afde 100644 --- a/python/paddle/distributed/launch/controllers/collective.py +++ b/python/paddle/distributed/launch/controllers/collective.py @@ -148,7 +148,7 @@ def _build_pod_with_args(self): else: e.update({'PADDLE_DISTRI_BACKEND': 'gloo'}) - log_file = f"workerlog.{i}" + log_file = f"workerlog.{i + rank_offset}" self.add_container(envs=e, log_file=log_file) return True @@ -251,7 +251,7 @@ def _build_pod_with_master(self, reset_pod=True): e.update({'PADDLE_DISTRI_BACKEND': 'gloo'}) # log_file = "{}.{}.{}.log".format(self.job.id, self.pod.name, i) - log_file = f"workerlog.{i}" + log_file = f"workerlog.{i + rank_offset}" self.add_container(envs=e, log_file=log_file) return True diff --git a/python/paddle/distributed/launch/controllers/ps.py b/python/paddle/distributed/launch/controllers/ps.py index 0160625c1061fb..276faf196daf97 100644 --- a/python/paddle/distributed/launch/controllers/ps.py +++ b/python/paddle/distributed/launch/controllers/ps.py @@ -111,7 +111,7 @@ def _build_pod_with_args(self): "POD_IP": self.ctx.node.ip, } e.update(_gloo_envs) - log_file = f"workerlog.{i}" + log_file = f"workerlog.{i + trainer_rank_offset}" self.add_container(envs=e, log_file=log_file) def _build_pod_with_master(self): @@ -214,7 +214,7 @@ def _build_pod_with_master(self): "POD_IP": self.ctx.node.ip, } e.update(_gloo_envs) - log_file = f"workerlog.{i}" + log_file = f"workerlog.{i + trainer_rank_offset}" self.add_container(envs=e, log_file=log_file) ''' NEW VERSION for i in range(server_num): diff --git a/python/paddle/distributed/utils/launch_utils.py b/python/paddle/distributed/utils/launch_utils.py index 51a9ebb8a4dcf9..fa1deedc0bbda1 100644 --- a/python/paddle/distributed/utils/launch_utils.py +++ b/python/paddle/distributed/utils/launch_utils.py @@ -473,7 +473,7 @@ def start_local_trainers( fn = None if log_dir is not None: os.makedirs(log_dir, exist_ok=True) - fn = open("%s/workerlog.%d" % (log_dir, idx), "a") + fn = open("%s/workerlog.%d" % (log_dir, t.rank), "a") proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn) else: proc = subprocess.Popen(cmd, env=current_env)