Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Chunked Prefill support for Multi Step Scheduling #7814

Conversation

varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Aug 23, 2024

Add Chunked Prefill support for Multi Step Scheduling.

Based on, #7528, this PR adds both the Force Single Step and Ignore Prefill policies.

The Ignore Prefill policy implemented as the default policy with a VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY to switch to the Force Single Step policy.

Benchmark :

Please find A100 benchmark results here

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@@ -1202,7 +1217,8 @@ def _append_slots(
the new source and destination block indices for the appended
slots.
"""
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
num_lookahead_slots = self._get_num_lookahead_slots(
is_prefill=seq_group.is_prefill())
Copy link
Contributor

@SolitaryThinker SolitaryThinker Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spent some time looking into PP. The assertion issue may be related to this change. Something is causing the batch to generate logits for the prefills in the batch and since len(sample_indices) == 0) as they don't perform sampling, the assertion fails on assert logits_applied == logits.shape[0] in _apply_min_tokens_penalty since logits_applied is the sum of all sampled_indices in the batch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks @SolitaryThinker. I believe it should be

        num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=seq_group.is_prefill() and seq_group.do_sample)

instead of just

        num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)

I haven't been able to reproduce this assertion yet. I'll keep at it. When you find some time, can you try this out as well ? Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

976d032 fixes the PP issue. The SamplingMetadata objects were being clobbered due to the SamplingMetadataCache reset.

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work! Leave some comments.

Comment on lines +350 to +351
if seq_group.state.remaining_steps > 0:
seq_group.finish_step()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you comment on this?

# The prefill sequences's remaining_step is 1 when they are
# scheduled initially. After the first step their remaining_step
# becomes 0.
if any([sgml.state.remaining_steps not in [0, 1, remaining_steps] \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if any([sgml.state.remaining_steps not in [0, 1, remaining_steps] \
if any([sgml.state.remaining_steps not in (0, 1, remaining_steps) \

@@ -60,6 +60,7 @@
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that be better if we make it more extensible like the following?

Suggested change
VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY: bool = False
VLLM_MULTI_STEP_CHUNKED_PREFILL_POLICY: str = "let-prefill-wait"

And another policy would be "single-step-with-prefill"

Comment on lines +13 to +22
class MultiStepChunkedPrefillPolicy(Enum):
# When prompt and decode sequences are scheduled together,
# the DEFAULT policy is to run the prompt and decodes sequences
# together only for the first step and run just the decode sequences
# in the rest of the steps.
DEFAULT = 1
# In FORCE_SINGLE_STEP policy, we force the scheduled sequences to
# run a single step and then re-schedule.
FORCE_SINGLE_STEP = 2
INVALID = 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should define this in vLLM instead of tests.

@@ -208,11 +277,42 @@ def prepare_model_input(
frozen_model_input = self._base_model_runner.prepare_model_input(
seq_group_metadata_list, virtual_engine, finished_requests_ids)

num_prompts = len(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change all "prompts" to "prefills".

Suggested change
num_prompts = len(
num_prefills = len(

Comment on lines +134 to +136
def without_prefills(m: "ModelInputForGPUWithSamplingMetadata",
sampling_metadata_decodes: SamplingMetadata) \
-> "ModelInputForGPUWithSamplingMetadata":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring

"""Remove prefill sequences from this model input metadata.
This is used by multi-step runner to execute the rest steps
of decode sequences.
"""

Comment on lines +154 to +155
assert (m.input_tokens is not None)
assert (m.input_positions is not None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert (m.input_tokens is not None)
assert (m.input_positions is not None)
assert m.input_tokens is not None
assert m.input_positions is not None

@@ -23,17 +48,36 @@
]


async def completions_with_server_args(prompts: List[str], model_name: str,
server_cli_args: List[str]):
class EnvContextManager():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you look around the tests and see if we already have a similar utility? If so please reuse it; otherwise please move this to the right place so that other tests in the future can use it.

os.environ.update(self.os_env)


async def completions_with_server_args(prompts: List[str],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#7651 also introduces this utility. Please coordinate with @afeldman-nm to better organize these.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac do you mean #7652?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops yes sorry for the typo.

Comment on lines +307 to +308
# TODO (varun) : Try using decode_metadata here. We hit some asserts in
# advance_step - but that seems resolvable.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate more on this? Specifically how it can be solved and what's the plan?

server_cli_args: List[str],
with_env: dict = {}): # noqa: B006
# env setup
os.environ.update(with_env)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO : fix stray update.

@varun-sundar-rabindranath
Copy link
Contributor Author

Bad performance - just multi-step is better than multi-step + chunked-prefill closing this in favor of #8378

@varun-sundar-rabindranath
Copy link
Contributor Author

@tlrmchlsmth @comaniac @SolitaryThinker I'll handle the relevant comments in the new PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants