forked from mhjabreel/CharCnn_Keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
70 lines (66 loc) · 3.84 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import tensorflow as tf
import json
from data_utils import Data
from models.char_cnn_zhang import CharCNNZhang
from models.char_cnn_kim import CharCNNKim
from models.char_tcn import CharTCN
tf.flags.DEFINE_string("model", "char_cnn_zhang", "Specifies which model to use: char_cnn_zhang or char_cnn_kim")
FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()
if __name__ == "__main__":
# Load configurations
config = json.load(open("config.json"))
# Load training data
training_data = Data(data_source=config["data"]["training_data_source"],
alphabet=config["data"]["alphabet"],
input_size=config["data"]["input_size"],
num_of_classes=config["data"]["num_of_classes"])
training_data.load_data()
training_inputs, training_labels = training_data.get_all_data()
# Load validation data
validation_data = Data(data_source=config["data"]["validation_data_source"],
alphabet=config["data"]["alphabet"],
input_size=config["data"]["input_size"],
num_of_classes=config["data"]["num_of_classes"])
validation_data.load_data()
validation_inputs, validation_labels = validation_data.get_all_data()
# Load model configurations and build model
if FLAGS.model == "kim":
model = CharCNNKim(input_size=config["data"]["input_size"],
alphabet_size=config["data"]["alphabet_size"],
embedding_size=config["char_cnn_kim"]["embedding_size"],
conv_layers=config["char_cnn_kim"]["conv_layers"],
fully_connected_layers=config["char_cnn_kim"]["fully_connected_layers"],
num_of_classes=config["data"]["num_of_classes"],
dropout_p=config["char_cnn_kim"]["dropout_p"],
optimizer=config["char_cnn_kim"]["optimizer"],
loss=config["char_cnn_kim"]["loss"])
elif FLAGS.model == 'tcn':
model = CharTCN(input_size=config["data"]["input_size"],
alphabet_size=config["data"]["alphabet_size"],
embedding_size=config["char_tcn"]["embedding_size"],
conv_layers=config["char_tcn"]["conv_layers"],
fully_connected_layers=config["char_tcn"]["fully_connected_layers"],
num_of_classes=config["data"]["num_of_classes"],
dropout_p=config["char_tcn"]["dropout_p"],
optimizer=config["char_tcn"]["optimizer"],
loss=config["char_tcn"]["loss"])
else:
model = CharCNNZhang(input_size=config["data"]["input_size"],
alphabet_size=config["data"]["alphabet_size"],
embedding_size=config["char_cnn_zhang"]["embedding_size"],
conv_layers=config["char_cnn_zhang"]["conv_layers"],
fully_connected_layers=config["char_cnn_zhang"]["fully_connected_layers"],
num_of_classes=config["data"]["num_of_classes"],
threshold=config["char_cnn_zhang"]["threshold"],
dropout_p=config["char_cnn_zhang"]["dropout_p"],
optimizer=config["char_cnn_zhang"]["optimizer"],
loss=config["char_cnn_zhang"]["loss"])
# Train model
model.train(training_inputs=training_inputs,
training_labels=training_labels,
validation_inputs=validation_inputs,
validation_labels=validation_labels,
epochs=config["training"]["epochs"],
batch_size=config["training"]["batch_size"],
checkpoint_every=config["training"]["checkpoint_every"])