Skip to content

Commit 3a58db8

Browse files
authored
fix _pad_attention_dp_dummy_request (#5583)
Signed-off-by: junq <[email protected]>
1 parent 7524c77 commit 3a58db8

File tree

4 files changed

+59
-44
lines changed

4 files changed

+59
-44
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,16 +1197,16 @@ def _prepare_tp_inputs(
11971197
new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch]
11981198
next_draft_tokens_device = new_tensors_device.next_draft_tokens # [batch, draft_len]
11991199

1200-
# Requests with draft tokens are treated like extend requests. Dummy extend requests should be
1201-
# at the end of extend_requests.
1200+
# Requests with draft tokens are treated like extend requests. CUDA graph dummy extend
1201+
# requests should be at the end of extend_requests.
12021202
extend_requests = []
1203-
extend_dummy_requests = []
1203+
extend_cuda_graph_dummy_requests = []
12041204
generation_requests = []
12051205
for request in scheduled_requests.generation_requests:
12061206
if len(request.py_draft_tokens
12071207
) > 0 or next_draft_tokens_device is not None:
1208-
if request.is_dummy:
1209-
extend_dummy_requests.append(request)
1208+
if request.is_cuda_graph_dummy:
1209+
extend_cuda_graph_dummy_requests.append(request)
12101210
else:
12111211
extend_requests.append(request)
12121212
else:
@@ -1219,8 +1219,8 @@ def _prepare_tp_inputs(
12191219
pin_memory=True)
12201220
mrope_config['mrope_position_deltas'].append(
12211221
mrope_position_deltas.to('cuda', non_blocking=True))
1222-
extend_requests += extend_dummy_requests
12231222

1223+
extend_requests = extend_cuda_graph_dummy_requests + extend_requests
12241224
if not self._disable_overlap_scheduler and self.is_spec_decode:
12251225
spec_dec_mode = self.spec_config.spec_dec_mode
12261226
assert spec_dec_mode.support_overlap_scheduler(
@@ -1229,18 +1229,18 @@ def _prepare_tp_inputs(
12291229
# will contain previous batch incices of generation requests
12301230
previous_batch_indices = []
12311231
previous_pos_indices = []
1232+
request_ids_with_previous_batch = []
1233+
num_extend_reqs_wo_previous_batch = 0
12321234
for request in extend_requests:
12331235
# the request has no previous tensor:
12341236
# (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
12351237
# (2) a dummy request; or
12361238
# (3) the first step in the generation server of disaggregated serving
12371239
if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None:
1238-
# get token ids, including input token ids and draft token ids. For these dummy requests,
1239-
# no need to copy the token ids.
1240-
if not request.is_dummy:
1241-
input_ids.append(request.get_last_tokens(0))
1242-
input_ids.extend(request.py_draft_tokens)
1243-
draft_tokens.extend(request.py_draft_tokens)
1240+
# get token ids, including input token ids and draft token ids
1241+
input_ids.append(request.get_last_tokens(0))
1242+
input_ids.extend(request.py_draft_tokens)
1243+
draft_tokens.extend(request.py_draft_tokens)
12441244
# get other ids and lengths
12451245
num_draft_tokens = len(request.py_draft_tokens)
12461246
past_seen_token_num = request.max_beam_num_tokens - 1
@@ -1268,6 +1268,7 @@ def _prepare_tp_inputs(
12681268
# update batch index
12691269
request.py_batch_idx = batch_idx
12701270
batch_idx += 1
1271+
num_extend_reqs_wo_previous_batch += 1
12711272
else:
12721273
# update batch index
12731274
previous_batch_idx = request.py_batch_idx
@@ -1294,7 +1295,10 @@ def _prepare_tp_inputs(
12941295
num_cached_tokens_per_seq.append(past_seen_token_num +
12951296
self.max_draft_len + 1)
12961297
prompt_lengths.append(request.py_prompt_len)
1297-
request_ids.append(request.py_request_id)
1298+
request_ids_with_previous_batch.append(request.py_request_id)
1299+
1300+
# move requests with previous batch to the end of the list
1301+
request_ids.extend(request_ids_with_previous_batch)
12981302

12991303
sequence_lengths.extend([1] * len(generation_requests))
13001304
gather_ids.extend(
@@ -1329,7 +1333,6 @@ def _prepare_tp_inputs(
13291333
num_tokens = len(input_ids)
13301334
num_draft_tokens = len(draft_tokens)
13311335
previous_batchs = len(previous_batch_indices)
1332-
num_requests = len(request_ids)
13331336
total_num_tokens = len(position_ids)
13341337
assert total_num_tokens <= self.max_num_tokens, (
13351338
"total_num_tokens should be less than or equal to max_num_tokens")
@@ -1371,27 +1374,31 @@ def _prepare_tp_inputs(
13711374
non_blocking=True)
13721375
# prepare data for the preprocess inputs
13731376
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
1377+
pre_tokens_start_idx = num_extend_reqs_wo_previous_batch * (
1378+
1 + self.max_draft_len)
1379+
pre_tokens_end_idx = pre_tokens_start_idx + previous_batch_tokens
1380+
pre_batch_start_idx = num_extend_reqs_wo_previous_batch
1381+
pre_batch_end_idx = pre_batch_start_idx + previous_batchs
13741382
previous_pos_indices = torch.tensor(previous_pos_indices,
13751383
dtype=torch.int,
13761384
pin_memory=True)
1377-
self.previous_pos_indices_cuda[0:previous_batch_tokens].copy_(
1378-
previous_pos_indices, non_blocking=True)
1385+
self.previous_pos_indices_cuda[
1386+
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
1387+
previous_pos_indices, non_blocking=True)
13791388
self.previous_pos_id_offsets_cuda[
1380-
0:previous_batch_tokens].copy_(
1389+
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
13811390
new_tokens_lens_device[self.previous_pos_indices_cuda[
1382-
0:previous_batch_tokens]],
1391+
pre_tokens_start_idx:pre_tokens_end_idx]],
1392+
non_blocking=True)
1393+
self.previous_kv_lens_offsets_cuda[
1394+
pre_batch_start_idx:pre_batch_end_idx].copy_(
1395+
kv_len_offsets_device[
1396+
self.previous_batch_indices_cuda[:previous_batchs]],
13831397
non_blocking=True)
1384-
self.previous_kv_lens_offsets_cuda[0:previous_batchs].copy_(
1385-
kv_len_offsets_device[
1386-
self.previous_batch_indices_cuda[:previous_batchs]],
1387-
non_blocking=True)
13881398
# for the requests that do not have previous batch, set the previous_pos_id_offsets and
13891399
# previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
1390-
self.previous_pos_id_offsets_cuda[
1391-
previous_batch_tokens:num_requests *
1392-
(1 + self.max_draft_len)] *= 0
1393-
self.previous_kv_lens_offsets_cuda[
1394-
previous_batchs:num_requests] *= 0
1400+
self.previous_pos_id_offsets_cuda[:pre_tokens_start_idx] *= 0
1401+
self.previous_kv_lens_offsets_cuda[:pre_batch_start_idx] *= 0
13951402
else:
13961403
# change the data to zeros to skip the value changes in _preprocess_inputs
13971404
self.previous_pos_id_offsets_cuda *= 0

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,7 +1492,7 @@ def _check_disagg_gen_transfer_status(self):
14921492
@nvtx_range("_pad_attention_dp_dummy_request")
14931493
def _pad_attention_dp_dummy_request(self):
14941494
"""
1495-
Pad with a dummy request, if required, to ensure every attention_dp rank has at least one active request.
1495+
Pad with dummy requests, if required, to avoid empty attention_dp rank.
14961496
"""
14971497
if not self.enable_attention_dp:
14981498
return
@@ -1506,20 +1506,22 @@ def _pad_attention_dp_dummy_request(self):
15061506
or req.is_disagg_generation_transmission_in_progress else 1
15071507
for req in self.active_requests
15081508
])
1509-
1510-
if self.expected_num_active_requests - num_active_request > 0 and num_active_request == 0:
1511-
llm_request = self.kv_cache_manager.add_dummy_requests(
1512-
request_ids=[0],
1509+
num_dummy_request = self.expected_num_active_requests - num_active_request
1510+
if num_dummy_request > 0:
1511+
llm_request_list = self.kv_cache_manager.add_dummy_requests(
1512+
request_ids=list(range(num_dummy_request)),
15131513
is_gen=not self.has_context_request,
15141514
prepare_resource=not self.has_context_request,
15151515
max_num_draft_tokens=self.max_draft_tokens,
1516-
)[0]
1517-
llm_request.is_attention_dp_dummy = True
1516+
)
1517+
for llm_request in llm_request_list:
1518+
llm_request.is_attention_dp_dummy = True
15181519
spec_resource_manager = self.resource_manager.get_resource_manager(
15191520
ResourceManagerType.SPEC_RESOURCE_MANAGER)
15201521
if spec_resource_manager is not None:
1521-
spec_resource_manager.add_dummy_requests([0])
1522-
self.active_requests.append(llm_request)
1522+
spec_resource_manager.add_dummy_requests(
1523+
list(range(num_dummy_request)))
1524+
self.active_requests += llm_request_list
15231525

15241526
@nvtx_range("_prepare_disagg_gen_init")
15251527
def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests):
@@ -1643,13 +1645,12 @@ def forward(scheduled_requests, resource_manager, new_tensors_device,
16431645

16441646
def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
16451647
# handle potential attention dp dummy request
1646-
if self.active_requests and self.active_requests[
1647-
-1].is_attention_dp_dummy:
1648-
request = self.active_requests[-1]
1649-
request.state = LlmRequestState.GENERATION_COMPLETE
1650-
self.inflight_req_ids.erase(request.py_request_id)
1651-
self._terminate_request(request)
1652-
self.active_requests.remove(request)
1648+
for request in self.active_requests[:]:
1649+
if request.is_attention_dp_dummy:
1650+
request.state = LlmRequestState.GENERATION_COMPLETE
1651+
self.inflight_req_ids.erase(request.py_request_id)
1652+
self._terminate_request(request)
1653+
self.active_requests.remove(request)
16531654

16541655
for request in scheduled_requests.context_requests:
16551656
if request.state != LlmRequestState.GENERATION_COMPLETE: # skip failed requests

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,12 @@ def update_requests(self, state: SampleStateMTP) -> None:
267267
request.py_decoding_iter += 1
268268
idx += 1
269269

270+
# skip the results of cuda graph dummy requests
271+
if idx == 0:
272+
num_cuda_graph_dummy_requests = len(new_tokens_list) - len(
273+
state.scheduled_requests.generation_requests)
274+
idx += num_cuda_graph_dummy_requests
275+
270276
for request in state.scheduled_requests.generation_requests:
271277
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
272278
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,8 @@ def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph,
646646
@pytest.mark.skip_device_not_contain(["H100"])
647647
@parametrize_with_ids("mtp_nextn", [0, 2])
648648
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
649-
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
649+
# OOM on H100 with default free_gpu_memory_fraction=0.9
650+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
650651
mtp_config = None
651652
if mtp_nextn > 0:
652653
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)

0 commit comments

Comments
 (0)