Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ def create_py_executor_instance(
resources,
mapping,
pytorch_backend_config,
executor_config,
ctx_chunk_config,
model_engine,
start_worker,
Expand All @@ -515,13 +514,19 @@ def create_py_executor_instance(
garbage_collection_gen0_threshold: Optional[int] = None,
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
max_seq_len: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_beam_width: Optional[int] = None,
max_num_tokens: Optional[int] = None,
peft_cache_config: Optional[trtllm.PeftCacheConfig] = None,
scheduler_config: Optional[trtllm.SchedulerConfig] = None,
cache_transceiver_config: Optional[trtllm.CacheTransceiverConfig] = None,
) -> PyExecutor:
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)

spec_config = model_engine.spec_config

logger.info(
f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}, max_batch_size={executor_config.max_batch_size}"
f"max_seq_len={max_seq_len}, max_num_requests={max_batch_size}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}"
)

for key, value in pytorch_backend_config.extra_resource_managers.items():
Expand Down Expand Up @@ -578,16 +583,15 @@ def create_py_executor_instance(
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)

peft_cache_config_model = PeftCacheConfig.from_pybind(
executor_config.peft_cache_config
) if executor_config.peft_cache_config is not None else PeftCacheConfig(
)
peft_cache_config
) if peft_cache_config is not None else PeftCacheConfig()
if lora_config.max_loras is not None:
peft_cache_config_model.num_device_module_layer = \
max_lora_rank * num_lora_modules * lora_config.max_loras
if lora_config.max_cpu_loras is not None:
peft_cache_config_model.num_host_module_layer = \
max_lora_rank * num_lora_modules * lora_config.max_cpu_loras
executor_config.peft_cache_config = peft_cache_config_model._to_pybind()
peft_cache_config = peft_cache_config_model._to_pybind()

from tensorrt_llm.bindings import WorldConfig
world_config = WorldConfig(
Expand All @@ -598,7 +602,7 @@ def create_py_executor_instance(
gpus_per_node=dist.mapping.gpus_per_node,
)
peft_cache_manager = PeftCacheManager(
peft_cache_config=executor_config.peft_cache_config,
peft_cache_config=peft_cache_config,
lora_config=lora_config,
model_config=model_binding_config,
world_config=world_config,
Expand All @@ -609,7 +613,7 @@ def create_py_executor_instance(
lora_config.trtllm_modules_to_hf_modules,
lora_config.swap_gate_up_proj_lora_b_weight)

max_num_sequences = executor_config.max_batch_size * mapping.pp_size
max_num_sequences = max_batch_size * mapping.pp_size

resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager(
max_num_sequences)
Expand All @@ -632,17 +636,15 @@ def create_py_executor_instance(
scheduler_capacity,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl if peft_cache_manager is not None else None,
executor_config.scheduler_config.capacity_scheduler_policy,
scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())
mb_scheduler = BindMicroBatchScheduler(executor_config.max_batch_size,
executor_config.max_num_tokens,
mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
ctx_chunk_config)
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)

config = model_engine.model.model_config.pretrained_config
attention_type = AttentionTypeCpp.MLA if is_mla(
config) else AttentionTypeCpp.DEFAULT
cache_transceiver_config = executor_config.cache_transceiver_config
kv_cache_transceiver = create_kv_cache_transceiver(
mapping, kv_cache_manager, attention_type, cache_transceiver_config)
return PyExecutor(
Expand All @@ -655,16 +657,17 @@ def create_py_executor_instance(
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=pytorch_backend_config.
disable_overlap_scheduler,
max_batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_draft_len=spec_config.max_draft_len
if spec_config is not None else 0,
kv_cache_transceiver=kv_cache_transceiver,
guided_decoder=guided_decoder,
start_worker=start_worker,
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
kv_connector_manager=kv_connector_manager,
max_seq_len=max_seq_len)
max_seq_len=max_seq_len,
peft_cache_config=peft_cache_config)


def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
Expand Down
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
FinishReason, InflightBatchingStats,
IterationStats, KvCacheStats,
RequestStage, RequestStats,
SpecDecodingStats,
PeftCacheConfig, RequestStage,
RequestStats, SpecDecodingStats,
StaticBatchingStats)
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
ReqIdsSet)
Expand Down Expand Up @@ -157,11 +157,14 @@ def __init__(self,
garbage_collection_gen0_threshold: Optional[int] = None,
start_worker: bool = True,
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
max_seq_len: Optional[int] = None):
max_seq_len: Optional[int] = None,
peft_cache_config: Optional[PeftCacheConfig] = None):
super(PyExecutor, self).__init__()
self.device_id = torch.cuda.current_device()
self.global_rank = global_mpi_rank()

self.peft_cache_config = peft_cache_config

# profile config
self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes(
PROFILE_START_STOP_ENV_VAR_NAME)
Expand Down
19 changes: 16 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def create_py_executor(
mapping = _get_mapping(executor_config)

dist = MPIDist(mapping=mapping)

cache_transceiver_config = executor_config.cache_transceiver_config
spec_config = executor_config.speculative_config
has_draft_model_engine = False
has_spec_drafter = False
Expand Down Expand Up @@ -508,7 +508,6 @@ def drafting_loop_wrapper(model):
resources=resources,
mapping=mapping,
pytorch_backend_config=pytorch_backend_config,
executor_config=executor_config,
ctx_chunk_config=ctx_chunk_config,
model_engine=model_engine,
start_worker=False,
Expand All @@ -520,7 +519,16 @@ def drafting_loop_wrapper(model):
kv_connector_manager=kv_connector_manager
if not estimating_kv_cache else None,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_num_tokens=executor_config.max_num_tokens,
peft_cache_config=executor_config.peft_cache_config,
scheduler_config=executor_config.scheduler_config,
cache_transceiver_config=cache_transceiver_config,
)
# Modify the executor_config.peft_cache_config which might be mutated
# inside create_py_executor_instance
executor_config.peft_cache_config = py_executor.peft_cache_config

if estimating_kv_cache:
assert kv_cache_creator is not None
Expand Down Expand Up @@ -553,7 +561,6 @@ def drafting_loop_wrapper(model):
resources=resources,
mapping=mapping,
pytorch_backend_config=pytorch_backend_config,
executor_config=executor_config,
ctx_chunk_config=ctx_chunk_config,
model_engine=model_engine,
start_worker=False,
Expand All @@ -565,6 +572,12 @@ def drafting_loop_wrapper(model):
garbage_collection_gen0_threshold,
kv_connector_manager=kv_connector_manager,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_num_tokens=executor_config.max_num_tokens,
peft_cache_config=executor_config.peft_cache_config,
scheduler_config=executor_config.scheduler_config,
cache_transceiver_config=cache_transceiver_config,
)

_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)
Expand Down