Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import OrderedDict from collections #2414

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from __future__ import annotations

import inspect
from collections import OrderedDict
from collections.abc import Sequence
from copy import deepcopy
from logging import Logger
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
from ax.core.search_space import SearchSpaceDigest
Expand Down
5 changes: 3 additions & 2 deletions ax/models/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
# pyre-strict

import warnings
from collections import OrderedDict
from collections.abc import Sequence
from logging import Logger
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

import torch
from ax.core.search_space import SearchSpaceDigest
Expand Down Expand Up @@ -463,4 +464,4 @@ def subset_state_dict(
for k, v in state_dict.items()
if k.startswith(expected_substring)
]
return OrderedDict(new_items) # pyre-ignore [29]: T168826187
return OrderedDict(new_items)
4 changes: 2 additions & 2 deletions ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
# pyre-strict

import dataclasses
from collections import OrderedDict
from contextlib import ExitStack
from copy import deepcopy
from typing import Dict, OrderedDict, Type
from typing import Dict, Type
from unittest import mock
from unittest.mock import Mock

Expand Down Expand Up @@ -397,7 +398,6 @@ def test_cross_validate(self, mock_fit: Mock) -> None:

old_surrogate = self.model.surrogates[Keys.ONLY_SURROGATE]
old_surrogate._model = mock.MagicMock()
# pyre-ignore [29]: T168826187
old_surrogate._model.state_dict.return_value = OrderedDict({"key": "val"})

for refit_on_cv, warm_start_refit in [
Expand Down
4 changes: 2 additions & 2 deletions ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import dataclasses
import math
from typing import Any, Dict, OrderedDict, Tuple, Type
from collections import OrderedDict
from typing import Any, Dict, Tuple, Type
from unittest.mock import MagicMock, Mock, patch

import numpy as np
Expand Down Expand Up @@ -954,7 +955,6 @@ def test_fit(

# Should `load_state_dict` when `state_dict` is not `None`
# and `refit` is `False`.
# pyre-ignore [29]: T168826187
state_dict = OrderedDict({"state_attribute": torch.ones(2)})
surrogate._submodels = {} # Prevent re-use of fitted model.
surrogate.fit(
Expand Down
3 changes: 1 addition & 2 deletions ax/models/torch/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

import warnings
from typing import OrderedDict
from collections import OrderedDict

import numpy as np
import torch
Expand Down Expand Up @@ -615,7 +615,6 @@ def test_subset_state_dict(self) -> None:
m0 = SingleTaskGP(train_X=torch.rand(5, 2), train_Y=torch.rand(5, 1))
m1 = SingleTaskGP(train_X=torch.rand(5, 2), train_Y=torch.rand(5, 1))
model_list = ModelListGP(m0, m1)
# pyre-ignore [6]: T168826187
model_list_state_dict = checked_cast(OrderedDict, model_list.state_dict())
# Subset the model dict from model list and check that it is correct.
m0_state_dict = model_list.models[0].state_dict()
Expand Down