Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Transformer Enhancement #9751

Open
xnuohz opened this issue Oct 30, 2024 · 5 comments
Open

Graph Transformer Enhancement #9751

xnuohz opened this issue Oct 30, 2024 · 5 comments
Labels

Comments

@xnuohz
Copy link
Contributor

xnuohz commented Oct 30, 2024

🚀 The feature, motivation and pitch

  1. Exphormer: Sparse Transformers for Graphs
  2. SGFormer: Simplifying and Empowering Transformers for Large-Graph Representations
  3. Polynormer: Polynomial-Expressive Graph Transformer in Linear Time
  4. Gradformer: Graph Transformer with Exponential Decay
  5. CoBFormer

Alternatives

No response

Additional context

No response

@xnuohz xnuohz added the feature label Oct 30, 2024
@phoeenniixx
Copy link

Hi @xnuohz, I tried the Exphormer, can you please tell me if I need to make any changes to it and if I can raise a PR for this?

import torch
import torch.nn as nn
import torch_geometric.nn as geom_nn
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.nn import MessagePassing
from typing import Optional, Tuple
import numpy as np
import random

class ExphormerAttention(MessagePassing):
    def __init__(
        self, 
        hidden_dim: int,
        num_heads: int,
        dropout: float = 0.1,
        edge_dim: Optional[int] = None
    ):
        super().__init__(aggr='add', node_dim=0)
        
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)
        
        self.edge_proj = nn.Linear(edge_dim, hidden_dim) if edge_dim is not None else None
        
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: Optional[torch.Tensor] = None,
        return_attention_weights: bool = False
    ) -> torch.Tensor:

        q = self.q_proj(x).view(-1, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(-1, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(-1, self.num_heads, self.head_dim)

        edge_features = None
        if edge_attr is not None and self.edge_proj is not None:
            edge_features = self.edge_proj(edge_attr).view(-1, self.num_heads, self.head_dim)

        # Propagate messages
        out = self.propagate(
            edge_index=edge_index,
            x=(q, k, v),
            edge_attr=edge_features,
            size=None
        )

        out = out.view(-1, self.hidden_dim)
        out = self.o_proj(out)
        
        if return_attention_weights:
            return out, self.attention_weights
        return out

    def message(
        self,
        q_i: torch.Tensor,
        k_j: torch.Tensor,
        v_j: torch.Tensor,
        edge_attr: Optional[torch.Tensor],
        index: torch.Tensor,
        ptr: Optional[torch.Tensor],
        size_i: Optional[int]
    ) -> torch.Tensor:

        alpha = (q_i * k_j).sum(dim=-1) * self.scale

        if edge_attr is not None:
            alpha = alpha + (q_i * edge_attr).sum(dim=-1)

        alpha = geom_nn.utils.softmax(alpha, index, ptr, size_i)
        self.attention_weights = alpha  # Store for optional return

        alpha = self.dropout(alpha)

        return v_j * alpha.unsqueeze(-1)

class EXPHORMER(nn.Module):
    def __init__(
        self,
        hidden_dim: int,
        num_heads: int = 8,
        num_layers: int = 3,
        dropout: float = 0.1,
        num_virtual_nodes: int = 1,
        expander_degree: int = 4,
        use_expander: bool = True,
        use_global: bool = True,
        edge_dim: Optional[int] = None
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.use_expander = use_expander
        self.use_global = use_global
        self.num_virtual_nodes = num_virtual_nodes
        self.expander_degree = expander_degree

        # Create attention layers
        self.layers = nn.ModuleList([
            ExphormerAttention(
                hidden_dim=hidden_dim,
                num_heads=num_heads,
                dropout=dropout,
                edge_dim=edge_dim
            ) for _ in range(num_layers)
        ])
        
        # Virtual node embedding
        if use_global:
            self.virtual_node_embedding = nn.Parameter(
                torch.randn(num_virtual_nodes, hidden_dim)
            )
        
        # Edge type embeddings
        self.edge_type_embeddings = nn.Parameter(torch.randn(3, hidden_dim))  # local, expander, global
        
        # Layer norm and dropout
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(hidden_dim) for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)

    def generate_expander_edges(self, num_nodes: int) -> torch.Tensor:
        """Generate random expander graph edges."""
        edges = []
        for _ in range(self.expander_degree // 2):
            perm = torch.randperm(num_nodes)
            edges.extend([(i, perm[i].item()) for i in range(num_nodes)])
            edges.extend([(perm[i].item(), i) for i in range(num_nodes)])
        
        return torch.tensor(edges, dtype=torch.long).t()

    def build_interaction_graph(
        self,
        edge_index: torch.Tensor,
        num_nodes: int,
        edge_attr: Optional[torch.Tensor] = None,
        batch: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Build the complete interaction graph with all components."""
        edge_indices = [edge_index]  # Start with local edges
        edge_types = [torch.zeros(edge_index.size(1), dtype=torch.long)]
        
        # Add expander edges
        if self.use_expander:
            expander_edges = self.generate_expander_edges(num_nodes)
            edge_indices.append(expander_edges)
            edge_types.append(torch.ones(expander_edges.size(1), dtype=torch.long))
        
        # Add global attention edges
        if self.use_global:
            virtual_node_indices = []
            for v_idx in range(self.num_virtual_nodes):
                v_node = num_nodes + v_idx
                # Connect virtual node to all other nodes
                src = torch.full((num_nodes,), v_node, dtype=torch.long)
                dst = torch.arange(num_nodes, dtype=torch.long)
                virtual_node_indices.extend([
                    torch.stack([src, dst]),
                    torch.stack([dst, src])
                ])
            
            if virtual_node_indices:
                virtual_edges = torch.cat(virtual_node_indices, dim=1)
                edge_indices.append(virtual_edges)
                edge_types.append(torch.full((virtual_edges.size(1),), 2, dtype=torch.long))
        
        # Combine all edges
        combined_edges = torch.cat(edge_indices, dim=1)
        combined_types = torch.cat(edge_types)
        
        # Create edge features from type embeddings
        edge_features = self.edge_type_embeddings[combined_types]
        
        # Combine with input edge features if they exist
        if edge_attr is not None:
            num_local_edges = edge_index.size(1)
            padding = torch.zeros(
                combined_edges.size(1) - num_local_edges,
                edge_attr.size(1),
                device=edge_attr.device
            )
            edge_attr = torch.cat([edge_attr, padding])
            edge_features = torch.cat([edge_features, edge_attr], dim=-1)
        
        return combined_edges, edge_features

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: Optional[torch.Tensor] = None,
        batch: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Forward pass of EXPHORMER.
        
        Args:
            x: Node features [num_nodes, hidden_dim]
            edge_index: Graph connectivity [2, num_edges]
            edge_attr: Edge features [num_edges, edge_dim]
            batch: Batch assignment for nodes [num_nodes]
        """
        num_nodes = x.size(0)
        
        # Add virtual nodes if using global attention
        if self.use_global:
            x = torch.cat([x, self.virtual_node_embedding], dim=0)
            if batch is not None:
                batch = torch.cat([
                    batch,
                    torch.zeros(self.num_virtual_nodes, dtype=torch.long, device=batch.device)
                ])
        
        # Build interaction graph
        interaction_edges, edge_features = self.build_interaction_graph(
            edge_index, num_nodes, edge_attr, batch
        )
        
        # Process layers
        for layer, layer_norm in zip(self.layers, self.layer_norms):
            # Attention layer
            out = layer(x, interaction_edges, edge_features)
            out = self.dropout(out)

            x = layer_norm(out + x)
        
        # Remove virtual nodes from output if they were added
        if self.use_global:
            x = x[:num_nodes]
        
        return x

@xnuohz
Copy link
Contributor Author

xnuohz commented Nov 12, 2024

Thanks @phoeenniixx
If ud like to integrate Exphormer in PyG, I think you may need to add some submodules

  • nn.attention.exphormer
  • nn.conv.exphormer_conv (or directly support ExphormerAttention in nn.conv.gps_conv, I am not sure if this is feasible, need your confirmation^^)
  • Add unit test and make sure it passes the CI
  • Add an example

cc @rusty1s

@phoeenniixx
Copy link

Sorry I am new to PyG 😅, just have some doubts:

  • you want me to create the ExphormerAttention class in nn.attention.exphormer?
  • And rest EXPHORMER class in nn.conv.exphormer_conv?
  • What I think is that we can break the EXPHORMER class, like right now only attention layer is a different module, we could break it to have expander edges , global attention and . Local neighborhood attention to different modules and then create a main "parent" class that brings all the components together. (Although I am not sure if it is useful to break it into so many parts?)

@xnuohz
Copy link
Contributor Author

xnuohz commented Nov 12, 2024

Sounds like expander edges can be implemented as part of utils to generate expanded graph, local and global attention can refer to nn.gcn_conv and transforms.VirtualNode.

@phoeenniixx
Copy link

Thanks! I'll try and raise a PR in some days... There you can tell me any changes you think I should make, as it is my first time here so please help me through the process :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants