Skip to content

Commit

Permalink
[Spark] Support property filter pushdown by utilizing payload file fo…
Browse files Browse the repository at this point in the history
…rmats (#221)
  • Loading branch information
Ziy1-Tan authored Sep 25, 2023
1 parent fbe03ed commit 1263e73
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.alibaba.graphar.datasources

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.hadoop.conf.Configuration
Expand All @@ -34,7 +33,6 @@ import org.apache.spark.sql.execution.datasources.{
PartitioningAwareFileIndex,
PartitionedFile
}
import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
import org.apache.spark.sql.execution.datasources.parquet.{
ParquetOptions,
ParquetReadSupport,
Expand Down Expand Up @@ -279,7 +277,7 @@ case class GarScan(
super.description() + ", PushedFilters: " + seqToString(pushedFilters)
}

/** Get the meata data map of the object. */
/** Get the meta data map of the object. */
override def getMetaData(): Map[String, String] = {
super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@

package com.alibaba.graphar.datasources

import scala.collection.JavaConverters._

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex

import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

import scala.collection.JavaConverters._
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScanBuilder
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScanBuilder

/** GarScanBuilder is a class to build the file scan for GarDataSource. */
case class GarScanBuilder(
sparkSession: SparkSession,
Expand All @@ -34,30 +37,58 @@ case class GarScanBuilder(
dataSchema: StructType,
options: CaseInsensitiveStringMap,
formatName: String
) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
) extends FileScanBuilder(sparkSession, fileIndex, dataSchema)
with SupportsPushDownFilters {
lazy val hadoopConf = {
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
}

// Check if the file format supports nested schema pruning.
override protected val supportsNestedSchemaPruning: Boolean =
formatName match {
case "csv" => false
case "orc" => true
case "parquet" => true
case _ => throw new IllegalArgumentException
}

// Note: This scan builder does not implement "with SupportsPushDownFilters".
private var filters: Array[Filter] = Array.empty
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
this.filters = filters
filters
}

// Note: To support pushdown filters, these two methods need to be implemented.
override def pushedFilters(): Array[Filter] = formatName match {
case "csv" => Array.empty[Filter]
case "orc" => pushedOrcFilters
case "parquet" => pushedParquetFilters
case _ => throw new IllegalArgumentException
}

// override def pushFilters(filters: Array[Filter]): Array[Filter]
private lazy val pushedParquetFilters: Array[Filter] = {
if (!sparkSession.sessionState.conf.parquetFilterPushDown) {
Array.empty[Filter]
} else {
val builder =
ParquetScanBuilder(sparkSession, fileIndex, schema, dataSchema, options)
builder.pushFilters(this.filters)
builder.pushedFilters()
}
}

private lazy val pushedOrcFilters: Array[Filter] = {
if (!sparkSession.sessionState.conf.orcFilterPushDown) {
Array.empty[Filter]
} else {
val builder =
OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options)
builder.pushFilters(this.filters)
builder.pushedFilters()
}
}

// override def pushedFilters(): Array[Filter]
// Check if the file format supports nested schema pruning.
override protected val supportsNestedSchemaPruning: Boolean =
formatName match {
case "csv" => false
case "orc" => sparkSession.sessionState.conf.nestedSchemaPruningEnabled
case "parquet" =>
sparkSession.sessionState.conf.nestedSchemaPruningEnabled
case _ => throw new IllegalArgumentException
}

/** Build the file scan for GarDataSource. */
override def build(): Scan = {
Expand All @@ -68,7 +99,7 @@ case class GarScanBuilder(
dataSchema,
readDataSchema(),
readPartitionSchema(),
filters,
pushedFilters(),
options,
formatName
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class VertexReader(
propertyGroup: PropertyGroup,
chunk_index: Long
): DataFrame = {
if (vertexInfo.containPropertyGroup(propertyGroup) == false) {
if (!vertexInfo.containPropertyGroup(propertyGroup)) {
throw new IllegalArgumentException
}
val file_type = propertyGroup.getFile_type()
Expand Down Expand Up @@ -95,7 +95,7 @@ class VertexReader(
propertyGroup: PropertyGroup,
addIndex: Boolean = true
): DataFrame = {
if (vertexInfo.containPropertyGroup(propertyGroup) == false) {
if (!vertexInfo.containPropertyGroup(propertyGroup)) {
throw new IllegalArgumentException
}
val file_type = propertyGroup.getFile_type()
Expand All @@ -107,9 +107,9 @@ class VertexReader(
.load(file_path)

if (addIndex) {
return IndexGenerator.generateVertexIndexColumn(df)
IndexGenerator.generateVertexIndexColumn(df)
} else {
return df
df
}
}

Expand Down Expand Up @@ -145,7 +145,7 @@ class VertexReader(

var rdd = df0.rdd
var schema_array = df0.schema.fields
for (i <- 1 to len - 1) {
for (i <- 1 until len) {
val pg: PropertyGroup = propertyGroups.get(i)
val new_df = readVertexPropertyGroup(pg, false)
schema_array = Array.concat(schema_array, new_df.schema.fields)
Expand All @@ -155,9 +155,9 @@ class VertexReader(
val schema = StructType(schema_array)
val df = spark.createDataFrame(rdd, schema)
if (addIndex) {
return IndexGenerator.generateVertexIndexColumn(df)
IndexGenerator.generateVertexIndexColumn(df)
} else {
return df
df
}
}

Expand Down
111 changes: 90 additions & 21 deletions spark/src/test/scala/com/alibaba/graphar/TestReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,9 @@

package com.alibaba.graphar

import com.alibaba.graphar.datasources._
import com.alibaba.graphar.reader.{VertexReader, EdgeReader}

import java.io.{File, FileInputStream}
import org.yaml.snakeyaml.Yaml
import org.yaml.snakeyaml.constructor.Constructor
import scala.beans.BeanProperty
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.SparkSession
import org.scalatest.funsuite.AnyFunSuite

class ReaderSuite extends AnyFunSuite {
Expand All @@ -33,7 +28,10 @@ class ReaderSuite extends AnyFunSuite {
.master("local[*]")
.getOrCreate()

spark.sparkContext.setLogLevel("Error")

test("read chunk files directly") {
val cond = "id < 1000"
// read vertex chunk files in Parquet
val parquet_file_path = "gar-test/ldbc_sample/parquet/"
val parquet_prefix =
Expand All @@ -46,7 +44,21 @@ class ReaderSuite extends AnyFunSuite {
// validate reading results
assert(df1.rdd.getNumPartitions == 10)
assert(df1.count() == 903)
// println(df1.rdd.collect().mkString("\n"))
var df_pd = df1.filter(cond)

/**
* ==Physical Plan==
* (1) Filter (isnotnull(id#0L) AND (id#0L < 1000))
* +- *(1) ColumnarToRow
* +- BatchScan[id#0L] GarScan DataFilters: [isnotnull(id#0L), (id#0L <
* 1000)], Format: gar, Location: InMemoryFileIndex(1
* paths)[file:/path/to/code/cpp/GraphAr/spark/src/test/resources/gar-test/l...,
* PartitionFilters: [], PushedFilters: [IsNotNull(id), LessThan(id,1000)],
* ReadSchema: struct<id:bigint>, PushedFilters: [IsNotNull(id),
* LessThan(id,1000)] RuntimeFilters: []
*/
df_pd.explain()
df_pd.show()

// read vertex chunk files in Orc
val orc_file_path = "gar-test/ldbc_sample/orc/"
Expand All @@ -58,6 +70,21 @@ class ReaderSuite extends AnyFunSuite {
.load(orc_read_path)
// validate reading results
assert(df2.rdd.collect().deep == df1.rdd.collect().deep)
df_pd = df1.filter(cond)

/**
* ==Physical Plan==
* (1) Filter (isnotnull(id#0L) AND (id#0L < 1000))
* +- *(1) ColumnarToRow
* +- BatchScan[id#0L] GarScan DataFilters: [isnotnull(id#0L), (id#0L <
* 1000)], Format: gar, Location: InMemoryFileIndex(1
* paths)[file:/path/to/GraphAr/spark/src/test/resources/gar-test/l...,
* PartitionFilters: [], PushedFilters: [IsNotNull(id), LessThan(id,1000)],
* ReadSchema: struct<id:bigint>, PushedFilters: [IsNotNull(id),
* LessThan(id,1000)] RuntimeFilters: []
*/
df_pd.explain()
df_pd.show()

// read adjList chunk files recursively in CSV
val csv_file_path = "gar-test/ldbc_sample/csv/"
Expand Down Expand Up @@ -85,7 +112,7 @@ class ReaderSuite extends AnyFunSuite {

test("read vertex chunks") {
// construct the vertex information
val file_path = "gar-test/ldbc_sample/csv/"
val file_path = "gar-test/ldbc_sample/parquet/"
val prefix = getClass.getClassLoader.getResource(file_path).getPath
val vertex_yaml = getClass.getClassLoader
.getResource(file_path + "person.vertex.yml")
Expand All @@ -101,33 +128,75 @@ class ReaderSuite extends AnyFunSuite {

// test reading a single property chunk
val single_chunk_df = reader.readVertexPropertyChunk(property_group, 0)
assert(single_chunk_df.columns.size == 3)
assert(single_chunk_df.columns.length == 3)
assert(single_chunk_df.count() == 100)
val cond = "gender = 'female'"
var df_pd = single_chunk_df.select("firstName", "gender").filter(cond)

// test reading chunks for a property group
val property_df = reader.readVertexPropertyGroup(property_group, false)
assert(property_df.columns.size == 3)
/**
* ==Physical Plan==
* (1) Filter (isnotnull(gender#2) AND (gender#2 = female))
* +- *(1) ColumnarToRow
* +- BatchScan[firstName#0, gender#2] GarScan DataFilters:
* [isnotnull(gender#2), (gender#2 = female)], Format: gar, Location:
* InMemoryFileIndex(1
* paths)[file:/path/to/GraphAr/spark/src/test/resources/gar-test/l...,
* PartitionFilters: [], PushedFilters: [IsNotNull(gender),
* EqualTo(gender,female)], ReadSchema:
* struct<firstName:string,gender:string>, PushedFilters:
* [IsNotNull(gender), EqualTo(gender,female)] RuntimeFilters: []
*/
df_pd.explain()
df_pd.show()

// test reading all chunks for a property group
val property_df =
reader.readVertexPropertyGroup(property_group, addIndex = false)
assert(property_df.columns.length == 3)
assert(property_df.count() == 903)
df_pd = property_df.select("firstName", "gender").filter(cond)

/**
* ==Physical Plan==
* (1) Filter (isnotnull(gender#31) AND (gender#31 = female))
* +- *(1) ColumnarToRow
* +- BatchScan[firstName#29, gender#31] GarScan DataFilters:
* [isnotnull(gender#31), (gender#31 = female)], Format: gar, Location:
* InMemoryFileIndex(1
* paths)[file:/path/to/code/cpp/GraphAr/spark/src/test/resources/gar-test/l...,
* PartitionFilters: [], PushedFilters: [IsNotNull(gender),
* EqualTo(gender,female)], ReadSchema:
* struct<firstName:string,gender:string>, PushedFilters:
* [IsNotNull(gender), EqualTo(gender,female)] RuntimeFilters: []
*/
df_pd.explain()
df_pd.show()

// test reading chunks for multiple property groups
val property_group_1 = vertex_info.getPropertyGroup("id")
var property_groups = new java.util.ArrayList[PropertyGroup]()
val property_groups = new java.util.ArrayList[PropertyGroup]()
property_groups.add(property_group_1)
property_groups.add(property_group)
val multiple_property_df =
reader.readMultipleVertexPropertyGroups(property_groups, false)
assert(multiple_property_df.columns.size == 4)
reader.readMultipleVertexPropertyGroups(property_groups, addIndex = false)
assert(multiple_property_df.columns.length == 4)
assert(multiple_property_df.count() == 903)

df_pd = multiple_property_df.filter(cond)
df_pd.explain()
df_pd.show()
// test reading chunks for all property groups and optionally adding indices
val vertex_df = reader.readAllVertexPropertyGroups(false)
vertex_df.show()
assert(vertex_df.columns.size == 4)
val vertex_df = reader.readAllVertexPropertyGroups(addIndex = false)
assert(vertex_df.columns.length == 4)
assert(vertex_df.count() == 903)
df_pd = vertex_df.filter(cond)
df_pd.explain()
df_pd.show()
val vertex_df_with_index = reader.readAllVertexPropertyGroups()
vertex_df_with_index.show()
assert(vertex_df_with_index.columns.size == 5)
assert(vertex_df_with_index.columns.length == 5)
assert(vertex_df_with_index.count() == 903)
df_pd = vertex_df_with_index.filter(cond).select("firstName", "gender")
df_pd.explain()
df_pd.show()

// throw an exception for non-existing property groups
val invalid_property_group = new PropertyGroup()
Expand Down

0 comments on commit 1263e73

Please sign in to comment.