99 AscendAttentionState ,
1010 AscendMetadata ,
1111 CommonAttentionState )
12+ from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
1213
1314
1415class TestAscendAttentionBackend (TestBase ):
@@ -67,8 +68,12 @@ def test_copy_blocks(self):
6768class TestAscendAttentionMetadataBuilder (TestBase ):
6869
6970 def setUp (self ):
70- self .mock_runner = MagicMock ()
71- self .builder = AscendAttentionMetadataBuilder (self .mock_runner )
71+ self .mock_vllm_config = MagicMock ()
72+ self .mock_vllm_config .model_config .max_model_len = 640
73+ self .mock_vllm_config .cache_config .block_size = 64
74+ self .mock_device = 'cpu:0'
75+ self .builder = AscendAttentionMetadataBuilder (self .mock_vllm_config ,
76+ self .mock_device )
7277
7378 def test_reorder_batch (self ):
7479 mock_input_batch = MagicMock ()
@@ -86,31 +91,28 @@ def test_reorder_batch(self):
8691 def test_build_prefill_no_cache (self , mock_is_310p , mock_nd_to_nz_2d ,
8792 mock_npu_format_cast ,
8893 mock_ascend_metadata ):
89- num_reqs = 2
90- num_actual_tokens = 10
91- max_query_len = 5
92-
93- self . mock_runner . input_batch . block_table = [ MagicMock ()]
94- self . mock_runner . input_batch . block_table [
95- 0 ]. get_device_tensor . return_value = torch . zeros (( 10 , 10 ))
96- self . mock_runner . max_num_blocks_per_req = 10
97- self . mock_runner . query_lens = torch .tensor ([ 3 , 4 ])
98- self . mock_runner . seq_lens_cpu = torch .tensor ([ 5 , 6 ])
99- self . mock_runner . slot_mapping_cpu = torch .tensor (range ( 20 ))
100- self . mock_runner . device = 'cpu:0'
101- self . mock_runner . attn_mask = torch .ones ((10 , 10 ))
102- self . mock_runner . attn_state = AscendAttentionState . PrefillNoCache
103- self . mock_runner . query_start_loc_cpu = torch . tensor ([ 0 , 3 , 7 ] )
94+ common_attn_metadata = AscendCommonAttentionMetadata (
95+ query_start_loc = torch . tensor ([ 0 , 3 , 7 ]),
96+ query_start_loc_cpu = torch . tensor ([ 0 , 3 , 7 ]),
97+ seq_lens_cpu = torch . tensor ([ 5 , 6 ]),
98+ num_reqs = 2 ,
99+ num_actual_tokens = 10 ,
100+ max_query_len = 5 ,
101+ decode_token_per_req = torch . tensor ([ 1 , 1 ]),
102+ block_table_tensor = torch .zeros (( 10 , 10 )),
103+ slot_mapping_cpu = torch .tensor (range ( 20 )),
104+ actual_seq_lengths_q = torch .tensor ([ 0 , 1 ]),
105+ positions = torch . tensor ([ 10 , 10 ]),
106+ attn_mask = torch .ones ((10 , 10 )),
107+ spec_attn_mask = None ,
108+ attn_state = AscendAttentionState . PrefillNoCache )
104109
105110 mock_nz_tensor = MagicMock ()
111+ mock_model = MagicMock ()
106112 mock_nd_to_nz_2d .return_value = mock_nz_tensor
107113 mock_npu_format_cast .return_value = mock_nz_tensor
108114
109- self .builder .build (
110- num_reqs ,
111- num_actual_tokens ,
112- max_query_len ,
113- )
115+ self .builder .build (common_attn_metadata , mock_model )
114116
115117 @patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
116118 @patch ('torch_npu.npu_format_cast' )
@@ -120,51 +122,53 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
120122 def test_build_chunked_prefill (self , mock_ascend_attention_state ,
121123 mock_is_310p , mock_nd_to_nz_spec ,
122124 mock_npu_format_cast , mock_ascend_metadata ):
123- num_reqs = 3
124- num_actual_tokens = 15
125- max_query_len = 6
126-
127- self . mock_runner . input_batch . block_table = [ MagicMock ()]
128- self . mock_runner . input_batch . block_table [
129- 0 ]. get_device_tensor . return_value = torch . zeros (( 10 , 10 ))
130- self . mock_runner . max_num_blocks_per_req = 10
131- self . mock_runner . query_lens = torch .tensor ([ 2 , 3 , 4 ])
132- self . mock_runner . seq_lens_cpu = torch .tensor ([ 4 , 5 , 6 ])
133- self . mock_runner . slot_mapping_cpu = torch .tensor (range ( 20 ))
134- self . mock_runner . device = 'cpu:0'
135- self . mock_runner . attn_mask = torch .ones ((15 , 15 ))
136- self . mock_runner . attn_state = AscendAttentionState . ChunkedPrefill
137- self . mock_runner . query_start_loc_cpu = torch . tensor ([ 0 , 2 , 5 , 9 ] )
125+ common_attn_metadata = AscendCommonAttentionMetadata (
126+ query_start_loc = torch . tensor ([ 0 , 2 , 5 , 9 ]),
127+ query_start_loc_cpu = torch . tensor ([ 0 , 2 , 5 , 9 ]),
128+ seq_lens_cpu = torch . tensor ([ 4 , 5 , 6 ]),
129+ num_reqs = 3 ,
130+ num_actual_tokens = 15 ,
131+ max_query_len = 6 ,
132+ decode_token_per_req = torch . tensor ([ 1 , 1 , 1 ]),
133+ block_table_tensor = torch .zeros (( 10 , 10 )),
134+ slot_mapping_cpu = torch .tensor (range ( 20 )),
135+ actual_seq_lengths_q = torch .tensor ([ 0 , 1 , 2 ]),
136+ positions = torch . tensor ([ 10 , 10 ]),
137+ attn_mask = torch .ones ((15 , 15 )),
138+ spec_attn_mask = None ,
139+ attn_state = AscendAttentionState . ChunkedPrefill )
138140
139141 mock_ascend_attention_state = MagicMock ()
140142 mock_ascend_attention_state .PrefillNoCache = 0
141143
142144 mock_nz_tensor = MagicMock ()
145+ mock_model = MagicMock ()
143146 mock_nd_to_nz_spec .return_value = mock_nz_tensor
144147 mock_npu_format_cast .return_value = mock_nz_tensor
145148
146- self .builder .build (num_reqs , num_actual_tokens , max_query_len )
149+ self .builder .build (common_attn_metadata , mock_model )
147150
148151 @patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
149152 @patch ('vllm_ascend.attention.attention_v1.is_310p' , return_value = False )
150153 def test_build_non_310p (self , mock_is_310p , mock_ascend_metadata ):
151- num_reqs = 3
152- num_actual_tokens = 15
153- max_query_len = 6
154-
155- self .mock_runner .input_batch .block_table = [MagicMock ()]
156- self .mock_runner .input_batch .block_table [
157- 0 ].get_device_tensor .return_value = torch .zeros ((10 , 10 ))
158- self .mock_runner .max_num_blocks_per_req = 10
159- self .mock_runner .query_lens = torch .tensor ([2 , 3 , 4 ])
160- self .mock_runner .seq_lens_cpu = torch .tensor ([4 , 5 , 6 ])
161- self .mock_runner .slot_mapping_cpu = torch .tensor (range (20 ))
162- self .mock_runner .device = 'cpu:0'
163- self .mock_runner .attn_mask = torch .ones ((15 , 15 ))
164- self .mock_runner .attn_state = AscendAttentionState .ChunkedPrefill
165- self .mock_runner .query_start_loc_cpu = torch .tensor ([0 , 2 , 5 , 9 ])
166-
167- self .builder .build (num_reqs , num_actual_tokens , max_query_len )
154+ common_attn_metadata = AscendCommonAttentionMetadata (
155+ query_start_loc = torch .tensor ([0 , 2 , 5 , 9 ]),
156+ query_start_loc_cpu = torch .tensor ([0 , 2 , 5 , 9 ]),
157+ seq_lens_cpu = torch .tensor ([4 , 5 , 6 ]),
158+ num_reqs = 3 ,
159+ num_actual_tokens = 15 ,
160+ max_query_len = 6 ,
161+ decode_token_per_req = torch .tensor ([1 , 1 , 1 ]),
162+ block_table_tensor = torch .zeros ((10 , 10 )),
163+ slot_mapping_cpu = torch .tensor (range (20 )),
164+ actual_seq_lengths_q = torch .tensor ([0 , 1 , 2 ]),
165+ positions = torch .tensor ([10 , 10 ]),
166+ attn_mask = torch .ones ((15 , 15 )),
167+ spec_attn_mask = None ,
168+ attn_state = AscendAttentionState .ChunkedPrefill )
169+ mock_model = MagicMock ()
170+
171+ self .builder .build (common_attn_metadata , mock_model )
168172
169173
170174class TestAscendAttentionBackendImpl (TestBase ):
0 commit comments