Skip to content

Commit

Permalink
install jax jaxlib at runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
thierrymoudiki committed Oct 6, 2024
1 parent aad3ad2 commit f99dd1c
Show file tree
Hide file tree
Showing 18 changed files with 1,737 additions and 1,635 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,8 @@ install: clean ## install the package to the active Python's site-packages
run-examples: ## run all examples with one command
find examples -maxdepth 2 -name "*.py" -exec python3 {} \;

run-booster: ## run all boosting estimators examples with one command
find examples -maxdepth 2 -name "*boost_*.py" -exec python3 {} \;

run-lazy: ## run all lazy estimators examples with one command
find examples -maxdepth 2 -name "*lazy*.py" -exec python3 {} \;
4 changes: 1 addition & 3 deletions mlsauce.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: mlsauce
Version: 0.20.1
Version: 0.20.2
Summary: Miscellaneous Statistical/Machine Learning tools
Maintainer: T. Moudiki
Maintainer-email: [email protected]
Expand Down Expand Up @@ -29,8 +29,6 @@ Requires-Dist: requests
Requires-Dist: scikit-learn
Requires-Dist: scipy
Requires-Dist: tqdm
Requires-Dist: jax
Requires-Dist: jaxlib
Provides-Extra: alldeps
Requires-Dist: numpy>=1.13.0; extra == "alldeps"
Requires-Dist: scipy>=0.19.0; extra == "alldeps"
Expand Down
2 changes: 0 additions & 2 deletions mlsauce.egg-info/requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ requests
scikit-learn
scipy
tqdm
jax
jaxlib

[alldeps]
numpy>=1.13.0
Expand Down
5 changes: 4 additions & 1 deletion mlsauce/booster/_booster_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from . import _boosterc as boosterc
except ImportError:
import _boosterc as boosterc
from ..utils import cluster
from ..utils import cluster, check_and_install


class LSBoostClassifier(BaseEstimator, ClassifierMixin):
Expand Down Expand Up @@ -167,6 +167,9 @@ def __init__(
self.degree = degree
self.poly_ = None
self.weights_distr = weights_distr
if self.backend in ("gpu", "tpu"):
check_and_install("jax")
check_and_install("jaxlib")

def fit(self, X, y, **kwargs):
"""Fit Booster (classifier) to training data (X, y)
Expand Down
5 changes: 4 additions & 1 deletion mlsauce/booster/_booster_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
except ImportError:
import _boosterc as boosterc
from ..predictioninterval import PredictionInterval
from ..utils import cluster
from ..utils import cluster, check_and_install


class LSBoostRegressor(BaseEstimator, RegressorMixin):
Expand Down Expand Up @@ -183,6 +183,9 @@ def __init__(
self.degree = degree
self.poly_ = None
self.weights_distr = weights_distr
if self.backend in ("gpu", "tpu"):
check_and_install("jax")
check_and_install("jaxlib")

def fit(self, X, y, **kwargs):
"""Fit Booster (regressor) to training data (X, y)
Expand Down
9 changes: 7 additions & 2 deletions mlsauce/elasticnet/enet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin
from numpy.linalg import inv
from ..utils import get_beta
from ..utils import get_beta, check_and_install
from ._enet import fit_elasticnet, predict_elasticnet

if platform.system() in ("Linux", "Darwin"):
try:
import jax.numpy as jnp
from jax import device_put
from jax.numpy.linalg import inv as jinv
except ImportError:
pass


class ElasticNetRegressor(BaseEstimator, RegressorMixin):
Expand Down Expand Up @@ -48,6 +50,9 @@ def __init__(self, reg_lambda=0.1, alpha=0.5, backend="cpu"):
self.reg_lambda = reg_lambda
self.alpha = alpha
self.backend = backend
if self.backend in ("gpu", "tpu"):
check_and_install("jax")
check_and_install("jaxlib")

def fit(self, X, y, **kwargs):
"""Fit matrixops (classifier) to training data (X, y)
Expand Down
9 changes: 7 additions & 2 deletions mlsauce/lasso/_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from . import _lassoc as mo
except ImportError:
import _lassoc as mo
from ..utils import get_beta
from ..utils import get_beta, check_and_install

if platform.system() in ("Linux", "Darwin"):
try:
import jax.numpy as jnp
from jax import device_put
from jax.numpy.linalg import inv as jinv
except ImportError:
pass


class LassoRegressor(BaseEstimator, RegressorMixin):
Expand Down Expand Up @@ -56,6 +58,9 @@ def __init__(self, reg_lambda=0.1, max_iter=10, tol=1e-3, backend="cpu"):
self.max_iter = max_iter
self.tol = tol
self.backend = backend
if self.backend in ("gpu", "tpu"):
check_and_install("jax")
check_and_install("jaxlib")

def fit(self, X, y, **kwargs):
"""Fit matrixops (classifier) to training data (X, y)
Expand Down
Loading

0 comments on commit f99dd1c

Please sign in to comment.