|
4 | 4 | # This source code is licensed under the MIT license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import logging |
7 | 8 | from typing import cast, List
|
8 | 9 | from unittest import mock
|
9 | 10 | from unittest.mock import MagicMock, patch
|
|
21 | 22 | from ax.exceptions.core import DataRequiredError, UserInputError
|
22 | 23 | from ax.exceptions.generation_strategy import (
|
23 | 24 | GenerationStrategyCompleted,
|
| 25 | + GenerationStrategyMisconfiguredException, |
24 | 26 | GenerationStrategyRepeatedPoints,
|
25 | 27 | MaxParallelismReachedException,
|
26 | 28 | )
|
27 | 29 | from ax.modelbridge.discrete import DiscreteModelBridge
|
28 | 30 | from ax.modelbridge.factory import get_sobol
|
| 31 | +from ax.modelbridge.generation_node import GenerationNode |
29 | 32 | from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
|
30 | 33 | from ax.modelbridge.model_spec import ModelSpec
|
31 | 34 | from ax.modelbridge.random import RandomModelBridge
|
32 | 35 | from ax.modelbridge.registry import Cont_X_trans, MODEL_KEY_TO_MODEL_SETUP, Models
|
33 | 36 | from ax.modelbridge.torch import TorchModelBridge
|
| 37 | +from ax.modelbridge.transition_criterion import ( |
| 38 | + MaxGenerationParallelism, |
| 39 | + MaxTrials, |
| 40 | + MinTrials, |
| 41 | +) |
34 | 42 | from ax.models.random.sobol import SobolGenerator
|
35 | 43 | from ax.utils.common.equality import same_elements
|
36 | 44 | from ax.utils.common.mock import mock_patch_method_original
|
@@ -416,8 +424,8 @@ def test_clone_reset(self) -> None:
|
416 | 424 | ]
|
417 | 425 | )
|
418 | 426 | ftgs._curr = ftgs._steps[1]
|
419 |
| - self.assertEqual(ftgs._curr.index, 1) |
420 |
| - self.assertEqual(ftgs.clone_reset()._curr.index, 0) |
| 427 | + self.assertEqual(ftgs.current_step_index, 1) |
| 428 | + self.assertEqual(ftgs.clone_reset().current_step_index, 0) |
421 | 429 |
|
422 | 430 | def test_kwargs_passed(self) -> None:
|
423 | 431 | gs = GenerationStrategy(
|
@@ -527,10 +535,12 @@ def test_trials_as_df(self) -> None:
|
527 | 535 | # attach necessary trials to fill up the Generation Strategy
|
528 | 536 | trial = exp.new_trial(sobol_generation_strategy.gen(experiment=exp))
|
529 | 537 | self.assertEqual(
|
530 |
| - sobol_generation_strategy.trials_as_df.head()["Generation Step"][0], 0 |
| 538 | + sobol_generation_strategy.trials_as_df.head()["Generation Step"][0], |
| 539 | + "GenerationStep_0", |
531 | 540 | )
|
532 | 541 | self.assertEqual(
|
533 |
| - sobol_generation_strategy.trials_as_df.head()["Generation Step"][2], 1 |
| 542 | + sobol_generation_strategy.trials_as_df.head()["Generation Step"][2], |
| 543 | + "GenerationStep_1", |
534 | 544 | )
|
535 | 545 |
|
536 | 546 | def test_max_parallelism_reached(self) -> None:
|
@@ -883,6 +893,248 @@ def test_gen_for_multiple_trials_with_multiple_models(self) -> None:
|
883 | 893 | for p in original_pending[m]:
|
884 | 894 | self.assertIn(p, pending[m])
|
885 | 895 |
|
| 896 | + # ---------- Tests for GenerationStrategies composed of GenerationNodes -------- |
| 897 | + def test_gs_setup_with_nodes(self) -> None: |
| 898 | + """Test GS initalization and validation with nodes""" |
| 899 | + sobol_model_spec = ModelSpec( |
| 900 | + model_enum=Models.SOBOL, |
| 901 | + model_kwargs={}, |
| 902 | + model_gen_kwargs={"n": 2}, |
| 903 | + ) |
| 904 | + node_1_criterion = [ |
| 905 | + MaxTrials( |
| 906 | + threshold=4, |
| 907 | + block_gen_if_met=False, |
| 908 | + transition_to="node_2", |
| 909 | + only_in_statuses=None, |
| 910 | + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], |
| 911 | + ), |
| 912 | + MinTrials( |
| 913 | + only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], |
| 914 | + threshold=2, |
| 915 | + transition_to="node_2", |
| 916 | + ), |
| 917 | + MaxGenerationParallelism( |
| 918 | + threshold=1, |
| 919 | + only_in_statuses=[TrialStatus.RUNNING], |
| 920 | + block_gen_if_met=True, |
| 921 | + block_transition_if_unmet=False, |
| 922 | + ), |
| 923 | + ] |
| 924 | + |
| 925 | + # check error raised if node names are not unique |
| 926 | + with self.assertRaisesRegex( |
| 927 | + GenerationStrategyMisconfiguredException, "All node names" |
| 928 | + ): |
| 929 | + GenerationStrategy( |
| 930 | + nodes=[ |
| 931 | + GenerationNode( |
| 932 | + node_name="node_1", |
| 933 | + transition_criteria=node_1_criterion, |
| 934 | + model_specs=[sobol_model_spec], |
| 935 | + ), |
| 936 | + GenerationNode( |
| 937 | + node_name="node_1", |
| 938 | + model_specs=[sobol_model_spec], |
| 939 | + ), |
| 940 | + ], |
| 941 | + ) |
| 942 | + # check error raised if transition to arguemnt is not valid |
| 943 | + with self.assertRaisesRegex( |
| 944 | + GenerationStrategyMisconfiguredException, "`transition_to` argument" |
| 945 | + ): |
| 946 | + GenerationStrategy( |
| 947 | + nodes=[ |
| 948 | + GenerationNode( |
| 949 | + node_name="node_1", |
| 950 | + transition_criteria=node_1_criterion, |
| 951 | + model_specs=[sobol_model_spec], |
| 952 | + ), |
| 953 | + GenerationNode( |
| 954 | + node_name="node_3", |
| 955 | + model_specs=[sobol_model_spec], |
| 956 | + ), |
| 957 | + ], |
| 958 | + ) |
| 959 | + |
| 960 | + # check error raised if provided both steps and nodes |
| 961 | + with self.assertRaisesRegex( |
| 962 | + GenerationStrategyMisconfiguredException, "either steps or nodes" |
| 963 | + ): |
| 964 | + GenerationStrategy( |
| 965 | + nodes=[ |
| 966 | + GenerationNode( |
| 967 | + node_name="node_1", |
| 968 | + transition_criteria=node_1_criterion, |
| 969 | + model_specs=[sobol_model_spec], |
| 970 | + ), |
| 971 | + GenerationNode( |
| 972 | + node_name="node_3", |
| 973 | + model_specs=[sobol_model_spec], |
| 974 | + ), |
| 975 | + ], |
| 976 | + steps=[ |
| 977 | + GenerationStep( |
| 978 | + model=Models.SOBOL, |
| 979 | + num_trials=5, |
| 980 | + model_kwargs=self.step_model_kwargs, |
| 981 | + ), |
| 982 | + GenerationStep( |
| 983 | + model=Models.GPEI, |
| 984 | + num_trials=-1, |
| 985 | + model_kwargs=self.step_model_kwargs, |
| 986 | + ), |
| 987 | + ], |
| 988 | + ) |
| 989 | + |
| 990 | + # check error raised if provided both steps and nodes under node list |
| 991 | + with self.assertRaisesRegex( |
| 992 | + GenerationStrategyMisconfiguredException, "must either be a GenerationStep" |
| 993 | + ): |
| 994 | + GenerationStrategy( |
| 995 | + nodes=[ |
| 996 | + GenerationNode( |
| 997 | + node_name="node_1", |
| 998 | + transition_criteria=node_1_criterion, |
| 999 | + model_specs=[sobol_model_spec], |
| 1000 | + ), |
| 1001 | + GenerationStep( |
| 1002 | + model=Models.SOBOL, |
| 1003 | + num_trials=5, |
| 1004 | + model_kwargs=self.step_model_kwargs, |
| 1005 | + ), |
| 1006 | + GenerationNode( |
| 1007 | + node_name="node_2", |
| 1008 | + model_specs=[sobol_model_spec], |
| 1009 | + ), |
| 1010 | + ], |
| 1011 | + ) |
| 1012 | + # check that warning is logged if no nodes have transition arguments |
| 1013 | + with self.assertLogs(GenerationStrategy.__module__, logging.WARNING) as logger: |
| 1014 | + warning_msg = ( |
| 1015 | + "None of the nodes in this GenerationStrategy " |
| 1016 | + "contain a `transition_to` argument in their transition_criteria. " |
| 1017 | + ) |
| 1018 | + GenerationStrategy( |
| 1019 | + nodes=[ |
| 1020 | + GenerationNode( |
| 1021 | + node_name="node_1", |
| 1022 | + model_specs=[sobol_model_spec], |
| 1023 | + ), |
| 1024 | + GenerationNode( |
| 1025 | + node_name="node_3", |
| 1026 | + model_specs=[sobol_model_spec], |
| 1027 | + ), |
| 1028 | + ], |
| 1029 | + ) |
| 1030 | + self.assertTrue( |
| 1031 | + any(warning_msg in output for output in logger.output), |
| 1032 | + logger.output, |
| 1033 | + ) |
| 1034 | + |
| 1035 | + def test_gs_with_generation_nodes(self) -> None: |
| 1036 | + "Simple test of a SOBOL + GPEI GenerationStrategy composed of GenerationNodes" |
| 1037 | + sobol_criterion = [ |
| 1038 | + MaxTrials( |
| 1039 | + threshold=5, |
| 1040 | + transition_to="GPEI_node", |
| 1041 | + block_gen_if_met=True, |
| 1042 | + only_in_statuses=None, |
| 1043 | + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], |
| 1044 | + ) |
| 1045 | + ] |
| 1046 | + gpei_criterion = [ |
| 1047 | + MaxTrials( |
| 1048 | + threshold=2, |
| 1049 | + transition_to=None, |
| 1050 | + block_gen_if_met=True, |
| 1051 | + only_in_statuses=None, |
| 1052 | + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], |
| 1053 | + ) |
| 1054 | + ] |
| 1055 | + sobol_model_spec = ModelSpec( |
| 1056 | + model_enum=Models.SOBOL, |
| 1057 | + model_kwargs=self.step_model_kwargs, |
| 1058 | + model_gen_kwargs={}, |
| 1059 | + ) |
| 1060 | + gpei_model_spec = ModelSpec( |
| 1061 | + model_enum=Models.GPEI, |
| 1062 | + model_kwargs=self.step_model_kwargs, |
| 1063 | + model_gen_kwargs={}, |
| 1064 | + ) |
| 1065 | + sobol_node = GenerationNode( |
| 1066 | + node_name="sobol_node", |
| 1067 | + transition_criteria=sobol_criterion, |
| 1068 | + model_specs=[sobol_model_spec], |
| 1069 | + gen_unlimited_trials=False, |
| 1070 | + ) |
| 1071 | + gpei_node = GenerationNode( |
| 1072 | + node_name="GPEI_node", |
| 1073 | + transition_criteria=gpei_criterion, |
| 1074 | + model_specs=[gpei_model_spec], |
| 1075 | + gen_unlimited_trials=False, |
| 1076 | + ) |
| 1077 | + |
| 1078 | + sobol_GPEI_GS_nodes = GenerationStrategy( |
| 1079 | + name="Sobol+GPEI_Nodes", |
| 1080 | + nodes=[sobol_node, gpei_node], |
| 1081 | + ) |
| 1082 | + exp = get_branin_experiment() |
| 1083 | + self.assertEqual(sobol_GPEI_GS_nodes.name, "Sobol+GPEI_Nodes") |
| 1084 | + self.assertEqual(sobol_GPEI_GS_nodes.model_transitions, [5]) |
| 1085 | + |
| 1086 | + for i in range(7): |
| 1087 | + g = sobol_GPEI_GS_nodes.gen(exp) |
| 1088 | + exp.new_trial(generator_run=g).run() |
| 1089 | + self.assertEqual(len(sobol_GPEI_GS_nodes._generator_runs), i + 1) |
| 1090 | + if i > 4: |
| 1091 | + self.mock_torch_model_bridge.assert_called() |
| 1092 | + else: |
| 1093 | + self.assertEqual(g._model_key, "Sobol") |
| 1094 | + mkw = g._model_kwargs |
| 1095 | + self.assertIsNotNone(mkw) |
| 1096 | + if i > 0: |
| 1097 | + # Generated points are randomized, so checking that they're there. |
| 1098 | + self.assertIsNotNone(mkw.get("generated_points")) |
| 1099 | + else: |
| 1100 | + # This is the first GR, there should be no generated points yet. |
| 1101 | + self.assertIsNone(mkw.get("generated_points")) |
| 1102 | + # Remove the randomized generated points to compare the rest. |
| 1103 | + mkw = mkw.copy() |
| 1104 | + del mkw["generated_points"] |
| 1105 | + self.assertEqual( |
| 1106 | + mkw, |
| 1107 | + { |
| 1108 | + "seed": None, |
| 1109 | + "deduplicate": True, |
| 1110 | + "init_position": i, |
| 1111 | + "scramble": True, |
| 1112 | + "fallback_to_sample_polytope": False, |
| 1113 | + }, |
| 1114 | + ) |
| 1115 | + self.assertEqual( |
| 1116 | + g._bridge_kwargs, |
| 1117 | + { |
| 1118 | + "optimization_config": None, |
| 1119 | + "status_quo_features": None, |
| 1120 | + "status_quo_name": None, |
| 1121 | + "transform_configs": None, |
| 1122 | + "transforms": Cont_X_trans, |
| 1123 | + "fit_out_of_design": False, |
| 1124 | + "fit_abandoned": False, |
| 1125 | + "fit_tracking_metrics": True, |
| 1126 | + "fit_on_init": True, |
| 1127 | + }, |
| 1128 | + ) |
| 1129 | + ms = g._model_state_after_gen |
| 1130 | + self.assertIsNotNone(ms) |
| 1131 | + # Generated points are randomized, so just checking that they are there. |
| 1132 | + self.assertIn("generated_points", ms) |
| 1133 | + # Remove the randomized generated points to compare the rest. |
| 1134 | + ms = ms.copy() |
| 1135 | + del ms["generated_points"] |
| 1136 | + self.assertEqual(ms, {"init_position": i + 1}) |
| 1137 | + |
886 | 1138 | # ------------- Testing helpers (put tests above this line) -------------
|
887 | 1139 |
|
888 | 1140 | def _run_GS_for_N_rounds(
|
|
0 commit comments