Skip to content

Conversation

@zhangyuqin1998
Copy link
Contributor

@zhangyuqin1998 zhangyuqin1998 commented Apr 9, 2025

PR Category

Distributed Strategy

PR Types

New features

Description

Support for the forward_backward_overlap mode in VPP is provided.

To enable forward-backward operations in VPP, follow these steps:

  1. Implement the overlapped_forward_backward method for the custom PipelineLayer. This will ensure that the framework uses your custom overlapped_forward_backward instead of the default pipeline forward and backward processes.
  2. Update the model configuration by setting "pipeline_parallel_config": "forward_backward_overlap_scheduler".

vp_dp_overlap-1-第 19 页 drawio

Additionally, the overlapped_forward_backward can be implemented as follows. To achieve overlapping, consider breaking down the forward_chunk and backward_chunk into more granular parts.

def overlapped_forward_backward(
    self,
    forward_chunk,  # the module of the forward chunk
    forward_inputs,
    forward_loss_fn_node, # maybe not used
    backward_chunk,  # the module of the backward chunk
    backward_loss_fn_node, # maybe not used
    backward_input_grads,
    scaler,
    p2p_async_handle=None,  # only used for p2p comm overlap
):          
    if backward_loss_fn_node is not None:
        if scaler:
            backward_input_grads = backward_loss_fn_node.backward(scaler=scaler)
        else:
            backward_input_grads = backward_loss_fn_node.backward()

    if p2p_async_handle is not None:
        p2p_async_handle.forward_handle_wait()
    forward_inputs = forward_chunk.forward(forward_inputs)
    
    if p2p_async_handle is not None:
        p2p_async_handle.forward_async_comm(forward_inputs)
        p2p_async_handle.backward_handle_wait()

    backward_input_grads = backward_chunk.backward(backward_input_grads)

    if p2p_async_handle is not None:
        p2p_async_handle.backwatd_async_comm(backward_input_grads)
            
    if forward_loss_fn_node is not None:
        forward_loss = forward_loss_fn_node.forward(forward_inputs)
    else:
        forward_loss = None

    return forward_inputs, forward_loss, backward_input_grads

cp from #71995

…ePaddle#71995)

* [Distributed] Support forward_backward_overlap mode for VPP

* add

* fix name
@paddle-bot
Copy link

paddle-bot bot commented Apr 9, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@ForFishes ForFishes merged commit 9ed6fec into PaddlePaddle:incubate/fleety_20250403 Apr 9, 2025
2 of 4 checks passed
@zhangyuqin1998 zhangyuqin1998 deleted the vpp_overlap_fleety branch May 9, 2025 06:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants