-
Notifications
You must be signed in to change notification settings - Fork 68
/
Copy pathtest_adversarial_examples.py
64 lines (54 loc) · 2.56 KB
/
test_adversarial_examples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torchvision.datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import models
from models import MNIST_target_net
use_cuda=True
image_nc=1
batch_size = 128
gen_input_nc = image_nc
# Define what device we are using
print("CUDA Available: ",torch.cuda.is_available())
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
# load the pretrained model
pretrained_model = "./MNIST_target_model.pth"
target_model = MNIST_target_net().to(device)
target_model.load_state_dict(torch.load(pretrained_model))
target_model.eval()
# load the generator of adversarial examples
pretrained_generator_path = './models/netG_epoch_60.pth'
pretrained_G = models.Generator(gen_input_nc, image_nc).to(device)
pretrained_G.load_state_dict(torch.load(pretrained_generator_path))
pretrained_G.eval()
# test adversarial examples in MNIST training dataset
mnist_dataset = torchvision.datasets.MNIST('./dataset', train=True, transform=transforms.ToTensor(), download=True)
train_dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=False, num_workers=1)
num_correct = 0
for i, data in enumerate(train_dataloader, 0):
test_img, test_label = data
test_img, test_label = test_img.to(device), test_label.to(device)
perturbation = pretrained_G(test_img)
perturbation = torch.clamp(perturbation, -0.3, 0.3)
adv_img = perturbation + test_img
adv_img = torch.clamp(adv_img, 0, 1)
pred_lab = torch.argmax(target_model(adv_img),1)
num_correct += torch.sum(pred_lab==test_label,0)
print('MNIST training dataset:')
print('num_correct: ', num_correct.item())
print('accuracy of adv imgs in training set: %f\n'%(num_correct.item()/len(mnist_dataset)))
# test adversarial examples in MNIST testing dataset
mnist_dataset_test = torchvision.datasets.MNIST('./dataset', train=False, transform=transforms.ToTensor(), download=True)
test_dataloader = DataLoader(mnist_dataset_test, batch_size=batch_size, shuffle=False, num_workers=1)
num_correct = 0
for i, data in enumerate(test_dataloader, 0):
test_img, test_label = data
test_img, test_label = test_img.to(device), test_label.to(device)
perturbation = pretrained_G(test_img)
perturbation = torch.clamp(perturbation, -0.3, 0.3)
adv_img = perturbation + test_img
adv_img = torch.clamp(adv_img, 0, 1)
pred_lab = torch.argmax(target_model(adv_img),1)
num_correct += torch.sum(pred_lab==test_label,0)
print('num_correct: ', num_correct.item())
print('accuracy of adv imgs in testing set: %f\n'%(num_correct.item()/len(mnist_dataset_test)))