Skip to content

Latest commit

 

History

History

This section of the project codebase is an reworking of Renovamen/Text-Classification

See the original code base for more dataset implementations.


Text Classification

PyTorch re-implementation of some text classificaiton models.

Supported Models

Train the following models by editing model_name item in config files (here are some example config files). Click the link of each for details.

Requirements

First, make sure your environment is installed with:

  • Python >= 3.5

Then install requirements:

pip install -r requirements.txt

N.B. We have included the specific requirements of the orignal repository here, and it is benefitial to make a seperate environment for this section.

Dataset

Currently, the following dataset is supported:

  • Incident report binary severity prediction

Set the path to the create binary severity classification dataset (dataset_path) in your config files.

We also include the code and configs for the following tasks:

  • AG News (Click here for details of these datasets)

You should download and unzip it first, then set their path (dataset_path) in your config files. If you would like to use other datasets, they may have to be stored in the same format as the above mentioned datasets.

Pre-trained Word Embeddings

If you would like to use pre-trained word embeddings (like GloVe), just set emb_pretrain to True and specify the path to pre-trained vectors (emb_folder and emb_filename) in your config files. You could also choose to fine-tune word embeddings or not with by editing fine_tune_embeddings item.

Or if you want to randomly initialize the embedding layer's weights, set emb_pretrain to False and specify the embedding size (embed_size).

Preprocess

The preprocessing of the data is done manually and stored locally first (where configs/test.yaml is the path to your config file):

python preprocess.py --config configs/example.yaml

Then load data dynamically using PyTorch's Dataloader when training (see datasets/dataloader.py). This may takes a little time, but in this way, the training can occupy less memory (which means we can have a large batch size) and take less time.

Train

To train a model, just run:

python train.py --config configs/example.yaml

If you have enabled the tensorboard (tensorboard: True in config files), you can visualize the losses and accuracies during training by:

tensorboard --logdir=<your_log_dir>

Test

Test a checkpoint and compute accuracy on test set:

python test.py --config configs/example.yaml

Classify

To predict the category for a specific sentence:

First edit the following items in classify.py:

checkpoint_path = 'str: path_to_your_checkpoint'

# pad limits
# only makes sense when model_name == 'han'
sentence_limit_per_doc = 15
word_limit_per_sentence = 20
# only makes sense when model_name != 'han'
word_limit = 200

Then, run:

python classify.py

Acknowledgement

This project codebase was based on Renovamen/Text-Classification and sgrvinod/a-PyTorch-Tutorial-to-Text-Classification.