Skip to content

Commit

Permalink
fix: can't convert DGL heterograph with edge attributes into ArangoDB (
Browse files Browse the repository at this point in the history
…#21)

* fix: #20

* new: test case for #20

* update: notebook version

* remove: duplicate docstring

* cleanup: test_dgl_to_adb

* fix: notebook typo
  • Loading branch information
aMahanna authored May 31, 2022
1 parent 940503b commit 9421afe
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 18 deletions.
19 changes: 14 additions & 5 deletions adbdgl_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
import logging
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set, Union
from typing import Any, DefaultDict, Dict, List, Optional, Set, Union

from arango.cursor import Cursor
from arango.database import Database
Expand Down Expand Up @@ -251,6 +251,7 @@ def dgl_to_arangodb(
for v_col in adb_v_cols:
ntype = None if is_default else v_col
v_col_docs = adb_documents[v_col]
features = dgl_g.node_attr_schemes(ntype).keys()

logger.debug(f"Preparing {len(dgl_g.nodes(ntype))} '{v_col}' DGL nodes")
node: Tensor
Expand All @@ -259,7 +260,7 @@ def dgl_to_arangodb(
adb_vertex = {"_key": str(dgl_node_id)}
self.__prepare_adb_attributes(
dgl_g.ndata,
dgl_g.node_attr_schemes(ntype).keys(),
features,
dgl_node_id,
adb_vertex,
v_col,
Expand All @@ -275,11 +276,14 @@ def dgl_to_arangodb(
for e_col in adb_e_cols:
etype = None if is_default else e_col
e_col_docs = adb_documents[e_col]
features = dgl_g.edge_attr_schemes(etype).keys()

canonical_etype = None
if is_default:
from_col = to_col = adb_v_cols[0]
else:
from_col, _, to_col = dgl_g.to_canonical_etype(e_col)
canonical_etype = dgl_g.to_canonical_etype(e_col)
from_col, _, to_col = canonical_etype

from_nodes, to_nodes = dgl_g.edges(etype=etype)
logger.debug(f"Preparing {len(from_nodes)} '{e_col}' DGL edges")
Expand All @@ -293,11 +297,12 @@ def dgl_to_arangodb(
}
self.__prepare_adb_attributes(
dgl_g.edata,
dgl_g.edge_attr_schemes(etype).keys(),
features,
dgl_edge_id,
adb_edge,
e_col,
has_one_ecol,
canonical_etype,
)

self.__insert_adb_docs(e_col, e_col_docs, adb_edge, batch_size)
Expand Down Expand Up @@ -402,6 +407,7 @@ def __prepare_adb_attributes(
doc: Json,
col: str,
has_one_col: bool,
canonical_etype: Optional[DGLCanonicalEType] = None,
) -> None:
"""Convert DGL features into a set of ArangoDB attributes for a given document
Expand All @@ -419,9 +425,12 @@ def __prepare_adb_attributes(
:param has_one_col: Set to True if the ArangoDB graph has one
vertex collection or one edge collection only.
:type has_one_col: bool
:param canonical_etype: The DGL canonical edge type belonging to the current
**col**, provided that **col** is an edge collection (ignored otherwise).
:type canonical_etype: adbdgl_adapter.typings.DGLCanonicalEType
"""
for key in features:
tensor = data[key] if has_one_col else data[key][col]
tensor = data[key] if has_one_col else data[key][canonical_etype or col]
doc[key] = self.__cntrl._dgl_feature_to_adb_attribute(key, col, tensor[id])

def __insert_adb_docs(
Expand Down
8 changes: 4 additions & 4 deletions examples/ArangoDB_DGL_Adapter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"id": "U1d45V4OeG89"
},
"source": [
"<a href=\"https://colab.research.google.com/github/arangoml/dgl-adapter/blob/2.0.0/examples/ArangoDB_DGL_Adapter.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
"<a href=\"https://colab.research.google.com/github/arangoml/dgl-adapter/blob/2.0.1/examples/ArangoDB_DGL_Adapter.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
Expand Down Expand Up @@ -57,9 +57,9 @@
"outputs": [],
"source": [
"%%capture\n",
"!pip install adbdgl-adapter==2.0.0\n",
"!pip install adbdgl-adapter==2.0.1\n",
"!pip install adb-cloud-connector\n",
"!git clone -b 2.0.0 --single-branch https://github.com/arangoml/dgl-adapter.git\n",
"!git clone -b 2.0.1 --single-branch https://github.com/arangoml/dgl-adapter.git\n",
"\n",
"## For drawing purposes \n",
"!pip install matplotlib\n",
Expand Down Expand Up @@ -987,7 +987,7 @@
"\n",
" if key == \"clique_ndata\":\n",
" try:\n",
" return [\"Eins\", \"Zwei\", \"Drei\", \"Vier\", \"Fünf\", \"Sechs\"][key-1]\n",
" return [\"Eins\", \"Zwei\", \"Drei\", \"Vier\", \"Fünf\", \"Sechs\"][val-1]\n",
" except:\n",
" return -1\n",
"\n",
Expand Down
17 changes: 16 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from arango import ArangoClient
from arango.database import StandardDatabase
from dgl import DGLGraph, remove_self_loop
from dgl import DGLGraph, heterograph, remove_self_loop
from dgl.data import KarateClubDataset, MiniGCDataset
from torch import ones, rand, tensor, zeros

Expand Down Expand Up @@ -114,3 +114,18 @@ def get_clique_graph() -> DGLGraph:
dgl_g.ndata["random_ndata"] = ones(dgl_g.num_nodes())
dgl_g.edata["random_edata"] = zeros(dgl_g.num_edges())
return dgl_g


def get_social_graph() -> DGLGraph:
dgl_g = heterograph(
{
("user", "follows", "user"): (tensor([0, 1]), tensor([1, 2])),
("user", "likes", "game"): (tensor([0, 1, 2]), tensor([0, 1, 2])),
("user", "plays", "game"): (tensor([1, 3]), tensor([1, 2])),
}
)

dgl_g.nodes["user"].data["age"] = tensor([21, 16, 38, 64])
dgl_g.edges["plays"].data["hours_played"] = tensor([3, 5])

return dgl_g
18 changes: 10 additions & 8 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_hypercube_graph,
get_karate_graph,
get_lollipop_graph,
get_social_graph,
)


Expand Down Expand Up @@ -121,23 +122,23 @@ def test_adb_graph_to_dgl(adapter: ADBDGL_Adapter, name: str) -> None:


@pytest.mark.parametrize(
"adapter, name, dgl_g, is_default_type, batch_size",
"adapter, name, dgl_g, batch_size",
[
(adbdgl_adapter, "Clique", get_clique_graph(), True, 3),
(adbdgl_adapter, "Lollipop", get_lollipop_graph(), True, 1000),
(adbdgl_adapter, "Hypercube", get_hypercube_graph(), True, 1000),
(adbdgl_adapter, "Karate", get_karate_graph(), True, 1000),
(adbdgl_adapter, "Clique", get_clique_graph(), 3),
(adbdgl_adapter, "Lollipop", get_lollipop_graph(), 1000),
(adbdgl_adapter, "Hypercube", get_hypercube_graph(), 1000),
(adbdgl_adapter, "Karate", get_karate_graph(), 1000),
(adbdgl_adapter, "Social", get_social_graph(), 1000),
],
)
def test_dgl_to_adb(
adapter: ADBDGL_Adapter,
name: str,
dgl_g: Union[DGLGraph, DGLHeteroGraph],
is_default_type: bool,
batch_size: int,
) -> None:
adb_g = adapter.dgl_to_arangodb(name, dgl_g, batch_size)
assert_arangodb_data(name, dgl_g, adb_g, is_default_type)
assert_arangodb_data(name, dgl_g, adb_g)


def assert_dgl_data(
Expand Down Expand Up @@ -176,8 +177,9 @@ def assert_arangodb_data(
name: str,
dgl_g: Union[DGLGraph, DGLHeteroGraph],
adb_g: ArangoGraph,
is_default_type: bool,
) -> None:
is_default_type = dgl_g.canonical_etypes == adbdgl_adapter.DEFAULT_CANONICAL_ETYPE

for dgl_v_col in dgl_g.ntypes:
adb_v_col = name + dgl_v_col if is_default_type else dgl_v_col
attributes = dgl_g.node_attr_schemes(
Expand Down

0 comments on commit 9421afe

Please sign in to comment.