Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement][Spark] Provide APIs for data reading and writing at the graph level #114

Merged
merged 9 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions spark/src/main/scala/com/alibaba/graphar/EdgeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,10 @@ class EdgeInfo() {
str = prefix + getAdjListPrefix(adj_list_type) + str
return str
}

def getConcatKey(): String = {
return getSrc_label + GeneralParams.regularSeperator + getEdge_label + GeneralParams.regularSeperator + getDst_label
acezen marked this conversation as resolved.
Show resolved Hide resolved
}
}

/** Helper object to load edge info files */
Expand Down
45 changes: 45 additions & 0 deletions spark/src/main/scala/com/alibaba/graphar/GraphInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,34 @@ class GraphInfo() {
@BeanProperty var vertices = new java.util.ArrayList[String]()
@BeanProperty var edges = new java.util.ArrayList[String]()
@BeanProperty var version: String = ""

var vertexInfos: Map[String, VertexInfo] = Map[String, VertexInfo]()
var edgeInfos: Map[String, EdgeInfo] = Map[String, EdgeInfo]()

def addVertexInfo(vertexInfo: VertexInfo): Unit = {
vertexInfos += (vertexInfo.getLabel -> vertexInfo)
}

def addEdgeInfo(edgeInfo: EdgeInfo): Unit = {
edgeInfos += (edgeInfo.getConcatKey() -> edgeInfo)
}

def getVertexInfo(label: String): VertexInfo = {
vertexInfos(label)
}

def getEdgeInfo(srcLabel: String, edgeLabel: String, dstLabel: String): EdgeInfo = {
val key = srcLabel + GeneralParams.regularSeperator + edgeLabel + GeneralParams.regularSeperator + dstLabel
edgeInfos(key)
}

def getVertexInfos(): Map[String, VertexInfo] = {
acezen marked this conversation as resolved.
Show resolved Hide resolved
return vertexInfos
}

def getEdgeInfos(): Map[String, EdgeInfo] = {
return edgeInfos
}
}

/** Helper object to load graph info files */
Expand All @@ -221,6 +249,23 @@ object GraphInfo {
graph_info.setPrefix(prefix)
}
}
val prefix = graph_info.getPrefix
val vertices_yaml = graph_info.getVertices
val vertices_it = vertices_yaml.iterator
while (vertices_it.hasNext()) {
val file_name = vertices_it.next()
val path = prefix + file_name
val vertex_info = VertexInfo.loadVertexInfo(path, spark)
graph_info.addVertexInfo(vertex_info)
}
val edges_yaml = graph_info.getEdges
val edges_it = edges_yaml.iterator
while (edges_it.hasNext()) {
val file_name = edges_it.next()
val path = prefix + file_name
val edge_info = EdgeInfo.loadEdgeInfo(path, spark)
graph_info.addEdgeInfo(edge_info)
}
return graph_info
}
}
66 changes: 66 additions & 0 deletions spark/src/main/scala/com/alibaba/graphar/graph/GraphReader.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/** Copyright 2022 Alibaba Group Holding Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.graphar.graph

import com.alibaba.graphar.{GeneralParams, AdjListType, GraphInfo, VertexInfo, EdgeInfo}
import com.alibaba.graphar.reader.{VertexReader, EdgeReader}

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._

/** The helper object for read graphs through the definitions of their infos. */
object GraphReader {
private def readAllVertices(prefix: String, vertexInfos: Map[String, VertexInfo], spark: SparkSession): Map[String, DataFrame] = {
val vertex_dataframes: Map[String, DataFrame] = vertexInfos.map { case (label, vertexInfo) => {
val reader = new VertexReader(prefix, vertexInfo, spark)
(label, reader.readAllVertexPropertyGroups(false))
}}
return vertex_dataframes
}

private def readAllEdges(prefix: String, edgeInfos: Map[String, EdgeInfo], spark: SparkSession): Map[String, Map[String, DataFrame]] = {
val edge_dataframes: Map[String, Map[String, DataFrame]] = edgeInfos.map { case (key, edgeInfo) => {
val adj_lists = edgeInfo.getAdj_lists
val adj_list_it = adj_lists.iterator
var adj_list_type_edge_df_map: Map[String, DataFrame] = Map[String, DataFrame]()
while (adj_list_it.hasNext()) {
val adj_list = adj_list_it.next()
val adj_list_type = adj_list.getAdjList_type_in_gar
val adj_list_type_str = adj_list.getAdjList_type
val reader = new EdgeReader(prefix, edgeInfo, adj_list_type, spark)
adj_list_type_edge_df_map += (adj_list_type_str -> reader.readEdges(false))
}
(key, adj_list_type_edge_df_map)
}}
return edge_dataframes
}

def read(graphInfo: GraphInfo, spark: SparkSession): Pair[Map[String, DataFrame], Map[String, Map[String, DataFrame]]] = {
acezen marked this conversation as resolved.
Show resolved Hide resolved
val prefix = graphInfo.getPrefix
val vertex_infos = graphInfo.getVertexInfos()
val edge_infos = graphInfo.getEdgeInfos()
return (readAllVertices(prefix, vertex_infos, spark), readAllEdges(prefix, edge_infos, spark))
}

def read(graphInfoPath: String, spark: SparkSession): Pair[Map[String, DataFrame], Map[String, Map[String, DataFrame]]] = {
// load graph info
val graph_info = GraphInfo.loadGraphInfo(graphInfoPath, spark)

// conduct reading
read(graph_info, spark)
}
}
99 changes: 99 additions & 0 deletions spark/src/main/scala/com/alibaba/graphar/graph/GraphWriter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/** Copyright 2022 Alibaba Group Holding Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.graphar.graph

import com.alibaba.graphar.{AdjListType, GraphInfo, VertexInfo, EdgeInfo}
import com.alibaba.graphar.writer.{VertexWriter, EdgeWriter}
import com.alibaba.graphar.utils.IndexGenerator

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._

/** The helper object for write graph through the definitions of their infos. */
object GraphWriter {
private def writeAllVertices(prefix: String,
vertexInfos: Map[String, VertexInfo],
vertex_num_map: Map[String, Long],
vertexDataFrames: Map[String, DataFrame],
spark: SparkSession): Unit = {
vertexInfos.foreach { case (label, vertexInfo) => {
val vertex_num = vertex_num_map(label)
val df_with_index = IndexGenerator.generateVertexIndexColumn(vertexDataFrames(label))
val writer = new VertexWriter(prefix, vertexInfo, df_with_index, Some(vertex_num))
writer.writeVertexProperties()
}}
}

private def writeAllEdges(prefix: String,
vertexInfos: Map[String, VertexInfo],
edgeInfos: Map[String, EdgeInfo],
vertex_num_map: Map[String, Long],
vertexDataFrames: Map[String, DataFrame],
edgeDataFrames: Map[String, DataFrame],
spark: SparkSession): Unit = {
edgeInfos.foreach { case (key, edgeInfo) => {
val srcLabel = edgeInfo.getSrc_label
val dstLabel = edgeInfo.getDst_label
val edge_key = edgeInfo.getConcatKey()
val src_vertex_index_mapping = IndexGenerator.constructVertexIndexMapping(vertexDataFrames(srcLabel), vertexInfos(srcLabel).getPrimaryKey())
val dst_vertex_index_mapping = {
if (srcLabel == dstLabel)
src_vertex_index_mapping
else
IndexGenerator.constructVertexIndexMapping(vertexDataFrames(dstLabel), vertexInfos(dstLabel).getPrimaryKey())
}
val edge_df_with_index = IndexGenerator.generateSrcAndDstIndexForEdgesFromMapping(edgeDataFrames(edge_key), src_vertex_index_mapping, dst_vertex_index_mapping)

val adj_lists = edgeInfo.getAdj_lists
val adj_list_it = adj_lists.iterator
while (adj_list_it.hasNext()) {
val adj_list_type = adj_list_it.next().getAdjList_type_in_gar
val vertex_num = {
if (adj_list_type == AdjListType.ordered_by_source || adj_list_type == AdjListType.unordered_by_source) {
vertex_num_map(srcLabel)
} else {
vertex_num_map(dstLabel)
}
}
val writer = new EdgeWriter(prefix, edgeInfo, adj_list_type, vertex_num, edge_df_with_index)
writer.writeEdges()
}
}}
}

def write(graphInfo: GraphInfo, vertexDataFrames: Map[String, DataFrame], edgeDataFrames: Map[String, DataFrame], spark: SparkSession): Unit = {
// get the vertex num of each vertex dataframe
val vertex_num_map: Map[String, Long] = vertexDataFrames.map { case (k, v) => (k, v.count()) }
val prefix = graphInfo.getPrefix
val vertex_infos = graphInfo.getVertexInfos()
val edge_infos = graphInfo.getEdgeInfos()

// write vertices
writeAllVertices(prefix, vertex_infos, vertex_num_map, vertexDataFrames, spark)

// write edges
writeAllEdges(prefix, vertex_infos, edge_infos, vertex_num_map, vertexDataFrames, edgeDataFrames, spark)
}

def write(graphInfoPath: String, vertexDataFrames: Map[String, DataFrame], edgeDataFrames: Map[String, DataFrame], spark: SparkSession): Unit = {
// load graph info
val graph_info = GraphInfo.loadGraphInfo(graphInfoPath, spark)

// conduct writing
write(graph_info, vertexDataFrames, edgeDataFrames, spark)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ object IndexGenerator {
generateDstIndexForEdgesFromMapping(df_with_src_index, dstColumnName, dstIndexMapping)
}

/** Assumes that the first and second columns are the src and dst columns */
def generateSrcAndDstIndexForEdgesFromMapping(edgeDf: DataFrame, srcIndexMapping: DataFrame, dstIndexMapping: DataFrame): DataFrame = {
val srcColName: String = edgeDf.columns(0)
val dstColName: String = edgeDf.columns(1)
val df_with_src_index = generateSrcIndexForEdgesFromMapping(edgeDf, srcColName, srcIndexMapping)
generateDstIndexForEdgesFromMapping(df_with_src_index, dstColName, dstIndexMapping)
}

/** Construct vertex index for source column. */
def generateSrcIndexForEdges(edgeDf: DataFrame, srcColumnName: String): DataFrame = {
val srcDf = edgeDf.select(srcColumnName).distinct()
Expand Down
3 changes: 3 additions & 0 deletions spark/src/test/scala/com/alibaba/graphar/TestGraphInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class GraphInfoSuite extends AnyFunSuite {
val prefix = getClass.getClassLoader.getResource("gar-test/ldbc_sample/csv/").getPath
val graph_info = GraphInfo.loadGraphInfo(yaml_path, spark)

val vertex_info = graph_info.getVertexInfo("person")
assert(vertex_info.getLabel == "person")

assert(graph_info.getName == "ldbc_sample")
assert(graph_info.getPrefix == prefix )
assert(graph_info.getVertices.size() == 1)
Expand Down
85 changes: 85 additions & 0 deletions spark/src/test/scala/com/alibaba/graphar/TestGraphReader.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/** Copyright 2022 Alibaba Group Holding Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.graphar

import com.alibaba.graphar.GraphInfo
import com.alibaba.graphar.graph.GraphReader

import java.io.{File, FileInputStream}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.hadoop.fs.{Path, FileSystem}
import org.scalatest.funsuite.AnyFunSuite

class TestGraphReaderSuite extends AnyFunSuite {
val spark = SparkSession.builder()
.enableHiveSupport()
.master("local[*]")
.getOrCreate()

test("read graphs by yaml paths") {
// conduct reader
acezen marked this conversation as resolved.
Show resolved Hide resolved
val graph_path = getClass.getClassLoader.getResource("gar-test/ldbc_sample/parquet/ldbc_sample.graph.yml").getPath
val vertex_edge_df_pair = GraphReader.read(graph_path, spark)
val vertex_dataframes = vertex_edge_df_pair._1
val edge_dataframes = vertex_edge_df_pair._2

assert(vertex_dataframes.size == 1)
assert(vertex_dataframes contains "person")
val person_df = vertex_dataframes("person")
assert(person_df.columns.size == 4)
assert(person_df.count() == 903)

assert(edge_dataframes.size == 1)
assert(edge_dataframes contains "person_knows_person")
val adj_list_type_dataframes = edge_dataframes("person_knows_person")
assert(adj_list_type_dataframes.size == 3)
}

test("read graphs by graph infos") {
// load graph info
val path = getClass.getClassLoader.getResource("gar-test/ldbc_sample/parquet/ldbc_sample.graph.yml").getPath
val graph_info = GraphInfo.loadGraphInfo(path, spark)

// conduct reader
acezen marked this conversation as resolved.
Show resolved Hide resolved
val vertex_edge_df_pair = GraphReader.read(graph_info, spark)
val vertex_dataframes = vertex_edge_df_pair._1
val edge_dataframes = vertex_edge_df_pair._2

assert(vertex_dataframes.size == 1)
assert(vertex_dataframes contains "person")
val person_df = vertex_dataframes("person")
assert(person_df.columns.size == 4)
assert(person_df.count() == 903)

val edgeInfos = graph_info.getEdgeInfos()
assert(edge_dataframes.size == edgeInfos.size)
edgeInfos.foreach { case (key, edgeInfo) => {
assert(edge_dataframes contains key)
val adj_list_type_dataframes = edge_dataframes(key)
val adj_lists = edgeInfo.getAdj_lists
assert(adj_list_type_dataframes.size == adj_lists.size)
val adj_list_it = adj_lists.iterator
while (adj_list_it.hasNext()) {
val adj_list = adj_list_it.next()
val adj_list_type = adj_list.getAdjList_type_in_gar
val adj_list_type_str = adj_list.getAdjList_type
assert(adj_list_type_dataframes contains adj_list_type_str)
val df = adj_list_type_dataframes(adj_list_type_str)
assert(df.count == 6626)
}
}}
}
}
Loading