Skip to content

Commit a344b8a

Browse files
committed
Add paramter validation
1 parent 8e7a286 commit a344b8a

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

bayes_opt/acquisition.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,12 @@ def __init__(
452452
if kappa < 0:
453453
error_msg = "kappa must be greater than or equal to 0."
454454
raise ValueError(error_msg)
455+
if exploration_decay is not None and not (0 < exploration_decay <= 1):
456+
error_msg = "exploration_decay must be greater than 0 and less than or equal to 1."
457+
raise ValueError(error_msg)
458+
if exploration_decay_delay is not None and (not isinstance(exploration_decay_delay, int) or exploration_decay_delay < 0):
459+
error_msg = "exploration_decay_delay must be an integer greater than or equal to 0."
460+
raise ValueError(error_msg)
455461

456462
super().__init__(random_state=random_state)
457463
self.kappa = kappa
@@ -604,6 +610,16 @@ def __init__(
604610
exploration_decay_delay: int | None = None,
605611
random_state: int | RandomState | None = None,
606612
) -> None:
613+
if xi <= 0:
614+
error_msg = "xi must be greater than 0."
615+
raise ValueError(error_msg)
616+
if exploration_decay is not None and not (0 < exploration_decay <= 1):
617+
error_msg = "exploration_decay must be greater than 0 and less than or equal to 1."
618+
raise ValueError(error_msg)
619+
if exploration_decay_delay is not None and (not isinstance(exploration_decay_delay, int) or exploration_decay_delay < 0):
620+
error_msg = "exploration_decay_delay must be an integer greater than or equal to 0."
621+
raise ValueError(error_msg)
622+
607623
super().__init__(random_state=random_state)
608624
self.xi = xi
609625
self.exploration_decay = exploration_decay
@@ -778,6 +794,16 @@ def __init__(
778794
exploration_decay_delay: int | None = None,
779795
random_state: int | RandomState | None = None,
780796
) -> None:
797+
if xi <= 0:
798+
error_msg = "xi must be greater than 0."
799+
raise ValueError(error_msg)
800+
if exploration_decay is not None and not (0 < exploration_decay <= 1):
801+
error_msg = "exploration_decay must be greater than 0 and less than or equal to 1."
802+
raise ValueError(error_msg)
803+
if exploration_decay_delay is not None and (not isinstance(exploration_decay_delay, int) or exploration_decay_delay < 0):
804+
error_msg = "exploration_decay_delay must be an integer greater than or equal to 0."
805+
raise ValueError(error_msg)
806+
781807
super().__init__(random_state=random_state)
782808
self.xi = xi
783809
self.exploration_decay = exploration_decay

0 commit comments

Comments
 (0)