Skip to content

Commit

Permalink
Merge pull request #609 from MilesCranmer/cleanup
Browse files Browse the repository at this point in the history
More extensive typing stubs and associated refactoring
  • Loading branch information
MilesCranmer committed Jun 16, 2024
2 parents 476f573 + 291dc85 commit f653388
Show file tree
Hide file tree
Showing 14 changed files with 378 additions and 182 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ site
**/*.code-workspace
**/*.tar.gz
venv
requirements-dev.lock
requirements.lock
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ dependencies:
- scikit-learn>=1.0.0,<2.0.0
- pyjuliacall>=0.9.15,<0.10.0
- click>=7.0.0,<9.0.0
- typing_extensions>=4.0.0,<5.0.0
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,10 @@ dev-dependencies = [
"pre-commit>=3.7.0",
"ipython>=8.23.0",
"ipykernel>=6.29.4",
"mypy>=1.10.0",
"jax[cpu]>=0.4.26",
"torch>=2.3.0",
"pandas-stubs>=2.2.1.240316",
"types-pytz>=2024.1.0.20240417",
"types-openpyxl>=3.1.0.20240428",
]
21 changes: 17 additions & 4 deletions pysr/denoising.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
"""Functions for denoising data during preprocessing."""

from typing import Optional, Tuple, cast

import numpy as np
from numpy import ndarray


def denoise(X, y, Xresampled=None, random_state=None):
def denoise(
X: ndarray,
y: ndarray,
Xresampled: Optional[ndarray] = None,
random_state: Optional[np.random.RandomState] = None,
) -> Tuple[ndarray, ndarray]:
"""Denoise the dataset using a Gaussian process."""
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel, WhiteKernel
Expand All @@ -15,12 +23,17 @@ def denoise(X, y, Xresampled=None, random_state=None):
gpr.fit(X, y)

if Xresampled is not None:
return Xresampled, gpr.predict(Xresampled)
return Xresampled, cast(ndarray, gpr.predict(Xresampled))

return X, gpr.predict(X)
return X, cast(ndarray, gpr.predict(X))


def multi_denoise(X, y, Xresampled=None, random_state=None):
def multi_denoise(
X: ndarray,
y: ndarray,
Xresampled: Optional[ndarray] = None,
random_state: Optional[np.random.RandomState] = None,
):
"""Perform `denoise` along each column of `y` independently."""
y = np.stack(
[
Expand Down
12 changes: 12 additions & 0 deletions pysr/export_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,15 @@ def sympy2multilatextable(
]

return "\n\n".join(latex_tables)


def with_preamble(table_string: str) -> str:
preamble_string = [
r"\usepackage{breqn}",
r"\usepackage{booktabs}",
"",
"...",
"",
table_string,
]
return "\n".join(preamble_string)
12 changes: 10 additions & 2 deletions pysr/export_numpy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Code for exporting discovered expressions to numpy"""

import warnings
from typing import List, Union

import numpy as np
import pandas as pd
from sympy import lambdify
from numpy.typing import NDArray
from sympy import Expr, Symbol, lambdify


def sympy2numpy(eqn, sympy_symbols, *, selection=None):
Expand All @@ -14,6 +16,10 @@ def sympy2numpy(eqn, sympy_symbols, *, selection=None):
class CallableEquation:
"""Simple wrapper for numpy lambda functions built with sympy"""

_sympy: Expr
_sympy_symbols: List[Symbol]
_selection: Union[NDArray[np.bool_], None]

def __init__(self, eqn, sympy_symbols, selection=None):
self._sympy = eqn
self._sympy_symbols = sympy_symbols
Expand All @@ -29,15 +35,17 @@ def __call__(self, X):
return self._lambda(
**{k: X[k].values for k in map(str, self._sympy_symbols)}
) * np.ones(expected_shape)

if self._selection is not None:
if X.shape[1] != len(self._selection):
if X.shape[1] != self._selection.sum():
warnings.warn(
"`X` should be of shape (n_samples, len(self._selection)). "
"Automatically filtering `X` to selection. "
"Note: Filtered `X` column order may not match column order in fit "
"this may lead to incorrect predictions and other errors."
)
X = X[:, self._selection]

return self._lambda(*X.T) * np.ones(expected_shape)

@property
Expand Down
12 changes: 7 additions & 5 deletions pysr/export_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import sympy
from sympy import sympify

from .utils import ArrayLike

sympy_mappings = {
"div": lambda x, y: x / y,
"mult": lambda x, y: x * y,
Expand All @@ -30,8 +32,8 @@
"acosh": lambda x: sympy.acosh(x),
"acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
"asinh": sympy.asinh,
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - sympy.S(1)),
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - sympy.S(1)),
"abs": abs,
"mod": sympy.Mod,
"erf": sympy.erf,
Expand Down Expand Up @@ -60,21 +62,21 @@


def create_sympy_symbols_map(
feature_names_in: List[str],
feature_names_in: ArrayLike[str],
) -> Dict[str, sympy.Symbol]:
return {variable: sympy.Symbol(variable) for variable in feature_names_in}


def create_sympy_symbols(
feature_names_in: List[str],
feature_names_in: ArrayLike[str],
) -> List[sympy.Symbol]:
return [sympy.Symbol(variable) for variable in feature_names_in]


def pysr2sympy(
equation: str,
*,
feature_names_in: Optional[List[str]] = None,
feature_names_in: Optional[ArrayLike[str]] = None,
extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
):
if feature_names_in is None:
Expand Down
22 changes: 19 additions & 3 deletions pysr/feature_selection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
"""Functions for doing feature selection during preprocessing."""

from typing import Optional, cast

import numpy as np
from numpy import ndarray
from numpy.typing import NDArray

from .utils import ArrayLike


def run_feature_selection(X, y, select_k_features, random_state=None):
def run_feature_selection(
X: ndarray,
y: ndarray,
select_k_features: int,
random_state: Optional[np.random.RandomState] = None,
) -> NDArray[np.bool_]:
"""
Find most important features.
Expand All @@ -21,11 +32,16 @@ def run_feature_selection(X, y, select_k_features, random_state=None):
selector = SelectFromModel(
clf, threshold=-np.inf, max_features=select_k_features, prefit=True
)
return selector.get_support(indices=True)
return cast(NDArray[np.bool_], selector.get_support(indices=False))


# Function has not been removed only due to usage in module tests
def _handle_feature_selection(X, select_k_features, y, variable_names):
def _handle_feature_selection(
X: ndarray,
select_k_features: Optional[int],
y: ndarray,
variable_names: ArrayLike[str],
):
if select_k_features is not None:
selection = run_feature_selection(X, y, select_k_features)
print(f"Using features {[variable_names[i] for i in selection]}")
Expand Down
22 changes: 17 additions & 5 deletions pysr/julia_helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
"""Functions for initializing the Julia environment and installing deps."""

from typing import Any, Callable, Union, cast

import numpy as np
from juliacall import convert as jl_convert # type: ignore
from numpy.typing import NDArray

from .deprecated import init_julia, install
from .julia_import import jl

jl_convert = cast(Callable[[Any, Any], Any], jl_convert)

jl.seval("using Serialization: Serialization")
jl.seval("using PythonCall: PythonCall")

Expand All @@ -22,24 +27,31 @@ def _escape_filename(filename):
return str_repr


def _load_cluster_manager(cluster_manager):
def _load_cluster_manager(cluster_manager: str):
jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
return jl.seval(f"addprocs_{cluster_manager}")


def jl_array(x):
def jl_array(x, dtype=None):
if x is None:
return None
return jl_convert(jl.Array, x)
elif dtype is None:
return jl_convert(jl.Array, x)
else:
return jl_convert(jl.Array[dtype], x)


def jl_is_function(f) -> bool:
return cast(bool, jl.seval("op -> op isa Function")(f))


def jl_serialize(obj):
def jl_serialize(obj: Any) -> NDArray[np.uint8]:
buf = jl.IOBuffer()
Serialization.serialize(buf, obj)
return np.array(jl.take_b(buf))


def jl_deserialize(s):
def jl_deserialize(s: Union[NDArray[np.uint8], None]):
if s is None:
return s
buf = jl.IOBuffer()
Expand Down
5 changes: 5 additions & 0 deletions pysr/julia_import.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import sys
import warnings
from types import ModuleType
from typing import cast

# Check if JuliaCall is already loaded, and if so, warn the user
# about the relevant environment variables. If not loaded,
Expand Down Expand Up @@ -42,6 +44,9 @@

from juliacall import Main as jl # type: ignore

jl = cast(ModuleType, jl)


jl_version = (jl.VERSION.major, jl.VERSION.minor, jl.VERSION.patch)

jl.seval("using SymbolicRegression")
Expand Down
Loading

0 comments on commit f653388

Please sign in to comment.