-
Notifications
You must be signed in to change notification settings - Fork 349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support recurrent with no states. #1113
base: master
Are you sure you want to change the base?
Conversation
I'll let someone more familiar with |
Ok, I'll write the test case. |
# Ensure that all initial states are available. | ||
initial_states = brick.initial_states(batch_size, as_dict=True, | ||
*args, **kwargs) | ||
for state_name in application.states: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like the code starting from this line can be moved out of the if
clause, and the else
part is not really necessary. Right now we pay a high price of having an additional level of indentation for this new feature, and it would be great to keep the complexity of the code down.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to add the line before
else:
initial_states = OrderedDict()
In the original code, the |
You may produce the error with the following code. The error occurs when the class does not contain a recurrent method named import numpy
import theano
from numpy.testing import assert_allclose
from theano import tensor
from blocks.bricks import Brick
from blocks.bricks.recurrent import BaseRecurrent, recurrent
# from recurrent import recurrent
class RecurrentWrapperNoStatesClass(BaseRecurrent):
def __init__(self, dim, **kwargs):
super(RecurrentWrapperNoStatesClass, self).__init__(**kwargs)
self.dim = dim
def get_dim(self, name):
if name in ['inputs', 'outputs', 'outputs_2']:
return self.dim
if name == 'mask':
return 0
return super(RecurrentWrapperNoStatesClass, self).get_dim(name)
@recurrent(sequences=['inputs', 'mask'], states=[],
outputs=['outputs', 'outputs_2'], contexts=[])
def apply2(self, inputs, mask=None):
outputs = inputs * 10
outputs_2 = tensor.sqr(inputs)
if mask:
outputs *= mask
outputs_2 *= mask
return outputs, outputs_2
if __name__ == '__main__':
recurrent_examples = RecurrentWrapperNoStatesClass(
dim=11, name='test_example')
X = tensor.tensor3('X')
out, out_2 = recurrent_examples.apply2(inputs=X, mask=None)
x_val = numpy.random.uniform(size=(5, 1, 1))
x_val = numpy.asarray(x_val, dtype=theano.config.floatX)
out_eval = out.eval({X: x_val})
out_2_eval = out_2.eval({X: x_val})
assert_allclose(x_val * 10, out_eval)
assert_allclose(numpy.square(x_val), out_2_eval) |
state_name, brick.name)) | ||
states_given = dict_subset(kwargs, application.states) | ||
else: | ||
states_given = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember right, it should be an OrderedDict
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since states_given in the else clause is never used, it does not matter whether it is a OrderedDict, dict or None.
@@ -104,7 +104,15 @@ def auxiliary_variables(self): | |||
@property | |||
def scan_variables(self): | |||
"""Variables of Scan ops.""" | |||
return list(chain(*[g.variables for g in self._scan_graphs])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code supposed that no recurrent class is nested. #1115
@@ -104,7 +104,15 @@ def auxiliary_variables(self): | |||
@property | |||
def scan_variables(self): | |||
"""Variables of Scan ops.""" | |||
return list(chain(*[g.variables for g in self._scan_graphs])) | |||
# BFS | |||
scan_graphs = self._scan_graphs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably want to copy scan_graphs
here, like e.g. scan_graphs = list(self._scan_graphs)
.
@@ -46,6 +46,9 @@ def initial_states(self, batch_size, *args, **kwargs): | |||
The keyword arguments of the application call. | |||
|
|||
""" | |||
if not hasattr(self, 'apply') or not self.apply.states: | |||
return | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain how it works? I cannot immediately see it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when some subclass call the default initial_states
function in the BaseRecurrent
class. This line would check whether it is necessary to return the initial states. If the subclass does not have an apply
method or its apply
method does not contain states
, the initial_states
would not return anything.
This line would make it to support recurrent
class with no apply
function or with no states
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you want to have a class without apply
? It's a mistake if a user forgot to define apply
and the best is to crash soon.
In a case if apply.states
is empty, initial_states
would return an empty list before this change, why is it wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this line is added, the above code, which contains a recurrent brick with no apply method, would run well.
But, you are right about the apply
method. The Brick
subclass should follow some design rules. The problem is no code checks whether there is an apply
method in a Brick
subclass at present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Beronx86 , checking apply.states
in BaseRecurrent.initial_states
is not a solution. There are quite a few places in Blocks-dependent code where initial_states
method is overloaded. Instead, like in your previous solution, initial_states
should not be called if application
does not have states
. Can you please revert back to the previous version of your fix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rizar I think this check could be carried out in Brick.__init__
method. So we can make sure all Brick
subclasses contain apply
methods. I reverted back the changes in BaseRecurrent
.
I don't understand, now you have removed your fix, and it is again not supported to have no states property. Why not just implemented like you did in the first place, but with more gentle changes to the code as I suggested? |
The recurrent wrapper does not support loop with no states. But this kind of loop may be useful. So I modified the codes.
Fixes #1112