Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/batch_manager/pauseRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ void tensorrt_llm::batch_manager::PauseRequests::operator()(RequestVector& reque
for (auto& llmReq : requestsToPause)
{
auto const reqId = llmReq->mRequestId;
inflightReqIds.erase(reqId);
TLLM_LOG_DEBUG("request with ID %lu removed from DECODER model inflight set", reqId);
auto const removed = inflightReqIds.erase(reqId);
TLLM_LOG_DEBUG("request with ID %lu removed from DECODER model inflight set: %d", reqId, removed);

// If a request in this context had been flagged to be paused, pause it right away
if (reqIdsToPause.find(reqId) != reqIdsToPause.end())
Expand Down
17 changes: 12 additions & 5 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,8 +881,6 @@ void TrtGptModelInflightBatching::forwardSync()
}
}

(*mPauseRequests)(currRequests.contextRequests, mInflightReqIds, mReqIdsToPause, true, *mSeqSlotManager,
mKvCacheManager, mCrossKvCacheManager, mPeftCacheManager);
(*mPauseRequests)(currRequests.generationRequests, mInflightReqIds, mReqIdsToPause, true, *mSeqSlotManager,
mKvCacheManager, mCrossKvCacheManager, mPeftCacheManager);

Expand Down Expand Up @@ -1051,14 +1049,23 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
{
NVTX3_SCOPED_RANGE(updateInflightReqIds);
// Add requests to in-flight set, so they can be skipped in other micro batches
for (auto const& requests : {currRequests.contextRequests, currRequests.generationRequests})
for (auto const& llmReq : currRequests.contextRequests)
{
for (auto const& llmReq : requests)
// Context requests that are chunking are not added to inflight set, so they are scheduled in the
// next micro batch.
if (llmReq->isLastContextChunk())
{
TLLM_LOG_DEBUG("request with ID %lu added to DECODER model inflight set", llmReq->mRequestId);
TLLM_LOG_DEBUG(
"Context request with ID %lu added to DECODER model inflight set", llmReq->mRequestId);
mInflightReqIds.insert(llmReq->mRequestId);
}
}
for (auto const& llmReq : currRequests.generationRequests)
{
TLLM_LOG_DEBUG(
"Generation request with ID %lu added to DECODER model inflight set", llmReq->mRequestId);
mInflightReqIds.insert(llmReq->mRequestId);
}
}

(*mAssignReqSeqSlots)(*mSeqSlotManager, currRequests.contextRequests, currRequests.generationRequests);
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,10 @@ def create_py_executor_instance(

spec_config = model_engine.spec_config

max_num_sequences = max_batch_size * mapping.pp_size

logger.info(
f"max_seq_len={max_seq_len}, max_num_requests={max_batch_size}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}"
f"max_seq_len={max_seq_len}, max_num_requests={max_num_sequences}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}"
)

for key, value in llm_args.extra_resource_managers.items():
Expand Down Expand Up @@ -760,8 +762,6 @@ def create_py_executor_instance(
lora_config.trtllm_modules_to_hf_modules,
lora_config.swap_gate_up_proj_lora_b_weight)

max_num_sequences = max_batch_size * mapping.pp_size

resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager(
max_num_sequences)

Expand Down
81 changes: 65 additions & 16 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class BatchState:
class BatchStatePP(BatchState):
microbatch_id: int = -1
scheduled_ctx_reqs: list[LlmRequest] = None
finished_ctx_reqs: list[LlmRequest] = None


class PyExecutor:
Expand Down Expand Up @@ -232,6 +233,8 @@ def __init__(self,
| None] = [None] * self.num_micro_batches
self.send_handles = [None] * self.num_micro_batches

# Set of request IDs that are currently in flight across all micro batches.
# The scheduler will avoid scheduling requests that are already in flight.
self.inflight_req_ids = ReqIdsSet()

# During warmup, we don't enable the profiler
Expand Down Expand Up @@ -694,7 +697,7 @@ def get_queued_req_stats(request_id: int) -> RequestStats:
return req_stats

def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
scheduled_batch) -> IterationStats:
scheduled_batch, micro_batch_id) -> IterationStats:
stats.iter_latency_ms = iter_latency_ms

stats.num_queued_requests = self.executor_request_queue.get_request_queue_size(
Expand Down Expand Up @@ -735,7 +738,7 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
stats.inflight_batching_stats.num_paused_requests = len(
scheduled_batch.paused_requests)
stats.inflight_batching_stats.avg_num_decoded_tokens_per_iter = 0
stats.inflight_batching_stats.micro_batch_id = 0
stats.inflight_batching_stats.micro_batch_id = micro_batch_id
if stats.specdec_stats is not None:
stats.specdec_stats.draft_overhead = 0.0 if iter_latency_ms <= 0.0 else float(
stats.specdec_stats.iter_latency_ms) / float(iter_latency_ms)
Expand All @@ -748,9 +751,13 @@ def _append_iter_stats(self,
with self.stats_lock:
self.stats.append((stats, req_stats))

def _process_iter_stats(self, finished_requests: list[LlmRequest],
active_requests: List[LlmRequest],
batch_state: BatchState):
def _process_iter_stats(
self,
finished_requests: list[LlmRequest],
active_requests: List[LlmRequest],
batch_state: BatchState,
micro_batch_id: int = 0,
):
iter_end_time = time.time()
iter_latency_ms = (iter_end_time - batch_state.iter_start_time) * 1e3
if batch_state.iter_stats is None:
Expand All @@ -763,9 +770,10 @@ def _process_iter_stats(self, finished_requests: list[LlmRequest],
and self.enable_iter_perf_stats) else None

self._append_iter_stats(
self._update_iter_stats(
batch_state.iter_stats, iter_latency_ms, len(finished_requests),
batch_state.sample_state.scheduled_requests), req_stats)
self._update_iter_stats(batch_state.iter_stats, iter_latency_ms,
len(finished_requests),
batch_state.sample_state.scheduled_requests,
micro_batch_id), req_stats)

def _executor_loop_cleanup(self):

Expand Down Expand Up @@ -825,6 +833,7 @@ def _executor_loop_pp(self):
self.num_scheduled_requests = scheduled_batch.batch_size

logger.debug(
f'iteration {self.iter_counter}, microbatch {microbatch_id}, '
f'has {len(self.active_requests)} active_requests, '
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
f'{len(scheduled_batch.generation_requests)} generation requests'
Expand All @@ -833,9 +842,13 @@ def _executor_loop_pp(self):
can_queue = self._can_queue(scheduled_batch)

if not can_queue:
logger.debug(
f"microbatch {microbatch_id} cannot be queued, skipping"
)
self.micro_batches[microbatch_id] = None
else:
self._add_inflight_ids(scheduled_batch)
logger.debug(f"microbatch {microbatch_id} can be queued")
finished_ctx_reqs = self._add_inflight_ids(scheduled_batch)

if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
Expand Down Expand Up @@ -895,6 +908,7 @@ def _executor_loop_pp(self):
iter_stats=iter_stats,
microbatch_id=microbatch_id,
scheduled_ctx_reqs=scheduled_batch.context_requests,
finished_ctx_reqs=finished_ctx_reqs,
)

self.micro_batches[microbatch_id] = batch_state
Expand Down Expand Up @@ -949,6 +963,8 @@ def _executor_loop_pp(self):
finished_requests = []
if previous_batch is not None:
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
sample_state = previous_batch.sample_state
sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs
self._update_requests(previous_batch.sample_state)

if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
Expand Down Expand Up @@ -980,7 +996,8 @@ def _executor_loop_pp(self):
self.resource_manager.update_resources(
previous_scheduled_batch, attn_metadata,
kv_cache_dtype_byte_size)
self._remove_inflight_ids(previous_scheduled_batch)

self._remove_inflight_ids(previous_batch)

self.wait_on_pp_send_handles(prev_microbatch_id)
self.micro_batches[prev_microbatch_id] = None
Expand All @@ -997,9 +1014,11 @@ def _executor_loop_pp(self):
microbatch_id = (microbatch_id + 1) % self.num_micro_batches

if self.enable_iter_perf_stats and previous_batch is not None:
sample_state = previous_batch.sample_state
sample_state.scheduled_requests.context_requests = previous_batch.scheduled_ctx_reqs
self._process_iter_stats(finished_requests,
self.active_requests,
previous_batch)
previous_batch, microbatch_id)

self.iter_counter += 1

Expand Down Expand Up @@ -2485,13 +2504,43 @@ def _pause_requests(self, requests_to_pause):
self._terminate_request(req)

def _add_inflight_ids(self, scheduled_requests):
"""Add reqids of current requests to self.inflight_req_ids."""
for req in scheduled_requests.all_requests():
"""Add request IDs of current requests to self.inflight_req_ids.

Non‑final context chunks are not added to the inflight set, so the scheduler can keep scheduling further
context chunks while earlier ones are in the PP pipeline. Only context requests that finish context phase
are inserted into the inflight set and collected into finished_ctx_reqs.
All generation requests are still inserted into the inflight set.
"""
finished_ctx_reqs = []
for req in scheduled_requests.context_requests:
if req.is_last_context_chunk:
logger.debug(
f"Context request with ID {req.request_id} added to DECODER model inflight set"
)
self.inflight_req_ids.insert(req.request_id)
finished_ctx_reqs.append(req)
for req in scheduled_requests.generation_requests:
logger.debug(
f"Generation request with ID {req.request_id} added to DECODER model inflight set"
)
self.inflight_req_ids.insert(req.request_id)
return finished_ctx_reqs

def _remove_inflight_ids(self, batch_state: BatchStatePP):
"""Remove request IDs of current requests from self.inflight_req_ids.

def _remove_inflight_ids(self, scheduled_requests):
"""Remove reqids of current requests from self.inflight_req_ids."""
for req in scheduled_requests.all_requests():
Context IDs are erased from the inflight set using batch_state.finished_ctx_reqs.
Generation IDs are erased using batch_state.sample_state.scheduled_requests.generation_requests.
"""
for req in batch_state.finished_ctx_reqs:
logger.debug(
f"Context request with ID {req.request_id} removed from DECODER model inflight set"
)
self.inflight_req_ids.erase(req.request_id)
for req in batch_state.sample_state.scheduled_requests.generation_requests:
logger.debug(
f"Generation request with ID {req.request_id} removed from DECODER model inflight set"
)
self.inflight_req_ids.erase(req.request_id)

def _handle_speculative_decoding(self, scheduled_batch, previous_tensors,
Expand Down
Loading