Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
xwhan committed Jun 17, 2019
1 parent eae3827 commit 531ac98
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 40 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
*.json
datasets/webqsp/
*.tar.gz
*.pt
__pycache__/
tf_logs/*
21 changes: 6 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,21 @@ Model Overview:

### Prepare data
```
mkdir datasets && cd datasets && wget http://nlp.cs.ucsb.edu/data/webqsp.tar.gz && tar -xzvf webqsp.tar.gz
mkdir datasets && cd datasets && wget http://nlp.cs.ucsb.edu/data/webqsp.tar.gz && tar -xzvf webqsp.tar.gz && cd ..
```

### Full KB setting
**Training**
```
CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_full_kb --num_layer 1 --max_num_neighbors 50 --label_smooth 0.1 --data_folder datasets/webqsp/full/
```
**Testing**
```
CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_full_kb --num_layer 1 --max_num_neighbors 50 --label_smooth 0.1 --data_folder datasets/webqsp/full/ --mode test
CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_full_kb --max_num_neighbors 50 --label_smooth 0.1 --data_folder datasets/webqsp/full/
```

### Incomplete KB setting (50%)
**Training**
```
CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_kb_05 --num_layer 1 --max_num_neighbors 100 --use_doc --label_smooth 0.1 --data_folder datasets/webqsp/kb_05/
```
**Testing**
### Incomplete KB setting
#### 50% KB
```
CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_kb_05 --num_layer 1 --max_num_neighbors 100 --use_doc --label_smooth 0.1 --data_folder datasets/webqsp/kb_05/ --mode test --eps 0.12
CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_kb_05 --max_num_neighbors 50 --use_doc --data_folder datasets/webqsp/kb_05/ --eps 0.12
```

### Bibtex
### Citation
```
@article{xiong2019improving,
title={Improving Question Answering over Incomplete KBs with Knowledge-Aware Reader},
Expand Down
6 changes: 3 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def forward(self, feed):

# prepare pagerank scores
ent_seed_info = feed['query_entities'].float() # seed entity will have 1.0 score
ent_pagerank = torch.cat([torch.zeros(1).to(torch.device('cuda')), ent_seed_info.view(-1)], dim=0)
pagerank = torch.index_select(ent_pagerank, dim=0, index=neighbor_ent_local_index).view(B*max_num_candidates, max_num_neighbors)
ent_is_seed = torch.cat([torch.zeros(1).to(torch.device('cuda')), ent_seed_info.view(-1)], dim=0)
ent_seed_indicator = torch.index_select(ent_is_seed, dim=0, index=neighbor_ent_local_index).view(B*max_num_candidates, max_num_neighbors)

# v0.0 more find-grained attention
q_emb_expand = q_emb.unsqueeze(1).expand(B, max_num_candidates, max_q_len, -1).contiguous()
Expand Down Expand Up @@ -163,7 +163,7 @@ def forward(self, feed):
neighbor_ent_emb = torch.index_select(ent_emb_for_lookup, dim=0, index=neighbor_ent_local_index)
neighbor_ent_emb = neighbor_ent_emb.view(B*max_num_candidates, max_num_neighbors, -1)
neighbor_vec = torch.cat([neighbor_rel_emb, neighbor_ent_emb], dim =-1).view(B*max_num_candidates, max_num_neighbors, -1) # for propagation
neighbor_scores = q_rel_simi * pagerank
neighbor_scores = q_rel_simi * ent_seed_indicator
neighbor_scores = neighbor_scores - (1 - neighbor_mask.view(B*max_num_candidates, max_num_neighbors)) * 1e8
attn_score = F.softmax(neighbor_scores, dim=1)
aggregate = self.kg_prop(neighbor_vec) * attn_score.unsqueeze(2)
Expand Down
23 changes: 4 additions & 19 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import numpy as np
import random
import json

from data_generator import DataLoader
from model import KAReader
Expand Down Expand Up @@ -44,17 +43,14 @@ def get_best_ans(candidate2prob):
return best_ans

def train(cfg):
tf_logger = SummaryWriter('my_logs/' + cfg['model_id'])
tf_logger = SummaryWriter('tf_logs/' + cfg['model_id'])

# train and test share the same set of documents
documents = load_documents(cfg['data_folder'] + cfg['{}_documents'.format(cfg['mode'])])



# train data
train_data = DataLoader(cfg, documents)
valid_data = DataLoader(cfg, documents, mode='dev')
test_data = DataLoader(cfg, documents, mode='test')

model = KAReader(cfg)
model = model.to(torch.device('cuda'))
Expand Down Expand Up @@ -97,12 +93,6 @@ def train(cfg):
print('evaluation best f1:{} current:{}'.format(best_val_f1, val_f1))
print('evaluation best hits:{} current:{}'.format(best_val_hits, val_hits))

if epoch % 5 == 0:
test_f1, test_hits = test(model, test_data, cfg['eps'])
tf_logger.add_scalar('test_hits', test_hits, epoch)
tf_logger.add_scalar('test_f1', test_f1, epoch)
torch.save(model.state_dict(), 'model/{}/{}_{}.pt'.format(cfg['name'], cfg['model_id'], test_hits))

print('save final model')
torch.save(model.state_dict(), 'model/{}/{}_final.pt'.format(cfg['name'], cfg['model_id']))

Expand All @@ -111,6 +101,8 @@ def train(cfg):
model.load_state_dict(torch.load(model_save_path))
model.eval()

print('Testing....')
test_data = DataLoader(cfg, documents, mode='test')
test(model, test_data, cfg['eps'])

def test(model, test_data, eps):
Expand Down Expand Up @@ -148,18 +140,11 @@ def test(model, test_data, eps):
f1s.append(f1)
hits.append(hit)
print('evaluation.......')
print('how many samples......', len(f1s))
print('how many eval samples......', len(f1s))
print('avg_f1', np.mean(f1s))
print('avg_hits', np.mean(hits))

for q, f1, hit in zip(questions, f1s, hits):
q_to_metrics[q] = (f1, hit)
json.dump(q_to_metrics, open('q_to_metrics_kb_only.json', 'w'))

q_to_ans = {}
for q, best_ans in zip(questions, pred_answers):
q_to_ans[q] = best_ans
json.dump(q_to_ans, open('q_to_ans_kb_only.json', 'w'))

model.train()
return np.mean(f1s), np.mean(hits)
Expand Down
3 changes: 0 additions & 3 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from collections import Counter

from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from tqdm import tqdm

import argparse
Expand Down Expand Up @@ -61,7 +59,6 @@ def get_config(config_path=None):

# model options
parser.add_argument('--use_doc', action='store_true')
parser.add_argument('--use_kb', action='store_true')
parser.add_argument('--use_inverse_relation', action='store_true')
parser.add_argument('--model_id', default='debug', type=str)
parser.add_argument('--load_model_file', default=None, type=str)
Expand Down

0 comments on commit 531ac98

Please sign in to comment.