Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about padding for transformer #7

Open
agave233 opened this issue Jun 24, 2022 · 1 comment
Open

Question about padding for transformer #7

agave233 opened this issue Jun 24, 2022 · 1 comment

Comments

@agave233
Copy link

Hi,
Thanks for your excellent work.
I found the implementation of padding a batch in your code is confusing. The code for batch is:

def pad_batch(h_node, batch, max_input_len, get_mask=False):
    num_batch = batch[-1] + 1
    num_nodes = []
    masks = []
    for i in range(num_batch):
        mask = batch.eq(i)
        masks.append(mask)
        num_node = mask.sum()
        num_nodes.append(num_node)

    # logger.info(max(num_nodes))
    max_num_nodes = min(max(num_nodes), max_input_len)
    padded_h_node = h_node.data.new(max_num_nodes, num_batch, h_node.size(-1)).fill_(0)
    src_padding_mask = h_node.data.new(num_batch, max_num_nodes).fill_(0).bool()

    for i, mask in enumerate(masks):
        num_node = num_nodes[i]
        if num_node > max_num_nodes:
            num_node = max_num_nodes
        padded_h_node[-num_node:, i] = h_node[mask][-num_node:]
        src_padding_mask[i, : max_num_nodes - num_node] = True  # [b, s]

    if get_mask:
        return padded_h_node, src_padding_mask, num_nodes, masks, max_num_nodes
    return padded_h_node, src_padding_mask

I think the line "src_padding_mask[i, : max_num_nodes - num_node] = True" for masking might should be:

src_padding_mask = h_node.data.new(num_batch, max_num_nodes).fill_(1).bool()
src_padding_mask[i, : max_num_nodes - num_node] = False

Because in the pooling part, the original code can cause the denominator of the this line as 0:

h_graph = transformer_out.sum(0) / src_padding_mask.sum(-1, keepdim=True)
@LUOyk1999
Copy link

LUOyk1999 commented Oct 9, 2022

There is nothing wrong with the author's code:
src_padding_mask = h_node.data.new(num_batch, max_num_nodes).fill_(0).bool()
src_padding_mask[i, : max_num_nodes - num_node] = True

Notice that padded_h_node's shape is (S, B, h_d), and S is divided into padding_node_size and graph_node_size (graph_node is in the back because it is easy to add 'cls' token).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants