Skip to content
This repository was archived by the owner on May 15, 2024. It is now read-only.

Commit 2443fd7

Browse files
committed
init submit
1 parent 006a68a commit 2443fd7

23 files changed

+2613
-2
lines changed

README.md

+112-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,112 @@
1-
# DiffKS
2-
Difference-aware Knowledge Selection
1+
# DiffKS: Difference-aware Knowledge Selection
2+
3+
Codes for the paper: **Difference-aware Knowledge Selection for Knowledge-grounded Conversation Generation**
4+
5+
Please cite this repository using the following reference:
6+
7+
```bib
8+
@inproceedings{diffks-zheng-2020,
9+
title="{D}ifference-aware Knowledge Selection for Knowledge-grounded Conversation Generation",
10+
author="Zheng, Chujie and
11+
Cao, Yunbo and
12+
Jiang, Daxin and
13+
Huang, Minlie",
14+
booktitle="Findings of EMNLP",
15+
year="2020"
16+
}
17+
```
18+
19+
## Prepare Data
20+
21+
Download the [Wizard of Wikipedia](https://drive.google.com/drive/folders/1eowwYSfJKaDtYgKHZVqh8alNmqP3jv9A?usp=sharing) dataset (downloaded using [Parlai](https://github.com/facebookresearch/ParlAI), please refer to the [Sequential Latent Knowledge Selection](https://github.com/bckim92/sequential-knowledge-transformer) for the download details) and put the files in the folder `./Wizard-of-Wikipedia`, or download the [Holl-E](https://drive.google.com/drive/folders/1xQBRDs5q_2xLOdOpbq7UeAmUM0Ht370A?usp=sharing) dataset and put the files in the folder `./Holl-E`.
22+
23+
For Wizard of Wikipedia (WoW):
24+
25+
```bash
26+
python prepare_wow_data.py
27+
```
28+
29+
For Holl-E:
30+
31+
```bash
32+
python prepare_holl_data.py
33+
```
34+
35+
Besides, download the pretrained [wordvector](https://apache-mxnet.s3.cn-north-1.amazonaws.com.cn/gluon/embeddings/glove/glove.6B.zip), unzip the files in `./` and rename the 300-d embedding file as `glove.txt`.
36+
37+
## Training
38+
39+
Our codes now only support single-GPU training, which requires at least 12GB memory.
40+
41+
For Wizard of Wikipedia:
42+
43+
```bash
44+
python run.py \
45+
--mode train \
46+
--dataset WizardOfWiki \
47+
--datapath ./Wizard-of-Wikipedia/prepared_data \
48+
--wvpath ./ \
49+
--cuda 0 \
50+
--droprate 0.5 \
51+
--disentangle \ # the disentangled model, delete this line if train the fused model
52+
--hist_len 2 \
53+
--hist_weights 0.7 0.3 \
54+
--out_dir ./output \
55+
--model_dir ./model \
56+
--cache
57+
```
58+
59+
For Holl-E:
60+
61+
```bash
62+
python run.py \
63+
--mode train \
64+
--dataset HollE \
65+
--datapath ./Holl-E/prepared_data \
66+
--wvpath ./ \
67+
--cuda 0 \
68+
--droprate 0.5 \
69+
--disentangle \ # the disentangled model, delete this line if train the fused model
70+
--hist_len 2 \
71+
--hist_weights 0.7 0.3 \
72+
--out_dir ./output \
73+
--model_dir ./model \
74+
--cache
75+
```
76+
77+
You can modify `run.py` and `myCoTK/dataloader.py` to change more hyperparameters.
78+
79+
## Evaluation
80+
81+
For Wizard of Wikipedia:
82+
83+
```bash
84+
python run.py \
85+
--mode test \
86+
--dataset WizardOfWiki \
87+
--cuda 0 \
88+
--restore best \
89+
--disentangle \ # the disentangled model, delete this line if train the fused model
90+
--hist_len 2 \
91+
--hist_weights 0.7 0.3 \
92+
--out_dir ./output \
93+
--model_dir ./model \
94+
--cache
95+
```
96+
97+
For Holl-E:
98+
99+
```bash
100+
python run.py \
101+
--mode test \
102+
--dataset Holl-E \
103+
--cuda 0 \
104+
--restore best \
105+
--disentangle \ # the disentangled model, delete this line if train the fused model
106+
--hist_len 2 \
107+
--hist_weights 0.7 0.3 \
108+
--out_dir ./output \
109+
--model_dir ./model \
110+
--cache
111+
```
112+

main.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# coding:utf-8
2+
import logging
3+
import json
4+
import os
5+
6+
from cotk.wordvector import WordVector, Glove
7+
from myCoTK.dataloader import WizardOfWiki, HollE
8+
from utils import debug, try_cache, cuda_init, Storage
9+
from seq2seq import Seq2seq
10+
11+
def main(args):
12+
logging.basicConfig(filename=0,
13+
level=logging.DEBUG,
14+
format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',
15+
datefmt='%H:%M:%S')
16+
17+
if args.debug:
18+
debug()
19+
logging.info(json.dumps(args, indent=2))
20+
21+
cuda_init(args.cuda_num, args.cuda)
22+
23+
volatile = Storage()
24+
volatile.load_exclude_set = args.load_exclude_set
25+
volatile.restoreCallback = args.restoreCallback
26+
27+
if args.dataset == 'WizardOfWiki':
28+
data_class = WizardOfWiki
29+
elif args.dataset == 'HollE':
30+
data_class = HollE
31+
else:
32+
raise ValueError
33+
wordvec_class = WordVector.load_class(args.wvclass)
34+
if wordvec_class is None:
35+
wordvec_class = Glove
36+
37+
if not os.path.exists(args.cache_dir):
38+
os.mkdir(args.cache_dir)
39+
args.cache_dir = os.path.join(args.cache_dir, args.dataset)
40+
41+
if not os.path.exists(args.out_dir):
42+
os.mkdir(args.out_dir)
43+
args.out_dir = os.path.join(args.out_dir, args.dataset)
44+
45+
if not os.path.exists(args.model_dir):
46+
os.mkdir(args.model_dir)
47+
if args.dataset not in args.model_dir:
48+
args.model_dir = os.path.join(args.model_dir, args.dataset)
49+
50+
if args.cache:
51+
dm = try_cache(data_class, (args.datapath,), args.cache_dir)
52+
volatile.wordvec = try_cache(
53+
lambda wv, ez, vl: wordvec_class(wv).load_matrix(ez, vl),
54+
(args.wvpath, args.embedding_size, dm.vocab_list),
55+
args.cache_dir, wordvec_class.__name__)
56+
else:
57+
dm = data_class(args.datapath)
58+
wv = wordvec_class(args.wvpath)
59+
volatile.wordvec = wv.load_matrix(args.embedding_size, dm.vocab_list)
60+
61+
volatile.dm = dm
62+
63+
param = Storage()
64+
param.args = args
65+
param.volatile = volatile
66+
67+
model = Seq2seq(param)
68+
if args.mode == "train":
69+
model.train_process()
70+
elif args.mode == "test":
71+
model.test_process()
72+
elif args.mode == 'dev':
73+
model.test_dev()
74+
else:
75+
raise ValueError("Unknown mode")

myCoTK/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)