-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
47 lines (40 loc) · 1.31 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import datetime
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader
import os
import numpy as np
class SE(nn.Module):
def __init__(self, inchannels, se_ratio):
super().__init__()
self.AvgPool = nn.AdaptiveAvgPool2d(1)
self.SEblock = nn.Sequential(
nn.Linear(inchannels, int(inchannels/se_ratio)),
nn.ReLU(),
nn.Linear(int(inchannels / se_ratio), inchannels)
)
def forward(self, x):
out = self.AvgPool(x)
out = out.view(x.size(0), -1)
out = self.SEblock(out)
out = out.view(x.size(0), x.size(1), 1, 1)
print(x)
print(out)
return x * torch.sigmoid(out)
x = torch.from_numpy(np.arange(24).reshape((2, 3, 2, 2))).float()
net = SE(3, 1)
net(x)
'''
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=CIFAR100_TRAIN_MEAN, std=CIFAR100_TRAIN_STD)
])
traindata = torchvision.datasets.CIFAR100(root='./data', train=True, download=False, transform=transform_train)
trainloader = DataLoader(traindata, batch_size=128, shuffle=True, num_workers=2)
print(len(trainloader))
print(len(trainloader.dataset))
'''