Skip to content

Commit

Permalink
using the last hidden state for Fc
Browse files Browse the repository at this point in the history
using the last hidden state for Fc rather than average pooling
  • Loading branch information
acadTags authored May 10, 2018
1 parent a552e11 commit d64e229
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion a03_TextRNN/p8_TextRNN_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def inference(self):
print("outputs:===>",outputs) #outputs:(<tf.Tensor 'bidirectional_rnn/fw/fw/transpose:0' shape=(?, 5, 100) dtype=float32>, <tf.Tensor 'ReverseV2:0' shape=(?, 5, 100) dtype=float32>))
#3. concat output
output_rnn=tf.concat(outputs,axis=2) #[batch_size,sequence_length,hidden_size*2]
self.output_rnn_last=tf.reduce_mean(output_rnn,axis=1) #[batch_size,hidden_size*2] #output_rnn_last=output_rnn[:,-1,:] ##[batch_size,hidden_size*2] #TODO
#self.output_rnn_last=tf.reduce_mean(output_rnn,axis=1) #[batch_size,hidden_size*2]
self.output_rnn_last=output_rnn[:,-1,:] ##[batch_size,hidden_size*2] #TODO
print("output_rnn_last:", self.output_rnn_last) # <tf.Tensor 'strided_slice:0' shape=(?, 200) dtype=float32>
#4. logits(use linear layer)
with tf.name_scope("output"): #inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of the input network.
Expand Down

0 comments on commit d64e229

Please sign in to comment.