Implementation of a multi-class classifier using the Animals-10 dataset from kaggle
This repository contains the offline training of the model using PyTorch, an API implementing the inference of the trained model and a basic front-end interface that the client can interact with.
Go to https://whichisit.netlify.app/ to make predictions on new animals
categories = ['sheep', 'cat', 'cow', 'butterfly', 'dog', 'squirrel', 'chicken', 'spider', 'elephant', 'horse']
We tested as models simple CNN and a resnet18, best results so far obtained with resnet18
-
Create virtual environment and install necessary modules with
pip
:Create venv:
python3 -m venv .venv
Activate it:
. .venv/bin/activate
pip install -r requirements-dev.txt
-
Run
split_data.py
to split the processed dataset in 80% for training, 10% for validation and 10% for final testing. You can change the ratios as you wish.python3 ./app/split_data.py --ratio 0.8 0.1 0.1
-
Run
process_images.py
in order to resize and crop raw images (in order for them to be of the same square size) and generate csv file with image filename and associated label (animal type)python3 ./app/process_images.py --size 256 --set train python3 ./app/process_images.py --size 256 --set val python3 ./app/process_images.py --size 256 --set test
We choose 256x256 px, but that can be changed.
-
Run training script
train_model.py --model <model>
It is recommended to run the script in the background and throw the prints in a log file, like this:
python3 ./app/train_model.py --model <model> > ./logs/<log_file> &
Where could be either a simple
cnn
orresnet18
loaded fromtorchvision
-
Test model
Run
test_model.py
to test the model intrain
,val
ortest
setspython3 ./app/test_model.py --model <model> --set <set>
-
Inference
Make predictions on random images from
test
set usinginfer.py
script.python3 ./app/infer.py --model resnet18 --samples <num_samples>
Based on the trained model we implemented a Rest API using Flask to create and endpoint that makes predictions on a new image of an animal. The API is deployed in heroku and this endpoint only accepts POST requests with an image file in the body:
You can either deploy the app with Heroku CLI or in Heroku dashboard
-
Login with Heroku
heroku login -i
-
Create heroku app
heroku create animal-classifier
-
Test heroku locally
heroku local
-
Associate heroku app with git repository
heroku git:remote -a animal_classifier01
-
Push to heroku
git push heroku main
In order for a user to make predictions on new images we implemented a basic ReactJS app as our project frontend.
-
Create react app
With
node
andnpm
installed:npx create-react-app whichisit
-
Start a development server:
npm start
-
Set an environment variable with the API endpoint
i. Create a
.env
file with:REACT_APP_API = "https://animal-classifier01.herokuapp.com/infer"
-
Implement a fetch
POST
request to the endpoint:fetch(`${process.env.REACT_APP_API}`, { method: "POST", body: image_data })
-
When you are ready to deploy to production create
build
withnpm
:npm run build
-
Deploy using Netlify
We found easier to deploy a React app in Netlify than in Heroku.
i. In the Netlify dashboard change the following Build settings (note that the React root directory is a subfolder of our repository)
Base directory: client Build command: npm run build Publish directory: client/build