diff --git a/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala b/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala index 612380ca3..32b2325e0 100644 --- a/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala +++ b/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala @@ -18,8 +18,8 @@ package com.alibaba.graphar.reader import com.alibaba.graphar.utils.{IndexGenerator, DataFrameConcat} import com.alibaba.graphar.{GeneralParams, VertexInfo, FileType, PropertyGroup} import com.alibaba.graphar.datasources._ +import com.alibaba.graphar.utils.FileSystem -import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.types._ import org.apache.spark.sql.functions._ @@ -36,11 +36,7 @@ class VertexReader(prefix: String, vertexInfo: VertexInfo, spark: SparkSession) /** Load the total number of vertices for this vertex type. */ def readVerticesNumber(): Long = { val file_path = prefix + "/" + vertexInfo.getVerticesNumFilePath() - val path = new Path(file_path) - val file_system = FileSystem.get(path.toUri(), spark.sparkContext.hadoopConfiguration) - val input = file_system.open(path) - val number = java.lang.Long.reverseBytes(input.readLong()) - file_system.close() + val number = FileSystem.readValue(file_path, spark.sparkContext.hadoopConfiguration) return number } diff --git a/spark/src/main/scala/com/alibaba/graphar/utils/FileSystem.scala b/spark/src/main/scala/com/alibaba/graphar/utils/FileSystem.scala index f7a883055..e721b44cd 100644 --- a/spark/src/main/scala/com/alibaba/graphar/utils/FileSystem.scala +++ b/spark/src/main/scala/com/alibaba/graphar/utils/FileSystem.scala @@ -68,7 +68,18 @@ object FileSystem { val path = new Path(outputPath) val fs = path.getFileSystem(hadoopConfig) val output = fs.create(path, true) // create or overwrite - output.writeLong(value) + // consistent with c++ library, convert to little-endian + output.writeLong(java.lang.Long.reverseBytes(value)) output.close() } + + def readValue(inputPath: String, hadoopConfig: Configuration): Long = { + val path = new Path(inputPath) + val fs = path.getFileSystem(hadoopConfig) + val input = fs.open(path) + // consistent with c++ library, little-endian in file, convert to big-endian + val num = java.lang.Long.reverseBytes(input.readLong()) + fs.close() + return num + } } diff --git a/spark/src/main/scala/com/alibaba/graphar/writer/VertexWriter.scala b/spark/src/main/scala/com/alibaba/graphar/writer/VertexWriter.scala index 4e64d6766..160397119 100644 --- a/spark/src/main/scala/com/alibaba/graphar/writer/VertexWriter.scala +++ b/spark/src/main/scala/com/alibaba/graphar/writer/VertexWriter.scala @@ -60,6 +60,7 @@ class VertexWriter(prefix: String, vertexInfo: VertexInfo, vertexDf: DataFrame, case None => vertexDf.count() case _ => numVertices.get } + writeVertexNum() private var chunks:DataFrame = VertexWriter.repartitionAndSort(vertexDf, vertexInfo.getChunk_size(), vertexNum) diff --git a/spark/src/test/scala/com/alibaba/graphar/TestWriter.scala b/spark/src/test/scala/com/alibaba/graphar/TestWriter.scala index 837be5748..da0b75380 100644 --- a/spark/src/test/scala/com/alibaba/graphar/TestWriter.scala +++ b/spark/src/test/scala/com/alibaba/graphar/TestWriter.scala @@ -15,8 +15,8 @@ package com.alibaba.graphar -import com.alibaba.graphar.utils.IndexGenerator import com.alibaba.graphar.writer.{VertexWriter, EdgeWriter} +import com.alibaba.graphar.utils import org.apache.spark.sql.{DataFrame, SparkSession} import org.scalatest.funsuite.AnyFunSuite @@ -45,7 +45,7 @@ class WriterSuite extends AnyFunSuite { val vertex_info = vertex_yaml.load(vertex_input).asInstanceOf[VertexInfo] // generate vertex index column for vertex dataframe - val vertex_df_with_index = IndexGenerator.generateVertexIndexColumn(vertex_df) + val vertex_df_with_index = utils.IndexGenerator.generateVertexIndexColumn(vertex_df) // create writer object for person and generate the properties with GAR format val prefix : String = "/tmp/" @@ -61,6 +61,9 @@ class WriterSuite extends AnyFunSuite { val chunk_path = new Path(prefix + vertex_info.getPrefix() + "*/*") val chunk_files = fs.globStatus(chunk_path) assert(chunk_files.length == 20) + val vertex_num_path = prefix + vertex_info.getVerticesNumFilePath() + val number = utils.FileSystem.readValue(vertex_num_path, spark.sparkContext.hadoopConfiguration) + assert(number.toInt == vertex_df.count()) assertThrows[IllegalArgumentException](new VertexWriter(prefix, vertex_info, vertex_df)) val invalid_property_group= new PropertyGroup() @@ -89,7 +92,7 @@ class WriterSuite extends AnyFunSuite { val srcDf = edge_df.select("src").withColumnRenamed("src", "vertex") val dstDf = edge_df.select("dst").withColumnRenamed("dst", "vertex") val vertexNum = srcDf.union(dstDf).distinct().count() - val edge_df_with_index = IndexGenerator.generateSrcAndDstIndexUnitedlyForEdges(edge_df, "src", "dst") + val edge_df_with_index = utils.IndexGenerator.generateSrcAndDstIndexUnitedlyForEdges(edge_df, "src", "dst") // create writer object for person_knows_person and generate the adj list and properties with GAR format val writer = new EdgeWriter(prefix, edge_info, adj_list_type, vertexNum, edge_df_with_index) @@ -156,10 +159,10 @@ class WriterSuite extends AnyFunSuite { val edge_info = edge_yaml.load(edge_input).asInstanceOf[EdgeInfo] // construct person vertex mapping with dataframe - val vertex_mapping = IndexGenerator.constructVertexIndexMapping(vertex_df, vertex_info.getPrimaryKey()) + val vertex_mapping = utils.IndexGenerator.constructVertexIndexMapping(vertex_df, vertex_info.getPrimaryKey()) // generate src index and dst index for edge datafram with vertex mapping - val edge_df_with_src_index = IndexGenerator.generateSrcIndexForEdgesFromMapping(edge_df, "src", vertex_mapping) - val edge_df_with_src_dst_index = IndexGenerator.generateDstIndexForEdgesFromMapping(edge_df_with_src_index, "dst", vertex_mapping) + val edge_df_with_src_index = utils.IndexGenerator.generateSrcIndexForEdgesFromMapping(edge_df, "src", vertex_mapping) + val edge_df_with_src_dst_index = utils.IndexGenerator.generateDstIndexForEdgesFromMapping(edge_df_with_src_index, "dst", vertex_mapping) // create writer object for person_knows_person and generate the adj list and properties with GAR format val writer = new EdgeWriter(prefix, edge_info, adj_list_type, vertex_num, edge_df_with_src_dst_index)