Skip to content

Commit

Permalink
Merge pull request #3423 from jsbrittain/jax_gpu
Browse files Browse the repository at this point in the history
JaxSolver fails when using GPU support with no input parameters
  • Loading branch information
Saransh-cpp authored Oct 31, 2023
2 parents 618b481 + 6cc3940 commit 138cbf2
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)

## Bug fixes

- Fixed a bug where the JaxSolver would fails when using GPU support with no input parameters ([#3423](https://github.com/pybamm-team/PyBaMM/pull/3423))

# [v23.9rc0](https://github.com/pybamm-team/PyBaMM/tree/v23.9rc0) - 2023-10-31

## Features
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
"navbar_end": ["theme-switcher", "navbar-icon-links"],
# add Algolia to the persistent navbar, this removes the default search icon
"navbar_persistent": "algolia-searchbox",
"navigation_with_keys": False,
"use_edit_page_button": True,
"analytics": {
"plausible_analytics_domain": "docs.pybamm.org",
Expand Down
2 changes: 1 addition & 1 deletion pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _integrate(self, model, t_eval, inputs=None):

y = []
platform = jax.lib.xla_bridge.get_backend().platform.casefold()
if platform.startswith("cpu"):
if len(inputs) <= 1 or platform.startswith("cpu"):
# cpu execution runs faster when multithreaded
async def solve_model_for_inputs():
async def solve_model_async(inputs_v):
Expand Down

0 comments on commit 138cbf2

Please sign in to comment.