From 8b65bbd5883a3afe03aac01fb8307194a9155083 Mon Sep 17 00:00:00 2001 From: Andrew O'Neill Date: Tue, 19 Aug 2025 03:17:55 +0000 Subject: [PATCH 1/3] benchmark fa2 vs tree attention --- cmake/external_projects/vllm_flash_attn.cmake | 4 +- tests/v1/spec_decode/test_tree_attention.py | 182 ++++++++++++++---- vllm/attention/utils/fa_utils.py | 3 +- vllm/v1/attention/backends/flash_attn.py | 95 ++++++--- 4 files changed, 218 insertions(+), 66 deletions(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 49defccbb1fa..2683a6b41eb6 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -37,8 +37,8 @@ if(VLLM_FLASH_ATTN_SRC_DIR) else() FetchContent_Declare( vllm-flash-attn - GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f + GIT_REPOSITORY https://github.com/samsung-cnct/flash-attention.git + GIT_TAG feaab457d8d58243f19bf234a42a498647de0e6f GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 456ce712d36e..276607df021c 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -11,6 +11,31 @@ get_attention_backend) from vllm.config import ParallelConfig, SpeculativeConfig from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.flash_attn import (FlashAttentionMetadataBuilder, TreeMetadata) + +import torch.utils.benchmark as benchmark + + +def benchmark_forward( + fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs +): + """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" + if verbose: + print(desc, "- Forward pass") + + def amp_wrapper(*inputs, **kwinputs): + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + fn(*inputs, **kwinputs) + + t = benchmark.Timer( + stmt="fn_amp(*inputs, **kwinputs)", + globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m class MockAttentionLayer(torch.nn.Module): @@ -35,7 +60,9 @@ def forward_attention( seqlen_k: int, backend: _Backend, spec_token_tree: Optional[str] = None, + tree_mask: Optional[torch.Tensor] = None, num_spec_tokens: int = 0, + bench_desc: str = "", ) -> torch.Tensor: batch_size, q_len, num_heads, dim_per_head = q.shape num_kv_heads = k.shape[-2] @@ -86,10 +113,18 @@ def forward_attention( ) # Build attention metadata. - attn_metadata = builder.build( - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - ) + if isinstance(builder, FlashAttentionMetadataBuilder): + fa_tree_metadata = TreeMetadata(mask=tree_mask.repeat(batch_size), lens=torch.arange(0, (batch_size+1)*num_spec_tokens, num_spec_tokens, dtype=torch.int32, device=q.device)) if tree_mask is not None else None + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + tree_metadata=fa_tree_metadata, + ) + else: + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) # Initialize the backend implementation. instance = impl_cls( @@ -107,6 +142,23 @@ def forward_attention( key = k.view(-1, num_kv_heads, dim_per_head) value = v.view(-1, num_kv_heads, dim_per_head) output = torch.empty_like(query) + + measurement=None + + if bench_desc != "": + _, measurement = benchmark_forward( + instance.forward, + layer=layer, + query=query, + key=key, + value=value, + kv_cache=kv_cache.clone(), + attn_metadata=attn_metadata, + output=output, + desc=bench_desc, + verbose=False, + ) + return instance.forward( layer=layer, query=query, @@ -115,7 +167,7 @@ def forward_attention( kv_cache=kv_cache.clone(), attn_metadata=attn_metadata, output=output, - ) + ), measurement def test_tree_attn_correctness() -> None: @@ -125,31 +177,56 @@ def test_tree_attn_correctness() -> None: device = "cuda" tree_attn_masks = { # Chain. - "[(0,), (0, 0), (0, 0, 0)]": - torch.tensor( - [ - [1, 0, 0, 0], - [1, 1, 0, 0], - [1, 1, 1, 0], - [1, 1, 1, 1], - ], - device=device, - dtype=torch.int32, + ("[(0,), (0, 0), (0, 0, 0)]", (-1, 0, 1)): + ( + torch.tensor( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + ], + device=device, + dtype=torch.int32, + ), + torch.tensor( + [ + 0b100, + 0b110, + 0b111, + ], + device=device, + dtype=torch.uint64, + ), ), # Tree. - "[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": - torch.tensor( - [ - [1, 0, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], - [1, 0, 1, 0, 0, 0, 0], - [1, 1, 0, 1, 0, 0, 0], - [1, 1, 0, 0, 1, 0, 0], - [1, 0, 1, 0, 0, 1, 0], - [1, 0, 1, 0, 0, 0, 1], - ], - device=device, - dtype=torch.int32, + ("[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]", (-1, -1, 0, 0, 1, 1)): + ( + torch.tensor( + [ + [1, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 0, 1, 0, 0, 0, 0], + [1, 1, 0, 1, 0, 0, 0], + [1, 1, 0, 0, 1, 0, 0], + [1, 0, 1, 0, 0, 1, 0], + [1, 0, 1, 0, 0, 0, 1], + ], + device=device, + dtype=torch.int32, + ), + torch.tensor( + [ + 0b100000, + 0b010000, + 0b101000, + 0b100100, + 0b010010, + 0b010001, + ], + device=device, + dtype=torch.uint64 + ), ), } @@ -167,7 +244,7 @@ def test_tree_attn_correctness() -> None: assert num_heads % num_kv_heads == 0 # Initialize q, k, and v. - tree_size_q = tree_attn_mask.shape[0] + tree_size_q = tree_attn_mask[0].shape[0] seqlen_k = sequence_position + tree_size_q q = torch.randn( (batch_size, tree_size_q, num_heads, dim_per_head), @@ -231,7 +308,7 @@ def test_tree_attn_correctness() -> None: tree_positions, block_table, block_size) # Compute attention for the tree. - tree_attn_output = forward_attention( + forward_attn_output1, tree_attn_measurement = forward_attention( q=q, k=k, v=v, @@ -240,16 +317,49 @@ def test_tree_attn_correctness() -> None: slot_mapping=tree_slot_mapping, seqlen_k=seqlen_k, backend=_Backend.TREE_ATTN, - spec_token_tree=spec_token_tree, + spec_token_tree=spec_token_tree[0], num_spec_tokens=tree_size_q - 1, - ).view(batch_size, -1, num_heads, dim_per_head) + bench_desc="tree_attention", + ) + tree_attn_output = forward_attn_output1.view(batch_size, -1, num_heads, dim_per_head) + + tree_time = tree_attn_measurement.mean + print(f"tree_attention average time: {tree_time:.6f} seconds") + + forward_attn_output2, fa2_attn_measurement = forward_attention( + q=q, + k=k, + v=v, + kv_cache=kv_cache, + block_table=block_table, + slot_mapping=tree_slot_mapping, + seqlen_k=seqlen_k, + backend=_Backend.FLASH_ATTN_VLLM_V1, + tree_mask=tree_attn_mask[1], + num_spec_tokens=tree_size_q - 1, + bench_desc="fa2_tree_attention", + ) + fa2_tree_attn_output = forward_attn_output2.view(batch_size, -1, num_heads, dim_per_head) + + fa2_time = fa2_attn_measurement.mean + print(f"fa2_tree_attention average time: {fa2_time:.6f} seconds") + + # Calculate speedup + speedup = tree_time / fa2_time + print(f"Speedup (tree/fa2): {speedup:.2f}x") + if speedup > 1: + print(f"fa2_tree_attention is {speedup:.2f}x faster\n") + else: + print(f"tree_attention is {1/speedup:.2f}x faster\n") + + assert torch.allclose(tree_attn_output, fa2_tree_attn_output, atol=7.81e-3) # Verify that the chain attention output for each # branch of the tree (computed using FA3) matches # the tree attention output. for q_index in range(tree_size_q): # Get the q, k, and v for the branch. - branch_mask = tree_attn_mask[q_index, :] + branch_mask = tree_attn_mask[0][q_index, :] branch_indices = torch.nonzero(branch_mask, as_tuple=True)[0] q_len = branch_indices.shape[0] @@ -268,7 +378,7 @@ def test_tree_attn_correctness() -> None: branch_positions, block_table, block_size) # Compute flash attention for the branch. - flash_attn_output = forward_attention( + forward_attn_output3, _ = forward_attention( q=q_branch, k=k_branch, v=v_branch, @@ -277,8 +387,8 @@ def test_tree_attn_correctness() -> None: slot_mapping=branch_slot_mapping, seqlen_k=sequence_position + q_len, backend=_Backend.FLASH_ATTN_VLLM_V1, - ).view(batch_size, -1, num_heads, dim_per_head) - + ) + flash_attn_output = forward_attn_output3.view(batch_size, -1, num_heads, dim_per_head) # Compare the outputs. assert torch.allclose( tree_attn_output[:, branch_indices], diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index f8b00565f051..69673ca6c68f 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -12,7 +12,8 @@ from vllm import _custom_ops as ops reshape_and_cache_flash = ops.reshape_and_cache_flash from vllm.vllm_flash_attn import (flash_attn_varlen_func, - get_scheduler_metadata) + get_scheduler_metadata, + tree_attention) elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops as ops reshape_and_cache_flash = ops.reshape_and_cache_flash diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ab7a71a399b3..9acd0ca6b0ef 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -20,7 +20,8 @@ if is_flash_attn_varlen_func_available(): from vllm.attention.utils.fa_utils import (flash_attn_varlen_func, get_scheduler_metadata, - reshape_and_cache_flash) + reshape_and_cache_flash, + tree_attention) from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger @@ -108,6 +109,14 @@ def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") +@dataclass +class TreeMetadata: + # mask is a 1-D Tensor of uint64. Each uint64 represents a row in the causal mask starting from the end. + mask: torch.Tensor + # lens is a 1-D Tensor of the cumulative lengths of the masks in the batch. + lens: torch.Tensor + + @dataclass class FlashAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. @@ -136,9 +145,11 @@ class FlashAttentionMetadata: # Optional aot scheduling scheduler_metadata: Optional[torch.Tensor] = None prefix_scheduler_metadata: Optional[torch.Tensor] = None - max_num_splits: int = 0 + max_num_splits: int = 0 causal: bool = True + # Optional tree attention + tree_metadata: Optional[TreeMetadata] = None def _get_sliding_window_configs( @@ -225,7 +236,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlashAttentionMetadata: + fast_build: bool = False, + tree_metadata: TreeMetadata = None) -> FlashAttentionMetadata: """ fast_build disables AOT scheduling, used when there will be few iterations i.e. spec-decode @@ -358,7 +370,9 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, - causal=causal) + causal=causal, + tree_metadata=tree_metadata, + ) return attn_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: @@ -530,29 +544,56 @@ def forward( descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - flash_attn_varlen_func( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=attn_metadata.causal, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - scheduler_metadata=scheduler_metadata, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - num_splits=attn_metadata.max_num_splits, - s_aux=self.sinks, - ) + # print(f"$$$$ DEBUG: metadata:\n") + # print(f"\tnum_actual_tokens={attn_metadata.num_actual_tokens}, max_query_len={attn_metadata.max_query_len}, \ + # query_start_loc shape={attn_metadata.query_start_loc.shape}, seq_lens shape={attn_metadata.seq_lens.shape},\ + # block_table shape={attn_metadata.block_table.shape}, slot_mapping shape={attn_metadata.slot_mapping.shape} ") + if attn_metadata.tree_metadata: + tree_attention( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + tree_mask=attn_metadata.tree_metadata.mask, + tree_mask_lens=attn_metadata.tree_metadata.lens, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + alibi_slopes=self.alibi_slopes, + block_table=block_table, + softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + else: + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + num_splits=attn_metadata.max_num_splits, + # s_aux=self.sinks, + ) return output # Cascade attention (rare case). From 7a2ebfe40d182eaedc7d8de485fc777ecb600c77 Mon Sep 17 00:00:00 2001 From: Andrew O'Neill Date: Tue, 19 Aug 2025 16:17:07 +0000 Subject: [PATCH 2/3] use benchmark from fa2 module --- tests/v1/spec_decode/test_tree_attention.py | 25 +-------------------- 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 276607df021c..27f091bc75d2 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -12,30 +12,7 @@ from vllm.config import ParallelConfig, SpeculativeConfig from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.flash_attn import (FlashAttentionMetadataBuilder, TreeMetadata) - -import torch.utils.benchmark as benchmark - - -def benchmark_forward( - fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs -): - """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" - if verbose: - print(desc, "- Forward pass") - - def amp_wrapper(*inputs, **kwinputs): - with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): - fn(*inputs, **kwinputs) - - t = benchmark.Timer( - stmt="fn_amp(*inputs, **kwinputs)", - globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, - num_threads=torch.get_num_threads(), - ) - m = t.timeit(repeats) - if verbose: - print(m) - return t, m +from vllm.vllm_flash_attn.utils.benchmark import benchmark_forward class MockAttentionLayer(torch.nn.Module): From 98f31c6b234ddae4e276656304acc61343f946aa Mon Sep 17 00:00:00 2001 From: Andrew O'Neill Date: Tue, 26 Aug 2025 21:19:52 +0000 Subject: [PATCH 3/3] add simple propose for e2e test --- tests/v1/spec_decode/test_eagle.py | 14 ++++++++----- vllm/v1/spec_decode/eagle.py | 32 +++++++++++++++++++++++------- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 7b8445a0b287..141da0b71f1b 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -18,6 +18,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform from vllm.v1.spec_decode.eagle import EagleProposer +import time model_dir = "meta-llama/Llama-3.1-8B-Instruct" eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" @@ -237,7 +238,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): pytest.skip("TRITON_ATTN_VLLM_V1 does not support " "multi-token eagle spec decode on current platform") - if (attn_backend == "TREE_ATTN"): + if (attn_backend in ("TREE_ATTN", "FLASH_ATTN_VLLM_V1")): pytest.skip("TREE_ATTN is tested separately in test_propose_tree" "because it requires special input mocking.") @@ -394,7 +395,9 @@ def create_deterministic_logits(token_ids): # Verify all tokens match our expectations assert torch.equal(result, expected_tokens) - +@pytest.mark.parametrize( + "backend", [_Backend.TREE_ATTN, _Backend.FLASH_ATTN_VLLM_V1] +) @pytest.mark.parametrize( "spec_token_tree", [ @@ -404,7 +407,7 @@ def create_deterministic_logits(token_ids): [(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)], # Tree ]) -def test_propose_tree(spec_token_tree): +def test_propose_tree(spec_token_tree, backend): # Get GPU device. device = torch.device(current_platform.device_type) @@ -477,7 +480,7 @@ def create_deterministic_logits(token_ids, k: int): proposer.attn_layer_names = ["layer.0"] # Get the tree attention metadata builder. - attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = get_attention_backend(backend) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), layer_names=proposer.attn_layer_names, @@ -517,6 +520,7 @@ def create_deterministic_logits(token_ids, k: int): sampling_metadata = mock.MagicMock() # Propose draft tokens. + start = time.perf_counter() result = proposer.propose(target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -524,7 +528,7 @@ def create_deterministic_logits(token_ids, k: int): common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata) assert result.shape == (batch_size, num_speculative_tokens) - + print(f"backend {backend} took {time.perf_counter()-start}") # The tokens are expected to be consecutive integers starting # from the base token IDs. expected_tokens = base_token_ids[:, None] + torch.arange( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a8a160a0f995..35e1ac953f09 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -19,7 +19,7 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.flash_attn import (FlashAttentionMetadata, FlashAttentionMetadataBuilder, TreeMetadata) from vllm.v1.attention.backends.rocm_aiter_fa import ( AiterFlashAttentionMetadata) from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, @@ -203,7 +203,7 @@ def propose( positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] - if isinstance(attn_metadata, TreeAttentionMetadata): + if isinstance(attn_metadata, (TreeAttentionMetadata, FlashAttentionMetadata)): # Draft using tree attention. draft_token_ids_list = self.propose_tree( batch_size=batch_size, @@ -339,7 +339,7 @@ def propose_tree( tree_attn_metadata_builder = \ self.runner.attn_groups[0][0].metadata_builder assert isinstance(tree_attn_metadata_builder, - TreeAttentionMetadataBuilder) + (TreeAttentionMetadataBuilder, FlashAttentionMetadataBuilder)) total_num_drafts = self.cu_drafts_per_level[0] level_num_drafts = total_num_drafts @@ -409,10 +409,28 @@ def propose_tree( num_actual_tokens=batch_size * query_len, max_query_len=query_len, ) - attn_metadata = tree_attn_metadata_builder.build_for_drafting( - common_attn_metadata=common_attn_metadata, - draft_index=level + 1, - ) + if isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder): + attn_metadata = tree_attn_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=level + 1, + ) + elif isinstance(tree_attn_metadata_builder, FlashAttentionMetadataBuilder): + attn_metadata = tree_attn_metadata_builder.build( + 0, + common_attn_metadata=common_attn_metadata, + tree_metadata=TreeMetadata( + mask=torch.tensor( + [1 << i for i in range(level_num_drafts)]*batch_size, + dtype=torch.uint64, + device=common_attn_metadata.query_start_loc.device, + ), + lens=torch.tensor( + [i * level_num_drafts for i in range(batch_size+1)], + dtype=torch.int32, + device=common_attn_metadata.query_start_loc.device, + ), + ), + ) # Apply new attention metadata to all layers. per_layer_attn_metadata = {}