diff --git a/tensorrt_llm/_torch/models/modeling_hyperclovax.py b/tensorrt_llm/_torch/models/modeling_hyperclovax.py index 975ccdb26ed..3237ad296d8 100644 --- a/tensorrt_llm/_torch/models/modeling_hyperclovax.py +++ b/tensorrt_llm/_torch/models/modeling_hyperclovax.py @@ -726,7 +726,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): self.vision_config = self.pretrained_config.vision_config model_path = self.pretrained_config._name_or_path - self.device = f"cuda:{model_config.mapping.rank}" + # TODO: use config.mapping.get_local_rank() instead + self.device = f"cuda:{torch.cuda.current_device()}" hf_model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index e4ef4b2dbb5..fbac3f00365 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -999,7 +999,8 @@ def __init__(self, model_config: ModelConfig[Llama4Config], *args, **kwargs): super().__init__() self.pretrained_config = model_config.pretrained_config - self.device = f"cuda:{model_config.mapping.rank}" + # TODO: use config.mapping.get_local_rank() instead + self.device = f"cuda:{torch.cuda.current_device()}" self.dtype = self.pretrained_config.text_config.torch_dtype diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 9b4d7eaac6b..e23fc75d53f 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -295,7 +295,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, super().__init__() self.model_config = model_config self.pretrained_config = model_config.pretrained_config - self.device = f"cuda:{model_config.mapping.rank}" + # TODO: use config.mapping.get_local_rank() instead + self.device = f"cuda:{torch.cuda.current_device()}" model_path = self.pretrained_config._name_or_path # Determine the actual local path for model files