Skip to content

Commit 9c96e0a

Browse files
committed
Add get_model_spec module for API access to model spec files
1 parent 64b8108 commit 9c96e0a

File tree

3 files changed

+224
-0
lines changed

3 files changed

+224
-0
lines changed

docs/api.rst

+6
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,9 @@ Mask classes
2727

2828
.. automodule:: xija.component.mask
2929
:members:
30+
31+
Get model spec
32+
--------------
33+
.. automodule:: xija.get_model_spec
34+
:members:
35+

xija/get_model_spec.py

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2+
"""
3+
Get Chandra model specifications
4+
"""
5+
import os
6+
import re
7+
from pathlib import Path
8+
from typing import List, Optional
9+
10+
from git import Repo
11+
import requests
12+
from Ska.File import get_globfiles
13+
14+
__all__ = ['get_xija_model_file', 'get_xija_model_names', 'get_repo_version',
15+
'check_github_version']
16+
17+
REPO_PATH = Path(os.environ['SKA'], 'data', 'chandra_models')
18+
MODELS_PATH = REPO_PATH / 'chandra_models' / 'xija'
19+
CHANDRA_MODELS_URL = 'https://api.github.com/repos/sot/chandra_models/releases'
20+
21+
22+
def get_xija_model_file(model_name, models_path=MODELS_PATH) -> str:
23+
"""
24+
Get file name of Xija model specification for the specified ``model_name``.
25+
26+
Supported model names include (but are not limited to): ``'aca'``,
27+
``'acisfp'``, ``'dea'``, ``'dpa'``, ``'psmc'``, ``'minusyz'``, and
28+
``'pftank2t'``.
29+
30+
Use ``get_xija_model_names()`` for the full list.
31+
32+
Examples
33+
--------
34+
>>> import xija
35+
>>> from xija.get_model_spec import get_xija_model_file
36+
>>> model_spec = get_xija_model_file('acisfp')
37+
>>> model = xija.XijaModel('acisfp', model_spec=model_spec, start='2012:001', stop='2012:010')
38+
>>> model.make()
39+
>>> model.calc()
40+
41+
Parameters
42+
----------
43+
model_name : str
44+
Name of model
45+
models_path : str, Path
46+
Path to directory containing xija model spec files (default is
47+
$SKA/data/chandra_models)
48+
49+
Returns
50+
-------
51+
str
52+
File name of the corresponding Xija model specification
53+
"""
54+
models_path = Path(models_path)
55+
56+
if not models_path.exists():
57+
raise FileNotFoundError(f'xija models directory {models_path} does not exist')
58+
59+
file_glob = str(models_path / '*' / f'{model_name.lower()}_spec.json')
60+
try:
61+
# get_globfiles() default requires exactly one file match and returns a list
62+
file_name = get_globfiles(file_glob)[0]
63+
except ValueError:
64+
names = get_xija_model_names()
65+
raise ValueError(f'no models matched {model_name}. Available models are: '
66+
f'{", ".join(names)}')
67+
68+
return file_name
69+
70+
71+
def get_xija_model_names(models_path=MODELS_PATH) -> List[str]:
72+
"""Return list of available xija model names.
73+
74+
Parameters
75+
----------
76+
models_path : str, Path
77+
Path to directory containing xija model spec files (default is
78+
$SKA/data/chandra_models)
79+
80+
Examples
81+
--------
82+
>>> from xija.get_model_spec import get_xija_model_names
83+
>>> names = get_xija_model_names()
84+
['aca',
85+
'acisfp',
86+
'dea',
87+
'dpa',
88+
'4rt700t',
89+
'minusyz',
90+
'pm1thv2t',
91+
'pm2thv1t',
92+
'pm2thv2t',
93+
'pftank2t',
94+
'pline03t_model',
95+
'pline04t_model',
96+
'psmc',
97+
'tcylaft6']
98+
99+
Returns
100+
-------
101+
list
102+
List of available xija model names
103+
"""
104+
models_path = Path(models_path)
105+
106+
fns = get_globfiles(str(models_path / '*' / '*_spec.json'), minfiles=0, maxfiles=None)
107+
names = [re.sub(r'_spec\.json', '', Path(fn).name) for fn in sorted(fns)]
108+
109+
return names
110+
111+
112+
def get_repo_version() -> str:
113+
"""Return version (most recent tag) of models repository.
114+
115+
Returns
116+
-------
117+
str
118+
Version (most recent tag) of models repository
119+
"""
120+
repo = Repo(REPO_PATH)
121+
122+
if repo.is_dirty():
123+
raise ValueError('repo is dirty')
124+
125+
tags = sorted(repo.tags, key=lambda tag: tag.commit.committed_datetime)
126+
tag_repo = tags[-1]
127+
if tag_repo.commit != repo.head.commit:
128+
raise ValueError(f'repo tip is not at tag {tag_repo}')
129+
130+
return tag_repo.name
131+
132+
133+
def check_github_version(tag_name, url=CHANDRA_MODELS_URL, timeout=5) -> Optional[bool]:
134+
"""Check that latest chandra_models GitHub repo release matches ``tag_name``.
135+
136+
This queries GitHub for the latest release of chandra_models.
137+
138+
Parameters
139+
----------
140+
tag_name : str
141+
Tag name e.g. '3.32'
142+
url : str
143+
URL for chandra_models releases on GitHub API
144+
timeout : int, float
145+
Request timeout (sec, default=5)
146+
147+
Returns
148+
-------
149+
bool, None
150+
True if chandra_models release on GitHub matches tag_name.
151+
None if the request timed out, indicating indeterminate answer.
152+
"""
153+
try:
154+
req = requests.get(url, timeout=timeout)
155+
except requests.ConnectTimeout:
156+
return None
157+
158+
if req.status_code != requests.codes.ok:
159+
req.raise_for_status()
160+
161+
tags_gh = sorted(req.json(), key=lambda tag: tag['published_at'])
162+
tag_gh_name = tags_gh[-1]['tag_name']
163+
164+
return tag_gh_name == tag_name

xija/tests/test_get_model_spec.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2+
3+
import os
4+
import json
5+
from pathlib import Path
6+
import re
7+
import pytest
8+
import requests
9+
10+
from ..get_model_spec import (get_xija_model_file, get_xija_model_names,
11+
get_repo_version, check_github_version)
12+
13+
14+
def test_get_model_file_aca():
15+
fn = get_xija_model_file('aca')
16+
assert fn.startswith(os.environ['SKA'])
17+
assert Path(fn).name == 'aca_spec.json'
18+
spec = json.load(open(fn))
19+
assert spec['name'] == 'aacccdpt'
20+
21+
22+
def test_get_model_file_fail():
23+
with pytest.raises(ValueError, match='no models matched xxxyyyzzz'):
24+
get_xija_model_file('xxxyyyzzz')
25+
26+
with pytest.raises(FileNotFoundError, match='xija models directory'):
27+
get_xija_model_file('aca', models_path='__NOT_A_DIRECTORY__')
28+
29+
30+
def test_get_xija_model_names():
31+
names = get_xija_model_names()
32+
assert all(name in names for name in ('aca', 'acisfp', 'dea', 'dpa', 'pftank2t'))
33+
34+
35+
def test_get_repo_version():
36+
version = get_repo_version()
37+
assert isinstance(version, str)
38+
assert re.match(r'^[0-9.]+$', version)
39+
40+
41+
def test_check_github_version():
42+
version = get_repo_version()
43+
status = check_github_version(version)
44+
assert status is True
45+
46+
status = check_github_version('asdf')
47+
assert status is False
48+
49+
# Force timeout
50+
status = check_github_version(version, timeout=0.00001)
51+
assert status is None
52+
53+
with pytest.raises(requests.ConnectionError):
54+
check_github_version(version, 'https://______bad_url______')

0 commit comments

Comments
 (0)