forked from BVLC/caffe
-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pycaffe] basic, partial testing of Net and SGDSolver
- Loading branch information
Showing
2 changed files
with
126 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import unittest | ||
import tempfile | ||
import os | ||
import numpy as np | ||
|
||
import caffe | ||
|
||
def simple_net_file(num_output): | ||
"""Make a simple net prototxt, based on test_net.cpp, returning the name | ||
of the (temporary) file.""" | ||
|
||
f = tempfile.NamedTemporaryFile(delete=False) | ||
f.write("""name: 'testnet' force_backward: true | ||
layers { type: DUMMY_DATA name: 'data' top: 'data' top: 'label' | ||
dummy_data_param { num: 5 channels: 2 height: 3 width: 4 | ||
num: 5 channels: 1 height: 1 width: 1 | ||
data_filler { type: 'gaussian' std: 1 } | ||
data_filler { type: 'constant' } } } | ||
layers { type: CONVOLUTION name: 'conv' bottom: 'data' top: 'conv' | ||
convolution_param { num_output: 11 kernel_size: 2 pad: 3 | ||
weight_filler { type: 'gaussian' std: 1 } | ||
bias_filler { type: 'constant' value: 2 } } | ||
weight_decay: 1 weight_decay: 0 } | ||
layers { type: INNER_PRODUCT name: 'ip' bottom: 'conv' top: 'ip' | ||
inner_product_param { num_output: """ + str(num_output) + """ | ||
weight_filler { type: 'gaussian' std: 2.5 } | ||
bias_filler { type: 'constant' value: -3 } } } | ||
layers { type: SOFTMAX_LOSS name: 'loss' bottom: 'ip' bottom: 'label' | ||
top: 'loss' }""") | ||
f.close() | ||
return f.name | ||
|
||
class TestNet(unittest.TestCase): | ||
def setUp(self): | ||
self.num_output = 13 | ||
net_file = simple_net_file(self.num_output) | ||
self.net = caffe.Net(net_file) | ||
# fill in valid labels | ||
self.net.blobs['label'].data[...] = \ | ||
np.random.randint(self.num_output, | ||
size=self.net.blobs['label'].data.shape) | ||
os.remove(net_file) | ||
|
||
def test_memory(self): | ||
"""Check that holding onto blob data beyond the life of a Net is OK""" | ||
|
||
params = sum(map(list, self.net.params.itervalues()), []) | ||
blobs = self.net.blobs.values() | ||
del self.net | ||
|
||
# now sum everything (forcing all memory to be read) | ||
total = 0 | ||
for p in params: | ||
total += p.data.sum() + p.diff.sum() | ||
for bl in blobs: | ||
total += bl.data.sum() + bl.diff.sum() | ||
|
||
def test_forward_backward(self): | ||
self.net.forward() | ||
self.net.backward() | ||
|
||
def test_inputs_outputs(self): | ||
self.assertEqual(self.net.inputs, []) | ||
self.assertEqual(self.net.outputs, ['loss']) | ||
|
||
def test_save_and_read(self): | ||
f = tempfile.NamedTemporaryFile(delete=False) | ||
f.close() | ||
self.net.save(f.name) | ||
net_file = simple_net_file(self.num_output) | ||
net2 = caffe.Net(net_file, f.name) | ||
os.remove(net_file) | ||
os.remove(f.name) | ||
for name in self.net.params: | ||
for i in range(len(self.net.params[name])): | ||
self.assertEqual(abs(self.net.params[name][i].data | ||
- net2.params[name][i].data).sum(), 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import unittest | ||
import tempfile | ||
import os | ||
import numpy as np | ||
|
||
import caffe | ||
from test_net import simple_net_file | ||
|
||
class TestSolver(unittest.TestCase): | ||
def setUp(self): | ||
self.num_output = 13 | ||
net_f = simple_net_file(self.num_output) | ||
f = tempfile.NamedTemporaryFile(delete=False) | ||
f.write("""net: '""" + net_f + """' | ||
test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9 | ||
weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75 | ||
display: 100 max_iter: 100 snapshot_after_train: false""") | ||
f.close() | ||
self.solver = caffe.SGDSolver(f.name) | ||
self.solver.net.set_mode_cpu() | ||
# fill in valid labels | ||
self.solver.net.blobs['label'].data[...] = \ | ||
np.random.randint(self.num_output, | ||
size=self.solver.net.blobs['label'].data.shape) | ||
self.solver.test_nets[0].blobs['label'].data[...] = \ | ||
np.random.randint(self.num_output, | ||
size=self.solver.test_nets[0].blobs['label'].data.shape) | ||
os.remove(f.name) | ||
os.remove(net_f) | ||
|
||
def test_solve(self): | ||
self.assertEqual(self.solver.iter, 0) | ||
self.solver.solve() | ||
self.assertEqual(self.solver.iter, 100) | ||
|
||
def test_net_memory(self): | ||
"""Check that nets survive after the solver is destroyed.""" | ||
|
||
nets = [self.solver.net] + list(self.solver.test_nets) | ||
self.assertEqual(len(nets), 2) | ||
del self.solver | ||
|
||
total = 0 | ||
for net in nets: | ||
for ps in net.params.itervalues(): | ||
for p in ps: | ||
total += p.data.sum() + p.diff.sum() | ||
for bl in net.blobs.itervalues(): | ||
total += bl.data.sum() + bl.diff.sum() |