Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improve][Spark] Improve the performance of GraphAr Spark Reader #84

Merged
merged 10 commits into from
Jan 18, 2023
2 changes: 1 addition & 1 deletion spark/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<scala.binary.version>2.12</scala.binary.version>
<PermGen>512m</PermGen>
<MaxPermGen>1024m</MaxPermGen>
<spark.version>3.2.0</spark.version>
<spark.version>3.2.2</spark.version>
<maven.compiler.release>8</maven.compiler.release>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/** Copyright 2022 Alibaba Group Holding Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.graphar.datasources

import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/** GarDataSource is a class to provide gar files as the data source for spark. */
class GarDataSource extends FileDataSourceV2 {

/** The default fallback file format is Parquet. */
override def fallbackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat]

/** The string that represents the format name. */
override def shortName(): String = "gar"

/** Provide a table from the data source. */
override def getTable(options: CaseInsensitiveStringMap): Table = {
val paths = getPaths(options)
val tableName = getTableName(options, paths)
val optionsWithoutPaths = getOptionsWithoutPaths(options)
GarTable(tableName, sparkSession, optionsWithoutPaths, paths, None, getFallbackFileFormat(options))
}

/** Provide a table from the data source with specific schema. */
override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
val paths = getPaths(options)
val tableName = getTableName(options, paths)
val optionsWithoutPaths = getOptionsWithoutPaths(options)
GarTable(tableName, sparkSession, optionsWithoutPaths, paths, Some(schema), getFallbackFileFormat(options))
}

// Get the actual fall back file format.
private def getFallbackFileFormat(options: CaseInsensitiveStringMap): Class[_ <: FileFormat] = options.get("fileFormat") match {
case "csv" => classOf[CSVFileFormat]
case "orc" => classOf[OrcFileFormat]
case "parquet" => classOf[ParquetFileFormat]
case _ => throw new IllegalArgumentException
}
}
237 changes: 237 additions & 0 deletions spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
/** Copyright 2022 Alibaba Group Holding Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.graphar.datasources

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.parquet.hadoop.ParquetInputFormat

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils}
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.execution.PartitionedFileUtil
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitioningAwareFileIndex, PartitionedFile}
import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport}
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetPartitionReaderFactory
import org.apache.spark.sql.execution.datasources.v2.orc.OrcPartitionReaderFactory
import org.apache.spark.sql.execution.datasources.v2.csv.CSVPartitionReaderFactory
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration

/** GarScan is a class to implement the file scan for GarDataSource. */
case class GarScan(
sparkSession: SparkSession,
hadoopConf: Configuration,
fileIndex: PartitioningAwareFileIndex,
dataSchema: StructType,
readDataSchema: StructType,
readPartitionSchema: StructType,
pushedFilters: Array[Filter],
options: CaseInsensitiveStringMap,
formatName: String,
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {

/** The gar format is not splitable. */
override def isSplitable(path: Path): Boolean = false

/** Create the reader factory according to the actual file format. */
override def createReaderFactory(): PartitionReaderFactory = formatName match {
case "csv" => createCSVReaderFactory()
case "orc" => createOrcReaderFactory()
case "parquet" => createParquetReaderFactory()
case _ => throw new IllegalArgumentException
}

// Create the reader factory for the CSV format.
private def createCSVReaderFactory(): PartitionReaderFactory = {
val columnPruning = sparkSession.sessionState.conf.csvColumnPruning &&
!readDataSchema.exists(_.name == sparkSession.sessionState.conf.columnNameOfCorruptRecord)

val parsedOptions: CSVOptions = new CSVOptions(
options.asScala.toMap,
columnPruning = columnPruning,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)

// Check a field requirement for corrupt records here to throw an exception in a driver side
ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord)
// Don't push any filter which refers to the "virtual" column which cannot present in the input.
// Such filters will be applied later on the upper layer.
val actualFilters =
pushedFilters.filterNot(_.references.contains(parsedOptions.columnNameOfCorruptRecord))

val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
val broadcastedConf = sparkSession.sparkContext.broadcast(
new SerializableConfiguration(hadoopConf))
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, readDataSchema, readPartitionSchema, parsedOptions, actualFilters)
}

// Create the reader factory for the Orc format.
private def createOrcReaderFactory(): PartitionReaderFactory = {
val broadcastedConf = sparkSession.sparkContext.broadcast(
new SerializableConfiguration(hadoopConf))
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, readDataSchema, readPartitionSchema, pushedFilters)
}

// Create the reader factory for the Parquet format.
private def createParquetReaderFactory(): PartitionReaderFactory = {
val readDataSchemaAsJson = readDataSchema.json
hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName)
hadoopConf.set(
ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
readDataSchemaAsJson)
hadoopConf.set(
ParquetWriteSupport.SPARK_ROW_SCHEMA,
readDataSchemaAsJson)
hadoopConf.set(
SQLConf.SESSION_LOCAL_TIMEZONE.key,
sparkSession.sessionState.conf.sessionLocalTimeZone)
hadoopConf.setBoolean(
SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key,
sparkSession.sessionState.conf.nestedSchemaPruningEnabled)
hadoopConf.setBoolean(
SQLConf.CASE_SENSITIVE.key,
sparkSession.sessionState.conf.caseSensitiveAnalysis)

ParquetWriteSupport.setSchema(readDataSchema, hadoopConf)

// Sets flags for `ParquetToSparkSchemaConverter`
hadoopConf.setBoolean(
SQLConf.PARQUET_BINARY_AS_STRING.key,
sparkSession.sessionState.conf.isParquetBinaryAsString)
hadoopConf.setBoolean(
SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
sparkSession.sessionState.conf.isParquetINT96AsTimestamp)

val broadcastedConf = sparkSession.sparkContext.broadcast(
new SerializableConfiguration(hadoopConf))
val sqlConf = sparkSession.sessionState.conf
ParquetPartitionReaderFactory(
sqlConf,
broadcastedConf,
dataSchema,
readDataSchema,
readPartitionSchema,
pushedFilters,
new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf))
}

/**
* Override "partitions" of org.apache.spark.sql.execution.datasources.v2.FileScan
* to disable splitting and sort the files by file paths instead of by file sizes.
* Note: This implementation does not support to partition attributes.
*/
override protected def partitions: Seq[FilePartition] = {
val selectedPartitions = fileIndex.listFiles(partitionFilters, dataFilters)
val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions)

val splitFiles = selectedPartitions.flatMap { partition =>
val partitionValues = partition.values
partition.files.flatMap { file =>
val filePath = file.getPath
PartitionedFileUtil.splitFiles(
sparkSession = sparkSession,
file = file,
filePath = filePath,
isSplitable = isSplitable(filePath),
maxSplitBytes = maxSplitBytes,
partitionValues = partitionValues
)
}.toArray.sortBy(_.filePath)
}

getFilePartitions(sparkSession, splitFiles)
}

/**
* Override "getFilePartitions" of org.apache.spark.sql.execution.datasources.FilePartition
* to assign each chunk file in GraphAr to a single partition.
*/
private def getFilePartitions(
sparkSession: SparkSession,
partitionedFiles: Seq[PartitionedFile]): Seq[FilePartition] = {
val partitions = new ArrayBuffer[FilePartition]
val currentFiles = new ArrayBuffer[PartitionedFile]

/** Close the current partition and move to the next. */
def closePartition(): Unit = {
if (currentFiles.nonEmpty) {
// Copy to a new Array.
val newPartition = FilePartition(partitions.size, currentFiles.toArray)
partitions += newPartition
}
currentFiles.clear()
}
// Assign a file to each partition
partitionedFiles.foreach { file =>
closePartition()
// Add the given file to the current partition.
currentFiles += file
}
closePartition()
partitions.toSeq
}

/** Check if two objects are equal. */
override def equals(obj: Any): Boolean = obj match {
case g: GarScan =>
super.equals(g) && dataSchema == g.dataSchema && options == g.options &&
equivalentFilters(pushedFilters, g.pushedFilters) && formatName == g.formatName
case _ => false
}

/** Get the hash code of the object. */
override def hashCode(): Int = formatName match {
case "csv" => super.hashCode()
case "orc" => getClass.hashCode()
case "parquet" => getClass.hashCode()
case _ => throw new IllegalArgumentException
}

/** Get the description string of the object. */
override def description(): String = {
super.description() + ", PushedFilters: " + seqToString(pushedFilters)
}

/** Get the meata data map of the object. */
override def getMetaData(): Map[String, String] = {
super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters))
}

/** Construct the file scan with filters. */
override def withFilters(
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/** Copyright 2022 Alibaba Group Holding Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.graphar.datasources

import scala.collection.JavaConverters._

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/** GarScanBuilder is a class to build the file scan for GarDataSource. */
case class GarScanBuilder(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
schema: StructType,
dataSchema: StructType,
options: CaseInsensitiveStringMap,
formatName: String)
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
lazy val hadoopConf = {
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
}

// Check if the file format supports nested schema pruning.
override protected val supportsNestedSchemaPruning: Boolean = formatName match {
case "csv" => false
case "orc" => true
case "parquet" => true
case _ => throw new IllegalArgumentException
}

// Note: This scan builder does not implement "with SupportsPushDownFilters".
private var filters: Array[Filter] = Array.empty

// Note: To support pushdown filters, these two methods need to be implemented.

// override def pushFilters(filters: Array[Filter]): Array[Filter]

// override def pushedFilters(): Array[Filter]

/** Build the file scan for GarDataSource. */
override def build(): Scan = {
GarScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(),
readPartitionSchema(), filters, options, formatName)
}
}
Loading