7
7
8
8
from __future__ import annotations
9
9
10
+ import json
10
11
import logging
11
12
import os
12
13
import shutil
13
14
from importlib import resources
14
- from typing import TYPE_CHECKING
15
+ from typing import TYPE_CHECKING , Literal
15
16
16
17
import requests
17
- import yaml
18
+ from huggingface_hub import hf_hub_download
19
+ from pydantic import AnyUrl , BaseModel
18
20
19
21
from fairchem .core import models
20
22
21
23
if TYPE_CHECKING :
22
24
from pathlib import Path
23
25
24
26
25
- with (resources .files (models ) / "pretrained_models.yml" ).open ("rt" ) as f :
26
- MODEL_REGISTRY = yaml .safe_load (f )
27
+ class HuggingFaceModel (BaseModel ):
28
+ type : Literal ["huggingface_hub" ]
29
+ repo_id : Literal ["fairchem/OMAT24" ]
30
+ filename : str
27
31
28
32
29
- available_pretrained_models = tuple (MODEL_REGISTRY .keys ())
33
+ class URLModel (BaseModel ):
34
+ url : str
35
+ type : Literal ["url" ]
36
+
37
+
38
+ class ModelRegistry (BaseModel ):
39
+ models : dict [str , AnyUrl | HuggingFaceModel | URLModel ]
40
+
41
+
42
+ with (resources .files (models ) / "pretrained_models.json" ).open ("rb" ) as f :
43
+ MODEL_REGISTRY = ModelRegistry (models = json .load (f ))
44
+
45
+
46
+ available_pretrained_models = tuple (MODEL_REGISTRY .models .keys ())
30
47
31
48
32
49
def model_name_to_local_file (model_name : str , local_cache : str | Path ) -> str :
@@ -40,25 +57,40 @@ def model_name_to_local_file(model_name: str, local_cache: str | Path) -> str:
40
57
str: local path to checkpoint file
41
58
"""
42
59
logging .info (f"Checking local cache: { local_cache } for model { model_name } " )
43
- if model_name not in MODEL_REGISTRY :
60
+ if model_name not in available_pretrained_models :
44
61
logging .error (f"Not a valid model name '{ model_name } '" )
45
62
raise ValueError (
46
63
f"Not a valid model name '{ model_name } '. Model name must be one of { available_pretrained_models } "
47
64
)
48
- if not os .path .exists (local_cache ):
49
- os .makedirs (local_cache , exist_ok = True )
50
- if not os .path .exists (local_cache ):
51
- logging .error (f"Failed to create local cache folder '{ local_cache } '" )
52
- raise RuntimeError (f"Failed to create local cache folder '{ local_cache } '" )
53
- model_url = MODEL_REGISTRY [model_name ]
54
- local_path = os .path .join (local_cache , os .path .basename (model_url ))
55
-
56
- # download the file
57
- if not os .path .isfile (local_path ):
58
- local_path_tmp = local_path + ".tmp" # download to a tmp file in case we fail
59
- with open (local_path_tmp , "wb" ) as out :
60
- response = requests .get (model_url , stream = True )
61
- response .raw .decode_content = True
62
- shutil .copyfileobj (response .raw , out )
63
- shutil .move (local_path_tmp , local_path )
64
- return local_path
65
+
66
+ if isinstance (MODEL_REGISTRY .models [model_name ], URLModel ):
67
+ # We have a url to download
68
+
69
+ if not os .path .exists (local_cache ):
70
+ os .makedirs (local_cache , exist_ok = True )
71
+ if not os .path .exists (local_cache ):
72
+ logging .error (f"Failed to create local cache folder '{ local_cache } '" )
73
+ raise RuntimeError (f"Failed to create local cache folder '{ local_cache } '" )
74
+ model_url = MODEL_REGISTRY .models [model_name ].url
75
+ local_path = os .path .join (local_cache , os .path .basename (model_url ))
76
+
77
+ # download the file
78
+ if not os .path .isfile (local_path ):
79
+ local_path_tmp = (
80
+ local_path + ".tmp"
81
+ ) # download to a tmp file in case we fail
82
+ with open (local_path_tmp , "wb" ) as out :
83
+ response = requests .get (model_url , stream = True )
84
+ response .raw .decode_content = True
85
+ shutil .copyfileobj (response .raw , out )
86
+ shutil .move (local_path_tmp , local_path )
87
+ return local_path
88
+ elif isinstance (MODEL_REGISTRY .models [model_name ], HuggingFaceModel ):
89
+ return hf_hub_download (
90
+ repo_id = MODEL_REGISTRY .models [model_name ].repo_id ,
91
+ filename = MODEL_REGISTRY .models [model_name ].filename ,
92
+ )
93
+ else :
94
+ raise NotImplementedError (
95
+ f"{ type (MODEL_REGISTRY .models [model_name ])} is an unknown registry type."
96
+ )
0 commit comments