Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

请问seq_output = tf.concat(self.lstm_outputs, 1)的用意是什么? #11

Open
Eddiechiu opened this issue Jul 6, 2018 · 7 comments

Comments

@Eddiechiu
Copy link

你好,想请教个问题。
我的运行下来,报错在
y_reshaped = tf.reshape(y_one_hot, self.logits.get_shape())

因为y_one_hot和self.logits的总元素数量不同,所以不能reshape。
我推算了一下:

  1. inputs的shape是(num_seqs, num_steps),经过tf.one_hot以后,lstm_inputs的shape变成(num_seqs, num_steps, num_classes)

  2. 我用的是cell是一层的lstm,lstm_inputs经过tf.nn.dynamic(cell, lstm_inputs, initial_state=self.initial_state)后,lstm_outputs的shape是(num_seqs, num_steps, lstm_size)

  3. lstm_outputs经过tf.concat(lstm_outputs, 1)以后,shape没有任何变化,再经过一些列运算后,shape就会有问题。

所以想问一下tf.concat(lstm_outputs, 1)这一步是做什么的?
感谢~

@charmpeng
Copy link

我认为,self.lstm_inputs的shape在经过embedding_lookup后,应该是(num_seqs, num_steps, embedding_size)。也就是一个input由embedding_size大小的向量表示。

@Eddiechiu
Copy link
Author

嗯,你的embedding_size就是我的num_classes,
但是tf.concat(lstm_outputs, 1)这一步我没懂,而且跳过这一步程序可以正常运行。

@charmpeng
Copy link

charmpeng commented Jul 12, 2018

'with tf.name_scope('lstm'):
        cell = tf.nn.rnn_cell.MultiRNNCell(
            [get_a_cell(self.lstm_size, self.keep_prob) for _ in range(self.num_layers)]
        )
        self.initial_state = cell.zero_state(self.num_seqs, tf.float32)

        # 通过dynamic_rnn对cell展开时间维度
        self.lstm_outputs, self.final_state = tf.nn.dynamic_rnn(cell, self.lstm_inputs, initial_state=self.initial_state)
        print("self.lstm_outputs.get_shape",self.lstm_outputs.get_shape()) # (32,50,128)
        seq_output = tf.concat(self.lstm_outputs, 1)
        print("seq_output.get_shape()",seq_output.get_shape())  # (32,50,128)
        x = tf.reshape(seq_output, [-1, self.lstm_size])
        print("x.get_shape()",x.get_shape())  # (1600,128)`

我打印出了shape,看来seq_output = tf.concat(self.lstm_outputs, 1) 这一句并没什么用处,因为concat前后,lstm_outputs与seq_output的shape都是一样的。

@Eddiechiu
Copy link
Author

嗯嗯,是的,所以我就直接reshape了 =D

@sunnima
Copy link

sunnima commented Mar 19, 2019

这一步确实没用。

@Natumsol
Copy link

确实没用。。。

@cherish6092
Copy link

tf.concat(values, 1)#values必须是序列,在这里不起作用

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants