@@ -83,6 +83,10 @@ def test_constant(self):
8383 c1_meta = Constant ("value" , 1 , dict (self .meta_data ))
8484 self .assertEqual (c1_meta .meta , self .meta_data )
8585
86+ # Test getting the size
87+ for constant in (c1 , c2 , c3 , c4 , c5 , c1_meta ):
88+ self .assertEqual (constant .get_size (), 1 )
89+
8690 def test_uniformfloat (self ):
8791 # TODO test non-equality
8892 # TODO test sampling from a log-distribution which has a negative
@@ -143,6 +147,12 @@ def test_uniformfloat(self):
143147 default_value = 1.0 , meta = dict (self .meta_data ))
144148 self .assertEqual (f_meta .meta , self .meta_data )
145149
150+ # Test get_size
151+ for float_hp in (f1 , f3 , f4 ):
152+ self .assertTrue (np .isinf (float_hp .get_size ()))
153+ self .assertEqual (f2 .get_size (), 101 )
154+ self .assertEqual (f5 .get_size (), 100 )
155+
146156 def test_uniformfloat_to_integer (self ):
147157 f1 = UniformFloatHyperparameter ("param" , 1 , 10 , q = 0.1 , log = True )
148158 with pytest .warns (UserWarning , match = "Setting quantization < 1 for Integer "
@@ -281,6 +291,11 @@ def test_normalfloat(self):
281291 default_value = 1.0 , meta = dict (self .meta_data ))
282292 self .assertEqual (f_meta .meta , self .meta_data )
283293
294+ # Test get_size
295+ for float_hp in (f1 , f2 , f3 , f4 , f5 ):
296+ self .assertTrue (np .isinf (float_hp .get_size ()))
297+ self .assertEqual (f6 .get_size (), 100 )
298+
284299 def test_normalfloat_to_uniformfloat (self ):
285300 f1 = NormalFloatHyperparameter ("param" , 0 , 10 , q = 0.1 )
286301 f1_expected = UniformFloatHyperparameter ("param" , - 30 , 30 , q = 0.1 )
@@ -343,12 +358,12 @@ def test_uniforminteger(self):
343358 "param, Type: UniformInteger, Range: [0, 10], Default: 5" ,
344359 str (f2 ))
345360
346- # f2_large_q = UniformIntegerHyperparameter("param", 0, 10, q=2)
347- # f2_large_q_ = UniformIntegerHyperparameter("param", 0, 10, q=2)
348- # self.assertEqual(f2_large_q, f2_large_q_)
349- # self.assertEqual(
350- # "param, Type: UniformInteger, Range: [0, 10], Default: 5, Q: 2",
351- # str(f2_large_q))
361+ f2_large_q = UniformIntegerHyperparameter ("param" , 0 , 10 , q = 2 )
362+ f2_large_q_ = UniformIntegerHyperparameter ("param" , 0 , 10 , q = 2 )
363+ self .assertEqual (f2_large_q , f2_large_q_ )
364+ self .assertEqual (
365+ "param, Type: UniformInteger, Range: [0, 10], Default: 5, Q: 2" ,
366+ str (f2_large_q ))
352367
353368 f3 = UniformIntegerHyperparameter ("param" , 1 , 10 , log = True )
354369 f3_ = UniformIntegerHyperparameter ("param" , 1 , 10 , log = True )
@@ -382,6 +397,13 @@ def test_uniforminteger(self):
382397 default_value = 1 , meta = dict (self .meta_data ))
383398 self .assertEqual (f_meta .meta , self .meta_data )
384399
400+ self .assertEqual (f1 .get_size (), 6 )
401+ self .assertEqual (f2 .get_size (), 11 )
402+ self .assertEqual (f2_large_q .get_size (), 6 )
403+ self .assertEqual (f3 .get_size (), 10 )
404+ self .assertEqual (f4 .get_size (), 10 )
405+ self .assertEqual (f5 .get_size (), 10 )
406+
385407 def test_uniformint_legal_float_values (self ):
386408 n_iter = UniformIntegerHyperparameter ("n_iter" , 5. , 1000. , default_value = 20.0 )
387409
@@ -475,6 +497,10 @@ def test_normalint(self):
475497 meta = dict (self .meta_data ))
476498 self .assertEqual (f_meta .meta , self .meta_data )
477499
500+ # Test get_size
501+ for int_hp in (f1 , f2 , f3 , f4 , f5 ):
502+ self .assertTrue (np .isinf (int_hp .get_size ()))
503+
478504 def test_normalint_legal_float_values (self ):
479505 n_iter = NormalIntegerHyperparameter ("n_iter" , 0 , 1. , default_value = 2.0 )
480506 self .assertIsInstance (n_iter .default_value , int )
@@ -556,6 +582,13 @@ def test_categorical(self):
556582 meta = dict (self .meta_data ))
557583 self .assertEqual (f_meta .meta , self .meta_data )
558584
585+ self .assertEqual (f1 .get_size (), 2 )
586+ self .assertEqual (f2 .get_size (), 1000 )
587+ self .assertEqual (f3 .get_size (), 999 )
588+ self .assertEqual (f4 .get_size (), 1000 )
589+ self .assertEqual (f5 .get_size (), 1000 )
590+ self .assertEqual (f6 .get_size (), 2 )
591+
559592 def test_categorical_strings (self ):
560593 f1 = CategoricalHyperparameter ("param" , ["a" , "b" ])
561594 f1_ = CategoricalHyperparameter ("param" , ["a" , "b" ])
@@ -1155,6 +1188,10 @@ def test_get_num_neighbors(self):
11551188 self .assertEqual (f1 .get_num_neighbors ("hot" ), 1 )
11561189 self .assertEqual (f1 .get_num_neighbors ("cold" ), 2 )
11571190
1191+ def test_ordinal_get_size (self ):
1192+ f1 = OrdinalHyperparameter ("temp" , ["freezing" , "cold" , "warm" , "hot" ])
1193+ self .assertEqual (f1 .get_size (), 4 )
1194+
11581195 def test_rvs (self ):
11591196 f1 = UniformFloatHyperparameter ("param" , 0 , 10 )
11601197
0 commit comments