Skip to content

Commit 1a0dfc2

Browse files
authored
Merge pull request #13 from SWxTREC/flattened-input-arrays
Handle flattened input arrays
2 parents 263d2d6 + b6bca2e commit 1a0dfc2

File tree

6 files changed

+58
-24
lines changed

6 files changed

+58
-24
lines changed

.coveragerc

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[run]
2+
branch = true
3+
source = pymsis
4+
omit = pymsis/__init__.py

.github/workflows/ci.yml

+4-10
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ jobs:
5353
- name: Install dependencies
5454
run: |
5555
python -m pip install --upgrade pip
56-
pip install flake8 pytest
57-
pip install -r requirements.txt
56+
pip install -r requirements-test.txt
5857
5958
- name: Install pymsis
6059
run: pip install -v .
@@ -65,12 +64,7 @@ jobs:
6564
flake8 . --count --show-source --statistics
6665
6766
- name: Test with pytest
68-
run: pytest --junitxml=junit/test-results-${{ matrix.python-version }}.xml
67+
run: pytest --color=yes --cov --cov-report=xml
6968

70-
- name: Upload pytest test results
71-
uses: actions/upload-artifact@v2
72-
with:
73-
name: pytest-results-${{ matrix.python-version }}
74-
path: junit/test-results-${{ matrix.python-version }}.xml
75-
# Use always() to always run this step to publish test results when there are test failures
76-
if: ${{ always() }}
69+
- name: Upload code coverage
70+
uses: codecov/codecov-action@v2

codecov.yml

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
comment: false

pymsis/msis.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77

88
def run(dates, lons, lats, alts, f107s, f107as, aps,
99
options=None, version=2):
10-
"""Call MSIS looping over all possible inputs.
10+
"""
11+
Call MSIS looping over all possible inputs. If ndates is
12+
the same as nlons, nlats, and nalts, then a flattened
13+
multi-point input array is assumed. Otherwise, the data
14+
will be expanded in a grid-like fashion. The possible
15+
return shapes are therefore (ndates, 11) and
16+
(ndates, nlons, nlats, nalts, 11).
1117
1218
Parameters
1319
----------
@@ -32,7 +38,7 @@ def run(dates, lons, lats, alts, f107s, f107as, aps,
3238
3339
Returns
3440
-------
35-
ndarray (ndates, nlons, nlats, nalts, 11)
41+
ndarray (ndates, nlons, nlats, nalts, 11) or (ndates, 11)
3642
| The data calculated at each grid point:
3743
| [Total mass density (kg/m3)
3844
| N2 # density (m-3),
@@ -168,7 +174,8 @@ def create_input(dates, lons, lats, alts, f107s, f107as, aps):
168174
(shape, flattened_input)
169175
The shape of the data as a tuple (ndates, nlons, nlats, nalts) and
170176
the flattened version of the input data
171-
(ndates*nlons*nlats*nalts, 14).
177+
(ndates*nlons*nlats*nalts, 14). If the input array was preflattened
178+
(ndates == nlons == nlats == nalts), then the shape is (ndates,).
172179
"""
173180
# Turn everything into arrays
174181
dates = np.atleast_1d(np.array(dates, dtype='datetime64'))
@@ -194,18 +201,30 @@ def create_input(dates, lons, lats, alts, f107s, f107as, aps):
194201
nlons = len(lons)
195202
nlats = len(lats)
196203
nalts = len(alts)
197-
shape = (ndates, nlons, nlats, nalts)
198204

199205
if not (ndates == len(f107s) == len(f107as) == len(aps)):
200206
raise ValueError(f"The length of dates ({ndates}), f107s "
201207
f"({len(f107s)}), f107as ({len(f107as)}), "
202208
f"and aps ({len(aps)}) must all be equal")
203209

210+
if ndates == nlons == nlats == nalts:
211+
# This means the data came in preflattened, from a satellite
212+
# trajectory for example, where we don't want to make a grid
213+
# out of the input data, we just want to stack it together.
214+
arr = np.stack([dyear, dseconds, lons, lats, alts, f107s, f107as], -1)
215+
216+
# ap has 7 components, so we need to concatenate it onto the
217+
# arrays rather than stack
218+
flattened_input = np.concatenate([arr, aps], axis=1,
219+
dtype=np.float32)
220+
shape = (ndates,)
221+
return shape, flattened_input
222+
204223
# Make a grid of indices
205224
indices = np.stack(np.meshgrid(np.arange(ndates),
206-
np.arange(nlons),
207-
np.arange(nlats),
208-
np.arange(nalts), indexing='ij'),
225+
np.arange(nlons),
226+
np.arange(nlats),
227+
np.arange(nalts), indexing='ij'),
209228
-1).reshape(-1, 4)
210229

211230
# Now stack all of the arrays, indexing by the proper indices
@@ -215,5 +234,7 @@ def create_input(dates, lons, lats, alts, f107s, f107as, aps):
215234
f107s[indices[:, 0]], f107as[indices[:, 0]]], -1)
216235
# ap has 7 components, so we need to concatenate it onto the
217236
# arrays rather than stack
218-
return shape, np.concatenate([arr, aps[indices[:, 0], :]], axis=1,
219-
dtype=np.float32)
237+
flattened_input = np.concatenate([arr, aps[indices[:, 0], :]], axis=1,
238+
dtype=np.float32)
239+
shape = (ndates, nlons, nlats, nalts)
240+
return shape, flattened_input
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
-r requirements.txt
22
flake8
33
pytest
4+
pytest-cov

tests/test_msis.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_create_options():
6060

6161
def test_create_input_single_point(input_data, expected_input):
6262
shape, data = msis.create_input(*input_data)
63-
assert shape == (1, 1, 1, 1)
63+
assert shape == (1,)
6464
assert data.shape == (1, 14)
6565
assert_array_equal(data[0, :], expected_input)
6666

@@ -70,7 +70,7 @@ def test_create_input_datetime(input_data, expected_input):
7070
# .item() gets the datetime object from the np.datetime64 object
7171
input_data = (input_data[0].item(),) + input_data[1:]
7272
shape, data = msis.create_input(*input_data)
73-
assert shape == (1, 1, 1, 1)
73+
assert shape == (1,)
7474
assert data.shape == (1, 14)
7575
assert_array_equal(data[0, :], expected_input)
7676

@@ -144,11 +144,11 @@ def test_run_options(input_data):
144144

145145
def test_run_single_point(input_data, expected_output):
146146
output = msis.run(*input_data)
147-
assert output.shape == (1, 1, 1, 1, 11)
147+
assert output.shape == (1, 11)
148148
assert_allclose(np.squeeze(output), expected_output, rtol=1e-5)
149149

150150

151-
def test_run_multi_point(input_data, expected_output):
151+
def test_run_gridded_multi_point(input_data, expected_output):
152152
date, lon, lat, alt, f107, f107a, ap = input_data
153153
# 5 x 5 surface
154154
input_data = (date, [lon]*5, [lat]*5, alt, f107, f107a, ap)
@@ -158,6 +158,19 @@ def test_run_multi_point(input_data, expected_output):
158158
assert_allclose(np.squeeze(output), expected, rtol=1e-5)
159159

160160

161+
def test_run_multi_point(input_data, expected_output):
162+
# test multi-point run, like a satellite fly-through
163+
# and make sure we don't grid the input data
164+
# 5 input points
165+
date, lon, lat, alt, f107, f107a, ap = input_data
166+
input_data = ([date]*5, [lon]*5, [lat]*5, [alt]*5,
167+
[f107]*5, [f107a]*5, ap*5)
168+
output = msis.run(*input_data)
169+
assert output.shape == (5, 11)
170+
expected = np.tile(expected_output, (5, 1))
171+
assert_allclose(np.squeeze(output), expected, rtol=1e-5)
172+
173+
161174
def test_run_wrapped_lon(input_data, expected_output):
162175
date, _, lat, alt, f107, f107a, ap = input_data
163176

@@ -204,7 +217,7 @@ def test_run_versions(input_data):
204217

205218
def test_run_version00(input_data, expected_output00):
206219
output = msis.run(*input_data, version=0)
207-
assert output.shape == (1, 1, 1, 1, 11)
220+
assert output.shape == (1, 11)
208221
assert_allclose(np.squeeze(output), expected_output00, rtol=1e-5)
209222

210223

0 commit comments

Comments
 (0)