Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,6 @@ class CUDAMallocAsyncAllocator : public Allocator {
void SetDefaultStream(gpuStream_t stream);
void ClearFreeStream(bool sync = false);

~CUDAMallocAsyncAllocator() {
VLOG(0) << "Async allocator is freed " << (this)
<< " tid = " << std::this_thread::get_id();
}

protected:
void FreeImpl(phi::Allocation* allocation) override;
phi::Allocation* AllocateImpl(size_t size) override;
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/backends/gpu/cuda/cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
#include "glog/logging.h"
#include "paddle/common/flags.h"

#ifdef PADDLE_WITH_CUDA
Expand Down
3 changes: 0 additions & 3 deletions paddle/phi/backends/gpu/cuda/cuda_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
#include <unordered_set>
#include <vector>

#include "glog/logging.h"

#include "paddle/common/errors.h"
#include "paddle/common/macros.h"
#include "paddle/phi/backends/context_pool.h"
Expand Down Expand Up @@ -65,7 +63,6 @@ class CUDAGraphContextManager {

DeviceContext *Get(int64_t pool_id, const Place &place, int stream_priority) {
std::lock_guard<std::mutex> lk(ctx_mtx_);
VLOG(6) << "Get cuda graph device context for " << place;

DeviceContextMap &ctxs = cuda_graph_ctx_pool_[pool_id];
if (ctxs.find(place) == ctxs.end()) {
Expand Down
105 changes: 95 additions & 10 deletions python/paddle/device/cuda/cuda_graphed_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
import os
from collections import deque
from enum import Enum

Expand All @@ -21,11 +22,59 @@

from .graphs import CUDAGraph

# CUDAGraphedLayer Debug tools
enable_debug_print = bool(
int(os.getenv('PADDLE_DEBUG_ENABLE_CUDAGRAPH_LAYER_LOGGING', '0'))
)
debug_cudagraphedlayer_fallback_to_default = bool(
int(os.getenv('PADDLE_DEBUG_CUDAGRAPHEDLAYER_FALLBACK_TO_DEFAULT', '0'))
)

logger = log_helper.get_logger(
__name__, logging.INFO, fmt='[%(levelname)s] %(message)s'
)


def debug_print(x):
if not enable_debug_print:
return
logger.info(x)


def print_tensor(
t,
name="Unnamed",
print_meta=True,
print_ptr=False,
print_hash=True,
hash=None,
):
output = []
if name:
output.append(name)
if hash is None:
hash = lambda t: float((t.astype('float32') * 1000).sum())

if t is None:
debug_print(f"{name} is None")
elif isinstance(t, paddle.Tensor):
if print_meta:
output.append(f"shape = {t.shape}")
output.append(f"place = {t.place}")
if print_ptr:
output.append(f"ptr = {hex(t.data_ptr())}")
if print_hash:
output.append(f"hash = {hash(t)}")
debug_print(" | ".join(output))


def printer(x, banner="printer"):
if not enable_debug_print:
return
debug_print(banner.center(100, "-"))
recursive_apply(print_tensor, x)


# We need this function, for any kind of inputs with iterables
# we recursively apply the function to the leave nodes
def recursive_apply(function, input_var):
Expand Down Expand Up @@ -58,8 +107,13 @@ def recursive_flatten(target):

def append(arg):
if isinstance(arg, paddle.Tensor):
if not arg.stop_gradient:
ret.append(arg)
# [NOTE] sometimes unnecessary tensors, such as the constant `mask` tensor in the PP layer, is passed into subsequent layers.
# When a tensor is marked with `stop_gradient=True`, it indicates that it does not contribute to gradient calculations,
# suggesting it's unrelated to the main computational process.
# Therefore, I try to eliminate the copying of such tensors in the to optimize performance.
# if not arg.stop_gradient:
# [NOTE] However, `stop_gradient=True` propagation rules within the framework appear to be flawed, so directly eliminate stop_gradient may cause bug
ret.append(arg)

recursive_apply(append, target)
return ret
Expand Down Expand Up @@ -94,6 +148,7 @@ def __init__(self, num_warmup_steps):

self.has_recorded = False

self.has_preserved_inputs = False
self.args_static = None
self.kwargs_static = None

Expand All @@ -102,7 +157,21 @@ def __init__(self, num_warmup_steps):
self.outputs_static = None

def preserve_or_copy(self, args, kwargs):
if self.args_static is None:
"""
For the CUDA Graph, it is crucial that the buffer remains address-stable,
meaning that the buffer addresses for any inputs to the CUDA Graph should not change.
One solution to achieve this is to preserve all input tensors.

This function attempts to recursively flatten the input arguments and keyword arguments
to identify all tensors passed to the layer (though it may still miss some due to other implicit
ways inputs can be passed to a layer). It then preserves references to these input tensors
as `self.inputs_static` so that the buffer pointers can be reused later.

When this method is called subsequently, it copies the values back to the preserved input tensors
to ensure the buffers are reused.
"""
if not self.has_preserved_inputs:
self.has_preserved_inputs = True
self.args_static = args
self.kwargs_static = kwargs
self.inputs_static = recursive_flatten_args_kwargs(
Expand All @@ -119,6 +188,9 @@ def record(self, f, *args, **kwargs):
self.graph.capture_begin()
self.outputs_static = f(*self.args_static, **self.kwargs_static)
self.graph.capture_end()
debug_print(
"[CUDAGraph] Record-Replay Start (Graph is replayed for the first time)"
)
self.graph.replay()

self.has_recorded = True
Expand All @@ -134,6 +206,7 @@ def replay(self, *args, **kwargs):

self.preserve_or_copy(args, kwargs)

debug_print("[CUDAGraph] Replay Start")
self.graph.replay()
return self.outputs_static

Expand Down Expand Up @@ -278,8 +351,12 @@ def forward(ctx, context, arg_tuple, *grad_inputs):
detached_grad_inputs = recursive_flatten_args_kwargs(args, kwargs)
inputs = (grad_inputs, detached_grad_inputs)

if context.is_warmup_step():
logger.debug("[CUDAGraph] Forward Step (Default)")
printer(detached_grad_inputs, "Forward input")
if (
context.is_warmup_step()
or debug_cudagraphedlayer_fallback_to_default
):
debug_print("[CUDAGraph] Forward Step (Default)")

with paddle.enable_grad():
y = context.layer(*args, **kwargs)
Expand All @@ -289,7 +366,7 @@ def forward(ctx, context, arg_tuple, *grad_inputs):
graph = context.get_graph()
if graph.is_record_step():
# In record step, record the forward pass in CUDA graph
logger.info("[CUDAGraph] Forward Step (Record)")
debug_print(f"[CUDAGraph] Forward Step (Record) id {id(graph)}")

def forward(*args, **kwargs):
with paddle.enable_grad():
Expand All @@ -301,14 +378,17 @@ def forward(*args, **kwargs):
(CUDAGraphLayerStatus.RECORD, graph, inputs, y)
)
else:
logger.debug(f"[CUDAGraph] Forward Step (Graph - {id(graph)})")
debug_print(f"[CUDAGraph] Forward Step (Graph) id {id(graph)}")
y = graph.forward_graph.replay(*args, **kwargs)

context.push_data(
(CUDAGraphLayerStatus.CUDAGRAPH, graph, None, y)
)

debug_print("[CUDAGraph] Forward Step End")

ctx.save_for_backward(context)
printer(y, "Forward output")
return detach(y)

@staticmethod
Expand All @@ -322,16 +402,18 @@ def backward(ctx, *dys):
(status, graph, inputs, ys) = context.pop_data()
y, dy = select_y_with_grad(ys, dys)

printer((y, dy), "Backward input")

if status == CUDAGraphLayerStatus.WARMUP:
logger.debug("[CUDAGraph] Backward Step (Default)")
debug_print("[CUDAGraph] Backward Step (Default)")

# In warmup step, perform standard backward operation
y.backward(dy)
args_grad = get_args_grad(inputs)

context.warmup_step()
elif status == CUDAGraphLayerStatus.RECORD:
logger.info("[CUDAGraph] Backward Step (Record)")
debug_print(f"[CUDAGraph] Backward Step (Record) id {id(graph)}")

# In record step, record the backward pass in CUDA graph
def backward(y, dy):
Expand All @@ -347,14 +429,17 @@ def backward(y, dy):

context.reuse_graph(graph)
elif status == CUDAGraphLayerStatus.CUDAGRAPH:
logger.debug(f"[CUDAGraph] Backward Step (Graph) - {id(graph)}")
debug_print(f"[CUDAGraph] Backward Step (Graph) id {id(graph)}")

# In CUDA graph step, replay the recorded graph for backward pass
args_grad = graph.backward_graph.replay(y, dy)
context.reuse_graph(graph)
else:
raise RuntimeError("Unknown cuda graph status")

debug_print("[CUDAGraph] Backward Step End")

printer(args_grad, "Backward output")
return args_grad


Expand Down
4 changes: 3 additions & 1 deletion python/paddle/distributed/fleet/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1744,7 +1744,9 @@ def start_pod_worker(self, args, pod):

if args.log_dir is not None:
os.makedirs(args.log_dir, exist_ok=True)
fn = open("%s/workerlog.%d" % (args.log_dir, idx), "w")
fn = open(
"%s/workerlog.%d" % (args.log_dir, cur_worker.rank), "w"
)
self.log_fns["worker"].append(fn)
proc = subprocess.Popen(
cmd, env=current_env, stdout=fn, stderr=fn
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/fleet/meta_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
)
from .pipeline_parallel import ( # noqa: F401
PipelineParallel,
PipelineParallelMicroStepLocations,
PipelineParallelWithInterleave,
PipelineParallelWithInterleaveFthenB,
register_global_pipeline_parallel_hook,
)
from .segment_parallel import SegmentParallel # noqa: F401
from .sharding_parallel import ShardingParallel # noqa: F401
Expand Down
Loading