From 44f2e801ca47c955861a7294f7ba5f072f6c8364 Mon Sep 17 00:00:00 2001 From: Ziy1-Tan Date: Mon, 14 Aug 2023 00:43:08 +0800 Subject: [PATCH] Feat: filter pushdown for spark Signed-off-by: Ziy1-Tan --- .../alibaba/graphar/datasources/GarScan.scala | 5 +- .../graphar/datasources/GarScanBuilder.scala | 46 +++++++---- .../alibaba/graphar/reader/VertexReader.scala | 30 +++---- .../com/alibaba/graphar/TestReader.scala | 81 ++++++++++--------- 4 files changed, 88 insertions(+), 74 deletions(-) diff --git a/spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala b/spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala index 5fa9b99fa..629a769bb 100644 --- a/spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala +++ b/spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala @@ -17,7 +17,6 @@ package com.alibaba.graphar.datasources import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.conf.Configuration @@ -40,6 +39,8 @@ import org.apache.spark.sql.execution.datasources.parquet.{ ParquetReadSupport, ParquetWriteSupport } +import org.apache.spark.sql.execution.datasources.{FilePartition, PartitioningAwareFileIndex, PartitionedFile} +import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetPartitionReaderFactory import org.apache.spark.sql.execution.datasources.v2.orc.OrcPartitionReaderFactory @@ -279,7 +280,7 @@ case class GarScan( super.description() + ", PushedFilters: " + seqToString(pushedFilters) } - /** Get the meata data map of the object. */ + /** Get the meta data map of the object. */ override def getMetaData(): Map[String, String] = { super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) } diff --git a/spark/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala b/spark/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala index f5728b387..e9345fbfa 100644 --- a/spark/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala +++ b/spark/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala @@ -16,16 +16,19 @@ package com.alibaba.graphar.datasources -import scala.collection.JavaConverters._ - import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex + import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import scala.collection.JavaConverters._ +import org.apache.spark.sql.execution.datasources.v2.orc.OrcScanBuilder +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScanBuilder + /** GarScanBuilder is a class to build the file scan for GarDataSource. */ case class GarScanBuilder( sparkSession: SparkSession, @@ -34,13 +37,37 @@ case class GarScanBuilder( dataSchema: StructType, options: CaseInsensitiveStringMap, formatName: String -) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { +) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) + with SupportsPushDownFilters { lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) } + private var filters: Array[Filter] = Array.empty + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + this.filters = filters + filters + } + + override def pushedFilters(): Array[Filter] = { + formatName match { + case "csv" => Array.empty + case "orc" => pushedOrcFilters + case "parquet" => pushedParquetFilters + case _ => throw new IllegalArgumentException + } + } + + private lazy val pushedParquetFilters = + ParquetScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + .pushFilters(filters) + + private lazy val pushedOrcFilters = + OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + .pushFilters(filters) + // Check if the file format supports nested schema pruning. override protected val supportsNestedSchemaPruning: Boolean = formatName match { @@ -50,15 +77,6 @@ case class GarScanBuilder( case _ => throw new IllegalArgumentException } - // Note: This scan builder does not implement "with SupportsPushDownFilters". - private var filters: Array[Filter] = Array.empty - - // Note: To support pushdown filters, these two methods need to be implemented. - - // override def pushFilters(filters: Array[Filter]): Array[Filter] - - // override def pushedFilters(): Array[Filter] - /** Build the file scan for GarDataSource. */ override def build(): Scan = { GarScan( @@ -68,7 +86,7 @@ case class GarScanBuilder( dataSchema, readDataSchema(), readPartitionSchema(), - filters, + pushedFilters(), options, formatName ) diff --git a/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala b/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala index 5ecf70bde..fb55b033e 100644 --- a/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala +++ b/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala @@ -63,11 +63,8 @@ class VertexReader( * vertex property chunk DataFrame. Raise IllegalArgumentException if the * property group not contained. */ - def readVertexPropertyChunk( - propertyGroup: PropertyGroup, - chunk_index: Long - ): DataFrame = { - if (vertexInfo.containPropertyGroup(propertyGroup) == false) { + def readVertexPropertyChunk(propertyGroup: PropertyGroup, chunk_index: Long): DataFrame = { + if (!vertexInfo.containPropertyGroup(propertyGroup)) { throw new IllegalArgumentException } val file_type = propertyGroup.getFile_type() @@ -91,25 +88,18 @@ class VertexReader( * DataFrame that contains all chunks of property group. Raise * IllegalArgumentException if the property group not contained. */ - def readVertexPropertyGroup( - propertyGroup: PropertyGroup, - addIndex: Boolean = true - ): DataFrame = { - if (vertexInfo.containPropertyGroup(propertyGroup) == false) { + def readVertexPropertyGroup(propertyGroup: PropertyGroup, addIndex: Boolean = true): DataFrame = { + if (!vertexInfo.containPropertyGroup(propertyGroup)) { throw new IllegalArgumentException } val file_type = propertyGroup.getFile_type() val file_path = prefix + vertexInfo.getPathPrefix(propertyGroup) - val df = spark.read - .option("fileFormat", file_type) - .option("header", "true") - .format("com.alibaba.graphar.datasources.GarDataSource") - .load(file_path) + val df = spark.read.option("fileFormat", file_type).option("header", "true").format("com.alibaba.graphar.datasources.GarDataSource").load(file_path) if (addIndex) { - return IndexGenerator.generateVertexIndexColumn(df) + IndexGenerator.generateVertexIndexColumn(df) } else { - return df + df } } @@ -145,7 +135,7 @@ class VertexReader( var rdd = df0.rdd var schema_array = df0.schema.fields - for (i <- 1 to len - 1) { + for ( i <- 1 until len) { val pg: PropertyGroup = propertyGroups.get(i) val new_df = readVertexPropertyGroup(pg, false) schema_array = Array.concat(schema_array, new_df.schema.fields) @@ -155,9 +145,9 @@ class VertexReader( val schema = StructType(schema_array) val df = spark.createDataFrame(rdd, schema) if (addIndex) { - return IndexGenerator.generateVertexIndexColumn(df) + IndexGenerator.generateVertexIndexColumn(df) } else { - return df + df } } diff --git a/spark/src/test/scala/com/alibaba/graphar/TestReader.scala b/spark/src/test/scala/com/alibaba/graphar/TestReader.scala index 36720eb54..78c0de07e 100644 --- a/spark/src/test/scala/com/alibaba/graphar/TestReader.scala +++ b/spark/src/test/scala/com/alibaba/graphar/TestReader.scala @@ -16,14 +16,9 @@ package com.alibaba.graphar -import com.alibaba.graphar.datasources._ import com.alibaba.graphar.reader.{VertexReader, EdgeReader} -import java.io.{File, FileInputStream} -import org.yaml.snakeyaml.Yaml -import org.yaml.snakeyaml.constructor.Constructor -import scala.beans.BeanProperty -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.SparkSession import org.scalatest.funsuite.AnyFunSuite class ReaderSuite extends AnyFunSuite { @@ -33,7 +28,10 @@ class ReaderSuite extends AnyFunSuite { .master("local[*]") .getOrCreate() + spark.sparkContext.setLogLevel("Error") + test("read chunk files directly") { + val cond = "id < 1000" // read vertex chunk files in Parquet val parquet_file_path = "gar-test/ldbc_sample/parquet/" val parquet_prefix = @@ -46,7 +44,9 @@ class ReaderSuite extends AnyFunSuite { // validate reading results assert(df1.rdd.getNumPartitions == 10) assert(df1.count() == 903) - // println(df1.rdd.collect().mkString("\n")) + var df_pd = df1.filter(cond) + df_pd.explain() + df_pd.show() // read vertex chunk files in Orc val orc_file_path = "gar-test/ldbc_sample/orc/" @@ -58,6 +58,9 @@ class ReaderSuite extends AnyFunSuite { .load(orc_read_path) // validate reading results assert(df2.rdd.collect().deep == df1.rdd.collect().deep) + df_pd = df1.filter(cond) + df_pd.explain() + df_pd.show() // read adjList chunk files recursively in CSV val csv_file_path = "gar-test/ldbc_sample/csv/" @@ -85,7 +88,7 @@ class ReaderSuite extends AnyFunSuite { test("read vertex chunks") { // construct the vertex information - val file_path = "gar-test/ldbc_sample/csv/" + val file_path = "gar-test/ldbc_sample/parquet/" val prefix = getClass.getClassLoader.getResource(file_path).getPath val vertex_yaml = getClass.getClassLoader .getResource(file_path + "person.vertex.yml") @@ -101,42 +104,50 @@ class ReaderSuite extends AnyFunSuite { // test reading a single property chunk val single_chunk_df = reader.readVertexPropertyChunk(property_group, 0) - assert(single_chunk_df.columns.size == 3) + assert(single_chunk_df.columns.length == 3) assert(single_chunk_df.count() == 100) - - // test reading chunks for a property group - val property_df = reader.readVertexPropertyGroup(property_group, false) - assert(property_df.columns.size == 3) + val cond = "gender = 'female'" + var df_pd = single_chunk_df.select("firstName","gender").filter(cond) + df_pd.explain() + df_pd.show() + + // test reading all chunks for a property group + val property_df = reader.readVertexPropertyGroup(property_group, addIndex = false) + assert(property_df.columns.length == 3) assert(property_df.count() == 903) + df_pd = property_df.select("firstName", "gender").filter(cond) + df_pd.explain() + df_pd.show() // test reading chunks for multiple property groups val property_group_1 = vertex_info.getPropertyGroup("id") - var property_groups = new java.util.ArrayList[PropertyGroup]() + val property_groups = new java.util.ArrayList[PropertyGroup]() property_groups.add(property_group_1) property_groups.add(property_group) - val multiple_property_df = - reader.readMultipleVertexPropertyGroups(property_groups, false) - assert(multiple_property_df.columns.size == 4) + val multiple_property_df = reader.readMultipleVertexPropertyGroups(property_groups, addIndex = false) + assert(multiple_property_df.columns.length == 4) assert(multiple_property_df.count() == 903) - + df_pd = multiple_property_df.filter(cond) + df_pd.explain() + df_pd.show() // test reading chunks for all property groups and optionally adding indices - val vertex_df = reader.readAllVertexPropertyGroups(false) - vertex_df.show() - assert(vertex_df.columns.size == 4) + val vertex_df = reader.readAllVertexPropertyGroups(addIndex = false) + assert(vertex_df.columns.length == 4) assert(vertex_df.count() == 903) + df_pd = vertex_df.filter(cond) + df_pd.explain() + df_pd.show() val vertex_df_with_index = reader.readAllVertexPropertyGroups() - vertex_df_with_index.show() - assert(vertex_df_with_index.columns.size == 5) + assert(vertex_df_with_index.columns.length == 5) assert(vertex_df_with_index.count() == 903) + df_pd = vertex_df_with_index.filter(cond).select("firstName", "gender") + df_pd.explain("formatted") + df_pd.show() // throw an exception for non-existing property groups val invalid_property_group = new PropertyGroup() - assertThrows[IllegalArgumentException]( - reader.readVertexPropertyChunk(invalid_property_group, 0) - ) - assertThrows[IllegalArgumentException]( - reader.readVertexPropertyGroup(invalid_property_group) - ) + assertThrows[IllegalArgumentException](reader.readVertexPropertyChunk(invalid_property_group, 0)) + assertThrows[IllegalArgumentException](reader.readVertexPropertyGroup(invalid_property_group)) } test("read edge chunks") { @@ -229,15 +240,9 @@ class ReaderSuite extends AnyFunSuite { // throw an exception for non-existing property groups val invalid_property_group = new PropertyGroup() - assertThrows[IllegalArgumentException]( - reader.readEdgePropertyChunk(invalid_property_group, 0, 0) - ) - assertThrows[IllegalArgumentException]( - reader.readEdgePropertyGroupForVertexChunk(invalid_property_group, 0) - ) - assertThrows[IllegalArgumentException]( - reader.readEdgePropertyGroup(invalid_property_group) - ) + assertThrows[IllegalArgumentException](reader.readEdgePropertyChunk(invalid_property_group, 0, 0)) + assertThrows[IllegalArgumentException](reader.readEdgePropertyGroupForVertexChunk(invalid_property_group, 0)) + assertThrows[IllegalArgumentException](reader.readEdgePropertyGroup(invalid_property_group)) // throw an exception for non-existing adjList types val invalid_adj_list_type = AdjListType.unordered_by_dest