-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Graph Transformer Enhancement #9751
Comments
Hi @xnuohz, I tried the Exphormer, can you please tell me if I need to make any changes to it and if I can raise a PR for this? import torch
import torch.nn as nn
import torch_geometric.nn as geom_nn
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.nn import MessagePassing
from typing import Optional, Tuple
import numpy as np
import random
class ExphormerAttention(MessagePassing):
def __init__(
self,
hidden_dim: int,
num_heads: int,
dropout: float = 0.1,
edge_dim: Optional[int] = None
):
super().__init__(aggr='add', node_dim=0)
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
self.o_proj = nn.Linear(hidden_dim, hidden_dim)
self.edge_proj = nn.Linear(edge_dim, hidden_dim) if edge_dim is not None else None
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None,
return_attention_weights: bool = False
) -> torch.Tensor:
q = self.q_proj(x).view(-1, self.num_heads, self.head_dim)
k = self.k_proj(x).view(-1, self.num_heads, self.head_dim)
v = self.v_proj(x).view(-1, self.num_heads, self.head_dim)
edge_features = None
if edge_attr is not None and self.edge_proj is not None:
edge_features = self.edge_proj(edge_attr).view(-1, self.num_heads, self.head_dim)
# Propagate messages
out = self.propagate(
edge_index=edge_index,
x=(q, k, v),
edge_attr=edge_features,
size=None
)
out = out.view(-1, self.hidden_dim)
out = self.o_proj(out)
if return_attention_weights:
return out, self.attention_weights
return out
def message(
self,
q_i: torch.Tensor,
k_j: torch.Tensor,
v_j: torch.Tensor,
edge_attr: Optional[torch.Tensor],
index: torch.Tensor,
ptr: Optional[torch.Tensor],
size_i: Optional[int]
) -> torch.Tensor:
alpha = (q_i * k_j).sum(dim=-1) * self.scale
if edge_attr is not None:
alpha = alpha + (q_i * edge_attr).sum(dim=-1)
alpha = geom_nn.utils.softmax(alpha, index, ptr, size_i)
self.attention_weights = alpha # Store for optional return
alpha = self.dropout(alpha)
return v_j * alpha.unsqueeze(-1)
class EXPHORMER(nn.Module):
def __init__(
self,
hidden_dim: int,
num_heads: int = 8,
num_layers: int = 3,
dropout: float = 0.1,
num_virtual_nodes: int = 1,
expander_degree: int = 4,
use_expander: bool = True,
use_global: bool = True,
edge_dim: Optional[int] = None
):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.use_expander = use_expander
self.use_global = use_global
self.num_virtual_nodes = num_virtual_nodes
self.expander_degree = expander_degree
# Create attention layers
self.layers = nn.ModuleList([
ExphormerAttention(
hidden_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
edge_dim=edge_dim
) for _ in range(num_layers)
])
# Virtual node embedding
if use_global:
self.virtual_node_embedding = nn.Parameter(
torch.randn(num_virtual_nodes, hidden_dim)
)
# Edge type embeddings
self.edge_type_embeddings = nn.Parameter(torch.randn(3, hidden_dim)) # local, expander, global
# Layer norm and dropout
self.layer_norms = nn.ModuleList([
nn.LayerNorm(hidden_dim) for _ in range(num_layers)
])
self.dropout = nn.Dropout(dropout)
def generate_expander_edges(self, num_nodes: int) -> torch.Tensor:
"""Generate random expander graph edges."""
edges = []
for _ in range(self.expander_degree // 2):
perm = torch.randperm(num_nodes)
edges.extend([(i, perm[i].item()) for i in range(num_nodes)])
edges.extend([(perm[i].item(), i) for i in range(num_nodes)])
return torch.tensor(edges, dtype=torch.long).t()
def build_interaction_graph(
self,
edge_index: torch.Tensor,
num_nodes: int,
edge_attr: Optional[torch.Tensor] = None,
batch: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Build the complete interaction graph with all components."""
edge_indices = [edge_index] # Start with local edges
edge_types = [torch.zeros(edge_index.size(1), dtype=torch.long)]
# Add expander edges
if self.use_expander:
expander_edges = self.generate_expander_edges(num_nodes)
edge_indices.append(expander_edges)
edge_types.append(torch.ones(expander_edges.size(1), dtype=torch.long))
# Add global attention edges
if self.use_global:
virtual_node_indices = []
for v_idx in range(self.num_virtual_nodes):
v_node = num_nodes + v_idx
# Connect virtual node to all other nodes
src = torch.full((num_nodes,), v_node, dtype=torch.long)
dst = torch.arange(num_nodes, dtype=torch.long)
virtual_node_indices.extend([
torch.stack([src, dst]),
torch.stack([dst, src])
])
if virtual_node_indices:
virtual_edges = torch.cat(virtual_node_indices, dim=1)
edge_indices.append(virtual_edges)
edge_types.append(torch.full((virtual_edges.size(1),), 2, dtype=torch.long))
# Combine all edges
combined_edges = torch.cat(edge_indices, dim=1)
combined_types = torch.cat(edge_types)
# Create edge features from type embeddings
edge_features = self.edge_type_embeddings[combined_types]
# Combine with input edge features if they exist
if edge_attr is not None:
num_local_edges = edge_index.size(1)
padding = torch.zeros(
combined_edges.size(1) - num_local_edges,
edge_attr.size(1),
device=edge_attr.device
)
edge_attr = torch.cat([edge_attr, padding])
edge_features = torch.cat([edge_features, edge_attr], dim=-1)
return combined_edges, edge_features
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None,
batch: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass of EXPHORMER.
Args:
x: Node features [num_nodes, hidden_dim]
edge_index: Graph connectivity [2, num_edges]
edge_attr: Edge features [num_edges, edge_dim]
batch: Batch assignment for nodes [num_nodes]
"""
num_nodes = x.size(0)
# Add virtual nodes if using global attention
if self.use_global:
x = torch.cat([x, self.virtual_node_embedding], dim=0)
if batch is not None:
batch = torch.cat([
batch,
torch.zeros(self.num_virtual_nodes, dtype=torch.long, device=batch.device)
])
# Build interaction graph
interaction_edges, edge_features = self.build_interaction_graph(
edge_index, num_nodes, edge_attr, batch
)
# Process layers
for layer, layer_norm in zip(self.layers, self.layer_norms):
# Attention layer
out = layer(x, interaction_edges, edge_features)
out = self.dropout(out)
x = layer_norm(out + x)
# Remove virtual nodes from output if they were added
if self.use_global:
x = x[:num_nodes]
return x |
Thanks @phoeenniixx
cc @rusty1s |
Sorry I am new to PyG 😅, just have some doubts:
|
Sounds like |
Thanks! I'll try and raise a PR in some days... There you can tell me any changes you think I should make, as it is my first time here so please help me through the process :) |
🚀 The feature, motivation and pitch
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: