-
Install PyTorch from http://pytorch.org
-
Run the following command to install additional dependencies
pip install -r requirements.txt
We will be using a dataset containing 250 different classes of sketches adapted from the classifysketch dataset. Download the training/validation/test images from here. The test image labels are not provided.
Run the script main.py
to train your model.
Modify main.py
, model.py
and data.py
for your assignment, with an aim to make the validation score better.
-
By default the images are loaded and resized to 64x64 pixels and normalized to zero-mean and standard deviation of 1. See data.py for the
data_transforms
. -
When changing models, you should also add support for your model in the
ModelFactory
class inmodel_factory.py
. This allows to not having to modify the evaluation script after the model has finished training.
As the model trains, model checkpoints are saved to files such as model_x.pth
to the current working directory.
You can take one of the checkpoints and run:
python evaluate.py --data [data_dir] --model [model_file] --model_name [model_name]
That generates a file kaggle.csv
that you can upload to the private kaggle competition website.
We recommend you use an online logger like Weights and Biases to track your experiments. This allows to visualise and compare every experiment you run. In particular, it could come in handy if you use google colab as you might easily loose track of your experiments when your sessions ends.
Note that currently, the code does not support such a logger. It should be pretty straightforward to set it up.
Adapted from Rob Fergus and Soumith Chintala https://github.com/soumith/traffic-sign-detection-homework.
Origial adaptation done by Gul Varol: https://github.com/gulvarol
New Sketch dataset and code adaptation done by Ricardo Garcia and Charles Raude: https://github.com/rjgpinel, http://imagine.enpc.fr/~raudec/