Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Py3 tf 1.8 #28

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Prerequisites
-------------

- Python 2.7 or Python 3.3+
> note this version only support Python3.x
- [Tensorflow](https://www.tensorflow.org/)


Expand All @@ -26,15 +27,15 @@ Usage

To train a model with `ptb` dataset:

$ python main.py --dataset ptb
$ python3 main.py --dataset ptb

To test an existing model:

$ python main.py --dataset ptb --forward_only True
$ python3 main.py --dataset ptb --forward_only True

To see all training options, run:

$ python main.py --help
$ python3 main.py --help

which will print

Expand Down
12 changes: 6 additions & 6 deletions batch_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@


def save(fname, obj):
with open(fname, 'w') as f:
with open(fname, 'wb') as f:
pickle.dump(obj, f)


def load(fname):
with open(fname, 'r') as f:
with open(fname, 'rb') as f:
return pickle.load(f)


Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(self, data_dir, dataset_name, batch_size, seq_length, max_word_leng
ydata[-1] = data[0].copy()
data_char = np.zeros([data.shape[0], self.max_word_length])

for idx in xrange(data.shape[0]):
for idx in range(data.shape[0]):
data_char[idx] = all_data_char[split][idx]

if split < 2:
Expand Down Expand Up @@ -141,13 +141,13 @@ def text_to_tensor(self, input_files, vocab_fname, tensor_fname, char_fname, max
word = word[2:]
output_tensor[word_num] = word2idx['|']
else:
if not word2idx.has_key(word):
if word not in word2idx:
idx2word.append(word)
word2idx[word] = len(idx2word) - 1
output_tensor[word_num] = word2idx[word]

for char in word:
if not char2idx.has_key(char):
if char not in char2idx:
idx2char.append(char)
char2idx[char] = len(idx2char) - 1
chars.append(char2idx[char])
Expand All @@ -156,7 +156,7 @@ def text_to_tensor(self, input_files, vocab_fname, tensor_fname, char_fname, max
if len(chars) == max_word_length:
chars[-1] = char2idx['}']

for idx in xrange(min(len(chars), max_word_length)):
for idx in range(min(len(chars), max_word_length)):
output_char[word_num][idx] = chars[idx]
word_num += 1

Expand Down
84 changes: 57 additions & 27 deletions models/LSTMTDNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import numpy as np
import tensorflow as tf

from TDNN import TDNN
from base import Model
from tensorflow.contrib.rnn.python.ops import core_rnn_cell

from .TDNN import TDNN
from .base import Model

from utils import progress
from batch_loader import BatchLoader
from ops import conv2d, batch_norm, highway
from .ops import conv2d, batch_norm, highway

class LSTMTDNN(Model):
"""
Expand Down Expand Up @@ -107,14 +109,21 @@ def prepare_model(self):
word_W = tf.get_variable("word_embed",
[self.word_vocab_size, self.word_embed_dim])

with tf.variable_scope("CNN") as scope:
with tf.variable_scope("CNN", reuse=tf.AUTO_REUSE) as scope:
self.char_inputs = tf.placeholder(tf.int32, [self.batch_size, self.seq_length, self.max_word_length])
self.word_inputs = tf.placeholder(tf.int32, [self.batch_size, self.seq_length])

char_indices = tf.split(1, self.seq_length, self.char_inputs)
word_indices = tf.split(1, self.seq_length, tf.expand_dims(self.word_inputs, -1))
char_inputs_shape = self.char_inputs.get_shape().as_list()
print("log 001 %s" % str(char_inputs_shape))

# old -- split(split_dim, num_split, value, name="split")
# new -- split(value, num_or_size_splits, axis=0, num=None, name='split')
# char_indices = tf.split(1, self.seq_length, self.char_inputs)
# word_indices = tf.split(1, self.seq_length, tf.expand_dims(self.word_inputs, -1))
char_indices = tf.split(self.char_inputs, self.seq_length, axis=1)
word_indices = tf.split(tf.expand_dims(self.word_inputs, -1), self.seq_length, axis=1)

for idx in xrange(self.seq_length):
for idx in range(self.seq_length):
char_index = tf.reshape(char_indices[idx], [-1, self.max_word_length])
word_index = tf.reshape(word_indices[idx], [-1, 1])

Expand All @@ -129,7 +138,8 @@ def prepare_model(self):

if self.use_word:
word_embed = tf.nn.embedding_lookup(word_W, word_index)
cnn_output = tf.concat(1, [char_cnn.output, tf.squeeze(word_embed, [1])])
# cnn_output = tf.concat(1, [char_cnn.output, tf.squeeze(word_embed, [1])])
cnn_output = tf.concat([char_cnn.output, tf.squeeze(word_embed, [1])], 1)
else:
cnn_output = char_cnn.output
else:
Expand All @@ -141,25 +151,33 @@ def prepare_model(self):
cnn_output = tf.squeeze(norm_output)

if highway:
#cnn_output = highway(input_, input_dim_length, self.highway_layers, 0)
# cnn_output = highway(input_, input_dim_length, self.highway_layers, 0)
cnn_output = highway(cnn_output, cnn_output.get_shape()[1], self.highway_layers, 0)

self.cnn_outputs.append(cnn_output)
cnn_output_shape = self.cnn_outputs[0].get_shape().as_list()
print("log 002 %d" % len(self.cnn_outputs))
print("log 003 %s" % str(cnn_output_shape))

with tf.variable_scope("LSTM") as scope:
self.cell = tf.nn.rnn_cell.BasicLSTMCell(self.rnn_size)
self.stacked_cell = tf.nn.rnn_cell.MultiRNNCell([self.cell] * self.layer_depth)
# self.cell = tf.nn.rnn_cell.BasicLSTMCell(self.rnn_size)
# self.stacked_cell = tf.nn.rnn_cell.MultiRNNCell([self.cell] * self.layer_depth)
# self.cell = tf.contrib.rnn.BasicLSTMCell(self.rnn_size)
# self.stacked_cell = tf.contrib.rnn.MultiRNNCell([self.cell] * self.layer_depth)
def lstm_cell():
cell = tf.contrib.rnn.BasicLSTMCell(self.rnn_size)
return cell
self.stacked_cell = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(self.layer_depth)])

outputs, _ = tf.nn.rnn(self.stacked_cell,
self.cnn_outputs,
dtype=tf.float32)
outputs, _ = tf.contrib.rnn.static_rnn(self.stacked_cell, self.cnn_outputs, dtype=tf.float32)

self.lstm_outputs = []
self.true_outputs = tf.placeholder(tf.int64,
[self.batch_size, self.seq_length])

loss = 0
true_outputs = tf.split(1, self.seq_length, self.true_outputs)
# true_outputs = tf.split(1, self.seq_length, self.true_outputs)
true_outputs = tf.split(self.true_outputs, self.seq_length, axis=1)

for idx, (top_h, true_output) in enumerate(zip(outputs, true_outputs)):
if self.dropout_prob > 0:
Expand All @@ -170,25 +188,30 @@ def prepare_model(self):
else:
if idx != 0:
scope.reuse_variables()
proj = tf.nn.rnn_cell._linear(top_h, self.word_vocab_size, 0)
# proj = tf.nn.rnn_cell._linear(top_h, self.word_vocab_size, 0)
proj = core_rnn_cell._linear(top_h, self.word_vocab_size, 0)
self.lstm_outputs.append(proj)

loss += tf.nn.sparse_softmax_cross_entropy_with_logits(self.lstm_outputs[idx], tf.squeeze(true_output))
# loss += tf.nn.sparse_softmax_cross_entropy_with_logits(self.lstm_outputs[idx], tf.squeeze(true_output))
loss += tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.lstm_outputs[idx], labels=tf.squeeze(true_output))

self.loss = tf.reduce_mean(loss) / self.seq_length

tf.scalar_summary("loss", self.loss)
tf.scalar_summary("perplexity", tf.exp(self.loss))
# move to tf.summary.scalar()
# tf.scalar_summary("loss", self.loss)
# tf.scalar_summary("perplexity", tf.exp(self.loss))
tf.summary.scalar("loss", self.loss)
tf.summary.scalar("perplexity", tf.exp(self.loss))

def train(self, epoch):
cost = 0
target = np.zeros([self.batch_size, self.seq_length])

N = self.loader.sizes[0]
for idx in xrange(N):
for idx in range(N):
target.fill(0)
x, y, x_char = self.loader.next_batch(0)
for b in xrange(self.batch_size):
for b in range(self.batch_size):
for t, w in enumerate(y[b]):
target[b][t] = w

Expand Down Expand Up @@ -226,11 +249,11 @@ def test(self, split_idx, max_batches=None):
target = np.zeros([self.batch_size, self.seq_length])

cost = 0
for idx in xrange(N):
for idx in range(N):
target.fill(0)

x, y, x_char = self.loader.next_batch(split_idx)
for b in xrange(self.batch_size):
for b in range(self.batch_size):
for t, w in enumerate(y[b]):
target[b][t] = w

Expand Down Expand Up @@ -275,22 +298,29 @@ def run(self, epoch=25,
global_step=self.global_step)

# ready for train
tf.initialize_all_variables().run()
# change init function
# tf.initialize_all_variables().run()
tf.global_variables_initializer().run()

if self.load(self.checkpoint_dir, self.dataset_name):
print("[*] SUCCESS to load model for %s." % self.dataset_name)
else:
print("[!] Failed to load model for %s." % self.dataset_name)

self.saver = tf.train.Saver()
self.merged_summary = tf.merge_all_summaries()
self.writer = tf.train.SummaryWriter("./logs", self.sess.graph_def)
# move to tf.summary.merge_all()
# self.merged_summary = tf.merge_all_summaries()
self.merged_summary = tf.summary.merge_all()
# move to tf.summary.FileWriter
# Passing a `GraphDef` to the SummaryWriter is deprecated. Pass a `Graph` object instead
# self.writer = tf.train.SummaryWriter("./logs", self.sess.graph_def)
self.writer = tf.summary.FileWriter("./logs", self.sess.graph)

self.log_loss = []
self.log_perp = []

if not self.forward_only:
for idx in xrange(epoch):
for idx in range(epoch):
train_loss = self.train(idx)
valid_loss = self.test(1)

Expand Down
7 changes: 5 additions & 2 deletions models/TDNN.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tensorflow as tf

from .ops import conv2d
from base import Model
from .base import Model

class TDNN(Model):
"""Time-delayed Nueral Network (cf. http://arxiv.org/abs/1508.06615v4)
Expand Down Expand Up @@ -38,6 +38,9 @@ def __init__(self, input_, embed_dim=650,
layers.append(tf.squeeze(pool))

if len(kernels) > 1:
self.output = tf.concat(1, layers)
# old -- concat(concat_dim, values, name="concat")
# new -- concat(values, axis, name='concat')
# self.output = tf.concat(1, layers)
self.output = tf.concat(layers, 1)
else:
self.output = layers[0]
4 changes: 2 additions & 2 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from TDNN import TDNN
from LSTMTDNN import LSTMTDNN
from .TDNN import TDNN
from .LSTMTDNN import LSTMTDNN
9 changes: 6 additions & 3 deletions models/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tensorflow as tf

from tensorflow.python.framework import ops
from tensorflow.contrib.rnn.python.ops import core_rnn_cell

from utils import *

Expand All @@ -23,11 +24,13 @@ def highway(input_, size, layer_size=1, bias=-2, f=tf.nn.relu):
where g is nonlinearity, t is transform gate, and (1 - t) is carry gate.
"""
output = input_
for idx in xrange(layer_size):
output = f(tf.nn.rnn_cell._linear(output, size, 0, scope='output_lin_%d' % idx))
for idx in range(layer_size):
# output = f(tf.nn.rnn_cell._linear(output, size, 0, scope='output_lin_%d' % idx))
output = f(core_rnn_cell._linear(output, size, 0))

transform_gate = tf.sigmoid(
tf.nn.rnn_cell._linear(input_, size, 0, scope='transform_lin_%d' % idx) + bias)
core_rnn_cell._linear(input_, size, 0) + bias)
# tf.nn.rnn_cell._linear(input_, size, 0, scope='transform_lin_%d' % idx) + bias)
carry_gate = 1. - transform_gate

output = transform_gate * output + carry_gate * input_
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pprint

try:
xrange
range
except NameError:
xrange = range

Expand Down