Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retain original data timestamp in Experiment.clone_with #2269

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,6 +1542,10 @@ def clone_with(
r"""
Return a copy of this experiment with some attributes replaced.

NOTE: This method only retains the latest data attached to the experiment.
This is the same data that would be accessed using common APIs such as
`Experiment.lookup_data()`.

Args:
search_space: New search space. If None, it uses the cloned search space
of the original experiment.
Expand All @@ -1561,7 +1565,7 @@ def clone_with(
trial_indices: If specified, only clones the specified trials. If None,
clones all trials.
data: If specified, attach this data to the cloned experiment. If None,
clones the data attached to the original experiment if
clones the latest data attached to the original experiment if
the experiment has any data.
"""
search_space = (
Expand Down Expand Up @@ -1611,39 +1615,40 @@ def clone_with(
default_data_type=self._default_data_type,
)

datas = []
# clone only the specified trials
# Clone only the specified trials.
original_trial_indices = self.trials.keys()
# pyre-fixme[9]: trial_indices has type `Optional[List[int]]`; used as
# `Set[int]`.
trial_indices = (
trial_indices_to_keep = (
set(original_trial_indices) if trial_indices is None else set(trial_indices)
)
if (
# pyre-fixme[16]: `Optional` has no attribute `difference`.
len(trial_indices_diff := trial_indices.difference(original_trial_indices))
> 0
if trial_indices_diff := trial_indices_to_keep.difference(
original_trial_indices
):
warnings.warn(
f"Trials indexed with {trial_indices_diff} are not a part "
"of the original experiment. ",
stacklevel=2,
)
# pyre-fixme[16]: `Optional` has no attribute `intersection`.
for trial_index in trial_indices.intersection(original_trial_indices):

data_by_trial = {}
for trial_index in trial_indices_to_keep.intersection(original_trial_indices):
trial = self.trials[trial_index]
if isinstance(trial, BatchTrial) or isinstance(trial, Trial):
trial.clone_to(cloned_experiment)
trial_data, storage_time = self.lookup_data_for_trial(trial_index)
if (trial_data is not None) and (storage_time is not None):
datas.append(trial_data)
# Get the data with the latest timestamp from the experiment.
all_trial_data = self._data_by_trial.get(trial_index, None)
if all_trial_data:
max_timestamp = max(all_trial_data.keys())
data_by_trial[trial_index] = OrderedDict(
[(max_timestamp, all_trial_data[max_timestamp])]
)
else:
raise NotImplementedError(f"Cloning of {type(trial)} is not supported.")

if (data is None) and (len(datas) > 0):
data = self.default_data_constructor.from_multiple_data(datas)
if data is not None:
if data:
# If user passed in data, use it.
cloned_experiment.attach_data(data)
else:
# Otherwise, attach the data extracted from the original experiment.
cloned_experiment._data_by_trial = data_by_trial

return cloned_experiment

Expand Down
12 changes: 11 additions & 1 deletion ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
# pyre-strict

import logging
from collections import OrderedDict
from typing import Dict, List, Type
from unittest.mock import MagicMock, patch

import numpy as np

import pandas as pd
import torch
from ax.core import BatchTrial, Trial
Expand Down Expand Up @@ -1041,6 +1041,7 @@ def test_clone_with(self) -> None:
search_space=larger_search_space,
status_quo=new_status_quo,
)
self.assertEqual(cloned_experiment._data_by_trial, experiment._data_by_trial)
self.assertEqual(len(cloned_experiment.trials), 2)
x1 = checked_cast(
RangeParameter, cloned_experiment.search_space.parameters["x1"]
Expand Down Expand Up @@ -1105,7 +1106,16 @@ def test_clone_with(self) -> None:
status_quo=new_status_quo,
)
new_data = cloned_experiment.lookup_data()
self.assertNotEqual(cloned_experiment._data_by_trial, experiment._data_by_trial)
self.assertIsInstance(new_data, MapData)
expected_data_by_trial = {}
for trial_index in experiment.trials:
original_trial_data = experiment._data_by_trial.get(trial_index, None)
if original_trial_data:
expected_data_by_trial[trial_index] = OrderedDict(
list(original_trial_data.items())[-1:]
)
self.assertEqual(cloned_experiment.data_by_trial, expected_data_by_trial)

experiment = get_experiment()
cloned_experiment = experiment.clone_with()
Expand Down
Loading