Skip to content

Commit

Permalink
PERF: Avoid stacking numpy arrays and pre-allocate our needed arrays
Browse files Browse the repository at this point in the history
We can stack the arrays manually by setting the contents directly
through Numpy's indexing. Additionally, this way we can allocate
the arrow with Fortran ordering so it should be faster to pass
to the underlying codes.
  • Loading branch information
greglucas committed Nov 15, 2024
1 parent be2cb86 commit 6f8d2d7
Showing 1 changed file with 27 additions and 38 deletions.
65 changes: 27 additions & 38 deletions pymsis/msis.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,41 +364,30 @@ def create_input(
# This means the data came in preflattened, from a satellite
# trajectory for example, where we don't want to make a grid
# out of the input data, we just want to stack it together.
arr = np.stack([dyear, dseconds, lons, lats, alts, f107s, f107as], -1)

# ap has 7 components, so we need to concatenate it onto the
# arrays rather than stack
flattened_input = np.concatenate([arr, aps], axis=1, dtype=np.float32)
return (ndates,), flattened_input

# Make a grid of indices
indices = np.stack(
np.meshgrid(
np.arange(ndates),
np.arange(nlons),
np.arange(nlats),
np.arange(nalts),
indexing="ij",
),
-1,
).reshape(-1, 4)

# Now stack all of the arrays, indexing by the proper indices
arr = np.stack(
[
dyear[indices[:, 0]],
dseconds[indices[:, 0]],
lons[indices[:, 1]],
lats[indices[:, 2]],
alts[indices[:, 3]],
f107s[indices[:, 0]],
f107as[indices[:, 0]],
],
-1,
)
# ap has 7 components, so we need to concatenate it onto the
# arrays rather than stack
flattened_input = np.concatenate(
[arr, aps[indices[:, 0], :]], axis=1, dtype=np.float32
)
return (ndates, nlons, nlats, nalts), flattened_input
# Create an array to hold all of the data
# F-ordering so we can pass by reference to the Fortran code
arr = np.empty((ndates, 14), dtype=np.float32, order="F")
arr[:, 0] = dyear
arr[:, 1] = dseconds
arr[:, 2] = lons
arr[:, 3] = lats
arr[:, 4] = alts
arr[:, 5] = f107s
arr[:, 6] = f107as
arr[:, 7:] = aps
return (ndates,), arr

# Use broadcasting to fill each column directly
# This is much faster than creating an indices array and then
# using that to fill the columns
arr = np.empty((ndates * nlons * nlats * nalts, 14), dtype=np.float32, order="F")
arr[:, 0] = np.repeat(dyear, nlons * nlats * nalts) # dyear
arr[:, 1] = np.repeat(dseconds, nlons * nlats * nalts) # dseconds
arr[:, 2] = np.tile(np.repeat(lons, nlats * nalts), ndates) # lons
arr[:, 3] = np.tile(np.repeat(lats, nalts), ndates * nlons) # lats
arr[:, 4] = np.tile(alts, ndates * nlons * nlats) # alts
arr[:, 5] = np.repeat(f107s, nlons * nlats * nalts) # f107s
arr[:, 6] = np.repeat(f107as, nlons * nlats * nalts) # f107as
arr[:, 7:] = np.repeat(aps, nlons * nlats * nalts, axis=0) # aps

return (ndates, nlons, nlats, nalts), arr

0 comments on commit 6f8d2d7

Please sign in to comment.