Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.catalyst.util.DateFormatter
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.parser.HoodieExtendedParserInterface
Expand All @@ -42,7 +43,7 @@ import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.{DataType, Metadata, StructType}
import org.apache.spark.storage.StorageLevel

import java.util.Locale
import java.util.{Locale, TimeZone}

/**
* Interface adapting discrepancies and incompatibilities between different Spark versions
Expand Down Expand Up @@ -115,6 +116,11 @@ trait SparkAdapter extends Serializable {
*/
def getSparkParsePartitionUtil: SparkParsePartitionUtil

/**
* Get the [[DateFormatter]].
*/
def getDateFormatter(tz: TimeZone): DateFormatter

/**
* Combine [[PartitionedFile]] to [[FilePartition]] according to `maxSplitBytes`.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ import org.apache.hudi.util.JFunction
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, EmptyRow, EqualTo, Expression, InterpretedPredicate}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, EmptyRow, EqualTo, Expression, InterpretedPredicate, Literal}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.{InternalRow, expressions}
import org.apache.spark.sql.execution.datasources.{FileStatusCache, NoopCache}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.types.{ByteType, DataType, DateType, IntegerType, LongType, ShortType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String

import javax.annotation.concurrent.NotThreadSafe
Expand Down Expand Up @@ -281,14 +281,15 @@ class SparkHoodieTableFileIndex(spark: SparkSession,
// Static partition-path prefix is defined as a prefix of the full partition-path where only
// first N partition columns (in-order) have proper (static) values bound in equality predicates,
// allowing in turn to build such prefix to be used in subsequent filtering
val staticPartitionColumnNameValuePairs: Seq[(String, Any)] = {
val staticPartitionColumnNameValuePairs: Seq[(String, (String, Any))] = {
// Extract from simple predicates of the form `date = '2022-01-01'` both
// partition column and corresponding (literal) value
val staticPartitionColumnValuesMap = extractEqualityPredicatesLiteralValues(partitionColumnPredicates)
val zoneId = configProperties.getString(DateTimeUtils.TIMEZONE_OPTION, SQLConf.get.sessionLocalTimeZone)
val staticPartitionColumnValuesMap = extractEqualityPredicatesLiteralValues(partitionColumnPredicates, zoneId)
// NOTE: For our purposes we can only construct partition-path prefix if proper prefix of the
// partition-schema has been bound by the partition-predicates
partitionColumnNames.takeWhile(colName => staticPartitionColumnValuesMap.contains(colName))
.map(colName => (colName, staticPartitionColumnValuesMap(colName).get))
.map(colName => (colName, (staticPartitionColumnValuesMap(colName)._1, staticPartitionColumnValuesMap(colName)._2.get)))
}

if (staticPartitionColumnNameValuePairs.isEmpty) {
Expand All @@ -301,7 +302,7 @@ class SparkHoodieTableFileIndex(spark: SparkSession,

if (staticPartitionColumnNameValuePairs.length == partitionColumnNames.length) {
// In case composed partition path is complete, we can return it directly avoiding extra listing operation
Seq(new PartitionPath(relativePartitionPathPrefix, staticPartitionColumnNameValuePairs.map(_._2.asInstanceOf[AnyRef]).toArray))
Seq(new PartitionPath(relativePartitionPathPrefix, staticPartitionColumnNameValuePairs.map(_._2._2.asInstanceOf[AnyRef]).toArray))
} else {
// Otherwise, compile extracted partition values (from query predicates) into a sub-path which is a prefix
// of the complete partition path, do listing for this prefix-path only
Expand All @@ -315,7 +316,7 @@ class SparkHoodieTableFileIndex(spark: SparkSession,
*
* @return relative partition path and a flag to indicate if the path is complete (i.e., not a prefix)
*/
private def composeRelativePartitionPath(staticPartitionColumnNameValuePairs: Seq[(String, Any)]): String = {
private def composeRelativePartitionPath(staticPartitionColumnNameValuePairs: Seq[(String, (String, Any))]): String = {
checkState(staticPartitionColumnNameValuePairs.nonEmpty)

// Since static partition values might not be available for all columns, we compile
Expand All @@ -331,7 +332,7 @@ class SparkHoodieTableFileIndex(spark: SparkSession,
)

partitionPathFormatter.combine(staticPartitionColumnNames.asJava,
staticPartitionColumnValues.map(_.asInstanceOf[AnyRef]): _*)
staticPartitionColumnValues.map(_._1): _*)
}

protected def doParsePartitionColumnValues(partitionColumns: Array[String], partitionPath: String): Array[Object] = {
Expand Down Expand Up @@ -407,24 +408,35 @@ class SparkHoodieTableFileIndex(spark: SparkSession,
metaClient.getTableConfig.getUrlEncodePartitioning.toBoolean
}

object SparkHoodieTableFileIndex {
object SparkHoodieTableFileIndex extends SparkAdapterSupport {

private def haveProperPartitionValues(partitionPaths: Seq[PartitionPath]) = {
partitionPaths.forall(_.values.length > 0)
}

private def extractEqualityPredicatesLiteralValues(predicates: Seq[Expression]): Map[String, Option[Any]] = {
private def extractEqualityPredicatesLiteralValues(predicates: Seq[Expression], zoneId: String): Map[String, (String, Option[Any])] = {
// TODO support coercible expressions (ie attr-references casted to particular type), similar
// to `MERGE INTO` statement

object ExtractableLiteral {
def unapply(exp: Expression): Option[String] = exp match {
case Literal(null, _) => None // `null`s can be cast as other types; we want to avoid NPEs.
case Literal(value, _: ByteType | IntegerType | LongType | ShortType) => Some(value.toString)
case Literal(value, _: StringType) => Some(value.toString)
case Literal(value, _: DateType) =>
Some(sparkAdapter.getDateFormatter(DateTimeUtils.getTimeZone(zoneId)).format(value.asInstanceOf[Int]))
case _ => None
}
}

// NOTE: To properly support predicates of the form `x = NULL`, we have to wrap result
// of the folded expression into [[Some]] (to distinguish it from the case when partition-column
// isn't bound to any value by the predicate)
predicates.flatMap {
case EqualTo(attr: AttributeReference, e: Expression) if e.foldable =>
Seq((attr.name, Some(e.eval(EmptyRow))))
case EqualTo(e: Expression, attr: AttributeReference) if e.foldable =>
Seq((attr.name, Some(e.eval(EmptyRow))))

case EqualTo(attr: AttributeReference, e @ ExtractableLiteral(valueString)) =>
Seq((attr.name, (valueString, Some(e.eval(EmptyRow)))))
case EqualTo(e @ ExtractableLiteral(valueString), attr: AttributeReference) =>
Seq((attr.name, (valueString, Some(e.eval(EmptyRow)))))
case _ => Seq.empty
}.toMap
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hudi

class TestLazyPartitionPathFetching extends HoodieSparkSqlTestBase {

test("Test querying with string column + partition pruning") {
withTempDir { tmp =>
val tableName = generateTableName
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long,
| date_par date
|) using hudi
| location '${tmp.getCanonicalPath}'
| tblproperties (
| primaryKey ='id',
| type = 'cow',
| preCombineField = 'ts'
| )
| PARTITIONED BY (date_par)
""".stripMargin)
spark.sql(s"insert into $tableName values(1, 'a1', 10, 1000, date '2023-02-27')")
spark.sql(s"insert into $tableName values(2, 'a2', 10, 1000, date '2023-02-28')")
spark.sql(s"insert into $tableName values(3, 'a3', 10, 1000, date '2023-03-01')")

checkAnswer(s"select id, name, price, ts from $tableName where date_par='2023-03-01' order by id")(
Seq(3, "a3", 10.0, 1000)
)

withSQLConf("spark.sql.session.timeZone" -> "UTC+2") {
checkAnswer(s"select id, name, price, ts from $tableName where date_par='2023-03-01' order by id")(
Seq(3, "a3", 10.0, 1000)
)
}
}
}

test("Test querying with date column + partition pruning (multi-level partitioning)") {
withTempDir { tmp =>
val tableName = generateTableName
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long,
| country string,
| date_par date
|) using hudi
| location '${tmp.getCanonicalPath}'
| tblproperties (
| primaryKey ='id',
| type = 'cow',
| preCombineField = 'ts'
| )
| PARTITIONED BY (country, date_par)
""".stripMargin)
spark.sql(s"insert into $tableName values(1, 'a1', 10, 1000, 'ID', date '2023-02-27')")
spark.sql(s"insert into $tableName values(2, 'a2', 10, 1000, 'ID', date '2023-02-28')")
spark.sql(s"insert into $tableName values(3, 'a3', 10, 1000, 'ID', date '2023-03-01')")

// for lazy fetching partition path & file slice to be enabled, filter must be applied on all partitions
checkAnswer(s"select id, name, price, ts from $tableName " +
s"where date_par='2023-03-01' and country='ID' order by id")(
Seq(3, "a3", 10.0, 1000)
)

withSQLConf("spark.sql.session.timeZone" -> "UTC+2") {
checkAnswer(s"select id, name, price, ts from $tableName " +
s"where date_par='2023-03-01' and country='ID' order by id")(
Seq(3, "a3", 10.0, 1000)
)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InterpretedPredicate}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.{Command, DeleteFromTable, LogicalPlan}
import org.apache.spark.sql.catalyst.util.DateFormatter
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Spark24HoodieParquetFileFormat}
import org.apache.spark.sql.execution.vectorized.MutableColumnarRow
Expand All @@ -45,7 +46,11 @@ import org.apache.spark.sql.types.{DataType, Metadata, MetadataBuilder, StructTy
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel._

import java.time.ZoneId
import java.util.TimeZone
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters.mapAsScalaMapConverter
import scala.collection.convert.Wrappers.JConcurrentMapWrapper
import scala.collection.mutable.ArrayBuffer

/**
Expand All @@ -64,6 +69,9 @@ class Spark2Adapter extends SparkAdapter {
// we simply produce an empty [[Metadata]] instance
new MetadataBuilder().build()

private val cache = JConcurrentMapWrapper(
new ConcurrentHashMap[ZoneId, DateFormatter](1))

override def getCatalogUtils: HoodieCatalogUtils = {
throw new UnsupportedOperationException("Catalog utilities are not supported in Spark 2.x");
}
Expand All @@ -90,6 +98,10 @@ class Spark2Adapter extends SparkAdapter {

override def getSparkParsePartitionUtil: SparkParsePartitionUtil = Spark2ParsePartitionUtil

override def getDateFormatter(tz: TimeZone): DateFormatter = {
cache.getOrElseUpdate(tz.toZoneId, DateFormatter())
}

/**
* Combine [[PartitionedFile]] to [[FilePartition]] according to `maxSplitBytes`.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.avro.Schema
import org.apache.hadoop.fs.Path
import org.apache.hudi.client.utils.SparkRowSerDe
import org.apache.hudi.common.table.HoodieTableMetaClient
import org.apache.hudi.spark3.internal.ReflectUtil
import org.apache.hudi.{AvroConversionUtils, DefaultSource, Spark3RowSerDe}
import org.apache.hudi.{AvroConversionUtils, DefaultSource, HoodieBaseRelation, Spark3RowSerDe}
import org.apache.spark.internal.Logging
Expand All @@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, InterpretedPredicate, Predicate}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.DateFormatter
import org.apache.spark.sql.connector.catalog.V2TableWithV1Fallback
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
Expand All @@ -42,13 +44,20 @@ import org.apache.spark.sql.{HoodieSpark3CatalogUtils, SQLContext, SparkSession}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel._

import java.time.ZoneId
import java.util.TimeZone
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters.mapAsScalaMapConverter
import scala.collection.convert.Wrappers.JConcurrentMapWrapper

/**
* Base implementation of [[SparkAdapter]] for Spark 3.x branch
*/
abstract class BaseSpark3Adapter extends SparkAdapter with Logging {

private val cache = JConcurrentMapWrapper(
new ConcurrentHashMap[ZoneId, DateFormatter](1))

def getCatalogUtils: HoodieSpark3CatalogUtils

override def createSparkRowSerDe(schema: StructType): SparkRowSerDe = {
Expand All @@ -74,6 +83,10 @@ abstract class BaseSpark3Adapter extends SparkAdapter with Logging {

override def getSparkParsePartitionUtil: SparkParsePartitionUtil = Spark3ParsePartitionUtil

override def getDateFormatter(tz: TimeZone): DateFormatter = {
cache.getOrElseUpdate(tz.toZoneId, ReflectUtil.getDateFormatter(tz.toZoneId))
}

/**
* Combine [[PartitionedFile]] to [[FilePartition]] according to `maxSplitBytes`.
*/
Expand Down