-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodel_utils.py
112 lines (101 loc) · 3.6 KB
/
model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import matplotlib.pyplot as plt
import mindspore
import mindspore.dataset as ds
import os
import numpy as np
from mindspore import nn
from mindspore import Model
from mindspore.dataset import transforms
from mindspore.dataset import vision
import mindspore.ops as ops
from data_utils import FemnistValTest
import copy
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, Callback
def mkdirs(dirpath):
"""
Create Folder
"""
try:
os.makedirs(dirpath)
except Exception as _:
pass
def get_model_list(root,name, models_ini_list, models):
"""
Load a saved model for testing
"""
model_list=[]
for i in range(len(name)):
data_path = os.path.join(root, name[i])
param_dict = load_checkpoint(data_path)
net = models[models_ini_list[i]["model_type"]](models_ini_list[i]["params"])
load_param_into_net(net, param_dict)
model_list.append(net)
return model_list
def get_femnist_model_list(root,name, models_ini_list, models):
"""
Load a saved model for testing
"""
model_list=[]
for i in range(len(name)):
if i in [0, 1, 2, 3, 4]:
param_dict = load_checkpoint(os.path.join(root[i], name[i]))
net = models[models_ini_list[i]["model_type"]](models_ini_list[i]["params"])
load_param_into_net(net, param_dict)
model_list.append(net)
return model_list
def test_models_femnist(device, models_list, test_x, test_y,savelurl):
"""
Test models
"""
dataset_sink = mindspore.context.get_context('device_target') == 'CPU'
apply_transform = transforms.py_transforms.Compose([vision.py_transforms.ToTensor(),
vision.py_transforms.Normalize((0.1307,),(0.3081,))])
femnist_bal_data_test = FemnistValTest(test_x, test_y, apply_transform)
testloader = ds.GeneratorDataset(femnist_bal_data_test, ["data", "label"], shuffle=True)
testloader = testloader.batch(batch_size=128)
accuracy_list = []
loss = NLLLoss()
for n, model in enumerate(models_list):
model = Model(model, loss, metrics={"accuracy"})
acc = model.eval(testloader, dataset_sink_mode=dataset_sink)
accuracy = acc['accuracy']
accuracy_list.append(accuracy)
print(accuracy_list)
return accuracy_list
def average_weights(w):
"""
Average model weights
"""
w_avg = copy.deepcopy(w[0])
for key in w_avg.keys():
for i in range(1,len(w)):
w_avg[key] +=w[i][key]
div = ops.Div()
w_avg[key] = div(w_avg[key],len(w))
return w_avg
class NLLLoss(nn.LossBase):
"""
NLLLoss loss function
"""
def __init__(self, reduction='mean'):
super(NLLLoss, self).__init__(reduction)
self.one_hot = ops.OneHot()
self.reduce_sum = ops.ReduceSum()
def construct(self, logits, label):
label_one_hot = self.one_hot(label, ops.shape(logits)[-1], ops.scalar_to_array(1.0), ops.scalar_to_array(0.0))
loss = self.reduce_sum(-1.0 * logits * label_one_hot, (1,))
return self.get_loss(loss)
class EarlyStop(Callback):
"""
Early stopping
"""
def __init__(self, control_loss=1):
super(EarlyStop, self).__init__()
self._control_loss = control_loss
def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs
if loss.asnumpy() < self._control_loss:
# Stop training
run_context._stop_requested = True