Refactored version of https://github.com/M-Nauta/ProtoTree and parts of https://github.com/cfchen-duke/ProtoPNet to make them more modular and easier to use. This will probably be turned into a Python package and moved to a new repository.
- Create a Python >=3.11 environment.
- Install requirements from
requirements.txt
(e.g.pip install -r requirements.txt
). - Install Graphviz. With the current code you need to be able to call
dot
from the terminal. - Install project code in editable mode using
pip install -e .
- You can train the tree model and see its performance on the test set
with
python src/run_model.py --model_type protopnet
(or--model_type prototree
).NOTE:
src/util/args.py
has a list of all args that can be used to configure the run.
- Install requirements from
datasources/requirements-download.txt
. - Run
python datasources/cub_download.py
. - Run
python datasources/cub_preprocess.py
. - (Optional, but recommended) Download a ResNet50 pretrained on iNaturalist2017 (filename on Google Drive:
BBN.iNaturalist2017.res50.180epoch.best_model.pth
) and place it in the foldersrc/features/state_dicts
.
Currently, all these steps are only done manually on a development machine. We should set up a pipeline that does these things automatically and reproducibly.
- Install requirements from
requirements-dev.txt
. - You can lint the code with
black src tests
. - You can check types with
MYPYPATH=src mypy src tests --explicit-package-bases --check-untyped-defs
. Note that it will be helpful to runmypy --install-types
beforehand. - You can run tests with the command
pytest
.