Skip to content

Commit

Permalink
remove using self.batch_size explicitly in model
Browse files Browse the repository at this point in the history
This change removes using self.batch_size explicitly in the model. This allows different batch sizes used in the same run for training and testing.
  • Loading branch information
acadTags authored May 20, 2020
1 parent 9b371d9 commit 0d7e40c
Showing 1 changed file with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,7 @@ def gru_forward_word_level(self, embedded_words):
embedded_words_squeeze = [tf.squeeze(x, axis=1) for x in
embedded_words_splitted] # it is a list,length is sentence_length, each element is [batch_size*num_sentences,embed_size]
# demension_1=embedded_words_squeeze[0].get_shape().dims[0]
h_t = tf.ones((self.batch_size * self.num_sentences,
self.hidden_size)) #TODO self.hidden_size h_t =int(tf.get_shape(embedded_words_squeeze[0])[0]) # tf.ones([self.batch_size*self.num_sentences, self.hidden_size]) # [batch_size*num_sentences,embed_size]
h_t = tf.ones_like(embedded_words_squeeze[0])
h_t_forward_list = []
for time_step, Xt in enumerate(embedded_words_squeeze): # Xt: [batch_size*num_sentences,embed_size]
h_t = self.gru_single_step_word_level(Xt,h_t) # [batch_size*num_sentences,embed_size]<------Xt:[batch_size*num_sentences,embed_size];h_t:[batch_size*num_sentences,embed_size]
Expand All @@ -283,7 +282,7 @@ def gru_backward_word_level(self, embedded_words):
embedded_words_splitted] # it is a list,length is sentence_length, each element is [batch_size*num_sentences,embed_size]
embedded_words_squeeze.reverse() # it is a list,length is sentence_length, each element is [batch_size*num_sentences,embed_size]
# demension_1=int(tf.get_shape(embedded_words_squeeze[0])[0]) #h_t = tf.ones([self.batch_size*self.num_sentences, self.hidden_size])
h_t = tf.ones((self.batch_size * self.num_sentences, self.hidden_size))
h_t = tf.ones_like(embedded_words_squeeze[0])
h_t_backward_list = []
for time_step, Xt in enumerate(embedded_words_squeeze):
h_t = self.gru_single_step_word_level(Xt, h_t)
Expand All @@ -303,7 +302,7 @@ def gru_forward_sentence_level(self, sentence_representation):
sentence_representation_squeeze = [tf.squeeze(x, axis=1) for x in
sentence_representation_splitted] # it is a list.length is num_sentences,each element is [batch_size, hidden_size*2]
# demension_1 = int(tf.get_shape(sentence_representation_squeeze[0])[0]) #scalar: batch_size
h_t = tf.ones((self.batch_size, self.hidden_size * 2)) # TODO
h_t = tf.ones_like(sentence_representation_squeeze[0])
h_t_forward_list = []
for time_step, Xt in enumerate(sentence_representation_squeeze): # Xt:[batch_size, hidden_size*2]
h_t = self.gru_single_step_sentence_level(Xt,
Expand All @@ -324,7 +323,7 @@ def gru_backward_sentence_level(self, sentence_representation):
sentence_representation_splitted] # it is a list.length is num_sentences,each element is [batch_size, hidden_size*2]
sentence_representation_squeeze.reverse()
# demension_1 = int(tf.get_shape(sentence_representation_squeeze[0])[0]) # scalar: batch_size
h_t = tf.ones((self.batch_size, self.hidden_size * 2))
h_t = tf.ones_like(sentence_representation_squeeze[0])
h_t_forward_list = []
for time_step, Xt in enumerate(sentence_representation_squeeze): # Xt:[batch_size, hidden_size*2]
h_t = self.gru_single_step_sentence_level(Xt,h_t) # h_t:[batch_size,hidden_size]<---------Xt:[batch_size, hidden_size*2]; h_t:[batch_size, hidden_size*2]
Expand Down Expand Up @@ -423,4 +422,4 @@ def test():
textRNN.dropout_keep_prob: dropout_keep_prob})
print("loss:", loss, "acc:", acc, "label:", input_y, "prediction:", predict)
# print("W_projection_value_:",W_projection_value)
#test()
#test()

0 comments on commit 0d7e40c

Please sign in to comment.