Skip to content

Commit

Permalink
Replace List with Sequence when typing MBM
Browse files Browse the repository at this point in the history
Summary: Sequence is a more permissive type hint that works better with Pyre. For example, Pyre complains if we pass `List[MultiTaskDataset]` to a function that expects `List[SupervisedDataset]` (even though `MultiTaskDataset` is a subclass of `SupervisedDataset`) but it is ok with this if we use `Sequnce[SupervisedDataset]` as the type hint.

Differential Revision: D56498406
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Apr 24, 2024
1 parent e279d68 commit 6f3f336
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 14 deletions.
5 changes: 3 additions & 2 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Mapping,
Optional,
OrderedDict,
Sequence,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -252,7 +253,7 @@ def botorch_acqf_class(self) -> Type[AcquisitionFunction]:

def fit(
self,
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
# state dict by surrogate label
Expand Down Expand Up @@ -503,7 +504,7 @@ def evaluate_acquisition_function(
@copy_doc(TorchModel.cross_validate)
def cross_validate(
self,
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
X_test: Tensor,
search_space_digest: SearchSpaceDigest,
**additional_model_inputs: Any,
Expand Down
8 changes: 4 additions & 4 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import inspect
from copy import deepcopy
from logging import Logger
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Type, Union
from typing import Any, Dict, List, Optional, OrderedDict, Sequence, Tuple, Type, Union

import torch
from ax.core.search_space import SearchSpaceDigest
Expand Down Expand Up @@ -238,7 +238,7 @@ def model(self) -> Model:
return self._model

@property
def training_data(self) -> List[SupervisedDataset]:
def training_data(self) -> Sequence[SupervisedDataset]:
if self._training_data is None:
raise ValueError(NOT_YET_FIT_MSG)
return self._training_data
Expand Down Expand Up @@ -492,7 +492,7 @@ def _make_botorch_outcome_transform(

def fit(
self,
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
state_dict: Optional[OrderedDict[str, Tensor]] = None,
Expand Down Expand Up @@ -544,7 +544,7 @@ def fit(

if not should_use_model_list and len(datasets) > 1:
datasets = convert_to_block_design(datasets=datasets, force=True)
self._training_data = datasets
self._training_data = list(datasets)

models = []
outcome_names = []
Expand Down
27 changes: 19 additions & 8 deletions ax/models/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,17 @@

import warnings
from logging import Logger
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Type
from typing import (
Any,
Callable,
Dict,
List,
Optional,
OrderedDict,
Sequence,
Tuple,
Type,
)

import torch
from ax.core.search_space import SearchSpaceDigest
Expand Down Expand Up @@ -41,7 +51,7 @@


def use_model_list(
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
botorch_model_class: Type[Model],
allow_batched_models: bool = True,
) -> bool:
Expand All @@ -67,7 +77,7 @@ def use_model_list(


def choose_model_class(
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
) -> Type[Model]:
"""Chooses a BoTorch `Model` using the given data (currently just Yvars)
Expand Down Expand Up @@ -173,7 +183,7 @@ def construct_acquisition_and_optimizer_options(


def convert_to_block_design(
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
force: bool = False,
) -> List[SupervisedDataset]:
# Convert data to "block design". TODO: Figure out a better
Expand Down Expand Up @@ -211,6 +221,7 @@ def convert_to_block_design(
"to block design by dropping observations that are not shared "
"between outcomes.",
AxWarning,
stacklevel=3,
)
X_shared, idcs_shared = _get_shared_rows(Xs=Xs)
Y = torch.cat([ds.Y[i] for ds, i in zip(datasets, idcs_shared)], dim=-1)
Expand Down Expand Up @@ -347,8 +358,8 @@ def combined_func(x: Tensor) -> Tensor:


def check_outcome_dataset_match(
outcome_names: List[str],
datasets: List[SupervisedDataset],
outcome_names: Sequence[str],
datasets: Sequence[SupervisedDataset],
exact_match: bool,
) -> None:
"""Check that the given outcome names match those of datasets.
Expand Down Expand Up @@ -390,8 +401,8 @@ def check_outcome_dataset_match(


def get_subset_datasets(
datasets: List[SupervisedDataset],
subset_outcome_names: List[str],
datasets: Sequence[SupervisedDataset],
subset_outcome_names: Sequence[str],
) -> List[SupervisedDataset]:
"""Get the list of datasets corresponding to the given subset of
outcome names. This is used to separate out datasets that are
Expand Down

0 comments on commit 6f3f336

Please sign in to comment.