Skip to content

Commit 9856e62

Browse files
authored
Add get_size() (#208)
* Add get_size() * Improve documentation * Fix unit test name
1 parent cdad6b3 commit 9856e62

File tree

4 files changed

+123
-6
lines changed

4 files changed

+123
-6
lines changed

ConfigSpace/configuration_space.pyx

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,32 @@ class ConfigurationSpace(collections.abc.Mapping):
13411341
"""
13421342
self.random = np.random.RandomState(seed)
13431343

1344+
def estimate_size(self) -> Union[float, int]:
1345+
"""
1346+
Estimate the size of the current configuration space (i.e. unique configurations).
1347+
1348+
This is ``np.inf`` in case if there is a single hyperparameter of size ``np.inf`` (i.e. a
1349+
:class:`~ConfigSpace.hyperparameters.UniformFloatHyperparameter`), otherwise
1350+
it is the product of the size of all hyperparameters. The function correctly guesses the
1351+
number of unique configurations if there are no condition and forbidden statements in the
1352+
configuration spaces. Otherwise, this is an upper bound. Use
1353+
:func:`~ConfigSpace.util.generate_grid` to generate all valid configurations if required.
1354+
1355+
Returns
1356+
-------
1357+
Union[float, int]
1358+
"""
1359+
sizes = []
1360+
for hp in self._hyperparameters.values():
1361+
sizes.append(hp.get_size())
1362+
if len(sizes) == 0:
1363+
return 0.0
1364+
else:
1365+
size = sizes[0]
1366+
for i in range(1, len(sizes)):
1367+
size = size * sizes[i]
1368+
return size
1369+
13441370

13451371
class Configuration(collections.abc.Mapping):
13461372
def __init__(self, configuration_space: ConfigurationSpace,

ConfigSpace/hyperparameters.pyx

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ cdef class Hyperparameter(object):
151151
cpdef int compare_vector(self, DTYPE_t value, DTYPE_t value2):
152152
raise NotImplementedError()
153153

154+
def get_size(self) -> float:
155+
raise NotImplementedError()
156+
154157

155158
cdef class Constant(Hyperparameter):
156159
cdef public value
@@ -255,6 +258,8 @@ cdef class Constant(Hyperparameter):
255258
transform: bool = False) -> List:
256259
return []
257260

261+
def get_size(self) -> float:
262+
return 1.0
258263

259264
cdef class UnParametrizedHyperparameter(Constant):
260265
pass
@@ -603,6 +608,12 @@ cdef class UniformFloatHyperparameter(FloatHyperparameter):
603608
neighbors.append(neighbor)
604609
return neighbors
605610

611+
def get_size(self) -> float:
612+
if self.q is None:
613+
return np.inf
614+
else:
615+
return np.rint((self.upper - self.lower) / self.q) + 1
616+
606617

607618
cdef class NormalFloatHyperparameter(FloatHyperparameter):
608619
cdef public mu
@@ -857,6 +868,14 @@ cdef class NormalFloatHyperparameter(FloatHyperparameter):
857868
neighbors.append(new_value)
858869
return neighbors
859870

871+
def get_size(self) -> float:
872+
if self.q is None:
873+
return np.inf
874+
elif self.lower is None:
875+
return np.inf
876+
else:
877+
return np.rint((self.upper - self.lower) / self.q) + 1
878+
860879

861880
cdef class UniformIntegerHyperparameter(IntegerHyperparameter):
862881
def __init__(self, name: str, lower: int, upper: int, default_value: Union[int, None] = None,
@@ -1108,6 +1127,13 @@ cdef class UniformIntegerHyperparameter(IntegerHyperparameter):
11081127

11091128
return neighbors
11101129

1130+
def get_size(self) -> float:
1131+
if self.q is None:
1132+
q = 1
1133+
else:
1134+
q = self.q
1135+
return np.rint((self.upper - self.lower) / q) + 1
1136+
11111137

11121138
cdef class NormalIntegerHyperparameter(IntegerHyperparameter):
11131139
cdef public mu
@@ -1358,6 +1384,16 @@ cdef class NormalIntegerHyperparameter(IntegerHyperparameter):
13581384
neighbors.append(new_value)
13591385
return neighbors
13601386

1387+
def get_size(self) -> float:
1388+
if self.lower is None:
1389+
return np.inf
1390+
else:
1391+
if self.q is None:
1392+
q = 1
1393+
else:
1394+
q = self.q
1395+
return np.rint((self.upper - self.lower) / self.q) + 1
1396+
13611397

13621398
cdef class CategoricalHyperparameter(Hyperparameter):
13631399
cdef public tuple choices
@@ -1623,6 +1659,9 @@ cdef class CategoricalHyperparameter(Hyperparameter):
16231659
"OrdinalHyperparameter, but is "
16241660
"<cdef class 'ConfigSpace.hyperparameters.CategoricalHyperparameter'>")
16251661

1662+
def get_size(self) -> float:
1663+
return len(self.choices)
1664+
16261665

16271666
cdef class OrdinalHyperparameter(Hyperparameter):
16281667
cdef public tuple sequence
@@ -1908,3 +1947,6 @@ cdef class OrdinalHyperparameter(Hyperparameter):
19081947

19091948
def allow_greater_less_comparison(self) -> bool:
19101949
return True
1950+
1951+
def get_size(self) -> float:
1952+
return len(self.sequence)

test/test_configuration_space.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,18 @@ def test_acts_as_mapping(self):
803803
assert list(d.items()) == list(zip(names, hyperparameters))
804804
assert len(d) == 5
805805

806+
def test_estimate_size(self):
807+
cs = ConfigurationSpace()
808+
self.assertEqual(cs.estimate_size(), 0)
809+
cs.add_hyperparameter(Constant('constant', 0))
810+
self.assertEqual(cs.estimate_size(), 1)
811+
cs.add_hyperparameter(UniformIntegerHyperparameter('integer', 0, 5))
812+
self.assertEqual(cs.estimate_size(), 6)
813+
cs.add_hyperparameter(CategoricalHyperparameter('cat', [0, 1, 2]))
814+
self.assertEqual(cs.estimate_size(), 18)
815+
cs.add_hyperparameter(UniformFloatHyperparameter('float', 0, 1))
816+
self.assertTrue(np.isinf(cs.estimate_size()))
817+
806818

807819
class ConfigurationTest(unittest.TestCase):
808820
def setUp(self):

test/test_hyperparameters.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def test_constant(self):
8383
c1_meta = Constant("value", 1, dict(self.meta_data))
8484
self.assertEqual(c1_meta.meta, self.meta_data)
8585

86+
# Test getting the size
87+
for constant in (c1, c2, c3, c4, c5, c1_meta):
88+
self.assertEqual(constant.get_size(), 1)
89+
8690
def test_uniformfloat(self):
8791
# TODO test non-equality
8892
# TODO test sampling from a log-distribution which has a negative
@@ -143,6 +147,12 @@ def test_uniformfloat(self):
143147
default_value=1.0, meta=dict(self.meta_data))
144148
self.assertEqual(f_meta.meta, self.meta_data)
145149

150+
# Test get_size
151+
for float_hp in (f1, f3, f4):
152+
self.assertTrue(np.isinf(float_hp.get_size()))
153+
self.assertEqual(f2.get_size(), 101)
154+
self.assertEqual(f5.get_size(), 100)
155+
146156
def test_uniformfloat_to_integer(self):
147157
f1 = UniformFloatHyperparameter("param", 1, 10, q=0.1, log=True)
148158
with pytest.warns(UserWarning, match="Setting quantization < 1 for Integer "
@@ -281,6 +291,11 @@ def test_normalfloat(self):
281291
default_value=1.0, meta=dict(self.meta_data))
282292
self.assertEqual(f_meta.meta, self.meta_data)
283293

294+
# Test get_size
295+
for float_hp in (f1, f2, f3, f4, f5):
296+
self.assertTrue(np.isinf(float_hp.get_size()))
297+
self.assertEqual(f6.get_size(), 100)
298+
284299
def test_normalfloat_to_uniformfloat(self):
285300
f1 = NormalFloatHyperparameter("param", 0, 10, q=0.1)
286301
f1_expected = UniformFloatHyperparameter("param", -30, 30, q=0.1)
@@ -343,12 +358,12 @@ def test_uniforminteger(self):
343358
"param, Type: UniformInteger, Range: [0, 10], Default: 5",
344359
str(f2))
345360

346-
# f2_large_q = UniformIntegerHyperparameter("param", 0, 10, q=2)
347-
# f2_large_q_ = UniformIntegerHyperparameter("param", 0, 10, q=2)
348-
# self.assertEqual(f2_large_q, f2_large_q_)
349-
# self.assertEqual(
350-
# "param, Type: UniformInteger, Range: [0, 10], Default: 5, Q: 2",
351-
# str(f2_large_q))
361+
f2_large_q = UniformIntegerHyperparameter("param", 0, 10, q=2)
362+
f2_large_q_ = UniformIntegerHyperparameter("param", 0, 10, q=2)
363+
self.assertEqual(f2_large_q, f2_large_q_)
364+
self.assertEqual(
365+
"param, Type: UniformInteger, Range: [0, 10], Default: 5, Q: 2",
366+
str(f2_large_q))
352367

353368
f3 = UniformIntegerHyperparameter("param", 1, 10, log=True)
354369
f3_ = UniformIntegerHyperparameter("param", 1, 10, log=True)
@@ -382,6 +397,13 @@ def test_uniforminteger(self):
382397
default_value=1, meta=dict(self.meta_data))
383398
self.assertEqual(f_meta.meta, self.meta_data)
384399

400+
self.assertEqual(f1.get_size(), 6)
401+
self.assertEqual(f2.get_size(), 11)
402+
self.assertEqual(f2_large_q.get_size(), 6)
403+
self.assertEqual(f3.get_size(), 10)
404+
self.assertEqual(f4.get_size(), 10)
405+
self.assertEqual(f5.get_size(), 10)
406+
385407
def test_uniformint_legal_float_values(self):
386408
n_iter = UniformIntegerHyperparameter("n_iter", 5., 1000., default_value=20.0)
387409

@@ -475,6 +497,10 @@ def test_normalint(self):
475497
meta=dict(self.meta_data))
476498
self.assertEqual(f_meta.meta, self.meta_data)
477499

500+
# Test get_size
501+
for int_hp in (f1, f2, f3, f4, f5):
502+
self.assertTrue(np.isinf(int_hp.get_size()))
503+
478504
def test_normalint_legal_float_values(self):
479505
n_iter = NormalIntegerHyperparameter("n_iter", 0, 1., default_value=2.0)
480506
self.assertIsInstance(n_iter.default_value, int)
@@ -556,6 +582,13 @@ def test_categorical(self):
556582
meta=dict(self.meta_data))
557583
self.assertEqual(f_meta.meta, self.meta_data)
558584

585+
self.assertEqual(f1.get_size(), 2)
586+
self.assertEqual(f2.get_size(), 1000)
587+
self.assertEqual(f3.get_size(), 999)
588+
self.assertEqual(f4.get_size(), 1000)
589+
self.assertEqual(f5.get_size(), 1000)
590+
self.assertEqual(f6.get_size(), 2)
591+
559592
def test_categorical_strings(self):
560593
f1 = CategoricalHyperparameter("param", ["a", "b"])
561594
f1_ = CategoricalHyperparameter("param", ["a", "b"])
@@ -1155,6 +1188,10 @@ def test_get_num_neighbors(self):
11551188
self.assertEqual(f1.get_num_neighbors("hot"), 1)
11561189
self.assertEqual(f1.get_num_neighbors("cold"), 2)
11571190

1191+
def test_ordinal_get_size(self):
1192+
f1 = OrdinalHyperparameter("temp", ["freezing", "cold", "warm", "hot"])
1193+
self.assertEqual(f1.get_size(), 4)
1194+
11581195
def test_rvs(self):
11591196
f1 = UniformFloatHyperparameter("param", 0, 10)
11601197

0 commit comments

Comments
 (0)