From 64aa1e8ba583b84f294114f15dc3e915e42e2c1d Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 30 Apr 2024 04:47:13 -0700 Subject: [PATCH 01/12] ip --- .../test_basic_correctness.py | 5 ++- vllm/attention/backends/xformers.py | 2 +- vllm/worker/model_runner.py | 39 +++++++++++++------ 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 97cff623c5e1..764344ac9436 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -6,14 +6,15 @@ MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", + # "meta-llama/Llama-2-7b-hf", ] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [False, True]) +# @pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("enforce_eager", [True]) def test_models( hf_runner, vllm_runner, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 572a4dc79a71..e5a2e34f8e6d 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -256,7 +256,7 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.context_lens, + decode_meta.prompt_lens_tensor, decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0704f5fec54d..9e3c10eb1e8b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -240,11 +240,13 @@ def _prepare_prompt( if len(seq_group_metadata_list) == 0: return PreparePromptMetadata.empty() + is_prompt = False for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt + # assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 + # assert len(seq_ids) == 1 seq_id = seq_ids[0] + is_prompt = seq_group_metadata.is_prompt computed_block_nums = seq_group_metadata.computed_block_nums if (self.scheduler_config is not None @@ -273,7 +275,8 @@ def _prepare_prompt( computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) - elif self.scheduler_config.chunked_prefill_enabled: + # elif self.scheduler_config.chunked_prefill_enabled: + else: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. block_table = seq_group_metadata.block_tables[seq_id] @@ -281,11 +284,11 @@ def _prepare_prompt( else: # The first prefill. prefix_block_tables.append([]) - else: - prefix_block_tables.append([]) - # Right now, prefill start is always 0. However, this - # assumption can be changed once chunked prefill is introduced. - assert computed_len == 0 + # else: + # prefix_block_tables.append([]) + # # Right now, prefill start is always 0. However, this + # # assumption can be changed once chunked prefill is introduced. + # assert computed_len == 0 # actual prompt lens context_lens.append(computed_len) @@ -377,7 +380,7 @@ def _prepare_prompt( device=self.device) prompt_lens_tensor = torch.tensor(prompt_lens, - dtype=torch.long, + dtype=torch.int, device=self.device) seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, dtype=torch.int32, @@ -394,11 +397,11 @@ def _prepare_prompt( out=seq_start_loc[1:]) attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, + is_prompt=is_prompt, prompt_lens=prompt_lens, prompt_lens_tensor=prompt_lens_tensor, max_subquery_len=max_subquery_len, - max_context_len=None, + max_context_len=max(context_lens), max_prompt_len=max_prompt_len, subquery_start_loc=subquery_start_loc, seq_start_loc=seq_start_loc, @@ -573,15 +576,27 @@ def prepare_input_tensors( multi_modal_input, slot_mapping, ) = self._prepare_prompt(prefill_reqs) + # ( + # decode_input_tokens, + # decode_input_positions, + # decode_attn_metadata, + # decode_lora_index_mapping, + # decode_lora_prompt_mapping, + # decode_lora_requests, + # decode_slot_mapping, + # ) = self._prepare_decode(decode_reqs) ( decode_input_tokens, decode_input_positions, decode_attn_metadata, + _, + _, decode_lora_index_mapping, decode_lora_prompt_mapping, decode_lora_requests, + _, decode_slot_mapping, - ) = self._prepare_decode(decode_reqs) + ) = self._prepare_prompt(decode_reqs) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, prompt_lens, subquery_lens, self.device, self.pin_memory) From 5760bc6088833869a258f55c439d2a83c940a364 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 1 May 2024 01:16:08 -0700 Subject: [PATCH 02/12] . --- tests/kernels/test_prefix_prefill.py | 16 +- tests/samplers/test_sampler.py | 34 +-- tests/spec_decode/utils.py | 8 +- tests/test_logits_processor.py | 8 +- tests/worker/test_model_runner.py | 88 +++---- vllm/attention/backends/flash_attn.py | 30 +-- vllm/attention/backends/rocm_flash_attn.py | 46 ++-- vllm/attention/backends/torch_sdpa.py | 34 +-- vllm/attention/backends/xformers.py | 75 ++++-- vllm/attention/ops/paged_attn.py | 25 +- vllm/config.py | 17 +- vllm/engine/arg_utils.py | 9 +- vllm/entrypoints/llm.py | 5 + vllm/model_executor/layers/sampler.py | 6 +- vllm/model_executor/sampling_metadata.py | 53 ++-- vllm/worker/cpu_model_runner.py | 36 +-- vllm/worker/model_input.py | 290 +++++++++++++++++++++ vllm/worker/model_runner.py | 153 +++++------ vllm/worker/neuron_model_runner.py | 30 +-- 19 files changed, 640 insertions(+), 323 deletions(-) create mode 100644 vllm/worker/model_input.py diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index ad31b0a7c2a1..9f5104ad2544 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -48,12 +48,12 @@ def test_contexted_kv_attention( cache_size = 640 block_size = 32 max_block_per_request = 64 - subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] - seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] num_kv_heads = num_heads // num_queries_per_kv - num_tokens = sum(subquery_lens) + num_tokens = sum(query_lens) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query.uniform_(-1e-3, 1e-3) output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) @@ -72,15 +72,15 @@ def test_contexted_kv_attention( num_kv_heads, head_size, dtype=dtype) - k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) - v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] block_table = values[:BS * max_block_per_request].view( BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN @@ -89,7 +89,7 @@ def test_contexted_kv_attention( dtype=torch.long), dim=0) for i in range(BS): - for j in range(subquery_lens[i]): + for j in range(query_lens[i]): k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + @@ -155,7 +155,7 @@ def test_contexted_kv_attention( value = value.unsqueeze(0) attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - subquery_lens, seq_lens) + query_lens, seq_lens) output_ref = xops.memory_efficient_attention_forward( query, key, diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 7859f0b21812..e4fea165a4d4 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -58,7 +58,7 @@ def _do_sample( device: str, ): seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -68,12 +68,12 @@ def _do_sample( sampling_params=sampling_params, block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=device, pin_memory=model_runner.pin_memory) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @@ -421,7 +421,7 @@ def run_test_case(*, "Invalid test case, need seq_group_metadata_list" batch_size = 0 - prompt_lens = [] + seq_lens = [] sampling_params_per_row = [] for sgm in seq_group_metadata_list: sampling_params = sgm.sampling_params @@ -431,7 +431,7 @@ def run_test_case(*, # a prompt seq_group has only one sequence seq_data = next(iter(sgm.seq_data.values())) prompt_len = seq_data.get_prompt_len() - prompt_lens.append(prompt_len) + seq_lens.append(prompt_len) if sgm.sampling_params.prompt_logprobs: # with prompt_logprobs each token in the prompt has a row in @@ -451,8 +451,8 @@ def run_test_case(*, _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens=prompt_lens if prompt_lens else None, - subquery_lens=prompt_lens if prompt_lens else None, + seq_lens=seq_lens if seq_lens else None, + query_lens=seq_lens if seq_lens else None, device=device, pin_memory=model_runner.pin_memory) # the logits tensor is modified in-place by the sampler @@ -497,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str): seq_group_metadata_list = [] expected_tokens: List[Optional[List[int]]] = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): expected: Optional[List[int]] = None sampling_type = random.randint(0, 3) @@ -532,13 +532,13 @@ def test_sampler_mixed(seed: int, device: str): sampling_params=sampling_params, block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) def test_sampling(model_runner: ModelRunner): sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=device, pin_memory=model_runner.pin_memory) sampler_output = sampler(logits=fake_logits, @@ -575,7 +575,7 @@ def test_sampling(model_runner: ModelRunner): # Shuffle the batch and resample target_index = list(range(batch_size)) for list_to_shuffle in (target_index, seq_group_metadata_list, - expected_tokens, prompt_lens): + expected_tokens, seq_lens): random.Random(seed).shuffle(list_to_shuffle) target_index = torch.tensor(target_index) input_tensor.data = input_tensor.index_select(0, target_index) @@ -620,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): assert len(warpers) == 2 # top_p and top_k seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -634,12 +634,12 @@ def test_sampler_top_k_top_p(seed: int, device: str): ), block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=device, pin_memory=model_runner.pin_memory) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 4f8295d25cf4..87c7d88a80f4 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -144,7 +144,7 @@ def create_seq_group_metadata_from_prompts( prompts: List[List[int]], num_gpu_blocks: int, block_size: int, - final_seq_lens: List[int], + final_prompt_lens: List[int], continuations: Optional[List[List[int]]] = None, seq_ids: Optional[List[int]] = None, ) -> List[SequenceGroupMetadata]: @@ -162,7 +162,7 @@ def create_seq_group_metadata_from_prompts( free_gpu_blocks.pop() for _ in range(round_up_to_next_block(final_len, block_size)) ] - for i, final_len in enumerate(final_seq_lens) + for i, final_len in enumerate(final_prompt_lens) } return [ @@ -251,13 +251,13 @@ def create_batch(batch_size, prev_output_tokens = [[ next(iterator) for _ in range(prev_output_token_len) ] for _ in range(batch_size)] - final_seq_lens = [ + final_prompt_lens = [ len(prompt) + len(prev_output_token) + k + 1 for prompt, prev_output_token in zip(prompts, prev_output_tokens) ] execute_model_data = create_execute_model_data( create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks, - block_size, final_seq_lens, + block_size, final_prompt_lens, prev_output_tokens, seq_ids), ) return execute_model_data, prompts, prev_output_tokens diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index dbaeb4de1825..179e8d25a341 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -70,7 +70,7 @@ def pick_ith(token_ids, logits): return logits seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -81,12 +81,12 @@ def pick_ith(token_ids, logits): logits_processors=[pick_ith]), block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) logits_processor_output = logits_processor( diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 56fe6db589f1..4c7d3673ca95 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -23,14 +23,14 @@ def test_prepare_prompt(batch_size): lora_config=None) model_runner.set_block_size(16) - prompt_lens = [] + seq_lens = [] seq_group_metadata_list = [] block_tables = {0: [1]} for i in range(batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = SequenceData(list(range(prompt_len))) + seqlen = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seqlen) + seq_data = SequenceData(list(range(seqlen))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -43,29 +43,29 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices = [] selected_token_start_idx = 0 - for prompt_len in prompt_lens: + for seqlen in seq_lens: expected_selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) - selected_token_start_idx += prompt_len - (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, + seqlen - 1) + selected_token_start_idx += seqlen + (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) - assert return_prompt_lens == prompt_lens + assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is True - assert torch.allclose(attn_metadata.prompt_lens_tensor, - torch.tensor(prompt_lens, device=device)) - assert attn_metadata.prompt_lens == prompt_lens - assert attn_metadata.max_prompt_len == max(prompt_lens) + assert torch.allclose(attn_metadata.seq_lens_tensor, + torch.tensor(seq_lens, device=device)) + assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.max_seqlen == max(seq_lens) # Test subquery start locs. start_idx = 0 start_loc = [start_idx] - for prompt_len in prompt_lens: - start_idx += prompt_len + for seqlen in seq_lens: + start_idx += seqlen start_loc.append(start_idx) assert torch.allclose( attn_metadata.subquery_start_loc, @@ -75,8 +75,8 @@ def test_prepare_prompt(batch_size): # equivalent to subquery_start_loc. start_idx = 0 seq_start_loc = [start_idx] - for prompt_len in prompt_lens: - start_idx += prompt_len + for seqlen in seq_lens: + start_idx += seqlen seq_start_loc.append(start_idx) assert torch.allclose( @@ -96,18 +96,18 @@ def test_prepare_prompt(batch_size): # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False - assert len(input_tokens) == sum(prompt_lens) - assert len(input_positions) == sum(prompt_lens) + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) torch.testing.assert_close(input_tokens, input_positions) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) - assert len(input_tokens) == sum(prompt_lens) - assert len(input_positions) == sum(prompt_lens) + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, @@ -146,13 +146,13 @@ def test_prepare_decode_cuda_graph(batch_size): lora_config=None) model_runner.set_block_size(16) - prompt_lens = [] + seq_lens = [] seq_group_metadata_list = [] for i in range(batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = list(range(prompt_len)) + seqlen = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seqlen) + seq_data = list(range(seqlen)) seq_data = SequenceData(seq_data) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -172,14 +172,14 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is False - assert attn_metadata.prompt_lens is None - assert attn_metadata.max_prompt_len is None + assert attn_metadata.seq_lens is None + assert attn_metadata.max_seqlen is None assert attn_metadata.subquery_start_loc is None assert attn_metadata.seq_start_loc is None - assert attn_metadata.max_context_len == max(prompt_lens) + assert attn_metadata.max_context_len == max(seq_lens) assert torch.allclose( - attn_metadata.context_lens[:len(prompt_lens)], - torch.tensor(prompt_lens, dtype=torch.int, device=device)) + attn_metadata.context_lens[:len(seq_lens)], + torch.tensor(seq_lens, dtype=torch.int, device=device)) # block table's first index corresponds to each batch, meaning in # decoding it is each token. @@ -198,13 +198,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify Sampling expected_selected_token_indices = [] selected_token_start_idx = 0 - for prompt_len in prompt_lens: + for seqlen in seq_lens: expected_selected_token_indices.append(selected_token_start_idx) selected_token_start_idx += 1 sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) actual = sampling_metadata.selected_token_indices @@ -241,14 +241,14 @@ def test_empty_seq_group(): assert attn_metadata is None assert len(slot_mapping) == 0 - (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, + (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None assert len(slot_mapping) == 0 - assert len(return_prompt_lens) == 0 + assert len(return_seq_lens) == 0 @pytest.fixture @@ -288,7 +288,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): model_runner.set_block_size(16) # Add prefill requests. - prompt_lens = [] + seq_lens = [] seq_group_metadata_list = [] prefill_metadata_list = [] decode_metadata_list = [] @@ -297,9 +297,9 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): decode_batch_size = batch_size - prefill_batch_size for i in range(prefill_batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = SequenceData(list(range(prompt_len))) + seqlen = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seqlen) + seq_data = SequenceData(list(range(seqlen))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -314,8 +314,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # Add decode requests for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_toks = list(range(prompt_len)) + seqlen = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(seqlen)) seq_data = SequenceData(prompt_toks) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -343,7 +343,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): else: assert attn_metadata.num_decode_tokens == _get_graph_batch_size( decode_batch_size) - assert attn_metadata.num_prefill_tokens == sum(prompt_lens) + assert attn_metadata.num_prefill_tokens == sum(seq_lens) # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 12e8c4404b94..d665e7b71a20 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -67,26 +67,26 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # or all decoding. True if all sequences are prompts. is_prompt: bool # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # NOTE(sang): Definition of context_len, query_len, and seqlen. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seqlen ----------------------| - # |- subquery_len -| + # |-- query_len ---| # WARNING(sang): context_len has different definition depending on if it is # prefill vs decoding. When it is prefill, it doesn't include new tokens. # When it is for decoding, it includes a new token. - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # Maximum query length in the batch. + max_query_len: Optional[int] + # Maximum sequence length in the batch. + max_seqlen: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -223,8 +223,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prompt_len, - max_seqlen_k=prefill_meta.max_prompt_len, + max_seqlen_q=prefill_meta.max_seqlen, + max_seqlen_k=prefill_meta.max_seqlen, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -245,9 +245,9 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.prompt_lens_tensor, + prefill_meta.seq_lens_tensor, prefill_meta.context_lens, - prefill_meta.max_subquery_len, + prefill_meta.max_query_len, self.alibi_slopes, ) if decode_meta := attn_metadata.decode_metadata: @@ -257,8 +257,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + decode_meta.seqlens, + decode_meta.max_seqlen, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 934acea0a3d6..4c76d7ab384c 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -65,26 +65,26 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # or all decoding. True if all sequences are prompts. is_prompt: bool # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # NOTE(sang): Definition of context_len, query_len, and seqlen. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seqlen ----------------------| - # |- subquery_len -| + # |-- query_len ---| # WARNING(sang): context_len has different definition depending on if it is # prefill vs decoding. When it is prefill, it doesn't include new tokens. # When it is for decoding, it includes a new token. - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # Maximum query length in the batch. + max_query_len: Optional[int] + # Maximum sequence length in the batch. + max_seqlen: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -248,7 +248,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - assert prefill_meta.prompt_lens is not None + assert prefill_meta.seq_lens is not None if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the @@ -261,8 +261,8 @@ def forward( None, prefill_meta.seq_start_loc, prefill_meta.seq_start_loc, - prefill_meta.max_prompt_len, - prefill_meta.max_prompt_len, + prefill_meta.max_seqlen, + prefill_meta.max_seqlen, True, self.scale, ) @@ -275,7 +275,7 @@ def forward( query, key, value, - prefill_meta.prompt_lens, + prefill_meta.seq_lens, self.scale, ) else: @@ -285,8 +285,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prompt_len, - max_seqlen_k=prefill_meta.max_prompt_len, + max_seqlen_q=prefill_meta.max_seqlen, + max_seqlen_k=prefill_meta.max_seqlen, softmax_scale=self.scale, causal=True, ) @@ -304,9 +304,9 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.prompt_lens_tensor, + prefill_meta.seq_lens_tensor, prefill_meta.context_lens, - prefill_meta.max_subquery_len, + prefill_meta.max_query_len, self.alibi_slopes, ) @@ -317,8 +317,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + decode_meta.seqlens, + decode_meta.max_seqlen, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -334,13 +334,13 @@ def _naive_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - prompt_lens: List[int], + seq_lens: List[int], scale: float, ) -> torch.Tensor: output = torch.empty_like(query) start = 0 - for _, prompt_len in enumerate(prompt_lens): - end = start + prompt_len + for _, seqlen in enumerate(seq_lens): + end = start + seqlen out = _naive_masked_attention( query[start:end], key[start:end], @@ -349,7 +349,7 @@ def _naive_attention( ) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out) - start += prompt_len + start += seqlen return output diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 55a7ce59ac6e..a1710994e6a9 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, # or all decoding. True if all sequences are prompts. is_prompt: bool slot_mapping: torch.Tensor - prompt_lens: Optional[List[int]] + seq_lens: Optional[List[int]] def __post_init__(self): # Set during the execution of the first attention op. @@ -136,7 +136,7 @@ def forward( kv_scale) if attn_metadata.is_prompt: - assert attn_metadata.prompt_lens is not None + assert attn_metadata.seq_lens is not None if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) @@ -147,13 +147,13 @@ def forward( if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.prompt_lens) # type: ignore + attn_metadata.seq_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.prompt_lens, self.sliding_window, + attn_metadata.seq_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.prompt_lens) + att_masks = [None] * len(attn_metadata.seq_lens) attn_metadata.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) @@ -164,9 +164,9 @@ def forward( output = torch.empty( (num_tokens, self.num_heads, self.head_size), dtype=query.dtype) - for prompt_len, mask in zip(attn_metadata.prompt_lens, + for seqlen, mask in zip(attn_metadata.seq_lens, attn_metadata.attn_bias): - end = start + prompt_len + end = start + seqlen sub_out = scaled_dot_product_attention( query[:, start:end, :], key[:, start:end, :], @@ -189,8 +189,8 @@ def forward( key_cache, value_cache, attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + attn_metadata.seq_lens, + attn_metadata.max_seqlen, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -205,13 +205,13 @@ def forward( def _make_alibi_bias( alibi_slopes: torch.Tensor, dtype: torch.dtype, - prompt_lens: List[int], + seq_lens: List[int], ) -> List[torch.Tensor]: attn_biases = [] - for prompt_len in prompt_lens: - bias = torch.arange(prompt_len, dtype=dtype) + for seqlen in seq_lens: + bias = torch.arange(seqlen, dtype=dtype) # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` + # `bias = bias[None, :].repeat(seqlen, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. @@ -221,7 +221,7 @@ def _make_alibi_bias( bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]) inf_mask = torch.empty( - (1, prompt_len, prompt_len), + (1, seqlen, seqlen), dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) attn_biases.append((bias + inf_mask).to(dtype)) @@ -229,14 +229,14 @@ def _make_alibi_bias( def _make_sliding_window_bias( - prompt_lens: List[int], + seq_lens: List[int], window_size: Optional[int], dtype: torch.dtype, ) -> List[torch.Tensor]: attn_biases = [] - for prompt_len in prompt_lens: + for seqlen in seq_lens: tensor = torch.full( - (1, prompt_len, prompt_len), + (1, seqlen, seqlen), dtype=dtype, fill_value=1, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index e5a2e34f8e6d..30a6dd3410ea 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -53,6 +53,20 @@ def copy_blocks( ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) +class AttentionMetadataBuilder: + def add_sequence_group(self, seq_group): + pass + + def build(self) -> AttentionMetadata: + pass + +def prepare_input(seq_group_list): + builder = AttentionMetadataBuilder() + for seq_group in seq_group_list: + # update positions, input_tokens, etc. + builder.add_seq_group(seq_group) + attn_metadata = builder.build() + @dataclass class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): @@ -67,27 +81,36 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # or all decoding. True if all sequences are prompts. is_prompt: bool # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # NOTE(sang): Definition of context_len, query_len, and seqlen. + # Before + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len/computed_len ----------| -> context len for prefill, computed len for prefix caching + # |------------ prompt_len/context_len --------------| -> context len for decode, prompt len for prefill + # |- subquery_len -| + + # After # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seqlen ----------------------| - # |- subquery_len -| + # |-- query_len ---| # WARNING(sang): context_len has different definition depending on if it is # prefill vs decoding. When it is prefill, it doesn't include new tokens. # When it is for decoding, it includes a new token. - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] + # Maximum query length in the batch. + max_query_len: Optional[int] # FIXME: It is for flash attn. - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # Maximum sequence length in the batch. + max_seqlen: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -242,9 +265,9 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.prompt_lens_tensor, + prefill_meta.seq_lens_tensor, prefill_meta.context_lens, - prefill_meta.max_subquery_len, + prefill_meta.max_query_len, self.alibi_slopes, ) assert output[:num_prefill_tokens].shape == out.shape @@ -256,8 +279,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.prompt_lens_tensor, - decode_meta.max_context_len, + decode_meta.seqlens, + decode_meta.max_seqlen, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -288,7 +311,7 @@ def _run_memory_efficient_xformers_forward( value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ - assert attn_metadata.prompt_lens is not None + assert attn_metadata.seq_lens is not None original_query = query if self.num_kv_heads != self.num_heads: # GQA/MQA requires the shape [B, M, G, H, K]. @@ -309,7 +332,7 @@ def _run_memory_efficient_xformers_forward( if attn_metadata.attn_bias is None: if self.alibi_slopes is None: attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.prompt_lens) + attn_metadata.seq_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) @@ -317,7 +340,7 @@ def _run_memory_efficient_xformers_forward( else: attn_metadata.attn_bias = _make_alibi_bias( self.alibi_slopes, self.num_kv_heads, query.dtype, - attn_metadata.prompt_lens) + attn_metadata.seq_lens) # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce @@ -342,8 +365,8 @@ def _run_memory_efficient_xformers_forward( # one. This is inefficient, especially when we have many short prompts. output = torch.empty_like(original_query) start = 0 - for i, prompt_len in enumerate(attn_metadata.prompt_lens): - end = start + prompt_len + for i, seqlen in enumerate(attn_metadata.seq_lens): + end = start + seqlen out = xops.memory_efficient_attention_forward( query[None, start:end], key[None, start:end], @@ -353,7 +376,7 @@ def _run_memory_efficient_xformers_forward( scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out.view_as(original_query[start:end])) - start += prompt_len + start += seqlen return output @@ -361,13 +384,13 @@ def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, dtype: torch.dtype, - prompt_lens: List[int], + seq_lens: List[int], ) -> LowerTriangularMaskWithTensorBias: attn_biases = [] - for prompt_len in prompt_lens: - bias = torch.arange(prompt_len, dtype=dtype) + for seqlen in seq_lens: + bias = torch.arange(seqlen, dtype=dtype) # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` + # `bias = bias[None, :].repeat(seqlen, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. @@ -375,16 +398,16 @@ def _make_alibi_bias( # element. bias = bias[None, :] - bias[:, None] - padded_len = (prompt_len + 7) // 8 * 8 + padded_len = (seqlen + 7) // 8 * 8 num_heads = alibi_slopes.shape[0] bias = torch.empty( 1, # batch size num_heads, - prompt_len, + seqlen, padded_len, device=alibi_slopes.device, dtype=dtype, - )[:, :, :, :prompt_len].copy_(bias) + )[:, :, :, :seqlen].copy_(bias) bias.mul_(alibi_slopes[:, None, None]) if num_heads != num_kv_heads: bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index cd0690a4ba95..987d00c939d2 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -13,12 +13,11 @@ @dataclass class PagedAttentionMetadata: """Metadata for PagedAttention.""" - # (batch_size,). The length of context (tokens stored in KV cache) per - # sequence. WARNING: When it is a prefill request, it doesn't include new - # tokens. When it is for decoding, it includes a new token. - context_lens: Optional[torch.Tensor] - # Maximum context length in the batch. - max_context_len: Optional[int] + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seqlens: Optional[torch.Tensor] + # Maximum sequence length in the batch. + max_seqlen: Optional[int] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks @@ -85,7 +84,7 @@ def forward_decode( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seqlens: torch.Tensor, max_context_len: int, kv_cache_dtype: str, num_kv_heads: int, @@ -118,7 +117,7 @@ def forward_decode( num_kv_heads, scale, block_tables, - context_lens, + seqlens, block_size, max_context_len, alibi_slopes, @@ -150,7 +149,7 @@ def forward_decode( num_kv_heads, scale, block_tables, - context_lens, + seqlens, block_size, max_context_len, alibi_slopes, @@ -168,9 +167,9 @@ def forward_prefix( value_cache: torch.Tensor, block_tables: torch.Tensor, subquery_start_loc: torch.Tensor, - prompt_lens_tensor: torch.Tensor, + seq_lens_tensor: torch.Tensor, context_lens: torch.Tensor, - max_subquery_len: int, + max_query_len: int, alibi_slopes: Optional[torch.Tensor], ) -> torch.Tensor: output = torch.empty_like(query) @@ -184,9 +183,9 @@ def forward_prefix( block_tables, # subquery_start_loc is (batch_size + 1,) subquery_start_loc[:-1], - prompt_lens_tensor, + seq_lens_tensor, context_lens, - max_subquery_len, + max_query_len, alibi_slopes, ) return output diff --git a/vllm/config.py b/vllm/config.py index a5512c657e03..457d4dce8f51 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -68,7 +68,10 @@ class ModelConfig: If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode. + to eager mode (DEPRECATED). + max_seqlen_to_capture: Maximum sequence len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. """ @@ -89,6 +92,7 @@ def __init__( quantization_param_path: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, + max_seqlen_to_capture: Optional[int] = None, max_logprobs: int = 5, skip_tokenizer_init: bool = False, ) -> None: @@ -104,6 +108,7 @@ def __init__( self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture + self.max_seqlen_to_capture = max_seqlen_to_capture or max_context_len_to_capture self.max_logprobs = max_logprobs self.skip_tokenizer_init = skip_tokenizer_init @@ -195,9 +200,9 @@ def _verify_quantization(self) -> None: "non-quantized models.", self.quantization) def _verify_cuda_graph(self) -> None: - if self.max_context_len_to_capture is None: - self.max_context_len_to_capture = self.max_model_len - self.max_context_len_to_capture = min(self.max_context_len_to_capture, + if self.max_seqlen_to_capture is None: + self.max_seqlen_to_capture = self.max_model_len + self.max_seqlen_to_capture = min(self.max_seqlen_to_capture, self.max_model_len) def verify_with_parallel_config( @@ -754,8 +759,8 @@ def maybe_create_spec_config( max_model_len=None, quantization=draft_quantization, enforce_eager=target_model_config.enforce_eager, - max_context_len_to_capture=target_model_config. - max_context_len_to_capture, + max_seqlen_to_capture=target_model_config. + max_seqlen_to_capture, max_logprobs=target_model_config.max_logprobs, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bd6437ee44c2..8955a516d393 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -45,6 +45,7 @@ class EngineArgs: quantization: Optional[str] = None enforce_eager: bool = False max_context_len_to_capture: int = 8192 + max_seqlen_to_capture: int = 8192 disable_custom_all_reduce: bool = False tokenizer_pool_size: int = 0 tokenizer_pool_type: str = "ray" @@ -319,6 +320,12 @@ def add_cli_args( type=int, default=EngineArgs.max_context_len_to_capture, help='Maximum context length covered by CUDA ' + 'graphs. When a sequence has context length (DEPRECATED)' + 'larger than this, we fall back to eager mode.') + parser.add_argument('--max-seqlen-to-capture', + type=int, + default=EngineArgs.max_seqlen_to_capture, + help='Maximum sequence length covered by CUDA ' 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode.') parser.add_argument('--disable-custom-all-reduce', @@ -475,7 +482,7 @@ def create_engine_config(self, ) -> EngineConfig: self.trust_remote_code, self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.quantization_param_path, - self.enforce_eager, self.max_context_len_to_capture, + self.enforce_eager, self.max_context_len_to_capture, self.max_seqlen_to_capture, self.max_logprobs, self.skip_tokenizer_init) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b022707794a7..4934cd09e7c8 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -69,6 +69,9 @@ class LLM: disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode (DEPRECATED). + max_seqlen_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. disable_custom_all_reduce: See ParallelConfig @@ -91,6 +94,7 @@ def __init__( swap_space: int = 4, enforce_eager: bool = False, max_context_len_to_capture: int = 8192, + max_seqlen_to_capture: int = 8192, disable_custom_all_reduce: bool = False, **kwargs, ) -> None: @@ -112,6 +116,7 @@ def __init__( swap_space=swap_space, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, + max_seqlen_to_capture=max_seqlen_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index d79c99e5d0a4..2de7763605df 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1033,8 +1033,8 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: assert seq_group.is_prompt, ( "Caller should ensure the sequence group is in a prefill stage.") seq_ids = seq_group.seq_ids - subquery_len = seq_group.subquery_len - assert subquery_len is not None + query_len = seq_group.query_len + assert query_len is not None # prompt has only 1 seq id. assert len(seq_ids) == 1 seq_data = seq_group.seq_data[seq_ids[0]] @@ -1042,7 +1042,7 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: prompt_tokens = seq_data.prompt_token_ids # +1 because we are looking for a next prompt token. next_token_index_start = computed_len + 1 - next_token_index_end = min(computed_len + subquery_len + 1, + next_token_index_end = min(computed_len + query_len + 1, len(prompt_tokens)) next_prompt_tokens = prompt_tokens[ next_token_index_start:next_token_index_end] diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 12156b2ba1aa..561952d11fd9 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -21,12 +21,13 @@ class SequenceGroupToSample: sampling_params: SamplingParams # seq_id -> sequence data. seq_data: Dict[int, SequenceData] - # The length of the prompt of the sequence group. None if it is in a decode + # The length of the sequence of the sequence group. None if it is in a decode # stage. - prompt_len: Optional[int] + seqlen: Optional[int] # The length of the query tokens to compute in the current step. None if it - # is in a decode stage. The length of subquery_len <= prompt_len. - subquery_len: Optional[int] + # is in a decode stage. The length of query_len <= seqlen if chunked prefill + # is enabled. + query_len: Optional[int] # A random number generator for sampling. generator: Optional[torch.Generator] # True if the sequence group is in prefill stage. False if it is in a @@ -46,8 +47,8 @@ def __post_init__(self): if len(self.prompt_logprob_indices) > 0: assert self.sampling_params.prompt_logprobs is not None if self.is_prompt: - assert self.prompt_len is not None - assert self.subquery_len is not None + assert self.seqlen is not None + assert self.query_len is not None class SamplingMetadata: @@ -94,8 +95,8 @@ def __init__( @staticmethod def prepare( seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - subquery_lens: Optional[List[int]], + seq_lens: List[int], + query_lens: Optional[List[int]], device: str, pin_memory: bool, ) -> "SamplingMetadata": @@ -104,8 +105,8 @@ def prepare( selected_token_indices, categorized_sample_indices, num_prompts, - ) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens, - subquery_lens, device) + ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, + query_lens, device) selected_token_indices = async_tensor_h2d(selected_token_indices, dtype=torch.long, target_device=device, @@ -137,8 +138,8 @@ def __repr__(self) -> str: def _prepare_seq_groups( seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - subquery_lens: Optional[List[int]], + seq_lens: List[int], + query_lens: Optional[List[int]], device: str, ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ SamplingType, List[Tuple[int, int]]], int]: @@ -146,9 +147,9 @@ def _prepare_seq_groups( Args: seq_group_metadata_list: A list of sequence group to batch. - prompt_lens: A list of prompt lens per sequence group. + seq_lens: A list of sequence lens per sequence group. Index of prompt len should match with seq_group_metadata_list. - subquery_lens: A list of query lengths. Prompt lens include the length + query_lens: A list of query lengths. Prompt lens include the length of entire prompt tokens, and it could be shorter. device: A device to use for random number generator, `SequenceGroupToSample.generator`. @@ -189,8 +190,8 @@ def _prepare_seq_groups( is_prompt = seq_group_metadata.is_prompt generator: Optional[torch.Generator] = None # If the current seq group is in decode stage, it is None. - prompt_len: Optional[int] = None - subquery_len: Optional[int] = None + seqlen: Optional[int] = None + query_len: Optional[int] = None prompt_logprob_indices: List[int] = [] sample_indices: List[int] = [] do_sample = seq_group_metadata.do_sample @@ -203,12 +204,12 @@ def _prepare_seq_groups( num_prompts += 1 num_prefill_sample = len(seq_ids) assert num_prefill_sample == 1 - assert subquery_lens is not None and prompt_lens is not None - subquery_len, prompt_len = subquery_lens[i], prompt_lens[i] + assert query_lens is not None and seq_lens is not None + query_len, seqlen = query_lens[i], seq_lens[i] # If we need sampling, exclude num_prefill_sample tokens from # prompt logprob. - prompt_logprob_len = (subquery_len - num_prefill_sample - if do_sample else subquery_len) + prompt_logprob_len = (query_len - num_prefill_sample + if do_sample else query_len) sample_len = num_prefill_sample if do_sample else 0 else: # Decode @@ -267,8 +268,8 @@ def sample(logits): seq_ids=seq_ids, sampling_params=sampling_params, seq_data=seq_group_metadata.seq_data, - prompt_len=prompt_len, - subquery_len=subquery_len, + seqlen=seqlen, + query_len=query_len, generator=generator, is_prompt=is_prompt, prompt_logprob_indices=list(prompt_logprob_indices), @@ -367,8 +368,8 @@ def from_sampling_metadata( and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get # their logprobs - subquery_len = seq_group.subquery_len - assert subquery_len is not None + query_len = seq_group.query_len + assert query_len is not None prefill_len = len(seq_group.prompt_logprob_indices) temperatures += [temperature] * prefill_len top_ps += [top_p] * prefill_len @@ -397,8 +398,8 @@ def from_sampling_metadata( if is_prompt: prompt_best_of.append(sampling_params.best_of) - subquery_len = seq_group.subquery_len - assert subquery_len is not None + query_len = seq_group.query_len + assert query_len is not None for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 34d7d3dffea1..2c6acded05b8 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -80,7 +80,7 @@ def _prepare_prompt( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - prompt_lens: List[int] = [] + seq_lens: List[int] = [] multi_modal_input_list: List[torch.Tensor] = [] for seq_group_metadata in seq_group_metadata_list: @@ -92,15 +92,15 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() computed_len = seq_data.get_num_computed_tokens() - prompt_len = len(prompt_tokens) + seqlen = len(prompt_tokens) - prompt_lens.append(prompt_len) # Prompt token num + seq_lens.append(seqlen) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, prompt_len))) + input_positions.extend(list(range(computed_len, seqlen))) if seq_group_metadata.multi_modal_data: multi_modal_input_list.append( @@ -109,15 +109,15 @@ def _prepare_prompt( # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, prompt_len - sliding_window). + # where start_idx is max(0, seqlen - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: - start_idx = max(0, prompt_len - self.sliding_window) + start_idx = max(0, seqlen - self.sliding_window) - for i in range(computed_len, prompt_len): + for i in range(computed_len, seqlen): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue @@ -151,8 +151,8 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - prompt_lens=prompt_lens, - num_prefills=len(prompt_lens), + seq_lens=seq_lens, + num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, prefill_metadata=None, @@ -163,7 +163,7 @@ def _prepare_prompt( slot_mapping=slot_mapping, kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, prompt_lens, + return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input) def _prepare_decode( @@ -236,7 +236,7 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, - prompt_lens=None, + seq_lens=None, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), max_context_len=max_context_len, @@ -265,20 +265,20 @@ def prepare_input_tensors( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, attn_metadata, prompt_lens, + (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, attn_metadata) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] + seq_lens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - # subquery_lens is not needed if chunked prefill is not + seq_lens, + # query_lens is not needed if chunked prefill is not # supported. Since CPU worker doesn't support chunked prefill - # just use prompt_lens instead. - prompt_lens, + # just use seq_lens instead. + seq_lens, self.device, pin_memory=False) # Broadcast the metadata. @@ -300,7 +300,7 @@ def prepare_input_tensors( sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, - prompt_lens=None, + seq_lens=None, selected_token_indices=selected_token_indices, categorized_sample_indices=None, generators=None, diff --git a/vllm/worker/model_input.py b/vllm/worker/model_input.py new file mode 100644 index 000000000000..b2fb83a46fd5 --- /dev/null +++ b/vllm/worker/model_input.py @@ -0,0 +1,290 @@ +import torch +from dataclasses import dataclass +from typing import List, Set, Optional, Type + +from vllm.vllm.sequence import SequenceGroupMetadata +from vllm.attention.backends.abstract import AttentionMetadata, AttentionBackend +from vllm.lora.request import LoRARequest +from vllm.lora.layers import LoRAMapping +from vllm.config import SchedulerConfig, LoRAConfig, VisionLanguageConfig +from vllm.utils import make_tensor_with_pad + + +_PAD_SLOT_ID = -1 + + +@dataclass +class GpuModelInput: + """Input to run a model. + + Input tensors include inputs across multiple sequence groups. + It assumes inputs are ordered by prefill -> decode sequences. + """ + # (num_tokens,) 1D Flattened input token IDs. + input_tokens: torch.Tensor + # (num_tokens,) Positions of a token in its sequence. Used for RoPE. + input_positions: torch.Tensor + # Attention metadata to run attention kernels. + attn_metadata: AttentionMetadata + # (batch_size,) A sequence length for each sequence group in a batch. + seq_lens: List[int] + # (batch_size,) A query length for eaach sequence group in a batch. + query_lens: List[int] + # Set of lora requests. + lora_requests: Set[LoRARequest] + # Inputs used for multi modality. + multi_modal_input: Optional[torch.Tensor] + # (num_tokens,) A page index per token. Each slot index is flattened. For + # example, if slot mapping is 15 and block size is 8, it means block index + # 1 and offset 3. + slot_mapping: torch.Tensor + # Lora mapping. None if lora is not used. + lora_mapping: Optional[LoRAMapping] + + @classmethod + def from_sequence_groups( + cls, + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + block_size: int, + device: str, + attn_backend: Type[AttentionBackend], + sliding_window: Optional[int]) -> "GpuModelInput": + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] + lora_requests: Set[LoRARequest] = set() + + seq_lens: List[int] = [] + context_lens: List[int] = [] + query_lens: List[int] = [] + prefix_block_tables: List[List[int]] = [] + multi_modal_input_list: List[torch.Tensor] = [] + + is_prompt = False + for seq_group_metadata in seq_group_metadata_list: + # assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + # assert len(seq_ids) == 1 + seq_id = seq_ids[0] + is_prompt = seq_group_metadata.is_prompt + + computed_block_nums = seq_group_metadata.computed_block_nums + if (scheduler_config is not None + and scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + token_chunk_size = seq_group_metadata.token_chunk_size + seq_data = seq_group_metadata.seq_data[seq_id] + computed_len = seq_data.get_num_computed_tokens() + # We should use get_len here because in case of preemption + # it contains output tokens. + prefill_end = min(seq_data.get_len(), + computed_len + token_chunk_size) + prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] + seqlen = prefill_end + seq_lens.append(seqlen) + + # NOTE: This only works for oooooooxxx style attention. + if computed_block_nums is not None and len( + computed_block_nums) > 0 and sliding_window is None: + # Prefix is not supported with sliding_window + computed_len = len(computed_block_nums) * block_size + prompt_tokens = prompt_tokens[computed_len:] + prefix_block_tables.append(computed_block_nums) + elif scheduler_config.chunked_prefill_enabled or not is_prompt: + if seq_group_metadata.block_tables is not None: + # Prefill has chunked before. + block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) + else: + # The first prefill. + prefix_block_tables.append([]) + else: + prefix_block_tables.append([]) + # Right now, prefill start is always 0. However, this + # assumption can be changed once chunked prefill is introduced. + assert computed_len == 0 + + # actual prompt lens + context_lens.append(computed_len) + query_lens.append(seqlen - computed_len) + + input_tokens.extend(prompt_tokens) + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.extend(list(range(computed_len, prefill_end))) + lora_id = seq_group_metadata.lora_int_id + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * (seqlen - computed_len) + lora_prompt_mapping.extend( + [lora_id] * + (seqlen - computed_len + if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + + if seq_group_metadata.multi_modal_data: + multi_modal_input_list.append( + seq_group_metadata.multi_modal_data.data) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.extend([_PAD_SLOT_ID] * seqlen) + continue + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, seqlen - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if sliding_window is not None: + assert computed_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") + start_idx = max(0, seqlen - sliding_window) + + for i in range(computed_len, prefill_end): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + + max_query_len = max(query_lens) + max_seqlen = max(seq_lens) + assert max_query_len > 0 + + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=device) + + if multi_modal_input_list: + assert vision_language_config, ( + "Multi-modal inputs are only supported by " + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(device) + else: + multi_modal_input = None + + # Prepare prefix block tables + max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) + block_tables = make_tensor_with_pad( + prefix_block_tables, + max_len=max_prompt_block_table_len, + pad=0, + dtype=torch.int, + device=device, + ) + + # Query length can be shorter than key (i.e., prompt) when prefill + # is chunked or prefix cached. + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + + torch.cumsum(query_lens_tensor, + dim=0, + dtype=subquery_start_loc.dtype, + out=subquery_start_loc[1:]) + + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + attn_metadata = attn_backend.make_metadata( + is_prompt=is_prompt, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_context_len=max(context_lens), + max_seqlen=max_seqlen, + subquery_start_loc=subquery_start_loc, + seq_start_loc=seq_start_loc, + context_lens=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + ) + + # Decode + # attn_metadata = self.attn_backend.make_metadata( + # is_prompt=False, + # seq_lens=None, + # seq_lens_tensor=None, + # max_query_len=None, + # max_context_len=max_context_len, + # max_seqlen=None, + # subquery_start_loc=None, + # seq_start_loc=None, + # context_lens=context_lens_tensor, + # block_tables=block_tables, + # use_cuda_graph=use_captured_graph, + # ) + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=device) + + if lora_config: + lora_mapping = LoRAMapping( + lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + + attn_metadata = AttentionMetadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + prefill_metadata=attn_metadata, + decode_metadata=attn_metadata, + kv_cache_dtype=self.kv_cache_dtype, + ) + + return ModelInput( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_requests=lora_requests, + multi_modal_input=multi_modal_input, + slot_mapping=slot_mapping_tensor, + lora_mapping=lora_mapping, + ) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9e3c10eb1e8b..06d1c31fac29 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -42,8 +42,8 @@ class PreparePromptMetadata(NamedTuple): input_tokens: List[int] input_positions: List[int] attn_metadata: Optional[AttentionMetadataPerStage] - prompt_lens: List[int] - subquery_lens: List[int] + seq_lens: List[int] + query_lens: List[int] lora_index_mapping: List[int] lora_prompt_mapping: List[int] lora_requests: Set[LoRARequest] @@ -56,8 +56,8 @@ def empty(cls): input_tokens=[], input_positions=[], attn_metadata=None, - prompt_lens=[], - subquery_lens=[], + seq_lens=[], + query_lens=[], lora_index_mapping=[], lora_prompt_mapping=[], lora_requests=set(), @@ -134,8 +134,8 @@ def __init__( self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - self.max_context_len_to_capture = ( - self.model_config.max_context_len_to_capture + self.max_seqlen_to_capture = ( + self.model_config.max_seqlen_to_capture if self.model_config is not None else 0) self.pin_memory = is_pin_memory_available() @@ -149,7 +149,7 @@ def __init__( self.model: torch.nn.Module # Set after load_model self.block_size: int # Set after initial profiling. # When using CUDA graph, the input block tables must be padded to - # max_context_len_to_capture. However, creating the block table in + # max_seqlen_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table # in numpy and only copy the actual input content at every iteration. # The shape of the cached block table will be @@ -218,7 +218,7 @@ def set_block_size(self, block_size: int) -> None: def get_max_block_per_batch(self) -> int: block_size = self.block_size - return (self.max_context_len_to_capture + block_size - 1) // block_size + return (self.max_seqlen_to_capture + block_size - 1) // block_size def _prepare_prompt( self, @@ -231,9 +231,9 @@ def _prepare_prompt( lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() - prompt_lens: List[int] = [] + seq_lens: List[int] = [] context_lens: List[int] = [] - subquery_lens: List[int] = [] + query_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] @@ -265,8 +265,8 @@ def _prepare_prompt( prefill_end = min(seq_data.get_len(), computed_len + token_chunk_size) prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] - prompt_len = prefill_end - prompt_lens.append(prompt_len) + seqlen = prefill_end + seq_lens.append(seqlen) # NOTE: This only works for oooooooxxx style attention. if computed_block_nums is not None and len( @@ -275,8 +275,7 @@ def _prepare_prompt( computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) - # elif self.scheduler_config.chunked_prefill_enabled: - else: + elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. block_table = seq_group_metadata.block_tables[seq_id] @@ -284,15 +283,15 @@ def _prepare_prompt( else: # The first prefill. prefix_block_tables.append([]) - # else: - # prefix_block_tables.append([]) - # # Right now, prefill start is always 0. However, this - # # assumption can be changed once chunked prefill is introduced. - # assert computed_len == 0 + else: + prefix_block_tables.append([]) + # Right now, prefill start is always 0. However, this + # assumption can be changed once chunked prefill is introduced. + assert computed_len == 0 # actual prompt lens context_lens.append(computed_len) - subquery_lens.append(prompt_len - computed_len) + query_lens.append(seqlen - computed_len) input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt @@ -303,10 +302,10 @@ def _prepare_prompt( if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping += [lora_id] * (prompt_len - computed_len) + lora_index_mapping += [lora_id] * (seqlen - computed_len) lora_prompt_mapping.extend( [lora_id] * - (prompt_len - computed_len + (seqlen - computed_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.multi_modal_data: @@ -316,13 +315,13 @@ def _prepare_prompt( if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. - slot_mapping.extend([_PAD_SLOT_ID] * prompt_len) + slot_mapping.extend([_PAD_SLOT_ID] * seqlen) continue # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, prompt_len - sliding_window). + # where start_idx is max(0, seqlen - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. @@ -331,7 +330,7 @@ def _prepare_prompt( assert computed_len == 0, ( "Prefix caching is currently not supported with " "sliding window attention") - start_idx = max(0, prompt_len - self.sliding_window) + start_idx = max(0, seqlen - self.sliding_window) for i in range(computed_len, prefill_end): if i < start_idx: @@ -343,9 +342,9 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - max_subquery_len = max(subquery_lens) - max_prompt_len = max(prompt_lens) - assert max_subquery_len > 0 + max_query_len = max(query_lens) + max_seqlen = max(seq_lens) + assert max_query_len > 0 context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, @@ -372,37 +371,37 @@ def _prepare_prompt( # Query length can be shorter than key (i.e., prompt) when prefill # is chunked or prefix cached. - subquery_lens_tensor = torch.tensor(subquery_lens, + query_lens_tensor = torch.tensor(query_lens, dtype=torch.long, device=self.device) - subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1, + subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) - prompt_lens_tensor = torch.tensor(prompt_lens, + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=self.device) - seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) - torch.cumsum(subquery_lens_tensor, + torch.cumsum(query_lens_tensor, dim=0, dtype=subquery_start_loc.dtype, out=subquery_start_loc[1:]) - torch.cumsum(prompt_lens_tensor, + torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) attn_metadata = self.attn_backend.make_metadata( is_prompt=is_prompt, - prompt_lens=prompt_lens, - prompt_lens_tensor=prompt_lens_tensor, - max_subquery_len=max_subquery_len, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, max_context_len=max(context_lens), - max_prompt_len=max_prompt_len, + max_seqlen=max_seqlen, subquery_start_loc=subquery_start_loc, seq_start_loc=seq_start_loc, context_lens=context_lens_tensor, @@ -414,8 +413,8 @@ def _prepare_prompt( input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, - prompt_lens=prompt_lens, - subquery_lens=subquery_lens, + seq_lens=seq_lens, + query_lens=query_lens, lora_index_mapping=lora_index_mapping, lora_prompt_mapping=lora_prompt_mapping, lora_requests=lora_requests, @@ -430,7 +429,7 @@ def _prepare_decode( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - context_lens: List[int] = [] + seqlens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] @@ -454,13 +453,13 @@ def _prepare_decode( generation_token = seq_data.get_last_token_id() input_tokens.append(generation_token) - seq_len = seq_data.get_len() - position = seq_len - 1 + seqlen = seq_data.get_len() + position = seqlen - 1 input_positions.append(position) - context_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) - context_lens.append(context_len) + seqlen = seqlen if self.sliding_window is None else min( + seqlen, self.sliding_window) + seqlens.append(seqlen) block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] @@ -480,11 +479,11 @@ def _prepare_decode( # See `capture_model` API for more details. # For decoding requests, batch_size == input_tokens. batch_size = len(input_tokens) - max_context_len = max(context_lens) + max_seqlen = max(seqlens) use_captured_graph = ( not self.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_context_len <= self.max_context_len_to_capture) + and max_seqlen <= self.max_seqlen_to_capture) if use_captured_graph: graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size @@ -492,21 +491,21 @@ def _prepare_decode( input_tokens.append(0) input_positions.append(0) slot_mapping.append(_PAD_SLOT_ID) - context_lens.append(1) + seqlens.append(1) block_tables.append([]) lora_index_mapping.append(0) batch_size = graph_batch_size - context_lens_tensor = torch.tensor(context_lens, + seqlens_tensor = torch.tensor(seqlens, dtype=torch.int, device=self.device) if use_captured_graph: # When using cuda-graph all these tensors should be # padded. - assert context_lens_tensor.shape[0] == len(input_tokens) - assert context_lens_tensor.shape[0] == len(input_positions) - assert context_lens_tensor.shape[0] == len(slot_mapping) + assert seqlens_tensor.shape[0] == len(input_tokens) + assert seqlens_tensor.shape[0] == len(input_positions) + assert seqlens_tensor.shape[0] == len(slot_mapping) # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -528,14 +527,14 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - prompt_lens=None, - prompt_lens_tensor=None, - max_subquery_len=None, - max_context_len=max_context_len, - max_prompt_len=None, + seq_lens=None, + seq_lens_tensor=seqlens_tensor, + max_query_len=None, + max_context_len=None, + max_seqlen=max_seqlen, subquery_start_loc=None, seq_start_loc=None, - context_lens=context_lens_tensor, + context_lens=None, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) @@ -568,43 +567,31 @@ def prepare_input_tensors( input_tokens, input_positions, prefill_attn_metadata, - prompt_lens, - subquery_lens, + seq_lens, + query_lens, lora_index_mapping, lora_prompt_mapping, lora_requests, multi_modal_input, slot_mapping, ) = self._prepare_prompt(prefill_reqs) - # ( - # decode_input_tokens, - # decode_input_positions, - # decode_attn_metadata, - # decode_lora_index_mapping, - # decode_lora_prompt_mapping, - # decode_lora_requests, - # decode_slot_mapping, - # ) = self._prepare_decode(decode_reqs) ( decode_input_tokens, decode_input_positions, decode_attn_metadata, - _, - _, decode_lora_index_mapping, decode_lora_prompt_mapping, decode_lora_requests, - _, decode_slot_mapping, - ) = self._prepare_prompt(decode_reqs) + ) = self._prepare_decode(decode_reqs) sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, prompt_lens, subquery_lens, + seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 - num_prefills = len(prompt_lens) + num_prefills = len(seq_lens) num_prefill_tokens = len(input_tokens) num_decode_tokens = len(decode_input_tokens) @@ -901,7 +888,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) - context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + seqlens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() graph_batch_size = _get_graph_batch_size( @@ -923,14 +910,14 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # Create dummy attn_metadata. decode_metadata = self.attn_backend.make_metadata( is_prompt=False, - prompt_lens=None, - prompt_lens_tensor=None, - max_subquery_len=None, - max_context_len=self.max_context_len_to_capture, - max_prompt_len=None, + seq_lens=None, + seq_lens_tensor=seqlens[:batch_size], + max_query_len=None, + max_context_len=None, + max_seqlen=self.max_seqlen_to_capture, subquery_start_loc=None, seq_start_loc=None, - context_lens=context_lens[:batch_size], + context_lens=None, block_tables=block_tables[:batch_size], use_cuda_graph=True, ) diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index a974e85c22f4..3078a48dbe8b 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -52,7 +52,7 @@ def _prepare_prompt( input_positions: List[List[int]] = [] input_block_ids: List[int] = [] - prompt_lens: List[int] = [] + seq_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -61,26 +61,26 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() - prompt_len = len(prompt_tokens) - prompt_lens.append(prompt_len) + seqlen = len(prompt_tokens) + seq_lens.append(seqlen) input_tokens.append(prompt_tokens) - input_positions.append(list(range(prompt_len))) + input_positions.append(list(range(seqlen))) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] assert len(block_table) == 1 input_block_ids.append(block_table[0]) - max_prompt_len = max(prompt_lens) - assert max_prompt_len > 0 + max_seqlen = max(seq_lens) + assert max_seqlen > 0 input_tokens = make_tensor_with_pad(input_tokens, - max_prompt_len, + max_seqlen, pad=0, dtype=torch.long, device=self.device) input_positions = make_tensor_with_pad(input_positions, - max_prompt_len, + max_seqlen, pad=0, dtype=torch.long, device=self.device) @@ -88,7 +88,7 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - return input_tokens, input_positions, input_block_ids, prompt_lens + return input_tokens, input_positions, input_block_ids, seq_lens def _prepare_decode( self, @@ -149,18 +149,18 @@ def prepare_input_tensors( # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_block_ids, - prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + seq_lens) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] + seq_lens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - # subquery_lens is not needed if chunked prefill is not + seq_lens, + # query_lens is not needed if chunked prefill is not # supported. Since neuron worker doesn't support chunked prefill - # just use prompt_lens instead. - prompt_lens, + # just use seq_lens instead. + seq_lens, self.device, self.pin_memory) From 7d5025366a99baa2f2f6e570c91c6133b71e84a6 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 1 May 2024 06:44:05 -0700 Subject: [PATCH 03/12] working. --- .../kernels/benchmark_paged_attention.py | 25 +- csrc/attention/attention_kernels.cu | 76 ++--- csrc/cpu/attention.cpp | 92 +++--- csrc/ops.h | 8 +- .../test_basic_correctness.py | 5 +- .../basic_correctness/test_chunked_prefill.py | 8 +- tests/kernels/test_attention.py | 35 +-- tests/spec_decode/e2e/conftest.py | 4 +- tests/worker/test_model_runner.py | 10 +- vllm/_custom_ops.py | 19 +- vllm/attention/backends/flash_attn.py | 22 +- vllm/attention/backends/rocm_flash_attn.py | 12 +- vllm/attention/backends/torch_sdpa.py | 2 +- vllm/attention/backends/xformers.py | 58 +--- vllm/attention/ops/paged_attn.py | 10 +- vllm/config.py | 13 +- vllm/engine/arg_utils.py | 13 +- vllm/entrypoints/llm.py | 4 +- vllm/model_executor/sampling_metadata.py | 16 +- vllm/worker/cpu_model_runner.py | 32 +- vllm/worker/model_input.py | 290 ------------------ vllm/worker/model_runner.py | 125 ++++---- vllm/worker/neuron_model_runner.py | 24 +- 23 files changed, 295 insertions(+), 608 deletions(-) delete mode 100644 vllm/worker/model_input.py diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 5c3650fa72d1..eb7120c2a6e7 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -16,7 +16,7 @@ def main( version: str, num_seqs: int, - context_len: int, + seqlen: int, num_query_heads: int, num_kv_heads: int, head_size: int, @@ -48,12 +48,12 @@ def main( dtype=torch.float, device=device) - context_lens = [context_len for _ in range(num_seqs)] - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device=device) + seqlens = [seqlen for _ in range(num_seqs)] + max_seqlen = max(seqlens) + seqlens = torch.tensor(seqlens, dtype=torch.int, device=device) # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + max_num_blocks_per_seq = (max_seqlen + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): block_table = [ @@ -77,8 +77,7 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // - PARTITION_SIZE) + num_partitions = ((max_seqlen + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), dtype=output.dtype, @@ -110,9 +109,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: num_kv_heads, scale, block_tables, - context_lens, + seqlens, block_size, - max_context_len, + max_seqlen, alibi_slopes, kv_cache_dtype, kv_scale, @@ -129,9 +128,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: num_kv_heads, scale, block_tables, - context_lens, + seqlens, block_size, - max_context_len, + max_seqlen, alibi_slopes, kv_cache_dtype, kv_scale, @@ -166,7 +165,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: choices=["v1", "v2"], default="v2") parser.add_argument("--batch-size", type=int, default=8) - parser.add_argument("--context-len", type=int, default=4096) + parser.add_argument("--seqlen", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", @@ -199,7 +198,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: main( version=args.version, num_seqs=args.batch_size, - context_len=args.context_len, + seqlen=args.seqlen, num_query_heads=args.num_query_heads, num_kv_heads=args.num_kv_heads, head_size=args.head_size, diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index f3a5bbfd3098..0c521d11c69d 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -104,7 +104,7 @@ __device__ void paged_attention_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seqlens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -115,23 +115,23 @@ __device__ void paged_attention_kernel( const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; - const int context_len = context_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + const int seqlen = seqlens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seqlen) { // No work to do. Terminate the thread block. return; } - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + const int num_seq_blocks = DIVIDE_ROUND_UP(seqlen, BLOCK_SIZE); + const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; // [start_block_idx, end_block_idx) is the range of blocks to process. const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); const int num_blocks = end_block_idx - start_block_idx; // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seqlen); const int num_tokens = end_token_idx - start_token_idx; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); @@ -245,12 +245,12 @@ __device__ void paged_attention_kernel( // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seqlen + 1) : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. - const bool mask = token_idx >= context_len; + const bool mask = token_idx >= seqlen; logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); @@ -364,14 +364,14 @@ __device__ void paged_attention_kernel( } else { v_vec = *reinterpret_cast(v_ptr + offset); } - if (block_idx == num_context_blocks - 1) { + if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, // we should explicitly zero out the values since they may contain NaNs. // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { - v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + v_vec_ptr[j] = token_idx + j < seqlen ? v_vec_ptr[j] : zero_value; } } accs[i] += dot(logits_vec, v_vec); @@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seqlens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel( const float kv_scale) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, + out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seqlens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); } @@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seqlens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel( const float kv_scale) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + block_tables, seqlens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); } @@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel( const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seqlens, // [num_seqs] const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + const int seqlen = seqlens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seqlen, PARTITION_SIZE); if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; @@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel( num_kv_heads, \ scale, \ block_tables_ptr, \ - context_lens_ptr, \ + seqlens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ q_stride, \ @@ -639,8 +639,8 @@ void paged_attention_v1_launcher( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, + torch::Tensor& seqlens, + int max_seqlen, const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); @@ -664,11 +664,11 @@ void paged_attention_v1_launcher( CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seqlens_ptr = seqlens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_context_len * sizeof(float); + int padded_max_seqlen = DIVIDE_ROUND_UP(max_seqlen, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_seqlen * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! @@ -715,8 +715,8 @@ void paged_attention_v1_launcher( num_kv_heads, \ scale, \ block_tables, \ - context_lens, \ - max_context_len, \ + seqlens, \ + max_seqlen, \ alibi_slopes, \ kv_scale); @@ -746,9 +746,9 @@ void paged_attention_v1( int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& seqlens, // [num_seqs] int block_size, - int max_context_len, + int max_seqlen, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { @@ -790,7 +790,7 @@ void paged_attention_v1( num_kv_heads, \ scale, \ block_tables_ptr, \ - context_lens_ptr, \ + seqlens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ q_stride, \ @@ -803,7 +803,7 @@ void paged_attention_v1( exp_sums_ptr, \ max_logits_ptr, \ tmp_out_ptr, \ - context_lens_ptr, \ + seqlens_ptr, \ max_num_partitions); template< @@ -824,8 +824,8 @@ void paged_attention_v2_launcher( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, + torch::Tensor& seqlens, + int max_seqlen, const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); @@ -852,10 +852,10 @@ void paged_attention_v2_launcher( CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seqlens_ptr = seqlens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + int max_num_partitions = DIVIDE_ROUND_UP(max_seqlen, PARTITION_SIZE); int logits_size = PARTITION_SIZE * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); @@ -909,8 +909,8 @@ void paged_attention_v2_launcher( num_kv_heads, \ scale, \ block_tables, \ - context_lens, \ - max_context_len, \ + seqlens, \ + max_seqlen, \ alibi_slopes, \ kv_scale); @@ -943,9 +943,9 @@ void paged_attention_v2( int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& seqlens, // [num_seqs] int block_size, - int max_context_len, + int max_seqlen, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 365bbd5e2372..e41d356a645b 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -70,11 +70,11 @@ template FORCE_INLINE std::pair reduceSoftmaxAlibi(T *data, const int size, const int capacity, const float alibi_slope, const int start_index, - const int context_len) { - data[0] += alibi_slope * (start_index - context_len + 1); + const int seqlen) { + data[0] += alibi_slope * (start_index - seqlen + 1); T max = data[0]; for (int i = 1; i < size; ++i) { - T qk = data[i] + alibi_slope * (start_index + i - context_len + 1); + T qk = data[i] + alibi_slope * (start_index + i - seqlen + 1); data[i] = qk; max = max >= qk ? max : qk; } @@ -225,7 +225,7 @@ struct paged_attention_v1_impl { const int num_kv_heads, const float scale, const int *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ context_lens, // [num_seqs] + const int *__restrict__ seqlens, // [num_seqs] const int max_num_blocks_per_seq, const float *__restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -235,32 +235,32 @@ struct paged_attention_v1_impl { static_assert(BLOCK_SIZE == 16); - int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; - int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; - TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); + int max_seqlen = max_num_blocks_per_seq * BLOCK_SIZE; + int max_seqlen_padded = (max_seqlen + 15) & 0xFFFFFFF0; + TORCH_CHECK((max_seqlen_padded * sizeof(float)) % 64 == 0); const int parallel_work_item_num = omp_get_max_threads(); size_t logits_bytes = - parallel_work_item_num * max_context_len_padded * sizeof(float); + parallel_work_item_num * max_seqlen_padded * sizeof(float); float *logits = (float *)std::aligned_alloc( 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_context_len_padded] + // [parallel_work_item_num, max_seqlen_padded] #pragma omp parallel for collapse(2) schedule(dynamic, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - int context_len = context_lens[seq_idx]; + int seqlen = seqlens[seq_idx]; const int *seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx; - const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int block_num = (seqlen + BLOCK_SIZE - 1) / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; const scalar_t *__restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const int last_block_token_num = - context_len - (block_num - 1) * BLOCK_SIZE; + seqlen - (block_num - 1) * BLOCK_SIZE; float *__restrict__ thread_block_logits = - logits + omp_get_thread_num() * max_context_len_padded; + logits + omp_get_thread_num() * max_seqlen_padded; // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { @@ -278,11 +278,11 @@ struct paged_attention_v1_impl { // Compute softmax if (alibi_slopes) { - reduceSoftmaxAlibi(thread_block_logits, context_len, + reduceSoftmaxAlibi(thread_block_logits, seqlen, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, - context_len); + seqlen); } else { - reduceSoftmax(thread_block_logits, context_len, + reduceSoftmax(thread_block_logits, seqlen, block_num * BLOCK_SIZE); } @@ -340,7 +340,7 @@ struct paged_attention_v1_impl { #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ paged_attention_v1_impl::call( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + block_tables_ptr, seqlens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ num_heads); @@ -348,8 +348,8 @@ template void paged_attention_v1_impl_launcher( torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &context_lens, - int max_context_len, const c10::optional &alibi_slopes) { + torch::Tensor &block_tables, torch::Tensor &seqlens, + int max_seqlen, const c10::optional &alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher( T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int *block_tables_ptr = block_tables.data_ptr(); - int *context_lens_ptr = context_lens.data_ptr(); + int *seqlens_ptr = seqlens.data_ptr(); switch (head_size) { case 64: @@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher( #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ paged_attention_v1_impl_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - context_lens, max_context_len, alibi_slopes); + seqlens, max_seqlen, alibi_slopes); #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ @@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &block_tables, - torch::Tensor &context_lens, int block_size, - int max_context_len, + torch::Tensor &seqlens, int block_size, + int max_seqlen, const c10::optional &alibi_slopes, const std::string &kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); @@ -448,7 +448,7 @@ struct paged_attention_v2_impl { const int num_kv_heads, const float scale, const int *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ context_lens, // [num_seqs] + const int *__restrict__ seqlens, // [num_seqs] const int max_num_blocks_per_seq, const float *__restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -465,22 +465,22 @@ struct paged_attention_v2_impl { for (int partition_idx = 0; partition_idx < max_num_partitions; ++partition_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int context_len = context_lens[seq_idx]; + const int seqlen = seqlens[seq_idx]; const int start_token_idx = partition_idx * PARTITION_SIZE; - if (start_token_idx >= context_len) + if (start_token_idx >= seqlen) continue; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seqlen + PARTITION_SIZE - 1) / PARTITION_SIZE; const bool no_reduce = (partition_num == 1); - const int context_token_num = - (std::min(context_len, start_token_idx + PARTITION_SIZE) - + const int token_num = + (std::min(seqlen, start_token_idx + PARTITION_SIZE) - start_token_idx); const int block_num = - (context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; const int last_block_token_num = - context_token_num - (block_num - 1) * BLOCK_SIZE; + token_num - (block_num - 1) * BLOCK_SIZE; const int *seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx + start_token_idx / BLOCK_SIZE; @@ -507,10 +507,10 @@ struct paged_attention_v2_impl { std::pair max_and_sum; if (alibi_slopes) { max_and_sum = reduceSoftmaxAlibi( - logits, context_token_num, block_num * BLOCK_SIZE, - alibi_slopes[head_idx], start_token_idx, context_len); + logits, token_num, block_num * BLOCK_SIZE, + alibi_slopes[head_idx], start_token_idx, seqlen); } else { - max_and_sum = reduceSoftmax(logits, context_token_num, + max_and_sum = reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); } @@ -583,9 +583,9 @@ struct paged_attention_v2_impl { #pragma omp parallel for collapse(2) schedule(static, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int context_len = context_lens[seq_idx]; + const int seqlen = seqlens[seq_idx]; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seqlen + PARTITION_SIZE - 1) / PARTITION_SIZE; if (partition_num == 1) continue; @@ -612,9 +612,9 @@ struct paged_attention_v2_impl { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { - const int context_len = context_lens[seq_idx]; + const int seqlen = seqlens[seq_idx]; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seqlen + PARTITION_SIZE - 1) / PARTITION_SIZE; if (partition_num == 1) continue; @@ -649,7 +649,7 @@ struct paged_attention_v2_impl { paged_attention_v2_impl::call( \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + seqlens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ kv_block_stride, kv_head_stride, num_seqs, num_heads, \ max_num_partitions); @@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher( torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, - int max_context_len, const c10::optional &alibi_slopes) { + torch::Tensor &block_tables, torch::Tensor &seqlens, int block_size, + int max_seqlen, const c10::optional &alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher( T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int *block_tables_ptr = block_tables.data_ptr(); - int *context_lens_ptr = context_lens.data_ptr(); + int *seqlens_ptr = seqlens.data_ptr(); switch (head_size) { case 64: @@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher( #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ paged_attention_v2_impl_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, block_size, \ - max_context_len, alibi_slopes); + num_kv_heads, scale, block_tables, seqlens, block_size, \ + max_seqlen, alibi_slopes); #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ @@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &block_tables, - torch::Tensor &context_lens, int block_size, - int max_context_len, + torch::Tensor &seqlens, int block_size, + int max_seqlen, const c10::optional &alibi_slopes, const std::string &kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); diff --git a/csrc/ops.h b/csrc/ops.h index 04b97d1784cd..4ce1c8df71d2 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -10,9 +10,9 @@ void paged_attention_v1( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, + torch::Tensor& seqlens, int block_size, - int max_context_len, + int max_seqlen, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale); @@ -28,9 +28,9 @@ void paged_attention_v2( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, + torch::Tensor& seqlens, int block_size, - int max_context_len, + int max_seqlen, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale); diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 764344ac9436..97cff623c5e1 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -6,15 +6,14 @@ MODELS = [ "facebook/opt-125m", - # "meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-2-7b-hf", ] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) -# @pytest.mark.parametrize("enforce_eager", [False, True]) -@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("enforce_eager", [False, True]) def test_models( hf_runner, vllm_runner, diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index d83416eb51b4..7acd4bc4e5a6 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -10,15 +10,17 @@ MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", + # "meta-llama/Llama-2-7b-hf", ] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) -@pytest.mark.parametrize("enforce_eager", [False, True]) +# @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +# @pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("enforce_eager", [False]) # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 9b1f3e30b6dc..0bf4f1810deb 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -61,7 +61,7 @@ def ref_single_query_cached_kv_attention( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seqlens: torch.Tensor, scale: float, alibi_slopes: Optional[torch.Tensor], ) -> None: @@ -72,15 +72,15 @@ def ref_single_query_cached_kv_attention( num_seqs = query.shape[0] block_tables = block_tables.cpu().tolist() - context_lens = context_lens.cpu().tolist() + seqlens = seqlens.cpu().tolist() for i in range(num_seqs): q = query[i].unsqueeze(0) block_table = block_tables[i] - context_len = int(context_lens[i]) + seqlen = int(seqlens[i]) keys = [] values = [] - for j in range(context_len): + for j in range(seqlen): block_number = int(block_table[j // block_size]) block_offset = j % block_size @@ -100,8 +100,8 @@ def ref_single_query_cached_kv_attention( alibi_bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. - position_ids = torch.arange(context_len).int() - alibi_bias = (position_ids - context_len + 1).float() + position_ids = torch.arange(seqlen).int() + alibi_bias = (position_ids - seqlen + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) @@ -149,13 +149,13 @@ def test_paged_attention( if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - context_lens[-1] = MAX_SEQ_LEN - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int) + seqlens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seqlens[-1] = MAX_SEQ_LEN + max_seqlen = max(seqlens) + seqlens = torch.tensor(seqlens, dtype=torch.int) # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + max_num_blocks_per_seq = (max_seqlen + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): block_table = [ @@ -186,16 +186,15 @@ def test_paged_attention( num_kv_heads, scale, block_tables, - context_lens, + seqlens, block_size, - max_context_len, + max_seqlen, alibi_slopes, kv_cache_dtype, kv_scale, ) elif version == "v2": - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // - PARTITION_SIZE) + num_partitions = ((max_seqlen + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( @@ -218,9 +217,9 @@ def test_paged_attention( num_kv_heads, scale, block_tables, - context_lens, + seqlens, block_size, - max_context_len, + max_seqlen, alibi_slopes, kv_cache_dtype, kv_scale, @@ -255,7 +254,7 @@ def test_paged_attention( key_cache, value_cache, block_tables, - context_lens, + seqlens, scale, alibi_slopes, ) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 5d3469c4210e..d019ddea7aa7 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -44,7 +44,7 @@ def __init__( gpu_memory_utilization: float = 0.9, swap_space: int = 4, enforce_eager: bool = False, - max_context_len_to_capture: int = 8192, + max_seqlen_to_capture: int = 8192, disable_custom_all_reduce: bool = False, **kwargs, ) -> None: @@ -65,7 +65,7 @@ def __init__( gpu_memory_utilization=gpu_memory_utilization, swap_space=swap_space, enforce_eager=enforce_eager, - max_context_len_to_capture=max_context_len_to_capture, + max_seqlen_to_capture=max_seqlen_to_capture, engine_use_ray=True, disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 4c7d3673ca95..dc1014adf2da 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -47,9 +47,8 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices.append(selected_token_start_idx + seqlen - 1) selected_token_start_idx += seqlen - (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, - _, _, - slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, + _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) @@ -241,9 +240,8 @@ def test_empty_seq_group(): assert attn_metadata is None assert len(slot_mapping) == 0 - (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, - _, _, - slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, + _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4af8b09b1e16..aa87325eaa0f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -39,17 +39,17 @@ def paged_attention_v1( num_kv_heads: int, scale: float, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seqlens: torch.Tensor, block_size: int, - max_context_len: int, + max_seqlen: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, ) -> None: vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, - context_lens, block_size, max_context_len, - alibi_slopes, kv_cache_dtype, kv_scale) + num_kv_heads, scale, block_tables, seqlens, + block_size, max_seqlen, alibi_slopes, + kv_cache_dtype, kv_scale) def paged_attention_v2( @@ -63,18 +63,17 @@ def paged_attention_v2( num_kv_heads: int, scale: float, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seqlens: torch.Tensor, block_size: int, - max_context_len: int, + max_seqlen: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, ) -> None: vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, - block_tables, context_lens, block_size, - max_context_len, alibi_slopes, kv_cache_dtype, - kv_scale) + block_tables, seqlens, block_size, max_seqlen, + alibi_slopes, kv_cache_dtype, kv_scale) # pos encoding ops diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index d665e7b71a20..04b190a53968 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -66,10 +66,11 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seqlens: Optional[List[int]] + # seqlens stored as a tensor. + seqlens_tensor: Optional[torch.Tensor] # NOTE(sang): Definition of context_len, query_len, and seqlen. # |---------- N-1 iteration --------| @@ -79,10 +80,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # |-------------------- seqlen ----------------------| # |-- query_len ---| - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - # Maximum query length in the batch. max_query_len: Optional[int] # Maximum sequence length in the batch. @@ -95,6 +92,9 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. @@ -245,8 +245,8 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.context_lens, + prefill_meta.seqlens_tensor, + prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, ) @@ -257,7 +257,7 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.seqlens, + decode_meta.seqlens_tensor, decode_meta.max_seqlen, attn_metadata.kv_cache_dtype, self.num_kv_heads, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 4c76d7ab384c..1be1c4ed8eff 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -64,7 +64,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] @@ -77,10 +78,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # |-------------------- seqlen ----------------------| # |-- query_len ---| - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - # Maximum query length in the batch. max_query_len: Optional[int] # Maximum sequence length in the batch. @@ -98,6 +95,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] class ROCmFlashAttentionImpl(AttentionImpl): @@ -305,7 +305,7 @@ def forward( prefill_meta.block_tables, prefill_meta.subquery_start_loc, prefill_meta.seq_lens_tensor, - prefill_meta.context_lens, + prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, ) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index a1710994e6a9..adce25545feb 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -165,7 +165,7 @@ def forward( (num_tokens, self.num_heads, self.head_size), dtype=query.dtype) for seqlen, mask in zip(attn_metadata.seq_lens, - attn_metadata.attn_bias): + attn_metadata.attn_bias): end = start + seqlen sub_out = scaled_dot_product_attention( query[:, start:end, :], diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 30a6dd3410ea..09fc4fab5c72 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -53,20 +53,6 @@ def copy_blocks( ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) -class AttentionMetadataBuilder: - def add_sequence_group(self, seq_group): - pass - - def build(self) -> AttentionMetadata: - pass - -def prepare_input(seq_group_list): - builder = AttentionMetadataBuilder() - for seq_group in seq_group_list: - # update positions, input_tokens, etc. - builder.add_seq_group(seq_group) - attn_metadata = builder.build() - @dataclass class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): @@ -80,21 +66,12 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # NOTE(sang): Definition of context_len, query_len, and seqlen. - # Before - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len/computed_len ----------| -> context len for prefill, computed len for prefix caching - # |------------ prompt_len/context_len --------------| -> context len for decode, prompt len for prefill - # |- subquery_len -| + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seqlens: Optional[List[int]] + # seqlens stored as a tensor. + seqlens_tensor: Optional[torch.Tensor] - # After # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| @@ -102,10 +79,6 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # |-------------------- seqlen ----------------------| # |-- query_len ---| - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - # Maximum query length in the batch. max_query_len: Optional[int] # FIXME: It is for flash attn. @@ -120,6 +93,9 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. @@ -265,8 +241,8 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.context_lens, + prefill_meta.seqlens_tensor, + prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, ) @@ -279,7 +255,7 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.seqlens, + decode_meta.seqlens_tensor, decode_meta.max_seqlen, attn_metadata.kv_cache_dtype, self.num_kv_heads, @@ -311,7 +287,7 @@ def _run_memory_efficient_xformers_forward( value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ - assert attn_metadata.seq_lens is not None + assert attn_metadata.seqlens is not None original_query = query if self.num_kv_heads != self.num_heads: # GQA/MQA requires the shape [B, M, G, H, K]. @@ -332,7 +308,7 @@ def _run_memory_efficient_xformers_forward( if attn_metadata.attn_bias is None: if self.alibi_slopes is None: attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seq_lens) + attn_metadata.seqlens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) @@ -340,7 +316,7 @@ def _run_memory_efficient_xformers_forward( else: attn_metadata.attn_bias = _make_alibi_bias( self.alibi_slopes, self.num_kv_heads, query.dtype, - attn_metadata.seq_lens) + attn_metadata.seqlens) # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce @@ -365,7 +341,7 @@ def _run_memory_efficient_xformers_forward( # one. This is inefficient, especially when we have many short prompts. output = torch.empty_like(original_query) start = 0 - for i, seqlen in enumerate(attn_metadata.seq_lens): + for i, seqlen in enumerate(attn_metadata.seqlens): end = start + seqlen out = xops.memory_efficient_attention_forward( query[None, start:end], @@ -384,10 +360,10 @@ def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, dtype: torch.dtype, - seq_lens: List[int], + seqlens: List[int], ) -> LowerTriangularMaskWithTensorBias: attn_biases = [] - for seqlen in seq_lens: + for seqlen in seqlens: bias = torch.arange(seqlen, dtype=dtype) # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(seqlen, 1)` diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 987d00c939d2..f798e535a8c8 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -85,7 +85,7 @@ def forward_decode( value_cache: torch.Tensor, block_tables: torch.Tensor, seqlens: torch.Tensor, - max_context_len: int, + max_seqlen: int, kv_cache_dtype: str, num_kv_heads: int, scale: float, @@ -96,7 +96,7 @@ def forward_decode( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) // + max_num_partitions = ((max_seqlen + _PARTITION_SIZE - 1) // _PARTITION_SIZE) # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use @@ -105,7 +105,7 @@ def forward_decode( # to parallelize. # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = (max_context_len <= 8192 + use_v1 = (max_seqlen <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)) if use_v1: # Run PagedAttention V1. @@ -119,7 +119,7 @@ def forward_decode( block_tables, seqlens, block_size, - max_context_len, + max_seqlen, alibi_slopes, kv_cache_dtype, kv_scale, @@ -151,7 +151,7 @@ def forward_decode( block_tables, seqlens, block_size, - max_context_len, + max_seqlen, alibi_slopes, kv_cache_dtype, kv_scale, diff --git a/vllm/config.py b/vllm/config.py index 457d4dce8f51..bd4602cb53c3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -68,7 +68,7 @@ class ModelConfig: If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode (DEPRECATED). + to eager mode (DEPRECATED. Use max_seqlen_to_capture instead). max_seqlen_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode @@ -108,7 +108,11 @@ def __init__( self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture - self.max_seqlen_to_capture = max_seqlen_to_capture or max_context_len_to_capture + if self.max_context_len_to_capture is not None: + logger.warning("`max_context_len_to_capture` is deprecated. " + "Use `max_seqlen_to_capture` instead.") + self.max_seqlen_to_capture = (max_seqlen_to_capture + or max_context_len_to_capture) self.max_logprobs = max_logprobs self.skip_tokenizer_init = skip_tokenizer_init @@ -203,7 +207,7 @@ def _verify_cuda_graph(self) -> None: if self.max_seqlen_to_capture is None: self.max_seqlen_to_capture = self.max_model_len self.max_seqlen_to_capture = min(self.max_seqlen_to_capture, - self.max_model_len) + self.max_model_len) def verify_with_parallel_config( self, @@ -759,8 +763,7 @@ def maybe_create_spec_config( max_model_len=None, quantization=draft_quantization, enforce_eager=target_model_config.enforce_eager, - max_seqlen_to_capture=target_model_config. - max_seqlen_to_capture, + max_seqlen_to_capture=target_model_config.max_seqlen_to_capture, max_logprobs=target_model_config.max_logprobs, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8955a516d393..122e77c07dde 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -44,7 +44,7 @@ class EngineArgs: tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: bool = False - max_context_len_to_capture: int = 8192 + max_context_len_to_capture: Optional[int] = None max_seqlen_to_capture: int = 8192 disable_custom_all_reduce: bool = False tokenizer_pool_size: int = 0 @@ -320,8 +320,10 @@ def add_cli_args( type=int, default=EngineArgs.max_context_len_to_capture, help='Maximum context length covered by CUDA ' - 'graphs. When a sequence has context length (DEPRECATED)' - 'larger than this, we fall back to eager mode.') + 'graphs. When a sequence has context length ' + 'larger than this, we fall back to eager mode. ' + '(DEPRECATED. Use --max-seqlen-to-capture instead' + ')') parser.add_argument('--max-seqlen-to-capture', type=int, default=EngineArgs.max_seqlen_to_capture, @@ -482,8 +484,9 @@ def create_engine_config(self, ) -> EngineConfig: self.trust_remote_code, self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.quantization_param_path, - self.enforce_eager, self.max_context_len_to_capture, self.max_seqlen_to_capture, - self.max_logprobs, self.skip_tokenizer_init) + self.enforce_eager, self.max_context_len_to_capture, + self.max_seqlen_to_capture, self.max_logprobs, + self.skip_tokenizer_init) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 4934cd09e7c8..f3d5b98c87ab 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -70,7 +70,7 @@ class LLM: If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode (DEPRECATED). + to eager mode (DEPRECATED. Use `max_seqlen_to_capture` instead). max_seqlen_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. @@ -93,7 +93,7 @@ def __init__( gpu_memory_utilization: float = 0.9, swap_space: int = 4, enforce_eager: bool = False, - max_context_len_to_capture: int = 8192, + max_context_len_to_capture: Optional[int] = None, max_seqlen_to_capture: int = 8192, disable_custom_all_reduce: bool = False, **kwargs, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 561952d11fd9..42bae6e78e6e 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -16,15 +16,23 @@ @dataclass class SequenceGroupToSample: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seqlen ----------------------| + # |-- query_len ---| + # Sequence ids for the sequence group in a previous step. seq_ids: List[int] sampling_params: SamplingParams # seq_id -> sequence data. seq_data: Dict[int, SequenceData] - # The length of the sequence of the sequence group. None if it is in a decode + # The length of the sequence (all tokens seen in the past + new token to + # compute attention) of the sequence group. None if it is in a decode # stage. seqlen: Optional[int] - # The length of the query tokens to compute in the current step. None if it + # The length of new query tokens to compute in the current step. None if it # is in a decode stage. The length of query_len <= seqlen if chunked prefill # is enabled. query_len: Optional[int] @@ -105,8 +113,8 @@ def prepare( selected_token_indices, categorized_sample_indices, num_prompts, - ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, - query_lens, device) + ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, + device) selected_token_indices = async_tensor_h2d(selected_token_indices, dtype=torch.long, target_device=device, diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 2c6acded05b8..ae920d94235d 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -80,7 +80,7 @@ def _prepare_prompt( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - seq_lens: List[int] = [] + seqlens: List[int] = [] multi_modal_input_list: List[torch.Tensor] = [] for seq_group_metadata in seq_group_metadata_list: @@ -94,7 +94,7 @@ def _prepare_prompt( computed_len = seq_data.get_num_computed_tokens() seqlen = len(prompt_tokens) - seq_lens.append(seqlen) # Prompt token num + seqlens.append(seqlen) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids # Token position ids @@ -151,8 +151,8 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - seq_lens=seq_lens, - num_prefills=len(seq_lens), + seqlens=seqlens, + num_prefills=len(seqlens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, prefill_metadata=None, @@ -163,7 +163,7 @@ def _prepare_prompt( slot_mapping=slot_mapping, kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, seq_lens, + return (input_tokens, input_positions, attn_metadata, seqlens, multi_modal_input) def _prepare_decode( @@ -188,12 +188,12 @@ def _prepare_decode( generation_token = seq_data.get_last_token_id() input_tokens.append(generation_token) - seq_len = seq_data.get_len() - position = seq_len - 1 + seqlen = seq_data.get_len() + position = seqlen - 1 input_positions.append(position) - context_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) + context_len = seqlen if self.sliding_window is None else min( + seqlen, self.sliding_window) context_lens.append(context_len) block_table = seq_group_metadata.block_tables[seq_id] @@ -236,7 +236,7 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, - seq_lens=None, + seqlens=None, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), max_context_len=max_context_len, @@ -265,20 +265,20 @@ def prepare_input_tensors( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, + (input_tokens, input_positions, attn_metadata, seqlens, multi_modal_input ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, attn_metadata) = self._prepare_decode(seq_group_metadata_list) - seq_lens = [] + seqlens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - seq_lens, + seqlens, # query_lens is not needed if chunked prefill is not # supported. Since CPU worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, + # just use seqlens instead. + seqlens, self.device, pin_memory=False) # Broadcast the metadata. @@ -300,7 +300,7 @@ def prepare_input_tensors( sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, - seq_lens=None, + seqlens=None, selected_token_indices=selected_token_indices, categorized_sample_indices=None, generators=None, diff --git a/vllm/worker/model_input.py b/vllm/worker/model_input.py deleted file mode 100644 index b2fb83a46fd5..000000000000 --- a/vllm/worker/model_input.py +++ /dev/null @@ -1,290 +0,0 @@ -import torch -from dataclasses import dataclass -from typing import List, Set, Optional, Type - -from vllm.vllm.sequence import SequenceGroupMetadata -from vllm.attention.backends.abstract import AttentionMetadata, AttentionBackend -from vllm.lora.request import LoRARequest -from vllm.lora.layers import LoRAMapping -from vllm.config import SchedulerConfig, LoRAConfig, VisionLanguageConfig -from vllm.utils import make_tensor_with_pad - - -_PAD_SLOT_ID = -1 - - -@dataclass -class GpuModelInput: - """Input to run a model. - - Input tensors include inputs across multiple sequence groups. - It assumes inputs are ordered by prefill -> decode sequences. - """ - # (num_tokens,) 1D Flattened input token IDs. - input_tokens: torch.Tensor - # (num_tokens,) Positions of a token in its sequence. Used for RoPE. - input_positions: torch.Tensor - # Attention metadata to run attention kernels. - attn_metadata: AttentionMetadata - # (batch_size,) A sequence length for each sequence group in a batch. - seq_lens: List[int] - # (batch_size,) A query length for eaach sequence group in a batch. - query_lens: List[int] - # Set of lora requests. - lora_requests: Set[LoRARequest] - # Inputs used for multi modality. - multi_modal_input: Optional[torch.Tensor] - # (num_tokens,) A page index per token. Each slot index is flattened. For - # example, if slot mapping is 15 and block size is 8, it means block index - # 1 and offset 3. - slot_mapping: torch.Tensor - # Lora mapping. None if lora is not used. - lora_mapping: Optional[LoRAMapping] - - @classmethod - def from_sequence_groups( - cls, - seq_group_metadata_list: List[SequenceGroupMetadata], - scheduler_config: SchedulerConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - block_size: int, - device: str, - attn_backend: Type[AttentionBackend], - sliding_window: Optional[int]) -> "GpuModelInput": - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - lora_index_mapping: List[int] = [] - lora_prompt_mapping: List[int] = [] - lora_requests: Set[LoRARequest] = set() - - seq_lens: List[int] = [] - context_lens: List[int] = [] - query_lens: List[int] = [] - prefix_block_tables: List[List[int]] = [] - multi_modal_input_list: List[torch.Tensor] = [] - - is_prompt = False - for seq_group_metadata in seq_group_metadata_list: - # assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - # assert len(seq_ids) == 1 - seq_id = seq_ids[0] - is_prompt = seq_group_metadata.is_prompt - - computed_block_nums = seq_group_metadata.computed_block_nums - if (scheduler_config is not None - and scheduler_config.chunked_prefill_enabled - and not (computed_block_nums is None - or computed_block_nums == [])): - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") - - token_chunk_size = seq_group_metadata.token_chunk_size - seq_data = seq_group_metadata.seq_data[seq_id] - computed_len = seq_data.get_num_computed_tokens() - # We should use get_len here because in case of preemption - # it contains output tokens. - prefill_end = min(seq_data.get_len(), - computed_len + token_chunk_size) - prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] - seqlen = prefill_end - seq_lens.append(seqlen) - - # NOTE: This only works for oooooooxxx style attention. - if computed_block_nums is not None and len( - computed_block_nums) > 0 and sliding_window is None: - # Prefix is not supported with sliding_window - computed_len = len(computed_block_nums) * block_size - prompt_tokens = prompt_tokens[computed_len:] - prefix_block_tables.append(computed_block_nums) - elif scheduler_config.chunked_prefill_enabled or not is_prompt: - if seq_group_metadata.block_tables is not None: - # Prefill has chunked before. - block_table = seq_group_metadata.block_tables[seq_id] - prefix_block_tables.append(block_table) - else: - # The first prefill. - prefix_block_tables.append([]) - else: - prefix_block_tables.append([]) - # Right now, prefill start is always 0. However, this - # assumption can be changed once chunked prefill is introduced. - assert computed_len == 0 - - # actual prompt lens - context_lens.append(computed_len) - query_lens.append(seqlen - computed_len) - - input_tokens.extend(prompt_tokens) - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, prefill_end))) - lora_id = seq_group_metadata.lora_int_id - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * (seqlen - computed_len) - lora_prompt_mapping.extend( - [lora_id] * - (seqlen - computed_len - if seq_group_metadata.sampling_params.prompt_logprobs else 1)) - - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) - - if seq_group_metadata.block_tables is None: - # During memory profiling, the block tables are not initialized - # yet. In this case, we just use a dummy slot mapping. - slot_mapping.extend([_PAD_SLOT_ID] * seqlen) - continue - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seqlen - sliding_window). - # For example, if the prompt len is 10, sliding window is 8, and - # block size is 4, the first two tokens are masked and the slot - # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if sliding_window is not None: - assert computed_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention") - start_idx = max(0, seqlen - sliding_window) - - for i in range(computed_len, prefill_end): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // block_size] - block_offset = i % block_size - slot = block_number * block_size + block_offset - slot_mapping.append(slot) - - max_query_len = max(query_lens) - max_seqlen = max(seq_lens) - assert max_query_len > 0 - - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=device) - - if multi_modal_input_list: - assert vision_language_config, ( - "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(device) - else: - multi_modal_input = None - - # Prepare prefix block tables - max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) - block_tables = make_tensor_with_pad( - prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=device, - ) - - # Query length can be shorter than key (i.e., prompt) when prefill - # is chunked or prefix cached. - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) - subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - - torch.cumsum(query_lens_tensor, - dim=0, - dtype=subquery_start_loc.dtype, - out=subquery_start_loc[1:]) - - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - - attn_metadata = attn_backend.make_metadata( - is_prompt=is_prompt, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_context_len=max(context_lens), - max_seqlen=max_seqlen, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - ) - - # Decode - # attn_metadata = self.attn_backend.make_metadata( - # is_prompt=False, - # seq_lens=None, - # seq_lens_tensor=None, - # max_query_len=None, - # max_context_len=max_context_len, - # max_seqlen=None, - # subquery_start_loc=None, - # seq_start_loc=None, - # context_lens=context_lens_tensor, - # block_tables=block_tables, - # use_cuda_graph=use_captured_graph, - # ) - - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=device) - input_positions_tensor = torch.tensor(input_positions, - dtype=torch.long, - device=device) - slot_mapping_tensor = torch.tensor(slot_mapping, - dtype=torch.long, - device=device) - - if lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None - - attn_metadata = AttentionMetadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - prefill_metadata=attn_metadata, - decode_metadata=attn_metadata, - kv_cache_dtype=self.kv_cache_dtype, - ) - - return ModelInput( - input_tokens=input_tokens_tensor, - input_positions=input_positions_tensor, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_requests=lora_requests, - multi_modal_input=multi_modal_input, - slot_mapping=slot_mapping_tensor, - lora_mapping=lora_mapping, - ) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 06d1c31fac29..714ba946cf98 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -42,7 +42,7 @@ class PreparePromptMetadata(NamedTuple): input_tokens: List[int] input_positions: List[int] attn_metadata: Optional[AttentionMetadataPerStage] - seq_lens: List[int] + seqlens: List[int] query_lens: List[int] lora_index_mapping: List[int] lora_prompt_mapping: List[int] @@ -56,7 +56,7 @@ def empty(cls): input_tokens=[], input_positions=[], attn_metadata=None, - seq_lens=[], + seqlens=[], query_lens=[], lora_index_mapping=[], lora_prompt_mapping=[], @@ -134,9 +134,8 @@ def __init__( self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - self.max_seqlen_to_capture = ( - self.model_config.max_seqlen_to_capture - if self.model_config is not None else 0) + self.max_seqlen_to_capture = (self.model_config.max_seqlen_to_capture + if self.model_config is not None else 0) self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype @@ -231,7 +230,7 @@ def _prepare_prompt( lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() - seq_lens: List[int] = [] + seqlens: List[int] = [] context_lens: List[int] = [] query_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] @@ -240,13 +239,11 @@ def _prepare_prompt( if len(seq_group_metadata_list) == 0: return PreparePromptMetadata.empty() - is_prompt = False for seq_group_metadata in seq_group_metadata_list: - # assert seq_group_metadata.is_prompt + assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) - # assert len(seq_ids) == 1 + assert len(seq_ids) == 1 seq_id = seq_ids[0] - is_prompt = seq_group_metadata.is_prompt computed_block_nums = seq_group_metadata.computed_block_nums if (self.scheduler_config is not None @@ -259,21 +256,19 @@ def _prepare_prompt( token_chunk_size = seq_group_metadata.token_chunk_size seq_data = seq_group_metadata.seq_data[seq_id] - computed_len = seq_data.get_num_computed_tokens() + context_len = seq_data.get_num_computed_tokens() # We should use get_len here because in case of preemption # it contains output tokens. - prefill_end = min(seq_data.get_len(), - computed_len + token_chunk_size) - prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] - seqlen = prefill_end - seq_lens.append(seqlen) + seqlen = min(seq_data.get_len(), context_len + token_chunk_size) + prompt_tokens = seq_data.get_token_ids()[context_len:seqlen] + seqlens.append(seqlen) # NOTE: This only works for oooooooxxx style attention. if computed_block_nums is not None and len( computed_block_nums) > 0 and self.sliding_window is None: # Prefix is not supported with sliding_window - computed_len = len(computed_block_nums) * self.block_size - prompt_tokens = prompt_tokens[computed_len:] + context_len = len(computed_block_nums) * self.block_size + prompt_tokens = prompt_tokens[context_len:] prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: @@ -287,25 +282,25 @@ def _prepare_prompt( prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. - assert computed_len == 0 + assert context_len == 0 # actual prompt lens - context_lens.append(computed_len) - query_lens.append(seqlen - computed_len) + context_lens.append(context_len) + query_lens.append(seqlen - context_len) input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, prefill_end))) + input_positions.extend(list(range(context_len, seqlen))) lora_id = seq_group_metadata.lora_int_id if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping += [lora_id] * (seqlen - computed_len) + lora_index_mapping += [lora_id] * (seqlen - context_len) lora_prompt_mapping.extend( [lora_id] * - (seqlen - computed_len + (seqlen - context_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.multi_modal_data: @@ -327,12 +322,12 @@ def _prepare_prompt( # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: - assert computed_len == 0, ( + assert context_len == 0, ( "Prefix caching is currently not supported with " "sliding window attention") start_idx = max(0, seqlen - self.sliding_window) - for i in range(computed_len, prefill_end): + for i in range(context_len, seqlen): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue @@ -343,7 +338,7 @@ def _prepare_prompt( slot_mapping.append(slot) max_query_len = max(query_lens) - max_seqlen = max(seq_lens) + max_seqlen = max(seqlens) assert max_query_len > 0 context_lens_tensor = torch.tensor(context_lens, @@ -372,16 +367,16 @@ def _prepare_prompt( # Query length can be shorter than key (i.e., prompt) when prefill # is chunked or prefix cached. query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) + dtype=torch.long, + device=self.device) subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + seqlens_tensor = torch.tensor(seqlens, + dtype=torch.int, + device=self.device) + seq_start_loc = torch.zeros(seqlens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) @@ -390,21 +385,20 @@ def _prepare_prompt( dtype=subquery_start_loc.dtype, out=subquery_start_loc[1:]) - torch.cumsum(seq_lens_tensor, + torch.cumsum(seqlens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) attn_metadata = self.attn_backend.make_metadata( - is_prompt=is_prompt, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, + is_prompt=True, + seqlens=seqlens, + seqlens_tensor=seqlens_tensor, max_query_len=max_query_len, - max_context_len=max(context_lens), max_seqlen=max_seqlen, subquery_start_loc=subquery_start_loc, seq_start_loc=seq_start_loc, - context_lens=context_lens_tensor, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, ) @@ -413,7 +407,7 @@ def _prepare_prompt( input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, - seq_lens=seq_lens, + seqlens=seqlens, query_lens=query_lens, lora_index_mapping=lora_index_mapping, lora_prompt_mapping=lora_prompt_mapping, @@ -480,10 +474,9 @@ def _prepare_decode( # For decoding requests, batch_size == input_tokens. batch_size = len(input_tokens) max_seqlen = max(seqlens) - use_captured_graph = ( - not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_seqlen <= self.max_seqlen_to_capture) + use_captured_graph = (not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_seqlen <= self.max_seqlen_to_capture) if use_captured_graph: graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size @@ -497,8 +490,8 @@ def _prepare_decode( batch_size = graph_batch_size seqlens_tensor = torch.tensor(seqlens, - dtype=torch.int, - device=self.device) + dtype=torch.int, + device=self.device) if use_captured_graph: # When using cuda-graph all these tensors should be @@ -527,14 +520,13 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - seq_lens=None, - seq_lens_tensor=seqlens_tensor, + seqlens=None, + seqlens_tensor=seqlens_tensor, max_query_len=None, - max_context_len=None, max_seqlen=max_seqlen, subquery_start_loc=None, seq_start_loc=None, - context_lens=None, + context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) @@ -567,7 +559,7 @@ def prepare_input_tensors( input_tokens, input_positions, prefill_attn_metadata, - seq_lens, + seqlens, query_lens, lora_index_mapping, lora_prompt_mapping, @@ -585,13 +577,13 @@ def prepare_input_tensors( decode_slot_mapping, ) = self._prepare_decode(decode_reqs) sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, seq_lens, query_lens, - self.device, self.pin_memory) + seq_group_metadata_list, seqlens, query_lens, self.device, + self.pin_memory) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 - num_prefills = len(seq_lens) + num_prefills = len(seqlens) num_prefill_tokens = len(input_tokens) num_decode_tokens = len(decode_input_tokens) @@ -803,10 +795,10 @@ def profile_run(self) -> None: int(max_num_batched_tokens / self.vision_language_config.image_feature_size)) for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) + seqlen = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) seq_data, fake_multi_modal_input = _prepare_fake_inputs( - seq_len, self.vision_language_config) + seqlen, self.vision_language_config) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, @@ -910,14 +902,13 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # Create dummy attn_metadata. decode_metadata = self.attn_backend.make_metadata( is_prompt=False, - seq_lens=None, - seq_lens_tensor=seqlens[:batch_size], + seqlens=None, + seqlens_tensor=seqlens[:batch_size], max_query_len=None, - max_context_len=None, max_seqlen=self.max_seqlen_to_capture, subquery_start_loc=None, seq_start_loc=None, - context_lens=None, + context_lens_tensor=None, block_tables=block_tables[:batch_size], use_cuda_graph=True, ) @@ -1027,7 +1018,7 @@ def capture( "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "context_lens": attn_metadata.decode_metadata.context_lens, + "seqlens_tensor": attn_metadata.decode_metadata.seqlens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} @@ -1049,8 +1040,8 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["context_lens"].copy_( - attn_metadata.decode_metadata.context_lens, non_blocking=True) + self.input_buffers["seqlens_tensor"].copy_( + attn_metadata.decode_metadata.seqlens_tensor, non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. @@ -1089,18 +1080,18 @@ def _get_graph_batch_size(batch_size: int) -> int: def _prepare_fake_inputs( - seq_len: int, vision_language_config: Optional[VisionLanguageConfig]): + seqlen: int, vision_language_config: Optional[VisionLanguageConfig]): """Prepare fake inputs for profile run.""" if vision_language_config: prompt_tokens = [ vision_language_config.image_token_id ] * vision_language_config.image_feature_size + [0] * ( - seq_len - vision_language_config.image_feature_size) + seqlen - vision_language_config.image_feature_size) fake_image_input = MultiModalData( type=MultiModalData.Type.IMAGE, data=torch.zeros(vision_language_config.image_input_shape, dtype=torch.float16)) else: - prompt_tokens = [0] * seq_len + prompt_tokens = [0] * seqlen fake_image_input = None return SequenceData(prompt_tokens), fake_image_input diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 3078a48dbe8b..e6fe82e1f710 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -52,7 +52,7 @@ def _prepare_prompt( input_positions: List[List[int]] = [] input_block_ids: List[int] = [] - seq_lens: List[int] = [] + seqlens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -62,7 +62,7 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() seqlen = len(prompt_tokens) - seq_lens.append(seqlen) + seqlens.append(seqlen) input_tokens.append(prompt_tokens) input_positions.append(list(range(seqlen))) @@ -72,7 +72,7 @@ def _prepare_prompt( assert len(block_table) == 1 input_block_ids.append(block_table[0]) - max_seqlen = max(seq_lens) + max_seqlen = max(seqlens) assert max_seqlen > 0 input_tokens = make_tensor_with_pad(input_tokens, max_seqlen, @@ -88,7 +88,7 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - return input_tokens, input_positions, input_block_ids, seq_lens + return input_tokens, input_positions, input_block_ids, seqlens def _prepare_decode( self, @@ -110,10 +110,10 @@ def _prepare_decode( generation_token = seq_data.get_last_token_id() input_tokens.append([generation_token]) - seq_len = seq_data.get_len() - position = seq_len - 1 + seqlen = seq_data.get_len() + position = seqlen - 1 input_positions.append([position]) - context_lens.append(seq_len) + context_lens.append(seqlen) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] @@ -149,18 +149,18 @@ def prepare_input_tensors( # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_block_ids, - seq_lens) = self._prepare_prompt(seq_group_metadata_list) + seqlens) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) - seq_lens = [] + seqlens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - seq_lens, + seqlens, # query_lens is not needed if chunked prefill is not # supported. Since neuron worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, + # just use seqlens instead. + seqlens, self.device, self.pin_memory) From 93b9ed15310008bfb8aeecf09e1b79fbcb4caa74 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 1 May 2024 06:50:43 -0700 Subject: [PATCH 04/12] should work now. --- tests/worker/test_model_runner.py | 69 +++++++++++----------- vllm/attention/backends/rocm_flash_attn.py | 16 ++--- vllm/attention/ops/paged_attn.py | 4 +- 3 files changed, 44 insertions(+), 45 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index dc1014adf2da..3c156c5790d1 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -23,13 +23,13 @@ def test_prepare_prompt(batch_size): lora_config=None) model_runner.set_block_size(16) - seq_lens = [] + seqlens = [] seq_group_metadata_list = [] block_tables = {0: [1]} for i in range(batch_size): # make sure all tokens fit into one block seqlen = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seqlen) + seqlens.append(seqlen) seq_data = SequenceData(list(range(seqlen))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -43,27 +43,28 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices = [] selected_token_start_idx = 0 - for seqlen in seq_lens: + for seqlen in seqlens: expected_selected_token_indices.append(selected_token_start_idx + seqlen - 1) selected_token_start_idx += seqlen - (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, + (input_tokens, input_positions, attn_metadata, return_seqlens, _, _, _, _, _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) - assert return_seq_lens == seq_lens + assert return_seqlens == seqlens assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is True - assert torch.allclose(attn_metadata.seq_lens_tensor, - torch.tensor(seq_lens, device=device)) - assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_seqlen == max(seq_lens) + assert torch.allclose( + attn_metadata.seqlens_tensor, + torch.tensor(seqlens, device=device, dtype=torch.int)) + assert attn_metadata.seqlens == seqlens + assert attn_metadata.max_seqlen == max(seqlens) # Test subquery start locs. start_idx = 0 start_loc = [start_idx] - for seqlen in seq_lens: + for seqlen in seqlens: start_idx += seqlen start_loc.append(start_idx) assert torch.allclose( @@ -74,17 +75,16 @@ def test_prepare_prompt(batch_size): # equivalent to subquery_start_loc. start_idx = 0 seq_start_loc = [start_idx] - for seqlen in seq_lens: + for seqlen in seqlens: start_idx += seqlen seq_start_loc.append(start_idx) assert torch.allclose( attn_metadata.seq_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) - assert attn_metadata.max_context_len is None assert torch.allclose( - attn_metadata.context_lens, - torch.zeros(attn_metadata.context_lens.shape[0], + attn_metadata.context_lens_tensor, + torch.zeros(attn_metadata.context_lens_tensor.shape[0], dtype=torch.int, device=device)) @@ -95,18 +95,18 @@ def test_prepare_prompt(batch_size): # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False - assert len(input_tokens) == sum(seq_lens) - assert len(input_positions) == sum(seq_lens) + assert len(input_tokens) == sum(seqlens) + assert len(input_positions) == sum(seqlens) torch.testing.assert_close(input_tokens, input_positions) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, + seqlens, + query_lens=seqlens, device=model_runner.device, pin_memory=model_runner.pin_memory) - assert len(input_tokens) == sum(seq_lens) - assert len(input_positions) == sum(seq_lens) + assert len(input_tokens) == sum(seqlens) + assert len(input_positions) == sum(seqlens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, @@ -145,12 +145,12 @@ def test_prepare_decode_cuda_graph(batch_size): lora_config=None) model_runner.set_block_size(16) - seq_lens = [] + seqlens = [] seq_group_metadata_list = [] for i in range(batch_size): # make sure all tokens fit into one block seqlen = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seqlen) + seqlens.append(seqlen) seq_data = list(range(seqlen)) seq_data = SequenceData(seq_data) seq_group_metadata = SequenceGroupMetadata( @@ -171,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is False - assert attn_metadata.seq_lens is None - assert attn_metadata.max_seqlen is None + assert attn_metadata.seqlens is None assert attn_metadata.subquery_start_loc is None assert attn_metadata.seq_start_loc is None - assert attn_metadata.max_context_len == max(seq_lens) + assert attn_metadata.max_seqlen == max(seqlens) assert torch.allclose( - attn_metadata.context_lens[:len(seq_lens)], - torch.tensor(seq_lens, dtype=torch.int, device=device)) + attn_metadata.seqlens_tensor[:len(seqlens)], + torch.tensor(seqlens, dtype=torch.int, device=device)) # block table's first index corresponds to each batch, meaning in # decoding it is each token. @@ -197,13 +196,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify Sampling expected_selected_token_indices = [] selected_token_start_idx = 0 - for seqlen in seq_lens: + for seqlen in seqlens: expected_selected_token_indices.append(selected_token_start_idx) selected_token_start_idx += 1 sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, + seqlens, + query_lens=seqlens, device=model_runner.device, pin_memory=model_runner.pin_memory) actual = sampling_metadata.selected_token_indices @@ -240,13 +239,13 @@ def test_empty_seq_group(): assert attn_metadata is None assert len(slot_mapping) == 0 - (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, + (input_tokens, input_positions, attn_metadata, return_seqlens, _, _, _, _, _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None assert len(slot_mapping) == 0 - assert len(return_seq_lens) == 0 + assert len(return_seqlens) == 0 @pytest.fixture @@ -286,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): model_runner.set_block_size(16) # Add prefill requests. - seq_lens = [] + seqlens = [] seq_group_metadata_list = [] prefill_metadata_list = [] decode_metadata_list = [] @@ -296,7 +295,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): for i in range(prefill_batch_size): # make sure all tokens fit into one block seqlen = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seqlen) + seqlens.append(seqlen) seq_data = SequenceData(list(range(seqlen))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -341,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): else: assert attn_metadata.num_decode_tokens == _get_graph_batch_size( decode_batch_size) - assert attn_metadata.num_prefill_tokens == sum(seq_lens) + assert attn_metadata.num_prefill_tokens == sum(seqlens) # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 1be1c4ed8eff..5bab9c2fea36 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -66,9 +66,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] + seqlens: Optional[List[int]] + # seqlens stored as a tensor. + seqlens_tensor: Optional[torch.Tensor] # NOTE(sang): Definition of context_len, query_len, and seqlen. # |---------- N-1 iteration --------| @@ -248,7 +248,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - assert prefill_meta.seq_lens is not None + assert prefill_meta.seqlens is not None if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the @@ -275,7 +275,7 @@ def forward( query, key, value, - prefill_meta.seq_lens, + prefill_meta.seqlens, self.scale, ) else: @@ -304,7 +304,7 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.seq_lens_tensor, + prefill_meta.seqlens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, @@ -334,12 +334,12 @@ def _naive_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - seq_lens: List[int], + seqlens: List[int], scale: float, ) -> torch.Tensor: output = torch.empty_like(query) start = 0 - for _, seqlen in enumerate(seq_lens): + for _, seqlen in enumerate(seqlens): end = start + seqlen out = _naive_masked_attention( query[start:end], diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index f798e535a8c8..26245c962777 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -167,7 +167,7 @@ def forward_prefix( value_cache: torch.Tensor, block_tables: torch.Tensor, subquery_start_loc: torch.Tensor, - seq_lens_tensor: torch.Tensor, + seqlens_tensor: torch.Tensor, context_lens: torch.Tensor, max_query_len: int, alibi_slopes: Optional[torch.Tensor], @@ -183,7 +183,7 @@ def forward_prefix( block_tables, # subquery_start_loc is (batch_size + 1,) subquery_start_loc[:-1], - seq_lens_tensor, + seqlens_tensor, context_lens, max_query_len, alibi_slopes, From 6cb2ead2117e646d3440c81a2ff4153909ef9a00 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 1 May 2024 17:41:01 -0700 Subject: [PATCH 05/12] . --- .buildkite/test-pipeline.yaml | 1 + tests/spec_decode/test_multi_step_worker.py | 24 ++++++++++++--------- tests/spec_decode/test_ngram_worker.py | 24 +++++++++++++-------- vllm/attention/backends/rocm_flash_attn.py | 2 +- vllm/worker/cpu_model_runner.py | 19 ++++++++-------- 5 files changed, 40 insertions(+), 30 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 11cda053260e..0d8ab466f0af 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -17,6 +17,7 @@ steps: - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py - label: Core Test command: pytest -v -s core diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 98f2731de9aa..cc0427633e68 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -34,7 +34,7 @@ def test_assert_enough_kv_space(num_steps: int): list(range(block_size * 2)), ] - final_seq_lens = [ + final_prompt_lens = [ len(prompt + output) + num_steps for prompt, output in zip(prompts, prev_output_tokens) ] @@ -43,7 +43,7 @@ def test_assert_enough_kv_space(num_steps: int): prompts, num_gpu_blocks, block_size, - final_seq_lens, + final_prompt_lens, continuations=prev_output_tokens) assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access @@ -103,17 +103,21 @@ def test_same_output_for_single_step(): [6, 7, 8, 9, 10], ] - final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] multi_step_execute_model_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) single_step_execute_model_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) @@ -181,7 +185,7 @@ def test_same_output_for_multi_step(): random.randint(0, 1000) for _ in range(random.randint(10, 20)) ] for _ in range(10)] - final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) multi_step_worker.execute_model = patch_execute_model_with_seeds( @@ -195,7 +199,7 @@ def test_same_output_for_multi_step(): num_gpu_blocks, block_size, continuations=continuations, - final_seq_lens=final_seq_lens), ) + final_prompt_lens=final_prompt_lens), ) # Run multi-step. zero_kv_cache(multi_step_worker.cache_engine) @@ -217,7 +221,7 @@ def test_same_output_for_multi_step(): num_gpu_blocks, block_size, continuations=continuations, - final_seq_lens=final_seq_lens)) + final_prompt_lens=final_prompt_lens)) single_step_output.extend( worker.execute_model(**execute_model_data.to_dict(), )) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py index ee4135015713..e7e2e87f599d 100644 --- a/tests/spec_decode/test_ngram_worker.py +++ b/tests/spec_decode/test_ngram_worker.py @@ -43,11 +43,13 @@ def test_ngram_algo_correctness_for_single_no_match(): ] proposal_len = 5 - final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] ngram_sampler_output_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) proposals = proposer.get_proposals( **ngram_sampler_output_data.to_dict(), @@ -110,11 +112,13 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): ] proposal_len = 5 - final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] ngram_sampler_output_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) proposals = proposer.get_proposals( **ngram_sampler_output_data.to_dict(), @@ -180,11 +184,13 @@ def test_ngram_algo_correctness_for_batches_match_all(): ] proposal_len = 5 - final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] ngram_sampler_output_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) proposals = proposer.get_proposals( **ngram_sampler_output_data.to_dict(), diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 5bab9c2fea36..b99b1e67bce4 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -317,7 +317,7 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.seqlens, + decode_meta.seqlens_tensor, decode_meta.max_seqlen, attn_metadata.kv_cache_dtype, self.num_kv_heads, diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index ae920d94235d..468a62f5c610 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -174,7 +174,7 @@ def _prepare_decode( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - context_lens: List[int] = [] + seqlens: List[int] = [] block_tables: List[List[int]] = [] for seq_group_metadata in seq_group_metadata_list: @@ -192,9 +192,9 @@ def _prepare_decode( position = seqlen - 1 input_positions.append(position) - context_len = seqlen if self.sliding_window is None else min( + seqlen = seqlen if self.sliding_window is None else min( seqlen, self.sliding_window) - context_lens.append(context_len) + seqlens.append(seqlen) block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] @@ -208,7 +208,7 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) - max_context_len = max(context_lens) + max_seqlen = max(seqlens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, @@ -219,9 +219,7 @@ def _prepare_decode( slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) - context_lens = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) + seqlens = torch.tensor(seqlens, dtype=torch.int, device=self.device) max_block_table_len = max( len(block_table) for block_table in block_tables) @@ -236,14 +234,15 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, - seqlens=None, + seqlens=seqlens, + max_seqlen=max_seqlen, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), - max_context_len=max_context_len, + max_context_len=None, num_prefills=0, prefill_metadata=None, decode_metadata=None, - context_lens=context_lens, + context_lens=None, block_tables=block_tables, kv_cache_dtype=self.kv_cache_dtype, ) From 9c451679505b3bc925a109814e455d9000a252dc Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 1 May 2024 18:33:50 -0700 Subject: [PATCH 06/12] fixed --- .buildkite/test-pipeline.yaml | 1 - vllm/attention/backends/torch_sdpa.py | 3 +-- vllm/attention/ops/paged_attn.py | 2 +- vllm/worker/cpu_model_runner.py | 3 +-- 4 files changed, 3 insertions(+), 6 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 0d8ab466f0af..11cda053260e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -17,7 +17,6 @@ steps: - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py - label: Core Test command: pytest -v -s core diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index adce25545feb..2f5fd11ab1c0 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -58,7 +58,6 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, # or all decoding. True if all sequences are prompts. is_prompt: bool slot_mapping: torch.Tensor - seq_lens: Optional[List[int]] def __post_init__(self): # Set during the execution of the first attention op. @@ -189,7 +188,7 @@ def forward( key_cache, value_cache, attn_metadata.block_tables, - attn_metadata.seq_lens, + attn_metadata.seqlens_tensor, attn_metadata.max_seqlen, attn_metadata.kv_cache_dtype, self.num_kv_heads, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 26245c962777..14082cb1f947 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -15,7 +15,7 @@ class PagedAttentionMetadata: """Metadata for PagedAttention.""" # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. - seqlens: Optional[torch.Tensor] + seqlens_tensor: Optional[torch.Tensor] # Maximum sequence length in the batch. max_seqlen: Optional[int] # (batch_size, max_blocks_per_seq). diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 468a62f5c610..f6abff1bc499 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -234,11 +234,10 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, - seqlens=seqlens, + seqlens_tensor=seqlens, max_seqlen=max_seqlen, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), - max_context_len=None, num_prefills=0, prefill_metadata=None, decode_metadata=None, From f1547db7edfa666e107c5be3581eafff15ff68a0 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 1 May 2024 19:19:29 -0700 Subject: [PATCH 07/12] make cpu work --- vllm/attention/backends/torch_sdpa.py | 1 + vllm/worker/cpu_model_runner.py | 12 +++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 2f5fd11ab1c0..32e86d5722b2 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -58,6 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, # or all decoding. True if all sequences are prompts. is_prompt: bool slot_mapping: torch.Tensor + seqlens: Optional[List[int]] def __post_init__(self): # Set during the execution of the first attention op. diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index f6abff1bc499..a1227d6b878d 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -152,13 +152,13 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, seqlens=seqlens, + seqlens_tensor=None, + max_seqlen=None, num_prefills=len(seqlens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, prefill_metadata=None, decode_metadata=None, - max_context_len=None, - context_lens=None, block_tables=torch.tensor([]), slot_mapping=slot_mapping, kv_cache_dtype=self.kv_cache_dtype, @@ -219,7 +219,9 @@ def _prepare_decode( slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) - seqlens = torch.tensor(seqlens, dtype=torch.int, device=self.device) + seqlens_tensor = torch.tensor(seqlens, + dtype=torch.int, + device=self.device) max_block_table_len = max( len(block_table) for block_table in block_tables) @@ -234,14 +236,14 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, - seqlens_tensor=seqlens, + seqlens=seqlens, + seqlens_tensor=seqlens_tensor, max_seqlen=max_seqlen, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), num_prefills=0, prefill_metadata=None, decode_metadata=None, - context_lens=None, block_tables=block_tables, kv_cache_dtype=self.kv_cache_dtype, ) From 22bc8fbead9d3c122390b6bfe6eccd0d9ec5051a Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 2 May 2024 07:05:05 -0700 Subject: [PATCH 08/12] fixed --- vllm/attention/backends/torch_sdpa.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 32e86d5722b2..04f2b2bd6ea0 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -136,7 +136,7 @@ def forward( kv_scale) if attn_metadata.is_prompt: - assert attn_metadata.seq_lens is not None + assert attn_metadata.seqlens is not None if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) @@ -147,13 +147,13 @@ def forward( if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.seq_lens) # type: ignore + attn_metadata.seqlens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.seq_lens, self.sliding_window, + attn_metadata.seqlens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.seq_lens) + att_masks = [None] * len(attn_metadata.seqlens) attn_metadata.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) @@ -164,7 +164,7 @@ def forward( output = torch.empty( (num_tokens, self.num_heads, self.head_size), dtype=query.dtype) - for seqlen, mask in zip(attn_metadata.seq_lens, + for seqlen, mask in zip(attn_metadata.seqlens, attn_metadata.attn_bias): end = start + seqlen sub_out = scaled_dot_product_attention( @@ -205,10 +205,10 @@ def forward( def _make_alibi_bias( alibi_slopes: torch.Tensor, dtype: torch.dtype, - seq_lens: List[int], + seqlens: List[int], ) -> List[torch.Tensor]: attn_biases = [] - for seqlen in seq_lens: + for seqlen in seqlens: bias = torch.arange(seqlen, dtype=dtype) # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(seqlen, 1)` @@ -229,12 +229,12 @@ def _make_alibi_bias( def _make_sliding_window_bias( - seq_lens: List[int], + seqlens: List[int], window_size: Optional[int], dtype: torch.dtype, ) -> List[torch.Tensor]: attn_biases = [] - for seqlen in seq_lens: + for seqlen in seqlens: tensor = torch.full( (1, seqlen, seqlen), dtype=dtype, From a53be4c17bea55924868ff45e17c8b49c8d6aa0e Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 3 May 2024 02:13:22 -0700 Subject: [PATCH 09/12] working. --- csrc/attention/attention_kernels.cu | 2 +- tests/basic_correctness/test_chunked_prefill.py | 8 +++----- vllm/config.py | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 0c521d11c69d..02a5469fbfca 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel( num_kv_heads, \ scale, \ block_tables_ptr, \ - seqlens_ptr, \ + seqlens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ q_stride, \ diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index f1f0f9eafcb8..47d582c726c6 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -10,17 +10,15 @@ MODELS = [ "facebook/opt-125m", - # "meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-2-7b-hf", ] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) -# @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) -@pytest.mark.parametrize("chunked_prefill_token_size", [16]) -# @pytest.mark.parametrize("enforce_eager", [False, True]) -@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("enforce_eager", [False, True]) # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) diff --git a/vllm/config.py b/vllm/config.py index 74b84906f3ac..463782aee0dd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -104,8 +104,8 @@ def __init__( self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture if self.max_context_len_to_capture is not None: - logger.warning("`max_context_len_to_capture` is deprecated. " - "Use `max_seqlen_to_capture` instead.") + raise ValueError("`max_context_len_to_capture` is deprecated. " + "Use `max_seqlen_to_capture` instead.") self.max_seqlen_to_capture = (max_seqlen_to_capture or max_context_len_to_capture) self.max_logprobs = max_logprobs From 1a3109a40e2e3ea4900535531bccd24129bf98ae Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 3 May 2024 02:59:27 -0700 Subject: [PATCH 10/12] refactor seqlen -> seq_len --- .../kernels/benchmark_paged_attention.py | 24 ++-- csrc/attention/attention_kernels.cu | 70 +++++------ csrc/cpu/attention.cpp | 82 ++++++------- csrc/ops.h | 8 +- tests/kernels/test_attention.py | 34 +++--- tests/spec_decode/e2e/conftest.py | 4 +- tests/worker/test_model_runner.py | 86 +++++++------- vllm/_custom_ops.py | 14 +-- vllm/attention/backends/flash_attn.py | 26 ++-- vllm/attention/backends/rocm_flash_attn.py | 42 +++---- vllm/attention/backends/torch_sdpa.py | 34 +++--- vllm/attention/backends/xformers.py | 42 +++---- vllm/attention/ops/paged_attn.py | 24 ++-- vllm/config.py | 20 ++-- vllm/engine/arg_utils.py | 10 +- vllm/entrypoints/llm.py | 8 +- vllm/model_executor/sampling_metadata.py | 14 +-- vllm/worker/cpu_model_runner.py | 58 ++++----- vllm/worker/model_runner.py | 112 +++++++++--------- vllm/worker/neuron_model_runner.py | 34 +++--- 20 files changed, 373 insertions(+), 373 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index eb7120c2a6e7..ca7967c1ab0d 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -16,7 +16,7 @@ def main( version: str, num_seqs: int, - seqlen: int, + seq_len: int, num_query_heads: int, num_kv_heads: int, head_size: int, @@ -48,12 +48,12 @@ def main( dtype=torch.float, device=device) - seqlens = [seqlen for _ in range(num_seqs)] - max_seqlen = max(seqlens) - seqlens = torch.tensor(seqlens, dtype=torch.int, device=device) + seq_lens = [seq_len for _ in range(num_seqs)] + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device) # Create the block tables. - max_num_blocks_per_seq = (max_seqlen + block_size - 1) // block_size + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): block_table = [ @@ -77,7 +77,7 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": - num_partitions = ((max_seqlen + PARTITION_SIZE - 1) // PARTITION_SIZE) + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), dtype=output.dtype, @@ -109,9 +109,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: num_kv_heads, scale, block_tables, - seqlens, + seq_lens, block_size, - max_seqlen, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -128,9 +128,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: num_kv_heads, scale, block_tables, - seqlens, + seq_lens, block_size, - max_seqlen, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -165,7 +165,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: choices=["v1", "v2"], default="v2") parser.add_argument("--batch-size", type=int, default=8) - parser.add_argument("--seqlen", type=int, default=4096) + parser.add_argument("--seq_len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", @@ -198,7 +198,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: main( version=args.version, num_seqs=args.batch_size, - seqlen=args.seqlen, + seq_len=args.seq_len, num_query_heads=args.num_query_heads, num_kv_heads=args.num_kv_heads, head_size=args.head_size, diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 02a5469fbfca..8b1b5e098015 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -104,7 +104,7 @@ __device__ void paged_attention_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seqlens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -115,13 +115,13 @@ __device__ void paged_attention_kernel( const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; - const int seqlen = seqlens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seqlen) { + const int seq_len = seq_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { // No work to do. Terminate the thread block. return; } - const int num_seq_blocks = DIVIDE_ROUND_UP(seqlen, BLOCK_SIZE); + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; // [start_block_idx, end_block_idx) is the range of blocks to process. @@ -131,7 +131,7 @@ __device__ void paged_attention_kernel( // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seqlen); + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int num_tokens = end_token_idx - start_token_idx; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); @@ -245,12 +245,12 @@ __device__ void paged_attention_kernel( // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seqlen + 1) : 0; + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. - const bool mask = token_idx >= seqlen; + const bool mask = token_idx >= seq_len; logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); @@ -371,7 +371,7 @@ __device__ void paged_attention_kernel( scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { - v_vec_ptr[j] = token_idx + j < seqlen ? v_vec_ptr[j] : zero_value; + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; } } accs[i] += dot(logits_vec, v_vec); @@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seqlens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel( const float kv_scale) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seqlens, + out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); } @@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seqlens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel( const float kv_scale) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seqlens, max_num_blocks_per_seq, alibi_slopes, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); } @@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel( const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ seqlens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; - const int seqlen = seqlens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(seqlen, PARTITION_SIZE); + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; @@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel( num_kv_heads, \ scale, \ block_tables_ptr, \ - seqlens_ptr, \ + seq_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ q_stride, \ @@ -639,8 +639,8 @@ void paged_attention_v1_launcher( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& seqlens, - int max_seqlen, + torch::Tensor& seq_lens, + int max_seq_len, const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); @@ -664,11 +664,11 @@ void paged_attention_v1_launcher( CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* seqlens_ptr = seqlens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seqlen = DIVIDE_ROUND_UP(max_seqlen, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_seqlen * sizeof(float); + int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_seq_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! @@ -715,8 +715,8 @@ void paged_attention_v1_launcher( num_kv_heads, \ scale, \ block_tables, \ - seqlens, \ - max_seqlen, \ + seq_lens, \ + max_seq_len, \ alibi_slopes, \ kv_scale); @@ -746,9 +746,9 @@ void paged_attention_v1( int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seqlens, // [num_seqs] + torch::Tensor& seq_lens, // [num_seqs] int block_size, - int max_seqlen, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { @@ -790,7 +790,7 @@ void paged_attention_v1( num_kv_heads, \ scale, \ block_tables_ptr, \ - seqlens_ptr, \ + seq_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ q_stride, \ @@ -803,7 +803,7 @@ void paged_attention_v1( exp_sums_ptr, \ max_logits_ptr, \ tmp_out_ptr, \ - seqlens_ptr, \ + seq_lens_ptr, \ max_num_partitions); template< @@ -824,8 +824,8 @@ void paged_attention_v2_launcher( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& seqlens, - int max_seqlen, + torch::Tensor& seq_lens, + int max_seq_len, const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); @@ -852,10 +852,10 @@ void paged_attention_v2_launcher( CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* seqlens_ptr = seqlens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_seqlen, PARTITION_SIZE); + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int logits_size = PARTITION_SIZE * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); @@ -909,8 +909,8 @@ void paged_attention_v2_launcher( num_kv_heads, \ scale, \ block_tables, \ - seqlens, \ - max_seqlen, \ + seq_lens, \ + max_seq_len, \ alibi_slopes, \ kv_scale); @@ -943,9 +943,9 @@ void paged_attention_v2( int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seqlens, // [num_seqs] + torch::Tensor& seq_lens, // [num_seqs] int block_size, - int max_seqlen, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index e41d356a645b..c1d765be0559 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -70,11 +70,11 @@ template FORCE_INLINE std::pair reduceSoftmaxAlibi(T *data, const int size, const int capacity, const float alibi_slope, const int start_index, - const int seqlen) { - data[0] += alibi_slope * (start_index - seqlen + 1); + const int seq_len) { + data[0] += alibi_slope * (start_index - seq_len + 1); T max = data[0]; for (int i = 1; i < size; ++i) { - T qk = data[i] + alibi_slope * (start_index + i - seqlen + 1); + T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1); data[i] = qk; max = max >= qk ? max : qk; } @@ -225,7 +225,7 @@ struct paged_attention_v1_impl { const int num_kv_heads, const float scale, const int *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ seqlens, // [num_seqs] + const int *__restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float *__restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -235,32 +235,32 @@ struct paged_attention_v1_impl { static_assert(BLOCK_SIZE == 16); - int max_seqlen = max_num_blocks_per_seq * BLOCK_SIZE; - int max_seqlen_padded = (max_seqlen + 15) & 0xFFFFFFF0; - TORCH_CHECK((max_seqlen_padded * sizeof(float)) % 64 == 0); + int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE; + int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0; + TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0); const int parallel_work_item_num = omp_get_max_threads(); size_t logits_bytes = - parallel_work_item_num * max_seqlen_padded * sizeof(float); + parallel_work_item_num * max_seq_len_padded * sizeof(float); float *logits = (float *)std::aligned_alloc( 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_seqlen_padded] + // [parallel_work_item_num, max_seq_len_padded] #pragma omp parallel for collapse(2) schedule(dynamic, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - int seqlen = seqlens[seq_idx]; + int seq_len = seq_lens[seq_idx]; const int *seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx; - const int block_num = (seqlen + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; const scalar_t *__restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const int last_block_token_num = - seqlen - (block_num - 1) * BLOCK_SIZE; + seq_len - (block_num - 1) * BLOCK_SIZE; float *__restrict__ thread_block_logits = - logits + omp_get_thread_num() * max_seqlen_padded; + logits + omp_get_thread_num() * max_seq_len_padded; // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { @@ -278,11 +278,11 @@ struct paged_attention_v1_impl { // Compute softmax if (alibi_slopes) { - reduceSoftmaxAlibi(thread_block_logits, seqlen, + reduceSoftmaxAlibi(thread_block_logits, seq_len, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, - seqlen); + seq_len); } else { - reduceSoftmax(thread_block_logits, seqlen, + reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); } @@ -340,7 +340,7 @@ struct paged_attention_v1_impl { #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ paged_attention_v1_impl::call( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, seqlens_ptr, max_num_blocks_per_seq, \ + block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ num_heads); @@ -348,8 +348,8 @@ template void paged_attention_v1_impl_launcher( torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &seqlens, - int max_seqlen, const c10::optional &alibi_slopes) { + torch::Tensor &block_tables, torch::Tensor &seq_lens, + int max_seq_len, const c10::optional &alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher( T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int *block_tables_ptr = block_tables.data_ptr(); - int *seqlens_ptr = seqlens.data_ptr(); + int *seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { case 64: @@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher( #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ paged_attention_v1_impl_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seqlens, max_seqlen, alibi_slopes); + seq_lens, max_seq_len, alibi_slopes); #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ @@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &block_tables, - torch::Tensor &seqlens, int block_size, - int max_seqlen, + torch::Tensor &seq_lens, int block_size, + int max_seq_len, const c10::optional &alibi_slopes, const std::string &kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); @@ -448,7 +448,7 @@ struct paged_attention_v2_impl { const int num_kv_heads, const float scale, const int *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ seqlens, // [num_seqs] + const int *__restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float *__restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -465,17 +465,17 @@ struct paged_attention_v2_impl { for (int partition_idx = 0; partition_idx < max_num_partitions; ++partition_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int seqlen = seqlens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int start_token_idx = partition_idx * PARTITION_SIZE; - if (start_token_idx >= seqlen) + if (start_token_idx >= seq_len) continue; const int partition_num = - (seqlen + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; const bool no_reduce = (partition_num == 1); const int token_num = - (std::min(seqlen, start_token_idx + PARTITION_SIZE) - + (std::min(seq_len, start_token_idx + PARTITION_SIZE) - start_token_idx); const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; @@ -508,7 +508,7 @@ struct paged_attention_v2_impl { if (alibi_slopes) { max_and_sum = reduceSoftmaxAlibi( logits, token_num, block_num * BLOCK_SIZE, - alibi_slopes[head_idx], start_token_idx, seqlen); + alibi_slopes[head_idx], start_token_idx, seq_len); } else { max_and_sum = reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); @@ -583,9 +583,9 @@ struct paged_attention_v2_impl { #pragma omp parallel for collapse(2) schedule(static, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int seqlen = seqlens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_num = - (seqlen + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; if (partition_num == 1) continue; @@ -612,9 +612,9 @@ struct paged_attention_v2_impl { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { - const int seqlen = seqlens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_num = - (seqlen + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; if (partition_num == 1) continue; @@ -649,7 +649,7 @@ struct paged_attention_v2_impl { paged_attention_v2_impl::call( \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seqlens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ kv_block_stride, kv_head_stride, num_seqs, num_heads, \ max_num_partitions); @@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher( torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &seqlens, int block_size, - int max_seqlen, const c10::optional &alibi_slopes) { + torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, + int max_seq_len, const c10::optional &alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher( T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int *block_tables_ptr = block_tables.data_ptr(); - int *seqlens_ptr = seqlens.data_ptr(); + int *seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { case 64: @@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher( #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ paged_attention_v2_impl_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, seqlens, block_size, \ - max_seqlen, alibi_slopes); + num_kv_heads, scale, block_tables, seq_lens, block_size, \ + max_seq_len, alibi_slopes); #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ @@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &block_tables, - torch::Tensor &seqlens, int block_size, - int max_seqlen, + torch::Tensor &seq_lens, int block_size, + int max_seq_len, const c10::optional &alibi_slopes, const std::string &kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); diff --git a/csrc/ops.h b/csrc/ops.h index 8c8278f223fe..9541adcb3de8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -10,9 +10,9 @@ void paged_attention_v1( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& seqlens, + torch::Tensor& seq_lens, int block_size, - int max_seqlen, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale); @@ -28,9 +28,9 @@ void paged_attention_v2( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& seqlens, + torch::Tensor& seq_lens, int block_size, - int max_seqlen, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale); diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 0bf4f1810deb..84539205e0ae 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -61,7 +61,7 @@ def ref_single_query_cached_kv_attention( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - seqlens: torch.Tensor, + seq_lens: torch.Tensor, scale: float, alibi_slopes: Optional[torch.Tensor], ) -> None: @@ -72,15 +72,15 @@ def ref_single_query_cached_kv_attention( num_seqs = query.shape[0] block_tables = block_tables.cpu().tolist() - seqlens = seqlens.cpu().tolist() + seq_lens = seq_lens.cpu().tolist() for i in range(num_seqs): q = query[i].unsqueeze(0) block_table = block_tables[i] - seqlen = int(seqlens[i]) + seq_len = int(seq_lens[i]) keys = [] values = [] - for j in range(seqlen): + for j in range(seq_len): block_number = int(block_table[j // block_size]) block_offset = j % block_size @@ -100,8 +100,8 @@ def ref_single_query_cached_kv_attention( alibi_bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. - position_ids = torch.arange(seqlen).int() - alibi_bias = (position_ids - seqlen + 1).float() + position_ids = torch.arange(seq_len).int() + alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) @@ -149,13 +149,13 @@ def test_paged_attention( if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - seqlens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - seqlens[-1] = MAX_SEQ_LEN - max_seqlen = max(seqlens) - seqlens = torch.tensor(seqlens, dtype=torch.int) + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int) # Create the block tables. - max_num_blocks_per_seq = (max_seqlen + block_size - 1) // block_size + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): block_table = [ @@ -186,15 +186,15 @@ def test_paged_attention( num_kv_heads, scale, block_tables, - seqlens, + seq_lens, block_size, - max_seqlen, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, ) elif version == "v2": - num_partitions = ((max_seqlen + PARTITION_SIZE - 1) // PARTITION_SIZE) + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( @@ -217,9 +217,9 @@ def test_paged_attention( num_kv_heads, scale, block_tables, - seqlens, + seq_lens, block_size, - max_seqlen, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -254,7 +254,7 @@ def test_paged_attention( key_cache, value_cache, block_tables, - seqlens, + seq_lens, scale, alibi_slopes, ) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index da58eba3006f..492620cf6e2c 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -45,7 +45,7 @@ def __init__( gpu_memory_utilization: float = 0.9, swap_space: int = 4, enforce_eager: bool = False, - max_seqlen_to_capture: int = 8192, + max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, **kwargs, ) -> None: @@ -66,7 +66,7 @@ def __init__( gpu_memory_utilization=gpu_memory_utilization, swap_space=swap_space, enforce_eager=enforce_eager, - max_seqlen_to_capture=max_seqlen_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, engine_use_ray=True, disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 3c156c5790d1..e7975d0ef48b 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -23,14 +23,14 @@ def test_prepare_prompt(batch_size): lora_config=None) model_runner.set_block_size(16) - seqlens = [] + seq_lens = [] seq_group_metadata_list = [] block_tables = {0: [1]} for i in range(batch_size): # make sure all tokens fit into one block - seqlen = i % (model_runner.block_size - 1) + 1 - seqlens.append(seqlen) - seq_data = SequenceData(list(range(seqlen))) + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -43,29 +43,29 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices = [] selected_token_start_idx = 0 - for seqlen in seqlens: + for seq_len in seq_lens: expected_selected_token_indices.append(selected_token_start_idx + - seqlen - 1) - selected_token_start_idx += seqlen - (input_tokens, input_positions, attn_metadata, return_seqlens, _, _, _, _, + seq_len - 1) + selected_token_start_idx += seq_len + (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) - assert return_seqlens == seqlens + assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is True assert torch.allclose( - attn_metadata.seqlens_tensor, - torch.tensor(seqlens, device=device, dtype=torch.int)) - assert attn_metadata.seqlens == seqlens - assert attn_metadata.max_seqlen == max(seqlens) + attn_metadata.seq_lens_tensor, + torch.tensor(seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.max_seq_len == max(seq_lens) # Test subquery start locs. start_idx = 0 start_loc = [start_idx] - for seqlen in seqlens: - start_idx += seqlen + for seq_len in seq_lens: + start_idx += seq_len start_loc.append(start_idx) assert torch.allclose( attn_metadata.subquery_start_loc, @@ -75,8 +75,8 @@ def test_prepare_prompt(batch_size): # equivalent to subquery_start_loc. start_idx = 0 seq_start_loc = [start_idx] - for seqlen in seqlens: - start_idx += seqlen + for seq_len in seq_lens: + start_idx += seq_len seq_start_loc.append(start_idx) assert torch.allclose( @@ -95,18 +95,18 @@ def test_prepare_prompt(batch_size): # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False - assert len(input_tokens) == sum(seqlens) - assert len(input_positions) == sum(seqlens) + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) torch.testing.assert_close(input_tokens, input_positions) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - seqlens, - query_lens=seqlens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) - assert len(input_tokens) == sum(seqlens) - assert len(input_positions) == sum(seqlens) + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, @@ -145,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size): lora_config=None) model_runner.set_block_size(16) - seqlens = [] + seq_lens = [] seq_group_metadata_list = [] for i in range(batch_size): # make sure all tokens fit into one block - seqlen = i % (model_runner.block_size - 1) + 1 - seqlens.append(seqlen) - seq_data = list(range(seqlen)) + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = list(range(seq_len)) seq_data = SequenceData(seq_data) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -171,13 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is False - assert attn_metadata.seqlens is None + assert attn_metadata.seq_lens is None assert attn_metadata.subquery_start_loc is None assert attn_metadata.seq_start_loc is None - assert attn_metadata.max_seqlen == max(seqlens) + assert attn_metadata.max_seq_len == max(seq_lens) assert torch.allclose( - attn_metadata.seqlens_tensor[:len(seqlens)], - torch.tensor(seqlens, dtype=torch.int, device=device)) + attn_metadata.seq_lens_tensor[:len(seq_lens)], + torch.tensor(seq_lens, dtype=torch.int, device=device)) # block table's first index corresponds to each batch, meaning in # decoding it is each token. @@ -196,13 +196,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify Sampling expected_selected_token_indices = [] selected_token_start_idx = 0 - for seqlen in seqlens: + for seq_len in seq_lens: expected_selected_token_indices.append(selected_token_start_idx) selected_token_start_idx += 1 sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - seqlens, - query_lens=seqlens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) actual = sampling_metadata.selected_token_indices @@ -239,13 +239,13 @@ def test_empty_seq_group(): assert attn_metadata is None assert len(slot_mapping) == 0 - (input_tokens, input_positions, attn_metadata, return_seqlens, _, _, _, _, + (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None assert len(slot_mapping) == 0 - assert len(return_seqlens) == 0 + assert len(return_seq_lens) == 0 @pytest.fixture @@ -285,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): model_runner.set_block_size(16) # Add prefill requests. - seqlens = [] + seq_lens = [] seq_group_metadata_list = [] prefill_metadata_list = [] decode_metadata_list = [] @@ -294,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): decode_batch_size = batch_size - prefill_batch_size for i in range(prefill_batch_size): # make sure all tokens fit into one block - seqlen = i % (model_runner.block_size - 1) + 1 - seqlens.append(seqlen) - seq_data = SequenceData(list(range(seqlen))) + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -311,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # Add decode requests for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block - seqlen = i % (model_runner.block_size - 1) + 1 - prompt_toks = list(range(seqlen)) + seq_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(seq_len)) seq_data = SequenceData(prompt_toks) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -340,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): else: assert attn_metadata.num_decode_tokens == _get_graph_batch_size( decode_batch_size) - assert attn_metadata.num_prefill_tokens == sum(seqlens) + assert attn_metadata.num_prefill_tokens == sum(seq_lens) # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f2bb13b95de1..c53d00fc9943 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -39,16 +39,16 @@ def paged_attention_v1( num_kv_heads: int, scale: float, block_tables: torch.Tensor, - seqlens: torch.Tensor, + seq_lens: torch.Tensor, block_size: int, - max_seqlen: int, + max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, ) -> None: vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, seqlens, - block_size, max_seqlen, alibi_slopes, + num_kv_heads, scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale) @@ -63,16 +63,16 @@ def paged_attention_v2( num_kv_heads: int, scale: float, block_tables: torch.Tensor, - seqlens: torch.Tensor, + seq_lens: torch.Tensor, block_size: int, - max_seqlen: int, + max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, ) -> None: vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, - block_tables, seqlens, block_size, max_seqlen, + block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 1f89588d19e9..07d9713188f5 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -68,22 +68,22 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. - seqlens: Optional[List[int]] - # seqlens stored as a tensor. - seqlens_tensor: Optional[torch.Tensor] + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, query_len, and seqlen. + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| + # |-------------------- seq_len ----------------------| # |-- query_len ---| # Maximum query length in the batch. max_query_len: Optional[int] # Maximum sequence length in the batch. - max_seqlen: Optional[int] + max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -221,10 +221,10 @@ def forward( q=query, k=key, v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seqlen, - max_seqlen_k=prefill_meta.max_seqlen, + cu_seq_lens_q=prefill_meta.seq_start_loc, + cu_seq_lens_k=prefill_meta.seq_start_loc, + max_seq_len_q=prefill_meta.max_seq_len, + max_seq_len_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -245,7 +245,7 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.seqlens_tensor, + prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, @@ -258,8 +258,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.seqlens_tensor, - decode_meta.max_seqlen, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index df18a7754aca..adcfe8d303de 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -66,22 +66,22 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. - seqlens: Optional[List[int]] - # seqlens stored as a tensor. - seqlens_tensor: Optional[torch.Tensor] + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, query_len, and seqlen. + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| + # |-------------------- seq_len ----------------------| # |-- query_len ---| # Maximum query length in the batch. max_query_len: Optional[int] # Maximum sequence length in the batch. - max_seqlen: Optional[int] + max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -247,7 +247,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - assert prefill_meta.seqlens is not None + assert prefill_meta.seq_lens is not None if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the @@ -260,8 +260,8 @@ def forward( None, prefill_meta.seq_start_loc, prefill_meta.seq_start_loc, - prefill_meta.max_seqlen, - prefill_meta.max_seqlen, + prefill_meta.max_seq_len, + prefill_meta.max_seq_len, True, self.scale, ) @@ -274,7 +274,7 @@ def forward( query, key, value, - prefill_meta.seqlens, + prefill_meta.seq_lens, self.scale, ) else: @@ -282,10 +282,10 @@ def forward( q=query, k=key, v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seqlen, - max_seqlen_k=prefill_meta.max_seqlen, + cu_seq_lens_q=prefill_meta.seq_start_loc, + cu_seq_lens_k=prefill_meta.seq_start_loc, + max_seq_len_q=prefill_meta.max_seq_len, + max_seq_len_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, ) @@ -303,7 +303,7 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.seqlens_tensor, + prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, @@ -317,8 +317,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.seqlens_tensor, - decode_meta.max_seqlen, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -334,13 +334,13 @@ def _naive_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - seqlens: List[int], + seq_lens: List[int], scale: float, ) -> torch.Tensor: output = torch.empty_like(query) start = 0 - for _, seqlen in enumerate(seqlens): - end = start + seqlen + for _, seq_len in enumerate(seq_lens): + end = start + seq_len out = _naive_masked_attention( query[start:end], key[start:end], @@ -349,7 +349,7 @@ def _naive_attention( ) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out) - start += seqlen + start += seq_len return output diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 04f2b2bd6ea0..2a1554f64b8e 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, # or all decoding. True if all sequences are prompts. is_prompt: bool slot_mapping: torch.Tensor - seqlens: Optional[List[int]] + seq_lens: Optional[List[int]] def __post_init__(self): # Set during the execution of the first attention op. @@ -136,7 +136,7 @@ def forward( kv_scale) if attn_metadata.is_prompt: - assert attn_metadata.seqlens is not None + assert attn_metadata.seq_lens is not None if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) @@ -147,13 +147,13 @@ def forward( if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.seqlens) # type: ignore + attn_metadata.seq_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.seqlens, self.sliding_window, + attn_metadata.seq_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.seqlens) + att_masks = [None] * len(attn_metadata.seq_lens) attn_metadata.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) @@ -164,9 +164,9 @@ def forward( output = torch.empty( (num_tokens, self.num_heads, self.head_size), dtype=query.dtype) - for seqlen, mask in zip(attn_metadata.seqlens, + for seq_len, mask in zip(attn_metadata.seq_lens, attn_metadata.attn_bias): - end = start + seqlen + end = start + seq_len sub_out = scaled_dot_product_attention( query[:, start:end, :], key[:, start:end, :], @@ -189,8 +189,8 @@ def forward( key_cache, value_cache, attn_metadata.block_tables, - attn_metadata.seqlens_tensor, - attn_metadata.max_seqlen, + attn_metadata.seq_lens_tensor, + attn_metadata.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -205,13 +205,13 @@ def forward( def _make_alibi_bias( alibi_slopes: torch.Tensor, dtype: torch.dtype, - seqlens: List[int], + seq_lens: List[int], ) -> List[torch.Tensor]: attn_biases = [] - for seqlen in seqlens: - bias = torch.arange(seqlen, dtype=dtype) + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seqlen, 1)` + # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. @@ -221,7 +221,7 @@ def _make_alibi_bias( bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]) inf_mask = torch.empty( - (1, seqlen, seqlen), + (1, seq_len, seq_len), dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) attn_biases.append((bias + inf_mask).to(dtype)) @@ -229,14 +229,14 @@ def _make_alibi_bias( def _make_sliding_window_bias( - seqlens: List[int], + seq_lens: List[int], window_size: Optional[int], dtype: torch.dtype, ) -> List[torch.Tensor]: attn_biases = [] - for seqlen in seqlens: + for seq_len in seq_lens: tensor = torch.full( - (1, seqlen, seqlen), + (1, seq_len, seq_len), dtype=dtype, fill_value=1, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 572f9069715c..60f6d43f2eaa 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -68,22 +68,22 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. - seqlens: Optional[List[int]] - # seqlens stored as a tensor. - seqlens_tensor: Optional[torch.Tensor] + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| + # |-------------------- seq_len ----------------------| # |-- query_len ---| # Maximum query length in the batch. max_query_len: Optional[int] # FIXME: It is for flash attn. # Maximum sequence length in the batch. - max_seqlen: Optional[int] + max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -241,7 +241,7 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.seqlens_tensor, + prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, @@ -256,8 +256,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.seqlens_tensor, - decode_meta.max_seqlen, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -288,7 +288,7 @@ def _run_memory_efficient_xformers_forward( value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ - assert attn_metadata.seqlens is not None + assert attn_metadata.seq_lens is not None original_query = query if self.num_kv_heads != self.num_heads: # GQA/MQA requires the shape [B, M, G, H, K]. @@ -309,7 +309,7 @@ def _run_memory_efficient_xformers_forward( if attn_metadata.attn_bias is None: if self.alibi_slopes is None: attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seqlens) + attn_metadata.seq_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) @@ -317,7 +317,7 @@ def _run_memory_efficient_xformers_forward( else: attn_metadata.attn_bias = _make_alibi_bias( self.alibi_slopes, self.num_kv_heads, query.dtype, - attn_metadata.seqlens) + attn_metadata.seq_lens) # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce @@ -342,8 +342,8 @@ def _run_memory_efficient_xformers_forward( # one. This is inefficient, especially when we have many short prompts. output = torch.empty_like(original_query) start = 0 - for i, seqlen in enumerate(attn_metadata.seqlens): - end = start + seqlen + for i, seq_len in enumerate(attn_metadata.seq_lens): + end = start + seq_len out = xops.memory_efficient_attention_forward( query[None, start:end], key[None, start:end], @@ -353,7 +353,7 @@ def _run_memory_efficient_xformers_forward( scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out.view_as(original_query[start:end])) - start += seqlen + start += seq_len return output @@ -361,13 +361,13 @@ def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, dtype: torch.dtype, - seqlens: List[int], + seq_lens: List[int], ) -> LowerTriangularMaskWithTensorBias: attn_biases = [] - for seqlen in seqlens: - bias = torch.arange(seqlen, dtype=dtype) + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seqlen, 1)` + # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. @@ -375,16 +375,16 @@ def _make_alibi_bias( # element. bias = bias[None, :] - bias[:, None] - padded_len = (seqlen + 7) // 8 * 8 + padded_len = (seq_len + 7) // 8 * 8 num_heads = alibi_slopes.shape[0] bias = torch.empty( 1, # batch size num_heads, - seqlen, + seq_len, padded_len, device=alibi_slopes.device, dtype=dtype, - )[:, :, :, :seqlen].copy_(bias) + )[:, :, :, :seq_len].copy_(bias) bias.mul_(alibi_slopes[:, None, None]) if num_heads != num_kv_heads: bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 578ac7f9a07a..00a0f10c0950 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -15,9 +15,9 @@ class PagedAttentionMetadata: """Metadata for PagedAttention.""" # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. - seqlens_tensor: Optional[torch.Tensor] + seq_lens_tensor: Optional[torch.Tensor] # Maximum sequence length in the batch. - max_seqlen: Optional[int] + max_seq_len: Optional[int] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks @@ -84,8 +84,8 @@ def forward_decode( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - seqlens: torch.Tensor, - max_seqlen: int, + seq_lens: torch.Tensor, + max_seq_len: int, kv_cache_dtype: str, num_kv_heads: int, scale: float, @@ -96,7 +96,7 @@ def forward_decode( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = ((max_seqlen + _PARTITION_SIZE - 1) // + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use @@ -105,7 +105,7 @@ def forward_decode( # to parallelize. # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = (max_seqlen <= 8192 + use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)) if use_v1: # Run PagedAttention V1. @@ -117,9 +117,9 @@ def forward_decode( num_kv_heads, scale, block_tables, - seqlens, + seq_lens, block_size, - max_seqlen, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -149,9 +149,9 @@ def forward_decode( num_kv_heads, scale, block_tables, - seqlens, + seq_lens, block_size, - max_seqlen, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -167,7 +167,7 @@ def forward_prefix( value_cache: torch.Tensor, block_tables: torch.Tensor, subquery_start_loc: torch.Tensor, - seqlens_tensor: torch.Tensor, + seq_lens_tensor: torch.Tensor, context_lens: torch.Tensor, max_query_len: int, alibi_slopes: Optional[torch.Tensor], @@ -184,7 +184,7 @@ def forward_prefix( block_tables, # subquery_start_loc is (batch_size + 1,) subquery_start_loc[:-1], - seqlens_tensor, + seq_lens_tensor, context_lens, max_query_len, alibi_slopes, diff --git a/vllm/config.py b/vllm/config.py index 463782aee0dd..4a505c787f2a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -63,8 +63,8 @@ class ModelConfig: If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode (DEPRECATED. Use max_seqlen_to_capture instead). - max_seqlen_to_capture: Maximum sequence len covered by CUDA graphs. + to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode skip_tokenizer_init: If true, skip initialization of tokenizer and @@ -87,7 +87,7 @@ def __init__( quantization_param_path: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, - max_seqlen_to_capture: Optional[int] = None, + max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 5, skip_tokenizer_init: bool = False, ) -> None: @@ -105,8 +105,8 @@ def __init__( self.max_context_len_to_capture = max_context_len_to_capture if self.max_context_len_to_capture is not None: raise ValueError("`max_context_len_to_capture` is deprecated. " - "Use `max_seqlen_to_capture` instead.") - self.max_seqlen_to_capture = (max_seqlen_to_capture + "Use `max_seq_len_to_capture` instead.") + self.max_seq_len_to_capture = (max_seq_len_to_capture or max_context_len_to_capture) self.max_logprobs = max_logprobs self.skip_tokenizer_init = skip_tokenizer_init @@ -199,9 +199,9 @@ def _verify_quantization(self) -> None: "non-quantized models.", self.quantization) def _verify_cuda_graph(self) -> None: - if self.max_seqlen_to_capture is None: - self.max_seqlen_to_capture = self.max_model_len - self.max_seqlen_to_capture = min(self.max_seqlen_to_capture, + if self.max_seq_len_to_capture is None: + self.max_seq_len_to_capture = self.max_model_len + self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_model_len) def verify_with_parallel_config( @@ -781,8 +781,8 @@ def maybe_create_spec_config( max_model_len=None, quantization=draft_quantization, enforce_eager=target_model_config.enforce_eager, - max_seqlen_to_capture=target_model_config. - max_seqlen_to_capture, + max_seq_len_to_capture=target_model_config. + max_seq_len_to_capture, max_logprobs=target_model_config.max_logprobs, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4d8730421a57..1c8e1079bed5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -45,7 +45,7 @@ class EngineArgs: quantization: Optional[str] = None enforce_eager: bool = False max_context_len_to_capture: Optional[int] = None - max_seqlen_to_capture: int = 8192 + max_seq_len_to_capture: int = 8192 disable_custom_all_reduce: bool = False tokenizer_pool_size: int = 0 tokenizer_pool_type: str = "ray" @@ -324,11 +324,11 @@ def add_cli_args( help='Maximum context length covered by CUDA ' 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode. ' - '(DEPRECATED. Use --max-seqlen-to-capture instead' + '(DEPRECATED. Use --max-seq_len-to-capture instead' ')') - parser.add_argument('--max-seqlen-to-capture', + parser.add_argument('--max-seq_len-to-capture', type=int, - default=EngineArgs.max_seqlen_to_capture, + default=EngineArgs.max_seq_len_to_capture, help='Maximum sequence length covered by CUDA ' 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode.') @@ -501,7 +501,7 @@ def create_engine_config(self, ) -> EngineConfig: self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture, - self.max_seqlen_to_capture, self.max_logprobs, + self.max_seq_len_to_capture, self.max_logprobs, self.skip_tokenizer_init) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f3d5b98c87ab..3ed660e18336 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -70,8 +70,8 @@ class LLM: If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode (DEPRECATED. Use `max_seqlen_to_capture` instead). - max_seqlen_to_capture: Maximum sequence len covered by CUDA graphs. + to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead). + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. disable_custom_all_reduce: See ParallelConfig @@ -94,7 +94,7 @@ def __init__( swap_space: int = 4, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, - max_seqlen_to_capture: int = 8192, + max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, **kwargs, ) -> None: @@ -116,7 +116,7 @@ def __init__( swap_space=swap_space, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, - max_seqlen_to_capture=max_seqlen_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 42bae6e78e6e..56ef76130af8 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -20,7 +20,7 @@ class SequenceGroupToSample: # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| + # |-------------------- seq_len ----------------------| # |-- query_len ---| # Sequence ids for the sequence group in a previous step. @@ -31,9 +31,9 @@ class SequenceGroupToSample: # The length of the sequence (all tokens seen in the past + new token to # compute attention) of the sequence group. None if it is in a decode # stage. - seqlen: Optional[int] + seq_len: Optional[int] # The length of new query tokens to compute in the current step. None if it - # is in a decode stage. The length of query_len <= seqlen if chunked prefill + # is in a decode stage. The length of query_len <= seq_len if chunked prefill # is enabled. query_len: Optional[int] # A random number generator for sampling. @@ -55,7 +55,7 @@ def __post_init__(self): if len(self.prompt_logprob_indices) > 0: assert self.sampling_params.prompt_logprobs is not None if self.is_prompt: - assert self.seqlen is not None + assert self.seq_len is not None assert self.query_len is not None @@ -198,7 +198,7 @@ def _prepare_seq_groups( is_prompt = seq_group_metadata.is_prompt generator: Optional[torch.Generator] = None # If the current seq group is in decode stage, it is None. - seqlen: Optional[int] = None + seq_len: Optional[int] = None query_len: Optional[int] = None prompt_logprob_indices: List[int] = [] sample_indices: List[int] = [] @@ -213,7 +213,7 @@ def _prepare_seq_groups( num_prefill_sample = len(seq_ids) assert num_prefill_sample == 1 assert query_lens is not None and seq_lens is not None - query_len, seqlen = query_lens[i], seq_lens[i] + query_len, seq_len = query_lens[i], seq_lens[i] # If we need sampling, exclude num_prefill_sample tokens from # prompt logprob. prompt_logprob_len = (query_len - num_prefill_sample @@ -276,7 +276,7 @@ def sample(logits): seq_ids=seq_ids, sampling_params=sampling_params, seq_data=seq_group_metadata.seq_data, - seqlen=seqlen, + seq_len=seq_len, query_len=query_len, generator=generator, is_prompt=is_prompt, diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index a1227d6b878d..b22dc9905c2b 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -80,7 +80,7 @@ def _prepare_prompt( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - seqlens: List[int] = [] + seq_lens: List[int] = [] multi_modal_input_list: List[torch.Tensor] = [] for seq_group_metadata in seq_group_metadata_list: @@ -92,15 +92,15 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() computed_len = seq_data.get_num_computed_tokens() - seqlen = len(prompt_tokens) + seq_len = len(prompt_tokens) - seqlens.append(seqlen) # Prompt token num + seq_lens.append(seq_len) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, seqlen))) + input_positions.extend(list(range(computed_len, seq_len))) if seq_group_metadata.multi_modal_data: multi_modal_input_list.append( @@ -109,15 +109,15 @@ def _prepare_prompt( # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seqlen - sliding_window). + # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: - start_idx = max(0, seqlen - self.sliding_window) + start_idx = max(0, seq_len - self.sliding_window) - for i in range(computed_len, seqlen): + for i in range(computed_len, seq_len): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue @@ -151,10 +151,10 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - seqlens=seqlens, - seqlens_tensor=None, - max_seqlen=None, - num_prefills=len(seqlens), + seq_lens=seq_lens, + seq_lens_tensor=None, + max_seq_len=None, + num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, prefill_metadata=None, @@ -163,7 +163,7 @@ def _prepare_prompt( slot_mapping=slot_mapping, kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, seqlens, + return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input) def _prepare_decode( @@ -174,7 +174,7 @@ def _prepare_decode( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - seqlens: List[int] = [] + seq_lens: List[int] = [] block_tables: List[List[int]] = [] for seq_group_metadata in seq_group_metadata_list: @@ -188,13 +188,13 @@ def _prepare_decode( generation_token = seq_data.get_last_token_id() input_tokens.append(generation_token) - seqlen = seq_data.get_len() - position = seqlen - 1 + seq_len = seq_data.get_len() + position = seq_len - 1 input_positions.append(position) - seqlen = seqlen if self.sliding_window is None else min( - seqlen, self.sliding_window) - seqlens.append(seqlen) + seq_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] @@ -208,7 +208,7 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) - max_seqlen = max(seqlens) + max_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, @@ -219,7 +219,7 @@ def _prepare_decode( slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) - seqlens_tensor = torch.tensor(seqlens, + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=self.device) @@ -236,9 +236,9 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, - seqlens=seqlens, - seqlens_tensor=seqlens_tensor, - max_seqlen=max_seqlen, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_seq_len=max_seq_len, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), num_prefills=0, @@ -265,20 +265,20 @@ def prepare_input_tensors( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, attn_metadata, seqlens, + (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, attn_metadata) = self._prepare_decode(seq_group_metadata_list) - seqlens = [] + seq_lens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - seqlens, + seq_lens, # query_lens is not needed if chunked prefill is not # supported. Since CPU worker doesn't support chunked prefill - # just use seqlens instead. - seqlens, + # just use seq_lens instead. + seq_lens, self.device, pin_memory=False) # Broadcast the metadata. @@ -300,7 +300,7 @@ def prepare_input_tensors( sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, - seqlens=None, + seq_lens=None, selected_token_indices=selected_token_indices, categorized_sample_indices=None, generators=None, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 714ba946cf98..c9a0e2191b05 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -42,7 +42,7 @@ class PreparePromptMetadata(NamedTuple): input_tokens: List[int] input_positions: List[int] attn_metadata: Optional[AttentionMetadataPerStage] - seqlens: List[int] + seq_lens: List[int] query_lens: List[int] lora_index_mapping: List[int] lora_prompt_mapping: List[int] @@ -56,7 +56,7 @@ def empty(cls): input_tokens=[], input_positions=[], attn_metadata=None, - seqlens=[], + seq_lens=[], query_lens=[], lora_index_mapping=[], lora_prompt_mapping=[], @@ -134,7 +134,7 @@ def __init__( self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - self.max_seqlen_to_capture = (self.model_config.max_seqlen_to_capture + self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture if self.model_config is not None else 0) self.pin_memory = is_pin_memory_available() @@ -148,7 +148,7 @@ def __init__( self.model: torch.nn.Module # Set after load_model self.block_size: int # Set after initial profiling. # When using CUDA graph, the input block tables must be padded to - # max_seqlen_to_capture. However, creating the block table in + # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table # in numpy and only copy the actual input content at every iteration. # The shape of the cached block table will be @@ -217,7 +217,7 @@ def set_block_size(self, block_size: int) -> None: def get_max_block_per_batch(self) -> int: block_size = self.block_size - return (self.max_seqlen_to_capture + block_size - 1) // block_size + return (self.max_seq_len_to_capture + block_size - 1) // block_size def _prepare_prompt( self, @@ -230,7 +230,7 @@ def _prepare_prompt( lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() - seqlens: List[int] = [] + seq_lens: List[int] = [] context_lens: List[int] = [] query_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] @@ -259,9 +259,9 @@ def _prepare_prompt( context_len = seq_data.get_num_computed_tokens() # We should use get_len here because in case of preemption # it contains output tokens. - seqlen = min(seq_data.get_len(), context_len + token_chunk_size) - prompt_tokens = seq_data.get_token_ids()[context_len:seqlen] - seqlens.append(seqlen) + seq_len = min(seq_data.get_len(), context_len + token_chunk_size) + prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] + seq_lens.append(seq_len) # NOTE: This only works for oooooooxxx style attention. if computed_block_nums is not None and len( @@ -286,21 +286,21 @@ def _prepare_prompt( # actual prompt lens context_lens.append(context_len) - query_lens.append(seqlen - context_len) + query_lens.append(seq_len - context_len) input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(context_len, seqlen))) + input_positions.extend(list(range(context_len, seq_len))) lora_id = seq_group_metadata.lora_int_id if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping += [lora_id] * (seqlen - context_len) + lora_index_mapping += [lora_id] * (seq_len - context_len) lora_prompt_mapping.extend( [lora_id] * - (seqlen - context_len + (seq_len - context_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.multi_modal_data: @@ -310,13 +310,13 @@ def _prepare_prompt( if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. - slot_mapping.extend([_PAD_SLOT_ID] * seqlen) + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) continue # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seqlen - sliding_window). + # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. @@ -325,9 +325,9 @@ def _prepare_prompt( assert context_len == 0, ( "Prefix caching is currently not supported with " "sliding window attention") - start_idx = max(0, seqlen - self.sliding_window) + start_idx = max(0, seq_len - self.sliding_window) - for i in range(context_len, seqlen): + for i in range(context_len, seq_len): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue @@ -338,7 +338,7 @@ def _prepare_prompt( slot_mapping.append(slot) max_query_len = max(query_lens) - max_seqlen = max(seqlens) + max_seq_len = max(seq_lens) assert max_query_len > 0 context_lens_tensor = torch.tensor(context_lens, @@ -373,10 +373,10 @@ def _prepare_prompt( dtype=torch.int32, device=self.device) - seqlens_tensor = torch.tensor(seqlens, + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=self.device) - seq_start_loc = torch.zeros(seqlens_tensor.shape[0] + 1, + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) @@ -385,17 +385,17 @@ def _prepare_prompt( dtype=subquery_start_loc.dtype, out=subquery_start_loc[1:]) - torch.cumsum(seqlens_tensor, + torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - seqlens=seqlens, - seqlens_tensor=seqlens_tensor, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, - max_seqlen=max_seqlen, + max_seq_len=max_seq_len, subquery_start_loc=subquery_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, @@ -407,7 +407,7 @@ def _prepare_prompt( input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, - seqlens=seqlens, + seq_lens=seq_lens, query_lens=query_lens, lora_index_mapping=lora_index_mapping, lora_prompt_mapping=lora_prompt_mapping, @@ -423,7 +423,7 @@ def _prepare_decode( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - seqlens: List[int] = [] + seq_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] @@ -447,13 +447,13 @@ def _prepare_decode( generation_token = seq_data.get_last_token_id() input_tokens.append(generation_token) - seqlen = seq_data.get_len() - position = seqlen - 1 + seq_len = seq_data.get_len() + position = seq_len - 1 input_positions.append(position) - seqlen = seqlen if self.sliding_window is None else min( - seqlen, self.sliding_window) - seqlens.append(seqlen) + seq_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] @@ -473,10 +473,10 @@ def _prepare_decode( # See `capture_model` API for more details. # For decoding requests, batch_size == input_tokens. batch_size = len(input_tokens) - max_seqlen = max(seqlens) + max_seq_len = max(seq_lens) use_captured_graph = (not self.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_seqlen <= self.max_seqlen_to_capture) + and max_seq_len <= self.max_seq_len_to_capture) if use_captured_graph: graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size @@ -484,21 +484,21 @@ def _prepare_decode( input_tokens.append(0) input_positions.append(0) slot_mapping.append(_PAD_SLOT_ID) - seqlens.append(1) + seq_lens.append(1) block_tables.append([]) lora_index_mapping.append(0) batch_size = graph_batch_size - seqlens_tensor = torch.tensor(seqlens, + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=self.device) if use_captured_graph: # When using cuda-graph all these tensors should be # padded. - assert seqlens_tensor.shape[0] == len(input_tokens) - assert seqlens_tensor.shape[0] == len(input_positions) - assert seqlens_tensor.shape[0] == len(slot_mapping) + assert seq_lens_tensor.shape[0] == len(input_tokens) + assert seq_lens_tensor.shape[0] == len(input_positions) + assert seq_lens_tensor.shape[0] == len(slot_mapping) # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -520,10 +520,10 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - seqlens=None, - seqlens_tensor=seqlens_tensor, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, max_query_len=None, - max_seqlen=max_seqlen, + max_seq_len=max_seq_len, subquery_start_loc=None, seq_start_loc=None, context_lens_tensor=None, @@ -559,7 +559,7 @@ def prepare_input_tensors( input_tokens, input_positions, prefill_attn_metadata, - seqlens, + seq_lens, query_lens, lora_index_mapping, lora_prompt_mapping, @@ -577,13 +577,13 @@ def prepare_input_tensors( decode_slot_mapping, ) = self._prepare_decode(decode_reqs) sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, seqlens, query_lens, self.device, + seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 - num_prefills = len(seqlens) + num_prefills = len(seq_lens) num_prefill_tokens = len(input_tokens) num_decode_tokens = len(decode_input_tokens) @@ -795,10 +795,10 @@ def profile_run(self) -> None: int(max_num_batched_tokens / self.vision_language_config.image_feature_size)) for group_id in range(max_num_seqs): - seqlen = (max_num_batched_tokens // max_num_seqs + + seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) seq_data, fake_multi_modal_input = _prepare_fake_inputs( - seqlen, self.vision_language_config) + seq_len, self.vision_language_config) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, @@ -880,7 +880,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) - seqlens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() graph_batch_size = _get_graph_batch_size( @@ -902,10 +902,10 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # Create dummy attn_metadata. decode_metadata = self.attn_backend.make_metadata( is_prompt=False, - seqlens=None, - seqlens_tensor=seqlens[:batch_size], + seq_lens=None, + seq_lens_tensor=seq_lens[:batch_size], max_query_len=None, - max_seqlen=self.max_seqlen_to_capture, + max_seq_len=self.max_seq_len_to_capture, subquery_start_loc=None, seq_start_loc=None, context_lens_tensor=None, @@ -1018,7 +1018,7 @@ def capture( "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "seqlens_tensor": attn_metadata.decode_metadata.seqlens_tensor, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} @@ -1040,8 +1040,8 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["seqlens_tensor"].copy_( - attn_metadata.decode_metadata.seqlens_tensor, non_blocking=True) + self.input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. @@ -1080,18 +1080,18 @@ def _get_graph_batch_size(batch_size: int) -> int: def _prepare_fake_inputs( - seqlen: int, vision_language_config: Optional[VisionLanguageConfig]): + seq_len: int, vision_language_config: Optional[VisionLanguageConfig]): """Prepare fake inputs for profile run.""" if vision_language_config: prompt_tokens = [ vision_language_config.image_token_id ] * vision_language_config.image_feature_size + [0] * ( - seqlen - vision_language_config.image_feature_size) + seq_len - vision_language_config.image_feature_size) fake_image_input = MultiModalData( type=MultiModalData.Type.IMAGE, data=torch.zeros(vision_language_config.image_input_shape, dtype=torch.float16)) else: - prompt_tokens = [0] * seqlen + prompt_tokens = [0] * seq_len fake_image_input = None return SequenceData(prompt_tokens), fake_image_input diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index e6fe82e1f710..a336be04e124 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -52,7 +52,7 @@ def _prepare_prompt( input_positions: List[List[int]] = [] input_block_ids: List[int] = [] - seqlens: List[int] = [] + seq_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -61,26 +61,26 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() - seqlen = len(prompt_tokens) - seqlens.append(seqlen) + seq_len = len(prompt_tokens) + seq_lens.append(seq_len) input_tokens.append(prompt_tokens) - input_positions.append(list(range(seqlen))) + input_positions.append(list(range(seq_len))) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] assert len(block_table) == 1 input_block_ids.append(block_table[0]) - max_seqlen = max(seqlens) - assert max_seqlen > 0 + max_seq_len = max(seq_lens) + assert max_seq_len > 0 input_tokens = make_tensor_with_pad(input_tokens, - max_seqlen, + max_seq_len, pad=0, dtype=torch.long, device=self.device) input_positions = make_tensor_with_pad(input_positions, - max_seqlen, + max_seq_len, pad=0, dtype=torch.long, device=self.device) @@ -88,7 +88,7 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - return input_tokens, input_positions, input_block_ids, seqlens + return input_tokens, input_positions, input_block_ids, seq_lens def _prepare_decode( self, @@ -110,10 +110,10 @@ def _prepare_decode( generation_token = seq_data.get_last_token_id() input_tokens.append([generation_token]) - seqlen = seq_data.get_len() - position = seqlen - 1 + seq_len = seq_data.get_len() + position = seq_len - 1 input_positions.append([position]) - context_lens.append(seqlen) + context_lens.append(seq_len) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] @@ -149,18 +149,18 @@ def prepare_input_tensors( # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_block_ids, - seqlens) = self._prepare_prompt(seq_group_metadata_list) + seq_lens) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) - seqlens = [] + seq_lens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - seqlens, + seq_lens, # query_lens is not needed if chunked prefill is not # supported. Since neuron worker doesn't support chunked prefill - # just use seqlens instead. - seqlens, + # just use seq_lens instead. + seq_lens, self.device, self.pin_memory) From 476ed1dbd1172158426ed6cd1e94d1017f82cb05 Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 3 May 2024 03:00:49 -0700 Subject: [PATCH 11/12] done --- vllm/attention/backends/flash_attn.py | 8 ++++---- vllm/attention/backends/rocm_flash_attn.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 07d9713188f5..fc7501ed5e91 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -221,10 +221,10 @@ def forward( q=query, k=key, v=value, - cu_seq_lens_q=prefill_meta.seq_start_loc, - cu_seq_lens_k=prefill_meta.seq_start_loc, - max_seq_len_q=prefill_meta.max_seq_len, - max_seq_len_k=prefill_meta.max_seq_len, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_seq_len, + max_seqlen_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index adcfe8d303de..c411b3971b8f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -282,10 +282,10 @@ def forward( q=query, k=key, v=value, - cu_seq_lens_q=prefill_meta.seq_start_loc, - cu_seq_lens_k=prefill_meta.seq_start_loc, - max_seq_len_q=prefill_meta.max_seq_len, - max_seq_len_k=prefill_meta.max_seq_len, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_seq_len, + max_seqlen_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, ) From 1e16379c2fe78ba314de66ae1f0773934927ff97 Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 3 May 2024 03:45:13 -0700 Subject: [PATCH 12/12] lint --- vllm/_custom_ops.py | 5 +++-- vllm/attention/backends/torch_sdpa.py | 2 +- vllm/config.py | 4 ++-- vllm/model_executor/sampling_metadata.py | 4 ++-- vllm/worker/cpu_model_runner.py | 4 ++-- vllm/worker/model_runner.py | 12 ++++++------ 6 files changed, 16 insertions(+), 15 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c53d00fc9943..b43f646fec88 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -72,8 +72,9 @@ def paged_attention_v2( ) -> None: vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, - block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, kv_scale) + block_tables, seq_lens, block_size, + max_seq_len, alibi_slopes, kv_cache_dtype, + kv_scale) # pos encoding ops diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 2a1554f64b8e..f75a279086a2 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -165,7 +165,7 @@ def forward( (num_tokens, self.num_heads, self.head_size), dtype=query.dtype) for seq_len, mask in zip(attn_metadata.seq_lens, - attn_metadata.attn_bias): + attn_metadata.attn_bias): end = start + seq_len sub_out = scaled_dot_product_attention( query[:, start:end, :], diff --git a/vllm/config.py b/vllm/config.py index 4a505c787f2a..3bdd3f774bc2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -107,7 +107,7 @@ def __init__( raise ValueError("`max_context_len_to_capture` is deprecated. " "Use `max_seq_len_to_capture` instead.") self.max_seq_len_to_capture = (max_seq_len_to_capture - or max_context_len_to_capture) + or max_context_len_to_capture) self.max_logprobs = max_logprobs self.skip_tokenizer_init = skip_tokenizer_init @@ -202,7 +202,7 @@ def _verify_cuda_graph(self) -> None: if self.max_seq_len_to_capture is None: self.max_seq_len_to_capture = self.max_model_len self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, - self.max_model_len) + self.max_model_len) def verify_with_parallel_config( self, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 56ef76130af8..9969c45963e9 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -33,8 +33,8 @@ class SequenceGroupToSample: # stage. seq_len: Optional[int] # The length of new query tokens to compute in the current step. None if it - # is in a decode stage. The length of query_len <= seq_len if chunked prefill - # is enabled. + # is in a decode stage. The length of query_len <= seq_len if chunked + # prefill is enabled. query_len: Optional[int] # A random number generator for sampling. generator: Optional[torch.Generator] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index b22dc9905c2b..193b021b7a11 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -220,8 +220,8 @@ def _prepare_decode( dtype=torch.long, device=self.device) seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) + dtype=torch.int, + device=self.device) max_block_table_len = max( len(block_table) for block_table in block_tables) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c9a0e2191b05..bbb1f5205af5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -135,7 +135,7 @@ def __init__( int, int]] = None # Set during graph capture. self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture - if self.model_config is not None else 0) + if self.model_config is not None else 0) self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype @@ -374,8 +374,8 @@ def _prepare_prompt( device=self.device) seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) + dtype=torch.int, + device=self.device) seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) @@ -490,8 +490,8 @@ def _prepare_decode( batch_size = graph_batch_size seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) + dtype=torch.int, + device=self.device) if use_captured_graph: # When using cuda-graph all these tensors should be @@ -796,7 +796,7 @@ def profile_run(self) -> None: self.vision_language_config.image_feature_size)) for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) + (group_id < max_num_batched_tokens % max_num_seqs)) seq_data, fake_multi_modal_input = _prepare_fake_inputs( seq_len, self.vision_language_config) seq = SequenceGroupMetadata(