diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 3b6fefc28..20f17ee1c 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -92,17 +92,18 @@ def __init__( # this is needed because shard downloader just # appends and not redownloads the file os.remove(self.model_safetensors_path) + + self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device) + self.model = self.llm_model.model.to(self.device) else: - self.llm_model_config = AutoConfig.from_pretrained( + self.llm_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=self.local_model_path, torch_dtype=self.dtype, device_map=self.device_map, offload_buffers=self.offload_buffers ) - - self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device) - - self.model = self.llm_model.model.to(self.device) + self.model = self.llm_model.model + except Exception as err: print(f"error loading and splitting model: {err}") raise