Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Implement vLLM V1 [1/N] #9289

Merged
merged 101 commits into from
Oct 22, 2024
Merged

[V1] Implement vLLM V1 [1/N] #9289

merged 101 commits into from
Oct 22, 2024

Conversation

WoosukKwon
Copy link
Collaborator

No description provided.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@robertgshaw2-neuralmagic
Copy link
Collaborator

👀

Copy link
Collaborator

@alexm-neuralmagic alexm-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon thanks for the hard work on this, looks like you made good progress. Left some comments/clarifications.

examples/offline_inference.py Outdated Show resolved Hide resolved
vllm/attention/selector.py Outdated Show resolved Hide resolved
vllm/commit_id.py Outdated Show resolved Hide resolved
vllm/entrypoints/llm.py Outdated Show resolved Hide resolved
# FIXME:
engine_args.max_num_seqs = max(engine_args.max_num_seqs, 2048)
engine_args.enable_chunked_prefill = False
self.llm_engine = LLMEngineV1.from_engine_args(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this where we want to switch between vllm_v1 and the old vllm?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding an if here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I introduced a new env variable VLLM_USE_V1, which is 0 by default. By setting this env variable, users can use the V1 code path.


# Calculate the slot mapping.
block_numbers = self.persistent_batch.block_table_cpu_tensor.flatten()[
token_indices // self.block_size]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How having M inside token indices (to separate requests) affects the block_numbers we get here? Isn't this results in a "jump"?

vllm_v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm_v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm_v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm_v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done most of the reviews. Has no brain left for GPUModelRunner. Will look into that more tomorrow.

vllm_v1/attention/backends/flash_attn.py Outdated Show resolved Hide resolved
vllm/config.py Outdated Show resolved Hide resolved
vllm/entrypoints/llm.py Outdated Show resolved Hide resolved
vllm/entrypoints/llm.py Outdated Show resolved Hide resolved
# FIXME:
engine_args.max_num_seqs = max(engine_args.max_num_seqs, 2048)
engine_args.enable_chunked_prefill = False
self.llm_engine = LLMEngineV1.from_engine_args(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding an if here?

vllm_v1/tokenizer/detokenizer_utils.py Outdated Show resolved Hide resolved
Comment on lines 229 to 247
def _get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_attention_layers(
parallel_config)

key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_attention_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
dtype_size = get_dtype_size(dtype)
return dtype_size * total
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not in this PR/re-arch, but eventually should we move this to the model code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yes? I actually didn't care much because the code is small and didn't bring any complexity.

Comment on lines 500 to 506
self.top_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have an abstraction for logic around self.x, self.x_cpu_tensor, self.x_cpu, self.x_reqs for different x?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please elaborate more?

vllm_v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm_v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
self._update_states(scheduler_output)
inputs = self._prepare_inputs(scheduler_output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need scheduler_output to prepare the inputs if we cache all the request states in the model runner?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scheduler output contains 1) the scheduling decision (req id -> num_tokens), and 2) all the data for new requests, and 3) new block ids for in-flight requests.

# NOTE: CPU-GPU synchronization happens here.
sampled_token_ids = sampler_output.sampled_token_ids.cpu()
sampled_token_ids_list = sampled_token_ids.tolist()
# TODO: Optimize.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you be a bit more specific on what to optimize?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more comments.

vllm/config.py Outdated Show resolved Hide resolved
Comment on lines 34 to 35
from vllm_v1.engine.llm_engine import LLMEngine as LLMEngineV1
from vllm_v1.outputs import RequestOutput as RequestOutputV1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the interface is compatible, would the following be easier?

if USE_V1:
    from vllm_v1.engine.llm_engine import LLMEngine
    from vllm_v1.outputs import RequestOutput
else:
    from vllm.engine.llm_engine import LLMEngine
    from vllm.outputs import RequestOutput  

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I introduced the VLLM_USE_V1 env variable and added a similar if statement. PTAL.

Comment on lines 32 to 37
def get_computed_blocks(self, request: Request) -> List[int]:
if not self.enable_caching:
# No prefix caching.
return []
# TODO(woosuk): Implement hash-based caching.
return []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to think about before implementing hash-based caching: Where to calculate the hash?

In block manager v1, the hash was calculated in the sequence (aka Request); while in block manager v2, the hash is calculated in the block manager. Calculating hash in Request makes sure the hash will be calculated only once during the Request life cycle, but calculating hash in block manager makes more sense because the hash should attach to cache blocks instead of sequences.

cc @rickyyx who is working on prefix-caching aware scheduler in v0.

Comment on lines 46 to 49
num_blocks = cdiv(request.num_computed_tokens + num_tokens,
self.block_size)
req_block_ids = self.req_to_block_ids[request.request_id]
num_new_blocks = num_blocks - len(req_block_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be incrementally calculated? The only missing information here to determine how many new blocks we need is the number of empty slots of the last block, and block manager should have this information, so maybe we could do something like the following

req_block_ids = self.req_to_block_ids[request.request_id]
empty_slots = req_block_ids[-1].empty_slots
if num_tokens <= empty_slots:
    # No new block is needed.
    return []

num_new_blocks = (num_tokens - empty_slots) // self.block_size
...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, could you explain why do you prefer that over this implementation? I personally found the current implementation more concise and intuitive (if I didn't miss anything).

vllm_v1/core/kv_cache_manager.py Outdated Show resolved Hide resolved
vllm_v1/sample/sampler.py Outdated Show resolved Hide resolved
vllm_v1/sample/sampler.py Outdated Show resolved Hide resolved
self.max_num_tokens = scheduler_config.max_num_batched_tokens

# Lazy initialization
self.model: nn.Module # Set after load_model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question. Is it the new feature supported by later Python version? I got an error with Python 3.9:

>>> class A:
...     def __init__(self):
...             self.data: str
...
>>> a = A()
>>> a.data
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'A' object has no attribute 'data'

if removed_req_indices:
self.persistent_batch.condense(removed_req_indices)

def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is unfortunately already complicate. I can't imagine how complicate it will be once other features (e.g., LoRA, multi-modal) are added...it seems also hard to extend these features after this function such as

model_inputs = self._prepare_inputs(...)
model_inputs = slef._prepare_lora_inputs(model_inputs)

vllm_v1/worker/gpu_worker.py Outdated Show resolved Hide resolved
@WoosukKwon WoosukKwon changed the title Add vllm_v1 [V1] Implement vLLM V1 [1/N] Oct 21, 2024
@WoosukKwon
Copy link
Collaborator Author

@zhuohan123 Can you please take another look?

@WoosukKwon WoosukKwon marked this pull request as ready for review October 21, 2024 13:51
@WoosukKwon WoosukKwon merged commit 6c5af09 into main Oct 22, 2024
30 checks passed
@WoosukKwon WoosukKwon deleted the re-arch-v1 branch October 22, 2024 08:24
charlifu pushed a commit to charlifu/vllm that referenced this pull request Oct 23, 2024
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Oct 23, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
MErkinSag pushed a commit to MErkinSag/vllm that referenced this pull request Oct 26, 2024
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
# OPTIMIZATION: Cache the request output and update it incrementally.
# This is used to avoid creating a new RequestOutput object every step.
# Request id -> RequestOutput
self.request_outputs: Dict[str, RequestOutput] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One (very late note) --> this caching may cause a bug. With AsyncLLMEngine, we will put these RequestOutput objects into the per request output queues which the OpenAI server then uses to make the objects sent back to the client. If the LLMEngine gets ahead of the AsyncLLMEngine, we will mutate the object before the OpenAI server has a chance to make its output.

FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
@lixiaolx lixiaolx mentioned this pull request Nov 13, 2024
5 tasks
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants