Skip to content

Commit

Permalink
Allennlp (#10)
Browse files Browse the repository at this point in the history
* switch to allennlp

* update readme
  • Loading branch information
matthew-z authored Dec 12, 2018
1 parent c09bf93 commit 8c931cb
Show file tree
Hide file tree
Showing 37 changed files with 1,389 additions and 1,767 deletions.
58 changes: 40 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,40 +1,62 @@
**Work In Progress.**


An unofficial implementation of R-net in PyTorch.
An unofficial implementation of R-net in [PyTorch](https://github.com/pytorch/pytorch) and [AllenNLP](https://github.com/allenai/allennlp).

[Natural Language Computing Group, MSRA: R-NET: Machine Reading Comprehension with Self-matching Networks](https://www.microsoft.com/en-us/research/publication/mrc/)

However, I failed to reproduce the result with the model described in this paper because some details are not very clear to me and the dynamic attention in self-matching requires too much memory.

Thus, I implemented the variant of R-Net according to [HKUST-KnowComp/R-Net](https://github.com/HKUST-KnowComp/R-Net) (in Tensorflow).

Python 3.5/3.6 and PyTorch 0.41
Some notes about [HKUST-KnowComp/R-Net](https://github.com/HKUST-KnowComp/R-Net) (the model is in configs/r-net/hkust.jsonnet) :
* Question and Passage share the same GRU sentence encoder instead of using separate encoders.
* Sentence Encoders have three layers, but the output is the concat of the three layers instead of the output of the top layer.
* Attentions in pair enocder and self-matching encoder are calculated before RNN (static attention) instead of at each RNN step (dynamic attention).
* The GRUs in the pair encoder and the self-matching encoder have only one layer instead of three layers.
* Variational dropouts are applied to (1) the inputs of RNNs (2) inputs of attentions


**Usage**

```
python main.py
### Dependency

* Python == 3.6
* [AllenNLP](https://github.com/allenai/allennlp) == 0.7.2
* PyTorch == 1.0



### Usage

```
cd R-net
python main.py train ./configs/squad/r-net/hkust.jsonnet -o '{"iterator.batch_size": 128}'
```

**Performance**

Currently this implementation only obtained 70.36 F1 score(60.07 EM), so I would like to recommend https://github.com/HKUST-KnowComp/R-Net (Tensorflow) if you are looking for higher performance.

<img src="img/tensorboard.png" width="800">
### Configuration

The models and hyperparameters are declared in `configs/`

* the HKUST-R-Net: `configs/r-net/hkust.jsonnet`
* the original R-Net: `configs/r-net/original.jsonnet` (currently not workable)



### Performance

The HKUST-R-Net can obtain 79.1 F1 score (70.1 EM) on the dev set.



**Implementation Details**
Red: Training score

I implemented both addition attention (like the original paper) and dot attention (like HKUST-KnowComp/R-Net).
Green: Validation score

While both of them are fine with pair encoder, it seems that self matching encoder with addition attention uses too much CUDA Memory. Thus, dot attention is used for self matching encoder.
Also, dot attention is used for pair encoder by default which does not affect performance much but improves training speed largely.
<img src="img/f1.png" width="400">
<img src="img/em.png" width="400">

Pair encoder + Addition attention: model/module.py: PairEncoder

Pair encoder + Dot attention: model/module.py: PairEncoderV2

Addition attention for self matching encoder: model/module.py: StaticAddAttention
### Acknowledgement

Dot attention for self matching encoder: model/module.py: StaticDotAttention
Thank [HKUST-KnowComp/R-Net](https://github.com/HKUST-KnowComp/R-Net) for sharing their Tensorflow implementation of R-net. This repo is based on their work.
136 changes: 136 additions & 0 deletions configs/squad/r-net/hkust.jsonnet
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
local embedding_size = 500;
local hidden_size = 75;
local attention_size = 75;
local num_layers = 3;
local dropout = 0.3;
local bidirectional = true;

{
dataset_reader: {
type: 'squad',
token_indexers: {
tokens: {
type: 'single_id',
lowercase_tokens: true,
},
token_characters: {
type: 'characters',
character_tokenizer: {
byte_encoding: 'utf-8',
start_tokens: [259],
end_tokens: [260],
},
// min_padding_length: 5,
},
},
},

train_data_path: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-train-v1.1.json',
validation_data_path: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-dev-v1.1.json',
model: {
type: 'r_net',
share_encoder: true,
text_field_embedder: {
token_embedders: {
tokens: {
type: 'embedding',
pretrained_file: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.840B.300d.txt.gz',
embedding_dim: 300,
trainable: false,
},
token_characters: {
type: 'character_encoding',
embedding: {
num_embeddings: 262,
embedding_dim: 8,
},
encoder: {
type: 'gru',
input_size: 8,
hidden_size: 100,
bidirectional: true,
dropout: dropout,
},
},
},
},

question_encoder: {
type: 'concat_rnn',
input_size: embedding_size,
hidden_size: hidden_size,
num_layers: num_layers,
bidirectional: bidirectional,
dropout: dropout,
},

passage_encoder: {
type: 'concat_rnn',
input_size: embedding_size,
hidden_size: hidden_size,
num_layers: num_layers,
bidirectional: bidirectional,
dropout: dropout,
},

pair_encoder: {
type: 'static_pair_encoder',
memory_size: hidden_size * 2 * num_layers,
input_size: hidden_size * 2 * num_layers,
hidden_size: hidden_size,
attention_size: attention_size,
bidirectional: bidirectional,
dropout: dropout,
batch_first: true,

},

self_encoder: {
type: 'static_self_encoder',
memory_size: hidden_size * 2,
input_size: hidden_size * 2,
hidden_size: hidden_size,
attention_size: attention_size,
bidirectional: bidirectional,
dropout: dropout,
batch_first: true,

},

output_layer: {
type: 'pointer_network',
question_size: hidden_size * 2 * num_layers,
passage_size: hidden_size * 2,
attention_size: attention_size,
dropout: dropout,
batch_first: true,
},
},

iterator: {
type: 'basic',
// sorting_keys: [['passage', 'num_tokens'], ['question', 'num_tokens']],
batch_size: 128,
// padding_noise: 0.2,
},

trainer: {
num_epochs: 120,
num_serialized_models_to_keep: 5,
grad_norm: 5.0,
patience: 10,
validation_metric: '+f1',
cuda_device: [0],
learning_rate_scheduler: {
type: 'reduce_on_plateau',
factor: 0.5,
mode: 'max',
patience: 3,
},
optimizer: {
type: 'adadelta',
lr: 0.5,
rho: 0.95,
},
},
}
133 changes: 133 additions & 0 deletions configs/squad/r-net/original.jsonnet
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
local embedding_size = 500;
local hidden_size = 75;
local attention_size = 75;
local num_layers = 3;
local dropout = 0.3;
local bidirectional = true;

{
dataset_reader: {
type: 'squad',
token_indexers: {
tokens: {
type: 'single_id',
lowercase_tokens: true,
},
token_characters: {
type: 'characters',
character_tokenizer: {
byte_encoding: 'utf-8',
start_tokens: [259],
end_tokens: [260],
},
// min_padding_length: 5,
},
},
},

train_data_path: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-train-v1.1.json',
validation_data_path: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-dev-v1.1.json',
model: {
type: 'r_net',
share_encoder: false,
text_field_embedder: {
token_embedders: {
tokens: {
type: 'embedding',
pretrained_file: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.840B.300d.txt.gz',
embedding_dim: 300,
trainable: false,
},
token_characters: {
type: 'character_encoding',
embedding: {
num_embeddings: 262,
embedding_dim: 8,
},
encoder: {
type: 'gru',
input_size: 8,
hidden_size: 100,
bidirectional: true,
dropout: dropout,
},
},
},
},

question_encoder: {
type: 'gru',
input_size: embedding_size,
hidden_size: hidden_size,
num_layers: num_layers,
bidirectional: bidirectional,
dropout: dropout,
},

passage_encoder: {
type: 'gru',
input_size: embedding_size,
hidden_size: hidden_size,
num_layers: num_layers,
bidirectional: bidirectional,
dropout: dropout,
},

pair_encoder: {
type: 'dynamic_pair_encoder',
memory_size: hidden_size * 2 * num_layers,
input_size: hidden_size * 2 * num_layers,
hidden_size: hidden_size,
attention_size: attention_size,
bidirectional: bidirectional,
dropout: dropout,
batch_first: true,
},

self_encoder: {
type: 'dynamic_self_encoder',
memory_size: hidden_size * 2,
input_size: hidden_size * 2,
hidden_size: hidden_size,
attention_size: attention_size,
bidirectional: bidirectional,
dropout: dropout,
batch_first: true,

},

output_layer: {
type: 'pointer_network',
question_size: hidden_size * 2 * num_layers,
passage_size: hidden_size * 2,
attention_size: attention_size,
dropout: dropout,
batch_first: true,
},
},

iterator: {
type: 'basic',
batch_size: 128,
},

trainer: {
num_epochs: 120,
num_serialized_models_to_keep: 5,
grad_norm: 5.0,
patience: 10,
validation_metric: '+f1',
cuda_device: [0],
learning_rate_scheduler: {
type: 'reduce_on_plateau',
factor: 0.5,
mode: 'max',
patience: 3,
},
optimizer: {
type: 'adadelta',
lr: 0.5,
rho: 0.95,
},
},
}
Binary file added img/em.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/f1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed img/tensorboard.png
Binary file not shown.
Loading

0 comments on commit 8c931cb

Please sign in to comment.