Skip to content

Commit

Permalink
ax/ax/utils/testing/core_stubs.py readability
Browse files Browse the repository at this point in the history
Summary:
When I was adding to `core_stubs.py`, it was sometimes hard to find what I needed. How do we feel about these blockier comments and additional partitions?

Definitely feel free to shoot this down -- I put this up really quickly 6 days ago and just remembered that I made this.

Reviewed By: lena-kashtelyan

Differential Revision: D23142772

fbshipit-source-id: 9382db43954fcfb71e48853ee4be2ab9f1ae2cf8
  • Loading branch information
EricZLou authored and facebook-github-bot committed Aug 24, 2020
1 parent ddb6a0e commit fc4cd2a
Showing 1 changed file with 61 additions and 27 deletions.
88 changes: 61 additions & 27 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@
logger = get_logger(__name__)


##############################
# Experiments
##############################


def get_experiment() -> Experiment:
Expand Down Expand Up @@ -284,7 +286,9 @@ def get_experiment_with_scalarized_objective() -> Experiment:
)


##############################
# Search Spaces
##############################


def get_search_space() -> SearchSpace:
Expand Down Expand Up @@ -420,7 +424,9 @@ def get_discrete_search_space() -> SearchSpace:
)


##############################
# Trials
##############################


def get_batch_trial(abandon_arm: bool = True) -> BatchTrial:
Expand Down Expand Up @@ -490,7 +496,9 @@ def get_trial() -> Trial:
return trial


##############################
# Parameters
##############################


def get_range_parameter() -> RangeParameter:
Expand Down Expand Up @@ -523,7 +531,9 @@ def get_fixed_parameter() -> FixedParameter:
return FixedParameter(name="z", parameter_type=ParameterType.BOOL, value=True)


##############################
# Parameter Constraints
##############################


def get_order_constraint() -> OrderConstraint:
Expand All @@ -548,7 +558,9 @@ def get_sum_constraint2() -> SumConstraint:
return SumConstraint(parameters=[x, w], is_upper_bound=True, bound=10.0)


##############################
# Metrics
##############################


def get_metric() -> Metric:
Expand Down Expand Up @@ -594,7 +606,22 @@ def get_factorial_metric(name: str = "success_metric") -> FactorialMetric:
)


# Optimization Configs
##############################
# Outcome Constraints
##############################


def get_outcome_constraint() -> OutcomeConstraint:
return OutcomeConstraint(metric=Metric(name="m2"), op=ComparisonOp.GEQ, bound=-0.25)


def get_branin_outcome_constraint() -> OutcomeConstraint:
return OutcomeConstraint(metric=get_branin_metric(), op=ComparisonOp.LEQ, bound=0)


##############################
# Objectives
##############################


def get_objective() -> Objective:
Expand All @@ -616,18 +643,6 @@ def get_scalarized_objective() -> Objective:
)


def get_outcome_constraint() -> OutcomeConstraint:
return OutcomeConstraint(metric=Metric(name="m2"), op=ComparisonOp.GEQ, bound=-0.25)


def get_optimization_config() -> OptimizationConfig:
objective = get_objective()
outcome_constraints = [get_outcome_constraint()]
return OptimizationConfig(
objective=objective, outcome_constraints=outcome_constraints
)


def get_branin_objective() -> Objective:
return Objective(metric=get_branin_metric(), minimize=False)

Expand All @@ -638,8 +653,29 @@ def get_branin_multi_objective() -> Objective:
)


def get_branin_outcome_constraint() -> OutcomeConstraint:
return OutcomeConstraint(metric=get_branin_metric(), op=ComparisonOp.LEQ, bound=0)
def get_augmented_branin_objective() -> Objective:
return Objective(metric=get_augmented_branin_metric(), minimize=False)


def get_hartmann_objective() -> Objective:
return Objective(metric=get_hartmann_metric(), minimize=False)


def get_augmented_hartmann_objective() -> Objective:
return Objective(metric=get_augmented_hartmann_metric(), minimize=False)


##############################
# Optimization Configs
##############################


def get_optimization_config() -> OptimizationConfig:
objective = get_objective()
outcome_constraints = [get_outcome_constraint()]
return OptimizationConfig(
objective=objective, outcome_constraints=outcome_constraints
)


def get_optimization_config_no_constraints() -> OptimizationConfig:
Expand All @@ -654,22 +690,10 @@ def get_branin_multi_objective_optimization_config() -> OptimizationConfig:
return OptimizationConfig(objective=get_branin_multi_objective())


def get_augmented_branin_objective() -> Objective:
return Objective(metric=get_augmented_branin_metric(), minimize=False)


def get_augmented_branin_optimization_config() -> OptimizationConfig:
return OptimizationConfig(objective=get_augmented_branin_objective())


def get_hartmann_objective() -> Objective:
return Objective(metric=get_hartmann_metric(), minimize=False)


def get_augmented_hartmann_objective() -> Objective:
return Objective(metric=get_augmented_hartmann_metric(), minimize=False)


def get_hartmann_optimization_config() -> OptimizationConfig:
return OptimizationConfig(objective=get_hartmann_objective())

Expand All @@ -678,7 +702,9 @@ def get_augmented_hartmann_optimization_config() -> OptimizationConfig:
return OptimizationConfig(objective=get_augmented_hartmann_objective())


##############################
# Arms
##############################


def get_arm() -> Arm:
Expand Down Expand Up @@ -751,7 +777,9 @@ def get_abandoned_arm() -> AbandonedArm:
return AbandonedArm(name="0_0", reason="foobar", time=datetime.now())


##############################
# Generator Runs
##############################


def get_generator_run() -> GeneratorRun:
Expand Down Expand Up @@ -785,14 +813,18 @@ def get_generator_run2() -> GeneratorRun:
return GeneratorRun(arms=arms, weights=weights)


##############################
# Runners
##############################


def get_synthetic_runner() -> SyntheticRunner:
return SyntheticRunner(dummy_metadata="foobar")


##############################
# Data
##############################


def get_data(trial_index: int = 0) -> Data:
Expand Down Expand Up @@ -821,7 +853,9 @@ def get_branin_data(trial_indices: Optional[Iterable[int]] = None) -> Data:
return Data(df=pd.DataFrame.from_records(df_dicts))


##############################
# Instances of types from core/types.py
##############################


def get_model_mean() -> TModelMean:
Expand Down

0 comments on commit fc4cd2a

Please sign in to comment.