Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support on load/save/update properties on SQA experiment #3093

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,18 +180,12 @@ def _get_experiment_sqa(
skip_runners_and_metrics: bool = False,
) -> SQAExperiment:
"""Obtains SQLAlchemy experiment object from DB."""
with session_scope() as session:
query = (
session.query(exp_sqa_class)
.filter_by(name=experiment_name)
# Delay loading trials to a separate call to `_get_trials_sqa` below
.options(noload("trials"))
)

if skip_runners_and_metrics:
query = query.options(noload("runners")).options(noload("trials.runner"))

sqa_experiment = query.one_or_none()
# Delay loading trials to a separate call to `_get_trials_sqa` below
sqa_experiment = _get_experiment_sqa_no_trials(
experiment_name=experiment_name,
exp_sqa_class=exp_sqa_class,
skip_runners_and_metrics=skip_runners_and_metrics,
)

if sqa_experiment is None:
raise ExperimentNotFoundError(f"Experiment '{experiment_name}' not found.")
Expand All @@ -209,6 +203,25 @@ def _get_experiment_sqa(
return sqa_experiment


def _get_experiment_sqa_no_trials(
experiment_name: str,
exp_sqa_class: type[SQAExperiment],
skip_runners_and_metrics: bool = False,
) -> SQAExperiment:
with session_scope() as session:
query = (
session.query(exp_sqa_class)
.filter_by(name=experiment_name)
.options(noload("trials"))
)

if skip_runners_and_metrics:
query = query.options(noload("runners")).options(noload("trials.runner"))

sqa_experiment = query.one_or_none()
return sqa_experiment


def _get_trials_sqa(
experiment_id: int,
trial_sqa_class: type[SQATrial],
Expand Down
20 changes: 20 additions & 0 deletions ax/storage/sqa_store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ax.storage.sqa_store.encoder import Encoder
from ax.storage.sqa_store.sqa_classes import (
SQAData,
SQAExperiment,
SQAGeneratorRun,
SQAMetric,
SQARunner,
Expand Down Expand Up @@ -496,6 +497,25 @@ def update_properties_on_experiment(
)


def update_properties_on_sqa_experiment(
experiment_with_updated_properties: SQAExperiment,
config: SQAConfig | None = None,
) -> None:
config = SQAConfig() if config is None else config
exp_sqa_class = config.class_to_sqa_class[Experiment]

exp_id = experiment_with_updated_properties.id
if exp_id is None:
raise ValueError("Experiment must be saved before being updated.")

with session_scope() as session:
session.query(exp_sqa_class).filter_by(id=exp_id).update(
{
"properties": experiment_with_updated_properties.properties,
}
)


def update_properties_on_trial(
trial_with_updated_properties: BaseTrial,
config: SQAConfig | None = None,
Expand Down
21 changes: 21 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from ax.core.types import ComparisonOp
from ax.exceptions.core import ObjectNotFoundError
from ax.exceptions.storage import JSONDecodeError, SQADecodeError, SQAEncodeError
from ax.fb.storage.sqa_store.constants import FB_SQA_CONFIG
from ax.fb.storage.sqa_store.load_helper import load_sqa_experiment
from ax.metrics.branin import BraninMetric
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.registry import Models
Expand Down Expand Up @@ -70,6 +72,7 @@
save_or_update_trials,
update_generation_strategy,
update_properties_on_experiment,
update_properties_on_sqa_experiment,
update_properties_on_trial,
update_runner_on_experiment,
update_trial_status,
Expand Down Expand Up @@ -223,6 +226,24 @@ def test_ExperimentSaveAndLoad(self) -> None:
loaded_experiment = load_experiment(exp.name)
self.assertEqual(loaded_experiment, exp)

def test_UpdatePropertiesOnSQAExperimentAndLoadSQAExperiment(self) -> None:
for exp in [
self.experiment,
get_experiment_with_map_data_type(),
get_experiment_with_multi_objective(),
get_experiment_with_scalarized_objective_and_outcome_constraint(),
]:
save_experiment(exp)
sqa_exp = self.encoder.experiment_to_sqa(exp)
sqa_config = FB_SQA_CONFIG
sqa_exp.properties = sqa_exp.properties or {}
sqa_exp.properties["property1"] = "property1_value"
update_properties_on_sqa_experiment(sqa_exp, config=sqa_config)

loaded_sqa_exp = load_sqa_experiment(sqa_exp.name)
self.assertIsInstance(loaded_sqa_exp, SQAExperiment)
self.assertEqual(loaded_sqa_exp.properties["property1"], "property1_value")

def test_saving_and_loading_experiment_with_aux_exp(self) -> None:
aux_experiment = Experiment(
name="test_aux_exp_in_SQAStoreTest",
Expand Down
Loading