Skip to content

Commit

Permalink
HANConv: NaN handling (#4841)
Browse files Browse the repository at this point in the history
* update

* changelog
  • Loading branch information
rusty1s authored Jun 22, 2022
1 parent 3d76627 commit f909d24
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825))
- Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805))
- Added a `max_sample` argument to `AddMetaPaths` in order to tackle very dense metapath edges ([#4750](https://github.com/pyg-team/pytorch_geometric/pull/4750))
- Test `HANConv` with empty tensors ([#4756](https://github.com/pyg-team/pytorch_geometric/pull/4756))
- Test `HANConv` with empty tensors ([#4756](https://github.com/pyg-team/pytorch_geometric/pull/4756), [#4841](https://github.com/pyg-team/pytorch_geometric/pull/4841))
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/conv/han_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def group(xs: List[Tensor], q: nn.Parameter,
else:
num_edge_types = len(xs)
out = torch.stack(xs)
if out.numel() == 0:
return out.view(0, out.size(-1))
attn_score = (q * torch.tanh(k_lin(out)).mean(1)).sum(-1)
attn = F.softmax(attn_score, dim=0)
out = torch.sum(attn.view(num_edge_types, 1, -1) * out, dim=0)
Expand Down

0 comments on commit f909d24

Please sign in to comment.