diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index d9d7b06cdb8c..48d55aa244ee 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -92,6 +92,33 @@ abstract class FileCommitProtocol extends Logging { */ def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + /** + * Notifies the commit protocol to add a new file, and gets back the full path that should be + * used. Must be called on the executors when running tasks. + * + * Note that the returned temp file may have an arbitrary path. The commit protocol only + * promises that the file will be at the location specified by the arguments after job commit. + * + * The "dir" parameter specifies the sub-directory within the base path, used to specify + * partitioning. The "spec" parameter specifies the file name. The rest are left to the commit + * protocol implementation to decide. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "spec" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + * + * @since 3.2.0 + */ + def newTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], spec: FileNameSpec): String = { + if (spec.prefix.isEmpty) { + newTaskTempFile(taskContext, dir, spec.suffix) + } else { + throw new UnsupportedOperationException(s"${getClass.getSimpleName}.newTaskTempFile does " + + s"not support file name prefix: ${spec.prefix}") + } + } + /** * Similar to newTaskTempFile(), but allows files to committed to an absolute output location. * Depending on the implementation, there may be weaker guarantees around adding files this way. @@ -103,6 +130,31 @@ abstract class FileCommitProtocol extends Logging { def newTaskTempFileAbsPath( taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String + /** + * Similar to newTaskTempFile(), but allows files to committed to an absolute output location. + * Depending on the implementation, there may be weaker guarantees around adding files this way. + * + * The "absoluteDir" parameter specifies the final absolute directory of file. The "spec" + * parameter specifies the file name. The rest are left to the commit protocol implementation to + * decide. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "spec" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + * + * @since 3.2.0 + */ + def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, spec: FileNameSpec): String = { + if (spec.prefix.isEmpty) { + newTaskTempFileAbsPath(taskContext, absoluteDir, spec.suffix) + } else { + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}.newTaskTempFileAbsPath does not support file name prefix: " + + s"${spec.prefix}") + } + } + /** * Commits a task after the writes succeed. Must be called on the executors when running tasks. */ @@ -140,6 +192,15 @@ object FileCommitProtocol extends Logging { object EmptyTaskCommitMessage extends TaskCommitMessage(null) + /** + * The specification for Spark output file name. + * This is used by [[FileCommitProtocol]] to create full path of file. + * + * @param prefix Prefix of file. + * @param suffix Suffix of file. + */ + final case class FileNameSpec(prefix: String, suffix: String) + /** * Instantiates a FileCommitProtocol using the given className. */ diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 7ea0f01511f2..f2f9fc748464 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -118,7 +118,12 @@ class HadoopMapReduceCommitProtocol( override def newTaskTempFile( taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { - val filename = getFilename(taskContext, ext) + newTaskTempFile(taskContext, dir, FileNameSpec("", ext)) + } + + override def newTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], spec: FileNameSpec): String = { + val filename = getFilename(taskContext, spec) val stagingDir: Path = committer match { // For FileOutputCommitter it has its own staging path called "work path". @@ -141,7 +146,12 @@ class HadoopMapReduceCommitProtocol( override def newTaskTempFileAbsPath( taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { - val filename = getFilename(taskContext, ext) + newTaskTempFileAbsPath(taskContext, absoluteDir, FileNameSpec("", ext)) + } + + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, spec: FileNameSpec): String = { + val filename = getFilename(taskContext, spec) val absOutputPath = new Path(absoluteDir, filename).toString // Include a UUID here to prevent file collisions for one task writing to different dirs. @@ -152,12 +162,12 @@ class HadoopMapReduceCommitProtocol( tmpOutputPath } - protected def getFilename(taskContext: TaskAttemptContext, ext: String): String = { + protected def getFilename(taskContext: TaskAttemptContext, spec: FileNameSpec): String = { // The file name looks like part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, // the file name is fine and won't overflow. val split = taskContext.getTaskAttemptID.getTaskID.getId - f"part-$split%05d-$jobId$ext" + f"${spec.prefix}part-$split%05d-$jobId${spec.suffix}" } override def setupJob(jobContext: JobContext): Unit = { diff --git a/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala b/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala index ce8680e7b44e..1834ca0f4205 100644 --- a/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala +++ b/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, PathOutputCommitter, PathOutputCommitterFactory} +import org.apache.spark.internal.io.FileCommitProtocol.FileNameSpec import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol /** @@ -122,20 +123,20 @@ class PathOutputCommitProtocol( * * @param taskContext task context * @param dir optional subdirectory - * @param ext file extension + * @param spec file naming specification * @return a path as a string */ override def newTaskTempFile( taskContext: TaskAttemptContext, dir: Option[String], - ext: String): String = { + spec: FileNameSpec): String = { val workDir = committer.getWorkPath val parent = dir.map { d => new Path(workDir, d) }.getOrElse(workDir) - val file = new Path(parent, getFilename(taskContext, ext)) - logTrace(s"Creating task file $file for dir $dir and ext $ext") + val file = new Path(parent, getFilename(taskContext, spec)) + logTrace(s"Creating task file $file for dir $dir and spec $spec") file.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 6de9b1d7cea4..1bdba19e319e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.FileCommitProtocol.FileNameSpec import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -157,7 +158,7 @@ class DynamicPartitionDataWriter( private val isPartitioned = description.partitionColumns.nonEmpty /** Flag saying whether or not the data to be written out is bucketed. */ - private val isBucketed = description.bucketIdExpression.isDefined + protected val isBucketed = description.bucketSpec.isDefined assert(isPartitioned || isBucketed, s"""DynamicPartitionWriteTask should be used for writing out data that's either @@ -196,7 +197,8 @@ class DynamicPartitionDataWriter( /** Given an input row, returns the corresponding `bucketId` */ private lazy val getBucketId: InternalRow => Int = { val proj = - UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns) + UnsafeProjection.create(Seq(description.bucketSpec.get.bucketIdExpression), + description.allColumns) row => proj(row).getInt(0) } @@ -222,17 +224,23 @@ class DynamicPartitionDataWriter( val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - // This must be in a form that matches our bucketing format. See BucketingUtils. - val ext = f"$bucketIdStr.c$fileCounter%03d" + + // The prefix and suffix must be in a form that matches our bucketing format. + // See BucketingUtils. + val prefix = bucketId match { + case Some(id) => description.bucketSpec.get.bucketFileNamePrefix(id) + case _ => "" + } + val suffix = f"$bucketIdStr.c$fileCounter%03d" + description.outputWriterFactory.getFileExtension(taskAttemptContext) + val fileNameSpec = FileNameSpec(prefix, suffix) val customPath = partDir.flatMap { dir => description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) } val currentPath = if (customPath.isDefined) { - committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, fileNameSpec) } else { - committer.newTaskTempFile(taskAttemptContext, partDir, ext) + committer.newTaskTempFile(taskAttemptContext, partDir, fileNameSpec) } currentWriter = description.outputWriterFactory.newInstance( @@ -277,6 +285,16 @@ class DynamicPartitionDataWriter( } } +/** + * Bucketing specification for all the write tasks. + * + * @param bucketIdExpression Expression to calculate bucket id based on bucket column(s). + * @param bucketFileNamePrefix Prefix of output file name based on bucket id. + */ +case class WriterBucketSpec( + bucketIdExpression: Expression, + bucketFileNamePrefix: Int => String) + /** A shared job description for all the write tasks. */ class WriteJobDescription( val uuid: String, // prevent collision between different (appending) write jobs @@ -285,7 +303,7 @@ class WriteJobDescription( val allColumns: Seq[Attribute], val dataColumns: Seq[Attribute], val partitionColumns: Seq[Attribute], - val bucketIdExpression: Option[Expression], + val bucketSpec: Option[WriterBucketSpec], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], val maxRecordsPerFile: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 48ebd6f0c610..8fecdf5bc3a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String @@ -113,12 +114,33 @@ object FileFormatWriter extends Logging { } val empty2NullPlan = if (needConvert) ProjectExec(projectList, plan) else plan - val bucketIdExpression = bucketSpec.map { spec => + val writerBucketSpec = bucketSpec.map { spec => val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) - // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can - // guarantee the data distribution is same between shuffle and bucketed data source, which - // enables us to only shuffle one side when join a bucketed table and a normal one. - HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + + if (options.getOrElse(DDLUtils.HIVE_PROVIDER, "false") == "true") { + // Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression. + // Without the extra bitwise-and operation, we can get wrong bucket id when hash value of + // columns is negative. See Hive implementation in + // `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`. + val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue)) + val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets)) + + // The bucket file name prefix is following Hive, Presto and Trino conversion, so this + // makes sure Hive bucketed table written by Spark, can be read by other SQL engines. + // + // Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`. + // Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`. + val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_" + WriterBucketSpec(bucketIdExpression, fileNamePrefix) + } else { + // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id + // expression, so that we can guarantee the data distribution is same between shuffle and + // bucketed data source, which enables us to only shuffle one side when join a bucketed + // table and a normal one. + val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets) + .partitionIdExpression + WriterBucketSpec(bucketIdExpression, (_: Int) => "") + } } val sortColumns = bucketSpec.toSeq.flatMap { spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) @@ -139,7 +161,7 @@ object FileFormatWriter extends Logging { allColumns = outputSpec.outputColumns, dataColumns = dataColumns, partitionColumns = partitionColumns, - bucketIdExpression = bucketIdExpression, + bucketSpec = writerBucketSpec, path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) @@ -150,7 +172,8 @@ object FileFormatWriter extends Logging { ) // We should first sort by partition columns, then bucket id, and finally sorting columns. - val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns + val requiredOrdering = + partitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns // the sort order doesn't matter val actualOrdering = empty2NullPlan.outputOrdering.map(_.child) val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { @@ -265,7 +288,7 @@ object FileFormatWriter extends Logging { if (sparkPartitionId != 0 && !iterator.hasNext) { // In case of empty job, leave first partition to save meta for file format like parquet. new EmptyDirectoryDataWriter(description, taskAttemptContext, committer) - } else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { + } else if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) { new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { new DynamicPartitionDataWriter(description, taskAttemptContext, committer) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala index cd62ee7814bf..0bbc5f83646f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -129,7 +129,7 @@ abstract class FileWriteBuilder( allColumns = allColumns, dataColumns = allColumns, partitionColumns = Seq.empty, - bucketIdExpression = None, + bucketSpec = None, path = pathName, customPartitionLocations = Map.empty, maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 0a5feda1bd53..ae35f29c764f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.sources import java.io.File -import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.functions._ @@ -136,29 +136,35 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") } - def tableDir: File = { - val identifier = spark.sessionState.sqlParser.parseTableIdentifier("bucketed_table") + def tableDir(table: String = "bucketed_table"): File = { + val identifier = spark.sessionState.sqlParser.parseTableIdentifier(table) new File(spark.sessionState.catalog.defaultTablePath(identifier)) } + private def bucketIdExpression(expressions: Seq[Expression], numBuckets: Int): Expression = + HashPartitioning(expressions, numBuckets).partitionIdExpression + /** * A helper method to check the bucket write functionality in low level, i.e. check the written * bucket files to see if the data are correct. User should pass in a data dir that these bucket * files are written to, and the format of data(parquet, json, etc.), and the bucketing * information. */ - private def testBucketing( + protected def testBucketing( dataDir: File, source: String, numBuckets: Int, bucketCols: Seq[String], - sortCols: Seq[String] = Nil): Unit = { + sortCols: Seq[String] = Nil, + inputDF: DataFrame = df, + bucketIdExpression: (Seq[Expression], Int) => Expression = bucketIdExpression, + getBucketIdFromFileName: String => Option[Int] = BucketingUtils.getBucketId): Unit = { val allBucketFiles = dataDir.listFiles().filterNot(f => f.getName.startsWith(".") || f.getName.startsWith("_") ) for (bucketFile <- allBucketFiles) { - val bucketId = BucketingUtils.getBucketId(bucketFile.getName).getOrElse { + val bucketId = getBucketIdFromFileName(bucketFile.getName).getOrElse { fail(s"Unable to find the related bucket files.") } @@ -167,7 +173,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val selectedColumns = (bucketCols ++ sortCols).distinct // We may lose the type information after write(e.g. json format doesn't keep schema // information), here we get the types from the original dataframe. - val types = df.select(selectedColumns.map(col): _*).schema.map(_.dataType) + val types = inputDF.select(selectedColumns.map(col): _*).schema.map(_.dataType) val columns = selectedColumns.zip(types).map { case (colName, dt) => col(colName).cast(dt) } @@ -188,7 +194,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val qe = readBack.select(bucketCols.map(col): _*).queryExecution val rows = qe.toRdd.map(_.copy()).collect() val getBucketId = UnsafeProjection.create( - HashPartitioning(qe.analyzed.output, numBuckets).partitionIdExpression :: Nil, + bucketIdExpression(qe.analyzed.output, numBuckets) :: Nil, qe.analyzed.output) for (row <- rows) { @@ -208,7 +214,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { .saveAsTable("bucketed_table") for (i <- 0 until 5) { - testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k")) + testBucketing(new File(tableDir(), s"i=$i"), source, 8, Seq("j", "k")) } } } @@ -225,7 +231,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { .saveAsTable("bucketed_table") for (i <- 0 until 5) { - testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"), Seq("k")) + testBucketing(new File(tableDir(), s"i=$i"), source, 8, Seq("j"), Seq("k")) } } } @@ -255,7 +261,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { .bucketBy(8, "i", "j") .saveAsTable("bucketed_table") - testBucketing(tableDir, source, 8, Seq("i", "j")) + testBucketing(tableDir(), source, 8, Seq("i", "j")) } } } @@ -269,7 +275,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { .sortBy("k") .saveAsTable("bucketed_table") - testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k")) + testBucketing(tableDir(), source, 8, Seq("i", "j"), Seq("k")) } } } @@ -286,7 +292,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { .saveAsTable("bucketed_table") for (i <- 0 until 5) { - testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k")) + testBucketing(new File(tableDir(), s"i=$i"), source, 8, Seq("j", "k")) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index a89243c331c7..9fb11474c770 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.internal.SQLConf @@ -125,7 +126,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log private def isParquetProperty(key: String) = key.startsWith("parquet.") || key.contains(".parquet.") - def convert(relation: HiveTableRelation): LogicalRelation = { + def convert(relation: HiveTableRelation, isWrite: Boolean): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) // Consider table and storage properties. For properties existing in both sides, storage @@ -134,7 +135,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val options = relation.tableMeta.properties.filterKeys(isParquetProperty).toMap ++ relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) - convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") + convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet", isWrite) } else { val options = relation.tableMeta.properties.filterKeys(isOrcProperty).toMap ++ relation.tableMeta.storage.properties @@ -143,13 +144,15 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log relation, options, classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat], - "orc") + "orc", + isWrite) } else { convertToLogicalRelation( relation, options, classOf[org.apache.spark.sql.hive.orc.OrcFileFormat], - "orc") + "orc", + isWrite) } } } @@ -158,7 +161,8 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log relation: HiveTableRelation, options: Map[String, String], fileFormatClass: Class[_ <: FileFormat], - fileType: String): LogicalRelation = { + fileType: String, + isWrite: Boolean): LogicalRelation = { val metastoreSchema = relation.tableMeta.schema val tableIdentifier = QualifiedTableName(relation.tableMeta.database, relation.tableMeta.identifier.table) @@ -166,6 +170,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions val tablePath = new Path(relation.tableMeta.location) val fileFormat = fileFormatClass.getConstructor().newInstance() + val bucketSpec = relation.tableMeta.bucketSpec + val (hiveOptions, hiveBucketSpec) = + if (isWrite) { + (options.updated(DDLUtils.HIVE_PROVIDER, "true"), bucketSpec) + } else { + (options, None) + } val result = if (relation.isPartitioned) { val partitionSchema = relation.tableMeta.partitionSchema @@ -207,16 +218,16 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - val updatedTable = inferIfNeeded(relation, options, fileFormat, Option(fileIndex)) + val updatedTable = inferIfNeeded(relation, hiveOptions, fileFormat, Option(fileIndex)) // Spark SQL's data source table now support static and dynamic partition insert. Source // table converted from Hive table should always use dynamic. - val enableDynamicPartition = options.updated("partitionOverwriteMode", "dynamic") + val enableDynamicPartition = hiveOptions.updated("partitionOverwriteMode", "dynamic") val fsRelation = HadoopFsRelation( location = fileIndex, partitionSchema = partitionSchema, dataSchema = updatedTable.dataSchema, - bucketSpec = None, + bucketSpec = hiveBucketSpec, fileFormat = fileFormat, options = enableDynamicPartition)(sparkSession = sparkSession) val created = LogicalRelation(fsRelation, updatedTable) @@ -236,15 +247,15 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log fileFormatClass, None) val logicalRelation = cached.getOrElse { - val updatedTable = inferIfNeeded(relation, options, fileFormat) + val updatedTable = inferIfNeeded(relation, hiveOptions, fileFormat) val created = LogicalRelation( DataSource( sparkSession = sparkSession, paths = rootPath.toString :: Nil, userSpecifiedSchema = Option(updatedTable.dataSchema), - bucketSpec = None, - options = options, + bucketSpec = hiveBucketSpec, + options = hiveOptions, className = fileType).resolveRelation(), table = updatedTable) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index ff7dc58829fa..5881568bcc03 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -212,13 +212,13 @@ case class RelationConversions( if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && (!r.isPartitioned || SQLConf.get.getConf(HiveUtils.CONVERT_INSERTING_PARTITIONED_TABLE)) && isConvertible(r) => - InsertIntoStatement(metastoreCatalog.convert(r), partition, cols, + InsertIntoStatement(metastoreCatalog.convert(r, isWrite = true), partition, cols, query, overwrite, ifPartitionNotExists) // Read path case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) => - metastoreCatalog.convert(relation) + metastoreCatalog.convert(relation, isWrite = false) // CTAS case CreateTable(tableDesc, mode, Some(query)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 283c254b3960..80672ba1a039 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -154,7 +154,7 @@ case class OptimizedCreateHiveTableAsSelectCommand( val metastoreCatalog = catalog.asInstanceOf[HiveSessionCatalog].metastoreCatalog val hiveTable = DDLUtils.readHiveTable(tableDesc) - val hadoopRelation = metastoreCatalog.convert(hiveTable) match { + val hadoopRelation = metastoreCatalog.convert(hiveTable, isWrite = true) match { case LogicalRelation(t: HadoopFsRelation, _, _, _) => t case _ => throw new AnalysisException(s"$tableIdentifier should be converted to " + "HadoopFsRelation.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala index bdbdcc295107..c12caaa2e0fd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql.sources +import java.io.File + +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, Expression, HiveHash, Literal, Pmod} +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION @@ -27,4 +32,47 @@ class BucketedWriteWithHiveSupportSuite extends BucketedWriteSuite with TestHive } override protected def fileFormatsToTest: Seq[String] = Seq("parquet", "orc") + + test("write hive bucketed table") { + def bucketIdExpression(expressions: Seq[Expression], numBuckets: Int): Expression = + Pmod(BitwiseAnd(HiveHash(expressions), Literal(Int.MaxValue)), Literal(8)) + + def getBucketIdFromFileName(fileName: String): Option[Int] = { + val hiveBucketedFileName = """^(\d+)_0_.*$""".r + fileName match { + case hiveBucketedFileName(bucketId) => Some(bucketId.toInt) + case _ => None + } + } + + val table = "hive_bucketed_table" + + fileFormatsToTest.foreach { format => + withTable(table) { + sql( + s""" + |CREATE TABLE IF NOT EXISTS $table (i int, j string) + |PARTITIONED BY(k string) + |CLUSTERED BY (i, j) SORTED BY (i) INTO 8 BUCKETS + |STORED AS $format + """.stripMargin) + + val df = + (0 until 50).map(i => (i % 13, i.toString, i % 5)).toDF("i", "j", "k") + df.write.mode(SaveMode.Overwrite).insertInto(table) + + for (k <- 0 until 5) { + testBucketing( + new File(tableDir(table), s"k=$k"), + format, + 8, + Seq("i", "j"), + Seq("i"), + df, + bucketIdExpression, + getBucketIdFromFileName) + } + } + } + } }