Skip to content

Commit

Permalink
Feat: filter pushdown for spark
Browse files Browse the repository at this point in the history
Signed-off-by: Ziy1-Tan <[email protected]>
  • Loading branch information
Ziy1-Tan authored and 伏聆 committed Sep 8, 2023
1 parent d095ee6 commit 44f2e80
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.alibaba.graphar.datasources

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.hadoop.conf.Configuration
Expand All @@ -40,6 +39,8 @@ import org.apache.spark.sql.execution.datasources.parquet.{
ParquetReadSupport,
ParquetWriteSupport
}
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitioningAwareFileIndex, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport}
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetPartitionReaderFactory
import org.apache.spark.sql.execution.datasources.v2.orc.OrcPartitionReaderFactory
Expand Down Expand Up @@ -279,7 +280,7 @@ case class GarScan(
super.description() + ", PushedFilters: " + seqToString(pushedFilters)
}

/** Get the meata data map of the object. */
/** Get the meta data map of the object. */
override def getMetaData(): Map[String, String] = {
super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@

package com.alibaba.graphar.datasources

import scala.collection.JavaConverters._

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex

import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

import scala.collection.JavaConverters._
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScanBuilder
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScanBuilder

/** GarScanBuilder is a class to build the file scan for GarDataSource. */
case class GarScanBuilder(
sparkSession: SparkSession,
Expand All @@ -34,13 +37,37 @@ case class GarScanBuilder(
dataSchema: StructType,
options: CaseInsensitiveStringMap,
formatName: String
) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
) extends FileScanBuilder(sparkSession, fileIndex, dataSchema)
with SupportsPushDownFilters {
lazy val hadoopConf = {
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
}

private var filters: Array[Filter] = Array.empty
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
this.filters = filters
filters
}

override def pushedFilters(): Array[Filter] = {
formatName match {
case "csv" => Array.empty
case "orc" => pushedOrcFilters
case "parquet" => pushedParquetFilters
case _ => throw new IllegalArgumentException
}
}

private lazy val pushedParquetFilters =
ParquetScanBuilder(sparkSession, fileIndex, schema, dataSchema, options)
.pushFilters(filters)

private lazy val pushedOrcFilters =
OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options)
.pushFilters(filters)

// Check if the file format supports nested schema pruning.
override protected val supportsNestedSchemaPruning: Boolean =
formatName match {
Expand All @@ -50,15 +77,6 @@ case class GarScanBuilder(
case _ => throw new IllegalArgumentException
}

// Note: This scan builder does not implement "with SupportsPushDownFilters".
private var filters: Array[Filter] = Array.empty

// Note: To support pushdown filters, these two methods need to be implemented.

// override def pushFilters(filters: Array[Filter]): Array[Filter]

// override def pushedFilters(): Array[Filter]

/** Build the file scan for GarDataSource. */
override def build(): Scan = {
GarScan(
Expand All @@ -68,7 +86,7 @@ case class GarScanBuilder(
dataSchema,
readDataSchema(),
readPartitionSchema(),
filters,
pushedFilters(),
options,
formatName
)
Expand Down
30 changes: 10 additions & 20 deletions spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,8 @@ class VertexReader(
* vertex property chunk DataFrame. Raise IllegalArgumentException if the
* property group not contained.
*/
def readVertexPropertyChunk(
propertyGroup: PropertyGroup,
chunk_index: Long
): DataFrame = {
if (vertexInfo.containPropertyGroup(propertyGroup) == false) {
def readVertexPropertyChunk(propertyGroup: PropertyGroup, chunk_index: Long): DataFrame = {
if (!vertexInfo.containPropertyGroup(propertyGroup)) {
throw new IllegalArgumentException
}
val file_type = propertyGroup.getFile_type()
Expand All @@ -91,25 +88,18 @@ class VertexReader(
* DataFrame that contains all chunks of property group. Raise
* IllegalArgumentException if the property group not contained.
*/
def readVertexPropertyGroup(
propertyGroup: PropertyGroup,
addIndex: Boolean = true
): DataFrame = {
if (vertexInfo.containPropertyGroup(propertyGroup) == false) {
def readVertexPropertyGroup(propertyGroup: PropertyGroup, addIndex: Boolean = true): DataFrame = {
if (!vertexInfo.containPropertyGroup(propertyGroup)) {
throw new IllegalArgumentException
}
val file_type = propertyGroup.getFile_type()
val file_path = prefix + vertexInfo.getPathPrefix(propertyGroup)
val df = spark.read
.option("fileFormat", file_type)
.option("header", "true")
.format("com.alibaba.graphar.datasources.GarDataSource")
.load(file_path)
val df = spark.read.option("fileFormat", file_type).option("header", "true").format("com.alibaba.graphar.datasources.GarDataSource").load(file_path)

if (addIndex) {
return IndexGenerator.generateVertexIndexColumn(df)
IndexGenerator.generateVertexIndexColumn(df)
} else {
return df
df
}
}

Expand Down Expand Up @@ -145,7 +135,7 @@ class VertexReader(

var rdd = df0.rdd
var schema_array = df0.schema.fields
for (i <- 1 to len - 1) {
for ( i <- 1 until len) {
val pg: PropertyGroup = propertyGroups.get(i)
val new_df = readVertexPropertyGroup(pg, false)
schema_array = Array.concat(schema_array, new_df.schema.fields)
Expand All @@ -155,9 +145,9 @@ class VertexReader(
val schema = StructType(schema_array)
val df = spark.createDataFrame(rdd, schema)
if (addIndex) {
return IndexGenerator.generateVertexIndexColumn(df)
IndexGenerator.generateVertexIndexColumn(df)
} else {
return df
df
}
}

Expand Down
81 changes: 43 additions & 38 deletions spark/src/test/scala/com/alibaba/graphar/TestReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,9 @@

package com.alibaba.graphar

import com.alibaba.graphar.datasources._
import com.alibaba.graphar.reader.{VertexReader, EdgeReader}

import java.io.{File, FileInputStream}
import org.yaml.snakeyaml.Yaml
import org.yaml.snakeyaml.constructor.Constructor
import scala.beans.BeanProperty
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.SparkSession
import org.scalatest.funsuite.AnyFunSuite

class ReaderSuite extends AnyFunSuite {
Expand All @@ -33,7 +28,10 @@ class ReaderSuite extends AnyFunSuite {
.master("local[*]")
.getOrCreate()

spark.sparkContext.setLogLevel("Error")

test("read chunk files directly") {
val cond = "id < 1000"
// read vertex chunk files in Parquet
val parquet_file_path = "gar-test/ldbc_sample/parquet/"
val parquet_prefix =
Expand All @@ -46,7 +44,9 @@ class ReaderSuite extends AnyFunSuite {
// validate reading results
assert(df1.rdd.getNumPartitions == 10)
assert(df1.count() == 903)
// println(df1.rdd.collect().mkString("\n"))
var df_pd = df1.filter(cond)
df_pd.explain()
df_pd.show()

// read vertex chunk files in Orc
val orc_file_path = "gar-test/ldbc_sample/orc/"
Expand All @@ -58,6 +58,9 @@ class ReaderSuite extends AnyFunSuite {
.load(orc_read_path)
// validate reading results
assert(df2.rdd.collect().deep == df1.rdd.collect().deep)
df_pd = df1.filter(cond)
df_pd.explain()
df_pd.show()

// read adjList chunk files recursively in CSV
val csv_file_path = "gar-test/ldbc_sample/csv/"
Expand Down Expand Up @@ -85,7 +88,7 @@ class ReaderSuite extends AnyFunSuite {

test("read vertex chunks") {
// construct the vertex information
val file_path = "gar-test/ldbc_sample/csv/"
val file_path = "gar-test/ldbc_sample/parquet/"
val prefix = getClass.getClassLoader.getResource(file_path).getPath
val vertex_yaml = getClass.getClassLoader
.getResource(file_path + "person.vertex.yml")
Expand All @@ -101,42 +104,50 @@ class ReaderSuite extends AnyFunSuite {

// test reading a single property chunk
val single_chunk_df = reader.readVertexPropertyChunk(property_group, 0)
assert(single_chunk_df.columns.size == 3)
assert(single_chunk_df.columns.length == 3)
assert(single_chunk_df.count() == 100)

// test reading chunks for a property group
val property_df = reader.readVertexPropertyGroup(property_group, false)
assert(property_df.columns.size == 3)
val cond = "gender = 'female'"
var df_pd = single_chunk_df.select("firstName","gender").filter(cond)
df_pd.explain()
df_pd.show()

// test reading all chunks for a property group
val property_df = reader.readVertexPropertyGroup(property_group, addIndex = false)
assert(property_df.columns.length == 3)
assert(property_df.count() == 903)
df_pd = property_df.select("firstName", "gender").filter(cond)
df_pd.explain()
df_pd.show()

// test reading chunks for multiple property groups
val property_group_1 = vertex_info.getPropertyGroup("id")
var property_groups = new java.util.ArrayList[PropertyGroup]()
val property_groups = new java.util.ArrayList[PropertyGroup]()
property_groups.add(property_group_1)
property_groups.add(property_group)
val multiple_property_df =
reader.readMultipleVertexPropertyGroups(property_groups, false)
assert(multiple_property_df.columns.size == 4)
val multiple_property_df = reader.readMultipleVertexPropertyGroups(property_groups, addIndex = false)
assert(multiple_property_df.columns.length == 4)
assert(multiple_property_df.count() == 903)

df_pd = multiple_property_df.filter(cond)
df_pd.explain()
df_pd.show()
// test reading chunks for all property groups and optionally adding indices
val vertex_df = reader.readAllVertexPropertyGroups(false)
vertex_df.show()
assert(vertex_df.columns.size == 4)
val vertex_df = reader.readAllVertexPropertyGroups(addIndex = false)
assert(vertex_df.columns.length == 4)
assert(vertex_df.count() == 903)
df_pd = vertex_df.filter(cond)
df_pd.explain()
df_pd.show()
val vertex_df_with_index = reader.readAllVertexPropertyGroups()
vertex_df_with_index.show()
assert(vertex_df_with_index.columns.size == 5)
assert(vertex_df_with_index.columns.length == 5)
assert(vertex_df_with_index.count() == 903)
df_pd = vertex_df_with_index.filter(cond).select("firstName", "gender")
df_pd.explain("formatted")
df_pd.show()

// throw an exception for non-existing property groups
val invalid_property_group = new PropertyGroup()
assertThrows[IllegalArgumentException](
reader.readVertexPropertyChunk(invalid_property_group, 0)
)
assertThrows[IllegalArgumentException](
reader.readVertexPropertyGroup(invalid_property_group)
)
assertThrows[IllegalArgumentException](reader.readVertexPropertyChunk(invalid_property_group, 0))
assertThrows[IllegalArgumentException](reader.readVertexPropertyGroup(invalid_property_group))
}

test("read edge chunks") {
Expand Down Expand Up @@ -229,15 +240,9 @@ class ReaderSuite extends AnyFunSuite {

// throw an exception for non-existing property groups
val invalid_property_group = new PropertyGroup()
assertThrows[IllegalArgumentException](
reader.readEdgePropertyChunk(invalid_property_group, 0, 0)
)
assertThrows[IllegalArgumentException](
reader.readEdgePropertyGroupForVertexChunk(invalid_property_group, 0)
)
assertThrows[IllegalArgumentException](
reader.readEdgePropertyGroup(invalid_property_group)
)
assertThrows[IllegalArgumentException](reader.readEdgePropertyChunk(invalid_property_group, 0, 0))
assertThrows[IllegalArgumentException](reader.readEdgePropertyGroupForVertexChunk(invalid_property_group, 0))
assertThrows[IllegalArgumentException](reader.readEdgePropertyGroup(invalid_property_group))

// throw an exception for non-existing adjList types
val invalid_adj_list_type = AdjListType.unordered_by_dest
Expand Down

0 comments on commit 44f2e80

Please sign in to comment.