diff --git a/jaxley/__init__.py b/jaxley/__init__.py index 30c331ed..ce2c261a 100644 --- a/jaxley/__init__.py +++ b/jaxley/__init__.py @@ -11,3 +11,4 @@ from jaxley.modules import * from jaxley.optimize import ParamTransform from jaxley.stimulus import datapoint_to_step_currents, step_current +from jaxley.utils.misc_utils import deprecated, deprecated_kwargs diff --git a/jaxley/utils/misc_utils.py b/jaxley/utils/misc_utils.py index d2a441f3..d78b1d40 100644 --- a/jaxley/utils/misc_utils.py +++ b/jaxley/utils/misc_utils.py @@ -1,6 +1,7 @@ # This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is # licensed under the Apache License Version 2.0, see +import warnings from typing import List, Optional, Union import jax.numpy as jnp @@ -34,3 +35,59 @@ def is_str_all(arg, force: bool = True) -> bool: assert arg == "all", "Only 'all' is allowed" return arg == "all" return False + + +class deprecated: + """Decorator to mark a function as deprecated. + + Can be used to mark functions that will be removed in future versions. This will + also be tested in the CI pipeline to ensure that deprecated functions are removed. + + Warns with: "func_name is deprecated and will be removed in version version." + + Args: + version: The version in which the function will be removed, i.e. "0.1.0". + amend_msg: An optional message to append to the deprecation warning. + """ + + def __init__(self, version: str, amend_msg: str = ""): + self._version: str = version + self._amend_msg: str = amend_msg + + def __call__(self, func): + def wrapper(*args, **kwargs): + msg = f"{func.__name__} is deprecated and will be removed in version {self._version}." + warnings.warn(msg + self._amend_msg) + return func(*args, **kwargs) + + return wrapper + + +class deprecated_kwargs: + """Decorator to mark a keyword arguemnt of a function as deprecated. + + Can be used to mark kwargs that will be removed in future versions. This will + also be tested in the CI pipeline to ensure that deprecated kwargs are removed. + + Warns with: "kwarg is deprecated and will be removed in version version." + + Args: + version: The version in which the keyword argument will be removed, i.e. "0.1.0". + deprecated_kwargs: A list of keyword arguments that are deprecated. + amend_msg: An optional message to append to the deprecation warning. + """ + + def __init__(self, version: str, kwargs: List = [], amend_msg: str = ""): + self._version: str = version + self._amend_msg: str = amend_msg + self._depcrecated_kwargs: List = kwargs + + def __call__(self, func): + def wrapper(*args, **kwargs): + for deprecated_kwarg in self._depcrecated_kwargs: + if deprecated_kwarg in kwargs and kwargs[deprecated_kwarg] is not None: + msg = f"{deprecated_kwarg} is deprecated and will be removed in version {self._version}." + warnings.warn(msg + self._amend_msg) + return func(*args, **kwargs) + + return wrapper diff --git a/tests/test_license.py b/tests/test_license.py deleted file mode 100644 index 7c7041df..00000000 --- a/tests/test_license.py +++ /dev/null @@ -1,27 +0,0 @@ -# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is -# licensed under the Apache License Version 2.0, see - -import os - -import pytest - - -def list_files(directory): - for root, dirs, files in os.walk(directory): - for file in files: - if file.endswith(".py"): - yield os.path.join(root, file) - - -license_txt = """# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is -# licensed under the Apache License Version 2.0, see """ - - -@pytest.mark.parametrize("dir", ["../jaxley", "."]) -def test_license(dir): - for i, file in enumerate(list_files(dir)): - with open(file, "r") as f: - header = f.read(len(license_txt)) - assert ( - header == license_txt - ), f"File {file} does not have the correct license header" diff --git a/tests/test_misc.py b/tests/test_misc.py new file mode 100644 index 00000000..75698747 --- /dev/null +++ b/tests/test_misc.py @@ -0,0 +1,60 @@ +# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is +# licensed under the Apache License Version 2.0, see + +import os +import re +from pathlib import Path +from typing import List + +import numpy as np +import pytest + + +def list_files(directory): + for root, dirs, files in os.walk(directory): + for file in files: + if file.endswith(".py"): + yield os.path.join(root, file) + + +license_txt = """# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is +# licensed under the Apache License Version 2.0, see """ + + +@pytest.mark.parametrize("dir", ["../jaxley", "."]) +def test_license(dir): + for i, file in enumerate(list_files(dir)): + with open(file, "r") as f: + header = f.read(len(license_txt)) + assert ( + header == license_txt + ), f"File {file} does not have the correct license header" + + +def test_rm_all_deprecated_functions(): + from jaxley.__version__ import __version__ as package_version + + package_version = np.array([int(s) for s in package_version.split(".")]) + + decorator_pattern = r"@deprecated(?:_signature)?" + version_pattern = r"[v]?(\d+\.\d+\.\d+)" + + package_dir = Path(__file__).parent.parent / "jaxley" + + violations = [] + for py_file in package_dir.rglob("*.py"): + with open(py_file, "r") as f: + for line_num, line in enumerate(f, 1): + if re.search(decorator_pattern, line): + version_match = re.search(version_pattern, line) + if version_match: + depr_version_str = version_match.group(1) + depr_version = np.array( + [int(s) for s in depr_version_str.split(".")] + ) + if not np.all(package_version <= depr_version): + violations.append(f"{py_file}:L{line_num}") + + assert not violations, "\n".join( + ["Found deprecated items that should have been removed:", *violations] + )