Skip to content

Commit

Permalink
wip: get new baselines
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Nov 22, 2024
1 parent 8e6de5a commit d30d790
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 37 deletions.
21 changes: 13 additions & 8 deletions .github/workflows/regression_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,19 @@ jobs:
python -m pip install --upgrade pip
pip install -e ".[dev]"
# - name: Run benchmarks and compare to baseline
# if: github.event.pull_request.base.ref == 'main'
# run: |
# # Check if regression test results exist in main branch
# if [ -f 'git cat-file -e main:tests/regression_test_baselines.json' ]; then
# git checkout main tests/regression_test_baselines.json
# else
# echo "No regression test results found in main branch"
# fi
# pytest -m regression
# git checkout

- name: Run benchmarks and compare to baseline
if: github.event.pull_request.base.ref == 'main'
run: |
# # Check if regression test results exist in main branch
# if [ -f 'git cat-file -e main:tests/regression_test_baselines.json' ]; then
# git checkout main tests/regression_test_baselines.json
# else
# echo "No regression test results found in main branch"
# fi
pytest -m regression
# git checkout
pytest -m regression
54 changes: 25 additions & 29 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,13 @@ def build_net(num_cells, artificial=True, connect=True, connection_prob=0.0):
(
# Test a single SWC cell with both solvers.
pytest.param(1, False, False, 0.0, "jaxley.stone"),
# pytest.param(1, False, False, 0.0, "jax.sparse"),
# # Test a network of SWC cells with both solvers.
# pytest.param(10, False, True, 0.1, "jaxley.stone"),
# pytest.param(10, False, True, 0.1, "jax.sparse"),
# # Test a larger network of smaller neurons with both solvers.
# pytest.param(1000, True, True, 0.001, "jaxley.stone"),
# pytest.param(1000, True, True, 0.001, "jax.sparse"),
pytest.param(1, False, False, 0.0, "jax.sparse"),
# Test a network of SWC cells with both solvers.
pytest.param(10, False, True, 0.1, "jaxley.stone"),
pytest.param(10, False, True, 0.1, "jax.sparse"),
# Test a larger network of smaller neurons with both solvers.
pytest.param(1000, True, True, 0.001, "jaxley.stone"),
pytest.param(1000, True, True, 0.001, "jax.sparse"),
),
)
@compare_to_baseline(baseline_iters=3)
Expand All @@ -219,41 +219,37 @@ def test_runtime(
connection_prob: float,
voltage_solver: str,
):
import time
delta_t = 0.025
t_max = 100.0

# def simulate(params):
# return jx.integrate(
# net,
# params=params,
# t_max=t_max,
# delta_t=delta_t,
# voltage_solver=voltage_solver,
# )
def simulate(params):
return jx.integrate(
net,
params=params,
t_max=t_max,
delta_t=delta_t,
voltage_solver=voltage_solver,
)

runtimes = {}

start_time = time.time()
# net, params = build_net(
# num_cells,
# artificial=artificial,
# connect=connect,
# connection_prob=connection_prob,
# )
time.sleep(0.1)
net, params = build_net(
num_cells,
artificial=artificial,
connect=connect,
connection_prob=connection_prob,
)
runtimes["build_time"] = time.time() - start_time

# jitted_simulate = jit(simulate)
jitted_simulate = jit(simulate)

start_time = time.time()
time.sleep(0.31)
# _ = jitted_simulate(params).block_until_ready()
_ = jitted_simulate(params).block_until_ready()
runtimes["compile_time"] = time.time() - start_time
# params[0]["radius"] = params[0]["radius"].at[0].set(0.5)
params[0]["radius"] = params[0]["radius"].at[0].set(0.5)

start_time = time.time()
# _ = jitted_simulate(params).block_until_ready()
time.sleep(0.21)
_ = jitted_simulate(params).block_until_ready()
runtimes["run_time"] = time.time() - start_time
return runtimes # @compare_to_baseline decorator will compare this to the baseline

0 comments on commit d30d790

Please sign in to comment.