Skip to content

Commit 9546904

Browse files
committed
WIP: Script for generating algo endpoints code
1 parent b253559 commit 9546904

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

scripts/generate_algo_endpoints.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#!/usr/bin/env python3
2+
3+
from graphdatascience import GraphDataScience
4+
from collections import defaultdict
5+
import os
6+
7+
ALGO_MODES = {"mutate", "stats", "stream", "train", "write"}
8+
9+
URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
10+
AUTH = ("neo4j", "password")
11+
if os.environ.get("NEO4J_USER"):
12+
AUTH = (
13+
os.environ.get("NEO4J_USER", "DUMMY"),
14+
os.environ.get("NEO4J_PASSWORD", "neo4j"),
15+
)
16+
DB = os.environ.get("NEO4J_DB", "neo4j")
17+
18+
gds = GraphDataScience(URI, auth=AUTH, database=DB)
19+
20+
21+
all_endpoints = gds.list()["name"].tolist()
22+
alpha_algo_endpoints = defaultdict(lambda: list())
23+
beta_algo_endpoints = defaultdict(lambda: list())
24+
prod_algo_endpoints = defaultdict(lambda: list())
25+
asp_algo_endpoints = defaultdict(lambda: list())
26+
sp_algo_endpoints = defaultdict(lambda: list())
27+
28+
29+
def add_mode(algo_name, mode, endpoints):
30+
if mode == "mutate":
31+
endpoints[algo_name].append("MutateEndpoint")
32+
elif mode == "stats":
33+
endpoints[algo_name].append("StatsEndpoint")
34+
elif mode == "stream":
35+
endpoints[algo_name].append("StreamEndpoint")
36+
elif mode == "train":
37+
endpoints[algo_name].append("TrainEndpoint")
38+
elif mode == "write":
39+
endpoints[algo_name].append("WriteEndpoint")
40+
41+
42+
def collect_algo_endpoints(all_endpoints):
43+
for e in all_endpoints:
44+
ep_components = e.split(".")
45+
46+
if ep_components[-1] not in ALGO_MODES:
47+
continue
48+
49+
if "graph" in ep_components:
50+
continue
51+
52+
if "pipeline" in ep_components:
53+
continue
54+
55+
ep_components = e.split(".")
56+
if len(ep_components) == 3:
57+
add_mode(ep_components[1], ep_components[2], prod_algo_endpoints)
58+
elif len(ep_components) == 4:
59+
if ep_components[1] == "alpha":
60+
add_mode(ep_components[2], ep_components[3], alpha_algo_endpoints)
61+
elif ep_components[1] == "beta":
62+
add_mode(ep_components[2], ep_components[3], beta_algo_endpoints)
63+
elif ep_components[1] == "allShortestPaths":
64+
add_mode(ep_components[2], ep_components[3], asp_algo_endpoints)
65+
elif ep_components[1] == "shortestPath":
66+
add_mode(ep_components[2], ep_components[3], sp_algo_endpoints)
67+
else:
68+
raise RuntimeError(f"Unable to handle algo endpoint '{e}'")
69+
else:
70+
# raise RuntimeError(f"Unable to handle algo endpoint '{e}'")
71+
print(e)
72+
73+
74+
def generate_algo_endpoint_builder(algo_name, algo_memberships):
75+
return f"""
76+
@property
77+
def {algo_name}(self):
78+
return CallerBase({', '.join(algo_memberships)})
79+
"""
80+
81+
82+
def populate_class(class_base, algo_pairs):
83+
algo_eps = os.linesep.join([generate_algo_endpoint_builder(name, classes) for name, classes in algo_pairs.items()])
84+
# for name, super_classes in algo_pairs.items():
85+
# class_base.append(generate_algo_endpoint_builder(name, super_classes))
86+
87+
return f"{class_base}{algo_eps}"
88+
89+
90+
collect_algo_endpoints(all_endpoints)
91+
92+
alpha_algos_class = "class AlphaAlgos(UncallableNamespace):"
93+
beta_algos_class = "class BetaAlgos(UncallableNamespace):"
94+
prod_algos_class = "class ProdAlgos(UncallableNamespace):"
95+
asp_algos_class = "class ASPAlgos(UncallableNamespace):"
96+
sp_algos_class = "class SPAlgos(UncallableNamespace):"
97+
98+
print(populate_class(alpha_algos_class, alpha_algo_endpoints))
99+
100+
# print(alpha_algos_class)

0 commit comments

Comments
 (0)