Skip to content

Support reduce-scatter when dp=tp#5988

Closed
EstherBear wants to merge 5 commits intosgl-project:mainfrom
EstherBear:feature/ds-opt
Closed

Support reduce-scatter when dp=tp#5988
EstherBear wants to merge 5 commits intosgl-project:mainfrom
EstherBear:feature/ds-opt

Conversation

@EstherBear
Copy link
Copy Markdown

Motivation

This PR modifies the deepseek model to use reduce-scatter in place of separate reduce and scatter operations when data parallel attention is enabled and dp == tp. By doing so, it exposes potential opportunities for the use of optimized reduce-scatter kernels for performance gains in future.

Modifications

  1. Modify deepseek model to use reduce-scatter when data parallel attention is enabled and dp == tp.
  2. Modify CudaGraphRunner and ForwardBatch to include necessary metadata for the reduce-scatter operations.

Checklist

@EstherBear
Copy link
Copy Markdown
Author

Hi @ch-wan, is this ready to merge or should I do any modifications?

@EstherBear
Copy link
Copy Markdown
Author

Hi @ch-wan, any updates?

@ch-wan
Copy link
Copy Markdown
Collaborator

ch-wan commented May 6, 2025

@EstherBear Sorry for the late response. I'm going to review it tonight.

if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
if self.tp_size == self.dp_size:
tensor_list = list(final_hidden_states.tensor_split(split_indices_cpu))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using Tensor.split() rather than Tensor.tensor_split()? This can reuse the info in forward_batch.global_num_tokens_gpu to make the code concise.

return get_attention_tp_group().all_gather(input_, tensor_list=output_list)


def dp_reduce_scatter(
Copy link
Copy Markdown
Collaborator

@ch-wan ch-wan May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function performs reduce-scatter across the global TP group. We may name it as tensor_model_parallel_reduce_scatter and place it to a different file.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just tp_reduce_scatter is sufficient

if self.tp_size == self.dp_size:
tensor_list = list(final_hidden_states.tensor_split(split_indices_cpu))
final_hidden_states = tensor_list[self.dp_rank]
dp_reduce_scatter(final_hidden_states, tensor_list)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor_model_parallel_reduce_scatter

],
hidden_states,
)
dp_copy(hidden_states, tmp_hidden_states, forward_batch)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my understanding is correct, this line is doing the same thing with dp_scatter.

@ch-wan
Copy link
Copy Markdown
Collaborator

ch-wan commented May 7, 2025

@EstherBear Thank you for your contribution. I have finished my review and left some comments. It appears that the PR does not pass some CIs (e.g., https://github.com/sgl-project/sglang/actions/runs/14876617018/job/41775058885?pr=5988). Could you please fix it? Also, it is highly encouraged to share your benchmark results. This allows us to clearly see the performance gain from your PR. For example, you can launch this to compare the efficiency:

python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2-Lite --disable-radix-cache --trust-remote-code --tp 8 --dp 8 --enable-dp-attention

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 512 --random-input 1000 --random-output 1000 --random-range-ratio 1 --host 127.0.0.1 --port 30000 --max-concurrency 128

@ch-wan
Copy link
Copy Markdown
Collaborator

ch-wan commented Aug 11, 2025

Done in #8539

@ch-wan ch-wan closed this Aug 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants