diff --git a/configen/conf/torchvision.yaml b/configen/conf/torchvision.yaml new file mode 100644 index 0000000..53cc6e4 --- /dev/null +++ b/configen/conf/torchvision.yaml @@ -0,0 +1,58 @@ +defaults: + - configen_schema + +configen: + # output directory + output_dir: ${hydra:runtime.cwd} + + header: | + # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + # + # Generated by configen, do not edit. + # See https://github.com/facebookresearch/hydra/tree/master/tools/configen + # fmt: off + # isort:skip_file + # flake8: noqa + + module_path_pattern: "hydra_configs/{{module_path}}.py" + + # list of modules to generate configs for + modules: + - name: torchvision.datasets.vision + classes: + - VisionDataset + - StandardTransform + + - name: torchvision.datasets.mnist + # mnist datasets + classes: + - MNIST + - FashionMNIST + - KMNIST + # TODO: The following need to be manually created for torchvision==0.7 + # - EMNIST + # - QMNIST + + - name: torchvision.models.alexnet + classes: + - AlexNet + + - name: torchvision.models.densenet + classes: + - DenseNet + + - name: torchvision.models.googlenet + classes: + - GoogLeNet + + - name: torchvision.models.mnasnet + classes: + - MNASNet + + - name: torchvision.models.squeezenet + classes: + - SqueezeNet + + - name: torchvision.models.resnet + classes: + - ResNet diff --git a/hydra-configs-torchvision/hydra_configs/torchvision/models/alexnet.py b/hydra-configs-torchvision/hydra_configs/torchvision/models/alexnet.py new file mode 100644 index 0000000..320426b --- /dev/null +++ b/hydra-configs-torchvision/hydra_configs/torchvision/models/alexnet.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Generated by configen, do not edit. +# See https://github.com/facebookresearch/hydra/tree/master/tools/configen +# fmt: off +# isort:skip_file +# flake8: noqa + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class AlexNetConf: + _target_: str = "torchvision.models.alexnet.AlexNet" + num_classes: Any = 1000 diff --git a/hydra-configs-torchvision/hydra_configs/torchvision/models/densenet.py b/hydra-configs-torchvision/hydra_configs/torchvision/models/densenet.py new file mode 100644 index 0000000..c1e3603 --- /dev/null +++ b/hydra-configs-torchvision/hydra_configs/torchvision/models/densenet.py @@ -0,0 +1,22 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Generated by configen, do not edit. +# See https://github.com/facebookresearch/hydra/tree/master/tools/configen +# fmt: off +# isort:skip_file +# flake8: noqa + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class DenseNetConf: + _target_: str = "torchvision.models.densenet.DenseNet" + growth_rate: Any = 32 + block_config: Any = (6, 12, 24, 16) + num_init_features: Any = 64 + bn_size: Any = 4 + drop_rate: Any = 0 + num_classes: Any = 1000 + memory_efficient: Any = False diff --git a/hydra-configs-torchvision/hydra_configs/torchvision/models/googlenet.py b/hydra-configs-torchvision/hydra_configs/torchvision/models/googlenet.py new file mode 100644 index 0000000..7ca48e1 --- /dev/null +++ b/hydra-configs-torchvision/hydra_configs/torchvision/models/googlenet.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Generated by configen, do not edit. +# See https://github.com/facebookresearch/hydra/tree/master/tools/configen +# fmt: off +# isort:skip_file +# flake8: noqa + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class GoogLeNetConf: + _target_: str = "torchvision.models.googlenet.GoogLeNet" + num_classes: Any = 1000 + aux_logits: Any = True + transform_input: Any = False + init_weights: Any = None + blocks: Any = None diff --git a/hydra-configs-torchvision/hydra_configs/torchvision/models/mnasnet.py b/hydra-configs-torchvision/hydra_configs/torchvision/models/mnasnet.py new file mode 100644 index 0000000..8eb872e --- /dev/null +++ b/hydra-configs-torchvision/hydra_configs/torchvision/models/mnasnet.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Generated by configen, do not edit. +# See https://github.com/facebookresearch/hydra/tree/master/tools/configen +# fmt: off +# isort:skip_file +# flake8: noqa + +from dataclasses import dataclass, field +from omegaconf import MISSING +from typing import Any + + +@dataclass +class MNASNetConf: + _target_: str = "torchvision.models.mnasnet.MNASNet" + alpha: Any = MISSING + num_classes: Any = 1000 + dropout: Any = 0.2 diff --git a/hydra-configs-torchvision/hydra_configs/torchvision/models/resnet.py b/hydra-configs-torchvision/hydra_configs/torchvision/models/resnet.py new file mode 100644 index 0000000..2521238 --- /dev/null +++ b/hydra-configs-torchvision/hydra_configs/torchvision/models/resnet.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Generated by configen, do not edit. +# See https://github.com/facebookresearch/hydra/tree/master/tools/configen +# fmt: off +# isort:skip_file +# flake8: noqa + +from dataclasses import dataclass, field +from omegaconf import MISSING +from typing import Any + + +@dataclass +class ResNetConf: + _target_: str = "torchvision.models.resnet.ResNet" + block: Any = MISSING + layers: Any = MISSING + num_classes: Any = 1000 + zero_init_residual: Any = False + groups: Any = 1 + width_per_group: Any = 64 + replace_stride_with_dilation: Any = None + norm_layer: Any = None diff --git a/hydra-configs-torchvision/hydra_configs/torchvision/models/squeezenet.py b/hydra-configs-torchvision/hydra_configs/torchvision/models/squeezenet.py new file mode 100644 index 0000000..34cae35 --- /dev/null +++ b/hydra-configs-torchvision/hydra_configs/torchvision/models/squeezenet.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Generated by configen, do not edit. +# See https://github.com/facebookresearch/hydra/tree/master/tools/configen +# fmt: off +# isort:skip_file +# flake8: noqa + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class SqueezeNetConf: + _target_: str = "torchvision.models.squeezenet.SqueezeNet" + version: Any = "1_0" + num_classes: Any = 1000 diff --git a/hydra-configs-torchvision/tests/test_instantiate_models.py b/hydra-configs-torchvision/tests/test_instantiate_models.py new file mode 100644 index 0000000..9823973 --- /dev/null +++ b/hydra-configs-torchvision/tests/test_instantiate_models.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import pytest +from hydra.utils import get_class, instantiate +from omegaconf import OmegaConf + +import torchvision.models as models + + +from torchvision.models.resnet import BasicBlock +from torchvision.models.resnet import Bottleneck +from typing import Any + +bb = BasicBlock(10, 10) +mnasnet_dict = {"alpha": 1.0, "num_classes": 1000} + + +@pytest.mark.parametrize( + "modulepath, classname, cfg, passthrough_args, passthrough_kwargs, expected", + [ + pytest.param( + "models.alexnet", + "AlexNet", + {}, + [], + {}, + models.AlexNet(), + id="AlexNetConf", + ), + pytest.param( + "models.resnet", + "ResNet", + {"layers": [2, 2, 2, 2]}, + [], + {"block": Bottleneck}, + models.ResNet(block=Bottleneck, layers=[2, 2, 2, 2]), + id="ResNetConf", + ), + pytest.param( + "models.densenet", + "DenseNet", + {}, + [], + {}, + models.DenseNet(), + id="DenseNetConf", + ), + pytest.param( + "models.squeezenet", + "SqueezeNet", + {}, + [], + {}, + models.SqueezeNet(), + id="SqueezeNetConf", + ), + pytest.param( + "models.mnasnet", + "MNASNet", + {"alpha": 1.0}, + [], + {}, + models.MNASNet(alpha=1.0), + id="MNASNetConf", + ), + pytest.param( + "models.googlenet", + "GoogLeNet", + {}, + [], + {}, + models.GoogLeNet(), + id="GoogleNetConf", + ), + ], +) +def test_instantiate_classes( + modulepath: str, + classname: str, + cfg: Any, + passthrough_args: Any, + passthrough_kwargs: Any, + expected: Any, +) -> None: + full_class = f"hydra_configs.torchvision.{modulepath}.{classname}Conf" + schema = OmegaConf.structured(get_class(full_class)) + cfg = OmegaConf.merge(schema, cfg) + obj = instantiate(cfg, *passthrough_args, **passthrough_kwargs) + + assert isinstance(obj, type(expected)) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index c8b2675..ca01ba5 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -2,3 +2,4 @@ git+https://github.com/facebookresearch/hydra#subdirectory=tools/configen git+https://github.com/facebookresearch/hydra torch==1.6.0 torchvision==0.7.0 +scipy