Skip to content

Commit

Permalink
Merge pull request #62 from greglucas/perf-avoid-stack
Browse files Browse the repository at this point in the history
PERF: avoid stacking in input creation and isclose comparisons
  • Loading branch information
greglucas authored Nov 15, 2024
2 parents be2cb86 + 0471f96 commit 76b2d92
Showing 1 changed file with 29 additions and 41 deletions.
70 changes: 29 additions & 41 deletions pymsis/msis.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,8 @@ def run(
)

# The Fortran code puts 9.9e-38 in as NaN
# Have to make sure this doesn't overlap 0 due to really small values
# so atol should be less than the comparison value
output[np.isclose(output, 9.9e-38, atol=1e-38)] = np.nan
# or 9.99e-38, or 9.999e-38, so lets just bound the 9s
output[(output >= 9.9e-38) & (output < 1e-37)] = np.nan # noqa: PLR2004

return output.reshape(*input_shape, 11)

Expand Down Expand Up @@ -364,41 +363,30 @@ def create_input(
# This means the data came in preflattened, from a satellite
# trajectory for example, where we don't want to make a grid
# out of the input data, we just want to stack it together.
arr = np.stack([dyear, dseconds, lons, lats, alts, f107s, f107as], -1)

# ap has 7 components, so we need to concatenate it onto the
# arrays rather than stack
flattened_input = np.concatenate([arr, aps], axis=1, dtype=np.float32)
return (ndates,), flattened_input

# Make a grid of indices
indices = np.stack(
np.meshgrid(
np.arange(ndates),
np.arange(nlons),
np.arange(nlats),
np.arange(nalts),
indexing="ij",
),
-1,
).reshape(-1, 4)

# Now stack all of the arrays, indexing by the proper indices
arr = np.stack(
[
dyear[indices[:, 0]],
dseconds[indices[:, 0]],
lons[indices[:, 1]],
lats[indices[:, 2]],
alts[indices[:, 3]],
f107s[indices[:, 0]],
f107as[indices[:, 0]],
],
-1,
)
# ap has 7 components, so we need to concatenate it onto the
# arrays rather than stack
flattened_input = np.concatenate(
[arr, aps[indices[:, 0], :]], axis=1, dtype=np.float32
)
return (ndates, nlons, nlats, nalts), flattened_input
# Create an array to hold all of the data
# F-ordering so we can pass by reference to the Fortran code
arr = np.empty((ndates, 14), dtype=np.float32, order="F")
arr[:, 0] = dyear
arr[:, 1] = dseconds
arr[:, 2] = lons
arr[:, 3] = lats
arr[:, 4] = alts
arr[:, 5] = f107s
arr[:, 6] = f107as
arr[:, 7:] = aps
return (ndates,), arr

# Use broadcasting to fill each column directly
# This is much faster than creating an indices array and then
# using that to fill the columns
arr = np.empty((ndates * nlons * nlats * nalts, 14), dtype=np.float32, order="F")
arr[:, 0] = np.repeat(dyear, nlons * nlats * nalts) # dyear
arr[:, 1] = np.repeat(dseconds, nlons * nlats * nalts) # dseconds
arr[:, 2] = np.tile(np.repeat(lons, nlats * nalts), ndates) # lons
arr[:, 3] = np.tile(np.repeat(lats, nalts), ndates * nlons) # lats
arr[:, 4] = np.tile(alts, ndates * nlons * nlats) # alts
arr[:, 5] = np.repeat(f107s, nlons * nlats * nalts) # f107s
arr[:, 6] = np.repeat(f107as, nlons * nlats * nalts) # f107as
arr[:, 7:] = np.repeat(aps, nlons * nlats * nalts, axis=0) # aps

return (ndates, nlons, nlats, nalts), arr

0 comments on commit 76b2d92

Please sign in to comment.