-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] 58 add goggle #162
Merged
Merged
[WIP] 58 add goggle #162
Changes from 12 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
c4c22d6
First goggle implementation, no cond
robsdavis ac05399
cleanup
robsdavis 13944e7
Adding tests and some debugging
robsdavis ce78c8b
Debugging
robsdavis 9010944
debugging
robsdavis 301acb1
debugging
robsdavis 9492bad
Debugging
robsdavis fe4ff06
Debugging
robsdavis afb57bd
Debugging
robsdavis 33e043d
Enabled reproducible results
robsdavis 86c4bd0
getting tests to pass
robsdavis 0b6687e
Clean up
robsdavis e43b769
Added more argument validation
robsdavis 1181db7
Add docstrings
robsdavis 2a6b128
remove cond test
robsdavis bf46be4
Expose node_dim as parameter.
robsdavis 973fac7
default to cpu device
robsdavis 60eff8c
revert to previous device defaults
robsdavis 06d7270
Added dependencies to setup.cfg
robsdavis 459772a
Added download of torch_stable to prereq.txt
robsdavis 32fc3d3
move goggle reqs to extra
robsdavis 63817d0
Merge branch 'main' into 58-add-goggle
robsdavis 06c526b
Fixing dependencies
robsdavis ac2319c
Swap airfoil dataset for diabetes in goggle tests
robsdavis c8e6c54
Fixing dependencies
robsdavis d61c32d
Handle failed plugin loading due to uninstalled extra
robsdavis b608658
Debugging dependencies
robsdavis f08198a
Debugging
robsdavis 2d67dc3
debug goggle (#167)
bcebere bdcde45
Update README.md
bcebere File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,297 @@ | ||
# Standard imports | ||
# stdlib | ||
from typing import Any, Optional, Tuple, Union | ||
|
||
# third party | ||
import torch | ||
from torch import Tensor | ||
from torch.nn import Parameter | ||
from torch.nn import Parameter as Param | ||
from torch_geometric.nn.conv import MessagePassing | ||
from torch_geometric.nn.inits import glorot, zeros | ||
from torch_geometric.typing import Adj, OptTensor | ||
from torch_sparse import SparseTensor, masked_select_nnz, matmul | ||
|
||
try: | ||
# third party | ||
from pyg_lib.ops import segment_matmul # C implemented method | ||
|
||
_WITH_PYG_LIB = True | ||
except ImportError: | ||
_WITH_PYG_LIB = False | ||
|
||
def segment_matmul(inputs: Tensor, ptr: Tensor, other: Tensor) -> Tensor: | ||
raise NotImplementedError | ||
|
||
|
||
def masked_edge_index( | ||
edge_index: Tensor, edge_mask: Tensor | ||
) -> Union[SparseTensor, Tensor]: | ||
if isinstance(edge_index, Tensor): | ||
return edge_index[:, edge_mask] | ||
else: | ||
return masked_select_nnz(edge_index, edge_mask, layout="coo") | ||
|
||
|
||
class RGCNConv(MessagePassing): | ||
r"""The relational graph convolutional operator from the `"Modeling | ||
Relational Data with Graph Convolutional Networks" | ||
<https://arxiv.org/abs/1703.06103>`_ paper | ||
.. math:: | ||
\mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot | ||
\mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)} | ||
\frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j, | ||
where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types. | ||
Edge type needs to be a one-dimensional :obj:`torch.long` tensor which | ||
stores a relation identifier | ||
:math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge. | ||
.. note:: | ||
This implementation is as memory-efficient as possible by iterating | ||
over each individual relation type. | ||
Therefore, it may result in low GPU utilization in case the graph has a | ||
large number of relations. | ||
As an alternative approach, :class:`FastRGCNConv` does not iterate over | ||
each individual type, but may consume a large amount of memory to | ||
compensate. | ||
We advise to check out both implementations to see which one fits your | ||
needs. | ||
Args: | ||
in_channels (int or tuple): Size of each input sample. A tuple | ||
corresponds to the sizes of source and target dimensionalities. | ||
In case no input features are given, this argument should | ||
correspond to the number of nodes in your graph. | ||
out_channels (int): Size of each output sample. | ||
num_relations (int): Number of relations. | ||
num_bases (int, optional): If set, this layer will use the | ||
basis-decomposition regularization scheme where :obj:`num_bases` | ||
denotes the number of bases to use. (default: :obj:`None`) | ||
num_blocks (int, optional): If set, this layer will use the | ||
block-diagonal-decomposition regularization scheme where | ||
:obj:`num_blocks` denotes the number of blocks to use. | ||
(default: :obj:`None`) | ||
aggr (string, optional): The aggregation scheme to use | ||
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). | ||
(default: :obj:`"mean"`) | ||
root_weight (bool, optional): If set to :obj:`False`, the layer will | ||
not add transformed root node features to the output. | ||
(default: :obj:`True`) | ||
is_sorted (bool, optional): If set to :obj:`True`, assumes that | ||
:obj:`edge_index` is sorted by :obj:`edge_type`. This avoids | ||
internal re-sorting of the data and can improve runtime and memory | ||
efficiency. (default: :obj:`False`) | ||
bias (bool, optional): If set to :obj:`False`, the layer will not learn | ||
an additive bias. (default: :obj:`True`) | ||
**kwargs (optional): Additional arguments of | ||
:class:`torch_geometric.nn.conv.MessagePassing`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_channels: Union[int, Tuple[int, int]], | ||
out_channels: int, | ||
num_relations: int, | ||
num_bases: Optional[int] = None, | ||
num_blocks: Optional[int] = None, | ||
aggr: str = "mean", | ||
root_weight: bool = True, | ||
is_sorted: bool = False, | ||
bias: bool = True, | ||
**kwargs: Any, | ||
) -> None: | ||
kwargs.setdefault("aggr", aggr) | ||
super().__init__(node_dim=0, **kwargs) | ||
self._WITH_PYG_LIB = torch.cuda.is_available() and _WITH_PYG_LIB | ||
|
||
if num_bases is not None and num_blocks is not None: | ||
raise ValueError( | ||
"Can not apply both basis-decomposition and " | ||
"block-diagonal-decomposition at the same time." | ||
) | ||
|
||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.num_relations = num_relations | ||
self.num_bases = num_bases | ||
self.num_blocks = num_blocks | ||
self.is_sorted = is_sorted | ||
|
||
if isinstance(in_channels, int): | ||
in_channels = (in_channels, in_channels) | ||
self.in_channels_l = in_channels[0] | ||
|
||
if num_bases is not None: | ||
self.weight = Parameter( | ||
torch.Tensor(num_bases, in_channels[0], out_channels) | ||
) | ||
self.comp = Parameter(torch.Tensor(num_relations, num_bases)) | ||
|
||
elif num_blocks is not None: | ||
if in_channels[0] % num_blocks != 0 and out_channels % num_blocks != 0: | ||
raise AssertionError( | ||
"Channels must be divisible by num_blocks, for RGCNConv." | ||
) | ||
self.weight = Parameter( | ||
torch.Tensor( | ||
num_relations, | ||
num_blocks, | ||
in_channels[0] // num_blocks, | ||
out_channels // num_blocks, | ||
) | ||
) | ||
self.register_parameter("comp", None) | ||
|
||
else: | ||
self.weight = Parameter( | ||
torch.Tensor(num_relations, in_channels[0], out_channels) | ||
) | ||
self.register_parameter("comp", None) | ||
|
||
if root_weight: | ||
self.root = Param(torch.Tensor(in_channels[1], out_channels)) | ||
else: | ||
self.register_parameter("root", None) | ||
|
||
if bias: | ||
self.bias = Param(torch.Tensor(out_channels)) | ||
else: | ||
self.register_parameter("bias", None) | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self) -> None: | ||
glorot(self.weight) | ||
glorot(self.comp) | ||
glorot(self.root) | ||
zeros(self.bias) | ||
|
||
def forward( | ||
self, | ||
x: Union[OptTensor, Tuple[OptTensor, Tensor]], | ||
edge_index: Adj, | ||
edge_type: OptTensor = None, | ||
edge_weight: OptTensor = None, | ||
) -> Tensor: | ||
r""" | ||
Args: | ||
x: The input node features. Can be either a :obj:`[num_nodes, | ||
in_channels]` node feature matrix, or an optional | ||
one-dimensional node index tensor (in which case input features | ||
are treated as trainable node embeddings). | ||
Furthermore, :obj:`x` can be of type :obj:`tuple` denoting | ||
source and destination node features. | ||
edge_index (LongTensor or SparseTensor): The edge indices. | ||
edge_type: The one-dimensional relation type/index for each edge in | ||
:obj:`edge_index`. | ||
Should be only :obj:`None` in case :obj:`edge_index` is of type | ||
:class:`torch_sparse.tensor.SparseTensor`. | ||
(default: :obj:`None`) | ||
""" | ||
x_l: OptTensor = None | ||
if isinstance(x, tuple): | ||
x_l = x[0] | ||
else: | ||
x_l = x | ||
|
||
if x_l is None: | ||
x_l = torch.arange(self.in_channels_l, device=self.weight.device) | ||
|
||
x_r: Tensor = x_l | ||
if isinstance(x, tuple): | ||
x_r = x[1] | ||
|
||
size = (x_l.size(0), x_r.size(0)) | ||
|
||
if isinstance(edge_index, SparseTensor): | ||
edge_type = edge_index.storage.value() | ||
if edge_type is None: | ||
raise AssertionError("edge_type cannot be None for RGCNConv.") | ||
|
||
out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device) | ||
|
||
weight = self.weight | ||
if self.num_bases is not None: | ||
weight = (self.comp @ weight.view(self.num_bases, -1)).view( | ||
self.num_relations, self.in_channels_l, self.out_channels | ||
) | ||
|
||
if self.num_blocks is not None: | ||
if x_l.dtype == torch.long and self.num_blocks is not None: | ||
raise ValueError( | ||
"Block-diagonal decomposition not supported " | ||
"for non-continuous input features." | ||
) | ||
|
||
for i in range(self.num_relations): | ||
tmp = masked_edge_index(edge_index, edge_type == i) | ||
h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size) | ||
h = h.view(-1, weight.size(1), weight.size(2)) | ||
h = torch.einsum("abc,bcd->abd", h, weight[i]) | ||
out = out + h.contiguous().view(-1, self.out_channels) | ||
|
||
else: | ||
if self._WITH_PYG_LIB and isinstance(edge_index, Tensor): | ||
# print("yes to self._WITH_PYG_LIB") | ||
""" | ||
if not self.is_sorted: | ||
if (edge_type[1:] < edge_type[:-1]).any(): | ||
edge_type, perm = edge_type.sort() | ||
edge_index = edge_index[:, perm] | ||
edge_type_ptr = torch.ops.torch_sparse.ind2ptr( | ||
edge_type, self.num_relations) | ||
out = self.propagate(edge_index, x=x_l, | ||
edge_type_ptr=edge_type_ptr, size=size) | ||
""" | ||
else: | ||
for i in range(self.num_relations): | ||
tmp = masked_edge_index(edge_index, edge_type == i) | ||
if edge_weight is not None: | ||
tmp_weight = edge_weight[edge_type == i] | ||
else: | ||
tmp_weight = None | ||
|
||
if x_l.dtype == torch.long: | ||
print("here and x_l.dtype is torch.long") | ||
""" | ||
out = out + self.propagate( | ||
tmp, | ||
x=weight[i, x_l], | ||
edge_type_ptr=None, | ||
size=size, | ||
) | ||
""" | ||
else: | ||
h = self.propagate( | ||
tmp, | ||
x=x_l, | ||
edge_type_ptr=None, | ||
edge_weight=tmp_weight, | ||
size=size, | ||
) | ||
out = out + (h @ weight[i]) | ||
root = self.root | ||
if root is not None: | ||
out = out + (root[x_r] if x_r.dtype == torch.long else x_r @ root) | ||
|
||
if self.bias is not None: | ||
out = out + self.bias | ||
|
||
return out | ||
|
||
def message( | ||
self, x_j: Tensor, edge_type_ptr: OptTensor, edge_weight: OptTensor | ||
) -> Tensor: | ||
if edge_type_ptr is not None: | ||
print("definitely not here") | ||
return segment_matmul(x_j, edge_type_ptr, self.weight) | ||
|
||
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j | ||
|
||
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: | ||
adj_t = adj_t.set_value(None) | ||
return matmul(adj_t, x, reduce=self.aggr) | ||
|
||
def __repr__(self) -> str: | ||
return ( | ||
f"{self.__class__.__name__}({self.in_channels}, " | ||
f"{self.out_channels}, num_relations={self.num_relations})" | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this
node_dim
a parameter or needs to be hardcoded to 3?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have exposed it as a parameter now.