Skip to content

Commit

Permalink
remove try to get error
Browse files Browse the repository at this point in the history
  • Loading branch information
thierrymoudiki committed Oct 9, 2024
1 parent 541f7bb commit bd692f4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 46 deletions.
43 changes: 20 additions & 23 deletions mlsauce/booster/_booster_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,29 +411,26 @@ def fit(self, X, y, **kwargs):
)
X = np.column_stack((X, clustered_X))

try:
self.obj = boosterc.fit_booster_classifier(
np.asarray(X, order="C"),
np.asarray(y, order="C"),
n_estimators=self.n_estimators,
learning_rate=self.learning_rate,
n_hidden_features=self.n_hidden_features,
reg_lambda=self.reg_lambda,
alpha=self.alpha,
row_sample=self.row_sample,
col_sample=self.col_sample,
dropout=self.dropout,
tolerance=self.tolerance,
direct_link=self.direct_link,
verbose=self.verbose,
seed=self.seed,
backend=self.backend,
solver=self.solver,
activation=self.activation,
obj=self.base_model,
)
except ValueError:
pass
self.obj = boosterc.fit_booster_classifier(
np.asarray(X, order="C"),
np.asarray(y, order="C"),
n_estimators=self.n_estimators,
learning_rate=self.learning_rate,
n_hidden_features=self.n_hidden_features,
reg_lambda=self.reg_lambda,
alpha=self.alpha,
row_sample=self.row_sample,
col_sample=self.col_sample,
dropout=self.dropout,
tolerance=self.tolerance,
direct_link=self.direct_link,
verbose=self.verbose,
seed=self.seed,
backend=self.backend,
solver=self.solver,
activation=self.activation,
obj=self.base_model,
)

self.n_classes_ = len(np.unique(y)) # for compatibility with sklearn
self.n_estimators = self.obj["n_estimators"]
Expand Down
43 changes: 20 additions & 23 deletions mlsauce/booster/_booster_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,29 +281,26 @@ def fit(self, X, y, **kwargs):
)
X = np.column_stack((X, clustered_X))

try:
self.obj = boosterc.fit_booster_regressor(
X=np.asarray(X, order="C"),
y=np.asarray(y, order="C"),
n_estimators=self.n_estimators,
learning_rate=self.learning_rate,
n_hidden_features=self.n_hidden_features,
reg_lambda=self.reg_lambda,
alpha=self.alpha,
row_sample=self.row_sample,
col_sample=self.col_sample,
dropout=self.dropout,
tolerance=self.tolerance,
direct_link=self.direct_link,
verbose=self.verbose,
seed=self.seed,
backend=self.backend,
solver=self.solver,
activation=self.activation,
obj=self.base_model,
)
except ValueError:
pass
self.obj = boosterc.fit_booster_regressor(
X=np.asarray(X, order="C"),
y=np.asarray(y, order="C"),
n_estimators=self.n_estimators,
learning_rate=self.learning_rate,
n_hidden_features=self.n_hidden_features,
reg_lambda=self.reg_lambda,
alpha=self.alpha,
row_sample=self.row_sample,
col_sample=self.col_sample,
dropout=self.dropout,
tolerance=self.tolerance,
direct_link=self.direct_link,
verbose=self.verbose,
seed=self.seed,
backend=self.backend,
solver=self.solver,
activation=self.activation,
obj=self.base_model,
)

self.n_estimators = self.obj["n_estimators"]

Expand Down

0 comments on commit bd692f4

Please sign in to comment.