Skip to content

Commit ab258e8

Browse files
committed
[https://nvbugs/5549081][fix] Fix device id assignment for some vision models (NVIDIA#8070)
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]> Signed-off-by: Chang Liu <[email protected]>
1 parent 0acdecb commit ab258e8

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

tensorrt_llm/_torch/models/modeling_hyperclovax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
726726
self.vision_config = self.pretrained_config.vision_config
727727

728728
model_path = self.pretrained_config._name_or_path
729-
self.device = f"cuda:{model_config.mapping.rank}"
729+
# TODO: use config.mapping.get_local_rank() instead
730+
self.device = f"cuda:{torch.cuda.current_device()}"
730731

731732
hf_model_config = AutoConfig.from_pretrained(model_path,
732733
trust_remote_code=True)

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,8 @@ def __init__(self, model_config: ModelConfig[Llama4Config], *args,
999999
**kwargs):
10001000
super().__init__()
10011001
self.pretrained_config = model_config.pretrained_config
1002-
self.device = f"cuda:{model_config.mapping.rank}"
1002+
# TODO: use config.mapping.get_local_rank() instead
1003+
self.device = f"cuda:{torch.cuda.current_device()}"
10031004

10041005
self.dtype = self.pretrained_config.text_config.torch_dtype
10051006

tensorrt_llm/_torch/models/modeling_llava_next.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
295295
super().__init__()
296296
self.model_config = model_config
297297
self.pretrained_config = model_config.pretrained_config
298-
self.device = f"cuda:{model_config.mapping.rank}"
298+
# TODO: use config.mapping.get_local_rank() instead
299+
self.device = f"cuda:{torch.cuda.current_device()}"
299300
model_path = self.pretrained_config._name_or_path
300301

301302
# Determine the actual local path for model files

0 commit comments

Comments
 (0)