diff --git a/jaxley/channels/channel.py b/jaxley/channels/channel.py index 18120ced..71791cee 100644 --- a/jaxley/channels/channel.py +++ b/jaxley/channels/channel.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from typing import Dict, Optional, Tuple from warnings import warn + import jax.numpy as jnp @@ -25,29 +26,19 @@ def __init__(self, name: Optional[str] = None): "michael.deistler@uni-tuebingen.de or create an issue on Github: " "https://github.com/jaxleyverse/jaxley/issues. Thank you!" ) - if not hasattr(self, "current_is_in_mA_per_cm2"): + if ( + not hasattr(self, "current_is_in_mA_per_cm2") + and self.current_is_in_mA_per_cm2 + ): raise ValueError( "The channel you are using is deprecated. " "In Jaxley version 0.5.0, we changed the unit of the current returned " "by `compute_current` of channels from `uA/cm^2` to `mA/cm^2`. Please " "update your channel model (by dividing the resulting current by 1000) " "and set `self.current_is_in_mA_per_cm2=True` as the first line " - "in the `__init__()` method of your channel. Alternatively, you can " - "make Jaxley use `uA/cm^2 by setting " - "`self.current_is_in_mA_per_cm2=False` as the " - f"first line in the `__init__()` method of your channel. {contact}" - ) - if not self.current_is_in_mA_per_cm2: - warn( - "You are using `current_is_in_mA_per_cm2=False`. This means that the " - "current through the channel will be considered in `uA/cm^2`. In " - "future versions of Jaxley, we will remove this option and force the " - "channel to return the current in `mA/cm^2`. We recommend to adapt " - "your `compute_current()` method now and return the current in " - "`mA/cm^2` (by dividing the current by 1000). After this has been " - "done, set `current_is_in_mA_per_cm2=True` to get rid of this " - f"warning. {contact}" + "in the `__init__()` method of your channel. {contact}" ) + self._name = name if name else self.__class__.__name__ @property diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 79102bd9..acf300e2 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1806,15 +1806,9 @@ def _channel_currents( voltage_term = (membrane_currents[1] - membrane_currents[0]) / diff constant_term = membrane_currents[0] - voltage_term * voltages[indices] - if channel.current_is_in_mA_per_cm2: - # * 1000 to convert from mA/cm^2 to uA/cm^2. - voltage_terms = voltage_terms.at[indices].add(voltage_term * 1000.0) - constant_terms = constant_terms.at[indices].add(-constant_term * 1000.0) - else: - # If `current_is_in_mA_per_cm2=False` then the current is assumed as - # `uA/cm^2`. - voltage_terms = voltage_terms.at[indices].add(voltage_term) - constant_terms = constant_terms.at[indices].add(-constant_term) + # * 1000 to convert from mA/cm^2 to uA/cm^2. + voltage_terms = voltage_terms.at[indices].add(voltage_term * 1000.0) + constant_terms = constant_terms.at[indices].add(-constant_term * 1000.0) # Save the current (for the unperturbed voltage) as a state that will # also be passed to the state update. diff --git a/tests/jaxley_identical/test_channel_current_units.py b/tests/jaxley_identical/test_channel_current_units.py deleted file mode 100644 index b2a1654c..00000000 --- a/tests/jaxley_identical/test_channel_current_units.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Dict, Optional - -import jax.numpy as jnp - -import jaxley as jx -from jaxley.channels import Channel, Leak - - -class LeakOldConvention(Channel): - """Leak current""" - - def __init__(self, name: Optional[str] = None): - self.current_is_in_mA_per_cm2 = False - - super().__init__(name) - prefix = self._name - self.channel_params = { - f"{prefix}_gLeak": 1e-4, - f"{prefix}_eLeak": -70.0, - } - self.channel_states = {} - self.current_name = f"i_{prefix}" - - def update_states(self, states, dt, v, params): - """No state to update.""" - return {} - - def compute_current( - self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] - ): - """Return current.""" - prefix = self._name - gLeak = params[f"{prefix}_gLeak"] # S/cm^2 - return gLeak * (v - params[f"{prefix}_eLeak"]) * 1000.0 # mA/cm^2 -> uA/cm^2. - - def init_state(self, states, v, params, delta_t): - return {} - - -def test_same_result_for_both_current_units(): - """Test whether two channels (with old and new unit convention) match.""" - current = jx.step_current(1.0, 2.0, 0.01, 0.025, 5.0) - comp1 = jx.Compartment() - comp2 = jx.Compartment() - - comp1.insert(LeakOldConvention()) - comp2.insert(Leak()) - - comp1.record() - comp1.stimulate(current) - comp2.record() - comp2.stimulate(current) - - v1 = jx.integrate(comp1) - v2 = jx.integrate(comp2) - - assert jnp.allclose(v1, v2)