Skip to content

Commit

Permalink
restore best hyperparameters for 30% settings
Browse files Browse the repository at this point in the history
  • Loading branch information
xwhan committed Jun 20, 2019
1 parent 531ac98 commit 49b347e
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 17 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_full_kb --max_num_nei
```

### Incomplete KB setting
#### 50% KB
The Hits@1 should match or be slightly better than the number reported in the paper. More tuning on threshold should give you better F1 score.
#### 30% KB
```
CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_kb_05 --max_num_neighbors 50 --use_doc --data_folder datasets/webqsp/kb_05/ --eps 0.12
CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_kb_05 --max_num_neighbors 50 --use_doc --data_folder datasets/webqsp/kb_03/ --eps 0.05
```

### Citation
Expand Down
9 changes: 3 additions & 6 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def __init__(self, args):
# question and doc encoder
self.question_encoder = Packed(nn.LSTM(self.word_dim, self.hidden_dim // 2, batch_first=True, bidirectional=True))

# for shared encoder ablation
# self.relation_encoder = Packed(nn.LSTM(self.word_dim, self.hidden_dim // 2, batch_first=True, bidirectional=True))

self.self_att_r = AttnEncoder(self.hidden_dim)
self.self_att_q = AttnEncoder(self.hidden_dim)
Expand Down Expand Up @@ -130,7 +128,6 @@ def forward(self, feed):
neighbor_ent_local_index = (neighbor_ent_local_index + 1) * neighbor_ent_local_mask
neighbor_ent_local_index = neighbor_ent_local_index.view(-1)

# prepare pagerank scores
ent_seed_info = feed['query_entities'].float() # seed entity will have 1.0 score
ent_is_seed = torch.cat([torch.zeros(1).to(torch.device('cuda')), ent_seed_info.view(-1)], dim=0)
ent_seed_indicator = torch.index_select(ent_is_seed, dim=0, index=neighbor_ent_local_index).view(B*max_num_candidates, max_num_neighbors)
Expand All @@ -141,7 +138,7 @@ def forward(self, feed):
q_mask_expand = q_mask.unsqueeze(1).expand(B, max_num_candidates, -1).contiguous()
q_mask_expand = q_mask_expand.view(B*max_num_candidates, -1)
q_n_affinity = torch.bmm(q_emb_expand, neighbor_rel_emb.transpose(1, 2)) # (bsize*max_num_candidates, q_len, max_num_neighbors)
q_n_affinity_mask_q = q_n_affinity - (1 - q_mask_expand.unsqueeze(2)) * 1e8
q_n_affinity_mask_q = q_n_affinity - (1 - q_mask_expand.unsqueeze(2)) * 1e20
q_n_affinity_mask_n = q_n_affinity - (1 - neighbor_mask.view(B*max_num_candidates, 1, max_num_neighbors))
normalize_over_q = F.softmax(q_n_affinity_mask_q, dim=1)
normalize_over_n = F.softmax(q_n_affinity_mask_n, dim=2)
Expand Down Expand Up @@ -229,8 +226,8 @@ def forward(self, feed):
# refine KB ent_emb
# refined_ent_emb = self.refine_ent(ent_emb, ent_emb_from_doc)
if self.use_doc:
ent_emb = self.attn_match(torch.cat([ent_emb, ent_emb_from_doc, ent_emb_from_span], dim=-1)).relu()
# q_node_emb = self.attn_match_q(q_node_emb).tanh()
ent_emb = l_relu(self.attn_match(torch.cat([ent_emb, ent_emb_from_doc, ent_emb_from_span], dim=-1)))
# q_node_emb = self.attn_match_q(q_node_emb)

ent_scores = (q_node_emb * ent_emb).sum(2)

Expand Down
3 changes: 0 additions & 3 deletions modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import math

from attention import SimpleEncoder

class Packed(nn.Module):

Expand All @@ -31,8 +30,6 @@ def gelu(x):

def l_relu(x, n_slope=0.01):
return F.leaky_relu(x, n_slope)
# return gelu(x)
# return F.relu(x)

class ConditionGate(nn.Module):
"""docstring for ConditionGate"""
Expand Down
2 changes: 1 addition & 1 deletion run_with_doc.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash

CUDA_VISIBLE_DEVICES=$1 python train.py --model_id $2 --num_layer 1 --max_num_neighbors 100 --use_doc --label_smooth 0.1 --data_folder datasets/webqsp/full/
CUDA_VISIBLE_DEVICES=$1 python train.py --model_id $2 --num_layer 1 --max_num_neighbors 100 --use_doc --data_folder datasets/webqsp/kb_05/ --eps 0.12
19 changes: 14 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,31 @@ def train(cfg):
print('save final model')
torch.save(model.state_dict(), 'model/{}/{}_final.pt'.format(cfg['name'], cfg['model_id']))

print('finished training, testing...')

# model_save_path = 'model/{}/{}_best.pt'.format(cfg['name'], cfg['model_id'])
# model.load_state_dict(torch.load(model_save_path))


print('..........Finished training, start testing.......')

test_data = DataLoader(cfg, documents, mode='test')
model.eval()
print('finished training, testing final model...')
test(model, test_data, cfg['eps'])

print('testing best model...')
model_save_path = 'model/{}/{}_best.pt'.format(cfg['name'], cfg['model_id'])
model.load_state_dict(torch.load(model_save_path))
model.eval()

print('Testing....')
test_data = DataLoader(cfg, documents, mode='test')
test(model, test_data, cfg['eps'])


def test(model, test_data, eps):

model.eval()
batcher = test_data.batcher()
id2entity = test_data.id2entity
f1s, hits = [], []
q_to_metrics = {}
questions = []
pred_answers = []
for feed in batcher:
Expand Down

0 comments on commit 49b347e

Please sign in to comment.