diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index 12d456a371d0..d3ed4d12099b 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -307,6 +307,8 @@ private[spark] object LogKeys { case object HIVE_METASTORE_VERSION extends LogKey case object HIVE_OPERATION_STATE extends LogKey case object HIVE_OPERATION_TYPE extends LogKey + case object HMS_CURRENT_BATCH_SIZE extends LogKey + case object HMS_INITIAL_BATCH_SIZE extends LogKey case object HOST extends LogKey case object HOSTS extends LogKey case object HOST_LOCAL_BLOCKS_SIZE extends LogKey diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ea187c0316c1..be1d34037c85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5199,6 +5199,29 @@ object SQLConf { .booleanConf .createWithDefault(false) + val HMS_BATCH_SIZE = buildConf("spark.sql.hive.metastore.batchSize") + .internal() + .doc("This setting defines the batch size for fetching metadata partitions from the" + + "Hive Metastore. A value of -1 disables batching by default. To enable batching," + + "specify a positive integer, which will determine the batch size for partition fetching." + ) + .version("4.0.0") + .intConf + .createWithDefault(-1) + + val METASTORE_PARTITION_BATCH_RETRY_COUNT = buildConf( + "spark.sql.metastore.partition.batch.retry.count") + .internal() + .doc( + "This setting specifies the number of retries for fetching partitions from the metastore" + + "in case of failure to fetch batch metadata. This retry mechanism is applicable only" + + "when HMS_BATCH_SIZE is enabled. It defines the count for the number of " + + "retries to be done." + ) + .version("4.0.0") + .intConf + .createWithDefault(3) + /** * Holds information about keys that have been deprecated. * @@ -6177,6 +6200,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def legacyEvalCurrentTime: Boolean = getConf(SQLConf.LEGACY_EVAL_CURRENT_TIME) + def getHiveMetaStoreBatchSize: Int = getConf(HMS_BATCH_SIZE) + + def metastorePartitionBatchRetryCount: Int = getConf(METASTORE_PARTITION_BATCH_RETRY_COUNT) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index c03fed4cc318..5ef956823be3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -22,6 +22,7 @@ import java.lang.reflect.{InvocationTargetException, Method} import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap} import java.util.concurrent.TimeUnit +import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -37,7 +38,7 @@ import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorF import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.spark.internal.{Logging, MDC} +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.internal.LogKeys.{CONFIG, CONFIG2, CONFIG3} import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} @@ -389,6 +390,70 @@ private[client] class Shim_v2_0 extends Shim with Logging { partitions.asScala.toSeq } + private def getPartitionNamesWithCount(hive: Hive, table: Table): (Int, Seq[String]) = { + val partitionNames = hive.getPartitionNames( + table.getDbName, table.getTableName, -1).asScala.toSeq + (partitionNames.length, partitionNames) + } + + private def getPartitionsInBatches( + hive: Hive, + table: Table, + initialBatchSize: Int, + partNames: Seq[String]): java.util.Collection[Partition] = { + val maxRetries = SQLConf.get.metastorePartitionBatchRetryCount + val decayingFactor = 2 + + if (initialBatchSize <= 0) { + throw new IllegalArgumentException(s"Invalid batch size $initialBatchSize provided " + + s"for fetching partitions.Batch size must be greater than 0") + } + + if (maxRetries < 0) { + throw new IllegalArgumentException(s"Invalid number of maximum retries $maxRetries " + + s"provided for fetching partitions.It must be a non-negative integer value") + } + + logInfo(log"Breaking your request into small batches" + + log" of ${MDC(LogKeys.HMS_INITIAL_BATCH_SIZE, initialBatchSize)}.") + + var batchSize = initialBatchSize + val processedPartitions = mutable.ListBuffer[Partition]() + var retryCount = 0 + var index = 0 + + def getNextBatchSize(): Int = { + val currentBatchSize = batchSize + batchSize = (batchSize / decayingFactor) max 1 + currentBatchSize + } + + var currentBatchSize = getNextBatchSize() + var partitions: java.util.Collection[Partition] = null + + while (index < partNames.size && retryCount <= maxRetries) { + val batch = partNames.slice(index, index + currentBatchSize) + + try { + partitions = hive.getPartitionsByNames(table, batch.asJava) + processedPartitions ++= partitions.asScala + index += batch.size + } catch { + case ex: Exception => + logWarning(s"Caught exception while fetching partitions for batch, attempting retry.", ex) + retryCount += 1 + currentBatchSize = getNextBatchSize() + logInfo(log"Further reducing batch size to " + + log"${MDC(LogKeys.HMS_CURRENT_BATCH_SIZE, currentBatchSize)}.") + if (retryCount > maxRetries) { + logError(s"Failed to fetch partitions for the request. Retries count exceeded.") + } + } + } + + processedPartitions.asJava + } + private def prunePartitionsFastFallback( hive: Hive, table: Table, @@ -406,11 +471,19 @@ private[client] class Shim_v2_0 extends Shim with Logging { } } + val batchSize = SQLConf.get.getHiveMetaStoreBatchSize + if (!SQLConf.get.metastorePartitionPruningFastFallback || predicates.isEmpty || predicates.exists(hasTimeZoneAwareExpression)) { + val (count, partNames) = getPartitionNamesWithCount(hive, table) recordHiveCall() - hive.getAllPartitionsOf(table) + if(count < batchSize || batchSize == -1) { + hive.getAllPartitionsOf(table) + } + else { + getPartitionsInBatches(hive, table, batchSize, partNames) + } } else { try { val partitionSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema( @@ -442,8 +515,14 @@ private[client] class Shim_v2_0 extends Shim with Logging { case ex: HiveException if ex.getCause.isInstanceOf[MetaException] => logWarning("Caught Hive MetaException attempting to get partition metadata by " + "filter from client side. Falling back to fetching all partition metadata", ex) + val (count, partNames) = getPartitionNamesWithCount(hive, table) recordHiveCall() - hive.getAllPartitionsOf(table) + if(count < batchSize || batchSize == -1) { + hive.getAllPartitionsOf(table) + } + else { + getPartitionsInBatches(hive, table, batchSize, partNames) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala index 1a4eb7554789..ce162a325f44 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala @@ -726,6 +726,40 @@ class HivePartitionFilteringSuite(version: String) } } + test("getPartitionsByFilter: getPartitionsInBatches") { + var filteredPartitions: Seq[CatalogTablePartition] = Seq() + var filteredPartitionsNoBatch: Seq[CatalogTablePartition] = Seq() + var filteredPartitionsHighBatch: Seq[CatalogTablePartition] = Seq() + + withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "1") { + filteredPartitions = client.getPartitionsByFilter( + client.getRawHiveTable("default", "test"), + Seq(attr("ds") === 20170101) + ) + } + withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "-1") { + filteredPartitionsNoBatch = client.getPartitionsByFilter( + client.getRawHiveTable("default", "test"), + Seq(attr("ds") === 20170101) + ) + } + withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "5000") { + filteredPartitionsHighBatch = client.getPartitionsByFilter( + client.getRawHiveTable("default", "test"), + Seq(attr("ds") === 20170101) + ) + } + + assert(filteredPartitions.size == filteredPartitionsNoBatch.size) + assert(filteredPartitions.size == filteredPartitionsHighBatch.size) + assert( + filteredPartitions.map(_.spec.toSet).toSet == + filteredPartitionsNoBatch.map(_.spec.toSet).toSet) + assert( + filteredPartitions.map(_.spec.toSet).toSet == + filteredPartitionsHighBatch.map(_.spec.toSet).toSet) + } + private def testMetastorePartitionFiltering( filterExpr: Expression, expectedDs: Seq[Int],