diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index ae5583ffe061..7a3e302204bb 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -124,4 +124,4 @@ def run_to_completion(profile_dir: Optional[str] = None): 'with ui.perfetto.dev or Tensorboard.' )) args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 9f0cf5bdc989..232e0141f24d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -152,6 +152,8 @@ def __init__( model_name: str, tokenizer_name: Optional[str] = None, dtype: str = "half", + draft_model: str = None, + propose_cnt: int = 1, ) -> None: self.model = LLM( model=model_name, @@ -159,6 +161,8 @@ def __init__( trust_remote_code=True, dtype=dtype, swap_space=0, + draft_model=draft_model, + propose_cnt=propose_cnt, ) def generate( diff --git a/tests/models/test_spec_dec.py b/tests/models/test_spec_dec.py new file mode 100644 index 000000000000..4d00731232e8 --- /dev/null +++ b/tests/models/test_spec_dec.py @@ -0,0 +1,50 @@ +"""Compare the outputs of Specutiave Decoding and original vLLM + +Run `pytest tests/models/test_spec_dec.py --forked`. +""" +from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel +from vllm.config import FLAGS +import pytest + +MODELS = [ + "lmsys/vicuna-7b-v1.3", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [50]) +@pytest.mark.parametrize("draft_model", ["JackFram/llama-160m"]) +@pytest.mark.parametrize("propose_cnt", [5]) +def test_models( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + draft_model: str, + propose_cnt: int, +) -> None: + spec_vllm_model = vllm_runner(model, + dtype=dtype, + draft_model=draft_model, + propose_cnt=propose_cnt) + spec_vllm_outputs = spec_vllm_model.generate_greedy( + example_prompts, max_tokens) + del spec_vllm_model + destroy_model_parallel() + + FLAGS.ENABLE_SD = False + vllm_model = vllm_runner(model, dtype=dtype) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + spec_output_ids, spec_output_str = spec_vllm_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert spec_output_str == vllm_output_str, ( + f"Test{i}:\nSpec: {len(spec_output_str)}\nvLLM: {len(vllm_output_str)}" + ) + assert spec_output_ids == vllm_output_ids, ( + f"Test{i}:\nSpec: {len(spec_output_ids)}\nvLLM: {len(vllm_output_ids)}" + ) diff --git a/vllm/block.py b/vllm/block.py index 435aa50ca22e..57f938573a2d 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -46,6 +46,14 @@ def get_last_token_id(self) -> int: assert self.num_tokens > 0 return self.token_ids[self.num_tokens - 1] + # delete num tokens from the end in the same block + def delete_last_tokens(self, num: int) -> None: + assert num > 0 + assert num <= self.num_tokens + self.num_tokens -= num + for i in range(self.num_tokens, len(self.token_ids)): + self.token_ids[i] = _BLANK_TOKEN_ID + class PhysicalTokenBlock: """Represents the state of a block in the KV cache.""" diff --git a/vllm/config.py b/vllm/config.py index cd92d361d33c..20be4513ffb7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -13,6 +13,10 @@ _GB = 1 << 30 +class FLAGS: + ENABLE_SD = False + + class ModelConfig: """Configuration for the model. @@ -356,6 +360,14 @@ def _verify_args(self) -> None: f"({self.max_num_seqs}).") +class SpecDecConfig: + + def __init__(self, draft_model_config: ModelConfig, + propose_cnt: int) -> None: + self.draft_model_config = draft_model_config + self.propose_cnt = propose_cnt + + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 8b26319b88cd..d7988ee576d9 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -179,6 +179,15 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: self.gpu_allocator.free(last_block) return last_block.block_number, new_block.block_number + def free_tailing_blocks(self, seq: Sequence) -> None: + block_table = self.block_tables[seq.seq_id] + free_cnt = len(seq.logical_token_blocks) - len(block_table) + while free_cnt > 0: + block = block_table.pop() + self.gpu_allocator.free(block) + free_cnt -= 1 + self.block_tables[seq.seq_id] = block_table + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: # NOTE: fork does not allocate a new physical block. # Thus, it is always safe from OOM. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ca28bbdc2fb9..715e1bf8b196 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -7,7 +7,8 @@ from vllm.core.policy import PolicyFactory from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata, SequenceStatus) + SequenceGroupMetadata, SequenceStatus, + SequenceOutput) logger = init_logger(__name__) @@ -309,6 +310,27 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: def free_seq(self, seq: Sequence) -> None: self.block_manager.free(seq) + def free_invalid_kv(self, seq: Sequence, seq_out: SequenceOutput): + # if all the tokens are accepted + # draft_token_ids: [A, B, C], accepted_tokens: [A, B, C, D], invalid_token_cnt = 3 + 1 - 4 = 0 + # if part of the tokens are accepted + # draft_token_ids: [A, B, C], accepted_tokens: [A, B, D], invalid_token_cnt = 3 + 1 - 3 = 1 + invalid_token_cnt = len(seq.data.get_draft_token_ids()) + 1 - len( + seq_out.accepted_tokens) + assert invalid_token_cnt >= 0 + + if invalid_token_cnt == 0: + return invalid_token_cnt + + # delete data + seq.data.output_token_ids = seq.data.output_token_ids[: + -invalid_token_cnt] + # delete from logical table + seq.delete_tailing_tokens(invalid_token_cnt) + # delete from physical table + self.block_manager.free_tailing_blocks(seq) + return invalid_token_cnt + def free_finished_seq_groups(self) -> None: self.running = [ seq_group for seq_group in self.running diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8dec696e7fb6..2a9be9c62c8c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, SpecDecConfig) @dataclass @@ -25,7 +25,7 @@ class EngineArgs: max_parallel_loading_workers: Optional[int] = None block_size: int = 16 swap_space: int = 4 # GiB - gpu_memory_utilization: float = 0.90 + gpu_memory_utilization: float = 0.80 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_paddings: int = 256 @@ -34,6 +34,9 @@ class EngineArgs: tokenizer_revision: Optional[str] = None quantization: Optional[str] = None + draft_model: Optional[str] = None + propose_cnt: Optional[int] = None + def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model @@ -182,6 +185,21 @@ def add_cli_args( choices=['awq', 'squeezellm', None], default=None, help='Method used to quantize the weights') + + # speculative decoding setting + parser.add_argument( + '--draft-model', + type=str, + default=None, + help= + 'name or path of the huggingface model to use as the draft model') + parser.add_argument( + '--propose-cnt', + type=int, + default=5, + help= + 'for speculative decoding, number of tokens to propose each step') + return parser @classmethod @@ -213,7 +231,20 @@ def create_engine_configs( self.max_num_seqs, model_config.max_model_len, self.max_paddings) - return model_config, cache_config, parallel_config, scheduler_config + + spec_dec_config: SpecDecConfig = None + if self.draft_model: + # assume the draft model and target model share the same tokenizer + # for now, share the same seed as the target + draft_model_config = ModelConfig(self.draft_model, self.tokenizer, + self.tokenizer_mode, + self.trust_remote_code, + self.download_dir, + self.load_format, 'auto', + self.seed) + spec_dec_config = SpecDecConfig(draft_model_config, + self.propose_cnt) + return model_config, cache_config, parallel_config, scheduler_config, spec_dec_config @dataclass diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a1acdfde449a..82ae623a1e0e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, SpecDecConfig, FLAGS) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import record_metrics @@ -18,6 +18,7 @@ from vllm.transformers_utils.tokenizer import (detokenize_incrementally, get_tokenizer) from vllm.utils import Counter +from vllm.engine.spec_dec import SpecDecWorker if ray: from ray.air.util.torch_dist import init_torch_dist_process_group @@ -66,6 +67,7 @@ def __init__( cache_config: CacheConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, + spec_dec_config: Optional[SpecDecConfig], distributed_init_method: str, placement_group: Optional["PlacementGroup"], log_stats: bool, @@ -121,6 +123,12 @@ def __init__( # List of (timestamp, num_tokens) self.num_generation_tokens: List[Tuple[float, int]] = [] + self.spec_dec_worker: SpecDecWorker = None + if spec_dec_config: + self.spec_dec_worker = SpecDecWorker(spec_dec_config, + self.scheduler) + FLAGS.ENABLE_SD = True + def _init_workers(self, distributed_init_method: str): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker @@ -408,11 +416,24 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # We reuse the parent sequence here to reduce redundant memory # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs) + if last_child_sample.accepted_tokens: + # Speculative Decoding enabled: invlidate kv cache for non-accepted tokens + self.scheduler.free_invalid_kv(parent, last_child_sample) + # add the last accept token to the output_token_ids + # TODO: we need to get the logprob of the last token + last_token_id = last_child_sample.accepted_tokens[-1] + parent.append_token_id(last_token_id, {last_token_id: -1}) + # always clear draft tokens + parent.data.draft_token_probs = [] + parent.step_gen_token_ids = last_child_sample.accepted_tokens + else: + parent.append_token_id(last_child_sample.output_token, + last_child_sample.logprobs) + parent.step_gen_token_ids = [last_child_sample.output_token] child_seqs.append((parent, parent)) for seq, _ in child_seqs: + self._truncate_sequence(seq, seq_group.sampling_params) self._decode_sequence(seq, seq_group.sampling_params) self._check_stop(seq, seq_group.sampling_params) @@ -573,6 +594,11 @@ def step(self) -> List[RequestOutput]: if scheduler_outputs.is_empty(): return ignored + # only enable speculative decoding for generation run + if self.spec_dec_worker and (not scheduler_outputs.prompt_run): + self.spec_dec_worker.set_draft_tokens(seq_group_metadata_list, + scheduler_outputs) + # Execute the model. output = self._run_workers( "execute_model", @@ -582,6 +608,10 @@ def step(self) -> List[RequestOutput]: blocks_to_copy=scheduler_outputs.blocks_to_copy, ) + if self.spec_dec_worker and (not scheduler_outputs.prompt_run): + # accept will set accepted_token_ids and accepted_token_probs in output + self.spec_dec_worker.accept(output, scheduler_outputs) + return self._process_model_outputs(output, scheduler_outputs) def _log_system_stats( @@ -657,12 +687,77 @@ def _log_system_stats( f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") self.last_logging_time = now + def _truncate_step_gen_token_ids(self, seq: Sequence, + truncate_len: int) -> None: + if truncate_len > 0: + seq.step_gen_token_ids = seq.step_gen_token_ids[:-truncate_len] + + def _truncate_sequence(self, seq: Sequence, + sampling_params: SamplingParams) -> None: + + output_token_ids = seq.get_output_token_ids() + for stop_token_id in sampling_params.stop_token_ids: + if stop_token_id in seq.get_token_ids(): + # seq: [p1, p2, p3, A, B, C], stop_token: B, p1, p2, p3 are prompt tokens + # truncate_len = 4 + 1 - 3 = 2 + # we need to include the stop_token in the output + truncated_output_len = seq.get_token_ids().index( + stop_token_id) + 1 - seq.get_prompt_len() + self._truncate_step_gen_token_ids( + seq, + len(output_token_ids) - truncated_output_len) + # we don't modify logical/physical block here + seq.data.output_token_ids = output_token_ids[: + truncated_output_len] + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.scheduler_config.max_model_len: + truncated_output_len = self.scheduler_config.max_model_len - seq.get_prompt_len( + ) + self._truncate_step_gen_token_ids( + seq, + len(output_token_ids) - truncated_output_len) + seq.data.output_token_ids = output_token_ids[:truncated_output_len] + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() >= sampling_params.max_tokens: + truncated_output_len = sampling_params.max_tokens + self._truncate_step_gen_token_ids( + seq, + len(output_token_ids) - truncated_output_len) + seq.data.output_token_ids = output_token_ids[:truncated_output_len] + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and self.tokenizer.eos_token_id in seq.get_output_token_ids()): + truncated_output_len = output_token_ids.index( + self.tokenizer.eos_token_id) + 1 + self._truncate_step_gen_token_ids( + seq, + len(output_token_ids) - truncated_output_len) + seq.data.output_token_ids = output_token_ids[:truncated_output_len] + seq.status = SequenceStatus.FINISHED_STOPPED + return + def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: + if seq.tokens is None: + # prefill phase + new_token_ids = seq.get_token_ids() + else: + gen_len = len(seq.step_gen_token_ids) + new_token_ids = seq.get_output_token_ids()[-gen_len:] """Decodes the new token for a sequence.""" (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( self.tokenizer, - all_input_ids=seq.get_token_ids(), + prompt_len=seq.get_prompt_len(), + new_token_ids=new_token_ids, prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, read_offset=seq.read_offset, @@ -681,31 +776,13 @@ def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: """Stop the finished sequences.""" for stop_str in sampling_params.stop: - if seq.output_text.endswith(stop_str): + if stop_str in seq.output_text: # Truncate the output text so that the stop string is # not included in the output. - seq.output_text = seq.output_text[:-len(stop_str)] + seq.output_text = seq.output_text[:seq.output_text. + index(stop_str)] seq.status = SequenceStatus.FINISHED_STOPPED return - if seq.get_last_token_id() in sampling_params.stop_token_ids: - seq.status = SequenceStatus.FINISHED_STOPPED - return - - # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == self.tokenizer.eos_token_id): - seq.status = SequenceStatus.FINISHED_STOPPED - return def _run_workers_in_batch( self, diff --git a/vllm/engine/spec_dec.py b/vllm/engine/spec_dec.py new file mode 100644 index 000000000000..9c7e2a159b82 --- /dev/null +++ b/vllm/engine/spec_dec.py @@ -0,0 +1,206 @@ +from vllm.config import SpecDecConfig +from vllm.sequence import SequenceGroupMetadata +from transformers import AutoModelForCausalLM +import torch +from typing import List +from vllm.sequence import SamplerOutput, SequenceOutput, Sequence, SequenceGroupOutput +from vllm.core.scheduler import SchedulerOutputs, Scheduler +from vllm.worker.worker import Worker +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# FIXME: we should get pad_token_id from tokenizer +PAD_TOKEN_ID = 0 +logger.setLevel("WARNING") + + +class SpecDecWorker(Worker): + + def __init__(self, config: SpecDecConfig, scheduler: Scheduler) -> None: + self.propose_cnt = config.propose_cnt + self.draft_model_config = config.draft_model_config + self.scheduler = scheduler + + # self.draft_model = get_model(self.draft_model_config) + logger.info("Initializing speculative decoding worker: " + f"model={self.draft_model_config.model!r}, " + f"tokenizer={self.draft_model_config.tokenizer!r}, " + f"propose_cnt={self.propose_cnt}, " + f"seed={self.draft_model_config.seed})") + self.draft_model = AutoModelForCausalLM.from_pretrained( + self.draft_model_config.model).cuda() + + self.alphas = [] + + ##### values to be set ##### + self.draft_kvs = None # if we use hf stype kvs + + def _prepare_inputs( + self, seq_group_metadata_list: List[SequenceGroupMetadata] + ) -> List[torch.Tensor]: + input_ids_list = [] + for seq_group_metadata in seq_group_metadata_list: + assert len( + seq_group_metadata.seq_data + ) == 1, f"Speculative Decoding does nor beam search for now: {len(seq_group_metadata.seq_data)}" + seq_id = next(iter(seq_group_metadata.seq_data)) + seq = seq_group_metadata.seq_data[seq_id] + input_ids_list.append(seq.get_token_ids()) + max_len = max([len(input_ids) for input_ids in input_ids_list]) + input_ids_list = [ + _pad_left_to_max(input_ids, max_len, PAD_TOKEN_ID) + for input_ids in input_ids_list + ] + return torch.tensor(input_ids_list, dtype=torch.long, device='cuda') + + # TODO: we need to align draft and target model's sampler + def _sample_method(self, logits) -> torch.Tensor: + temperature = 0.0001 + return torch.softmax(logits / temperature, dim=-1) + + # propose draft tokens + # the function will run the draft model and set draft_tokens and draft_token_probs of each seq + def set_draft_tokens(self, seq_group_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs) -> None: + logger.info(f"# of input request: {len(seq_group_list)}") + input_tensor = self._prepare_inputs(seq_group_list) + draft_logits, draft_distributions, draft_tokens = [], [], [] + # recompute for now + attention_mask = (input_tensor != PAD_TOKEN_ID) + past_key_values = None + for _ in range(self.propose_cnt): + with torch.no_grad(): + outputs = self.draft_model(input_tensor, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=True) + + past_key_values = outputs.past_key_values + next_token_logits = outputs.logits[:, -1, :] + distribution = self._sample_method(next_token_logits) + attention_mask = torch.cat([ + attention_mask, + torch.ones(input_tensor.shape[0], 1, device='cuda') + ], + dim=1) + input_tensor = torch.multinomial(distribution, num_samples=1) + input_tensor = torch.argmax(distribution, dim=-1).reshape(-1, 1) + + draft_logits.append(next_token_logits) + draft_distributions.append(distribution) + draft_tokens.append(input_tensor) + + # seq_id -> Sequence + seqs = {} + for seq_group in scheduler_outputs.scheduled_seq_groups: + for id in seq_group.seqs_dict: + assert id not in seqs + seqs[id] = seq_group.seqs_dict[id] + for i, seq_group_metadata in enumerate(seq_group_list): + seq_id = next(iter(seq_group_metadata.seq_data)) + seq_data = seq_group_metadata.seq_data[seq_id] + for j in range(self.propose_cnt): + draft_token = draft_tokens[j][i].item() + seq_data.draft_token_probs.append( + {draft_token: draft_distributions[j][i]}) + # need to update seqs and seq_metadata + # update seqs to allocate logical block + # update seq_metadata to align with seqs, seq_metadata will be used in the next step to prepare inputs + seqs[seq_id].append_token_id( + draft_token, + { + draft_token: + -1 # We don't have the logprob yet, it should come from the target model + }) + seq_group_metadata.seq_data[seq_id] = seqs[seq_id].data + # allocate physical block + self.scheduler.block_manager.append_slot(seqs[seq_id]) + seq_group_metadata.block_tables[ + seq_id] = self.scheduler.block_manager.get_block_table( + seqs[seq_id]) + + logger.info(f"Seq draft tokens: {seq_data.get_draft_token_ids()}") + logger.info(f"All tokens: {seq_data.get_token_ids()}") + + def _extract_target_prob_dis(self, seq_group_output: SequenceGroupOutput, + pos: int) -> torch.Tensor: + # generation phase + sample_probdis = seq_group_output.samples[0].probdis + dis = self._sample_method(sample_probdis[pos]) + return dis.cuda() + + # Accept draft tokens based on draft probabilities and target probabilities + # The implementation strictly follows rejection sampling: + # r = rand(0, 1) + # accpet if r <= p/q + # reject and sample from a new distribution if r > p/q + # The function reads draft tokens/probs from scheduler_outputs and set accepted token_ids + # in traget_outputs + def _accept_tokens(self, seq: Sequence, + seq_group_output: SequenceOutput) -> List[int]: + accepted_token_ids = [] + for i, token_prob in enumerate(seq.data.draft_token_probs): + token_id = list(token_prob.keys())[0] + draft_prob_dis = seq.get_draft_probdis(token_id, i) + target_prob_dis = self._extract_target_prob_dis( + seq_group_output, i) + q, p = draft_prob_dis[token_id].item( + ), target_prob_dis[token_id].item() + self.alphas.append(min(p, q)) + if len(self.alphas) % 20 == 0: + logger.warning( + f"alpha: {len(self.alphas)}, {sum(self.alphas) / len(self.alphas)}" + ) + r = torch.rand(1).item() + logger.info(f"p: {p}, q: {q}, r: {r}") + if r <= p / q: # accept + accepted_token_ids.append(token_id) + else: # reject and resample + new_dis = torch.clamp(target_prob_dis - draft_prob_dis, min=0) + new_dis = new_dis / new_dis.sum(dim=-1, keepdim=True) + # next_token = torch.multinomial(new_dis, num_samples=1) + next_token = torch.argmax(new_dis, dim=-1) + # logger.warning(( + # f"next_token token: {next_token},", + # f"{torch.argmax(target_prob_dis, dim=-1)}", + # f"{torch.argmax(draft_prob_dis, dim=-1)}", + # )) + accepted_token_ids.append(next_token.item()) + break + + return accepted_token_ids + + def accept(self, target_outputs: List[SamplerOutput], + scheduler_outputs: SchedulerOutputs) -> None: + scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups + assert len(scheduled_seq_groups) == len(target_outputs) + for seq_group, seq_group_output in zip(scheduled_seq_groups, + target_outputs): + assert seq_group.num_seqs() == 1 + sample: SequenceOutput = seq_group_output.samples[0] + seq_id = list(seq_group.seqs_dict.keys())[0] + cur_seq = seq_group.seqs_dict[seq_id] + logger.info( + f"sample output: {sample.output_token}, {cur_seq.data.get_draft_token_ids()}" + ) + assert seq_id == sample.parent_seq_id, \ + (f"seq_group: {seq_id} and", + f"seq_group_output: {sample.parent_seq_id} are not aligned") + + accepted_token_ids = self._accept_tokens(cur_seq, seq_group_output) + + # all proposed tokens are accepted + if accepted_token_ids == cur_seq.data.get_draft_token_ids(): + if isinstance(sample.output_token, int): + last_token = sample.output_token + else: + last_token = sample.output_token[-1] + accepted_token_ids.append(last_token) + logger.info( + f"accept tokens: {accepted_token_ids}, {sample.output_token}") + sample.accepted_tokens = accepted_token_ids + + +def _pad_left_to_max(x: List[int], max_len: int, pad: int) -> List[int]: + return [pad] * (max_len - len(x)) + x diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b05ba71c6d35..b677bfa0cd51 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -69,7 +69,7 @@ def __init__( revision: Optional[str] = None, tokenizer_revision: Optional[str] = None, seed: int = 0, - gpu_memory_utilization: float = 0.9, + gpu_memory_utilization: float = 0.8, swap_space: int = 4, **kwargs, ) -> None: diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index e4ddf08cd9a0..f2d38b537094 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -14,19 +14,23 @@ class InputMetadata: block_tables: The block tables. (Seq id -> list of physical block) """ - def __init__( - self, - prompt_lens: List[int], - slot_mapping: torch.Tensor, - max_context_len: Optional[int], - context_lens: Optional[torch.Tensor], - block_tables: Optional[torch.Tensor], - ) -> None: + def __init__(self, + prompt_lens: List[int], + slot_mapping: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + block_tables: torch.Tensor, + start_loc: Optional[torch.Tensor] = None, + sd_len_to_gen: Optional[int] = None, + sd_prompt_lens: Optional[List[int]] = None) -> None: self.prompt_lens = prompt_lens self.max_context_len = max_context_len self.slot_mapping = slot_mapping + self.start_loc = start_loc self.context_lens = context_lens self.block_tables = block_tables + self.sd_len_to_gen = sd_len_to_gen + self.sd_prompt_lens = sd_prompt_lens self.is_prompt = len(prompt_lens) > 0 # Set during the execution of the first attention op. diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index b84af362efca..2b69fb0748d2 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -10,6 +10,8 @@ from vllm._C import ops from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.kv_mqa import context_attention_fwd +from vllm.config import FLAGS _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -163,16 +165,20 @@ def forward( ) output = out.view_as(query) else: - # Decoding run. - output = _paged_attention( - query, - key_cache, - value_cache, - input_metadata, - self.head_mapping, - self.scale, - self.alibi_slopes, - ) + if FLAGS.ENABLE_SD: + output = _multi_query_cached_kv_attention( + query, key, value, key_cache, value_cache, input_metadata) + else: + # Decoding run. + output = _paged_attention( + query, + key_cache, + value_cache, + input_metadata, + self.head_mapping, + self.scale, + self.alibi_slopes, + ) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) @@ -280,3 +286,25 @@ def _paged_attention( alibi_slopes, ) return output + + +def _multi_query_cached_kv_attention(query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata): + # Pre-allocate the output tensor. + output = torch.empty_like(query) + context_attention_fwd( + query, + key, + value, + output, + key_cache, + value_cache, + input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata.start_loc, + input_metadata.context_lens, + input_metadata.sd_prompt_lens, + input_metadata.sd_len_to_gen) + return output diff --git a/vllm/model_executor/layers/kv_mqa.py b/vllm/model_executor/layers/kv_mqa.py new file mode 100644 index 000000000000..b76c28eeeff5 --- /dev/null +++ b/vllm/model_executor/layers/kv_mqa.py @@ -0,0 +1,262 @@ +import torch +import triton +import triton.language as tl + + +def gc_torch(): + pass + + +if triton.__version__ >= "2.1.0": + + @triton.jit + def _fwd_kernel( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = (cur_batch_in_all_start_index + + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[ + None, :] * stride_qd + + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = bn[ + None, :] * stride_k_cache_bs + cur_head * stride_k_cache_h + ( + offs_d[:, None] // x) * stride_k_cache_d + ( + (start_n + offs_n[None, :]) % + block_size) * stride_k_cache_bl + ( + offs_d[:, None] % x) * stride_k_cache_x + off_v = bn[:, + None] * stride_v_cache_bs + cur_head * stride_v_cache_h + offs_d[ + None, :] * stride_v_cache_d + ( + start_n + offs_n[:, None] + ) % block_size * stride_v_cache_bl + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = offs_n[ + None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, + None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[ + None, :] * stride_vd + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = (cur_batch_in_all_start_index + + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[ + None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @torch.inference_mode() + def context_attention_fwd(q, k, v, o, k_cache, v_cache, b_loc, b_start_loc, + b_seq_len, b_ctx_len, max_input_len): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 + batch, head = b_seq_len.shape[0], q.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + num_warps = 8 if Lk <= 64 else 8 + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + v_cache.shape[3], + 8, + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 13da9aa38af0..83c2a1179f8c 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from vllm.config import FLAGS from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -39,6 +40,21 @@ def forward( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, embedding_bias: Optional[torch.Tensor] = None, + ) -> SamplerOutput: + prompt_run = sampling_metadata.num_prompts > 0 + if FLAGS.ENABLE_SD and (not prompt_run): + return self._sd_forward(embedding, hidden_states, + sampling_metadata, embedding_bias) + else: + return self._forward(embedding, hidden_states, sampling_metadata, + embedding_bias) + + def _forward( + self, + embedding: torch.Tensor, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + embedding_bias: Optional[torch.Tensor] = None, ) -> SamplerOutput: # Get the hidden states that we use for sampling. hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) @@ -97,6 +113,45 @@ def forward( return _build_sampler_output(sample_results, sampling_metadata, prompt_logprobs, sample_logprobs) + def _sd_forward( + self, + embedding: torch.Tensor, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + embedding_bias: Optional[torch.Tensor] = None, + ) -> SamplerOutput: + # Sampler forward for speculative decoding. + # It is a simiplified version of the original forward + # and only supports argmax sampling + batch_size = hidden_states.shape[0] + len_to_gen = hidden_states.shape[1] + hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) + + # Get the logits for the next tokens. + logits = _get_logits(hidden_states, embedding, embedding_bias, + self.vocab_size) + + # Do not apply templerature since we only support greedy sampling + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities. + # Use log_softmax to ensure numerical stability. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + # sample_results = torch.argmax(logprobs, dim=-1).cpu().reshape(batch_size, -1) + sample_results = _greedy_sample(sampling_metadata.seq_groups, logprobs, + len_to_gen) + prompt_logprobs, sample_logprobs = _get_logprobs( + logprobs, sampling_metadata, sample_results) + + probdis = probs.reshape(batch_size, len_to_gen, -1) + # change probs to a list of lists + probdis = [list(tensor.unbind(0)) for tensor in probdis.unbind(0)] + return _build_sampler_output(sample_results, sampling_metadata, + prompt_logprobs, sample_logprobs, probdis) + def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor], @@ -362,11 +417,12 @@ def _apply_min_p( return logits -def _greedy_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - logprobs: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: +def _greedy_sample(selected_seq_groups: List[Tuple[List[int], SamplingParams]], + logprobs: torch.Tensor, + len_to_gen: int = 1) -> List[Tuple[List[int], List[int]]]: samples = torch.argmax(logprobs, dim=-1).cpu() + if len_to_gen > 1: + samples = samples.reshape(-1, len_to_gen) sample_idx = 0 results = [] for seq_group in selected_seq_groups: @@ -375,10 +431,13 @@ def _greedy_sample( assert num_parent_seqs == 1, ( "Greedy sampling should have only one seq.") parent_ids = list(range(num_parent_seqs)) - next_token_ids = [samples[sample_idx].item()] + if len_to_gen > 1: + next_token_ids = samples[sample_idx].tolist() + else: + next_token_ids = [samples[sample_idx].item()] results.append((next_token_ids, parent_ids)) sample_idx += num_parent_seqs - assert sample_idx == logprobs.size(0) + # assert sample_idx == logprobs.size(0) return results @@ -545,13 +604,14 @@ def _get_logprobs( token_id for token_id in prompt_tokens[1:]) sample_idx += prompt_len - 1 batched_logprobs_query_seq_indices.extend( - [sample_idx + parent_id for parent_id in parent_ids]) + [sample_idx + parent_id + for parent_id in parent_ids] * len(next_token_ids)) batched_logprobs_query_token_indices.extend(next_token_ids) if sampling_params.logprobs is not None: largest_num_logprobs = max(largest_num_logprobs, sampling_params.logprobs) sample_idx += num_parent_seqs - assert sample_idx == logprobs.size(0) + # assert sample_idx == logprobs.size(0) # Batched query for logprobs of selected token batched_logprobs_query_result = logprobs[[ @@ -633,20 +693,34 @@ def _build_sampler_output( sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], + sample_probdis: List[List[torch.Tensor]] = None, ) -> SamplerOutput: + if sample_probdis is None: + sample_probdis = [[None]] * len(sample_logprobs) sampler_output = [] for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(sampling_metadata.seq_groups, - sample_results, prompt_logprobs, - sample_logprobs): + group_sample_logprobs, + group_sample_probdis) in zip(sampling_metadata.seq_groups, + sample_results, prompt_logprobs, + sample_logprobs, sample_probdis): seq_ids, _ = seq_group next_token_ids, parent_ids = sample_result seq_outputs = [] - for parent_id, next_token_id, logprobs in zip(parent_ids, - next_token_ids, - group_sample_logprobs): + + if FLAGS.ENABLE_SD and len(next_token_ids) > 1: + assert len(parent_ids) == 1 + parent_id = parent_ids[0] + # FIXME: group_sample_logprobs incorrect seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) + SequenceOutput(seq_ids[parent_id], next_token_ids, + group_sample_logprobs, group_sample_probdis)) + else: + for parent_id, next_token_id, logprobs, probdis in zip( + parent_ids, next_token_ids, group_sample_logprobs, + group_sample_probdis): + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, logprobs, + probdis)) sampler_output.append( SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) return sampler_output diff --git a/vllm/sequence.py b/vllm/sequence.py index 7d36eeac0aa0..bd73cde34141 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -2,6 +2,7 @@ import copy import enum from typing import Dict, List, Optional, Union +import torch from vllm.block import LogicalTokenBlock from vllm.sampling_params import SamplingParams @@ -66,6 +67,11 @@ def __init__( ) -> None: self.prompt_token_ids = prompt_token_ids self.output_token_ids: List[int] = [] + # we use a list here because + # we can generate the same token multiple times in different locations + # for each entry in the list, it's a map of + # token_id -> probability distribution + self.draft_token_probs: List[Dict[int, torch.Tensor]] = [] self.cumulative_logprob = 0.0 def append_token_id(self, token_id: int, logprob: float) -> None: @@ -89,6 +95,17 @@ def get_last_token_id(self) -> int: return self.prompt_token_ids[-1] return self.output_token_ids[-1] + def get_draft_token_ids(self) -> List[int]: + draft_tokens = [list(tp.keys())[0] for tp in self.draft_token_probs] + return draft_tokens + + def get_verified_token_ids(self) -> List[int]: + draft_token_ids = self.get_draft_token_ids() + assert self.output_token_ids[-len(draft_token_ids):] == draft_token_ids + if len(draft_token_ids) == len(self.output_token_ids): + return [self.prompt_token_ids[-1]] + draft_token_ids + return self.output_token_ids[-len(draft_token_ids) - 1:] + def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self.prompt_token_ids}, " @@ -119,6 +136,7 @@ def __init__( self.block_size = block_size self.data = SequenceData(prompt_token_ids) + self.step_gen_token_ids: List[int] = [] self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -140,6 +158,9 @@ def _append_logical_block(self) -> None: ) self.logical_token_blocks.append(block) + def _delete_logical_block(self, block: LogicalTokenBlock) -> None: + self.logical_token_blocks.remove(block) + def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: cursor = 0 while cursor < len(token_ids): @@ -166,6 +187,18 @@ def append_token_id( self.output_logprobs.append(logprobs) self.data.append_token_id(token_id, logprobs[token_id]) + # delete n tokens from the end of the sequence + def delete_tailing_tokens(self, n: int) -> None: + while n > 0: + assert len(self.logical_token_blocks) > 0 + last_block = self.logical_token_blocks[-1] + if last_block.num_tokens < n: + n -= last_block.num_tokens + self._delete_logical_block(last_block) + else: + last_block.delete_last_tokens(n) + break + def get_len(self) -> int: return self.data.get_len() @@ -219,6 +252,11 @@ def __repr__(self) -> str: f"status={self.status.name}, " f"num_blocks={len(self.logical_token_blocks)})") + def get_draft_probdis(self, token_id: int, pos: int) -> torch.Tensor: + token_probdis = self.data.draft_token_probs[pos] + assert token_id in token_probdis + return token_probdis[token_id] + class SequenceGroup: """A group of sequences that are generated from the same prompt. @@ -361,18 +399,24 @@ class SequenceOutput: output_token: The output token ID. logprobs: The logprobs of the output token. (Token id -> logP(x_i+1 | x_0, ..., x_i)) + accepted_tokens: The tokens that are accepted by the speculative decoding. + probdis: The probability distribution of the output token. It is used in speculative decoding. """ def __init__( self, parent_seq_id: int, - output_token: int, - logprobs: Dict[int, float], + output_token: Union[int, List[int]], + logprobs: Union[Dict[int, float], List[Dict[int, float]]], + probdis: Optional[List[Dict[int, torch.Tensor]]] = None, ) -> None: self.parent_seq_id = parent_seq_id self.output_token = output_token self.logprobs = logprobs + self.accepted_tokens = None + self.probdis = probdis + def __repr__(self) -> str: return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " f"output_token={self.output_token}, " diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 717232c4cab3..ccfc8061ab9c 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -121,32 +121,33 @@ def _convert_tokens_to_string_with_added_encoders( # under Apache 2.0 license def detokenize_incrementally( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - all_input_ids: List[int], + prompt_len: int, + new_token_ids: List[int], prev_tokens: Optional[List[str]], prefix_offset: int = 0, read_offset: int = 0, skip_special_tokens: bool = False, spaces_between_special_tokens: bool = True, ) -> Tuple[List[str], str, int, int]: - new_token_id = all_input_ids[-1] # This is the first iteration for this sequence if prev_tokens is None: new_tokens = tokenizer.convert_ids_to_tokens( - all_input_ids, skip_special_tokens=skip_special_tokens) + new_token_ids, skip_special_tokens=skip_special_tokens) output_tokens = new_tokens # 5 is an arbitrary value that should work for all # tokenizers (bigger = more conservative). # Subtract 1 extra to account for the generated token. - prefix_offset = max(len(output_tokens) - 6, 0) + prefix_offset = max(prompt_len - 5, 0) # If the first new token is a special token, we can't skip 1 extra token - if skip_special_tokens and new_token_id in tokenizer.all_special_ids: - read_offset = max(len(output_tokens), 0) + first_new_token_id = new_token_ids[0] + if skip_special_tokens and first_new_token_id in tokenizer.all_special_ids: + read_offset = max(prompt_len - 1, 0) else: - read_offset = max(len(output_tokens) - 1, 0) + read_offset = max(prompt_len, 0) else: # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( - [new_token_id], skip_special_tokens=skip_special_tokens) + new_token_ids, skip_special_tokens=skip_special_tokens) output_tokens = prev_tokens + new_tokens # The prefix text is necessary only to defeat cleanup algorithms in diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2209c994e2b8..a08c5a0415a8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -2,7 +2,7 @@ import torch -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, FLAGS from vllm.logger import init_logger from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.sampling_params import SamplingParams, SamplingType @@ -114,6 +114,12 @@ def _prepare_prompt( ) return input_tokens, input_positions, input_metadata + def _get_slot(self, block_table: List[int], position: int) -> int: + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + return slot + def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -124,6 +130,8 @@ def _prepare_decode( slot_mapping: List[List[int]] = [] context_lens: List[int] = [] block_tables: List[List[int]] = [] + sd_prompt_lens: List[int] = [] + len_to_gen: int = 1 for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt @@ -131,22 +139,40 @@ def _prepare_decode( seq_ids = list(seq_group_metadata.seq_data.keys()) for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) + if FLAGS.ENABLE_SD: + verify_tokens = seq_data.get_verified_token_ids() + len_to_gen = len(verify_tokens) + input_tokens.append(verify_tokens) - context_len = seq_data.get_len() - if self.sliding_window is not None: - context_len = min(context_len, self.sliding_window) - context_lens.append(context_len) + context_len = seq_data.get_len() + context_lens.append(context_len) - position = context_len - 1 - input_positions.append([position]) + positions = list( + range(context_len - len_to_gen, context_len)) + input_positions.append(positions) - block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append([slot]) + block_table = seq_group_metadata.block_tables[seq_id] + slots = [] + for position in positions: + slots.append(self._get_slot(block_table, position)) + slot_mapping.append(slots) + + sd_prompt_lens.append(context_len - len_to_gen) + else: + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + context_len = seq_data.get_len() + if self.sliding_window is not None: + context_len = min(context_len, self.sliding_window) + context_lens.append(context_len) + + position = context_len - 1 + input_positions.append([position]) + + block_table = seq_group_metadata.block_tables[seq_id] + slot = self._get_slot(block_table, position) + slot_mapping.append([slot]) if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window // @@ -155,15 +181,15 @@ def _prepare_decode( block_tables.append(block_table) input_tokens = _make_tensor_with_pad(input_tokens, - max_len=1, + max_len=len_to_gen, pad=0, dtype=torch.long) input_positions = _make_tensor_with_pad(input_positions, - max_len=1, + max_len=len_to_gen, pad=0, dtype=torch.long) slot_mapping = _make_tensor_with_pad(slot_mapping, - max_len=1, + max_len=len_to_gen, pad=_PAD_SLOT_ID, dtype=torch.long) max_context_len = max(context_lens) @@ -176,13 +202,25 @@ def _prepare_decode( pad=0, dtype=torch.int) - input_metadata = InputMetadata( - prompt_lens=[], - slot_mapping=slot_mapping, - max_context_len=max_context_len, - context_lens=context_lens, - block_tables=block_tables, - ) + start_loc_tensor = None + if FLAGS.ENABLE_SD: + start_loc_tensor = torch.arange(0, + len(sd_prompt_lens) * len_to_gen, + len_to_gen, + dtype=torch.long, + device='cuda') + sd_prompt_lens = torch.tensor(sd_prompt_lens, + dtype=torch.long, + device='cuda') + + input_metadata = InputMetadata(prompt_lens=[], + slot_mapping=slot_mapping, + max_context_len=max_context_len, + context_lens=context_lens, + block_tables=block_tables, + start_loc=start_loc_tensor, + sd_len_to_gen=len_to_gen, + sd_prompt_lens=sd_prompt_lens) return input_tokens, input_positions, input_metadata def _prepare_sample( @@ -223,10 +261,17 @@ def _prepare_sample( selected_token_start_idx += max_prompt_len else: num_seqs = len(seq_ids) + if FLAGS.ENABLE_SD: + assert len(seq_ids) == 1 + seq = seq_group_metadata.seq_data[seq_ids[0]] + selected_token_end_idx = selected_token_start_idx + len( + seq.get_draft_token_ids()) + 1 + else: + selected_token_end_idx = selected_token_start_idx + num_seqs + selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + num_seqs)) - selected_token_start_idx += num_seqs + range(selected_token_start_idx, selected_token_end_idx)) + selected_token_start_idx = selected_token_end_idx categorized_sample_indices[ sampling_params.sampling_type].extend( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6f5e16f0011f..2d9de65049ba 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -116,13 +116,11 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: self.model_runner.set_block_size(self.cache_engine.block_size) @torch.inference_mode() - def execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: + def execute_model(self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: # Issue cache operations. issued_cache_op = False if blocks_to_swap_in: