Skip to content

Commit 1a68848

Browse files
committed
commit all programs
1 parent 42a74e1 commit 1a68848

File tree

173 files changed

+241121
-66771
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

173 files changed

+241121
-66771
lines changed

Qadpt_model.png

349 KB
Loading

Readme.md

+30-43
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,38 @@
11
# DyKGChat
2-
This project is the implementation of our paper **DyKgChat: A Multi-domain Chit-chat Dialogue Generation Dataset Grounding on Dynamic Knowledge Graphs**.
2+
The project contains the collected data and code of our paper **Yi-Lin Tuan, Yun-Nung Chen, Hung-yi Lee. "DyKgChat: Benchmarking Dialogue Generation Grounding on Dynamic Knowledge Graphs", EMNLP 2019**.
33

4+
* our proposed approach: (Qadpt) **Q**uick **Ad**a**pt**ive Dynamic Knoledge-Grounded Neural Converation Model (pronouce: Q-adapt)
45

5-
## Requirements
6-
* jieba
7-
* python3
6+
![Qadpt](Qadpt_model.png)
7+
8+
## Setup
9+
### Installation (my environment)
10+
* python3.6
811
* tensorflow r1.13
12+
* jieba
13+
* nltk3.2.5
914

10-
## Files
11-
* `data/`: the collected data `hgzhz/` and `friends/` as well as their trained TransE
12-
* `model_ckpts/`: the trained models
13-
* `Qadpt/`: the programs
15+
### Files
16+
* `data/`: the collected data `hgzhz/` and `friends/` as well as the trained TransE
17+
* `model_ckpts/`: the trained models in the paper
1418

1519

1620
## Usage
17-
* clone the repository and switch to directory `Qadpt/`
18-
```
19-
$cd Qadpt/
20-
```
21-
22-
* testing hgzhz (the following commands must be in order)
23-
```
24-
$bash run.sh -1 pred_acc Qadpt
25-
$bash run.sh -1 ifchange Qadpt
26-
$bash run.sh -1 eval_pred_acc Qadpt
27-
```
28-
29-
* testing friends
30-
```
31-
$bash frun.sh -1 pred_acc Qadpt
32-
$bash frun.sh -1 ifchange Qadpt
33-
$bash frun.sh -1 eval_pred_acc Qadpt
34-
```
35-
36-
The automatic evaluation results will be printed on the screen, and some files will be outputed to `Qadpt/hgzhz_results/` or `Qadpt/friends_results/`.
37-
38-
The default `ifchange` evaluates **Last-1** score. To change to **random** or **Last-2**, modify the `line 464` in `main.py` to `level=-1` or `level=1`.
39-
40-
41-
* training hgzhz
42-
```
43-
$bash run.sh 0 None Qadpt_new
44-
```
45-
46-
* testing friends
47-
```
48-
$bash frun.sh 0 None Qadpt_new
49-
```
50-
51-
The trained model will be stored in `model_ckpts/hgzhz/Qadpt_new/` or `model_ckpts/friends/Qadpt_new/`
21+
* clone the repository
22+
* run the script `run.sh`
23+
```
24+
$bash run.sh <GPU_ID> <method> <model> <data> <exp_name>
25+
```
26+
* for <GPU_ID>, check your device avalibility by `nvidia-smi`
27+
* for <method>, choose from `train`, `pred_acc`, `eval_pred_acc`, `ifchange`
28+
* for <model>, choose from `seq2seq`, `MemNet`, `TAware`, `KAware`, `Qadpt`
29+
* for <data>, choose from `friends`, `hgzhz_v1_0`(used in our paper), `hgzhz`(current newest version)
30+
* for <exp_name>, check the directory `model_ckpts`
31+
32+
## More description
33+
* testing method
34+
* `pred_acc`: for metrics `Generated-KW`, `BLEU-2`, `distinct-n`
35+
* `eval_pred_acc`: for metrics `KW-Acc`, `KW/Generic`, `perplexity`
36+
* `ifchange`: for change rates / accurate change rates
37+
* script options
38+
* the `hops_num` and `change_level` are required to be changed in `run.sh`

args.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import argparse
2+
import re
3+
4+
def parse():
5+
parser = argparse.ArgumentParser(
6+
description='You have to set the parameters for the model.')
7+
8+
# directory related
9+
parser.add_argument("--model", type=str, default='Qadpt')
10+
parser.add_argument("--model-dir", type=str, default='model_ckpts')
11+
parser.add_argument("--results-dir", type=str, default='results')
12+
parser.add_argument("--data-dir", type=str, default='data')
13+
parser.add_argument("--data-path", type=str, default='data/friends/friends.txt')
14+
parser.add_argument("--data-type", type=str, default='test')
15+
# parameters related
16+
parser.add_argument("--size", type=int, default=128)
17+
parser.add_argument("--num-layers", type=int, default=1)
18+
parser.add_argument("--hops-num", type=int, default=1)
19+
parser.add_argument("--kgpath-len", type=int, default=6)
20+
parser.add_argument("--vocab-size", type=int, default=20000)
21+
parser.add_argument("--fact-size", type=int, default=100)
22+
# for training setting
23+
parser.add_argument("--lr", type=float, default=0.5)
24+
parser.add_argument("--lr-decay", type=float, default=0.99)
25+
parser.add_argument("--grad-norm", type=float, default=5.0)
26+
parser.add_argument("--buckets", type=str, default='[(10, 5)]')
27+
parser.add_argument("--batch-size", type=int, default=128)
28+
parser.add_argument("--max-seq-len", type=int, default=50)
29+
parser.add_argument("--max-train-data-size", type=int, default=0)# 0: no limit
30+
parser.add_argument("--steps-per-checkpoint", type=int, default=200)
31+
# test
32+
parser.add_argument("--test-type", type=str, default='train')
33+
parser.add_argument("--change-level", type=int, default=0)
34+
35+
return parser.parse_args()
36+
37+
def parse_buckets(str_buck):
38+
_pair = re.compile(r"(\d+,\d+)")
39+
_num = re.compile(r"\d+")
40+
buck_list = _pair.findall(str_buck)
41+
if len(buck_list) < 1:
42+
raise ValueError("The bucket should has at least 1 component.")
43+
buckets = []
44+
for buck in buck_list:
45+
tmp = _num.findall(buck)
46+
d_tmp = (int(tmp[0]), int(tmp[1]))
47+
buckets.append(d_tmp)
48+
return buckets
49+
50+
FLAGS = parse()
51+
FLAGS.data_dir, _ = FLAGS.data_path.rsplit('/',1)
52+
_buckets = parse_buckets(FLAGS.buckets)

0 commit comments

Comments
 (0)