From 8915b8efee56c4dad63003cd1d8c124110adf24c Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Tue, 28 Jan 2025 12:25:06 +0200 Subject: [PATCH 1/3] Enable delayed sampling --- vllm/worker/hpu_model_runner.py | 92 ++++++++++++++++++++++++++++++--- 1 file changed, 84 insertions(+), 8 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index a543cd709a9e..2afdb45cc924 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -79,6 +79,9 @@ LORA_WARMUP_RANK = 8 +VLLM_DELAYED_SAMPLING = os.environ.get('VLLM_DELAYED_SAMPLING', 'false').lower() == 'true' +DUMMY_TOKEN_ID = -1 + def subtuple(obj: object, typename: str, @@ -708,9 +711,11 @@ def __init__( raise ValueError( "Speculative decoding is not supported with " "contiguous PA, please set VLLM_CONTIGUOUS_PA=false") - # For multi-step scheduling + # For both multi-step scheduling and delayed sampling self.cached_step_outputs: List[torch.Tensor] = [] self.is_pooler = False + # For delayed sampling + self.cached_step_inputs: List[ModelInputForHPUWithSamplingMetadata] = [] def _set_gc_threshold(self) -> None: # Read https://docs.python.org/3/library/gc.html#gc.set_threshold @@ -850,7 +855,7 @@ def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): for seq_group_metadata in seq_group_metadata_list) temperature = 0.0 if has_greedy_samples else 1.0 dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( - 0, 0, is_prompt, temperature=temperature) + -1, 0, is_prompt, temperature=temperature) seq_group_metadata_list.extend(dummy_seq_group_metadata for _ in range(batch_size_padding)) return seq_group_metadata_list, real_batch_size, batch_size_padded @@ -2288,6 +2293,20 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], return lora_mask, lora_logits_mask + def _get_seq_ids(self, model_input): + return ([sg.seq_ids[0] + for sg in model_input.sampling_metadata.seq_groups]) + + def _pad_to_max_num_seqs(self, tensor, value): + padding_needed = self.max_num_seqs - tensor.size(0) + if padding_needed: + padding = torch.full((padding_needed, *tensor.shape[1:]), + value, + device=tensor.device, + dtype=tensor.dtype) + tensor = torch.cat([tensor, padding]) + return tensor + @torch.inference_mode() def execute_model( self, @@ -2299,6 +2318,26 @@ def execute_model( previous_hidden_states: Optional[torch.Tensor] = None, seqs=None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode + assert not (use_delayed_sampling and num_steps != 1), \ + 'Delayed sampling is not compatible with MSS!' + if use_delayed_sampling and not model_input.is_prompt: + num_cached = len(self.cached_step_outputs) + assert num_cached > 0 + cur_seq_ids = self._get_seq_ids(model_input) + cur_seq_id_pos = {sid: idx for idx, sid in enumerate(cur_seq_ids) if sid >= 0} + htorch.core.mark_step() + for i in range(num_cached): + prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i]) + target_indices = [cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids] + padding = self.cached_step_outputs[i].size(0) - len(target_indices) + target_indices.extend([-1] * padding) + target_indices = torch.tensor(target_indices, + device=model_input.input_tokens.device, + dtype=model_input.input_tokens.dtype) + model_input.input_tokens.index_copy_(0, target_indices, self.cached_step_outputs[i]) + htorch.core.mark_step() + if not model_input.is_first_multi_step: if not model_input.is_last_step: # not first or last multi-step @@ -2365,7 +2404,7 @@ def execute_model( f"graphs{'T' if use_graphs else 'F'}") else: model_event_name = 'model_executable' - if num_steps > 1: + if num_steps > 1 or use_delayed_sampling: # in case of multi-step scheduling # we only want to pythonize in the last step sampling_metadata.skip_sampler_cpu_output = True @@ -2433,9 +2472,9 @@ def try_revert_dummy_output_tokens(): if not self.is_driver_worker: continue - if model_input.async_callback is not None: - model_input.async_callback() - # Sample the next token. + if use_delayed_sampling: + fake_output = self._delayed_sampler_outputs(model_input) + with self.profiler.record_event( 'internal', ('sample_' f'{"prompt" if is_prompt else "decode"}_' @@ -2448,9 +2487,16 @@ def try_revert_dummy_output_tokens(): ) if num_steps > 1: output = output.sampled_token_ids - self.cached_step_outputs.append( - output.detach().clone()) + self.cached_step_outputs.append(output) + if use_delayed_sampling: + self._patch_prev_output() + output = self._pad_to_max_num_seqs( + output.sampled_token_ids, DUMMY_TOKEN_ID) + self.cached_step_outputs.append(output) + self.cached_step_inputs.append(model_input) htorch.core.mark_step() + if model_input.async_callback is not None: + model_input.async_callback() if i < num_steps - 1: if i == 0: if model_input.async_callback is not None: @@ -2544,12 +2590,21 @@ def try_revert_dummy_output_tokens(): if model_input.is_prompt: output.prefill_hidden_states = hidden_states output.hidden_states = hidden_states + if use_delayed_sampling: + return [fake_output] + return [output] if self.is_driver_worker else [] else: return [] return output if type(output) is list else [output] + def _delayed_sampler_outputs(self, model_input): + next_token_ids = [[DUMMY_TOKEN_ID]] * len(model_input.sampling_metadata.seq_groups) + sampler_output = self._make_decode_output( + next_token_ids, model_input.sampling_metadata.seq_groups) + return sampler_output + def _decode_sampler_outputs(self, model_input): use_async_out_proc = model_input.async_callback is not None sampler_outputs = [] @@ -2599,3 +2654,24 @@ def _make_decode_output( sampler_outputs.append( CompletionSequenceGroupOutput(seq_outputs, None)) return SamplerOutput(sampler_outputs) + + def _patch_prev_output(self): + assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \ + f'Inputs and outputs are out of sync! {len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}' + if len(self.cached_step_inputs) == 0: + return + model_input = self.cached_step_inputs.pop(0) + delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze(-1).tolist() + ctx = model_input.async_callback.keywords["ctx"] + assert len(ctx.output_queue) == 1, 'There should be exactly 1 output waiting!' + output_data = ctx.output_queue[0] + assert len(output_data.outputs) == 1 + for fake_out, real_out in zip(output_data.outputs[0], delayed_output): + fake_out.samples[0].output_token = real_out + for sg, real_out in zip(output_data.seq_group_metadata_list, delayed_output): + assert len(sg.seq_data) == 1 + seq_data = list(sg.seq_data.values())[0] + # This is a hack. Assigning output_token_ids triggers + # a cache recomputation and we only need to update the last token + seq_data.output_token_ids_array[-1] = real_out + seq_data._cached_all_token_ids[-1] = real_out From b152335dd83932eb24c642dc71e18eca51be47f5 Mon Sep 17 00:00:00 2001 From: Kamil Kaczor Date: Wed, 5 Mar 2025 11:42:51 +0100 Subject: [PATCH 2/3] Cherry-pick of: "Delayed sampling tp fix #834" (#885) Cherry-pick of: https://github.com/HabanaAI/vllm-fork/pull/834 which should fix TP>1 issues. --------- Co-authored-by: Tianmu Li --- vllm/worker/hpu_model_runner.py | 77 ++++++++++++++++++++++++--------- 1 file changed, 56 insertions(+), 21 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 2afdb45cc924..28b3d364a7e2 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -79,7 +79,8 @@ LORA_WARMUP_RANK = 8 -VLLM_DELAYED_SAMPLING = os.environ.get('VLLM_DELAYED_SAMPLING', 'false').lower() == 'true' +VLLM_DELAYED_SAMPLING = os.environ.get('VLLM_DELAYED_SAMPLING', + 'false').lower() == 'true' DUMMY_TOKEN_ID = -1 @@ -715,7 +716,8 @@ def __init__( self.cached_step_outputs: List[torch.Tensor] = [] self.is_pooler = False # For delayed sampling - self.cached_step_inputs: List[ModelInputForHPUWithSamplingMetadata] = [] + self.cached_step_inputs: List[ + ModelInputForHPUWithSamplingMetadata] = [] def _set_gc_threshold(self) -> None: # Read https://docs.python.org/3/library/gc.html#gc.set_threshold @@ -2294,8 +2296,9 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], return lora_mask, lora_logits_mask def _get_seq_ids(self, model_input): - return ([sg.seq_ids[0] - for sg in model_input.sampling_metadata.seq_groups]) + return ([ + sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups + ]) def _pad_to_max_num_seqs(self, tensor, value): padding_needed = self.max_num_seqs - tensor.size(0) @@ -2321,21 +2324,31 @@ def execute_model( use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode assert not (use_delayed_sampling and num_steps != 1), \ 'Delayed sampling is not compatible with MSS!' - if use_delayed_sampling and not model_input.is_prompt: + assert model_input.input_tokens is not None + if use_delayed_sampling and not model_input.is_prompt and \ + self.is_driver_worker: num_cached = len(self.cached_step_outputs) assert num_cached > 0 cur_seq_ids = self._get_seq_ids(model_input) - cur_seq_id_pos = {sid: idx for idx, sid in enumerate(cur_seq_ids) if sid >= 0} + cur_seq_id_pos = { + sid: idx + for idx, sid in enumerate(cur_seq_ids) if sid >= 0 + } htorch.core.mark_step() for i in range(num_cached): prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i]) - target_indices = [cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids] - padding = self.cached_step_outputs[i].size(0) - len(target_indices) + target_indices = [ + cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids + ] + padding = self.cached_step_outputs[i].size(0) - len( + target_indices) target_indices.extend([-1] * padding) - target_indices = torch.tensor(target_indices, - device=model_input.input_tokens.device, - dtype=model_input.input_tokens.dtype) - model_input.input_tokens.index_copy_(0, target_indices, self.cached_step_outputs[i]) + target_indices = torch.tensor( + target_indices, + device=model_input.input_tokens.device, + dtype=model_input.input_tokens.dtype) + model_input.input_tokens.index_copy_( + 0, target_indices, self.cached_step_outputs[i]) htorch.core.mark_step() if not model_input.is_first_multi_step: @@ -2353,7 +2366,21 @@ def execute_model( assert model_input.lora_mapping is not None self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) - input_tokens = model_input.input_tokens + # Rank!=0 workers has is_prompt==None + if use_delayed_sampling and not model_input.is_prompt and \ + model_input.input_tokens.size(1) == 1: + if self.is_driver_worker: + model_kwargs_broadcast_data = { + "input_tokens": model_input.input_tokens + } + broadcast_tensor_dict(model_kwargs_broadcast_data, src=0) + input_tokens = model_input.input_tokens + + else: + model_kwargs_broadcast_data = broadcast_tensor_dict(src=0) + input_tokens = model_kwargs_broadcast_data["input_tokens"] + else: + input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata sampling_metadata = model_input.sampling_metadata @@ -2488,7 +2515,7 @@ def try_revert_dummy_output_tokens(): if num_steps > 1: output = output.sampled_token_ids self.cached_step_outputs.append(output) - if use_delayed_sampling: + if use_delayed_sampling and self.is_driver_worker: self._patch_prev_output() output = self._pad_to_max_num_seqs( output.sampled_token_ids, DUMMY_TOKEN_ID) @@ -2591,7 +2618,10 @@ def try_revert_dummy_output_tokens(): output.prefill_hidden_states = hidden_states output.hidden_states = hidden_states if use_delayed_sampling: - return [fake_output] + if self.is_driver_worker: + return [fake_output] + else: + return [] return [output] if self.is_driver_worker else [] else: @@ -2600,7 +2630,8 @@ def try_revert_dummy_output_tokens(): return output if type(output) is list else [output] def _delayed_sampler_outputs(self, model_input): - next_token_ids = [[DUMMY_TOKEN_ID]] * len(model_input.sampling_metadata.seq_groups) + next_token_ids = [[DUMMY_TOKEN_ID]] * len( + model_input.sampling_metadata.seq_groups) sampler_output = self._make_decode_output( next_token_ids, model_input.sampling_metadata.seq_groups) return sampler_output @@ -2657,18 +2688,22 @@ def _make_decode_output( def _patch_prev_output(self): assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \ - f'Inputs and outputs are out of sync! {len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}' + f'''Inputs and outputs are out of sync! + {len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}''' if len(self.cached_step_inputs) == 0: return model_input = self.cached_step_inputs.pop(0) - delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze(-1).tolist() - ctx = model_input.async_callback.keywords["ctx"] - assert len(ctx.output_queue) == 1, 'There should be exactly 1 output waiting!' + delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze( + -1).tolist() + ctx = model_input.async_callback.keywords["ctx"] # type: ignore + assert len( + ctx.output_queue) == 1, 'There should be exactly 1 output waiting!' output_data = ctx.output_queue[0] assert len(output_data.outputs) == 1 for fake_out, real_out in zip(output_data.outputs[0], delayed_output): fake_out.samples[0].output_token = real_out - for sg, real_out in zip(output_data.seq_group_metadata_list, delayed_output): + for sg, real_out in zip(output_data.seq_group_metadata_list, + delayed_output): assert len(sg.seq_data) == 1 seq_data = list(sg.seq_data.values())[0] # This is a hack. Assigning output_token_ids triggers From 43b50bd89978168d3f21d9ab48ee9791eafdb2e7 Mon Sep 17 00:00:00 2001 From: Kamil Kaczor Date: Wed, 5 Mar 2025 13:22:11 +0100 Subject: [PATCH 3/3] Cherry-pick "Fixes delayed sampling for sequential requests #845" (#888) Cherry-pick of: https://github.com/HabanaAI/vllm-fork/pull/845 fixing issue in fe. static benchmarks --- vllm/worker/hpu_model_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 28b3d364a7e2..df6889a25af2 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -2696,6 +2696,10 @@ def _patch_prev_output(self): delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze( -1).tolist() ctx = model_input.async_callback.keywords["ctx"] # type: ignore + # If there's no output to patch with, which is usually the case when + # we're starting a new request after all requests are completed. + if len(ctx.output_queue) == 0: + return assert len( ctx.output_queue) == 1, 'There should be exactly 1 output waiting!' output_data = ctx.output_queue[0]