From caa2003ab82904e2cb3ff4337cd0b94b41539421 Mon Sep 17 00:00:00 2001 From: fuqianya Date: Wed, 13 Oct 2021 11:12:34 +0800 Subject: [PATCH] [PaddlePaddle Hackathon] add AlexNet (#36058) * add alexnet --- python/paddle/tests/test_pretrained_model.py | 4 +- python/paddle/tests/test_vision_models.py | 4 +- python/paddle/vision/__init__.py | 2 + python/paddle/vision/models/__init__.py | 6 +- python/paddle/vision/models/alexnet.py | 192 +++++++++++++++++++ 5 files changed, 205 insertions(+), 3 deletions(-) create mode 100644 python/paddle/vision/models/alexnet.py diff --git a/python/paddle/tests/test_pretrained_model.py b/python/paddle/tests/test_pretrained_model.py index b24b51555c581..fba1435c75e9c 100644 --- a/python/paddle/tests/test_pretrained_model.py +++ b/python/paddle/tests/test_pretrained_model.py @@ -52,7 +52,9 @@ def infer(self, arch): np.testing.assert_allclose(res['dygraph'], res['static']) def test_models(self): - arches = ['mobilenet_v1', 'mobilenet_v2', 'resnet18', 'vgg16'] + arches = [ + 'mobilenet_v1', 'mobilenet_v2', 'resnet18', 'vgg16', 'alexnet' + ] for arch in arches: self.infer(arch) diff --git a/python/paddle/tests/test_vision_models.py b/python/paddle/tests/test_vision_models.py index a25a8f373c29c..ea42c22e289ed 100644 --- a/python/paddle/tests/test_vision_models.py +++ b/python/paddle/tests/test_vision_models.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import unittest import numpy as np @@ -71,6 +70,9 @@ def test_resnet101(self): def test_resnet152(self): self.models_infer('resnet152') + def test_alexnet(self): + self.models_infer('alexnet') + def test_vgg16_num_classes(self): vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10) diff --git a/python/paddle/vision/__init__.py b/python/paddle/vision/__init__.py index 76393865ded04..b8ac548a96663 100644 --- a/python/paddle/vision/__init__.py +++ b/python/paddle/vision/__init__.py @@ -44,6 +44,8 @@ from .models import vgg16 # noqa: F401 from .models import vgg19 # noqa: F401 from .models import LeNet # noqa: F401 +from .models import AlexNet # noqa: F401 +from .models import alexnet # noqa: F401 from .transforms import BaseTransform # noqa: F401 from .transforms import Compose # noqa: F401 from .transforms import Resize # noqa: F401 diff --git a/python/paddle/vision/models/__init__.py b/python/paddle/vision/models/__init__.py index d38f3b1722ee8..b85333614637f 100644 --- a/python/paddle/vision/models/__init__.py +++ b/python/paddle/vision/models/__init__.py @@ -28,6 +28,8 @@ from .vgg import vgg16 # noqa: F401 from .vgg import vgg19 # noqa: F401 from .lenet import LeNet # noqa: F401 +from .alexnet import AlexNet # noqa: F401 +from .alexnet import alexnet # noqa: F401 __all__ = [ #noqa 'ResNet', @@ -45,5 +47,7 @@ 'mobilenet_v1', 'MobileNetV2', 'mobilenet_v2', - 'LeNet' + 'LeNet', + 'AlexNet', + 'alexnet' ] diff --git a/python/paddle/vision/models/alexnet.py b/python/paddle/vision/models/alexnet.py new file mode 100644 index 0000000000000..1d36ef37b6ced --- /dev/null +++ b/python/paddle/vision/models/alexnet.py @@ -0,0 +1,192 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddle.nn import Linear, Dropout, ReLU +from paddle.nn import Conv2D, MaxPool2D +from paddle.nn.initializer import Uniform +from paddle.fluid.param_attr import ParamAttr +from paddle.utils.download import get_weights_path_from_url + +model_urls = { + "alexnet": ( + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/AlexNet_pretrained.pdparams", + "7f0f9f737132e02732d75a1459d98a43", ) +} + +__all__ = [] + + +class ConvPoolLayer(nn.Layer): + def __init__(self, + input_channels, + output_channels, + filter_size, + stride, + padding, + stdv, + groups=1, + act=None): + super(ConvPoolLayer, self).__init__() + + self.relu = ReLU() if act == "relu" else None + + self._conv = Conv2D( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)), + bias_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) + self._pool = MaxPool2D(kernel_size=3, stride=2, padding=0) + + def forward(self, inputs): + x = self._conv(inputs) + if self.relu is not None: + x = self.relu(x) + x = self._pool(x) + return x + + +class AlexNet(nn.Layer): + """AlexNet model from + `"ImageNet Classification with Deep Convolutional Neural Networks" + `_ + + Args: + num_classes (int): Output dim of last fc layer. Default: 1000. + + Examples: + .. code-block:: python + + from paddle.vision.models import AlexNet + + alexnet = AlexNet() + + """ + + def __init__(self, num_classes=1000): + super(AlexNet, self).__init__() + self.num_classes = num_classes + stdv = 1.0 / math.sqrt(3 * 11 * 11) + self._conv1 = ConvPoolLayer(3, 64, 11, 4, 2, stdv, act="relu") + stdv = 1.0 / math.sqrt(64 * 5 * 5) + self._conv2 = ConvPoolLayer(64, 192, 5, 1, 2, stdv, act="relu") + stdv = 1.0 / math.sqrt(192 * 3 * 3) + self._conv3 = Conv2D( + 192, + 384, + 3, + stride=1, + padding=1, + weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)), + bias_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) + stdv = 1.0 / math.sqrt(384 * 3 * 3) + self._conv4 = Conv2D( + 384, + 256, + 3, + stride=1, + padding=1, + weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)), + bias_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) + stdv = 1.0 / math.sqrt(256 * 3 * 3) + self._conv5 = ConvPoolLayer(256, 256, 3, 1, 1, stdv, act="relu") + + if self.num_classes > 0: + stdv = 1.0 / math.sqrt(256 * 6 * 6) + self._drop1 = Dropout(p=0.5, mode="downscale_in_infer") + self._fc6 = Linear( + in_features=256 * 6 * 6, + out_features=4096, + weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)), + bias_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) + + self._drop2 = Dropout(p=0.5, mode="downscale_in_infer") + self._fc7 = Linear( + in_features=4096, + out_features=4096, + weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)), + bias_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) + self._fc8 = Linear( + in_features=4096, + out_features=num_classes, + weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)), + bias_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) + + def forward(self, inputs): + x = self._conv1(inputs) + x = self._conv2(x) + x = self._conv3(x) + x = F.relu(x) + x = self._conv4(x) + x = F.relu(x) + x = self._conv5(x) + + if self.num_classes > 0: + x = paddle.flatten(x, start_axis=1, stop_axis=-1) + x = self._drop1(x) + x = self._fc6(x) + x = F.relu(x) + x = self._drop2(x) + x = self._fc7(x) + x = F.relu(x) + x = self._fc8(x) + + return x + + +def _alexnet(arch, pretrained, **kwargs): + model = AlexNet(**kwargs) + + if pretrained: + assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( + arch) + weight_path = get_weights_path_from_url(model_urls[arch][0], + model_urls[arch][1]) + + param = paddle.load(weight_path) + model.load_dict(param) + + return model + + +def alexnet(pretrained=False, **kwargs): + """AlexNet model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False. + + Examples: + .. code-block:: python + + from paddle.vision.models import alexnet + + # build model + model = alexnet() + + # build model and load imagenet pretrained weight + # model = alexnet(pretrained=True) + """ + return _alexnet('alexnet', pretrained, **kwargs)