Skip to content

Commit 7ab8112

Browse files
authored
[None][fix] Refactoring to avoid circular import when importing torch models (#6720)
Signed-off-by: Rakib Hasan <[email protected]>
1 parent c9fe07e commit 7ab8112

File tree

33 files changed

+159
-105
lines changed

33 files changed

+159
-105
lines changed

docs/source/torch/features/lora.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ The PyTorch backend provides LoRA support, allowing you to:
3333

3434
```python
3535
from tensorrt_llm import LLM
36-
from tensorrt_llm.lora_manager import LoraConfig
36+
from tensorrt_llm.lora_helper import LoraConfig
3737
from tensorrt_llm.executor.request import LoRARequest
3838
from tensorrt_llm.sampling_params import SamplingParams
3939

examples/llm-api/llm_multilora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from tensorrt_llm import LLM
77
from tensorrt_llm.executor import LoRARequest
8-
from tensorrt_llm.lora_manager import LoraConfig
8+
from tensorrt_llm.lora_helper import LoraConfig
99

1010

1111
def main():

tensorrt_llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _add_trt_llm_dll_directory():
3333
# otherwise `MemoryError: std::bad_alloc` pattern error will be raised.
3434
import xgrammar # noqa
3535

36+
import tensorrt_llm._torch.models as torch_models
3637
import tensorrt_llm.functional as functional
3738
import tensorrt_llm.math_utils as math_utils
3839
import tensorrt_llm.models as models
@@ -82,6 +83,7 @@ def _add_trt_llm_dll_directory():
8283
'default_trtnet',
8384
'precision',
8485
'net_guard',
86+
'torch_models',
8587
'Network',
8688
'Mapping',
8789
'MnnvlMemory',

tensorrt_llm/_torch/models/modeling_phi4mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
2323
register_input_processor)
2424
from ...logger import logger
25-
from ...lora_manager import LoraConfig
25+
from ...lora_helper import LoraConfig
2626
from ...sampling_params import SamplingParams
2727
from ..attention_backend import AttentionMetadata
2828
from ..model_config import ModelConfig

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn.functional as F
77
from torch import nn
88

9-
from tensorrt_llm import logger
9+
import tensorrt_llm.logger as trtllm_logger
1010
from tensorrt_llm._utils import get_sm_version
1111
from tensorrt_llm.quantization.utils.fp4_utils import (
1212
float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices,
@@ -743,7 +743,7 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict],
743743
if int(name.split(".")[0]) not in expert_ids:
744744
continue
745745
weight_name = name.replace("weight_scale_inv", "weight")
746-
logger.debug(f"Resmoothing {weight_name}")
746+
trtllm_logger.logger.debug(f"Resmoothing {weight_name}")
747747
weight = weights[weight_name][:]
748748
scale = weights[name][:]
749749
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
1414
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
1515
from tensorrt_llm.logger import logger
16-
from tensorrt_llm.lora_manager import (LoraConfig,
17-
get_default_trtllm_modules_to_hf_modules,
18-
load_torch_lora)
16+
from tensorrt_llm.lora_helper import (LoraConfig,
17+
get_default_trtllm_modules_to_hf_modules)
18+
from tensorrt_llm.lora_manager import load_torch_lora
1919
from tensorrt_llm.mapping import Mapping
2020

2121
from ..model_config import ModelConfig

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
2828
MultimodalRuntimeData)
2929
from tensorrt_llm.logger import logger
30-
from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig
30+
from tensorrt_llm.lora_helper import LoraConfig
31+
from tensorrt_llm.lora_manager import LoraModelConfig
3132
from tensorrt_llm.mapping import Mapping
3233
from tensorrt_llm.models.modeling_utils import QuantAlgo
3334
from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
1414
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
1515
from tensorrt_llm.logger import logger
16-
from tensorrt_llm.lora_manager import LoraConfig
16+
from tensorrt_llm.lora_helper import LoraConfig
1717
from tensorrt_llm.mapping import Mapping
1818
from tensorrt_llm.quantization import QuantAlgo
1919

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import tensorrt_llm
1111
import tensorrt_llm.bindings
1212
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
13-
from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig
13+
from tensorrt_llm.lora_helper import LoraConfig
14+
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
1415
from tensorrt_llm.sampling_params import SamplingParams
1516

1617
from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range

tensorrt_llm/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from .functional import PositionEmbeddingType
3737
from .graph_rewriting import optimize
3838
from .logger import logger
39-
from .lora_manager import LoraConfig
39+
from .lora_helper import LoraConfig
4040
from .models import PretrainedConfig, PretrainedModel
4141
from .models.modeling_utils import SpeculativeDecodingMode, optimize_model
4242
from .network import Network, net_guard

0 commit comments

Comments
 (0)