diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieFileIndex.scala index 61c2f3ad575ae..8e7e1f85bb90d 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieFileIndex.scala @@ -262,7 +262,14 @@ case class HoodieFileIndex( // If the partition column size is not equal to the partition fragment size // and the partition column size is 1, we map the whole partition path // to the partition column which can benefit from the partition prune. - InternalRow.fromSeq(Seq(UTF8String.fromString(partitionPath))) + val prefix = s"${partitionSchema.fieldNames.head}=" + val partitionValue = if (partitionPath.startsWith(prefix)) { + // support hive style partition path + partitionPath.substring(prefix.length) + } else { + partitionPath + } + InternalRow.fromSeq(Seq(UTF8String.fromString(partitionValue))) } else if (partitionFragments.length != partitionSchema.fields.length && partitionSchema.fields.length > 1) { // If the partition column size is not equal to the partition fragments size diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala index 72b26be4fc1fd..ee83cf4463e38 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala @@ -28,8 +28,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.avro.SchemaConverters import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal} import org.apache.spark.sql.execution.datasources.{FileStatusCache, InMemoryFileIndex, Spark2ParsePartitionUtil, Spark3ParsePartitionUtil, SparkParsePartitionUtil} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.{And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith} import org.apache.spark.sql.types.{StringType, StructField, StructType} import scala.collection.JavaConverters._ @@ -128,4 +130,98 @@ object HoodieSparkUtils { new Spark3ParsePartitionUtil(conf) } } + + /** + * Convert Filters to Catalyst Expressions and joined by And. If convert success return an + * Non-Empty Option[Expression],or else return None. + */ + def convertToCatalystExpressions(filters: Array[Filter], + tableSchema: StructType): Option[Expression] = { + val expressions = filters.map(convertToCatalystExpression(_, tableSchema)) + if (expressions.forall(p => p.isDefined)) { + if (expressions.isEmpty) { + None + } else if (expressions.length == 1) { + expressions(0) + } else { + Some(expressions.map(_.get).reduce(org.apache.spark.sql.catalyst.expressions.And)) + } + } else { + None + } + } + + /** + * Convert Filter to Catalyst Expression. If convert success return an Non-Empty + * Option[Expression],or else return None. + */ + def convertToCatalystExpression(filter: Filter, tableSchema: StructType): Option[Expression] = { + Option( + filter match { + case EqualTo(attribute, value) => + org.apache.spark.sql.catalyst.expressions.EqualTo(toAttribute(attribute, tableSchema), Literal.create(value)) + case EqualNullSafe(attribute, value) => + org.apache.spark.sql.catalyst.expressions.EqualNullSafe(toAttribute(attribute, tableSchema), Literal.create(value)) + case GreaterThan(attribute, value) => + org.apache.spark.sql.catalyst.expressions.GreaterThan(toAttribute(attribute, tableSchema), Literal.create(value)) + case GreaterThanOrEqual(attribute, value) => + org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual(toAttribute(attribute, tableSchema), Literal.create(value)) + case LessThan(attribute, value) => + org.apache.spark.sql.catalyst.expressions.LessThan(toAttribute(attribute, tableSchema), Literal.create(value)) + case LessThanOrEqual(attribute, value) => + org.apache.spark.sql.catalyst.expressions.LessThanOrEqual(toAttribute(attribute, tableSchema), Literal.create(value)) + case In(attribute, values) => + val attrExp = toAttribute(attribute, tableSchema) + val valuesExp = values.map(v => Literal.create(v)) + org.apache.spark.sql.catalyst.expressions.In(attrExp, valuesExp) + case IsNull(attribute) => + org.apache.spark.sql.catalyst.expressions.IsNull(toAttribute(attribute, tableSchema)) + case IsNotNull(attribute) => + org.apache.spark.sql.catalyst.expressions.IsNotNull(toAttribute(attribute, tableSchema)) + case And(left, right) => + val leftExp = convertToCatalystExpression(left, tableSchema) + val rightExp = convertToCatalystExpression(right, tableSchema) + if (leftExp.isEmpty || rightExp.isEmpty) { + null + } else { + org.apache.spark.sql.catalyst.expressions.And(leftExp.get, rightExp.get) + } + case Or(left, right) => + val leftExp = convertToCatalystExpression(left, tableSchema) + val rightExp = convertToCatalystExpression(right, tableSchema) + if (leftExp.isEmpty || rightExp.isEmpty) { + null + } else { + org.apache.spark.sql.catalyst.expressions.Or(leftExp.get, rightExp.get) + } + case Not(child) => + val childExp = convertToCatalystExpression(child, tableSchema) + if (childExp.isEmpty) { + null + } else { + org.apache.spark.sql.catalyst.expressions.Not(childExp.get) + } + case StringStartsWith(attribute, value) => + val leftExp = toAttribute(attribute, tableSchema) + val rightExp = Literal.create(s"$value%") + org.apache.spark.sql.catalyst.expressions.Like(leftExp, rightExp) + case StringEndsWith(attribute, value) => + val leftExp = toAttribute(attribute, tableSchema) + val rightExp = Literal.create(s"%$value") + org.apache.spark.sql.catalyst.expressions.Like(leftExp, rightExp) + case StringContains(attribute, value) => + val leftExp = toAttribute(attribute, tableSchema) + val rightExp = Literal.create(s"%$value%") + org.apache.spark.sql.catalyst.expressions.Like(leftExp, rightExp) + case _=> null + } + ) + } + + private def toAttribute(columnName: String, tableSchema: StructType): AttributeReference = { + val field = tableSchema.find(p => p.name == columnName) + assert(field.isDefined, s"Cannot find column: $columnName, Table Columns are: " + + s"${tableSchema.fieldNames.mkString(",")}") + AttributeReference(columnName, field.get.dataType, field.get.nullable)() + } } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/MergeOnReadSnapshotRelation.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/MergeOnReadSnapshotRelation.scala index c9d413bbdc570..13cf43ead2bb1 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/MergeOnReadSnapshotRelation.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/MergeOnReadSnapshotRelation.scala @@ -67,7 +67,6 @@ class MergeOnReadSnapshotRelation(val sqlContext: SQLContext, DataSourceReadOptions.REALTIME_MERGE_OPT_KEY, DataSourceReadOptions.DEFAULT_REALTIME_MERGE_OPT_VAL) private val maxCompactionMemoryInBytes = getMaxCompactionMemoryInBytes(jobConf) - private val fileIndex = buildFileIndex() private val preCombineField = { val preCombineFieldFromTableConfig = metaClient.getTableConfig.getPreCombineField if (preCombineFieldFromTableConfig != null) { @@ -94,6 +93,8 @@ class MergeOnReadSnapshotRelation(val sqlContext: SQLContext, }) val requiredAvroSchema = AvroConversionUtils .convertStructTypeToAvroSchema(requiredStructSchema, tableAvroSchema.getName, tableAvroSchema.getNamespace) + + val fileIndex = buildFileIndex(filters) val hoodieTableState = HoodieMergeOnReadTableState( tableStructSchema, requiredStructSchema, @@ -131,7 +132,8 @@ class MergeOnReadSnapshotRelation(val sqlContext: SQLContext, rdd.asInstanceOf[RDD[Row]] } - def buildFileIndex(): List[HoodieMergeOnReadFileSplit] = { + def buildFileIndex(filters: Array[Filter]): List[HoodieMergeOnReadFileSplit] = { + val fileStatuses = if (globPaths.isDefined) { // Load files from the global paths if it has defined to be compatible with the original mode val inMemoryFileIndex = HoodieSparkUtils.createInMemoryFileIndex(sqlContext.sparkSession, globPaths.get) @@ -139,7 +141,19 @@ class MergeOnReadSnapshotRelation(val sqlContext: SQLContext, } else { // Load files by the HoodieFileIndex. val hoodieFileIndex = HoodieFileIndex(sqlContext.sparkSession, metaClient, Some(tableStructSchema), optParams, FileStatusCache.getOrCreate(sqlContext.sparkSession)) - hoodieFileIndex.allFiles + + // Get partition filter and convert to catalyst expression + val partitionColumns = hoodieFileIndex.partitionSchema.fieldNames.toSet + val partitionFilters = filters.filter(f => f.references.forall(p => partitionColumns.contains(p))) + val partitionFilterExpression = + HoodieSparkUtils.convertToCatalystExpressions(partitionFilters, tableStructSchema) + + // if convert success to catalyst expression, use the partition prune + if (partitionFilterExpression.isDefined) { + hoodieFileIndex.listFiles(Seq(partitionFilterExpression.get), Seq.empty).flatMap(_.files) + } else { + hoodieFileIndex.allFiles + } } if (fileStatuses.isEmpty) { // If this an empty table, return an empty split list. diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala new file mode 100644 index 0000000000000..d1a117086c30d --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi + +import org.apache.hudi.HoodieSparkUtils.convertToCatalystExpressions +import org.apache.hudi.HoodieSparkUtils.convertToCatalystExpression +import org.apache.spark.sql.sources.{And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith} +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType} +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +import scala.collection.mutable.ArrayBuffer + +class TestConvertFilterToCatalystExpression { + + private lazy val tableSchema = { + val fields = new ArrayBuffer[StructField]() + fields.append(StructField("id", LongType, nullable = false)) + fields.append(StructField("name", StringType, nullable = true)) + fields.append(StructField("price", DoubleType, nullable = true)) + fields.append(StructField("ts", IntegerType, nullable = false)) + StructType(fields) + } + + @Test + def testBaseConvert(): Unit = { + checkConvertFilter(eq("id", 1), "(`id` = 1)") + checkConvertFilter(eqs("name", "a1"), "(`name` <=> 'a1')") + checkConvertFilter(lt("price", 10), "(`price` < 10)") + checkConvertFilter(lte("ts", 1), "(`ts` <= 1)") + checkConvertFilter(gt("price", 10), "(`price` > 10)") + checkConvertFilter(gte("price", 10), "(`price` >= 10)") + checkConvertFilter(in("id", 1, 2 , 3), "(`id` IN (1, 2, 3))") + checkConvertFilter(isNull("id"), "(`id` IS NULL)") + checkConvertFilter(isNotNull("name"), "(`name` IS NOT NULL)") + checkConvertFilter(and(lt("ts", 10), gt("ts", 1)), + "((`ts` < 10) AND (`ts` > 1))") + checkConvertFilter(or(lte("ts", 10), gte("ts", 1)), + "((`ts` <= 10) OR (`ts` >= 1))") + checkConvertFilter(not(and(lt("ts", 10), gt("ts", 1))), + "(NOT ((`ts` < 10) AND (`ts` > 1)))") + checkConvertFilter(startWith("name", "ab"), "`name` LIKE 'ab%'") + checkConvertFilter(endWith("name", "cd"), "`name` LIKE '%cd'") + checkConvertFilter(contains("name", "e"), "`name` LIKE '%e%'") + } + + @Test + def testConvertFilters(): Unit = { + checkConvertFilters(Array.empty[Filter], null) + checkConvertFilters(Array(eq("id", 1)), "(`id` = 1)") + checkConvertFilters(Array(lt("ts", 10), gt("ts", 1)), + "((`ts` < 10) AND (`ts` > 1))") + } + + @Test + def testUnSupportConvert(): Unit = { + checkConvertFilters(Array(unsupport()), null) + checkConvertFilters(Array(and(unsupport(), eq("id", 1))), null) + checkConvertFilters(Array(or(unsupport(), eq("id", 1))), null) + checkConvertFilters(Array(and(eq("id", 1), not(unsupport()))), null) + } + + private def checkConvertFilter(filter: Filter, expectExpression: String): Unit = { + val exp = convertToCatalystExpression(filter, tableSchema) + if (expectExpression == null) { + assertEquals(exp.isEmpty, true) + } else { + assertEquals(exp.isDefined, true) + assertEquals(expectExpression, exp.get.sql) + } + } + + private def checkConvertFilters(filters: Array[Filter], expectExpression: String): Unit = { + val exp = convertToCatalystExpressions(filters, tableSchema) + if (expectExpression == null) { + assertEquals(exp.isEmpty, true) + } else { + assertEquals(exp.isDefined, true) + assertEquals(expectExpression, exp.get.sql) + } + } + + private def eq(attribute: String, value: Any): Filter = { + EqualTo(attribute, value) + } + + private def eqs(attribute: String, value: Any): Filter = { + EqualNullSafe(attribute, value) + } + + private def gt(attribute: String, value: Any): Filter = { + GreaterThan(attribute, value) + } + + private def gte(attribute: String, value: Any): Filter = { + GreaterThanOrEqual(attribute, value) + } + + private def lt(attribute: String, value: Any): Filter = { + LessThan(attribute, value) + } + + private def lte(attribute: String, value: Any): Filter = { + LessThanOrEqual(attribute, value) + } + + private def in(attribute: String, values: Any*): Filter = { + In(attribute, values.toArray) + } + + private def isNull(attribute: String): Filter = { + IsNull(attribute) + } + + private def isNotNull(attribute: String): Filter = { + IsNotNull(attribute) + } + + private def and(left: Filter, right: Filter): Filter = { + And(left, right) + } + + private def or(left: Filter, right: Filter): Filter = { + Or(left, right) + } + + private def not(child: Filter): Filter = { + Not(child) + } + + private def startWith(attribute: String, value: String): Filter = { + StringStartsWith(attribute, value) + } + + private def endWith(attribute: String, value: String): Filter = { + StringEndsWith(attribute, value) + } + + private def contains(attribute: String, value: String): Filter = { + StringContains(attribute, value) + } + + private def unsupport(): Filter = { + UnSupportFilter("") + } + + case class UnSupportFilter(value: Any) extends Filter { + override def references: Array[String] = Array.empty + } +} diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala index 00c40abf1c3c7..eba2ac2c54b81 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.functions._ import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} import org.junit.jupiter.params.ParameterizedTest -import org.junit.jupiter.params.provider.ValueSource +import org.junit.jupiter.params.provider.{CsvSource, ValueSource} import scala.collection.JavaConversions._ @@ -614,4 +614,67 @@ class TestMORDataSource extends HoodieClientTestBase { .load(basePath) assertEquals(N + 1, hoodieIncViewDF1.count()) } + + @ParameterizedTest + @CsvSource(Array("true, false", "false, true", "false, false", "true, true")) + def testMORPartitionPrune(partitionEncode: Boolean, hiveStylePartition: Boolean): Unit = { + val partitions = Array("2021/03/01", "2021/03/02", "2021/03/03", "2021/03/04", "2021/03/05") + val newDataGen = new HoodieTestDataGenerator(partitions) + val records1 = newDataGen.generateInsertsContainsAllPartitions("000", 100) + val inputDF1 = spark.read.json(spark.sparkContext.parallelize(recordsToStrings(records1), 2)) + + val partitionCounts = partitions.map(p => p -> records1.count(r => r.getPartitionPath == p)).toMap + + inputDF1.write.format("hudi") + .options(commonOpts) + .option(DataSourceWriteOptions.OPERATION_OPT_KEY, DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL) + .option(DataSourceWriteOptions.TABLE_TYPE_OPT_KEY, DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL) + .option(DataSourceWriteOptions.URL_ENCODE_PARTITIONING_OPT_KEY, partitionEncode) + .option(DataSourceWriteOptions.HIVE_STYLE_PARTITIONING_OPT_KEY, hiveStylePartition) + .mode(SaveMode.Overwrite) + .save(basePath) + + val count1 = spark.read.format("hudi") + .load(basePath) + .filter("partition = '2021/03/01'") + .count() + assertEquals(partitionCounts("2021/03/01"), count1) + + val count2 = spark.read.format("hudi") + .load(basePath) + .filter("partition > '2021/03/01' and partition < '2021/03/03'") + .count() + assertEquals(partitionCounts("2021/03/02"), count2) + + val count3 = spark.read.format("hudi") + .load(basePath) + .filter("partition != '2021/03/01'") + .count() + assertEquals(records1.size() - partitionCounts("2021/03/01"), count3) + + val count4 = spark.read.format("hudi") + .load(basePath) + .filter("partition like '2021/03/03%'") + .count() + assertEquals(partitionCounts("2021/03/03"), count4) + + val count5 = spark.read.format("hudi") + .load(basePath) + .filter("partition like '%2021/03/%'") + .count() + assertEquals(records1.size(), count5) + + val count6 = spark.read.format("hudi") + .load(basePath) + .filter("partition = '2021/03/01' or partition = '2021/03/05'") + .count() + assertEquals(partitionCounts("2021/03/01") + partitionCounts("2021/03/05"), count6) + + val count7 = spark.read.format("hudi") + .load(basePath) + .filter("substr(partition, 9, 10) = '03'") + .count() + + assertEquals(partitionCounts("2021/03/03"), count7) + } }