Skip to content

Commit

Permalink
refactor some docstring and params
Browse files Browse the repository at this point in the history
  • Loading branch information
Wollents committed Nov 14, 2024
1 parent 52b8086 commit c7df01e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 21 deletions.
8 changes: 7 additions & 1 deletion pygod/detector/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class CARD(DeepDetector):
gama: float, optional
The proportion of the local reconstruction in contrastive learning module.
Default: ``0.5``
alpha: float, optional
The proprotion of the community embedding in the conbine_encoder.
Default: ``0.1``
verbose : int, optional
Verbosity mode. Range in [0, 3]. Larger value for printing out
more log information. Default: ``0``.
Expand Down Expand Up @@ -115,6 +118,7 @@ def __init__(self,
subgraph_num_neigh=4,
fp=0.6,
gama=0.5,
alpha=0.1,
verbose=0,
save_emb=False,
compile_model=False,
Expand All @@ -138,6 +142,7 @@ def __init__(self,
self.subgraph_num_neigh = subgraph_num_neigh
self.fp = fp
self.gama = gama
self.alpha = alpha

def process_graph(self, data):
community_adj, self.diff_data = CARDBase.process_graph(data)
Expand All @@ -151,14 +156,15 @@ def init_model(self, **kwargs):
self.hid_dim)

return CARDBase(in_dim=self.in_dim,
subgraph_num_neigh=self.subgraph_num_neigh,
fp=self.fp,
gama=self.gama,
alpha=self.alpha,
hid_dim=self.hid_dim,
num_layers=self.num_layers,
dropout=self.dropout,
act=self.act,
backbone=self.backbone,
subgraph_num_neigh=self.subgraph_num_neigh,
**kwargs).to(self.device)

def forward_model(self, data):
Expand Down
36 changes: 16 additions & 20 deletions pygod/nn/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class CARDBase(nn.Module):
gama: float, optional
The proportion of the local reconstruction in contrastive learning module.
Default: ``0.5``
alpha: float, optional
The proprotion of the community embedding in the conbine_encoder.
Default: ``0.1``
hid_dim : int, optional
Hidden dimension of model. Default: ``64``.
num_layers : int, optional
Expand All @@ -54,9 +57,9 @@ class CARDBase(nn.Module):

def __init__(self,
in_dim,
subgraph_num_neigh=4,
fp=0.6,
gama=0.4,
subgraph_num_neigh=4,
alpha=0.1,
hid_dim=64,
num_layers=4,
Expand Down Expand Up @@ -171,18 +174,18 @@ def loss_func(self, logits, diff_logits, x_, local_x_, x, con_label):
Parameters
----------
logits : _type_
_description_
diff_logits : _type_
_description_
x_ : _type_
_description_
local_x_ : _type_
_description_
x : _type_
_description_
con_label : _type_
_description_
logits : torch.Tensor
Discriminator logits of positive subgraphs batch.
diff_logits : torch.Tensor
Discriminator logits of negative subgraphs batch.
x_ : torch.Tensor
Global reconstructed attribute embeddings.
local_x_ : torch.Tensor
Local reconstructed attribute embeddings.
x : torch.Tensor
Input attribute embeddings.
con_label : torch.Tensor
Contrastive learning pseudo label
Returns
-------
Expand Down Expand Up @@ -254,13 +257,6 @@ def _train_subgraph_network(self, data):
x = subgraph.x
edge_index = subgraph.edge_index

# diff_subgraphs = NeighborLoader(
# self.diff, num_neighbors=[-1] * self.num_layers)
# diff_subgraph = diff_subgraphs([index])
# diff_subgraph.x[0, :] = 0
# diff_x = diff_subgraph.x.to(self.device)
# diff_edge_index = diff_subgraph.edge_index.to(self.device)

ori_emb = self.encoder(x, edge_index)
community_emb = self.community_encoder(community_adj)
combine_emb = self.combine_encoder(
Expand Down

0 comments on commit c7df01e

Please sign in to comment.