Skip to content

Commit

Permalink
Import OrderedDict from collections
Browse files Browse the repository at this point in the history
Summary: `typing.OrderedDict` is a deprecated alias of `collections.OrderedDict`: https://docs.python.org/3/library/typing.html#typing.OrderedDict

Differential Revision: D56795174
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Apr 30, 2024
1 parent 1be47e1 commit dd5c66f
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 9 deletions.
3 changes: 2 additions & 1 deletion ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from __future__ import annotations

import inspect
from collections import OrderedDict
from collections.abc import Sequence
from copy import deepcopy
from logging import Logger
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
from ax.core.search_space import SearchSpaceDigest
Expand Down
5 changes: 3 additions & 2 deletions ax/models/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
# pyre-strict

import warnings
from collections import OrderedDict
from collections.abc import Sequence
from logging import Logger
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

import torch
from ax.core.search_space import SearchSpaceDigest
Expand Down Expand Up @@ -463,4 +464,4 @@ def subset_state_dict(
for k, v in state_dict.items()
if k.startswith(expected_substring)
]
return OrderedDict(new_items) # pyre-ignore [29]: T168826187
return OrderedDict(new_items)
4 changes: 2 additions & 2 deletions ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
# pyre-strict

import dataclasses
from collections import OrderedDict
from contextlib import ExitStack
from copy import deepcopy
from typing import Dict, OrderedDict, Type
from typing import Dict, Type
from unittest import mock
from unittest.mock import Mock

Expand Down Expand Up @@ -397,7 +398,6 @@ def test_cross_validate(self, mock_fit: Mock) -> None:

old_surrogate = self.model.surrogates[Keys.ONLY_SURROGATE]
old_surrogate._model = mock.MagicMock()
# pyre-ignore [29]: T168826187
old_surrogate._model.state_dict.return_value = OrderedDict({"key": "val"})

for refit_on_cv, warm_start_refit in [
Expand Down
4 changes: 2 additions & 2 deletions ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import dataclasses
import math
from typing import Any, Dict, OrderedDict, Tuple, Type
from collections import OrderedDict
from typing import Any, Dict, Tuple, Type
from unittest.mock import MagicMock, Mock, patch

import numpy as np
Expand Down Expand Up @@ -954,7 +955,6 @@ def test_fit(

# Should `load_state_dict` when `state_dict` is not `None`
# and `refit` is `False`.
# pyre-ignore [29]: T168826187
state_dict = OrderedDict({"state_attribute": torch.ones(2)})
surrogate._submodels = {} # Prevent re-use of fitted model.
surrogate.fit(
Expand Down
3 changes: 1 addition & 2 deletions ax/models/torch/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

import warnings
from typing import OrderedDict
from collections import OrderedDict

import numpy as np
import torch
Expand Down Expand Up @@ -615,7 +615,6 @@ def test_subset_state_dict(self) -> None:
m0 = SingleTaskGP(train_X=torch.rand(5, 2), train_Y=torch.rand(5, 1))
m1 = SingleTaskGP(train_X=torch.rand(5, 2), train_Y=torch.rand(5, 1))
model_list = ModelListGP(m0, m1)
# pyre-ignore [6]: T168826187
model_list_state_dict = checked_cast(OrderedDict, model_list.state_dict())
# Subset the model dict from model list and check that it is correct.
m0_state_dict = model_list.models[0].state_dict()
Expand Down

0 comments on commit dd5c66f

Please sign in to comment.