Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.joda.time.DateTime;
import org.junit.After;
import org.junit.Assert;
import org.junit.Assume;
Expand Down Expand Up @@ -361,6 +362,25 @@ public void addPartitionToPartitioned() {
sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
}

@Test
public void addDataPartitionedByDateToPartitioned() {
createDatePartitionedFileTable("parquet");

String createIceberg =
"CREATE TABLE %s (id Integer, name String, dept String, date Date) USING iceberg PARTITIONED BY (date)";

sql(createIceberg, tableName);

Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('date', '2021-01-01'))",
catalogName, tableName, fileTableDir.getAbsolutePath());

Assert.assertEquals(2L, result);

assertEquals("Iceberg table contains correct data",
sql("SELECT id, name, dept, date FROM %s WHERE date = '2021-01-01' ORDER BY id", sourceTableName),
sql("SELECT id, name, dept, date FROM %s ORDER BY id", tableName));
}

@Test
public void addFilteredPartitionsToPartitioned() {
createCompositePartitionedTable("parquet");
Expand Down Expand Up @@ -779,6 +799,25 @@ public void testPartitionedImportFromEmptyPartitionDoesNotThrow() {
unpartitionedDF.col("dept"),
unpartitionedDF.col("name").as("naMe"));

private static final StructField[] dateStruct = {
new StructField("id", DataTypes.IntegerType, true, Metadata.empty()),
new StructField("name", DataTypes.StringType, true, Metadata.empty()),
new StructField("dept", DataTypes.StringType, true, Metadata.empty()),
new StructField("ts", DataTypes.DateType, true, Metadata.empty())
};

private static java.sql.Date toDate(String value) {
return new java.sql.Date(DateTime.parse(value).getMillis());
}

private static final Dataset<Row> dateDF =
spark.createDataFrame(
ImmutableList.of(
RowFactory.create(1, "John Doe", "hr", toDate("2021-01-01")),
RowFactory.create(2, "Jane Doe", "hr", toDate("2021-01-01")),
RowFactory.create(3, "Matt Doe", "hr", toDate("2021-01-02")),
RowFactory.create(4, "Will Doe", "facilities", toDate("2021-01-02"))),
new StructType(dateStruct)).repartition(2);

private void createUnpartitionedFileTable(String format) {
String createParquet =
Expand Down Expand Up @@ -852,4 +891,13 @@ private void createPartitionedHiveTable() {
partitionedDF.write().insertInto(sourceTableName);
partitionedDF.write().insertInto(sourceTableName);
}

private void createDatePartitionedFileTable(String format) {
String createParquet = "CREATE TABLE %s (id Integer, name String, dept String, date Date) USING %s " +
"PARTITIONED BY (date) LOCATION '%s'";

sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath());

dateDF.write().insertInto(sourceTableName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.Path;
import org.apache.iceberg.MetadataTableType;
import org.apache.iceberg.MetadataTableUtils;
Expand All @@ -48,6 +49,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.spark.SparkTableUtil.SparkPartition;
import org.apache.iceberg.spark.source.SparkTable;
Expand Down Expand Up @@ -77,6 +79,7 @@
import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.execution.datasources.FileStatusCache;
import org.apache.spark.sql.execution.datasources.InMemoryFileIndex;
import org.apache.spark.sql.execution.datasources.PartitionDirectory;
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
Expand Down Expand Up @@ -745,9 +748,11 @@ public static TableIdentifier identifierToTableIdentifier(Identifier identifier)
* @param spark a Spark session
* @param rootPath a table identifier
* @param format format of the file
* @param partitionFilter partitionFilter of the file
* @return all table's partitions
*/
public static List<SparkPartition> getPartitions(SparkSession spark, Path rootPath, String format) {
public static List<SparkPartition> getPartitions(SparkSession spark, Path rootPath, String format,
Map<String, String> partitionFilter) {
FileStatusCache fileStatusCache = FileStatusCache.getOrCreate(spark);
Map<String, String> emptyMap = Collections.emptyMap();

Expand All @@ -768,9 +773,23 @@ public static List<SparkPartition> getPartitions(SparkSession spark, Path rootPa

org.apache.spark.sql.execution.datasources.PartitionSpec spec = fileIndex.partitionSpec();
StructType schema = spec.partitionColumns();
if (schema.isEmpty()) {
return Lists.newArrayList();
}

List<org.apache.spark.sql.catalyst.expressions.Expression> filterExpressions =
SparkUtil.partitionMapToExpression(schema, partitionFilter);
Seq<org.apache.spark.sql.catalyst.expressions.Expression> scalaPartitionFilters =
JavaConverters.asScalaBufferConverter(filterExpressions).asScala().toSeq();

List<org.apache.spark.sql.catalyst.expressions.Expression> dataFilters = Lists.newArrayList();
Seq<org.apache.spark.sql.catalyst.expressions.Expression> scalaDataFilters =
JavaConverters.asScalaBufferConverter(dataFilters).asScala().toSeq();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an easier way to construct an empty sequence? Also, since this is always empty, can you put the dataFilters definition and this line next to one another? The line to create scalaPartitionFilters can be next to the line above that creates filterExpressions.


Seq<PartitionDirectory> filteredPartitions = fileIndex.listFiles(scalaPartitionFilters, scalaDataFilters);

return JavaConverters
.seqAsJavaListConverter(spec.partitions())
.seqAsJavaListConverter(filteredPartitions)
.asJava()
.stream()
.map(partition -> {
Expand All @@ -781,7 +800,11 @@ public static List<SparkPartition> getPartitions(SparkSession spark, Path rootPa
Object value = CatalystTypeConverters.convertToScala(catalystValue, field.dataType());
values.put(field.name(), String.valueOf(value));
});
return new SparkPartition(values, partition.path().toString(), format);

FileStatus fileStatus =
JavaConverters.seqAsJavaListConverter(partition.files()).asJava().get(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this use partition.files() instead of partition.path()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because here partition is PartitionDirectory

case class PartitionDirectory(values: InternalRow, files: Seq[FileStatus])

listFiles returns a Seq of PartitionDirectory

  def listFiles(
      partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory]

Before my change, partition is PartitionPath

case class PartitionPath(values: InternalRow, path: Path)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for the context! I assumed that it would use the same values.


return new SparkPartition(values, fileStatus.getPath().getParent().toString(), format);
}).collect(Collectors.toList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

package org.apache.iceberg.spark;

import java.sql.Date;
import java.sql.Timestamp;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand All @@ -31,14 +34,23 @@
import org.apache.iceberg.hadoop.HadoopConfigurable;
import org.apache.iceberg.io.FileIO;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.transforms.Transform;
import org.apache.iceberg.transforms.UnknownTransform;
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.Pair;
import org.apache.spark.sql.RuntimeConfig;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.BoundReference;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.util.SerializableConfiguration;
import org.joda.time.DateTime;

public class SparkUtil {

Expand Down Expand Up @@ -179,4 +191,63 @@ public static Configuration hadoopConfCatalogOverrides(SparkSession spark, Strin
private static String hadoopConfPrefixForCatalog(String catalogName) {
return String.format(SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR, catalogName);
}

/**
* Get a List of Spark filter Expression.
*
* @param schema table schema
* @param filters filters in the format of a Map, where key is one of the table column name,
* and value is the specific value to be filtered on the column.
* @return a List of filters in the format of Spark Expression.
*/
public static List<Expression> partitionMapToExpression(StructType schema,
Map<String, String> filters) {
List<Expression> filterExpressions = Lists.newArrayList();
for (Map.Entry<String, String> entry : filters.entrySet()) {
try {
int index = schema.fieldIndex(entry.getKey());
DataType dataType = schema.fields()[index].dataType();
BoundReference ref = new BoundReference(index, dataType, true);
switch (dataType.typeName()) {
case "integer":
filterExpressions.add(new EqualTo(ref,
Literal.create(Integer.parseInt(entry.getValue()), DataTypes.IntegerType)));
break;
case "string":
filterExpressions.add(new EqualTo(ref, Literal.create(entry.getValue(), DataTypes.StringType)));
break;
case "short":
filterExpressions.add(new EqualTo(ref,
Literal.create(Short.parseShort(entry.getValue()), DataTypes.ShortType)));
break;
case "long":
filterExpressions.add(new EqualTo(ref,
Literal.create(Long.parseLong(entry.getValue()), DataTypes.LongType)));
break;
case "float":
filterExpressions.add(new EqualTo(ref,
Literal.create(Float.parseFloat(entry.getValue()), DataTypes.FloatType)));
break;
case "double":
filterExpressions.add(new EqualTo(ref,
Literal.create(Double.parseDouble(entry.getValue()), DataTypes.DoubleType)));
break;
case "date":
filterExpressions.add(new EqualTo(ref,
Literal.create(new Date(DateTime.parse(entry.getValue()).getMillis()), DataTypes.DateType)));
break;
case "timestamp":
filterExpressions.add(new EqualTo(ref,
Literal.create(new Timestamp(DateTime.parse(entry.getValue()).getMillis()), DataTypes.TimestampType)));
break;
default:
throw new IllegalStateException("Unexpected data type in partition filters: " + dataType);
}
} catch (IllegalArgumentException e) {
// ignore if filter is not on table columns
}
}

return filterExpressions;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ private static void ensureNameMappingPresent(Table table) {
private void importFileTable(Table table, Path tableLocation, String format, Map<String, String> partitionFilter,
boolean checkDuplicateFiles) {
// List Partitions via Spark InMemory file search interface
List<SparkPartition> partitions = Spark3Util.getPartitions(spark(), tableLocation, format);
List<SparkPartition> partitions =
Spark3Util.getPartitions(spark(), tableLocation, format, partitionFilter);

if (table.spec().isUnpartitioned()) {
Preconditions.checkArgument(partitions.isEmpty(), "Cannot add partitioned files to an unpartitioned table");
Expand All @@ -171,12 +172,8 @@ private void importFileTable(Table table, Path tableLocation, String format, Map
importPartitions(table, ImmutableList.of(partition), checkDuplicateFiles);
} else {
Preconditions.checkArgument(!partitions.isEmpty(),
"Cannot find any partitions in table %s", partitions);
List<SparkPartition> filteredPartitions = SparkTableUtil.filterPartitions(partitions, partitionFilter);
Preconditions.checkArgument(!filteredPartitions.isEmpty(),
"Cannot find any partitions which match the given filter. Partition filter is %s",
MAP_JOINER.join(partitionFilter));
importPartitions(table, filteredPartitions, checkDuplicateFiles);
"Cannot find any matching partitions in table %s", partitions);
importPartitions(table, partitions, checkDuplicateFiles);
}
}

Expand Down