Skip to content

Commit 6dc2238

Browse files
zulissimetalbluque
andauthored
Hf model name download (#1048)
* first commit * add test * typo * typo * ruff * add hf readonly fg token for tests * address review * raise error if type not recognized * add pydantic to reqs * typo * correctly revert --------- Co-authored-by: lbluque <[email protected]>
1 parent 44c6f32 commit 6dc2238

File tree

9 files changed

+385
-92
lines changed

9 files changed

+385
-92
lines changed

.github/workflows/test.yml

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ jobs:
5353
echo "$(readlink -f .)" >> $GITHUB_PATH
5454
5555
- name: Test core with pytest
56+
env:
57+
HF_TOKEN: ${{ secrets.HF_TOKEN_OMAT_READONLY }}
5658
run: |
5759
pytest tests -vv --ignore=tests/demo/ocpapi/tests/integration/ --cov-report=xml --cov=fairchem -c ./packages/fairchem-core/pyproject.toml
5860

packages/env.cpu.yml

+1
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ dependencies:
2626
- submitit
2727
- tensorboard
2828
- wandb
29+
- huggingface_hub
2930
name: fair-chem

packages/env.gpu.yml

+1
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ dependencies:
2929
- submitit
3030
- tensorboard
3131
- wandb
32+
- huggingface_hub
3233
name: fair-chem

packages/fairchem-core/pyproject.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "fairchem-core"
7-
description = "Machine learning models for use in catalysis as part of the Open Catalyst Project"
7+
description = "Machine learning models for chemistry and materials science by the FAIR Chemistry team"
88
license = {text = "MIT License"}
99
dynamic = ["version", "readme"]
1010
requires-python = ">=3.9, <3.13"
@@ -24,7 +24,9 @@ dependencies = [
2424
"tqdm",
2525
"submitit",
2626
"hydra-core",
27-
"torchtnt"
27+
"torchtnt",
28+
"huggingface_hub>=0.29.2",
29+
"pydantic>=2.10.0"
2830
]
2931

3032
[project.optional-dependencies] # add optional dependencies to be installed as pip install fairchem.core[dev]

packages/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
torch==2.4.1
22
numpy==1.26.4
33
ase==3.24.0
4+
huggingface_hub==0.29.2

src/fairchem/core/models/model_registry.py

+55-23
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,43 @@
77

88
from __future__ import annotations
99

10+
import json
1011
import logging
1112
import os
1213
import shutil
1314
from importlib import resources
14-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Literal
1516

1617
import requests
17-
import yaml
18+
from huggingface_hub import hf_hub_download
19+
from pydantic import AnyUrl, BaseModel
1820

1921
from fairchem.core import models
2022

2123
if TYPE_CHECKING:
2224
from pathlib import Path
2325

2426

25-
with (resources.files(models) / "pretrained_models.yml").open("rt") as f:
26-
MODEL_REGISTRY = yaml.safe_load(f)
27+
class HuggingFaceModel(BaseModel):
28+
type: Literal["huggingface_hub"]
29+
repo_id: Literal["fairchem/OMAT24"]
30+
filename: str
2731

2832

29-
available_pretrained_models = tuple(MODEL_REGISTRY.keys())
33+
class URLModel(BaseModel):
34+
url: str
35+
type: Literal["url"]
36+
37+
38+
class ModelRegistry(BaseModel):
39+
models: dict[str, AnyUrl | HuggingFaceModel | URLModel]
40+
41+
42+
with (resources.files(models) / "pretrained_models.json").open("rb") as f:
43+
MODEL_REGISTRY = ModelRegistry(models=json.load(f))
44+
45+
46+
available_pretrained_models = tuple(MODEL_REGISTRY.models.keys())
3047

3148

3249
def model_name_to_local_file(model_name: str, local_cache: str | Path) -> str:
@@ -40,25 +57,40 @@ def model_name_to_local_file(model_name: str, local_cache: str | Path) -> str:
4057
str: local path to checkpoint file
4158
"""
4259
logging.info(f"Checking local cache: {local_cache} for model {model_name}")
43-
if model_name not in MODEL_REGISTRY:
60+
if model_name not in available_pretrained_models:
4461
logging.error(f"Not a valid model name '{model_name}'")
4562
raise ValueError(
4663
f"Not a valid model name '{model_name}'. Model name must be one of {available_pretrained_models}"
4764
)
48-
if not os.path.exists(local_cache):
49-
os.makedirs(local_cache, exist_ok=True)
50-
if not os.path.exists(local_cache):
51-
logging.error(f"Failed to create local cache folder '{local_cache}'")
52-
raise RuntimeError(f"Failed to create local cache folder '{local_cache}'")
53-
model_url = MODEL_REGISTRY[model_name]
54-
local_path = os.path.join(local_cache, os.path.basename(model_url))
55-
56-
# download the file
57-
if not os.path.isfile(local_path):
58-
local_path_tmp = local_path + ".tmp" # download to a tmp file in case we fail
59-
with open(local_path_tmp, "wb") as out:
60-
response = requests.get(model_url, stream=True)
61-
response.raw.decode_content = True
62-
shutil.copyfileobj(response.raw, out)
63-
shutil.move(local_path_tmp, local_path)
64-
return local_path
65+
66+
if isinstance(MODEL_REGISTRY.models[model_name], URLModel):
67+
# We have a url to download
68+
69+
if not os.path.exists(local_cache):
70+
os.makedirs(local_cache, exist_ok=True)
71+
if not os.path.exists(local_cache):
72+
logging.error(f"Failed to create local cache folder '{local_cache}'")
73+
raise RuntimeError(f"Failed to create local cache folder '{local_cache}'")
74+
model_url = MODEL_REGISTRY.models[model_name].url
75+
local_path = os.path.join(local_cache, os.path.basename(model_url))
76+
77+
# download the file
78+
if not os.path.isfile(local_path):
79+
local_path_tmp = (
80+
local_path + ".tmp"
81+
) # download to a tmp file in case we fail
82+
with open(local_path_tmp, "wb") as out:
83+
response = requests.get(model_url, stream=True)
84+
response.raw.decode_content = True
85+
shutil.copyfileobj(response.raw, out)
86+
shutil.move(local_path_tmp, local_path)
87+
return local_path
88+
elif isinstance(MODEL_REGISTRY.models[model_name], HuggingFaceModel):
89+
return hf_hub_download(
90+
repo_id=MODEL_REGISTRY.models[model_name].repo_id,
91+
filename=MODEL_REGISTRY.models[model_name].filename,
92+
)
93+
else:
94+
raise NotImplementedError(
95+
f"{type(MODEL_REGISTRY.models[model_name])} is an unknown registry type."
96+
)

0 commit comments

Comments
 (0)