From 2082ef07d180f4a2a8fa8a081832f174d2862865 Mon Sep 17 00:00:00 2001 From: Anqi Date: Fri, 10 Feb 2023 13:45:19 +0800 Subject: [PATCH] add more example --- README-CN.md | 2 + README.md | 2 + .../algorithm/AlgoPerformanceTest.scala | 97 +++++++++++++++++++ 3 files changed, 101 insertions(+) create mode 100644 example/src/main/scala/com/vesoft/nebula/algorithm/AlgoPerformanceTest.scala diff --git a/README-CN.md b/README-CN.md index 493aa19..b83a3ac 100644 --- a/README-CN.md +++ b/README-CN.md @@ -21,6 +21,7 @@ nebula-algorithm 是一款基于 [GraphX](https://spark.apache.org/graphx/) 的 | ClusteringCoefficient | 聚集系数 |推荐、电信诈骗分析| | Jaccard |杰卡德相似度计算|相似度计算、推荐| | BFS |广度优先遍历 |层序遍历、最短路径规划| + | DFS |深度优先遍历 |层序遍历、最短路径规划| | Node2Vec | - |图分类| 使用 `nebula-algorithm`,可以通过提交 `Spark` 任务的形式使用完整的算法工具对 `Nebula Graph` 数据库中的数据执行图计算,也可以通过编程形式调用`lib`库下的算法针对DataFrame执行图计算。 @@ -101,6 +102,7 @@ nebula-algorithm 是一款基于 [GraphX](https://spark.apache.org/graphx/) 的 | closeness | closeness |double/string| | hanp | hanp | int/string | | bfs | bfs | string | +| dfs | dfs | string | | jaccard | jaccard | string | | node2vec | node2vec | string | diff --git a/README.md b/README.md index 3d1c030..ba4370e 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ nebula-algorithm is a Spark Application based on [GraphX](https://spark.apache.o | ClusteringCoefficient | recommended, telecom fraud analysis| | Jaccard | similarity calculation, recommendation| | BFS | sequence traversal, Shortest path plan| +| DFS | sequence traversal, Shortest path plan| | Node2Vec | graph machine learning, recommendation| @@ -111,6 +112,7 @@ If you want to write the algorithm execution result into NebulaGraph(`sink: nebu | closeness | closeness |double/string| | hanp | hanp | int/string | | bfs | bfs | string | +| bfs | dfs | string | | jaccard | jaccard | string | | node2vec | node2vec | string | diff --git a/example/src/main/scala/com/vesoft/nebula/algorithm/AlgoPerformanceTest.scala b/example/src/main/scala/com/vesoft/nebula/algorithm/AlgoPerformanceTest.scala new file mode 100644 index 0000000..7900b12 --- /dev/null +++ b/example/src/main/scala/com/vesoft/nebula/algorithm/AlgoPerformanceTest.scala @@ -0,0 +1,97 @@ +/* Copyright (c) 2022 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.algorithm + +import com.vesoft.nebula.connector.connector.{NebulaDataFrameReader} +import com.facebook.thrift.protocol.TCompactProtocol +import com.vesoft.nebula.algorithm.config.{CcConfig, LPAConfig, LouvainConfig, PRConfig} +import com.vesoft.nebula.algorithm.lib.{ + ConnectedComponentsAlgo, + LabelPropagationAlgo, + LouvainAlgo, + PageRankAlgo +} +import com.vesoft.nebula.connector.{NebulaConnectionConfig, ReadNebulaConfig} +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} + +object AlgoPerformanceTest { + + def main(args: Array[String]): Unit = { + val sparkConf = new SparkConf() + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) + val spark = SparkSession + .builder() + .config(sparkConf) + .getOrCreate() + + val df = readNebulaData(spark) + lpa(spark, df) + louvain(spark, df) + pagerank(spark, df) + wcc(spark, df) + + } + + def readNebulaData(spark: SparkSession): DataFrame = { + val start = System.currentTimeMillis() + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.0.1:9559") + .withTimeout(6000) + .withConenctionRetry(2) + .build() + val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig + .builder() + .withSpace("twitter") + .withLabel("FOLLOW") + .withNoColumn(true) + .withLimit(20000) + .withPartitionNum(120) + .build() + val df: DataFrame = + spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF() + df.cache() + df.count() + println(s"read data cost time ${(System.currentTimeMillis() - start)}") + df + } + + def lpa(spark: SparkSession, df: DataFrame): Unit = { + val start = System.currentTimeMillis() + val lpaConfig = LPAConfig(10) + val lpa = LabelPropagationAlgo.apply(spark, df, lpaConfig, false) + lpa.write.csv("hdfs://127.0.0.1:9000/tmp/lpa") + println(s"lpa compute and save cost ${System.currentTimeMillis() - start}") + } + + def pagerank(spark: SparkSession, df: DataFrame): Unit = { + val start = System.currentTimeMillis() + val pageRankConfig = PRConfig(10, 0.85) + val pr = PageRankAlgo.apply(spark, df, pageRankConfig, false) + pr.write.csv("hdfs://127.0.0.1:9000/tmp/pagerank") + println(s"pagerank compute and save cost ${System.currentTimeMillis() - start}") + } + + def wcc(spark: SparkSession, df: DataFrame): Unit = { + val start = System.currentTimeMillis() + val ccConfig = CcConfig(20) + val cc = ConnectedComponentsAlgo.apply(spark, df, ccConfig, false) + cc.write.csv("hdfs://127.0.0.1:9000/tmp/wcc") + println(s"wcc compute and save cost ${System.currentTimeMillis() - start}") + } + + def louvain(spark: SparkSession, df: DataFrame): Unit = { + val start = System.currentTimeMillis() + val louvainConfig = LouvainConfig(10, 5, 0.5) + val louvain = LouvainAlgo.apply(spark, df, louvainConfig, false) + louvain.write.csv("hdfs://127.0.0.1:9000/tmp/louvain") + println(s"louvain compute and save cost ${System.currentTimeMillis() - start}") + } + +}