Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] 58 add goggle #162

Merged
merged 30 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c4c22d6
First goggle implementation, no cond
robsdavis Mar 22, 2023
ac05399
cleanup
robsdavis Mar 22, 2023
13944e7
Adding tests and some debugging
robsdavis Mar 23, 2023
ce78c8b
Debugging
robsdavis Mar 24, 2023
9010944
debugging
robsdavis Mar 27, 2023
301acb1
debugging
robsdavis Mar 27, 2023
9492bad
Debugging
robsdavis Mar 28, 2023
fe4ff06
Debugging
robsdavis Mar 28, 2023
afb57bd
Debugging
robsdavis Mar 28, 2023
33e043d
Enabled reproducible results
robsdavis Mar 29, 2023
86c4bd0
getting tests to pass
robsdavis Mar 29, 2023
0b6687e
Clean up
robsdavis Mar 29, 2023
e43b769
Added more argument validation
robsdavis Mar 29, 2023
1181db7
Add docstrings
robsdavis Mar 30, 2023
2a6b128
remove cond test
robsdavis Mar 30, 2023
bf46be4
Expose node_dim as parameter.
robsdavis Mar 30, 2023
973fac7
default to cpu device
robsdavis Mar 30, 2023
60eff8c
revert to previous device defaults
robsdavis Mar 30, 2023
06d7270
Added dependencies to setup.cfg
robsdavis Mar 31, 2023
459772a
Added download of torch_stable to prereq.txt
robsdavis Mar 31, 2023
32fc3d3
move goggle reqs to extra
robsdavis Mar 31, 2023
63817d0
Merge branch 'main' into 58-add-goggle
robsdavis Mar 31, 2023
06c526b
Fixing dependencies
robsdavis Mar 31, 2023
ac2319c
Swap airfoil dataset for diabetes in goggle tests
robsdavis Mar 31, 2023
c8e6c54
Fixing dependencies
robsdavis Mar 31, 2023
d61c32d
Handle failed plugin loading due to uninstalled extra
robsdavis Mar 31, 2023
b608658
Debugging dependencies
robsdavis Mar 31, 2023
f08198a
Debugging
robsdavis Mar 31, 2023
2d67dc3
debug goggle (#167)
bcebere Apr 1, 2023
bdcde45
Update README.md
bcebere Apr 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions prereq.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-f https://download.pytorch.org/whl/torch_stable.html
numpy
torch
tsai
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ install_requires =
opacus>=1.3
decaf-synthetic-data>=0.1.5
optuna>=3.1
dgl
torch_geometric
torch_sparse
torch_scatter
shap
tqdm
loguru
Expand Down
300 changes: 300 additions & 0 deletions src/synthcity/plugins/core/models/RGCNConv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
# Standard imports
# stdlib
from typing import Any, Optional, Tuple, Union

# third party
import torch
from torch import Tensor
from torch.nn import Parameter
from torch.nn import Parameter as Param
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.typing import Adj, OptTensor
from torch_sparse import SparseTensor, masked_select_nnz, matmul

try:
# third party
from pyg_lib.ops import segment_matmul # C implemented method

_WITH_PYG_LIB = True
except ImportError:
_WITH_PYG_LIB = False

def segment_matmul(inputs: Tensor, ptr: Tensor, other: Tensor) -> Tensor:
raise NotImplementedError


def masked_edge_index(
edge_index: Tensor, edge_mask: Tensor
) -> Union[SparseTensor, Tensor]:
if isinstance(edge_index, Tensor):
return edge_index[:, edge_mask]
else:
return masked_select_nnz(edge_index, edge_mask, layout="coo")


class RGCNConv(MessagePassing):
r"""The relational graph convolutional operator from the `"Modeling
Relational Data with Graph Convolutional Networks"
<https://arxiv.org/abs/1703.06103>`_ paper
.. math::
\mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot
\mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
\frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,
where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types.
Edge type needs to be a one-dimensional :obj:`torch.long` tensor which
stores a relation identifier
:math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge.
.. note::
This implementation is as memory-efficient as possible by iterating
over each individual relation type.
Therefore, it may result in low GPU utilization in case the graph has a
large number of relations.
As an alternative approach, :class:`FastRGCNConv` does not iterate over
each individual type, but may consume a large amount of memory to
compensate.
We advise to check out both implementations to see which one fits your
needs.
Args:
in_channels (int or tuple): Size of each input sample. A tuple
corresponds to the sizes of source and target dimensionalities.
In case no input features are given, this argument should
correspond to the number of nodes in your graph.
out_channels (int): Size of each output sample.
num_relations (int): Number of relations.
num_bases (int, optional): If set, this layer will use the
basis-decomposition regularization scheme where :obj:`num_bases`
denotes the number of bases to use. (default: :obj:`None`)
num_blocks (int, optional): If set, this layer will use the
block-diagonal-decomposition regularization scheme where
:obj:`num_blocks` denotes the number of blocks to use.
(default: :obj:`None`)
aggr (string, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
(default: :obj:`"mean"`)
root_weight (bool, optional): If set to :obj:`False`, the layer will
not add transformed root node features to the output.
(default: :obj:`True`)
is_sorted (bool, optional): If set to :obj:`True`, assumes that
:obj:`edge_index` is sorted by :obj:`edge_type`. This avoids
internal re-sorting of the data and can improve runtime and memory
efficiency. (default: :obj:`False`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
message_passing_node_dim (int, optional): The axis along which to
propagate in message passing. (default: :obj:`0`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""

def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
num_relations: int,
num_bases: Optional[int] = None,
num_blocks: Optional[int] = None,
aggr: str = "mean",
root_weight: bool = True,
is_sorted: bool = False,
bias: bool = True,
message_passing_node_dim: int = 0,
**kwargs: Any,
) -> None:
kwargs.setdefault("aggr", aggr)
super().__init__(node_dim=message_passing_node_dim, **kwargs)
self._WITH_PYG_LIB = torch.cuda.is_available() and _WITH_PYG_LIB

if num_bases is not None and num_blocks is not None:
raise ValueError(
"Can not apply both basis-decomposition and "
"block-diagonal-decomposition at the same time."
)

self.in_channels = in_channels
self.out_channels = out_channels
self.num_relations = num_relations
self.num_bases = num_bases
self.num_blocks = num_blocks
self.is_sorted = is_sorted

if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
self.in_channels_l = in_channels[0]

if num_bases is not None:
self.weight = Parameter(
torch.Tensor(num_bases, in_channels[0], out_channels)
)
self.comp = Parameter(torch.Tensor(num_relations, num_bases))

elif num_blocks is not None:
if in_channels[0] % num_blocks != 0 and out_channels % num_blocks != 0:
raise AssertionError(
"Channels must be divisible by num_blocks, for RGCNConv."
)
self.weight = Parameter(
torch.Tensor(
num_relations,
num_blocks,
in_channels[0] // num_blocks,
out_channels // num_blocks,
)
)
self.register_parameter("comp", None)

else:
self.weight = Parameter(
torch.Tensor(num_relations, in_channels[0], out_channels)
)
self.register_parameter("comp", None)

if root_weight:
self.root = Param(torch.Tensor(in_channels[1], out_channels))
else:
self.register_parameter("root", None)

if bias:
self.bias = Param(torch.Tensor(out_channels))
else:
self.register_parameter("bias", None)

self.reset_parameters()

def reset_parameters(self) -> None:
glorot(self.weight)
glorot(self.comp)
glorot(self.root)
zeros(self.bias)

def forward(
self,
x: Union[OptTensor, Tuple[OptTensor, Tensor]],
edge_index: Adj,
edge_type: OptTensor = None,
edge_weight: OptTensor = None,
) -> Tensor:
r"""
Args:
x: The input node features. Can be either a :obj:`[num_nodes,
in_channels]` node feature matrix, or an optional
one-dimensional node index tensor (in which case input features
are treated as trainable node embeddings).
Furthermore, :obj:`x` can be of type :obj:`tuple` denoting
source and destination node features.
edge_index (LongTensor or SparseTensor): The edge indices.
edge_type: The one-dimensional relation type/index for each edge in
:obj:`edge_index`.
Should be only :obj:`None` in case :obj:`edge_index` is of type
:class:`torch_sparse.tensor.SparseTensor`.
(default: :obj:`None`)
"""
x_l: OptTensor = None
if isinstance(x, tuple):
x_l = x[0]
else:
x_l = x

if x_l is None:
x_l = torch.arange(self.in_channels_l, device=self.weight.device)

x_r: Tensor = x_l
if isinstance(x, tuple):
x_r = x[1]

size = (x_l.size(0), x_r.size(0))

if isinstance(edge_index, SparseTensor):
edge_type = edge_index.storage.value()
if edge_type is None:
raise AssertionError("edge_type cannot be None for RGCNConv.")

out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device)

weight = self.weight
if self.num_bases is not None:
weight = (self.comp @ weight.view(self.num_bases, -1)).view(
self.num_relations, self.in_channels_l, self.out_channels
)

if self.num_blocks is not None:
if x_l.dtype == torch.long and self.num_blocks is not None:
raise ValueError(
"Block-diagonal decomposition not supported "
"for non-continuous input features."
)

for i in range(self.num_relations):
tmp = masked_edge_index(edge_index, edge_type == i)
h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size)
h = h.view(-1, weight.size(1), weight.size(2))
h = torch.einsum("abc,bcd->abd", h, weight[i])
out = out + h.contiguous().view(-1, self.out_channels)

else:
if self._WITH_PYG_LIB and isinstance(edge_index, Tensor):
# print("yes to self._WITH_PYG_LIB")
"""
if not self.is_sorted:
if (edge_type[1:] < edge_type[:-1]).any():
edge_type, perm = edge_type.sort()
edge_index = edge_index[:, perm]
edge_type_ptr = torch.ops.torch_sparse.ind2ptr(
edge_type, self.num_relations)
out = self.propagate(edge_index, x=x_l,
edge_type_ptr=edge_type_ptr, size=size)
"""
else:
for i in range(self.num_relations):
tmp = masked_edge_index(edge_index, edge_type == i)
if edge_weight is not None:
tmp_weight = edge_weight[edge_type == i]
else:
tmp_weight = None

if x_l.dtype == torch.long:
print("here and x_l.dtype is torch.long")
"""
out = out + self.propagate(
tmp,
x=weight[i, x_l],
edge_type_ptr=None,
size=size,
)
"""
else:
h = self.propagate(
tmp,
x=x_l,
edge_type_ptr=None,
edge_weight=tmp_weight,
size=size,
)
out = out + (h @ weight[i])
root = self.root
if root is not None:
out = out + (root[x_r] if x_r.dtype == torch.long else x_r @ root)

if self.bias is not None:
out = out + self.bias

return out

def message(
self, x_j: Tensor, edge_type_ptr: OptTensor, edge_weight: OptTensor
) -> Tensor:
if edge_type_ptr is not None:
print("definitely not here")
return segment_matmul(x_j, edge_type_ptr, self.weight)

return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
adj_t = adj_t.set_value(None)
return matmul(adj_t, x, reduce=self.aggr)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.in_channels}, "
f"{self.out_channels}, num_relations={self.num_relations})"
)
Loading