Skip to content

Commit

Permalink
Replace List with Sequence when typing MBM (#2394)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2394

Sequence is a more permissive type hint that works better with Pyre. For example, Pyre complains if we pass `List[MultiTaskDataset]` to a function that expects `List[SupervisedDataset]` (even though `MultiTaskDataset` is a subclass of `SupervisedDataset`) but it is ok with this if we use `Sequnce[SupervisedDataset]` as the type hint.

Reviewed By: esantorella

Differential Revision: D56498406

fbshipit-source-id: bd7c48f0bc81c1481162282d2d037b00fba7d6b9
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Apr 24, 2024
1 parent e279d68 commit 2d9c738
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
5 changes: 3 additions & 2 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

import dataclasses
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from functools import wraps
Expand Down Expand Up @@ -252,7 +253,7 @@ def botorch_acqf_class(self) -> Type[AcquisitionFunction]:

def fit(
self,
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
# state dict by surrogate label
Expand Down Expand Up @@ -503,7 +504,7 @@ def evaluate_acquisition_function(
@copy_doc(TorchModel.cross_validate)
def cross_validate(
self,
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
X_test: Tensor,
search_space_digest: SearchSpaceDigest,
**additional_model_inputs: Any,
Expand Down
5 changes: 3 additions & 2 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

import inspect
from collections.abc import Sequence
from copy import deepcopy
from logging import Logger
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Type, Union
Expand Down Expand Up @@ -492,7 +493,7 @@ def _make_botorch_outcome_transform(

def fit(
self,
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
state_dict: Optional[OrderedDict[str, Tensor]] = None,
Expand Down Expand Up @@ -544,7 +545,7 @@ def fit(

if not should_use_model_list and len(datasets) > 1:
datasets = convert_to_block_design(datasets=datasets, force=True)
self._training_data = datasets
self._training_data = list(datasets) # So that it can be modified if needed.

models = []
outcome_names = []
Expand Down
16 changes: 9 additions & 7 deletions ax/models/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

import warnings
from collections.abc import Sequence
from logging import Logger
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Type

Expand Down Expand Up @@ -41,7 +42,7 @@


def use_model_list(
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
botorch_model_class: Type[Model],
allow_batched_models: bool = True,
) -> bool:
Expand All @@ -67,7 +68,7 @@ def use_model_list(


def choose_model_class(
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
) -> Type[Model]:
"""Chooses a BoTorch `Model` using the given data (currently just Yvars)
Expand Down Expand Up @@ -173,7 +174,7 @@ def construct_acquisition_and_optimizer_options(


def convert_to_block_design(
datasets: List[SupervisedDataset],
datasets: Sequence[SupervisedDataset],
force: bool = False,
) -> List[SupervisedDataset]:
# Convert data to "block design". TODO: Figure out a better
Expand Down Expand Up @@ -211,6 +212,7 @@ def convert_to_block_design(
"to block design by dropping observations that are not shared "
"between outcomes.",
AxWarning,
stacklevel=3,
)
X_shared, idcs_shared = _get_shared_rows(Xs=Xs)
Y = torch.cat([ds.Y[i] for ds, i in zip(datasets, idcs_shared)], dim=-1)
Expand Down Expand Up @@ -347,8 +349,8 @@ def combined_func(x: Tensor) -> Tensor:


def check_outcome_dataset_match(
outcome_names: List[str],
datasets: List[SupervisedDataset],
outcome_names: Sequence[str],
datasets: Sequence[SupervisedDataset],
exact_match: bool,
) -> None:
"""Check that the given outcome names match those of datasets.
Expand Down Expand Up @@ -390,8 +392,8 @@ def check_outcome_dataset_match(


def get_subset_datasets(
datasets: List[SupervisedDataset],
subset_outcome_names: List[str],
datasets: Sequence[SupervisedDataset],
subset_outcome_names: Sequence[str],
) -> List[SupervisedDataset]:
"""Get the list of datasets corresponding to the given subset of
outcome names. This is used to separate out datasets that are
Expand Down

0 comments on commit 2d9c738

Please sign in to comment.