Skip to content

Commit f8b4077

Browse files
authored
[nvbugs/5326453] Avoid nesting NCCL grouping in allgather OP (#5789)
Signed-off-by: junq <[email protected]>
1 parent 6062dc6 commit f8b4077

File tree

5 files changed

+54
-61
lines changed

5 files changed

+54
-61
lines changed

cpp/tensorrt_llm/thop/allgatherOp.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,21 @@ class AllgatherOp
9797
{
9898
std::vector<torch::Tensor> output_list;
9999
output_list.reserve(input_list.size());
100-
ncclGroupStart();
100+
// NCCL Groups cannot be nested.
101+
// Skip NCCL grouping when AllgatherV(variable-length AllGather) calls ncclGroupStart/ncclGroupEnd.
102+
if (!sizes.has_value())
103+
{
104+
ncclGroupStart();
105+
}
101106
for (auto const& input : input_list)
102107
{
103108
auto output = run(input, sizes);
104109
output_list.push_back(output);
105110
}
106-
ncclGroupEnd();
111+
if (!sizes.has_value())
112+
{
113+
ncclGroupEnd();
114+
}
107115
return output_list;
108116
}
109117

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,16 +1197,16 @@ def _prepare_tp_inputs(
11971197
new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch]
11981198
next_draft_tokens_device = new_tensors_device.next_draft_tokens # [batch, draft_len]
11991199

1200-
# Requests with draft tokens are treated like extend requests. CUDA graph dummy extend
1201-
# requests should be at the end of extend_requests.
1200+
# Requests with draft tokens are treated like extend requests. Dummy extend requests should be
1201+
# at the end of extend_requests.
12021202
extend_requests = []
1203-
extend_cuda_graph_dummy_requests = []
1203+
extend_dummy_requests = []
12041204
generation_requests = []
12051205
for request in scheduled_requests.generation_requests:
12061206
if len(request.py_draft_tokens
12071207
) > 0 or next_draft_tokens_device is not None:
1208-
if request.is_cuda_graph_dummy:
1209-
extend_cuda_graph_dummy_requests.append(request)
1208+
if request.is_dummy:
1209+
extend_dummy_requests.append(request)
12101210
else:
12111211
extend_requests.append(request)
12121212
else:
@@ -1219,8 +1219,8 @@ def _prepare_tp_inputs(
12191219
pin_memory=True)
12201220
mrope_config['mrope_position_deltas'].append(
12211221
mrope_position_deltas.to('cuda', non_blocking=True))
1222+
extend_requests += extend_dummy_requests
12221223

1223-
extend_requests = extend_cuda_graph_dummy_requests + extend_requests
12241224
if not self._disable_overlap_scheduler and self.is_spec_decode:
12251225
spec_dec_mode = self.spec_config.spec_dec_mode
12261226
assert spec_dec_mode.support_overlap_scheduler(
@@ -1229,18 +1229,18 @@ def _prepare_tp_inputs(
12291229
# will contain previous batch incices of generation requests
12301230
previous_batch_indices = []
12311231
previous_pos_indices = []
1232-
request_ids_with_previous_batch = []
1233-
num_extend_reqs_wo_previous_batch = 0
12341232
for request in extend_requests:
12351233
# the request has no previous tensor:
12361234
# (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
12371235
# (2) a dummy request; or
12381236
# (3) the first step in the generation server of disaggregated serving
12391237
if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None:
1240-
# get token ids, including input token ids and draft token ids
1241-
input_ids.append(request.get_last_tokens(0))
1242-
input_ids.extend(request.py_draft_tokens)
1243-
draft_tokens.extend(request.py_draft_tokens)
1238+
# get token ids, including input token ids and draft token ids. For these dummy requests,
1239+
# no need to copy the token ids.
1240+
if not request.is_dummy:
1241+
input_ids.append(request.get_last_tokens(0))
1242+
input_ids.extend(request.py_draft_tokens)
1243+
draft_tokens.extend(request.py_draft_tokens)
12441244
# get other ids and lengths
12451245
num_draft_tokens = len(request.py_draft_tokens)
12461246
past_seen_token_num = request.max_beam_num_tokens - 1
@@ -1268,7 +1268,6 @@ def _prepare_tp_inputs(
12681268
# update batch index
12691269
request.py_batch_idx = batch_idx
12701270
batch_idx += 1
1271-
num_extend_reqs_wo_previous_batch += 1
12721271
else:
12731272
# update batch index
12741273
previous_batch_idx = request.py_batch_idx
@@ -1295,10 +1294,7 @@ def _prepare_tp_inputs(
12951294
num_cached_tokens_per_seq.append(past_seen_token_num +
12961295
self.max_draft_len + 1)
12971296
prompt_lengths.append(request.py_prompt_len)
1298-
request_ids_with_previous_batch.append(request.py_request_id)
1299-
1300-
# move requests with previous batch to the end of the list
1301-
request_ids.extend(request_ids_with_previous_batch)
1297+
request_ids.append(request.py_request_id)
13021298

13031299
sequence_lengths.extend([1] * len(generation_requests))
13041300
gather_ids.extend(
@@ -1333,6 +1329,7 @@ def _prepare_tp_inputs(
13331329
num_tokens = len(input_ids)
13341330
num_draft_tokens = len(draft_tokens)
13351331
previous_batchs = len(previous_batch_indices)
1332+
num_requests = len(request_ids)
13361333
total_num_tokens = len(position_ids)
13371334
assert total_num_tokens <= self.max_num_tokens, (
13381335
"total_num_tokens should be less than or equal to max_num_tokens")
@@ -1374,31 +1371,27 @@ def _prepare_tp_inputs(
13741371
non_blocking=True)
13751372
# prepare data for the preprocess inputs
13761373
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
1377-
pre_tokens_start_idx = num_extend_reqs_wo_previous_batch * (
1378-
1 + self.max_draft_len)
1379-
pre_tokens_end_idx = pre_tokens_start_idx + previous_batch_tokens
1380-
pre_batch_start_idx = num_extend_reqs_wo_previous_batch
1381-
pre_batch_end_idx = pre_batch_start_idx + previous_batchs
13821374
previous_pos_indices = torch.tensor(previous_pos_indices,
13831375
dtype=torch.int,
13841376
pin_memory=True)
1385-
self.previous_pos_indices_cuda[
1386-
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
1387-
previous_pos_indices, non_blocking=True)
1377+
self.previous_pos_indices_cuda[0:previous_batch_tokens].copy_(
1378+
previous_pos_indices, non_blocking=True)
13881379
self.previous_pos_id_offsets_cuda[
1389-
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
1380+
0:previous_batch_tokens].copy_(
13901381
new_tokens_lens_device[self.previous_pos_indices_cuda[
1391-
pre_tokens_start_idx:pre_tokens_end_idx]],
1392-
non_blocking=True)
1393-
self.previous_kv_lens_offsets_cuda[
1394-
pre_batch_start_idx:pre_batch_end_idx].copy_(
1395-
kv_len_offsets_device[
1396-
self.previous_batch_indices_cuda[:previous_batchs]],
1382+
0:previous_batch_tokens]],
13971383
non_blocking=True)
1384+
self.previous_kv_lens_offsets_cuda[0:previous_batchs].copy_(
1385+
kv_len_offsets_device[
1386+
self.previous_batch_indices_cuda[:previous_batchs]],
1387+
non_blocking=True)
13981388
# for the requests that do not have previous batch, set the previous_pos_id_offsets and
13991389
# previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
1400-
self.previous_pos_id_offsets_cuda[:pre_tokens_start_idx] *= 0
1401-
self.previous_kv_lens_offsets_cuda[:pre_batch_start_idx] *= 0
1390+
self.previous_pos_id_offsets_cuda[
1391+
previous_batch_tokens:num_requests *
1392+
(1 + self.max_draft_len)] *= 0
1393+
self.previous_kv_lens_offsets_cuda[
1394+
previous_batchs:num_requests] *= 0
14021395
else:
14031396
# change the data to zeros to skip the value changes in _preprocess_inputs
14041397
self.previous_pos_id_offsets_cuda *= 0

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,7 +1492,7 @@ def _check_disagg_gen_transfer_status(self):
14921492
@nvtx_range("_pad_attention_dp_dummy_request")
14931493
def _pad_attention_dp_dummy_request(self):
14941494
"""
1495-
Pad with dummy requests, if required, to avoid empty attention_dp rank.
1495+
Pad with a dummy request, if required, to ensure every attention_dp rank has at least one active request.
14961496
"""
14971497
if not self.enable_attention_dp:
14981498
return
@@ -1506,22 +1506,20 @@ def _pad_attention_dp_dummy_request(self):
15061506
or req.is_disagg_generation_transmission_in_progress else 1
15071507
for req in self.active_requests
15081508
])
1509-
num_dummy_request = self.expected_num_active_requests - num_active_request
1510-
if num_dummy_request > 0:
1511-
llm_request_list = self.kv_cache_manager.add_dummy_requests(
1512-
request_ids=list(range(num_dummy_request)),
1509+
1510+
if self.expected_num_active_requests - num_active_request > 0 and num_active_request == 0:
1511+
llm_request = self.kv_cache_manager.add_dummy_requests(
1512+
request_ids=[0],
15131513
is_gen=not self.has_context_request,
15141514
prepare_resource=not self.has_context_request,
15151515
max_num_draft_tokens=self.max_draft_tokens,
1516-
)
1517-
for llm_request in llm_request_list:
1518-
llm_request.is_attention_dp_dummy = True
1516+
)[0]
1517+
llm_request.is_attention_dp_dummy = True
15191518
spec_resource_manager = self.resource_manager.get_resource_manager(
15201519
ResourceManagerType.SPEC_RESOURCE_MANAGER)
15211520
if spec_resource_manager is not None:
1522-
spec_resource_manager.add_dummy_requests(
1523-
list(range(num_dummy_request)))
1524-
self.active_requests += llm_request_list
1521+
spec_resource_manager.add_dummy_requests([0])
1522+
self.active_requests.append(llm_request)
15251523

15261524
@nvtx_range("_prepare_disagg_gen_init")
15271525
def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests):
@@ -1645,12 +1643,13 @@ def forward(scheduled_requests, resource_manager, new_tensors_device,
16451643

16461644
def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
16471645
# handle potential attention dp dummy request
1648-
for request in self.active_requests[:]:
1649-
if request.is_attention_dp_dummy:
1650-
request.state = LlmRequestState.GENERATION_COMPLETE
1651-
self.inflight_req_ids.erase(request.py_request_id)
1652-
self._terminate_request(request)
1653-
self.active_requests.remove(request)
1646+
if self.active_requests and self.active_requests[
1647+
-1].is_attention_dp_dummy:
1648+
request = self.active_requests[-1]
1649+
request.state = LlmRequestState.GENERATION_COMPLETE
1650+
self.inflight_req_ids.erase(request.py_request_id)
1651+
self._terminate_request(request)
1652+
self.active_requests.remove(request)
16541653

16551654
for request in scheduled_requests.context_requests:
16561655
if request.state != LlmRequestState.GENERATION_COMPLETE: # skip failed requests

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,6 @@ def update_requests(self, state: SampleStateMTP) -> None:
267267
request.py_decoding_iter += 1
268268
idx += 1
269269

270-
# skip the results of cuda graph dummy requests
271-
if idx == 0:
272-
num_cuda_graph_dummy_requests = len(new_tokens_list) - len(
273-
state.scheduled_requests.generation_requests)
274-
idx += num_cuda_graph_dummy_requests
275-
276270
for request in state.scheduled_requests.generation_requests:
277271
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
278272
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,8 +646,7 @@ def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph,
646646
@pytest.mark.skip_device_not_contain(["H100"])
647647
@parametrize_with_ids("mtp_nextn", [0, 2])
648648
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
649-
# OOM on H100 with default free_gpu_memory_fraction=0.9
650-
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
649+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
651650
mtp_config = None
652651
if mtp_nextn > 0:
653652
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)

0 commit comments

Comments
 (0)