Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_hyperclovax.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
self.vision_config = self.pretrained_config.vision_config

model_path = self.pretrained_config._name_or_path
self.device = f"cuda:{model_config.mapping.rank}"
self.device = f"cuda:{torch.cuda.current_device()}"

hf_model_config = AutoConfig.from_pretrained(model_path,
trust_remote_code=True)
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ def __init__(self, model_config: ModelConfig[Llama4Config], *args,
**kwargs):
super().__init__()
self.pretrained_config = model_config.pretrained_config
self.device = f"cuda:{model_config.mapping.rank}"
self.device = f"cuda:{torch.cuda.current_device()}"

self.dtype = self.pretrained_config.text_config.torch_dtype

Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
super().__init__()
self.model_config = model_config
self.pretrained_config = model_config.pretrained_config
self.device = f"cuda:{model_config.mapping.rank}"
self.device = f"cuda:{torch.cuda.current_device()}"
model_path = self.pretrained_config._name_or_path

# Determine the actual local path for model files
Expand Down
Loading