diff --git a/spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala b/spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala index 5fa9b99fa..7d538c75a 100644 --- a/spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala +++ b/spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala @@ -17,7 +17,6 @@ package com.alibaba.graphar.datasources import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.conf.Configuration @@ -34,7 +33,6 @@ import org.apache.spark.sql.execution.datasources.{ PartitioningAwareFileIndex, PartitionedFile } -import org.apache.spark.sql.execution.datasources.csv.CSVDataSource import org.apache.spark.sql.execution.datasources.parquet.{ ParquetOptions, ParquetReadSupport, @@ -279,7 +277,7 @@ case class GarScan( super.description() + ", PushedFilters: " + seqToString(pushedFilters) } - /** Get the meata data map of the object. */ + /** Get the meta data map of the object. */ override def getMetaData(): Map[String, String] = { super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) } diff --git a/spark/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala b/spark/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala index f5728b387..1a77997d6 100644 --- a/spark/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala +++ b/spark/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala @@ -16,16 +16,19 @@ package com.alibaba.graphar.datasources -import scala.collection.JavaConverters._ - import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex + import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import scala.collection.JavaConverters._ +import org.apache.spark.sql.execution.datasources.v2.orc.OrcScanBuilder +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScanBuilder + /** GarScanBuilder is a class to build the file scan for GarDataSource. */ case class GarScanBuilder( sparkSession: SparkSession, @@ -34,30 +37,58 @@ case class GarScanBuilder( dataSchema: StructType, options: CaseInsensitiveStringMap, formatName: String -) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { +) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) + with SupportsPushDownFilters { lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) } - // Check if the file format supports nested schema pruning. - override protected val supportsNestedSchemaPruning: Boolean = - formatName match { - case "csv" => false - case "orc" => true - case "parquet" => true - case _ => throw new IllegalArgumentException - } - - // Note: This scan builder does not implement "with SupportsPushDownFilters". private var filters: Array[Filter] = Array.empty + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + this.filters = filters + filters + } - // Note: To support pushdown filters, these two methods need to be implemented. + override def pushedFilters(): Array[Filter] = formatName match { + case "csv" => Array.empty[Filter] + case "orc" => pushedOrcFilters + case "parquet" => pushedParquetFilters + case _ => throw new IllegalArgumentException + } - // override def pushFilters(filters: Array[Filter]): Array[Filter] + private lazy val pushedParquetFilters: Array[Filter] = { + if (!sparkSession.sessionState.conf.parquetFilterPushDown) { + Array.empty[Filter] + } else { + val builder = + ParquetScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + builder.pushFilters(this.filters) + builder.pushedFilters() + } + } + + private lazy val pushedOrcFilters: Array[Filter] = { + if (!sparkSession.sessionState.conf.orcFilterPushDown) { + Array.empty[Filter] + } else { + val builder = + OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + builder.pushFilters(this.filters) + builder.pushedFilters() + } + } - // override def pushedFilters(): Array[Filter] + // Check if the file format supports nested schema pruning. + override protected val supportsNestedSchemaPruning: Boolean = + formatName match { + case "csv" => false + case "orc" => sparkSession.sessionState.conf.nestedSchemaPruningEnabled + case "parquet" => + sparkSession.sessionState.conf.nestedSchemaPruningEnabled + case _ => throw new IllegalArgumentException + } /** Build the file scan for GarDataSource. */ override def build(): Scan = { @@ -68,7 +99,7 @@ case class GarScanBuilder( dataSchema, readDataSchema(), readPartitionSchema(), - filters, + pushedFilters(), options, formatName ) diff --git a/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala b/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala index 5ecf70bde..eb2bac23e 100644 --- a/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala +++ b/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala @@ -67,7 +67,7 @@ class VertexReader( propertyGroup: PropertyGroup, chunk_index: Long ): DataFrame = { - if (vertexInfo.containPropertyGroup(propertyGroup) == false) { + if (!vertexInfo.containPropertyGroup(propertyGroup)) { throw new IllegalArgumentException } val file_type = propertyGroup.getFile_type() @@ -95,7 +95,7 @@ class VertexReader( propertyGroup: PropertyGroup, addIndex: Boolean = true ): DataFrame = { - if (vertexInfo.containPropertyGroup(propertyGroup) == false) { + if (!vertexInfo.containPropertyGroup(propertyGroup)) { throw new IllegalArgumentException } val file_type = propertyGroup.getFile_type() @@ -107,9 +107,9 @@ class VertexReader( .load(file_path) if (addIndex) { - return IndexGenerator.generateVertexIndexColumn(df) + IndexGenerator.generateVertexIndexColumn(df) } else { - return df + df } } @@ -145,7 +145,7 @@ class VertexReader( var rdd = df0.rdd var schema_array = df0.schema.fields - for (i <- 1 to len - 1) { + for (i <- 1 until len) { val pg: PropertyGroup = propertyGroups.get(i) val new_df = readVertexPropertyGroup(pg, false) schema_array = Array.concat(schema_array, new_df.schema.fields) @@ -155,9 +155,9 @@ class VertexReader( val schema = StructType(schema_array) val df = spark.createDataFrame(rdd, schema) if (addIndex) { - return IndexGenerator.generateVertexIndexColumn(df) + IndexGenerator.generateVertexIndexColumn(df) } else { - return df + df } } diff --git a/spark/src/test/scala/com/alibaba/graphar/TestReader.scala b/spark/src/test/scala/com/alibaba/graphar/TestReader.scala index 36720eb54..665cdf3f5 100644 --- a/spark/src/test/scala/com/alibaba/graphar/TestReader.scala +++ b/spark/src/test/scala/com/alibaba/graphar/TestReader.scala @@ -16,14 +16,9 @@ package com.alibaba.graphar -import com.alibaba.graphar.datasources._ import com.alibaba.graphar.reader.{VertexReader, EdgeReader} -import java.io.{File, FileInputStream} -import org.yaml.snakeyaml.Yaml -import org.yaml.snakeyaml.constructor.Constructor -import scala.beans.BeanProperty -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.SparkSession import org.scalatest.funsuite.AnyFunSuite class ReaderSuite extends AnyFunSuite { @@ -33,7 +28,10 @@ class ReaderSuite extends AnyFunSuite { .master("local[*]") .getOrCreate() + spark.sparkContext.setLogLevel("Error") + test("read chunk files directly") { + val cond = "id < 1000" // read vertex chunk files in Parquet val parquet_file_path = "gar-test/ldbc_sample/parquet/" val parquet_prefix = @@ -46,7 +44,21 @@ class ReaderSuite extends AnyFunSuite { // validate reading results assert(df1.rdd.getNumPartitions == 10) assert(df1.count() == 903) - // println(df1.rdd.collect().mkString("\n")) + var df_pd = df1.filter(cond) + + /** + * ==Physical Plan== + * (1) Filter (isnotnull(id#0L) AND (id#0L < 1000)) + * +- *(1) ColumnarToRow + * +- BatchScan[id#0L] GarScan DataFilters: [isnotnull(id#0L), (id#0L < + * 1000)], Format: gar, Location: InMemoryFileIndex(1 + * paths)[file:/path/to/code/cpp/GraphAr/spark/src/test/resources/gar-test/l..., + * PartitionFilters: [], PushedFilters: [IsNotNull(id), LessThan(id,1000)], + * ReadSchema: struct, PushedFilters: [IsNotNull(id), + * LessThan(id,1000)] RuntimeFilters: [] + */ + df_pd.explain() + df_pd.show() // read vertex chunk files in Orc val orc_file_path = "gar-test/ldbc_sample/orc/" @@ -58,6 +70,21 @@ class ReaderSuite extends AnyFunSuite { .load(orc_read_path) // validate reading results assert(df2.rdd.collect().deep == df1.rdd.collect().deep) + df_pd = df1.filter(cond) + + /** + * ==Physical Plan== + * (1) Filter (isnotnull(id#0L) AND (id#0L < 1000)) + * +- *(1) ColumnarToRow + * +- BatchScan[id#0L] GarScan DataFilters: [isnotnull(id#0L), (id#0L < + * 1000)], Format: gar, Location: InMemoryFileIndex(1 + * paths)[file:/path/to/GraphAr/spark/src/test/resources/gar-test/l..., + * PartitionFilters: [], PushedFilters: [IsNotNull(id), LessThan(id,1000)], + * ReadSchema: struct, PushedFilters: [IsNotNull(id), + * LessThan(id,1000)] RuntimeFilters: [] + */ + df_pd.explain() + df_pd.show() // read adjList chunk files recursively in CSV val csv_file_path = "gar-test/ldbc_sample/csv/" @@ -85,7 +112,7 @@ class ReaderSuite extends AnyFunSuite { test("read vertex chunks") { // construct the vertex information - val file_path = "gar-test/ldbc_sample/csv/" + val file_path = "gar-test/ldbc_sample/parquet/" val prefix = getClass.getClassLoader.getResource(file_path).getPath val vertex_yaml = getClass.getClassLoader .getResource(file_path + "person.vertex.yml") @@ -101,33 +128,75 @@ class ReaderSuite extends AnyFunSuite { // test reading a single property chunk val single_chunk_df = reader.readVertexPropertyChunk(property_group, 0) - assert(single_chunk_df.columns.size == 3) + assert(single_chunk_df.columns.length == 3) assert(single_chunk_df.count() == 100) + val cond = "gender = 'female'" + var df_pd = single_chunk_df.select("firstName", "gender").filter(cond) - // test reading chunks for a property group - val property_df = reader.readVertexPropertyGroup(property_group, false) - assert(property_df.columns.size == 3) + /** + * ==Physical Plan== + * (1) Filter (isnotnull(gender#2) AND (gender#2 = female)) + * +- *(1) ColumnarToRow + * +- BatchScan[firstName#0, gender#2] GarScan DataFilters: + * [isnotnull(gender#2), (gender#2 = female)], Format: gar, Location: + * InMemoryFileIndex(1 + * paths)[file:/path/to/GraphAr/spark/src/test/resources/gar-test/l..., + * PartitionFilters: [], PushedFilters: [IsNotNull(gender), + * EqualTo(gender,female)], ReadSchema: + * struct, PushedFilters: + * [IsNotNull(gender), EqualTo(gender,female)] RuntimeFilters: [] + */ + df_pd.explain() + df_pd.show() + + // test reading all chunks for a property group + val property_df = + reader.readVertexPropertyGroup(property_group, addIndex = false) + assert(property_df.columns.length == 3) assert(property_df.count() == 903) + df_pd = property_df.select("firstName", "gender").filter(cond) + + /** + * ==Physical Plan== + * (1) Filter (isnotnull(gender#31) AND (gender#31 = female)) + * +- *(1) ColumnarToRow + * +- BatchScan[firstName#29, gender#31] GarScan DataFilters: + * [isnotnull(gender#31), (gender#31 = female)], Format: gar, Location: + * InMemoryFileIndex(1 + * paths)[file:/path/to/code/cpp/GraphAr/spark/src/test/resources/gar-test/l..., + * PartitionFilters: [], PushedFilters: [IsNotNull(gender), + * EqualTo(gender,female)], ReadSchema: + * struct, PushedFilters: + * [IsNotNull(gender), EqualTo(gender,female)] RuntimeFilters: [] + */ + df_pd.explain() + df_pd.show() // test reading chunks for multiple property groups val property_group_1 = vertex_info.getPropertyGroup("id") - var property_groups = new java.util.ArrayList[PropertyGroup]() + val property_groups = new java.util.ArrayList[PropertyGroup]() property_groups.add(property_group_1) property_groups.add(property_group) val multiple_property_df = - reader.readMultipleVertexPropertyGroups(property_groups, false) - assert(multiple_property_df.columns.size == 4) + reader.readMultipleVertexPropertyGroups(property_groups, addIndex = false) + assert(multiple_property_df.columns.length == 4) assert(multiple_property_df.count() == 903) - + df_pd = multiple_property_df.filter(cond) + df_pd.explain() + df_pd.show() // test reading chunks for all property groups and optionally adding indices - val vertex_df = reader.readAllVertexPropertyGroups(false) - vertex_df.show() - assert(vertex_df.columns.size == 4) + val vertex_df = reader.readAllVertexPropertyGroups(addIndex = false) + assert(vertex_df.columns.length == 4) assert(vertex_df.count() == 903) + df_pd = vertex_df.filter(cond) + df_pd.explain() + df_pd.show() val vertex_df_with_index = reader.readAllVertexPropertyGroups() - vertex_df_with_index.show() - assert(vertex_df_with_index.columns.size == 5) + assert(vertex_df_with_index.columns.length == 5) assert(vertex_df_with_index.count() == 903) + df_pd = vertex_df_with_index.filter(cond).select("firstName", "gender") + df_pd.explain() + df_pd.show() // throw an exception for non-existing property groups val invalid_property_group = new PropertyGroup()