Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace List with Sequence when typing MBM #2394

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Mapping,
Optional,
OrderedDict,
Sequence,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -252,7 +253,7 @@ def botorch_acqf_class(self) -> Type[AcquisitionFunction]:

def fit(
self,
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
# state dict by surrogate label
Expand Down Expand Up @@ -503,7 +504,7 @@ def evaluate_acquisition_function(
@copy_doc(TorchModel.cross_validate)
def cross_validate(
self,
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
X_test: Tensor,
search_space_digest: SearchSpaceDigest,
**additional_model_inputs: Any,
Expand Down
8 changes: 4 additions & 4 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import inspect
from copy import deepcopy
from logging import Logger
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Type, Union
from typing import Any, Dict, List, Optional, OrderedDict, Sequence, Tuple, Type, Union

import torch
from ax.core.search_space import SearchSpaceDigest
Expand Down Expand Up @@ -238,7 +238,7 @@ def model(self) -> Model:
return self._model

@property
def training_data(self) -> List[SupervisedDataset]:
def training_data(self) -> Sequence[SupervisedDataset]:
if self._training_data is None:
raise ValueError(NOT_YET_FIT_MSG)
return self._training_data
Expand Down Expand Up @@ -492,7 +492,7 @@ def _make_botorch_outcome_transform(

def fit(
self,
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
state_dict: Optional[OrderedDict[str, Tensor]] = None,
Expand Down Expand Up @@ -544,7 +544,7 @@ def fit(

if not should_use_model_list and len(datasets) > 1:
datasets = convert_to_block_design(datasets=datasets, force=True)
self._training_data = datasets
self._training_data = list(datasets)

models = []
outcome_names = []
Expand Down
27 changes: 19 additions & 8 deletions ax/models/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,17 @@

import warnings
from logging import Logger
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Type
from typing import (
Any,
Callable,
Dict,
List,
Optional,
OrderedDict,
Sequence,
Tuple,
Type,
)

import torch
from ax.core.search_space import SearchSpaceDigest
Expand Down Expand Up @@ -41,7 +51,7 @@


def use_model_list(
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
botorch_model_class: Type[Model],
allow_batched_models: bool = True,
) -> bool:
Expand All @@ -67,7 +77,7 @@ def use_model_list(


def choose_model_class(
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
) -> Type[Model]:
"""Chooses a BoTorch `Model` using the given data (currently just Yvars)
Expand Down Expand Up @@ -173,7 +183,7 @@ def construct_acquisition_and_optimizer_options(


def convert_to_block_design(
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
force: bool = False,
) -> List[SupervisedDataset]:
# Convert data to "block design". TODO: Figure out a better
Expand Down Expand Up @@ -211,6 +221,7 @@ def convert_to_block_design(
"to block design by dropping observations that are not shared "
"between outcomes.",
AxWarning,
stacklevel=3,
)
X_shared, idcs_shared = _get_shared_rows(Xs=Xs)
Y = torch.cat([ds.Y[i] for ds, i in zip(datasets, idcs_shared)], dim=-1)
Expand Down Expand Up @@ -347,8 +358,8 @@ def combined_func(x: Tensor) -> Tensor:


def check_outcome_dataset_match(
outcome_names: List[str],
datasets: List[SupervisedDataset],
outcome_names: Sequence[str],
datasets: Sequence[SupervisedDataset],
exact_match: bool,
) -> None:
"""Check that the given outcome names match those of datasets.
Expand Down Expand Up @@ -390,8 +401,8 @@ def check_outcome_dataset_match(


def get_subset_datasets(
datasets: List[SupervisedDataset],
subset_outcome_names: List[str],
datasets: Sequence[SupervisedDataset],
subset_outcome_names: Sequence[str],
) -> List[SupervisedDataset]:
"""Get the list of datasets corresponding to the given subset of
outcome names. This is used to separate out datasets that are
Expand Down