Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Sundar Rabindranath committed Sep 16, 2024
1 parent ff5529a commit 500a2d3
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from vllm import _custom_ops as ops
from vllm.distributed import get_pp_group
from vllm.engine.llm_engine import (SchedulerContext,
SchedulerContextOuptutQueueData)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
SamplerOutput,
Expand All @@ -29,8 +31,6 @@
BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
_init_frozen_model_input_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from vllm.engine.llm_engine import (SchedulerContext,
SchedulerContextOuptutQueueData)

from ..model_executor.model_loader.tensorizer import TensorizerConfig

Expand Down Expand Up @@ -429,7 +429,8 @@ def execute_model(
# path for warm up runs
if not model_input.is_multi_step:
return self._base_model_runner.execute_model(
model_input.frozen_model_input, kv_caches, intermediate_tensors, num_steps)
model_input.frozen_model_input, kv_caches,
intermediate_tensors, num_steps)

# make sure we skip the sampler on the lask rank and only pythonize
# if CPU is ahead.
Expand All @@ -444,8 +445,8 @@ def execute_model(
self._base_model_runner.model.sampler.include_gpu_probs_tensor = (
True)
if model_input.frozen_model_input.sampling_metadata:
model_input.frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
True)
model_input.frozen_model_input.sampling_metadata.\
skip_sampler_cpu_output = (True)

# some pre-execute model logic for multi-step:
# - if it's the first step, we need to reset the sampling tensors
Expand All @@ -468,9 +469,11 @@ def execute_model(
model_input, model_input.cached_outputs[-1].sampler_output)

if model_input.base_output_proc_callback is None:
assert model_input.frozen_model_input is not None
model_input.base_output_proc_callback = \
model_input.frozen_model_input.async_callback

assert model_input.frozen_model_input is not None
if model_input.frozen_model_input.async_callback is not None:
assert model_input.base_output_proc_callback is not None
async_callback = functools.partial(
Expand All @@ -483,10 +486,11 @@ def execute_model(
async_callback=async_callback)

# Execute the model
output = self._base_model_runner.execute_model(model_input.frozen_model_input,
kv_caches,
intermediate_tensors,
num_steps=1)
output = self._base_model_runner.execute_model(
model_input.frozen_model_input,
kv_caches,
intermediate_tensors,
num_steps=1)

# record the event for the current step so that the next step can sync
model_input.record_step_event(current_stream)
Expand Down Expand Up @@ -517,6 +521,7 @@ def execute_model(

# Pythonize the output if CPU is ahead and the previous step is
# ready.
assert model_input.frozen_model_input is not None
if model_input.frozen_model_input.async_callback is None:
for model_output in model_input.cached_outputs:
model_output.maybe_pythonize(model_input,
Expand Down

0 comments on commit 500a2d3

Please sign in to comment.