Skip to content

Commit a8521a3

Browse files
committed
[None][chore] remove executor_config in create_py_executor_instance
Signed-off-by: leslie-fang25 <[email protected]>
1 parent 32cb853 commit a8521a3

File tree

3 files changed

+46
-20
lines changed

3 files changed

+46
-20
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
import tensorrt_llm.bindings.executor as trtllm
1010
from tensorrt_llm._torch.model_config import ModelConfig
1111
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
12+
from tensorrt_llm.bindings.executor import \
13+
CacheTransceiverConfig as _CacheTransceiverConfig
1214
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
15+
from tensorrt_llm.bindings.executor import PeftCacheConfig as _PeftCacheConfig
16+
from tensorrt_llm.bindings.executor import SchedulerConfig as _SchedulerConfig
1317
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig, SamplerType
1418
from tensorrt_llm.logger import logger
1519
from tensorrt_llm.lora_helper import (LoraConfig,
@@ -504,7 +508,6 @@ def create_py_executor_instance(
504508
resources,
505509
mapping,
506510
pytorch_backend_config,
507-
executor_config,
508511
ctx_chunk_config,
509512
model_engine,
510513
start_worker,
@@ -515,13 +518,19 @@ def create_py_executor_instance(
515518
garbage_collection_gen0_threshold: Optional[int] = None,
516519
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
517520
max_seq_len: Optional[int] = None,
521+
max_batch_size: Optional[int] = None,
522+
max_beam_width: Optional[int] = None,
523+
max_num_tokens: Optional[int] = None,
524+
peft_cache_config: Optional[_PeftCacheConfig] = None,
525+
scheduler_config: Optional[_SchedulerConfig] = None,
526+
cache_transceiver_config: Optional[_CacheTransceiverConfig] = None,
518527
) -> PyExecutor:
519528
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
520529

521530
spec_config = model_engine.spec_config
522531

523532
logger.info(
524-
f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}, max_batch_size={executor_config.max_batch_size}"
533+
f"max_seq_len={max_seq_len}, max_num_requests={max_batch_size}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}"
525534
)
526535

527536
for key, value in pytorch_backend_config.extra_resource_managers.items():
@@ -578,16 +587,15 @@ def create_py_executor_instance(
578587
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)
579588

580589
peft_cache_config_model = PeftCacheConfig.from_pybind(
581-
executor_config.peft_cache_config
582-
) if executor_config.peft_cache_config is not None else PeftCacheConfig(
583-
)
590+
peft_cache_config
591+
) if peft_cache_config is not None else PeftCacheConfig()
584592
if lora_config.max_loras is not None:
585593
peft_cache_config_model.num_device_module_layer = \
586594
max_lora_rank * num_lora_modules * lora_config.max_loras
587595
if lora_config.max_cpu_loras is not None:
588596
peft_cache_config_model.num_host_module_layer = \
589597
max_lora_rank * num_lora_modules * lora_config.max_cpu_loras
590-
executor_config.peft_cache_config = peft_cache_config_model._to_pybind()
598+
peft_cache_config = peft_cache_config_model._to_pybind()
591599

592600
from tensorrt_llm.bindings import WorldConfig
593601
world_config = WorldConfig(
@@ -598,7 +606,7 @@ def create_py_executor_instance(
598606
gpus_per_node=dist.mapping.gpus_per_node,
599607
)
600608
peft_cache_manager = PeftCacheManager(
601-
peft_cache_config=executor_config.peft_cache_config,
609+
peft_cache_config=peft_cache_config,
602610
lora_config=lora_config,
603611
model_config=model_binding_config,
604612
world_config=world_config,
@@ -609,7 +617,7 @@ def create_py_executor_instance(
609617
lora_config.trtllm_modules_to_hf_modules,
610618
lora_config.swap_gate_up_proj_lora_b_weight)
611619

612-
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
620+
max_num_sequences = max_batch_size * mapping.pp_size
613621

614622
resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager(
615623
max_num_sequences)
@@ -632,17 +640,16 @@ def create_py_executor_instance(
632640
scheduler_capacity,
633641
kv_cache_manager.impl if kv_cache_manager is not None else None,
634642
peft_cache_manager.impl if peft_cache_manager is not None else None,
635-
executor_config.scheduler_config.capacity_scheduler_policy,
643+
scheduler_config.capacity_scheduler_policy,
636644
two_step_lookahead=mapping.has_pp())
637-
mb_scheduler = BindMicroBatchScheduler(executor_config.max_batch_size,
638-
executor_config.max_num_tokens,
645+
mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
639646
ctx_chunk_config)
640647
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
641648

642649
config = model_engine.model.model_config.pretrained_config
643650
attention_type = AttentionTypeCpp.MLA if is_mla(
644651
config) else AttentionTypeCpp.DEFAULT
645-
cache_transceiver_config = executor_config.cache_transceiver_config
652+
cache_transceiver_config = cache_transceiver_config
646653
kv_cache_transceiver = create_kv_cache_transceiver(
647654
mapping, kv_cache_manager, attention_type, cache_transceiver_config)
648655
return PyExecutor(
@@ -655,16 +662,17 @@ def create_py_executor_instance(
655662
max_num_sequences=max_num_sequences,
656663
disable_overlap_scheduler=pytorch_backend_config.
657664
disable_overlap_scheduler,
658-
max_batch_size=executor_config.max_batch_size,
659-
max_beam_width=executor_config.max_beam_width,
665+
max_batch_size=max_batch_size,
666+
max_beam_width=max_beam_width,
660667
max_draft_len=spec_config.max_draft_len
661668
if spec_config is not None else 0,
662669
kv_cache_transceiver=kv_cache_transceiver,
663670
guided_decoder=guided_decoder,
664671
start_worker=start_worker,
665672
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
666673
kv_connector_manager=kv_connector_manager,
667-
max_seq_len=max_seq_len)
674+
max_seq_len=max_seq_len,
675+
peft_cache_config=peft_cache_config)
668676

669677

670678
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
is_trace_enabled, nvtx_range, trace_func)
2525
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
2626
FinishReason, InflightBatchingStats,
27-
IterationStats, KvCacheStats,
28-
RequestStage, RequestStats,
27+
IterationStats, KvCacheStats)
28+
from tensorrt_llm.bindings.executor import PeftCacheConfig as _PeftCacheConfig
29+
from tensorrt_llm.bindings.executor import (RequestStage, RequestStats,
2930
SpecDecodingStats,
3031
StaticBatchingStats)
3132
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
@@ -157,11 +158,14 @@ def __init__(self,
157158
garbage_collection_gen0_threshold: Optional[int] = None,
158159
start_worker: bool = True,
159160
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
160-
max_seq_len: Optional[int] = None):
161+
max_seq_len: Optional[int] = None,
162+
peft_cache_config: Optional[_PeftCacheConfig] = None):
161163
super(PyExecutor, self).__init__()
162164
self.device_id = torch.cuda.current_device()
163165
self.global_rank = global_mpi_rank()
164166

167+
self.peft_cache_config = peft_cache_config
168+
165169
# profile config
166170
self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes(
167171
PROFILE_START_STOP_ENV_VAR_NAME)

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,6 @@ def drafting_loop_wrapper(model):
495495
resources=resources,
496496
mapping=mapping,
497497
pytorch_backend_config=pytorch_backend_config,
498-
executor_config=executor_config,
499498
ctx_chunk_config=ctx_chunk_config,
500499
model_engine=model_engine,
501500
start_worker=False,
@@ -507,7 +506,16 @@ def drafting_loop_wrapper(model):
507506
kv_connector_manager=kv_connector_manager
508507
if not estimating_kv_cache else None,
509508
max_seq_len=executor_config.max_seq_len,
509+
max_batch_size=executor_config.max_batch_size,
510+
max_beam_width=executor_config.max_beam_width,
511+
max_num_tokens=executor_config.max_num_tokens,
512+
peft_cache_config=executor_config.peft_cache_config,
513+
scheduler_config=executor_config.scheduler_config,
514+
cache_transceiver_config=executor_config.cache_transceiver_config,
510515
)
516+
# Modify the executor_config.peft_cache_config which might be mutated
517+
# inside create_py_executor_instance
518+
executor_config.peft_cache_config = py_executor.peft_cache_config
511519

512520
if estimating_kv_cache:
513521
assert kv_cache_creator is not None
@@ -540,7 +548,6 @@ def drafting_loop_wrapper(model):
540548
resources=resources,
541549
mapping=mapping,
542550
pytorch_backend_config=pytorch_backend_config,
543-
executor_config=executor_config,
544551
ctx_chunk_config=ctx_chunk_config,
545552
model_engine=model_engine,
546553
start_worker=False,
@@ -552,6 +559,13 @@ def drafting_loop_wrapper(model):
552559
garbage_collection_gen0_threshold,
553560
kv_connector_manager=kv_connector_manager,
554561
max_seq_len=executor_config.max_seq_len,
562+
max_batch_size=executor_config.max_batch_size,
563+
max_beam_width=executor_config.max_beam_width,
564+
max_num_tokens=executor_config.max_num_tokens,
565+
peft_cache_config=executor_config.peft_cache_config,
566+
scheduler_config=executor_config.scheduler_config,
567+
cache_transceiver_config=executor_config.
568+
cache_transceiver_config,
555569
)
556570

557571
_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)

0 commit comments

Comments
 (0)