Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Add search space validation for choice types (#3975)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhe-lz authored Jul 26, 2021
1 parent ef9e27b commit b0de7c9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
5 changes: 5 additions & 0 deletions nni/common/hpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def validate_search_space(
raise ValueError(f'search space "{name}"\'s value is not a list : {spec}')

if type_ == 'choice':
if not all(isinstance(arg, (float, int, str)) for arg in args):
# FIXME: need further check for each algorithm which types are actually supported
# for now validation only prints warning so it doesn't harm
if not isinstance(args[0], dict) or '_name' not in args[0]: # not nested search space
raise ValueError(f'search space "{name}" (choice) should only contain numbers or strings : {spec}')
continue

if type_.startswith('q'):
Expand Down
12 changes: 12 additions & 0 deletions test/ut/sdk/test_hpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,22 @@
'choice': good['choice'],
'randint': good['randint'],
}
good_nested = {
'outer': {
'_type': 'choice',
'_value': [
{ '_name': 'empty' },
{ '_name': 'a', 'a_1': { '_type': 'choice', '_value': ['a', 'b'] } }
]
}
}

bad_type = 'x'
bad_spec_type = { 'x': [1, 2, 3] }
bad_fields = { 'x': { 'type': 'choice', 'value': ['a', 'b'] } }
bad_type_name = { 'x': { '_type': 'choic', '_value': ['a'] } }
bad_value = { 'x': { '_type': 'choice', '_value': 'ab' } }
bad_choice_args = { 'x': { '_type': 'choice', 'value': [ 'a', object() ] } }
bad_2_args = { 'x': { '_type': 'randint', '_value': [1, 2, 3] } }
bad_3_args = { 'x': { '_type': 'quniform', '_value': [0] } }
bad_int_args = { 'x': { '_type': 'randint', '_value': [1.0, 2.0] } }
Expand All @@ -32,11 +42,13 @@

def test_hpo_utils():
assert validate_search_space(good, raise_exception=False)
assert validate_search_space(good_nested, raise_exception=False)
assert not validate_search_space(bad_type, raise_exception=False)
assert not validate_search_space(bad_spec_type, raise_exception=False)
assert not validate_search_space(bad_fields, raise_exception=False)
assert not validate_search_space(bad_type_name, raise_exception=False)
assert not validate_search_space(bad_value, raise_exception=False)
assert not validate_search_space(bad_choice_args, raise_exception=False)
assert not validate_search_space(bad_2_args, raise_exception=False)
assert not validate_search_space(bad_3_args, raise_exception=False)
assert not validate_search_space(bad_int_args, raise_exception=False)
Expand Down

0 comments on commit b0de7c9

Please sign in to comment.