Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][torch.compile] not compile for profiling #7796

Merged
merged 17 commits into from
Aug 27, 2024

Conversation

youkaichao
Copy link
Member

we only profile once to determine the space for kv-cache, and don't run profile anymore.

compiling this run will only add guards and code cache for useless code.

removing this compilation can reduce the overhead of dynamp.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao
Copy link
Member Author

my measurement on gemma 2b shows this can reduce the Dynamo overhead by about 0.1~0.2 ms.

Dockerfile.tpu Outdated Show resolved Hide resolved
@WoosukKwon WoosukKwon self-assigned this Aug 27, 2024
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left minor comments.

.buildkite/run-tpu-test.sh Outdated Show resolved Hide resolved
vllm/worker/tpu_model_runner.py Outdated Show resolved Hide resolved
vllm/worker/model_runner_base.py Outdated Show resolved Hide resolved
vllm/worker/model_runner_base.py Outdated Show resolved Hide resolved
@WoosukKwon
Copy link
Collaborator

@youkaichao On Gemma-2B, this PR reduced the KV cache space from 34552 to 25752, 25% drop.

@WoosukKwon
Copy link
Collaborator

@youkaichao Can we do this instead?

# Compiled model only used for initial memory profiling
tmp_for_profile = torch.compile(model, ...)
# Compiled model used for actual execution.
self.model = torch.compile(model, ...)

@youkaichao
Copy link
Member Author

@youkaichao On Gemma-2B, this PR reduced the KV cache space from 34552 to 25752, 25% drop.

you mean you also want the compilation for profiling stage?

@WoosukKwon
Copy link
Collaborator

you mean you also want the compilation for profiling stage?

Yeah, while I'm not sure, there's a chance that not using torch.compile can lead to suboptimal memory allocation, overestimating the memory usage in memory profiling.

@youkaichao
Copy link
Member Author

@WoosukKwon how do you like the current implementation? compilation and optimiation still happens for the profiling run, but it will be discarded and does not affect later runs.

@WoosukKwon
Copy link
Collaborator

@youkaichao Looks very good to me! Thanks for the quick fix!

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Aug 27, 2024

For gemma 2b, I didn't see much performance difference: the latency decreased from 1.59s to 1.58s (batch size 8, input len 1024, output len 128).

@youkaichao
Copy link
Member Author

I'm measuring the time spent in the decode run, and I see a clear overhead reduction.

Previously, it takes 8ms for every step, now it only takes 7.5~7.6 ms.

@WoosukKwon
Copy link
Collaborator

@youkaichao I see. How does it work without --enforce-eager? Will the complied model still have only 2 guards?

@youkaichao
Copy link
Member Author

How does it work without --enforce-eager? Will the complied model still have only 2 guards?

yes, you already marked symbolic shapes, and all shapes match the symbolic shapes.

@youkaichao youkaichao merged commit 64cc644 into vllm-project:main Aug 27, 2024
20 of 26 checks passed
@youkaichao youkaichao deleted the profile_no_compile branch August 27, 2024 04:34
triple-Mu pushed a commit to triple-Mu/vllm_official that referenced this pull request Sep 4, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants