Skip to content

Commit

Permalink
Update _adaopt.py
Browse files Browse the repository at this point in the history
  • Loading branch information
thierrymoudiki authored Apr 24, 2024
1 parent 1f34d62 commit b3b86f3
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions mlsauce/adaopt/_adaopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
try:
from . import _adaoptc as adaoptc
except ImportError:
import _adaoptc as adaoptc
from ._adaoptc import *
except ImportError:
import _adaoptc



Expand Down Expand Up @@ -194,7 +196,7 @@ def fit(self, X, y, **kwargs):

assert n == len(y_), "must have X.shape[0] == len(y)"

res = adaoptc.fit_adaopt(
res = fit_adaopt(
X=np.asarray(X_).astype(np.float64),
y=np.asarray(y_).astype(np.int64),
n_iterations=self.n_iterations,
Expand Down Expand Up @@ -268,7 +270,7 @@ def predict_proba(self, X, **kwargs):
n_test = X.shape[0]

if self.n_jobs is None:
return adaoptc.predict_proba_adaopt(
return predict_proba_adaopt(
X_test=np.asarray(X, order="C").astype(np.float64),
scaled_X_train=np.asarray(
self.scaled_X_train, order="C"
Expand Down Expand Up @@ -310,17 +312,17 @@ def multiproc_func(i):
p_train,
)

kmin_test_i = adaoptc.find_kmin_x(
kmin_test_i = find_kmin_x(
dists_test_i, n_x=n_train, k=self.k, cache=self.cache
)

weights_test_i = adaoptc.calculate_weights(kmin_test_i[0])
weights_test_i = calculate_weights(kmin_test_i[0])

probs_test_i = adaoptc.calculate_probs(
probs_test_i = calculate_probs(
kmin_test_i[1], self.probs_training
)

return adaoptc.average_probs(
return average_probs(
probs=probs_test_i, weights=weights_test_i
)

Expand All @@ -329,7 +331,7 @@ def multiproc_func(i):
@delayed
@wrap_non_picklable_objects
def multiproc_func(i):
dists_test_i = adaoptc.distance_to_mat_manhattan2(
dists_test_i = distance_to_mat_manhattan2(
np.asarray(scaled_X_test.astype(np.float64), order="C")[
i, :
],
Expand All @@ -341,17 +343,17 @@ def multiproc_func(i):
p_train,
)

kmin_test_i = adaoptc.find_kmin_x(
kmin_test_i = find_kmin_x(
dists_test_i, n_x=n_train, k=self.k, cache=self.cache
)

weights_test_i = adaoptc.calculate_weights(kmin_test_i[0])
weights_test_i = calculate_weights(kmin_test_i[0])

probs_test_i = adaoptc.calculate_probs(
probs_test_i = calculate_probs(
kmin_test_i[1], self.probs_training
)

return adaoptc.average_probs(
return average_probs(
probs=probs_test_i, weights=weights_test_i
)

Expand All @@ -372,17 +374,17 @@ def multiproc_func(i, *args):
p_train,
)

kmin_test_i = adaoptc.find_kmin_x(
kmin_test_i = find_kmin_x(
dists_test_i, n_x=n_train, k=self.k, cache=self.cache
)

weights_test_i = adaoptc.calculate_weights(kmin_test_i[0])
weights_test_i = calculate_weights(kmin_test_i[0])

probs_test_i = adaoptc.calculate_probs(
probs_test_i = calculate_probs(
kmin_test_i[1], self.probs_training
)

return adaoptc.average_probs(
return average_probs(
probs=probs_test_i, weights=weights_test_i
)

Expand Down

0 comments on commit b3b86f3

Please sign in to comment.