@@ -452,6 +452,12 @@ def __init__(
452452 if kappa < 0 :
453453 error_msg = "kappa must be greater than or equal to 0."
454454 raise ValueError (error_msg )
455+ if exploration_decay is not None and not (0 < exploration_decay <= 1 ):
456+ error_msg = "exploration_decay must be greater than 0 and less than or equal to 1."
457+ raise ValueError (error_msg )
458+ if exploration_decay_delay is not None and (not isinstance (exploration_decay_delay , int ) or exploration_decay_delay < 0 ):
459+ error_msg = "exploration_decay_delay must be an integer greater than or equal to 0."
460+ raise ValueError (error_msg )
455461
456462 super ().__init__ (random_state = random_state )
457463 self .kappa = kappa
@@ -604,6 +610,16 @@ def __init__(
604610 exploration_decay_delay : int | None = None ,
605611 random_state : int | RandomState | None = None ,
606612 ) -> None :
613+ if xi <= 0 :
614+ error_msg = "xi must be greater than 0."
615+ raise ValueError (error_msg )
616+ if exploration_decay is not None and not (0 < exploration_decay <= 1 ):
617+ error_msg = "exploration_decay must be greater than 0 and less than or equal to 1."
618+ raise ValueError (error_msg )
619+ if exploration_decay_delay is not None and (not isinstance (exploration_decay_delay , int ) or exploration_decay_delay < 0 ):
620+ error_msg = "exploration_decay_delay must be an integer greater than or equal to 0."
621+ raise ValueError (error_msg )
622+
607623 super ().__init__ (random_state = random_state )
608624 self .xi = xi
609625 self .exploration_decay = exploration_decay
@@ -778,6 +794,16 @@ def __init__(
778794 exploration_decay_delay : int | None = None ,
779795 random_state : int | RandomState | None = None ,
780796 ) -> None :
797+ if xi <= 0 :
798+ error_msg = "xi must be greater than 0."
799+ raise ValueError (error_msg )
800+ if exploration_decay is not None and not (0 < exploration_decay <= 1 ):
801+ error_msg = "exploration_decay must be greater than 0 and less than or equal to 1."
802+ raise ValueError (error_msg )
803+ if exploration_decay_delay is not None and (not isinstance (exploration_decay_delay , int ) or exploration_decay_delay < 0 ):
804+ error_msg = "exploration_decay_delay must be an integer greater than or equal to 0."
805+ raise ValueError (error_msg )
806+
781807 super ().__init__ (random_state = random_state )
782808 self .xi = xi
783809 self .exploration_decay = exploration_decay
0 commit comments