Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training on batches of GraphsTuples? #147

Open
robertswil opened this issue Nov 17, 2021 · 5 comments
Open

Training on batches of GraphsTuples? #147

robertswil opened this issue Nov 17, 2021 · 5 comments

Comments

@robertswil
Copy link

robertswil commented Nov 17, 2021

Let's say I want to train an LSTM or transformer on sequences of graphs using Sonnet2/TF2:

I want to represent the graphs in each sequence as one GraphsTuple, which means my batches are essentially an iterable of GraphsTuples, each with a variable number of graphs. This is great until it's time to get the input signature and compile the update step. It's unclear to me how to define the tensorspec for this type of input. Is my best route to subclass collections.namedtuple() similar to how you define a GraphsTuple, or can you suggest a more elegant solution?

Thanks

@alvarosg
Copy link
Collaborator

Thanks for your message!

There are two options here and hopefully at least one of them would work for you:

  • Option 1: if the graphs change structure over time, define the input as a sequence of graphs tuples as you mention. Then you can use this method to get a signature as in the TF demo, but instead of getting a single signature, you can get getting a signature for each graph in the sequence, and passing this list of signatures as the input_signature to the tf.function, because specs passed to tf.function can be arbitrarily nested this should work.
  • Option 2: If all graphs in the sequence share the same structure (nodes and edges), you can make any node features and edge features have shape [total_num_nodes/edges, sequence_length, feature_size], and just give a single GraphsTuple as the input to your update function. Then inside your update function you can do something like this to build graphs at each step:
def update_fn(input_graph_sequence, ...)

  def loop_body(step_i,...):
     graph_step_i = input_graph_sequence.replace(
        nodes=input_graph_sequence.nodes[:, step_i], 
        edges=input_graph_sequence.edges[:, step_i])

  num_steps = input_graph_sequence.nodes.shape.as_list()[1]
  tf.scan(loop_body, tf.range(num_steps), ...)  
 ...

Hope this helps!

@robertswil
Copy link
Author

This worked. Thanks @alvarosg !

@robertswil
Copy link
Author

Follow-on issue:

I am passing batches to the model during training like so:

outputs = tf.convert_to_tensor([model(graphs_tuple) for graphs_tuple in inputs])

As a reminder, each batch is an iterable of GaphsTuples, and each GraphsTuple represents a sequence of graphs for one training data point.

The GraphIndependent object of the encoder block (EncodeProcessDecode) throws the error: AttributeError: 'GraphsTuple' object has no attribute 'replace'. Location of the error according to the stack trace is here in modules.GraphIndependent``._build().

Any ideas on how to solve?

@alvarosg
Copy link
Collaborator

Could you check the type of the object being passed to the model?

My guess is that the GraphsTuple input that your are passing is not actually a graphs.GraphsTuple, but some serialization library or something like that has transformed it into a namedtuple that looks the same, but is not actually the same class, and does not have the extra methods

A simple fix to get the right type just do:

graphs_tuple = graphs.GraphsTuple(*graphs_tuple) before passing it to EncodeProcessDecode, but it may be good to understand where the type gets messed up.

@robertswil
Copy link
Author

Your hunch was correct! It was being transformed into collections.GraphsTuple during the list comprehension:

outputs = tf.convert_to_tensor([model(inputs) for inputs in inputs_train])

Is that^^ the preferred way to feed a batch of graph sequences during the update step? Wondering if it isn't, since when run graphs.GraphsTuple(*inputs) during training, I receive the following error:

OperatorNotAllowedInGraphError: iterating over 'tf.Tensor' is not allowed: AutGraph did convert this function. This might indicate you are trying to use an unsupported feature.

Checking a bit more, type(inputs) when this error is thrown is a symbolic tensor, not a GraphsTuple, which I guess means this happens before the backend actually runs the first batch through the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants