Skip to content

Commit

Permalink
Upgrade networkx dependency (#9718)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Oct 21, 2024
1 parent 3ac9fda commit facf0c4
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 8 deletions.
2 changes: 1 addition & 1 deletion torch_geometric/config_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _recursive_from_config(value: Any) -> Any:
if is_dataclass(value):
if getattr(value, '_target_', None):
try:
cls = _locate_cls(value._target_) # type: ignore[attr-defined]
cls = _locate_cls(value._target_) # type: ignore
except ImportError:
pass # Keep the dataclass as it is.
else:
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/metrics/link_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def update(
) + 1
arange = torch.arange(
start=0,
end=max_index * pred_index_mat.size(0),
step=max_index,
end=max_index * pred_index_mat.size(0), # type: ignore
step=max_index, # type: ignore
device=pred_index_mat.device,
).view(-1, 1)
flat_pred_index = (pred_index_mat + arange).view(-1)
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/add_positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def forward(self, data: Data) -> Data:
from numpy.linalg import eig, eigh
eig_fn = eig if not self.is_undirected else eigh

eig_vals, eig_vecs = eig_fn(L.todense()) # type: ignore
eig_vals, eig_vecs = eig_fn(L.todense())
else:
from scipy.sparse.linalg import eigs, eigsh
eig_fn = eigs if not self.is_undirected else eigsh
Expand Down
6 changes: 5 additions & 1 deletion torch_geometric/transforms/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ def get_attrs_with_suffix(
return [key for key in store.keys() if key.endswith(suffix)]


def get_mask_size(attr: str, store: BaseStorage, size: Optional[int]) -> int:
def get_mask_size(
attr: str,
store: BaseStorage,
size: Optional[int],
) -> Optional[int]:
if size is not None:
return size
return store.num_edges if store.is_edge_attr(attr) else store.num_nodes
Expand Down
5 changes: 2 additions & 3 deletions torch_geometric/visualization/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,8 @@ def _visualize_graph_via_networkx(
),
)

nodes = nx.draw_networkx_nodes(g, pos, node_size=node_size,
node_color='white', margins=0.1)
nodes.set_edgecolor('black')
nx.draw_networkx_nodes(g, pos, node_size=node_size, node_color='white',
margins=0.1, edgecolors='black')
nx.draw_networkx_labels(g, pos, font_size=10)

if path is not None:
Expand Down

0 comments on commit facf0c4

Please sign in to comment.