Skip to content

Commit 2baa25e

Browse files
FIX beam_width upper bound (#193)
1 parent 085a225 commit 2baa25e

File tree

3 files changed

+20
-7
lines changed

3 files changed

+20
-7
lines changed

fastcan/_fastcan.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,17 +216,21 @@ def fit(self, X, y):
216216
"`indices_include` and `indices_exclude` should not have intersection."
217217
)
218218

219-
if (
220-
n_features - self.indices_exclude_.size
221-
< self.n_features_to_select + self.beam_width - 1
222-
):
219+
if n_features - self.indices_exclude_.size < self.n_features_to_select:
223220
raise ValueError(
224-
"n_features - n_exclusions should >= "
225-
"n_features_to_select + beam_width - 1."
221+
"n_features_to_select should <= n_features - n_exclusions."
226222
)
227223
if self.n_features_to_select < self.indices_include_.size:
228224
raise ValueError("n_features_to_select should >= n_inclusions.")
229225

226+
if (
227+
self.beam_width
228+
> n_features - self.indices_exclude_.size - self.indices_include_.size
229+
):
230+
raise ValueError(
231+
"beam_width should <= n_features - n_exclusions - n_inclusions."
232+
)
233+
230234
if self.eta:
231235
xy_hstack = np.hstack((X, y))
232236
xy_centered = xy_hstack - xy_hstack.mean(0)

tests/test_beam.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ def test_beam_error():
3333
X_origin = rng.normal(size=(n_samples, n_features))
3434
y = rng.normal(size=n_samples)
3535

36+
X = X_origin.copy()
37+
# Should pass without error
38+
FastCan(n_features_to_select=1, beam_width=n_features).fit(X, y)
39+
# Should raise an error
40+
with pytest.raises(ValueError, match=r"beam_width should <= .*"):
41+
FastCan(n_features_to_select=1, indices_include=[0], beam_width=n_features).fit(
42+
X, y
43+
)
44+
3645
X = X_origin.copy()
3746
X[:, [0, 1, 2]] = 0 # Zero feature
3847
with pytest.raises(ValueError, match=r"Beam Search: Not enough valid candidates.*"):

tests/test_fastcan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_raise_errors():
237237
with pytest.raises(ValueError, match=r"`indices_include` and `indices_exclu.*"):
238238
selector_include_exclude_intersect.fit(X, y)
239239

240-
with pytest.raises(ValueError, match=r"n_features - n_exclusions should.*"):
240+
with pytest.raises(ValueError, match=r"n_features_to_select should <=.*"):
241241
selector_n_candidates.fit(X, y)
242242

243243
with pytest.raises(ValueError, match=r"n_features_to_select should.*"):

0 commit comments

Comments
 (0)