diff --git a/jaxley/__init__.py b/jaxley/__init__.py
index 30c331ed..ce2c261a 100644
--- a/jaxley/__init__.py
+++ b/jaxley/__init__.py
@@ -11,3 +11,4 @@
from jaxley.modules import *
from jaxley.optimize import ParamTransform
from jaxley.stimulus import datapoint_to_step_currents, step_current
+from jaxley.utils.misc_utils import deprecated, deprecated_kwargs
diff --git a/jaxley/utils/misc_utils.py b/jaxley/utils/misc_utils.py
index d2a441f3..d78b1d40 100644
--- a/jaxley/utils/misc_utils.py
+++ b/jaxley/utils/misc_utils.py
@@ -1,6 +1,7 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see
+import warnings
from typing import List, Optional, Union
import jax.numpy as jnp
@@ -34,3 +35,59 @@ def is_str_all(arg, force: bool = True) -> bool:
assert arg == "all", "Only 'all' is allowed"
return arg == "all"
return False
+
+
+class deprecated:
+ """Decorator to mark a function as deprecated.
+
+ Can be used to mark functions that will be removed in future versions. This will
+ also be tested in the CI pipeline to ensure that deprecated functions are removed.
+
+ Warns with: "func_name is deprecated and will be removed in version version."
+
+ Args:
+ version: The version in which the function will be removed, i.e. "0.1.0".
+ amend_msg: An optional message to append to the deprecation warning.
+ """
+
+ def __init__(self, version: str, amend_msg: str = ""):
+ self._version: str = version
+ self._amend_msg: str = amend_msg
+
+ def __call__(self, func):
+ def wrapper(*args, **kwargs):
+ msg = f"{func.__name__} is deprecated and will be removed in version {self._version}."
+ warnings.warn(msg + self._amend_msg)
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+class deprecated_kwargs:
+ """Decorator to mark a keyword arguemnt of a function as deprecated.
+
+ Can be used to mark kwargs that will be removed in future versions. This will
+ also be tested in the CI pipeline to ensure that deprecated kwargs are removed.
+
+ Warns with: "kwarg is deprecated and will be removed in version version."
+
+ Args:
+ version: The version in which the keyword argument will be removed, i.e. "0.1.0".
+ deprecated_kwargs: A list of keyword arguments that are deprecated.
+ amend_msg: An optional message to append to the deprecation warning.
+ """
+
+ def __init__(self, version: str, kwargs: List = [], amend_msg: str = ""):
+ self._version: str = version
+ self._amend_msg: str = amend_msg
+ self._depcrecated_kwargs: List = kwargs
+
+ def __call__(self, func):
+ def wrapper(*args, **kwargs):
+ for deprecated_kwarg in self._depcrecated_kwargs:
+ if deprecated_kwarg in kwargs and kwargs[deprecated_kwarg] is not None:
+ msg = f"{deprecated_kwarg} is deprecated and will be removed in version {self._version}."
+ warnings.warn(msg + self._amend_msg)
+ return func(*args, **kwargs)
+
+ return wrapper
diff --git a/tests/test_license.py b/tests/test_license.py
deleted file mode 100644
index 7c7041df..00000000
--- a/tests/test_license.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
-# licensed under the Apache License Version 2.0, see
-
-import os
-
-import pytest
-
-
-def list_files(directory):
- for root, dirs, files in os.walk(directory):
- for file in files:
- if file.endswith(".py"):
- yield os.path.join(root, file)
-
-
-license_txt = """# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
-# licensed under the Apache License Version 2.0, see """
-
-
-@pytest.mark.parametrize("dir", ["../jaxley", "."])
-def test_license(dir):
- for i, file in enumerate(list_files(dir)):
- with open(file, "r") as f:
- header = f.read(len(license_txt))
- assert (
- header == license_txt
- ), f"File {file} does not have the correct license header"
diff --git a/tests/test_misc.py b/tests/test_misc.py
new file mode 100644
index 00000000..75698747
--- /dev/null
+++ b/tests/test_misc.py
@@ -0,0 +1,60 @@
+# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
+# licensed under the Apache License Version 2.0, see
+
+import os
+import re
+from pathlib import Path
+from typing import List
+
+import numpy as np
+import pytest
+
+
+def list_files(directory):
+ for root, dirs, files in os.walk(directory):
+ for file in files:
+ if file.endswith(".py"):
+ yield os.path.join(root, file)
+
+
+license_txt = """# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
+# licensed under the Apache License Version 2.0, see """
+
+
+@pytest.mark.parametrize("dir", ["../jaxley", "."])
+def test_license(dir):
+ for i, file in enumerate(list_files(dir)):
+ with open(file, "r") as f:
+ header = f.read(len(license_txt))
+ assert (
+ header == license_txt
+ ), f"File {file} does not have the correct license header"
+
+
+def test_rm_all_deprecated_functions():
+ from jaxley.__version__ import __version__ as package_version
+
+ package_version = np.array([int(s) for s in package_version.split(".")])
+
+ decorator_pattern = r"@deprecated(?:_signature)?"
+ version_pattern = r"[v]?(\d+\.\d+\.\d+)"
+
+ package_dir = Path(__file__).parent.parent / "jaxley"
+
+ violations = []
+ for py_file in package_dir.rglob("*.py"):
+ with open(py_file, "r") as f:
+ for line_num, line in enumerate(f, 1):
+ if re.search(decorator_pattern, line):
+ version_match = re.search(version_pattern, line)
+ if version_match:
+ depr_version_str = version_match.group(1)
+ depr_version = np.array(
+ [int(s) for s in depr_version_str.split(".")]
+ )
+ if not np.all(package_version <= depr_version):
+ violations.append(f"{py_file}:L{line_num}")
+
+ assert not violations, "\n".join(
+ ["Found deprecated items that should have been removed:", *violations]
+ )