Skip to content

Commit 2067009

Browse files
authored
Merge pull request #366 from adrn/reorder-nbody
Fix inefficient nbody orbit reorder
2 parents a8be7d3 + 1a21825 commit 2067009

File tree

3 files changed

+39
-17
lines changed

3 files changed

+39
-17
lines changed

CHANGES.rst

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ Bug fixes
1515
class to enable using the new (correct) parameter values, but the default will
1616
continue to use the Gala modified values (for backwards compatibility).
1717

18+
- Improved internal efficiency of ``DirectNBody``.
19+
1820

1921
API changes
2022
-----------

gala/dynamics/nbody/core.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,7 @@ def integrate_orbit(self, Integrator=None, Integrator_kwargs=dict(), **time_spec
259259
frame=self.frame,
260260
)
261261

262-
# Reorder orbits:
263-
remap_idx = np.zeros((orbits.shape[-1], orbits.shape[-1]), dtype=int)
264-
remap_idx[idx, np.arange(orbits.shape[-1])] = 1
265-
_, undo_idx = np.where(remap_idx == 1)
266-
267-
return orbits[..., undo_idx]
268-
remap_idx[idx, np.arange(orbits.shape[-1])] = 1
269-
_, undo_idx = np.where(remap_idx == 1)
262+
# Reorder orbits to original order:
263+
undo_idx = np.argsort(idx)
270264

271265
return orbits[..., undo_idx]

gala/dynamics/nbody/tests/test_nbody.py

+35-9
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,23 @@
33
import numpy as np
44
import pytest
55

6-
# Custom
7-
from gala.potential import (
8-
NullPotential,
9-
NFWPotential,
10-
HernquistPotential,
11-
ConstantRotatingFrame,
12-
StaticFrame,
13-
)
146
from gala.dynamics import PhaseSpacePosition, combine
15-
from gala.units import UnitSystem, galactic
167
from gala.integrate import (
178
DOPRI853Integrator,
189
LeapfrogIntegrator,
1910
Ruth4Integrator,
2011
)
2112

13+
# Custom
14+
from gala.potential import (
15+
ConstantRotatingFrame,
16+
HernquistPotential,
17+
NFWPotential,
18+
NullPotential,
19+
StaticFrame,
20+
)
21+
from gala.units import UnitSystem, galactic
22+
2223
# Project
2324
from ..core import DirectNBody
2425

@@ -220,3 +221,28 @@ def test_directnbody_integrate_rotframe(self, Integrator):
220221

221222
assert u.allclose(orbits_static.xyz, orbits_static.xyz)
222223
assert u.allclose(orbits2.v_xyz, orbits2.v_xyz)
224+
225+
@pytest.mark.parametrize("Integrator", [DOPRI853Integrator])
226+
def test_nbody_reorder(self, Integrator):
227+
N = 16
228+
rng = np.random.default_rng(seed=42)
229+
w0 = PhaseSpacePosition(
230+
pos=rng.normal(0, 5, size=(3, N)) * u.kpc,
231+
vel=rng.normal(0, 50, size=(3, N)) * u.km / u.s,
232+
)
233+
pots = [
234+
(
235+
HernquistPotential(1e9 * u.Msun, 1.0 * u.pc, units=galactic)
236+
if rng.uniform() > 0.5
237+
else None
238+
)
239+
for _ in range(N)
240+
]
241+
sim = DirectNBody(
242+
w0,
243+
pots,
244+
external_potential=HernquistPotential(1e12, 10, units=galactic),
245+
units=galactic,
246+
)
247+
orbits = sim.integrate_orbit(dt=1.0 * u.Myr, t1=0, t2=100 * u.Myr)
248+
assert np.allclose(orbits.pos[0].xyz, w0.pos.xyz)

0 commit comments

Comments
 (0)