Skip to content

Commit

Permalink
[Improvement] Improve GraphAr spark writer performance and implement …
Browse files Browse the repository at this point in the history
…custom writer builder to bypass spark's write behavior (#92)
  • Loading branch information
acezen authored Feb 20, 2023
1 parent d26d3b8 commit ad30121
Show file tree
Hide file tree
Showing 16 changed files with 906 additions and 205 deletions.
2 changes: 2 additions & 0 deletions spark/src/main/java/com/alibaba/graphar/GeneralParams.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ public class GeneralParams {
public static final String vertexChunkIndexCol = "_graphArVertexChunkIndex";
public static final String edgeIndexCol = "_graphArEdgeIndex";
public static final String regularSeperator = "_";
public static final String offsetStartChunkIndexKey = "_graphar_offset_start_chunk_index";
public static final String aggNumListOfEdgeChunkKey = "_graphar_agg_num_list_of_edge_chunk";
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/** Copyright 2022 Alibaba Group Holding Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.graphar.datasources

import com.alibaba.graphar.GeneralParams

import org.json4s._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.hadoop.mapreduce._
import org.apache.spark.internal.Logging

object GarCommitProtocol {
private def binarySearchPair(aggNums: Array[Int], key: Int): (Int, Int) = {
var low = 0
var high = aggNums.length - 1
var mid = 0
while (low <= high) {
mid = (high + low) / 2;
if (aggNums(mid) <= key && (mid == aggNums.length - 1 || aggNums(mid + 1) > key)) {
return (mid, key - aggNums(mid))
} else if (aggNums(mid) > key) {
high = mid - 1
} else {
low = mid + 1
}
}
return (low, key - aggNums(low))
}
}

class GarCommitProtocol(jobId: String,
path: String,
options: Map[String, String],
dynamicPartitionOverwrite: Boolean = false)
extends SQLHadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) with Serializable with Logging {

override def getFilename(taskContext: TaskAttemptContext, ext: String): String = {
val partitionId = taskContext.getTaskAttemptID.getTaskID.getId
if (options.contains(GeneralParams.offsetStartChunkIndexKey)) {
// offset chunk file name, looks like chunk0
val chunk_index = options.get(GeneralParams.offsetStartChunkIndexKey).get.toInt + partitionId
return f"chunk$chunk_index"
}
if (options.contains(GeneralParams.aggNumListOfEdgeChunkKey)) {
// edge chunk file name, looks like part0/chunk0
val jValue = parse(options.get(GeneralParams.aggNumListOfEdgeChunkKey).get)
implicit val formats = DefaultFormats // initialize a default formats for json4s
val aggNums: Array[Int] = Extraction.extract[Array[Int]](jValue)
val chunkPair: (Int, Int) = GarCommitProtocol.binarySearchPair(aggNums, partitionId)
val vertex_chunk_index: Int = chunkPair._1
val edge_chunk_index: Int = chunkPair._2
return f"part$vertex_chunk_index/chunk$edge_chunk_index"
}
// vertex chunk file name, looks like chunk0
return f"chunk$partitionId"
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
/** Copyright 2022 Alibaba Group Holding Limited.
/* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -15,38 +16,104 @@

package com.alibaba.graphar.datasources

import org.apache.spark.sql.connector.catalog.Table
import scala.collection.JavaConverters._
import java.util

import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.connector.expressions.Transform

/** GarDataSource is a class to provide gar files as the data source for spark. */
class GarDataSource extends FileDataSourceV2 {
import com.alibaba.graphar.utils.Utils

object GarUtils
/** GarDataSource is a class to provide gar files as the data source for spark. */
class GarDataSource extends TableProvider with DataSourceRegister {
/** The default fallback file format is Parquet. */
override def fallbackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat]
def fallbackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat]

lazy val sparkSession = SparkSession.active

/** The string that represents the format name. */
override def shortName(): String = "gar"

protected def getPaths(map: CaseInsensitiveStringMap): Seq[String] = {
val objectMapper = new ObjectMapper()
val paths = Option(map.get("paths")).map { pathStr =>
objectMapper.readValue(pathStr, classOf[Array[String]]).toSeq
}.getOrElse(Seq.empty)
paths ++ Option(map.get("path")).toSeq
}

protected def getOptionsWithoutPaths(map: CaseInsensitiveStringMap): CaseInsensitiveStringMap = {
val withoutPath = map.asCaseSensitiveMap().asScala.filterKeys { k =>
!k.equalsIgnoreCase("path") && !k.equalsIgnoreCase("paths")
}
new CaseInsensitiveStringMap(withoutPath.toMap.asJava)
}

protected def getTableName(map: CaseInsensitiveStringMap, paths: Seq[String]): String = {
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(
map.asCaseSensitiveMap().asScala.toMap)
val name = shortName() + " " + paths.map(qualifiedPathName(_, hadoopConf)).mkString(",")
Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, name)
}

private def qualifiedPathName(path: String, hadoopConf: Configuration): String = {
val hdfsPath = new Path(path)
val fs = hdfsPath.getFileSystem(hadoopConf)
hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toString
}

/** Provide a table from the data source. */
override def getTable(options: CaseInsensitiveStringMap): Table = {
def getTable(options: CaseInsensitiveStringMap): Table = {
val paths = getPaths(options)
val tableName = getTableName(options, paths)
val optionsWithoutPaths = getOptionsWithoutPaths(options)
GarTable(tableName, sparkSession, optionsWithoutPaths, paths, None, getFallbackFileFormat(options))
}

/** Provide a table from the data source with specific schema. */
override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
val paths = getPaths(options)
val tableName = getTableName(options, paths)
val optionsWithoutPaths = getOptionsWithoutPaths(options)
GarTable(tableName, sparkSession, optionsWithoutPaths, paths, Some(schema), getFallbackFileFormat(options))
GarTable(tableName, sparkSession, optionsWithoutPaths, paths, Some(schema), getFallbackFileFormat(options))
}

override def supportsExternalMetadata(): Boolean = true

private var t: Table = null

override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
if (t == null) t = getTable(options)
t.schema()
}

override def inferPartitioning(options: CaseInsensitiveStringMap): Array[Transform] = {
Array.empty
}

override def getTable(schema: StructType,
partitioning: Array[Transform],
properties: util.Map[String, String]): Table = {
// If the table is already loaded during schema inference, return it directly.
if (t != null) {
t
} else {
getTable(new CaseInsensitiveStringMap(properties), schema)
}
}

// Get the actual fall back file format.
Expand Down
43 changes: 19 additions & 24 deletions spark/src/main/scala/com/alibaba/graphar/datasources/GarTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,28 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.FileStatus

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuilder}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
import org.apache.spark.sql.execution.datasources.orc.OrcUtils
import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils
import org.apache.spark.sql.execution.datasources.v2.FileTable
import org.apache.spark.sql.execution.datasources.v2.csv.CSVWrite
import org.apache.spark.sql.execution.datasources.v2.orc.OrcWrite
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetWrite
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap

import com.alibaba.graphar.datasources.csv.CSVWriteBuilder
import com.alibaba.graphar.datasources.parquet.ParquetWriteBuilder
import com.alibaba.graphar.datasources.orc.OrcWriteBuilder


/** GarTable is a class to represent the graph data in GraphAr as a table. */
case class GarTable(
name: String,
sparkSession: SparkSession,
options: CaseInsensitiveStringMap,
paths: Seq[String],
userSpecifiedSchema: Option[StructType],
fallbackFileFormat: Class[_ <: FileFormat])
case class GarTable(name: String,
sparkSession: SparkSession,
options: CaseInsensitiveStringMap,
paths: Seq[String],
userSpecifiedSchema: Option[StructType],
fallbackFileFormat: Class[_ <: FileFormat])
extends FileTable(sparkSession, options, paths, userSpecifiedSchema) {

/** Construct a new scan builder. */
Expand All @@ -51,28 +52,22 @@ case class GarTable(
override def inferSchema(files: Seq[FileStatus]): Option[StructType] = formatName match {
case "csv" => {
val parsedOptions = new CSVOptions(
options.asScala.toMap,
columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
sparkSession.sessionState.conf.sessionLocalTimeZone)
options.asScala.toMap,
columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
sparkSession.sessionState.conf.sessionLocalTimeZone)

CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions)
}
case "orc" => OrcUtils.inferSchema(sparkSession, files, options.asScala.toMap)
case "parquet" => ParquetUtils.inferSchema(sparkSession, options.asScala.toMap, files)
case _ => throw new IllegalArgumentException
}

/** Construct a new write builder according to the actual file format. */
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = formatName match {
case "csv" => new WriteBuilder {
override def build(): Write = CSVWrite(paths, formatName, supportsDataType, info)
}
case "orc" => new WriteBuilder {
override def build(): Write = OrcWrite(paths, formatName, supportsDataType, info)
}
case "parquet" => new WriteBuilder {
override def build(): Write = ParquetWrite(paths, formatName, supportsDataType, info)
}
case "csv" => new CSVWriteBuilder(paths, formatName, supportsDataType, info)
case "orc" => new OrcWriteBuilder(paths, formatName, supportsDataType, info)
case "parquet" => new ParquetWriteBuilder(paths, formatName, supportsDataType, info)
case _ => throw new IllegalArgumentException
}

Expand Down
Loading

0 comments on commit ad30121

Please sign in to comment.