diff --git a/spark/src/main/scala/com/alibaba/graphar/EdgeInfo.scala b/spark/src/main/scala/com/alibaba/graphar/EdgeInfo.scala index 5ba1e82d9..7b351c150 100644 --- a/spark/src/main/scala/com/alibaba/graphar/EdgeInfo.scala +++ b/spark/src/main/scala/com/alibaba/graphar/EdgeInfo.scala @@ -426,6 +426,10 @@ class EdgeInfo() { str = prefix + getAdjListPrefix(adj_list_type) + str return str } + + def getConcatKey(): String = { + return getSrc_label + GeneralParams.regularSeperator + getEdge_label + GeneralParams.regularSeperator + getDst_label + } } /** Helper object to load edge info files */ diff --git a/spark/src/main/scala/com/alibaba/graphar/GraphInfo.scala b/spark/src/main/scala/com/alibaba/graphar/GraphInfo.scala index 86fccd0a1..a60ff858f 100644 --- a/spark/src/main/scala/com/alibaba/graphar/GraphInfo.scala +++ b/spark/src/main/scala/com/alibaba/graphar/GraphInfo.scala @@ -203,6 +203,34 @@ class GraphInfo() { @BeanProperty var vertices = new java.util.ArrayList[String]() @BeanProperty var edges = new java.util.ArrayList[String]() @BeanProperty var version: String = "" + + var vertexInfos: Map[String, VertexInfo] = Map[String, VertexInfo]() + var edgeInfos: Map[String, EdgeInfo] = Map[String, EdgeInfo]() + + def addVertexInfo(vertexInfo: VertexInfo): Unit = { + vertexInfos += (vertexInfo.getLabel -> vertexInfo) + } + + def addEdgeInfo(edgeInfo: EdgeInfo): Unit = { + edgeInfos += (edgeInfo.getConcatKey() -> edgeInfo) + } + + def getVertexInfo(label: String): VertexInfo = { + vertexInfos(label) + } + + def getEdgeInfo(srcLabel: String, edgeLabel: String, dstLabel: String): EdgeInfo = { + val key = srcLabel + GeneralParams.regularSeperator + edgeLabel + GeneralParams.regularSeperator + dstLabel + edgeInfos(key) + } + + def getVertexInfos(): Map[String, VertexInfo] = { + return vertexInfos + } + + def getEdgeInfos(): Map[String, EdgeInfo] = { + return edgeInfos + } } /** Helper object to load graph info files */ @@ -221,6 +249,23 @@ object GraphInfo { graph_info.setPrefix(prefix) } } + val prefix = graph_info.getPrefix + val vertices_yaml = graph_info.getVertices + val vertices_it = vertices_yaml.iterator + while (vertices_it.hasNext()) { + val file_name = vertices_it.next() + val path = prefix + file_name + val vertex_info = VertexInfo.loadVertexInfo(path, spark) + graph_info.addVertexInfo(vertex_info) + } + val edges_yaml = graph_info.getEdges + val edges_it = edges_yaml.iterator + while (edges_it.hasNext()) { + val file_name = edges_it.next() + val path = prefix + file_name + val edge_info = EdgeInfo.loadEdgeInfo(path, spark) + graph_info.addEdgeInfo(edge_info) + } return graph_info } } diff --git a/spark/src/main/scala/com/alibaba/graphar/graph/GraphReader.scala b/spark/src/main/scala/com/alibaba/graphar/graph/GraphReader.scala new file mode 100644 index 000000000..b336fabe3 --- /dev/null +++ b/spark/src/main/scala/com/alibaba/graphar/graph/GraphReader.scala @@ -0,0 +1,94 @@ +/** Copyright 2022 Alibaba Group Holding Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.graphar.graph + +import com.alibaba.graphar.{GeneralParams, AdjListType, GraphInfo, VertexInfo, EdgeInfo} +import com.alibaba.graphar.reader.{VertexReader, EdgeReader} + +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.functions._ + +/** The helper object for reading graph through the definitions of graph info. */ +object GraphReader { + /** Loads the vertex chunks as DataFrame with the vertex infos. + * + * @param prefix The absolute prefix. + * @param vertexInfos The map of (vertex label -> VertexInfo) for the graph. + * @param spark The Spark session for the reading. + * @return The map of (vertex label -> DataFrame) + */ + private def readAllVertices(prefix: String, vertexInfos: Map[String, VertexInfo], spark: SparkSession): Map[String, DataFrame] = { + val vertex_dataframes: Map[String, DataFrame] = vertexInfos.map { case (label, vertexInfo) => { + val reader = new VertexReader(prefix, vertexInfo, spark) + (label, reader.readAllVertexPropertyGroups(false)) + }} + return vertex_dataframes + } + + /** Loads the edge chunks as DataFrame with the edge infos. + * + * @param prefix The absolute prefix. + * @param edgeInfos The map of (srcLabel_edgeLabel_dstlabel -> EdgeInfo) for the graph. + * @param spark The Spark session for the reading. + * @return The map of (srcLabel_edgeLabel_dstlabel -> (adj_list_type_str -> DataFrame)) + */ + private def readAllEdges(prefix: String, edgeInfos: Map[String, EdgeInfo], spark: SparkSession): Map[String, Map[String, DataFrame]] = { + val edge_dataframes: Map[String, Map[String, DataFrame]] = edgeInfos.map { case (key, edgeInfo) => { + val adj_lists = edgeInfo.getAdj_lists + val adj_list_it = adj_lists.iterator + var adj_list_type_edge_df_map: Map[String, DataFrame] = Map[String, DataFrame]() + while (adj_list_it.hasNext()) { + val adj_list = adj_list_it.next() + val adj_list_type = adj_list.getAdjList_type_in_gar + val adj_list_type_str = adj_list.getAdjList_type + val reader = new EdgeReader(prefix, edgeInfo, adj_list_type, spark) + adj_list_type_edge_df_map += (adj_list_type_str -> reader.readEdges(false)) + } + (key, adj_list_type_edge_df_map) + }} + return edge_dataframes + } + + /** Reading the graph as DataFrames with the graph info object. + * + * @param graphInfo The info object for the graph. + * @param spark The Spark session for the loading. + * @return Pair of vertex dataframes and edge dataframes, the vertex dataframes are stored as the map of (vertex_label -> DataFrame) + * the edge dataframes are stored as a map of (srcLabel_edgeLabel_dstLabel -> (adj_list_type_str -> DataFrame)) + */ + def read(graphInfo: GraphInfo, spark: SparkSession): Pair[Map[String, DataFrame], Map[String, Map[String, DataFrame]]] = { + val prefix = graphInfo.getPrefix + val vertex_infos = graphInfo.getVertexInfos() + val edge_infos = graphInfo.getEdgeInfos() + return (readAllVertices(prefix, vertex_infos, spark), readAllEdges(prefix, edge_infos, spark)) + } + + /** Reading the graph as DataFrames with the graph info yaml file. + * + * @param graphInfoPath The path of the graph info yaml. + * @param spark The Spark session for the loading. + * @return Pair of vertex dataframes and edge dataframes, the vertex dataframes are stored as the map of (vertex_label -> DataFrame) + * the edge dataframes are stored as a map of (srcLabel_edgeLabel_dstLabel -> (adj_list_type_str -> DataFrame)) + */ + def read(graphInfoPath: String, spark: SparkSession): Pair[Map[String, DataFrame], Map[String, Map[String, DataFrame]]] = { + // load graph info + val graph_info = GraphInfo.loadGraphInfo(graphInfoPath, spark) + + // conduct reading + read(graph_info, spark) + } +} diff --git a/spark/src/main/scala/com/alibaba/graphar/graph/GraphTransformer.scala b/spark/src/main/scala/com/alibaba/graphar/graph/GraphTransformer.scala index bb2345ed2..02023bca6 100644 --- a/spark/src/main/scala/com/alibaba/graphar/graph/GraphTransformer.scala +++ b/spark/src/main/scala/com/alibaba/graphar/graph/GraphTransformer.scala @@ -25,35 +25,6 @@ import org.apache.spark.sql.functions._ /** The helper object for transforming graphs through the definitions of their infos. */ object GraphTransformer { - /** Construct the map of (vertex label -> VertexInfo) for a graph. */ - private def constructVertexInfoMap(prefix: String, graphInfo: GraphInfo, spark: SparkSession): Map[String, VertexInfo] = { - var vertex_infos_map: Map[String, VertexInfo] = Map() - val vertices_yaml = graphInfo.getVertices - val vertices_it = vertices_yaml.iterator - while (vertices_it.hasNext()) { - val file_name = vertices_it.next() - val path = prefix + file_name - val vertex_info = VertexInfo.loadVertexInfo(path, spark) - vertex_infos_map += (vertex_info.getLabel -> vertex_info) - } - return vertex_infos_map - } - - /** Construct the map of (edge label -> EdgeInfo) for a graph. */ - private def constructEdgeInfoMap(prefix: String, graphInfo: GraphInfo, spark: SparkSession): Map[String, EdgeInfo] = { - var edge_infos_map: Map[String, EdgeInfo] = Map() - val edges_yaml = graphInfo.getEdges - val edges_it = edges_yaml.iterator - while (edges_it.hasNext()) { - val file_name = edges_it.next() - val path = prefix + file_name - val edge_info = EdgeInfo.loadEdgeInfo(path, spark) - val key = edge_info.getSrc_label + GeneralParams.regularSeperator + edge_info.getEdge_label + GeneralParams.regularSeperator + edge_info.getDst_label - edge_infos_map += (key -> edge_info) - } - return edge_infos_map - } - /** Transform the vertex chunks following the meta data defined in graph info objects. * * @param sourceGraphInfo The info object for the source graph. @@ -105,7 +76,7 @@ object GraphTransformer { val path = dest_prefix + dest_edges_it.next() val dest_edge_info = EdgeInfo.loadEdgeInfo(path, spark) // load source edge info - val key = dest_edge_info.getSrc_label + GeneralParams.regularSeperator + dest_edge_info.getEdge_label + GeneralParams.regularSeperator + dest_edge_info.getDst_label + val key = dest_edge_info.getConcatKey() if (!sourceEdgeInfosMap.contains(key)) { throw new IllegalArgumentException } @@ -160,19 +131,11 @@ object GraphTransformer { * @param spark The Spark session for the transformer. */ def transform(sourceGraphInfo: GraphInfo, destGraphInfo: GraphInfo, spark: SparkSession): Unit = { - val source_prefix = sourceGraphInfo.getPrefix - val dest_prefix = destGraphInfo.getPrefix - - // construct the (vertex label -> vertex info) map for the source graph - val source_vertex_infos_map = constructVertexInfoMap(source_prefix, sourceGraphInfo, spark) - // construct the (edge label -> edge info) map for the source graph - val source_edge_infos_map = constructEdgeInfoMap(source_prefix, sourceGraphInfo, spark) - // transform and generate vertex data chunks - transformAllVertices(sourceGraphInfo, destGraphInfo, source_vertex_infos_map, spark) + transformAllVertices(sourceGraphInfo, destGraphInfo, sourceGraphInfo.getVertexInfos(), spark) // transform and generate edge data chunks - transformAllEdges(sourceGraphInfo, destGraphInfo, source_vertex_infos_map, source_edge_infos_map, spark) + transformAllEdges(sourceGraphInfo, destGraphInfo, sourceGraphInfo.getVertexInfos(), sourceGraphInfo.getEdgeInfos(), spark) } /** Transform the graphs following the meta data defined in info files. diff --git a/spark/src/main/scala/com/alibaba/graphar/graph/GraphWriter.scala b/spark/src/main/scala/com/alibaba/graphar/graph/GraphWriter.scala new file mode 100644 index 000000000..ea7a5dab5 --- /dev/null +++ b/spark/src/main/scala/com/alibaba/graphar/graph/GraphWriter.scala @@ -0,0 +1,134 @@ +/** Copyright 2022 Alibaba Group Holding Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.graphar.graph + +import com.alibaba.graphar.{AdjListType, GraphInfo, VertexInfo, EdgeInfo} +import com.alibaba.graphar.writer.{VertexWriter, EdgeWriter} +import com.alibaba.graphar.utils.IndexGenerator + +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.functions._ + +/** The helper object for writing graph through the definitions of graph info. */ +object GraphWriter { + /** Writing the vertex dataframes to GAR with the vertex infos. + * + * @param prefix The absolute prefix. + * @param vertexInfos The map of (vertex label -> VertexInfo) for the graph. + * @param vertexNumMap vertex num of vertices, a map of (vertex label -> vertex num) + * @param vertexDataFrames vertex dataframes, a map of (vertex label -> DataFrame) + * @param spark The Spark session for the writing. + */ + private def writeAllVertices(prefix: String, + vertexInfos: Map[String, VertexInfo], + vertexNumMap: Map[String, Long], + vertexDataFrames: Map[String, DataFrame], + spark: SparkSession): Unit = { + vertexInfos.foreach { case (label, vertexInfo) => { + val vertex_num = vertexNumMap(label) + val df_with_index = IndexGenerator.generateVertexIndexColumn(vertexDataFrames(label)) + val writer = new VertexWriter(prefix, vertexInfo, df_with_index, Some(vertex_num)) + writer.writeVertexProperties() + }} + } + + /** Writing edge dataframes to GAR with the vertex infos. + * + * @param prefix The absolute prefix. + * @param vertexInfos The map of (vertex label -> VertexInfo) for the graph. + * @param edgeInfos The map of (srclabel_edgeLabel_dstLabel -> EdgeInfo) for the graph. + * @param vertexNumMap vertex num of vertices, a map of (vertex label -> vertex num) + * @param vertexDataFrames vertex dataframes, a map of (vertex label -> DataFrame) + * @param edgeDataFrames edge dataframes, a map of (srcLabel_edgeLabel_dstLabel -> DataFrame) + * @param spark The Spark session for the writing. + */ + private def writeAllEdges(prefix: String, + vertexInfos: Map[String, VertexInfo], + edgeInfos: Map[String, EdgeInfo], + vertexNumMap: Map[String, Long], + vertexDataFrames: Map[String, DataFrame], + edgeDataFrames: Map[String, DataFrame], + spark: SparkSession): Unit = { + edgeInfos.foreach { case (key, edgeInfo) => { + val srcLabel = edgeInfo.getSrc_label + val dstLabel = edgeInfo.getDst_label + val edge_key = edgeInfo.getConcatKey() + val src_vertex_index_mapping = IndexGenerator.constructVertexIndexMapping(vertexDataFrames(srcLabel), vertexInfos(srcLabel).getPrimaryKey()) + val dst_vertex_index_mapping = { + if (srcLabel == dstLabel) + src_vertex_index_mapping + else + IndexGenerator.constructVertexIndexMapping(vertexDataFrames(dstLabel), vertexInfos(dstLabel).getPrimaryKey()) + } + val edge_df_with_index = IndexGenerator.generateSrcAndDstIndexForEdgesFromMapping(edgeDataFrames(edge_key), src_vertex_index_mapping, dst_vertex_index_mapping) + + val adj_lists = edgeInfo.getAdj_lists + val adj_list_it = adj_lists.iterator + while (adj_list_it.hasNext()) { + val adj_list_type = adj_list_it.next().getAdjList_type_in_gar + val vertex_num = { + if (adj_list_type == AdjListType.ordered_by_source || adj_list_type == AdjListType.unordered_by_source) { + vertexNumMap(srcLabel) + } else { + vertexNumMap(dstLabel) + } + } + val writer = new EdgeWriter(prefix, edgeInfo, adj_list_type, vertex_num, edge_df_with_index) + writer.writeEdges() + } + }} + } + + /** Writing the graph DataFrames to GAR with the graph info object. + * + * @param graphInfo The graph info object of the graph. + * @param vertexDataFrames: vertex dataframes, a map of (vertex label -> DataFrame) + * @param edgeDataFrames: edge dataframes, a map of (srcLabel_edgeLabel_dstLabel -> DataFrame) + * @param spark The Spark session for the writing. + */ + def write(graphInfo: GraphInfo, + vertexDataFrames: Map[String, DataFrame], + edgeDataFrames: Map[String, DataFrame], + spark: SparkSession): Unit = { + // get the vertex num of each vertex dataframe + val vertex_num_map: Map[String, Long] = vertexDataFrames.map { case (k, v) => (k, v.count()) } + val prefix = graphInfo.getPrefix + val vertex_infos = graphInfo.getVertexInfos() + val edge_infos = graphInfo.getEdgeInfos() + + // write vertices + writeAllVertices(prefix, vertex_infos, vertex_num_map, vertexDataFrames, spark) + + // write edges + writeAllEdges(prefix, vertex_infos, edge_infos, vertex_num_map, vertexDataFrames, edgeDataFrames, spark) + } + + /** Writing the graph DataFrames to GAR with the graph info yaml file. + * + * @param graphInfoPath The path of the graph info yaml. + * @param vertexDataFrames: vertex dataframes, a map of (vertex label -> DataFrame) + * @param edgeDataFrames: edge dataframes, a map of (srcLabel_edgeLabel_dstLabel -> DataFrame) + * @param spark The Spark session for the writing. + */ + def write(graphInfoPath: String, vertexDataFrames: Map[String, DataFrame], edgeDataFrames: Map[String, DataFrame], spark: SparkSession): Unit = { + // load graph info + val graph_info = GraphInfo.loadGraphInfo(graphInfoPath, spark) + + // conduct writing + write(graph_info, vertexDataFrames, edgeDataFrames, spark) + } +} diff --git a/spark/src/main/scala/com/alibaba/graphar/utils/IndexGenerator.scala b/spark/src/main/scala/com/alibaba/graphar/utils/IndexGenerator.scala index 32c04c4d8..d70cfd390 100644 --- a/spark/src/main/scala/com/alibaba/graphar/utils/IndexGenerator.scala +++ b/spark/src/main/scala/com/alibaba/graphar/utils/IndexGenerator.scala @@ -142,6 +142,14 @@ object IndexGenerator { generateDstIndexForEdgesFromMapping(df_with_src_index, dstColumnName, dstIndexMapping) } + /** Assumes that the first and second columns are the src and dst columns */ + def generateSrcAndDstIndexForEdgesFromMapping(edgeDf: DataFrame, srcIndexMapping: DataFrame, dstIndexMapping: DataFrame): DataFrame = { + val srcColName: String = edgeDf.columns(0) + val dstColName: String = edgeDf.columns(1) + val df_with_src_index = generateSrcIndexForEdgesFromMapping(edgeDf, srcColName, srcIndexMapping) + generateDstIndexForEdgesFromMapping(df_with_src_index, dstColName, dstIndexMapping) + } + /** Construct vertex index for source column. */ def generateSrcIndexForEdges(edgeDf: DataFrame, srcColumnName: String): DataFrame = { val srcDf = edgeDf.select(srcColumnName).distinct() diff --git a/spark/src/test/scala/com/alibaba/graphar/TestGraphInfo.scala b/spark/src/test/scala/com/alibaba/graphar/TestGraphInfo.scala index 0c33eb319..bc3b25945 100644 --- a/spark/src/test/scala/com/alibaba/graphar/TestGraphInfo.scala +++ b/spark/src/test/scala/com/alibaba/graphar/TestGraphInfo.scala @@ -34,6 +34,9 @@ class GraphInfoSuite extends AnyFunSuite { val prefix = getClass.getClassLoader.getResource("gar-test/ldbc_sample/csv/").getPath val graph_info = GraphInfo.loadGraphInfo(yaml_path, spark) + val vertex_info = graph_info.getVertexInfo("person") + assert(vertex_info.getLabel == "person") + assert(graph_info.getName == "ldbc_sample") assert(graph_info.getPrefix == prefix ) assert(graph_info.getVertices.size() == 1) diff --git a/spark/src/test/scala/com/alibaba/graphar/TestGraphReader.scala b/spark/src/test/scala/com/alibaba/graphar/TestGraphReader.scala new file mode 100644 index 000000000..1073e697a --- /dev/null +++ b/spark/src/test/scala/com/alibaba/graphar/TestGraphReader.scala @@ -0,0 +1,85 @@ +/** Copyright 2022 Alibaba Group Holding Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.graphar + +import com.alibaba.graphar.GraphInfo +import com.alibaba.graphar.graph.GraphReader + +import java.io.{File, FileInputStream} +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.hadoop.fs.{Path, FileSystem} +import org.scalatest.funsuite.AnyFunSuite + +class TestGraphReaderSuite extends AnyFunSuite { + val spark = SparkSession.builder() + .enableHiveSupport() + .master("local[*]") + .getOrCreate() + + test("read graphs by yaml paths") { + // conduct reading + val graph_path = getClass.getClassLoader.getResource("gar-test/ldbc_sample/parquet/ldbc_sample.graph.yml").getPath + val vertex_edge_df_pair = GraphReader.read(graph_path, spark) + val vertex_dataframes = vertex_edge_df_pair._1 + val edge_dataframes = vertex_edge_df_pair._2 + + assert(vertex_dataframes.size == 1) + assert(vertex_dataframes contains "person") + val person_df = vertex_dataframes("person") + assert(person_df.columns.size == 4) + assert(person_df.count() == 903) + + assert(edge_dataframes.size == 1) + assert(edge_dataframes contains "person_knows_person") + val adj_list_type_dataframes = edge_dataframes("person_knows_person") + assert(adj_list_type_dataframes.size == 3) + } + + test("read graphs by graph infos") { + // load graph info + val path = getClass.getClassLoader.getResource("gar-test/ldbc_sample/parquet/ldbc_sample.graph.yml").getPath + val graph_info = GraphInfo.loadGraphInfo(path, spark) + + // conduct reading + val vertex_edge_df_pair = GraphReader.read(graph_info, spark) + val vertex_dataframes = vertex_edge_df_pair._1 + val edge_dataframes = vertex_edge_df_pair._2 + + assert(vertex_dataframes.size == 1) + assert(vertex_dataframes contains "person") + val person_df = vertex_dataframes("person") + assert(person_df.columns.size == 4) + assert(person_df.count() == 903) + + val edgeInfos = graph_info.getEdgeInfos() + assert(edge_dataframes.size == edgeInfos.size) + edgeInfos.foreach { case (key, edgeInfo) => { + assert(edge_dataframes contains key) + val adj_list_type_dataframes = edge_dataframes(key) + val adj_lists = edgeInfo.getAdj_lists + assert(adj_list_type_dataframes.size == adj_lists.size) + val adj_list_it = adj_lists.iterator + while (adj_list_it.hasNext()) { + val adj_list = adj_list_it.next() + val adj_list_type = adj_list.getAdjList_type_in_gar + val adj_list_type_str = adj_list.getAdjList_type + assert(adj_list_type_dataframes contains adj_list_type_str) + val df = adj_list_type_dataframes(adj_list_type_str) + assert(df.count == 6626) + } + }} + } +} diff --git a/spark/src/test/scala/com/alibaba/graphar/TestGraphWriter.scala b/spark/src/test/scala/com/alibaba/graphar/TestGraphWriter.scala new file mode 100644 index 000000000..55881b70c --- /dev/null +++ b/spark/src/test/scala/com/alibaba/graphar/TestGraphWriter.scala @@ -0,0 +1,82 @@ +/** Copyright 2022 Alibaba Group Holding Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.graphar + +import com.alibaba.graphar.GraphInfo +import com.alibaba.graphar.graph.GraphWriter +import com.alibaba.graphar.utils + +import java.io.{File, FileInputStream} +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.hadoop.fs.{Path, FileSystem} +import org.scalatest.funsuite.AnyFunSuite + +class TestGraphWriterSuite extends AnyFunSuite { + val spark = SparkSession.builder() + .enableHiveSupport() + .master("local[*]") + .getOrCreate() + + test("write graphs by graph infos") { + // load graph info + val path = getClass.getClassLoader.getResource("gar-test/ldbc_sample/parquet/ldbc_sample.graph.yml").getPath + val graph_info = GraphInfo.loadGraphInfo(path, spark) + val prefix = "/tmp/test_graph_writer" + graph_info.setPrefix(prefix) // avoid overwite gar-test files + val fs = FileSystem.get(new Path(prefix).toUri(), spark.sparkContext.hadoopConfiguration) + + // prepare the dataframes + val vertex_file_path = getClass.getClassLoader.getResource("gar-test/ldbc_sample/person_0_0.csv").getPath + val vertex_df = spark.read.option("delimiter", "|").option("header", "true").csv(vertex_file_path) + val vertex_dataframes: Map[String, DataFrame] = Map("person" -> vertex_df) + val file_path = getClass.getClassLoader.getResource("gar-test/ldbc_sample/person_knows_person_0_0.csv").getPath + val edge_df = spark.read.option("delimiter", "|").option("header", "true").csv(file_path) + val edge_dataframes: Map[String, DataFrame] = Map("person_knows_person" -> edge_df) + + // conduct writing + GraphWriter.write(graph_info, vertex_dataframes, edge_dataframes, spark) + val vertex_info = graph_info.getVertexInfo("person") + val chunk_path = new Path(prefix + vertex_info.getPrefix() + "*/*") + val chunk_files = fs.globStatus(chunk_path) + assert(chunk_files.length == 20) + val vertex_num_path = prefix + vertex_info.getVerticesNumFilePath() + val number = utils.FileSystem.readValue(vertex_num_path, spark.sparkContext.hadoopConfiguration) + assert(number.toInt == vertex_df.count()) + + val edgeInfos = graph_info.getEdgeInfos() + edgeInfos.foreach { case (key, edgeInfo) => { + val adj_lists = edgeInfo.getAdj_lists + val adj_list_it = adj_lists.iterator + while (adj_list_it.hasNext()) { + val adj_list = adj_list_it.next() + val adj_list_type = adj_list.getAdjList_type_in_gar + val adj_list_type_str = adj_list.getAdjList_type + val adj_list_path_pattern = new Path(prefix + edgeInfo.getAdjListPathPrefix(adj_list_type) + "*/*") + val adj_list_chunk_files = fs.globStatus(adj_list_path_pattern) + assert(adj_list_chunk_files.length > 0) + if (adj_list_type == AdjListType.ordered_by_source || adj_list_type == AdjListType.ordered_by_dest) { + val offset_path_pattern = new Path(prefix + edgeInfo.getOffsetPathPrefix(adj_list_type) + "*") + val offset_chunk_files = fs.globStatus(offset_path_pattern) + assert(offset_chunk_files.length > 0) + } + } + }} + + // cleaning generated files + fs.delete(new Path(prefix)) + fs.close() + } +}