Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Replaced blockReduce[...] functions with cub::BlockReduce #7233

Merged
merged 9 commits into from
Aug 22, 2024

Conversation

ProExpertProg
Copy link
Contributor

@ProExpertProg ProExpertProg commented Aug 7, 2024

Replace all uses of the custom (and buggy) blockReduce function with the idiomatic cub::BlockReduce.

I also removed the runtime dispatching for block sizes per @LucasWilkinson's suggestion.

Here are the results of the layernorm and quant microbenchmarks.

The summary is that the cub version is (geomean) 0.2% faster for layernorm and 1.3% slower for quant, but that's most likely just noise. That's because different shapes & dtypes produce vastly different ratios of the cub and main runtimes. And there don't seem to be any patterns.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

github-actions bot commented Aug 7, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@ProExpertProg ProExpertProg marked this pull request as ready for review August 7, 2024 00:06
@ProExpertProg
Copy link
Contributor Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 7, 2024
@ProExpertProg
Copy link
Contributor Author

cc @WoosukKwon

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice cleanup, thanks!

@tlrmchlsmth
Copy link
Collaborator

@ProExpertProg could you check if the kernels-test-3 failure is related to your changes?

@mawong-amd
Copy link
Contributor

mawong-amd commented Aug 7, 2024

I also removed the runtime dispatching for block sizes per @LucasWilkinson's suggestion.

Would appreciate a link to this suggestion/discussion if it's public! Is there a reason why this is done? The runtime dispatch for block sizes was done for performance reasons: smaller block sizes allow better block occupancy on CUs for memory latency hiding during the prefill phase (or from a more agnostic view: when num_tokens = kernel grid size is sufficiently large, so enough blocks can be scheduled). If CUB avoids this issue/provides performance improvement over the current implementation, happy to have this change. I'm curious if we'll still see the same gains in prefill when having smaller block sizes.

(and buggy) blockReduce function

Out of curiosity, what bugs have you seen?

@ProExpertProg
Copy link
Contributor Author

ProExpertProg commented Aug 7, 2024

@mawong-amd there were a few bugs:

  • It used the values from all threads instead of using active threads only, which returns undefined for inactive threads, but it happened to be 0 which worked for sum and max (but not min).
  • Similarly, it used 0 as the sentinel value which worked for sum and absmax but not min or max where negative values are involved.
  • It used static shared memory which meant two separate invocations could interact with each other which was unintuitive. Perhaps not a bug per-se but it took a while to diagnose. cub::BlockReduce is more explicit about its shared memory use.

smaller block sizes allow better block occupancy on CUs for memory latency hiding during the prefill phase

We still use a smaller block size, we just have a single block reduction size as opposed to 2 separate ones which will still take up the same amount of shared memory space (because we pick between them at kernel runtime). See 6050cae ff5b44c.

If we care about the size of shared memory in the kernel, I can add a template parameter so we can avoid the branch and specialize.

@LucasWilkinson
Copy link
Contributor

LucasWilkinson commented Aug 7, 2024

Would appreciate a link to this suggestion/discussion if it's public! Is there a reason why this is done? The runtime dispatch for block sizes was done for performance reasons: smaller block sizes allow better block occupancy on CUs for memory latency hiding during the prefill phase (or from a more agnostic view: when num_tokens = kernel grid size is sufficiently large, so enough blocks can be scheduled). If CUB avoids this issue/provides performance improvement over the current implementation, happy to have this change. I'm curious if we'll still see the same gains in prefill when having smaller block sizes.

The kernels are still being launched based on this but @ProExpertProg just removed templating 2 separate block reduce implementations and runtime dispatching between those here, with cub both those cases can be handle but just correctly setting num_valid.

Out of curiosity, what bugs have you seen?
Its not buggy, but its opaque api design can easily lead people to introducing bugs, this occurred with us by doing some thing like:

float a = ....;
float b = ....;
float x = blockReduceSum<float>(a);
float y = blockReduceSum<float>(b);
if (threadIdx.x == 0) {
   // do something with x and y
}

in this case x could non-deterministically contain incorrect an result since the float y = blockReduceSum<float>(b); could trash the shared memory used by float x = blockReduceSum<float>(a);. This is fixed by doing:

float a = ....;
float b = ....;
float x = blockReduceSum<float>(a);
__syncthreads();
float y = blockReduceSum<float>(b);
if (threadIdx.x == 0) {
   // do something with x and y
}

but its hard to for a user to know that blockReduceSum is allocating and using shared memory from the callsite alone, cub fixes this by forcing the user to allocate the shared memory outside of the reduce making it more obvious to the reader a __syncthreads(); may be required.

Also just from software engineering perspective I think it makes sense for use to leverage hardened Nvidia implementations of common routines when possible. (as a bonus we benefit from their docs, which in this case would explain the need for sync threads: https://nvidia.github.io/cccl/cub/developer_overview.html#id5)

@mawong-amd
Copy link
Contributor

mawong-amd commented Aug 7, 2024

It used the values from all threads instead of using active threads only, which returns undefined for inactive threads, but it happened to be 0 which worked for sum and max (but not min).
Similarly, it used 0 as the sentinel value which worked for sum and absmax but not min or max where negative values are involved.

Great catch!

It used static shared memory which meant two separate invocations could interact with each other which was unintuitive. Perhaps not a bug per-se but it took a while to diagnose. cub::BlockReduce is more explicit about its shared memory use.

Coincidentally I thought about this a few weeks ago while reviewing code, but I eventually reasoned that static __shared__ shouldn't differ from __shared__. Reason being, shared memory is implemented in L1 cache and is reserved per-block when the block is assigned to a EU, and is also released when the block completes execution (with some caveats for Compute Capability 9.0+). In particular, it's not persistent across different kernel invocations. So there shouldn't be a clash between simultaneously executing kernels. Feel free to correct me if I'm mistaken.

Edit: Ahhh you mean clashes between different calls in the same kernel. Apologies! Yes, that's a bit unexpected and it makes sense to be more explicit about its shared memory use.

@mawong-amd
Copy link
Contributor

mawong-amd commented Aug 7, 2024

with cub both those cases can be handle but just correctly setting num_valid

I see! I'm not too familiar with CUB, but in that case, this removal makes sense.

I think it makes sense for use to leverage hardened Nvidia implementations of common routines when possible. (as a bonus we benefit from their docs, which in this case would explain the need for sync threads: https://nvidia.github.io/cccl/cub/developer_overview.html#id5)

I agree completely. I believe there are still a few places in vLLM where people implement their own unoptimized blockReduce (ignoring the vLLM implementation), so by using higher quality implementations like in this PR, hopefully that fragmentation will be reduced.

@ProExpertProg
Copy link
Contributor Author

@mawong-amd if you point me to those places I'd be happy to consolidate and try to replace them in this PR (or another) if we want to do that.

@ProExpertProg ProExpertProg changed the title Replaced blockReduce[...] functions with cub::BlockReduce [Kernel] Replaced blockReduce[...] functions with cub::BlockReduce Aug 7, 2024
@ProExpertProg
Copy link
Contributor Author

ProExpertProg commented Aug 7, 2024

@tlrmchlsmth I checked and test_flash_attn fails on current main as well, so it's not related to this PR. And it does not use any kernels affected in this PR.

@tlrmchlsmth
Copy link
Collaborator

@ProExpertProg could you share some performance numbers with the benchmark you wrote?

@ProExpertProg
Copy link
Contributor Author

Yep, sorry, took a little longer than expected. Here are the results of the layernorm and quant microbenchmarks.

The summary is that the cub version is (geomean) 0.2% faster for layernorm and 1.3% slower for quant, but that's most likely just noise. That's because different shapes & dtypes produce vastly different ratios of the cub and main runtimes. And there don't seem to be any patterns.

@tlrmchlsmth tlrmchlsmth merged commit 7937009 into vllm-project:main Aug 22, 2024
64 checks passed
@mgoin mgoin deleted the luka/block-reduce branch August 22, 2024 00:56
omrishiv pushed a commit to omrishiv/vllm that referenced this pull request Aug 26, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants