Skip to content

Commit 0fbdf51

Browse files
committed
[refactor] Move iter_counter handling to PyExecutor
- Moved iter_counter in PyExecutor to ensure consistency in tracking iterations. - This allows tracking of iteration where scheduled requests are empty. Signed-off-by: Robin Kobus <[email protected]>
1 parent d02d49b commit 0fbdf51

File tree

5 files changed

+25
-18
lines changed

5 files changed

+25
-18
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def __init__(
153153
self.llm_args.batch_wait_timeout_iters = 0
154154
self.llm_args.batch_wait_max_tokens_ratio = 0.0
155155
self.llm_args.max_num_tokens = seq_info.max_num_tokens
156-
self.iter_counter = 0
157156

158157
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
159158
self.max_beam_width = max_beam_width

tensorrt_llm/_torch/expert_statistic.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,15 @@ def create(rank_id: int):
2929
rank_id, start, stop)
3030

3131
@staticmethod
32-
def set_iter(iter_id: int) -> bool:
32+
def should_record() -> bool:
3333
if ExpertStatistic.expert_statistic_obj is not None:
34-
return ExpertStatistic.expert_statistic_obj._set_iter(iter_id)
35-
else:
36-
return False
34+
return ExpertStatistic.expert_statistic_obj._should_record
35+
return False
36+
37+
@staticmethod
38+
def set_iter(iter_id: int) -> None:
39+
if ExpertStatistic.expert_statistic_obj is not None:
40+
ExpertStatistic.expert_statistic_obj._set_iter(iter_id)
3741

3842
@staticmethod
3943
def set_layer(layer_id: int) -> None:
@@ -57,10 +61,10 @@ def __init__(self, rank_id: int, start: int, stop: int) -> None:
5761
self._records = {}
5862

5963
@property
60-
def should_record(self) -> bool:
64+
def _should_record(self) -> bool:
6165
return self.current_iter_id is not None and self.start <= self.current_iter_id < self.stop
6266

63-
def _set_iter(self, iter_id: int) -> bool:
67+
def _set_iter(self, iter_id: int) -> None:
6468
self.current_iter_id = iter_id
6569
if iter_id == self.stop:
6670
logger.info(
@@ -74,14 +78,13 @@ def _set_iter(self, iter_id: int) -> bool:
7478
json.dump(self._meta_info, f)
7579
safetensors.torch.save_file(
7680
self._records, f"{path}/rank{self.rank_id}.safetensors")
77-
return self.should_record
7881

7982
def _set_layer(self, layer: int) -> None:
8083
self.current_layer = layer
8184

8285
def _maybe_add_info(self, expert_count: int,
8386
token_selected_experts: torch.Tensor) -> None:
84-
if not self.should_record:
87+
if not self._should_record:
8588
return
8689

8790
if self._meta_info is None:

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ def __del__(self):
164164
def maybe_get_cuda_graph(
165165
self,
166166
batch: ScheduledRequests,
167-
iter_counter: int,
168167
enable_spec_decode: bool,
169168
attn_metadata: Any,
170169
spec_metadata: Optional[Any] = None,
@@ -180,7 +179,7 @@ def maybe_get_cuda_graph(
180179
- The key for the graph, if applicable.
181180
"""
182181
# disable when doing statistic
183-
if ExpertStatistic.set_iter(iter_counter):
182+
if ExpertStatistic.should_record():
184183
return None, None, None
185184

186185
can_run_cuda_graph = batch.can_run_cuda_graph

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,6 @@ def __init__(
364364
if self.use_mrope:
365365
self.mrope_position_ids_cuda = torch.empty(
366366
(3, 1, self.max_num_tokens), dtype=torch.int, device='cuda')
367-
self.iter_counter = 0
368367

369368
# Pre-allocated buffers for draft model to avoid implicit synchronization
370369
# These are used to build index tensors without creating tensors from Python lists
@@ -2572,7 +2571,6 @@ def forward(self,
25722571

25732572
maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
25742573
padded_requests,
2575-
iter_counter=self.iter_counter,
25762574
enable_spec_decode=self.enable_spec_decode,
25772575
attn_metadata=attn_metadata,
25782576
spec_metadata=spec_metadata,
@@ -2596,7 +2594,6 @@ def forward(self,
25962594
new_tensors_device, cache_indirection_buffer,
25972595
num_accepted_tokens_device, req_id_to_old_request)
25982596

2599-
self.iter_counter += 1
26002597
with with_shared_pool(self.cuda_graph_runner.get_graph_pool()):
26012598
if not can_run_graph:
26022599
# Fallback to eager execution if graph was not used

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313

14+
from tensorrt_llm._torch.expert_statistic import ExpertStatistic
1415
from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds
1516

1617
try:
@@ -137,6 +138,7 @@ def __init__(self,
137138

138139
self.peft_cache_config = peft_cache_config
139140

141+
self.iter_counter = 0
140142
# profile config
141143
self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes(
142144
PROFILE_START_STOP_ENV_VAR_NAME)
@@ -575,7 +577,7 @@ def profile_step():
575577
formatted_timestamp = datetime.datetime.now().strftime(
576578
"%Y-%m-%d %H:%M:%S")
577579
logger.info(
578-
f"iter = {self.model_engine.iter_counter}, "
580+
f"iter = {self.iter_counter}, "
579581
f"global_rank = {self.global_rank}, "
580582
f"rank = {self.dist.rank}, "
581583
f"currank_total_requests = {self.executor_request_queue.num_fetch_requests_cur_rank}/"
@@ -705,7 +707,7 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
705707
stats.cpu_mem_usage = 0
706708
stats.pinned_mem_usage = 0
707709

708-
stats.iter = self.model_engine.iter_counter
710+
stats.iter = self.iter_counter
709711

710712
kv_cache_manager = self.resource_manager.resource_managers.get(
711713
ResourceManagerType.KV_CACHE_MANAGER)
@@ -1004,6 +1006,8 @@ def _executor_loop_pp(self):
10041006
self.active_requests,
10051007
previous_batch)
10061008

1009+
self.iter_counter += 1
1010+
10071011
def wait_on_pp_send_handles(self, microbatch_id):
10081012
if self.send_handles[microbatch_id] is not None:
10091013
self.send_handles[microbatch_id].wait()
@@ -1240,6 +1244,8 @@ def _executor_loop(self):
12401244
iter_stats=iter_stats,
12411245
iter_start_time=iter_start_time))
12421246

1247+
self.iter_counter += 1
1248+
12431249
def _prepare_draft_requests(self):
12441250
try:
12451251
# Set draft tokens here to make the KV cache manager
@@ -1473,6 +1479,8 @@ def _executor_loop_overlap(self):
14731479

14741480
self._kv_connector_terminate_requests()
14751481

1482+
self.iter_counter += 1
1483+
14761484
def _accept_draft_tokens(
14771485
self, scheduled_batch: ScheduledRequests,
14781486
target_outputs: SampleStateTensors,
@@ -1964,9 +1972,10 @@ def _check_disagg_gen_cache_transfer_status(self, atLeastNum: int = 0):
19641972
def _forward_step(self,
19651973
scheduled_requests,
19661974
new_tensors_device: Optional[SampleStateTensors] = None):
1975+
ExpertStatistic.set_iter(self.iter_counter)
19671976

19681977
@nvtx_range(
1969-
f"[Executor] _forward_step {self.model_engine.iter_counter + 1}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs"
1978+
f"[Executor] _forward_step {self.iter_counter}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs"
19701979
)
19711980
def forward(scheduled_requests, resource_manager, new_tensors_device,
19721981
gather_context_logits, cache_indirection_buffer):
@@ -2304,7 +2313,7 @@ def _handle_responses(self):
23042313

23052314
# Skip active requests that are not scheduled
23062315
if request.return_perf_metrics and request.py_decoding_iter >= 1:
2307-
request.update_perf_metrics(self.model_engine.iter_counter)
2316+
request.update_perf_metrics(self.iter_counter)
23082317

23092318
request_done = False
23102319
if request.py_decoding_iter == 1 or request.is_finished or \

0 commit comments

Comments
 (0)