From a566c64e8d2575bb33d7a6d980376564b0e26b85 Mon Sep 17 00:00:00 2001 From: Chris McBride Date: Tue, 6 Jun 2023 17:23:47 -0400 Subject: [PATCH] refactor random_permutations for simplicity alpha-reorder imports fix broken import update changelog --- doc/changelog.rst | 2 ++ smartsim/entity/strategies.py | 25 +++++++------------------ 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 5900b23bf..f37386608 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -37,6 +37,7 @@ A full list of changes and detailed notes can be found below: Detailed notes +- Simplify code in `random_permutations` parameter generation strategy (PR300_) - Remove wait time associated with Experiment launch summary (PR298_) - Update Redis conf file to conform with Redis v7.0.5 conf file (PR293_) - Migrate from redis-py-cluster to redis-py for cluster status checks (PR292_) @@ -55,6 +56,7 @@ argument name is still `interface` for backward compatibility reasons. (PR281_) - Typehints have been added to public APIs. A makefile target to execute static analysis with mypy is available `make check-mypy`. (PR295_) +.. _PR300: https://github.com/CrayLabs/SmartSim/pull/300 .. _PR298: https://github.com/CrayLabs/SmartSim/pull/298 .. _PR293: https://github.com/CrayLabs/SmartSim/pull/293 .. _PR292: https://github.com/CrayLabs/SmartSim/pull/292 diff --git a/smartsim/entity/strategies.py b/smartsim/entity/strategies.py index 0ef47e13d..803b7934d 100644 --- a/smartsim/entity/strategies.py +++ b/smartsim/entity/strategies.py @@ -56,21 +56,10 @@ def step_values( def random_permutations( param_names: t.List[str], param_values: t.List[t.List[str]], n_models: int = 0 ) -> t.List[t.Dict[str, str]]: - # first, check if we've requested more values than possible. - perms = list(product(*param_values)) - if n_models >= len(perms): - return create_all_permutations(param_names, param_values) - else: - permutations: t.List[t.Dict[str, str]] = [] - permutation_strings = set() - while len(permutations) < n_models: - model_dict = dict( - zip( - param_names, - map(lambda x: x[random.randint(0, len(x) - 1)], param_values), - ) - ) - if str(model_dict) not in permutation_strings: - permutation_strings.add(str(model_dict)) - permutations.append(model_dict) - return permutations + permutations = create_all_permutations(param_names, param_values) + + # sample from available permutations if n_models is specified + if n_models and n_models < len(permutations): + permutations = random.sample(permutations, n_models) + + return permutations