diff --git a/.github/workflows/spark.yaml b/.github/workflows/spark.yaml index e315bdf0d..1489cbc5c 100644 --- a/.github/workflows/spark.yaml +++ b/.github/workflows/spark.yaml @@ -37,6 +37,17 @@ concurrency: jobs: GraphAr-spark: runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + include: + - mvn-profile: "datasources-32" + spark: "spark-3.2.2" + spark-hadoop: "spark-3.2.2-bin-hadoop3.2" + - mvn-profile: "datasources-33" + spark: "spark-3.3.4" + spark-hadoop: "spark-3.3.4-bin-hadoop3" + steps: - uses: actions/checkout@v3 with: @@ -53,22 +64,24 @@ jobs: run: | export JAVA_HOME=${JAVA_HOME_11_X64} pushd spark - mvn --no-transfer-progress clean package -DskipTests -Dspotless.check.skip=true + echo "Build ${{ matrix.mvn-profile }}" + mvn --no-transfer-progress clean package -DskipTests -Dspotless.check.skip=true -P ${{ matrix.mvn-profile }} popd - name: Run test run: | export JAVA_HOME=${JAVA_HOME_11_X64} pushd spark - mvn --no-transfer-progress test -Dspotless.check.skip=true + echo "Test ${{ matrix.mvn-profile }}" + mvn --no-transfer-progress test -Dspotless.check.skip=true -P ${{ matrix.mvn-profile }} popd - name: Run Neo4j2GraphAr example run: | export JAVA_HOME=${JAVA_HOME_11_X64} pushd spark - scripts/get-spark-to-home.sh - export SPARK_HOME="${HOME}/spark-3.2.2-bin-hadoop3.2" + scripts/get-spark-to-home.sh ${{ matrix.spark }} ${{ matrix.spark-hadoop }} + export SPARK_HOME="${HOME}/${{ matrix.spark-hadoop }}" export PATH="${SPARK_HOME}/bin":"${PATH}" scripts/get-neo4j-to-home.sh @@ -78,7 +91,7 @@ jobs: scripts/deploy-neo4j-movie-data.sh - scripts/build.sh + scripts/build.sh ${{ matrix.mvn-profile }} export NEO4J_USR="neo4j" export NEO4J_PWD="neo4j" @@ -90,20 +103,20 @@ jobs: # stop and clean popd - + - name: Run Nebula2GraphAr example run: | export JAVA_HOME=${JAVA_HOME_11_X64} pushd spark scripts/get-nebula-to-home.sh - export SPARK_HOME="${HOME}/spark-3.2.2-bin-hadoop3.2" + export SPARK_HOME="${HOME}/${{ matrix.spark-hadoop }}" export PATH="${SPARK_HOME}/bin":"${PATH}" scripts/get-nebula-to-home.sh scripts/deploy-nebula-default-data.sh - scripts/build.sh + scripts/build.sh ${{ matrix.mvn-profile }} scripts/run-nebula2graphar.sh @@ -113,7 +126,7 @@ jobs: --name nebula-console-loader \ --network nebula-docker-env_nebula-net \ vesoft/nebula-console:nightly -addr 172.28.3.1 -port 9669 -u root -p nebula -e "use basketballplayer; clear space basketballplayer;" - + # import from GraphAr scripts/run-graphar2nebula.sh @@ -124,11 +137,10 @@ jobs: run: | export JAVA_HOME=${JAVA_HOME_11_X64} pushd spark - scripts/get-spark-to-home.sh - export SPARK_HOME="${HOME}/spark-3.2.2-bin-hadoop3.2" + export SPARK_HOME="${HOME}/${{ matrix.spark-hadoop }}" export PATH="${SPARK_HOME}/bin":"${PATH}" - scripts/build.sh + scripts/build.sh ${{ matrix.mvn-profile }} # run the importer cd import diff --git a/.licenserc.yaml b/.licenserc.yaml index 45b89c4cf..c6b69dc01 100644 --- a/.licenserc.yaml +++ b/.licenserc.yaml @@ -37,6 +37,7 @@ header: - '**/.scalafmt.conf' - 'cpp/apidoc' - 'spark/datasources-32/src/main/scala/com/alibaba/graphar/datasources' + - 'spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources' - '*.md' - '*.rst' - '**/*.json' diff --git a/spark/datasources-33/.scalafmt.conf b/spark/datasources-33/.scalafmt.conf new file mode 120000 index 000000000..4cb05e831 --- /dev/null +++ b/spark/datasources-33/.scalafmt.conf @@ -0,0 +1 @@ +../.scalafmt.conf \ No newline at end of file diff --git a/spark/datasources-33/pom.xml b/spark/datasources-33/pom.xml new file mode 100644 index 000000000..e1af90c01 --- /dev/null +++ b/spark/datasources-33/pom.xml @@ -0,0 +1,188 @@ + + + + + 4.0.0 + + + com.alibaba + graphar + ${graphar.version} + + + com.alibaba + graphar-datasources + ${graphar.version} + jar + + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-mllib_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-hive_${scala.binary.version} + ${spark.version} + provided + + + + + + + org.scala-tools + maven-scala-plugin + 2.15.2 + + ${scala.version} + + -target:jvm-1.8 + + + -Xss4096K + + + + + scala-compile + + compile + + + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + scala-test-compile + + testCompile + + + + + + net.alchim31.maven + scala-maven-plugin + 4.8.0 + + + + compile + testCompile + + + + + + -Xms64m + -Xmx1024m + + + -Ywarn-unused + + + + org.scalameta + semanticdb-scalac_2.12.10 + 4.3.24 + + + + + + com.diffplug.spotless + spotless-maven-plugin + 2.20.0 + + + + + + + 1.13.0 + + + + + + ${project.basedir}/.scalafmt.conf + + + + + + io.github.evis + scalafix-maven-plugin_2.13 + 0.1.8_0.11.0 + + + org.apache.maven.plugins + maven-source-plugin + + + attach-sources + + jar + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + + + attach-javadocs + + jar + + + + + + maven-site-plugin + 3.7.1 + + + + diff --git a/spark/datasources-33/src/main/java/com/alibaba/graphar/GeneralParams.java b/spark/datasources-33/src/main/java/com/alibaba/graphar/GeneralParams.java new file mode 120000 index 000000000..972663dd8 --- /dev/null +++ b/spark/datasources-33/src/main/java/com/alibaba/graphar/GeneralParams.java @@ -0,0 +1 @@ +../../../../../../../graphar/src/main/java/com/alibaba/graphar/GeneralParams.java \ No newline at end of file diff --git a/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarCommitProtocol.scala b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarCommitProtocol.scala new file mode 100644 index 000000000..527a3bc5c --- /dev/null +++ b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarCommitProtocol.scala @@ -0,0 +1,95 @@ +/* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.graphar.datasources + +import com.alibaba.graphar.GeneralParams + +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol +import org.apache.hadoop.mapreduce._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.io.FileNameSpec + +object GarCommitProtocol { + private def binarySearchPair(aggNums: Array[Int], key: Int): (Int, Int) = { + var low = 0 + var high = aggNums.length - 1 + var mid = 0 + while (low <= high) { + mid = (high + low) / 2; + if ( + aggNums(mid) <= key && (mid == aggNums.length - 1 || aggNums( + mid + 1 + ) > key) + ) { + return (mid, key - aggNums(mid)) + } else if (aggNums(mid) > key) { + high = mid - 1 + } else { + low = mid + 1 + } + } + return (low, key - aggNums(low)) + } +} + +class GarCommitProtocol( + jobId: String, + path: String, + options: Map[String, String], + dynamicPartitionOverwrite: Boolean = false +) extends SQLHadoopMapReduceCommitProtocol( + jobId, + path, + dynamicPartitionOverwrite + ) + with Serializable + with Logging { + + override def getFilename( + taskContext: TaskAttemptContext, + spec: FileNameSpec + ): String = { + val partitionId = taskContext.getTaskAttemptID.getTaskID.getId + if (options.contains(GeneralParams.offsetStartChunkIndexKey)) { + // offset chunk file name, looks like chunk0 + val chunk_index = options + .get(GeneralParams.offsetStartChunkIndexKey) + .get + .toInt + partitionId + return f"chunk$chunk_index" + } + if (options.contains(GeneralParams.aggNumListOfEdgeChunkKey)) { + // edge chunk file name, looks like part0/chunk0 + val jValue = parse( + options.get(GeneralParams.aggNumListOfEdgeChunkKey).get + ) + implicit val formats = + DefaultFormats // initialize a default formats for json4s + val aggNums: Array[Int] = Extraction.extract[Array[Int]](jValue) + val chunkPair: (Int, Int) = + GarCommitProtocol.binarySearchPair(aggNums, partitionId) + val vertex_chunk_index: Int = chunkPair._1 + val edge_chunk_index: Int = chunkPair._2 + return f"part$vertex_chunk_index/chunk$edge_chunk_index" + } + // vertex chunk file name, looks like chunk0 + return f"chunk$partitionId" + } +} diff --git a/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarDataSource.scala b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarDataSource.scala new file mode 100644 index 000000000..d4fe44fd1 --- /dev/null +++ b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarDataSource.scala @@ -0,0 +1,178 @@ +/* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.graphar.datasources + +import scala.collection.JavaConverters._ +import scala.util.matching.Regex +import java.util + +import com.fasterxml.jackson.databind.ObjectMapper +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.connector.catalog.{Table, TableProvider} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.connector.expressions.Transform + +object GarUtils + +/** + * GarDataSource is a class to provide gar files as the data source for spark. + */ +class GarDataSource extends TableProvider with DataSourceRegister { + private val REDACTION_REPLACEMENT_TEXT = "*********(redacted)" + + /** + * Redact the sensitive information in the given string. + */ + // Copy of redact from graphar Utils + private def redact(regex: Option[Regex], text: String): String = { + regex match { + case None => text + case Some(r) => + if (text == null || text.isEmpty) { + text + } else { + r.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT) + } + } + } + + /** The default fallback file format is Parquet. */ + def fallbackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat] + + lazy val sparkSession = SparkSession.active + + /** The string that represents the format name. */ + override def shortName(): String = "gar" + + protected def getPaths(map: CaseInsensitiveStringMap): Seq[String] = { + val objectMapper = new ObjectMapper() + val paths = Option(map.get("paths")) + .map { pathStr => + objectMapper.readValue(pathStr, classOf[Array[String]]).toSeq + } + .getOrElse(Seq.empty) + paths ++ Option(map.get("path")).toSeq + } + + protected def getOptionsWithoutPaths( + map: CaseInsensitiveStringMap + ): CaseInsensitiveStringMap = { + val withoutPath = map.asCaseSensitiveMap().asScala.filterKeys { k => + !k.equalsIgnoreCase("path") && !k.equalsIgnoreCase("paths") + } + new CaseInsensitiveStringMap(withoutPath.toMap.asJava) + } + + protected def getTableName( + map: CaseInsensitiveStringMap, + paths: Seq[String] + ): String = { + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions( + map.asCaseSensitiveMap().asScala.toMap + ) + val name = shortName() + " " + paths + .map(qualifiedPathName(_, hadoopConf)) + .mkString(",") + redact(sparkSession.sessionState.conf.stringRedactionPattern, name) + } + + private def qualifiedPathName( + path: String, + hadoopConf: Configuration + ): String = { + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toString + } + + /** Provide a table from the data source. */ + def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GarTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + None, + getFallbackFileFormat(options) + ) + } + + /** Provide a table from the data source with specific schema. */ + def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GarTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + getFallbackFileFormat(options) + ) + } + + override def supportsExternalMetadata(): Boolean = true + + private var t: Table = null + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + if (t == null) t = getTable(options) + t.schema() + } + + override def inferPartitioning( + options: CaseInsensitiveStringMap + ): Array[Transform] = { + Array.empty + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String] + ): Table = { + // If the table is already loaded during schema inference, return it directly. + if (t != null) { + t + } else { + getTable(new CaseInsensitiveStringMap(properties), schema) + } + } + + // Get the actual fall back file format. + private def getFallbackFileFormat( + options: CaseInsensitiveStringMap + ): Class[_ <: FileFormat] = options.get("fileFormat") match { + case "csv" => classOf[CSVFileFormat] + case "orc" => classOf[OrcFileFormat] + case "parquet" => classOf[ParquetFileFormat] + case _ => throw new IllegalArgumentException + } +} diff --git a/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala new file mode 100644 index 000000000..5d1653dc7 --- /dev/null +++ b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala @@ -0,0 +1,303 @@ +/* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.graphar.datasources + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetInputFormat + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils} +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.PartitionedFileUtil +import org.apache.spark.sql.execution.datasources.{ + FilePartition, + PartitioningAwareFileIndex, + PartitionedFile +} +import org.apache.spark.sql.execution.datasources.parquet.{ + ParquetOptions, + ParquetReadSupport, + ParquetWriteSupport +} +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetPartitionReaderFactory +import org.apache.spark.sql.execution.datasources.v2.orc.OrcPartitionReaderFactory +import org.apache.spark.sql.execution.datasources.v2.csv.CSVPartitionReaderFactory +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +/** GarScan is a class to implement the file scan for GarDataSource. */ +case class GarScan( + sparkSession: SparkSession, + hadoopConf: Configuration, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + pushedFilters: Array[Filter], + options: CaseInsensitiveStringMap, + formatName: String, + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty +) extends FileScan { + + /** The gar format is not splitable. */ + override def isSplitable(path: Path): Boolean = false + + /** Create the reader factory according to the actual file format. */ + override def createReaderFactory(): PartitionReaderFactory = + formatName match { + case "csv" => createCSVReaderFactory() + case "orc" => createOrcReaderFactory() + case "parquet" => createParquetReaderFactory() + case _ => + throw new IllegalArgumentException("Invalid format name: " + formatName) + } + + // Create the reader factory for the CSV format. + private def createCSVReaderFactory(): PartitionReaderFactory = { + val columnPruning = sparkSession.sessionState.conf.csvColumnPruning && + !readDataSchema.exists( + _.name == sparkSession.sessionState.conf.columnNameOfCorruptRecord + ) + + val parsedOptions: CSVOptions = new CSVOptions( + options.asScala.toMap, + columnPruning = columnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord + ) + + // Check a field requirement for corrupt records here to throw an exception in a driver side + ExprUtils.verifyColumnNameOfCorruptRecord( + dataSchema, + parsedOptions.columnNameOfCorruptRecord + ) + // Don't push any filter which refers to the "virtual" column which cannot present in the input. + // Such filters will be applied later on the upper layer. + val actualFilters = + pushedFilters.filterNot( + _.references.contains(parsedOptions.columnNameOfCorruptRecord) + ) + + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = + sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = sparkSession.sparkContext.broadcast( + new SerializableConfiguration(hadoopConf) + ) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. + CSVPartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + parsedOptions, + actualFilters + ) + } + + // Create the reader factory for the Orc format. + private def createOrcReaderFactory(): PartitionReaderFactory = { + val broadcastedConf = sparkSession.sparkContext.broadcast( + new SerializableConfiguration(hadoopConf) + ) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. + OrcPartitionReaderFactory( + sqlConf = sparkSession.sessionState.conf, + broadcastedConf = broadcastedConf, + dataSchema = dataSchema, + readDataSchema = readDataSchema, + partitionSchema = readPartitionSchema, + filters = pushedFilters, + aggregation = None + ) + } + + // Create the reader factory for the Parquet format. + private def createParquetReaderFactory(): PartitionReaderFactory = { + val readDataSchemaAsJson = readDataSchema.json + hadoopConf.set( + ParquetInputFormat.READ_SUPPORT_CLASS, + classOf[ParquetReadSupport].getName + ) + hadoopConf.set( + ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + readDataSchemaAsJson + ) + hadoopConf.set(ParquetWriteSupport.SPARK_ROW_SCHEMA, readDataSchemaAsJson) + hadoopConf.set( + SQLConf.SESSION_LOCAL_TIMEZONE.key, + sparkSession.sessionState.conf.sessionLocalTimeZone + ) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sparkSession.sessionState.conf.nestedSchemaPruningEnabled + ) + hadoopConf.setBoolean( + SQLConf.CASE_SENSITIVE.key, + sparkSession.sessionState.conf.caseSensitiveAnalysis + ) + + ParquetWriteSupport.setSchema(readDataSchema, hadoopConf) + + // Sets flags for `ParquetToSparkSchemaConverter` + hadoopConf.setBoolean( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sparkSession.sessionState.conf.isParquetBinaryAsString + ) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sparkSession.sessionState.conf.isParquetINT96AsTimestamp + ) + hadoopConf.setBoolean( + SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key, + sparkSession.sessionState.conf.legacyParquetNanosAsLong + ) + hadoopConf.setBoolean( + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key, + sparkSession.sessionState.conf.parquetFieldIdReadEnabled + ) + + val broadcastedConf = sparkSession.sparkContext.broadcast( + new SerializableConfiguration(hadoopConf) + ) + val sqlConf = sparkSession.sessionState.conf + ParquetPartitionReaderFactory( + sqlConf = sqlConf, + broadcastedConf = broadcastedConf, + dataSchema = dataSchema, + readDataSchema = readDataSchema, + partitionSchema = readPartitionSchema, + filters = pushedFilters, + aggregation = None, + new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf) + ) + } + + /** + * Override "partitions" of + * org.apache.spark.sql.execution.datasources.v2.FileScan to disable splitting + * and sort the files by file paths instead of by file sizes. Note: This + * implementation does not support to partition attributes. + */ + override protected def partitions: Seq[FilePartition] = { + val selectedPartitions = fileIndex.listFiles(partitionFilters, dataFilters) + val maxSplitBytes = + FilePartition.maxSplitBytes(sparkSession, selectedPartitions) + + val splitFiles = selectedPartitions.flatMap { partition => + val partitionValues = partition.values + partition.files + .flatMap { file => + val filePath = file.getPath + PartitionedFileUtil.splitFiles( + sparkSession = sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable(filePath), + maxSplitBytes = maxSplitBytes, + partitionValues = partitionValues + ) + } + .toArray + .sortBy(_.filePath) + } + + getFilePartitions(sparkSession, splitFiles) + } + + /** + * Override "getFilePartitions" of + * org.apache.spark.sql.execution.datasources.FilePartition to assign each + * chunk file in GraphAr to a single partition. + */ + private def getFilePartitions( + sparkSession: SparkSession, + partitionedFiles: Seq[PartitionedFile] + ): Seq[FilePartition] = { + val partitions = new ArrayBuffer[FilePartition] + val currentFiles = new ArrayBuffer[PartitionedFile] + + /** Close the current partition and move to the next. */ + def closePartition(): Unit = { + if (currentFiles.nonEmpty) { + // Copy to a new Array. + val newPartition = FilePartition(partitions.size, currentFiles.toArray) + partitions += newPartition + } + currentFiles.clear() + } + // Assign a file to each partition + partitionedFiles.foreach { file => + closePartition() + // Add the given file to the current partition. + currentFiles += file + } + closePartition() + partitions.toSeq + } + + /** Check if two objects are equal. */ + override def equals(obj: Any): Boolean = obj match { + case g: GarScan => + super.equals(g) && dataSchema == g.dataSchema && options == g.options && + equivalentFilters( + pushedFilters, + g.pushedFilters + ) && formatName == g.formatName + case _ => false + } + + /** Get the hash code of the object. */ + override def hashCode(): Int = formatName match { + case "csv" => super.hashCode() + case "orc" => getClass.hashCode() + case "parquet" => getClass.hashCode() + case _ => + throw new IllegalArgumentException("Invalid format name: " + formatName) + } + + /** Get the description string of the object. */ + override def description(): String = { + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + } + + /** Get the meta data map of the object. */ + override def getMetaData(): Map[String, String] = { + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + } + + /** Construct the file scan with filters. */ + def withFilters( + partitionFilters: Seq[Expression], + dataFilters: Seq[Expression] + ): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) +} diff --git a/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala new file mode 100644 index 000000000..75d517211 --- /dev/null +++ b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala @@ -0,0 +1,106 @@ +/* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.graphar.datasources + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex + +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import scala.collection.JavaConverters._ +import org.apache.spark.sql.execution.datasources.v2.orc.OrcScanBuilder +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScanBuilder + +/** GarScanBuilder is a class to build the file scan for GarDataSource. */ +case class GarScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap, + formatName: String +) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + lazy val hadoopConf = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + } + + private var filters: Array[Filter] = Array.empty + + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { + this.filters = dataFilters + formatName match { + case "csv" => Array.empty[Filter] + case "orc" => pushedOrcFilters + case "parquet" => pushedParquetFilters + case _ => + throw new IllegalArgumentException("Invalid format name: " + formatName) + } + } + + private lazy val pushedParquetFilters: Array[Filter] = { + if (!sparkSession.sessionState.conf.parquetFilterPushDown) { + Array.empty[Filter] + } else { + val builder = + ParquetScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + builder.pushDataFilters(this.filters) + builder.pushedParquetFilters + } + } + + private lazy val pushedOrcFilters: Array[Filter] = { + if (!sparkSession.sessionState.conf.orcFilterPushDown) { + Array.empty[Filter] + } else { + val builder = + OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + builder.pushDataFilters(this.filters) + } + } + + // Check if the file format supports nested schema pruning. + override protected val supportsNestedSchemaPruning: Boolean = + formatName match { + case "csv" => false + case "orc" => sparkSession.sessionState.conf.nestedSchemaPruningEnabled + case "parquet" => + sparkSession.sessionState.conf.nestedSchemaPruningEnabled + case _ => + throw new IllegalArgumentException("Invalid format name: " + formatName) + } + + /** Build the file scan for GarDataSource. */ + override def build(): Scan = { + GarScan( + sparkSession, + hadoopConf, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + pushedDataFilters, + options, + formatName + ) + } +} diff --git a/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarTable.scala b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarTable.scala new file mode 100644 index 000000000..66c710026 --- /dev/null +++ b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarTable.scala @@ -0,0 +1,131 @@ +/* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.graphar.datasources + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.FileStatus + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.csv.CSVDataSource +import org.apache.spark.sql.execution.datasources.orc.OrcUtils +import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import com.alibaba.graphar.datasources.csv.CSVWriteBuilder +import com.alibaba.graphar.datasources.parquet.ParquetWriteBuilder +import com.alibaba.graphar.datasources.orc.OrcWriteBuilder + +/** GarTable is a class to represent the graph data in GraphAr as a table. */ +case class GarTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat] +) extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + /** Construct a new scan builder. */ + override def newScanBuilder( + options: CaseInsensitiveStringMap + ): GarScanBuilder = + new GarScanBuilder( + sparkSession, + fileIndex, + schema, + dataSchema, + options, + formatName + ) + + /** + * Infer the schema of the table through the methods of the actual file + * format. + */ + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = + formatName match { + case "csv" => { + val parsedOptions = new CSVOptions( + options.asScala.toMap, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone + ) + + CSVDataSource(parsedOptions).inferSchema( + sparkSession, + files, + parsedOptions + ) + } + case "orc" => + OrcUtils.inferSchema(sparkSession, files, options.asScala.toMap) + case "parquet" => + ParquetUtils.inferSchema(sparkSession, options.asScala.toMap, files) + case _ => + throw new IllegalArgumentException("Invalid format name: " + formatName) + } + + /** Construct a new write builder according to the actual file format. */ + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = + formatName match { + case "csv" => + new CSVWriteBuilder(paths, formatName, supportsDataType, info) + case "orc" => + new OrcWriteBuilder(paths, formatName, supportsDataType, info) + case "parquet" => + new ParquetWriteBuilder(paths, formatName, supportsDataType, info) + case _ => + throw new IllegalArgumentException("Invalid format name: " + formatName) + } + + /** + * Check if a data type is supported. Note: Currently, the GraphAr data source + * only supports several atomic data types. To support additional data types + * such as Struct, Array and Map, revise this function to handle them case by + * case as the commented code shows. + */ + override def supportsDataType(dataType: DataType): Boolean = dataType match { + // case _: AnsiIntervalType => false + + case _: AtomicType => true + + // case st: StructType => st.forall { f => supportsDataType(f.dataType) } + + case ArrayType(elementType, _) => + formatName match { + case "orc" => supportsDataType(elementType) + case "parquet" => supportsDataType(elementType) + case _ => false + } + + // case MapType(keyType, valueType, _) => + // supportsDataType(keyType) && supportsDataType(valueType) + + // case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _ => false + } + + /** The actual file format for storing the data in GraphAr. */ + override def formatName: String = options.get("fileFormat") +} diff --git a/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarWriterBuilder.scala b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarWriterBuilder.scala new file mode 100644 index 000000000..55af4e4ca --- /dev/null +++ b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/GarWriterBuilder.scala @@ -0,0 +1,176 @@ +/* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * The implementation of GarWriteBuilder is referred from FileWriteBuilder of spark 3.1.1 + */ + +package com.alibaba.graphar.datasources + +import java.util.UUID + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.sql.execution.datasources.OutputWriterFactory +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.connector.write.{ + BatchWrite, + LogicalWriteInfo, + WriteBuilder +} +import org.apache.spark.sql.execution.datasources.{ + BasicWriteJobStatsTracker, + DataSource, + OutputWriterFactory, + WriteJobDescription +} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.sql.execution.datasources.v2.FileBatchWrite +import org.apache.spark.sql.catalyst.expressions.AttributeReference + +abstract class GarWriteBuilder( + paths: Seq[String], + formatName: String, + supportsDataType: DataType => Boolean, + info: LogicalWriteInfo +) extends WriteBuilder { + private val schema = info.schema() + private val queryId = info.queryId() + private val options = info.options() + + override def buildForBatch(): BatchWrite = { + val sparkSession = SparkSession.active + validateInputs(sparkSession.sessionState.conf.caseSensitiveAnalysis) + val path = new Path(paths.head) + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = + sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val job = getJobInstance(hadoopConf, path) + val committer = new GarCommitProtocol( + java.util.UUID.randomUUID().toString, + paths.head, + options.asScala.toMap, + false + ) + lazy val description = + createWriteJobDescription( + sparkSession, + hadoopConf, + job, + paths.head, + options.asScala.toMap + ) + + committer.setupJob(job) + new FileBatchWrite(job, description, committer) + } + + def prepareWrite( + sqlConf: SQLConf, + job: Job, + options: Map[String, String], + dataSchema: StructType + ): OutputWriterFactory + + private def validateInputs(caseSensitiveAnalysis: Boolean): Unit = { + assert(schema != null, "Missing input data schema") + assert(queryId != null, "Missing query ID") + + if (paths.length != 1) { + throw new IllegalArgumentException( + "Expected exactly one path to be specified, but " + + s"got: ${paths.mkString(", ")}" + ) + } + val pathName = paths.head + DataSource.validateSchema(schema) + + schema.foreach { field => + if (!supportsDataType(field.dataType)) { + throw new IllegalArgumentException( + s"$formatName data source does not support ${field.dataType.catalogString} data type." + ) + } + } + } + + private def getJobInstance(hadoopConf: Configuration, path: Path): Job = { + val job = Job.getInstance(hadoopConf) + job.setOutputKeyClass(classOf[Void]) + job.setOutputValueClass(classOf[InternalRow]) + FileOutputFormat.setOutputPath(job, path) + job + } + + private def createWriteJobDescription( + sparkSession: SparkSession, + hadoopConf: Configuration, + job: Job, + pathName: String, + options: Map[String, String] + ): WriteJobDescription = { + val caseInsensitiveOptions = CaseInsensitiveMap(options) + // Note: prepareWrite has side effect. It sets "job". + val outputWriterFactory = + prepareWrite( + sparkSession.sessionState.conf, + job, + caseInsensitiveOptions, + schema + ) + // same as schema.toAttributes which is private of spark package + val allColumns: Seq[AttributeReference] = schema.map(f => + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() + ) + val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + val statsTracker = + new BasicWriteJobStatsTracker(serializableHadoopConf, metrics) + // TODO: after partitioning is supported in V2: + // 1. filter out partition columns in `dataColumns`. + // 2. Don't use Seq.empty for `partitionColumns`. + new WriteJobDescription( + uuid = UUID.randomUUID().toString, + serializableHadoopConf = + new SerializableConfiguration(job.getConfiguration), + outputWriterFactory = outputWriterFactory, + allColumns = allColumns, + dataColumns = allColumns, + partitionColumns = Seq.empty, + bucketSpec = None, + path = pathName, + customPartitionLocations = Map.empty, + maxRecordsPerFile = caseInsensitiveOptions + .get("maxRecordsPerFile") + .map(_.toLong) + .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), + timeZoneId = caseInsensitiveOptions + .get(DateTimeUtils.TIMEZONE_OPTION) + .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone), + statsTrackers = Seq(statsTracker) + ) + } +} diff --git a/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/csv/CSVWriterBuilder.scala b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/csv/CSVWriterBuilder.scala new file mode 100644 index 000000000..977dd05a3 --- /dev/null +++ b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/csv/CSVWriterBuilder.scala @@ -0,0 +1,72 @@ +/* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * The implementation of CSVWriteBuilder is refered from CSVWriteBuilder of spark 3.1.1 + */ + +package com.alibaba.graphar.datasources.csv + +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.catalyst.util.CompressionCodecs +import org.apache.spark.sql.connector.write.LogicalWriteInfo +import org.apache.spark.sql.execution.datasources.{ + CodecStreams, + OutputWriter, + OutputWriterFactory +} +import org.apache.spark.sql.execution.datasources.csv.CsvOutputWriter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, StructType} + +import com.alibaba.graphar.datasources.GarWriteBuilder + +class CSVWriteBuilder( + paths: Seq[String], + formatName: String, + supportsDataType: DataType => Boolean, + info: LogicalWriteInfo +) extends GarWriteBuilder(paths, formatName, supportsDataType, info) { + override def prepareWrite( + sqlConf: SQLConf, + job: Job, + options: Map[String, String], + dataSchema: StructType + ): OutputWriterFactory = { + val conf = job.getConfiguration + val csvOptions = new CSVOptions( + options, + columnPruning = sqlConf.csvColumnPruning, + sqlConf.sessionLocalTimeZone + ) + csvOptions.compressionCodec.foreach { codec => + CompressionCodecs.setCodecConfiguration(conf, codec) + } + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext + ): OutputWriter = { + new CsvOutputWriter(path, dataSchema, context, csvOptions) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + ".csv" + CodecStreams.getCompressionExtension(context) + } + } + } +} diff --git a/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/orc/OrcOutputWriter.scala b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/orc/OrcOutputWriter.scala new file mode 100644 index 000000000..addb7bdd9 --- /dev/null +++ b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/orc/OrcOutputWriter.scala @@ -0,0 +1,68 @@ +/* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * The implementation of OrcOutputWriter is referred from OrcOutputWriter of spark 3.1.1 + */ + +package com.alibaba.graphar.datasources.orc + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.orc.OrcFile +import org.apache.orc.mapred.{ + OrcOutputFormat => OrcMapRedOutputFormat, + OrcStruct +} +import org.apache.orc.mapreduce.{OrcMapreduceRecordWriter, OrcOutputFormat} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.OutputWriter +import org.apache.spark.sql.execution.datasources.orc.{OrcSerializer, OrcUtils} +import org.apache.spark.sql.types._ + +class OrcOutputWriter( + val path: String, + dataSchema: StructType, + context: TaskAttemptContext +) extends OutputWriter { + + private[this] val serializer = new OrcSerializer(dataSchema) + + private val recordWriter = { + val orcOutputFormat = new OrcOutputFormat[OrcStruct]() { + override def getDefaultWorkFile( + context: TaskAttemptContext, + extension: String + ): Path = { + new Path(path) + } + } + val filename = orcOutputFormat.getDefaultWorkFile(context, ".orc") + val options = OrcMapRedOutputFormat.buildOptions(context.getConfiguration) + val writer = OrcFile.createWriter(filename, options) + val recordWriter = new OrcMapreduceRecordWriter[OrcStruct](writer) + OrcUtils.addSparkVersionMetadata(writer) + recordWriter + } + + override def write(row: InternalRow): Unit = { + recordWriter.write(NullWritable.get(), serializer.serialize(row)) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} diff --git a/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/orc/OrcWriteBuilder.scala b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/orc/OrcWriteBuilder.scala new file mode 100644 index 000000000..1fe41738d --- /dev/null +++ b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/orc/OrcWriteBuilder.scala @@ -0,0 +1,103 @@ +/* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * The implementation of OrcWriteBuilder is referred from OrcWriteBuilder of spark 3.1.1 + */ + +package com.alibaba.graphar.datasources.orc + +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.orc.OrcConf.{COMPRESS, MAPRED_OUTPUT_SCHEMA} +import org.apache.orc.mapred.OrcStruct + +import org.apache.spark.sql.connector.write.LogicalWriteInfo +import org.apache.spark.sql.execution.datasources.{ + OutputWriter, + OutputWriterFactory +} +import org.apache.spark.sql.execution.datasources.orc.{OrcOptions, OrcUtils} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +import com.alibaba.graphar.datasources.GarWriteBuilder + +object OrcWriteBuilder { + // the getQuotedSchemaString method of spark OrcFileFormat + private def getQuotedSchemaString(dataType: DataType): String = + dataType match { + case StructType(fields) => + fields + .map(f => s"`${f.name}`:${getQuotedSchemaString(f.dataType)}") + .mkString("struct<", ",", ">") + case ArrayType(elementType, _) => + s"array<${getQuotedSchemaString(elementType)}>" + case MapType(keyType, valueType, _) => + s"map<${getQuotedSchemaString(keyType)},${getQuotedSchemaString(valueType)}>" + case _ => // UDT and others + dataType.catalogString + } +} + +class OrcWriteBuilder( + paths: Seq[String], + formatName: String, + supportsDataType: DataType => Boolean, + info: LogicalWriteInfo +) extends GarWriteBuilder(paths, formatName, supportsDataType, info) { + + override def prepareWrite( + sqlConf: SQLConf, + job: Job, + options: Map[String, String], + dataSchema: StructType + ): OutputWriterFactory = { + val orcOptions = new OrcOptions(options, sqlConf) + + val conf = job.getConfiguration + + conf.set( + MAPRED_OUTPUT_SCHEMA.getAttribute, + OrcWriteBuilder.getQuotedSchemaString(dataSchema) + ) + + conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) + + conf + .asInstanceOf[JobConf] + .setOutputFormat( + classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]] + ) + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext + ): OutputWriter = { + new OrcOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + val compressionExtension: String = { + val name = context.getConfiguration.get(COMPRESS.getAttribute) + OrcUtils.extensionsForCompressionCodecNames.getOrElse(name, "") + } + + compressionExtension + ".orc" + } + } + } +} diff --git a/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/parquet/ParquetWriterBuilder.scala b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/parquet/ParquetWriterBuilder.scala new file mode 100644 index 000000000..1a5b8bfff --- /dev/null +++ b/spark/datasources-33/src/main/scala/com/alibaba/graphar/datasources/parquet/ParquetWriterBuilder.scala @@ -0,0 +1,151 @@ +/* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * The implementation of ParquetWriteBuilder is referred from ParquetWriteBuilder of spark 3.1.1 + */ + +package com.alibaba.graphar.datasources.parquet + +import org.apache.hadoop.mapreduce.{Job, OutputCommitter, TaskAttemptContext} +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} +import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel +import org.apache.parquet.hadoop.codec.CodecConfig +import org.apache.parquet.hadoop.util.ContextUtil + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.write.LogicalWriteInfo +import org.apache.spark.sql.execution.datasources.{ + OutputWriter, + OutputWriterFactory +} +import org.apache.spark.sql.execution.datasources.parquet._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +import com.alibaba.graphar.datasources.GarWriteBuilder + +class ParquetWriteBuilder( + paths: Seq[String], + formatName: String, + supportsDataType: DataType => Boolean, + info: LogicalWriteInfo +) extends GarWriteBuilder(paths, formatName, supportsDataType, info) + with Logging { + + override def prepareWrite( + sqlConf: SQLConf, + job: Job, + options: Map[String, String], + dataSchema: StructType + ): OutputWriterFactory = { + val parquetOptions = new ParquetOptions(options, sqlConf) + + val conf = ContextUtil.getConfiguration(job) + + val committerClass = + conf.getClass( + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, + classOf[ParquetOutputCommitter], + classOf[OutputCommitter] + ) + + if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { + logInfo( + "Using default output committer for Parquet: " + + classOf[ParquetOutputCommitter].getCanonicalName + ) + } else { + logInfo( + "Using user defined output committer for Parquet: " + committerClass.getCanonicalName + ) + } + + conf.setClass( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + committerClass, + classOf[OutputCommitter] + ) + + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + + ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) + + // This metadata is useful for keeping UDTs like Vector/Matrix. + ParquetWriteSupport.setSchema(dataSchema, conf) + + // Sets flags for `ParquetWriteSupport`, which converts Catalyst schema to Parquet + // schema and writes actual rows to Parquet files. + conf.set( + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, + sqlConf.writeLegacyParquetFormat.toString + ) + + conf.set( + SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, + sqlConf.parquetOutputTimestampType.toString + ) + + // Sets compression scheme + conf.set( + ParquetOutputFormat.COMPRESSION, + parquetOptions.compressionCodecClassName + ) + + // ParquetOutputWriter required fields starting from 3.3.x + conf.set( + SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key, + sqlConf.parquetFieldIdWriteEnabled.toString + ) + + // SPARK-15719: Disables writing Parquet summary files by default. + if ( + conf.get(ParquetOutputFormat.JOB_SUMMARY_LEVEL) == null + && conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null + ) { + conf.setEnum(ParquetOutputFormat.JOB_SUMMARY_LEVEL, JobSummaryLevel.NONE) + } + + if ( + ParquetOutputFormat.getJobSummaryLevel(conf) == JobSummaryLevel.NONE + && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass) + ) { + // output summary is requested, but the class is not a Parquet Committer + logWarning( + s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + + s" create job summaries. " + + s"Set Parquet option ${ParquetOutputFormat.JOB_SUMMARY_LEVEL} to NONE." + ) + } + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext + ): OutputWriter = { + new ParquetOutputWriter(path, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + CodecConfig.from(context).getCodec.getExtension + ".parquet" + } + } + } +} diff --git a/spark/pom.xml b/spark/pom.xml index 4fc8235ab..8586aa029 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -49,6 +49,27 @@ true + + datasources-33 + + graphar + UTF-8 + UTF-8 + 2.12.12 + 2.12 + 512m + 1024m + 3.3.4 + 1.8 + 1.8 + 3.3.8-public + 0.1.0-SNAPSHOT + + + graphar + datasources-33 + + diff --git a/spark/scripts/build.sh b/spark/scripts/build.sh index 04bb7c474..5022ad8bb 100755 --- a/spark/scripts/build.sh +++ b/spark/scripts/build.sh @@ -18,4 +18,4 @@ set -eu cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" cd .. -mvn clean package -DskipTests +mvn --no-transfer-progress clean package -DskipTests -P ${1:-'datasources-32'} diff --git a/spark/scripts/get-spark-to-home.sh b/spark/scripts/get-spark-to-home.sh index 773c6e030..caa81299a 100755 --- a/spark/scripts/get-spark-to-home.sh +++ b/spark/scripts/get-spark-to-home.sh @@ -17,4 +17,4 @@ set -eu cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -curl https://archive.apache.org/dist/spark/spark-3.2.2/spark-3.2.2-bin-hadoop3.2.tgz | tar -xz -C ${HOME}/ +curl https://archive.apache.org/dist/spark/${1}/${2}.tgz | tar -xz -C ${HOME}/