Skip to content

Commit

Permalink
minor typo fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
YingtongDou authored Dec 20, 2023
1 parent 3f6f058 commit 88c2fed
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions pygod/detector/gadnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GADNR(DeepDetector):
neighborhood distribution reconstruction. Default: ``3``.
neigh_loss : str, optional
The neighbor reconstruction loss. ``KL`` represents the KL divergence
loss, ``W2`` represents the W2 loss. Defualt: ``KL``.
loss, ``W2`` represents the W2 loss. Default: ``KL``.
lambda_loss1 : float, optional
The weight of the neighborhood reconstruction loss term.
Default: ``1e-2``.
Expand Down Expand Up @@ -320,7 +320,7 @@ def fit(self,

self.model.train()
self.decision_score_ = torch.zeros(data.x.shape[0])
for epoch in range(self.epoch):
for epoch in range(1, self.epoch+1, 1):
start_time = time.time()
epoch_loss = 0
epoch_loss_per_node = torch.zeros(data.x.shape[0])
Expand Down Expand Up @@ -413,7 +413,8 @@ def decision_function(self,
The three loss term weights must be the same as the fit function if
``real_loss`` is ``False``.
"""
if self.full_batch: # full batch inference

if self.full_batch: # full batch inference
if self.batch_size != data.x.shape[0]:
raise ValueError(data, 'should have the same number of nodes '
'as the training data under the full '
Expand Down

0 comments on commit 88c2fed

Please sign in to comment.