Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLI args #24

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
save/
**/__pycache__
experiment-log-*.csv
75 changes: 33 additions & 42 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import getopt
import sys

from colorama import Fore
import tensorflow as tf

from models.gsgan.Gsgan import Gsgan
from models.leakgan.Leakgan import Leakgan
Expand All @@ -12,17 +12,21 @@
from models.textGan_MMD.Textgan import TextganMmd


supported_gans = {
'seqgan': Seqgan,
'gsgan': Gsgan,
'textgan': TextganMmd,
'leakgan': Leakgan,
'rankgan': Rankgan,
'maligan': Maligan,
'mle': Mle
}
supported_training = {'oracle', 'cfg', 'real'}


def set_gan(gan_name):
gans = dict()
gans['seqgan'] = Seqgan
gans['gsgan'] = Gsgan
gans['textgan'] = TextganMmd
gans['leakgan'] = Leakgan
gans['rankgan'] = Rankgan
gans['maligan'] = Maligan
gans['mle'] = Mle
try:
Gan = gans[gan_name.lower()]
Gan = supported_gans[gan_name.lower()]
gan = Gan()
gan.vocab_size = 5000
gan.generate_num = 10000
Expand All @@ -32,7 +36,6 @@ def set_gan(gan_name):
sys.exit(-2)



def set_training(gan, training_method):
try:
if training_method == 'oracle':
Expand All @@ -50,36 +53,24 @@ def set_training(gan, training_method):
return gan_func


def parse_cmd(argv):
try:
opts, args = getopt.getopt(argv, "hg:t:d:")

opt_arg = dict(opts)
if '-h' in opt_arg.keys():
print('usage: python main.py -g <gan_type>')
print(' python main.py -g <gan_type> -t <train_type>')
print(' python main.py -g <gan_type> -t realdata -d <your_data_location>')
sys.exit(0)
if not '-g' in opt_arg.keys():
print('unspecified GAN type, use MLE training only...')
gan = set_gan('mle')
else:
gan = set_gan(opt_arg['-g'])
if not '-t' in opt_arg.keys():
gan.train_oracle()
else:
gan_func = set_training(gan, opt_arg['-t'])
if opt_arg['-t'] == 'real' and '-d' in opt_arg.keys():
gan_func(opt_arg['-d'])
else:
gan_func()
except getopt.GetoptError:
print('invalid arguments!')
print('`python main.py -h` for help')
sys.exit(-1)
pass

def parse_cmd():
flags = tf.app.flags
flags.DEFINE_enum('gan_type', 'mle', list(supported_gans.keys()),
'Type of GAN to use')
flags.DEFINE_enum('train_type', 'oracle', supported_training,
'Type of training to use')
flags.DEFINE_string('data', 'data/image_coco.txt', '')
return

if __name__ == '__main__':
gan = None
parse_cmd(sys.argv[1:])
parse_cmd()
tf.app.flags.DEFINE_string('oracle_file', 'save/oracle.txt', '')
tf.app.flags.DEFINE_string('generator_file', 'save/generator.txt', '')
tf.app.flags.DEFINE_string('test_file', 'save/test_file.txt', '')
flags = tf.app.flags.FLAGS
gan = set_gan(flags.gan_type)
train_f = set_training(gan, flags.train_type)
if flags.train_type == 'real':
train_f(flags.data)
else:
train_f()
7 changes: 7 additions & 0 deletions models/Gan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import abstractmethod
from tensorflow.app import flags

from utils.utils import init_sess

Expand All @@ -19,6 +20,12 @@ def __init__(self):
self.log = None
self.reward = None

FLAGS = flags.FLAGS
self.oracle_file = FLAGS.oracle_file
self.generator_file = FLAGS.generator_file
self.test_file = FLAGS.test_file


def set_oracle(self, oracle):
self.oracle = oracle

Expand Down
4 changes: 0 additions & 4 deletions models/gsgan/Gsgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ def __init__(self, oracle=None):
self.generate_num = 128
self.start_token = 0

self.oracle_file = 'save/oracle.txt'
self.generator_file = 'save/generator.txt'
self.test_file = 'save/test_file.txt'

def init_oracle_trainng(self, oracle=None):
if oracle is None:
oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim,
Expand Down
4 changes: 0 additions & 4 deletions models/leakgan/Leakgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ def __init__(self, oracle=None):
self.dis_embedding_dim = 64
self.goal_size = 16

self.oracle_file = 'save/oracle.txt'
self.generator_file = 'save/generator.txt'
self.test_file = 'save/test_file.txt'

def init_oracle_trainng(self, oracle=None):
goal_out_size = sum(self.num_filters)

Expand Down
4 changes: 0 additions & 4 deletions models/maligan_basic/Maligan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ def __init__(self, oracle=None):
self.generate_num = 128
self.start_token = 0

self.oracle_file = 'save/oracle.txt'
self.generator_file = 'save/generator.txt'
self.test_file = 'save/test_file.txt'

def init_oracle_trainng(self, oracle=None):
if oracle is None:
oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim,
Expand Down
4 changes: 0 additions & 4 deletions models/mle/Mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ def __init__(self, oracle=None):
self.generate_num = 128
self.start_token = 0

self.oracle_file = 'save/oracle.txt'
self.generator_file = 'save/generator.txt'
self.test_file = 'save/test_file.txt'

def init_oracle_trainng(self, oracle=None):
if oracle is None:
oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim,
Expand Down
3 changes: 0 additions & 3 deletions models/pg_bleu/Pgbleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ def __init__(self, oracle=None):
self.generate_num = 128
self.start_token = 0

self.oracle_file = 'save/oracle.txt'
self.generator_file = 'save/generator.txt'

def init_oracle_trainng(self, oracle=None):
if oracle is None:
oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim,
Expand Down
4 changes: 0 additions & 4 deletions models/rankgan/Rankgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ def __init__(self, oracle=None):
self.generate_num = 128
self.start_token = 0

self.oracle_file = 'save/oracle.txt'
self.generator_file = 'save/generator.txt'
self.test_file = 'save/test_file.txt'

def init_oracle_trainng(self, oracle=None):
if oracle is None:
oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim,
Expand Down
4 changes: 0 additions & 4 deletions models/seqgan/Seqgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ def __init__(self, oracle=None):
self.generate_num = 128
self.start_token = 0

self.oracle_file = 'save/oracle.txt'
self.generator_file = 'save/generator.txt'
self.test_file = 'save/test_file.txt'

def init_metric(self):
nll = Nll(data_loader=self.oracle_data_loader, rnn=self.oracle, sess=self.sess)
self.add_metric(nll)
Expand Down
4 changes: 0 additions & 4 deletions models/textGan_MMD/Textgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ def __init__(self, oracle=None):
self.generate_num = 128
self.start_token = 0

self.oracle_file = 'save/oracle.txt'
self.generator_file = 'save/generator.txt'
self.test_file = 'save/test_file.txt'

def init_oracle_trainng(self, oracle=None):
if oracle is None:
oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim,
Expand Down