Skip to content

Commit dae55c6

Browse files
committed
fix autodeploy ci failure
Signed-off-by: leslie-fang25 <[email protected]>
1 parent 872f1b5 commit dae55c6

File tree

3 files changed

+80
-83
lines changed

3 files changed

+80
-83
lines changed

tensorrt_llm/executor/worker.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ def _create_py_executor(executor_config):
124124
args["lora_config"] = lora_config
125125
args[
126126
"garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold
127+
elif executor_config.backend == "_autodeploy":
128+
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
129+
create_autodeploy_executor
130+
create_executor = create_autodeploy_executor
127131
else:
128132
raise ValueError(
129133
f"Unsupported backend config: {executor_config.backend}")
@@ -146,21 +150,9 @@ def _create_engine(executor_config):
146150
executor_config=executor_config,
147151
managed_weights=engine.managed_weights)
148152

149-
if not hasattr(executor_config, "backend"):
150-
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
151-
executor_config)
152-
args = {
153-
"executor_config": executor_config,
154-
"checkpoint_dir": executor_config.hf_model_dir,
155-
}
156-
if executor_config.backend == "_autodeploy":
157-
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
158-
create_autodeploy_executor
159-
create_executor = create_autodeploy_executor
160-
else:
161-
raise ValueError(
162-
f"Unsupported backend config: {executor_config.backend}")
163-
return create_executor(**args)
153+
assert not hasattr(executor_config, "backend")
154+
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
155+
executor_config)
164156

165157
self.engine = _create_py_executor(
166158
executor_config) if llm_args is not None else _create_engine(

tensorrt_llm/llmapi/llm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,6 @@ def _build_model(self):
957957
self.tokenizer)
958958
self._tokenizer = self.input_processor.tokenizer
959959

960-
assert isinstance(self.args, TorchLlmArgs)
961960
# Update the tokenizer in TorchLlmArgs, so it can be used in GenerationExecutorWorker to init executor_config
962961
self.args.set_tokenizer(self.tokenizer)
963962

tensorrt_llm/llmapi/llm_args.py

Lines changed: 73 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,6 +1825,77 @@ def _load_config_from_ckpt(self, ckpt_dir: Path):
18251825
def set_tokenizer(self, tokenizer):
18261826
self.tokenizer = tokenizer
18271827

1828+
def get_executor_config(self,
1829+
_hf_model_dir: Optional[Path] = None
1830+
) -> _ExecutorConfig:
1831+
executor_config = _ExecutorConfig(
1832+
max_beam_width=self.max_beam_width,
1833+
scheduler_config=PybindMirror.maybe_to_pybind(
1834+
self.scheduler_config),
1835+
max_batch_size=self.max_batch_size,
1836+
max_num_tokens=self.max_num_tokens,
1837+
gather_generation_logits=self.gather_generation_logits,
1838+
fail_fast_on_attention_window_too_large=getattr(
1839+
self, 'fail_fast_on_attention_window_too_large', False),
1840+
)
1841+
1842+
if self.kv_cache_config is not None:
1843+
executor_config.kv_cache_config = PybindMirror.maybe_to_pybind(
1844+
self.kv_cache_config)
1845+
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
1846+
# Disable KV cache reuse for deterministic mode
1847+
executor_config.kv_cache_config.enable_block_reuse = False
1848+
executor_config.kv_cache_config.enable_partial_reuse = False
1849+
if self.peft_cache_config is not None:
1850+
executor_config.peft_cache_config = PybindMirror.maybe_to_pybind(
1851+
self.peft_cache_config)
1852+
if self.decoding_config is not None:
1853+
executor_config.decoding_config = self.decoding_config
1854+
if self.guided_decoding_backend == 'xgrammar':
1855+
executor_config.guided_decoding_config = _GuidedDecodingConfig(
1856+
backend=_GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,
1857+
**_xgrammar_tokenizer_info(self.tokenizer))
1858+
elif self.guided_decoding_backend == 'llguidance':
1859+
executor_config.guided_decoding_config = _GuidedDecodingConfig(
1860+
backend=_GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE,
1861+
**_llguidance_tokenizer_info(self.tokenizer))
1862+
elif self.guided_decoding_backend is not None:
1863+
raise ValueError(
1864+
f"Unsupported guided decoding backend {self.guided_decoding_backend}"
1865+
)
1866+
1867+
executor_config.enable_chunked_context = self.enable_chunked_prefill
1868+
executor_config.max_beam_width = self.max_beam_width
1869+
if self.cache_transceiver_config is not None:
1870+
executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind(
1871+
self.cache_transceiver_config)
1872+
1873+
from tensorrt_llm._torch.pyexecutor.config import update_executor_config
1874+
1875+
spec_config = self.speculative_config
1876+
max_batch_size = executor_config.max_batch_size
1877+
1878+
if spec_config is not None and spec_config.decoding_type == "AUTO":
1879+
from tensorrt_llm._torch.speculative import suggest_spec_config
1880+
spec_config = suggest_spec_config(max_batch_size)
1881+
1882+
update_executor_config(
1883+
executor_config,
1884+
backend=self.backend,
1885+
pytorch_backend_config=self.get_pytorch_backend_config()
1886+
if self.backend in ["pytorch", "_autodeploy"] else None,
1887+
mapping=self.parallel_config.to_mapping(),
1888+
speculative_config=spec_config,
1889+
hf_model_dir=_hf_model_dir,
1890+
max_input_len=self.max_input_len,
1891+
max_seq_len=self.max_seq_len,
1892+
checkpoint_format=None
1893+
if self.backend == "_autodeploy" else self.checkpoint_format,
1894+
checkpoint_loader=None
1895+
if self.backend == "_autodeploy" else self.checkpoint_loader)
1896+
1897+
return executor_config
1898+
18281899

18291900
class TrtLlmArgs(BaseLlmArgs):
18301901

@@ -2377,73 +2448,8 @@ def set_mm_encoder_only(self, mm_encoder_only):
23772448
def get_executor_config(self,
23782449
_hf_model_dir: Optional[Path] = None
23792450
) -> _ExecutorConfig:
2380-
executor_config = _ExecutorConfig(
2381-
max_beam_width=self.max_beam_width,
2382-
scheduler_config=PybindMirror.maybe_to_pybind(
2383-
self.scheduler_config),
2384-
max_batch_size=self.max_batch_size,
2385-
max_num_tokens=self.max_num_tokens,
2386-
gather_generation_logits=self.gather_generation_logits,
2387-
fail_fast_on_attention_window_too_large=getattr(
2388-
self, 'fail_fast_on_attention_window_too_large', False),
2389-
)
2390-
2391-
if self.kv_cache_config is not None:
2392-
executor_config.kv_cache_config = PybindMirror.maybe_to_pybind(
2393-
self.kv_cache_config)
2394-
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
2395-
# Disable KV cache reuse for deterministic mode
2396-
executor_config.kv_cache_config.enable_block_reuse = False
2397-
executor_config.kv_cache_config.enable_partial_reuse = False
2398-
if self.peft_cache_config is not None:
2399-
executor_config.peft_cache_config = PybindMirror.maybe_to_pybind(
2400-
self.peft_cache_config)
2401-
if self.decoding_config is not None:
2402-
executor_config.decoding_config = self.decoding_config
2403-
if self.guided_decoding_backend == 'xgrammar':
2404-
executor_config.guided_decoding_config = _GuidedDecodingConfig(
2405-
backend=_GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,
2406-
**_xgrammar_tokenizer_info(self.tokenizer))
2407-
elif self.guided_decoding_backend == 'llguidance':
2408-
executor_config.guided_decoding_config = _GuidedDecodingConfig(
2409-
backend=_GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE,
2410-
**_llguidance_tokenizer_info(self.tokenizer))
2411-
elif self.guided_decoding_backend is not None:
2412-
raise ValueError(
2413-
f"Unsupported guided decoding backend {self.guided_decoding_backend}"
2414-
)
2415-
2416-
executor_config.enable_chunked_context = self.enable_chunked_prefill
2417-
executor_config.max_beam_width = self.max_beam_width
2418-
if self.cache_transceiver_config is not None:
2419-
executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind(
2420-
self.cache_transceiver_config)
2421-
2422-
from tensorrt_llm._torch.pyexecutor.config import update_executor_config
2423-
2424-
spec_config = self.speculative_config
2425-
max_batch_size = executor_config.max_batch_size
2426-
2427-
if spec_config is not None and spec_config.decoding_type == "AUTO":
2428-
from tensorrt_llm._torch.speculative import suggest_spec_config
2429-
spec_config = suggest_spec_config(max_batch_size)
2430-
2431-
update_executor_config(
2432-
executor_config,
2433-
backend=self.backend,
2434-
pytorch_backend_config=self.get_pytorch_backend_config()
2435-
if self.backend in ["pytorch", "_autodeploy"] else None,
2436-
mapping=self.parallel_config.to_mapping(),
2437-
speculative_config=spec_config,
2438-
hf_model_dir=_hf_model_dir,
2439-
max_input_len=self.max_input_len,
2440-
max_seq_len=self.max_seq_len,
2441-
checkpoint_format=None
2442-
if self.backend == "_autodeploy" else self.checkpoint_format,
2443-
checkpoint_loader=None
2444-
if self.backend == "_autodeploy" else self.checkpoint_loader,
2445-
mm_encoder_only=self.mm_encoder_only)
2446-
2451+
executor_config = super().get_executor_config(_hf_model_dir)
2452+
executor_config.mm_encoder_only = self.mm_encoder_only
24472453
return executor_config
24482454

24492455
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig

0 commit comments

Comments
 (0)