reference: Dynamic routing between capsules by Sara Sabour, Nicholas Frosst, Geoffrey E Hinton
Note: this implementation strictly follow the instructions of the paper, check the paper for details.
The key of the paper is not how accurate the CapsNet
is, but the novel idea of representation of image with capsule
.
- Codes are tested on
tensorflow 1.3
, andpython 2.7
. But it should be compatible withpython 3.x
- Other dependencies as follows,
six>=1.11
matplotlib>=2.0.2
numpy>=1.7.1
scipy>=0.13.2
easydict>=1.6
tqdm>=4.17.1
install by running
$ cd $ROOT
$ pip install -r requirements.txt
NOTE: all the experiments conducted on the checkpoint: Jbox(SJTU) or Google_Drive
By running:
$ cd code
$ python eval.py --ckpt 'path/to/ckpt' --mode reconstruct
reconstruct results: (Note: the float numbers on the row with even number are max norm of the 10 digit capsules)
By running:
$ cd code
$ python eval.py --ckpt 'path/to/ckpt' --mode cap_tweak
results:
Note: images along x-axis
are representations of units of 16-D vector, and y-axis
corresponds
to the tweak range of [-0.25, 0.25] with stride 0.05.
By running:
$ cd code
$ python eval.py --ckpt 'path/to/ckpt' --mode adversarial
result:
the adversarial result is not as good as i expected, I was hoping that capsule
representation
would be more robust to adversarial attack.
Note: all trained with batch_size = 100
latest commit with 3 iterations of dynamic routing
:
1. update dynamic routing with tf.while_loop and static way
2. fix margin loss issue
result:
Iterations | 1k | 2k | 3k | 4k | 5k |
---|---|---|---|---|---|
val_acc | 98.90 | 99.16 | 99.09 | 99.30 | 99.24 |
test_acc | - | - | - | - | 99.21 |
commit 8e3785d.
with bugs:
1. wrong implementation of margin loss
2. updating `prior` during routing
result:
Iterations | 2k | 4k | 5k | 7k | 9k | 10k |
---|---|---|---|---|---|---|
val_acc | 98.02 | 98.58 | - | 98.82 | 98.96 | - |
test_acc | - | - | 98.89 | - | - | 99.09 |
- clone the repo, and set up parameters in
code/config.py
- then
$ cd $ROOT/code
$ python train.py --data_dir 'path/to/data' --max_iters 10000 --ckpt 'OPTIONAL:path/to/ckpt' --batch_size 100
or train with logs by runing(NOTE: set extra arguments in train.sh accordingly):
$ cd $ROOT/code
$ bash train.sh
The less accurate may due to the missingDifferent input size.3M
parameters.(My implementaion with 8M compared to 11M referred in the paper.)- The model is still under-fitting.
- report exclusive experiment results
- try to fix the inefficacy
- Keras implementation
- Discussion about routing algorithm, issue and issue