diff --git a/.circleci/config.yml b/.circleci/config.yml index 5e21b41812..dfa66dadf7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -138,6 +138,62 @@ jobs: command: | conda activate chronon_py sbt +scalafmtCheck + # run these separately as we need a isolated JVM to not have Spark session settings interfere with other runs + # long term goal is to refactor the current testing spark session builder and avoid adding new single test to CI + "Scala 13 -- Iceberg Format Tests": + executor: docker_baseimg_executor + steps: + - checkout + - run: + name: Run Scala 13 tests for Iceberg format + environment: + format_test: iceberg + shell: /bin/bash -leuxo pipefail + command: | + conda activate chronon_py + # Increase if we see OOM. + export SBT_OPTS="-XX:+CMSClassUnloadingEnabled -XX:MaxPermSize=4G -Xmx4G -Xms2G" + sbt ';project spark_embedded; ++ 2.13.6; testOnly ai.chronon.spark.test.TableUtilsFormatTest' + - store_test_results: + path: /chronon/spark/target/test-reports + - store_test_results: + path: /chronon/aggregator/target/test-reports + - run: + name: Compress spark-warehouse + command: | + cd /tmp/ && tar -czvf spark-warehouse.tar.gz chronon/spark-warehouse + when: on_fail + - store_artifacts: + path: /tmp/spark-warehouse.tar.gz + destination: spark_warehouse.tar.gz + when: on_fail + "Scala 13 -- Iceberg Table Utils Tests": + executor: docker_baseimg_executor + steps: + - checkout + - run: + name: Run Scala 13 tests for Iceberg Table Utils + environment: + format_test: iceberg + shell: /bin/bash -leuxo pipefail + command: | + conda activate chronon_py + # Increase if we see OOM. + export SBT_OPTS="-XX:+CMSClassUnloadingEnabled -XX:MaxPermSize=4G -Xmx4G -Xms2G" + sbt ';project spark_embedded; ++ 2.13.6; testOnly ai.chronon.spark.test.TableUtilsTest' + - store_test_results: + path: /chronon/spark/target/test-reports + - store_test_results: + path: /chronon/aggregator/target/test-reports + - run: + name: Compress spark-warehouse + command: | + cd /tmp/ && tar -czvf spark-warehouse.tar.gz chronon/spark-warehouse + when: on_fail + - store_artifacts: + path: /tmp/spark-warehouse.tar.gz + destination: spark_warehouse.tar.gz + when: on_fail workflows: build_test_deploy: @@ -161,3 +217,9 @@ workflows: - "Chronon Python Lint": requires: - "Pull Docker Image" + - "Scala 13 -- Iceberg Format Tests": + requires: + - "Pull Docker Image" + - "Scala 13 -- Iceberg Table Utils Tests": + requires: + - "Pull Docker Image" \ No newline at end of file diff --git a/.gitignore b/.gitignore index 90d6e9b021..d276e1b53a 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,11 @@ *.logs *.iml .idea/ +*.jvmopts +.bloop* +.metals* +.venv* +*metals.sbt* .eclipse **/.vscode/ **/__pycache__/ @@ -21,7 +26,9 @@ api/py/.coverage api/py/htmlcov/ **/derby.log cs - +*.bloop +*.metals +*.venv # Documentation builds docs/build/ diff --git a/build.sbt b/build.sbt index 438c034971..932f6e6259 100644 --- a/build.sbt +++ b/build.sbt @@ -13,6 +13,7 @@ lazy val spark3_1_1 = "3.1.1" lazy val spark3_2_1 = "3.2.1" lazy val spark3_5_3 = "3.5.3" lazy val tmp_warehouse = "/tmp/chronon/" +lazy val icebergVersion = "1.1.0" ThisBuild / organization := "ai.chronon" ThisBuild / organizationName := "chronon" @@ -189,6 +190,16 @@ val VersionMatrix: Map[String, VersionDependency] = Map( Some("1.0.1"), Some("2.0.2") ), + //3.2 is the minimum version for iceberg + // due to INSERT_INTO support without specifying iceberg format + "iceberg32" -> VersionDependency( + Seq( + "org.apache.iceberg" %% "iceberg-spark-runtime-3.2", + ), + None, + None, + Some(icebergVersion), + ), "jackson" -> VersionDependency( Seq( "com.fasterxml.jackson.core" % "jackson-core", @@ -415,7 +426,7 @@ lazy val spark_uber = (project in file("spark")) libraryDependencies ++= (if (use_spark_3_5.value) fromMatrix(scalaVersion.value, "jackson", "spark-all-3-5/provided", "delta-core/provided") else - fromMatrix(scalaVersion.value, "jackson", "spark-all/provided", "delta-core/provided")) + fromMatrix(scalaVersion.value, "jackson", "spark-all/provided", "delta-core/provided", "iceberg32/provided")), ) lazy val spark_embedded = (project in file("spark")) @@ -427,7 +438,7 @@ lazy val spark_embedded = (project in file("spark")) libraryDependencies ++= (if (use_spark_3_5.value) fromMatrix(scalaVersion.value, "spark-all-3-5", "delta-core") else - fromMatrix(scalaVersion.value, "spark-all", "delta-core")), + fromMatrix(scalaVersion.value, "spark-all", "delta-core", "iceberg32")), target := target.value.toPath.resolveSibling("target-embedded").toFile, Test / test := {} ) diff --git a/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala b/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala index 931efe6080..f5365d4995 100644 --- a/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala +++ b/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala @@ -178,3 +178,44 @@ class ChrononDeltaLakeKryoRegistrator extends ChrononKryoRegistrator { additionalDeltaNames.foreach(name => doRegister(name, kryo)) } } + +class ChrononIcebergKryoRegistrator extends ChrononKryoRegistrator { + override def registerClasses(kryo: Kryo): Unit = { + super.registerClasses(kryo) + val additionalIcebergNames = Seq( + "org.apache.iceberg.spark.source.SerializableTableWithSize", + "org.apache.iceberg.encryption.PlaintextEncryptionManager", + "org.apache.iceberg.hadoop.HadoopFileIO", + "org.apache.iceberg.SerializableTable$SerializableConfSupplier", + "org.apache.iceberg.util.SerializableMap", + "org.apache.iceberg.LocationProviders$DefaultLocationProvider", + "org.apache.iceberg.spark.source.SparkWrite$TaskCommit", + "org.apache.iceberg.DataFile", + "org.apache.iceberg.GenericDataFile", + "org.apache.iceberg.FileContent", + "org.apache.iceberg.FileFormat", + "org.apache.iceberg.SerializableByteBufferMap", + "org.apache.iceberg.PartitionData", + // For some reasons just .Types doesn't work + "org.apache.iceberg.types.Types$StructType", + "org.apache.iceberg.types.Types$NestedField", + "org.apache.iceberg.types.Types$StringType", + "org.apache.iceberg.types.Types$IntegerType", + "org.apache.iceberg.types.Types$LongType", + "org.apache.iceberg.types.Types$DoubleType", + "org.apache.iceberg.types.Types$FloatType", + "org.apache.iceberg.types.Types$BooleanType", + "org.apache.iceberg.types.Types$DateType", + "org.apache.iceberg.types.Types$TimestampType", + "org.apache.iceberg.types.Types$TimeType", + "org.apache.iceberg.types.Types$DecimalType", + "org.apache.iceberg.types.Types$NestedField$", + "org.apache.iceberg.SnapshotRef", + "org.apache.iceberg.SnapshotRefType", + "org.apache.iceberg.spark.source.SerializableTableWithSize$SerializableMetadataTableWithSize", + "org.apache.iceberg.MetadataTableType", + "org.apache.iceberg.BaseFile$1" + ) + additionalIcebergNames.foreach(name => doRegister(name, kryo)) + } +} diff --git a/spark/src/main/scala/ai/chronon/spark/Extensions.scala b/spark/src/main/scala/ai/chronon/spark/Extensions.scala index e080d87a4e..c97fcbe5d2 100644 --- a/spark/src/main/scala/ai/chronon/spark/Extensions.scala +++ b/spark/src/main/scala/ai/chronon/spark/Extensions.scala @@ -83,7 +83,7 @@ object Extensions { .groupBy(col(TableUtils(dataFrame.sparkSession).partitionColumn)) .count() .collect() - .map(row => row.getString(0) -> row.getLong(1)) + .map(row => row.get(0).toString -> row.getLong(1)) .toMap DfWithStats(dataFrame, partitionCounts) } diff --git a/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala b/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala index 1dfc4fc64a..d1c8eacda4 100644 --- a/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala +++ b/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala @@ -23,6 +23,7 @@ import org.apache.spark.SPARK_VERSION import java.io.File import java.util.logging.Logger import scala.util.Properties +import java.util.UUID object SparkSessionBuilder { @transient private lazy val logger = LoggerFactory.getLogger(getClass) @@ -39,6 +40,8 @@ object SparkSessionBuilder { additionalConfig: Option[Map[String, String]] = None, enforceKryoSerializer: Boolean = true): SparkSession = { + val userName = Properties.userName + val warehouseDir = localWarehouseLocation.map(expandUser).getOrElse(DefaultWarehouseDir.getAbsolutePath) // allow us to override the format by specifying env vars. This allows us to not have to worry about interference // between Spark sessions created in existing chronon tests that need the hive format and some specific tests // that require a format override like delta lake. @@ -50,6 +53,17 @@ object SparkSessionBuilder { "spark.chronon.table_write.format" -> "delta" ) (configMap, "ai.chronon.spark.ChrononDeltaLakeKryoRegistrator") + case Some("iceberg") => + val configMap = Map( + "spark.sql.extensions" -> "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions", + "spark.sql.catalog.spark_catalog" -> "org.apache.iceberg.spark.SparkSessionCatalog", + "spark.chronon.table_write.format" -> "iceberg", + "spark.chronon.table_read.format" -> "iceberg", + "spark.sql.catalog.local" -> "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.spark_catalog.type" -> "hadoop", + "spark.sql.catalog.spark_catalog.warehouse" -> s"$warehouseDir/data" + ) + (configMap, "ai.chronon.spark.ChrononIcebergKryoRegistrator") case _ => (Map.empty, "ai.chronon.spark.ChrononKryoRegistrator") } @@ -60,8 +74,7 @@ object SparkSessionBuilder { //required to run spark locally with hive support enabled - for sbt test System.setSecurityManager(null) } - val userName = Properties.userName - val warehouseDir = localWarehouseLocation.map(expandUser).getOrElse(DefaultWarehouseDir.getAbsolutePath) + var baseBuilder = SparkSession .builder() .appName(name) diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 23c9238bca..0486534b58 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -239,9 +239,15 @@ case object Iceberg extends Format { override def partitions(tableName: String, partitionColumns: Seq[String])(implicit sparkSession: SparkSession): Seq[Map[String, String]] = { sparkSession.sqlContext - .sql(s"SHOW PARTITIONS $tableName") + .sql(s"SELECT partition FROM $tableName" ++ ".partitions") .collect() - .map(row => parseHivePartition(row.getString(0))) + .map { row => + val partitionStruct = row.getStruct(0) + partitionStruct.schema.fieldNames.zipWithIndex.map { + case (fieldName, idx) => + fieldName -> partitionStruct.get(idx).toString + }.toMap + } } private def getIcebergPartitions(tableName: String, @@ -394,7 +400,14 @@ case class TableUtils(sparkSession: SparkSession) { rdd } - def tableExists(tableName: String): Boolean = sparkSession.catalog.tableExists(tableName) + def tableExists(tableName: String): Boolean = { + try { + sparkSession.sql(s"DESCRIBE TABLE $tableName") + true + } catch { + case _: AnalysisException => false + } + } def loadEntireTable(tableName: String): DataFrame = sparkSession.table(tableName) @@ -972,17 +985,39 @@ case class TableUtils(sparkSession: SparkSession) { partitions: Seq[String], partitionColumn: String = partitionColumn, subPartitionFilters: Map[String, String] = Map.empty): Unit = { + // TODO this is using datasource v1 semantics, which won't be compatible with non-hive catalogs + // notably, the unit test iceberg integration uses hadoop because of + // https://github.com/apache/iceberg/issues/7847 if (partitions.nonEmpty && tableExists(tableName)) { - val partitionSpecs = partitions - .map { partition => - val mainSpec = s"$partitionColumn='$partition'" - val specs = mainSpec +: subPartitionFilters.map { - case (key, value) => s"${key}='${value}'" - }.toSeq - specs.mkString("PARTITION (", ",", ")") - } - .mkString(",") - val dropSql = s"ALTER TABLE $tableName DROP IF EXISTS $partitionSpecs" + val dropSql = tableFormatProvider.readFormat(tableName) match { + // really this is Dsv1 vs Dsv2, not hive vs iceberg, + // but we break this way since only Iceberg is migrated to Dsv2 + case Iceberg => + // Build WHERE clause: (ds='2024-05-01' OR ds='2024-05-02') [AND k='v' AND …] + val mainPred = partitions + .map(p => s"$partitionColumn='${p}'") + .mkString("(", " OR ", ")") + + val extraPred = subPartitionFilters + .map { case (k, v) => s"$k='${v}'" } + .mkString(" AND ") + + val where = Seq(mainPred, extraPred).filter(_.nonEmpty).mkString(" AND ") + + s"DELETE FROM $tableName WHERE $where" + case _ => + // default case is for Hive + val partitionSpecs = partitions + .map { partition => + val mainSpec = s"$partitionColumn='$partition'" + val specs = mainSpec +: subPartitionFilters.map { + case (key, value) => s"${key}='${value}'" + }.toSeq + specs.mkString("PARTITION (", ",", ")") + } + .mkString(",") + s"ALTER TABLE $tableName DROP IF EXISTS $partitionSpecs" + } sql(dropSql) } else { logger.info(s"$tableName doesn't exist, please double check before drop partitions") diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExtensionsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExtensionsTest.scala index aecee8bfe7..73198cfb1a 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ExtensionsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ExtensionsTest.scala @@ -43,4 +43,19 @@ class ExtensionsTest { } assertEquals(0, diff.count()) } + + @Test + def testDfWithStatsLongPartition(): Unit = { + val df = Seq( + (1, 20240103L), + (2, 20240104L), + (3, 20240104L) + ).toDF("key", "ds") + + val dfWithStats: DfWithStats = DfWithStats(df) + val stats = dfWithStats.stats + + assertEquals(3L, stats.count) + } + } \ No newline at end of file diff --git a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala index 397986f63a..23dda045bc 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala @@ -420,6 +420,88 @@ class JoinTest { assertEquals(diff.count(), 0) } + @Test + def testEventsEventsTemporalLongDs(): Unit = { + val spark: SparkSession = SparkSessionBuilder.build("JoinTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + spark.conf.set("spark.chronon.partition.format", "yyyy-MM-dd") + val tableUtils = TableUtils(spark) + val namespace = "test_namespace_jointest" + "_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + val joinConf = getEventsEventsTemporal("temporal", namespace) + val viewsSchema = List( + Column("user", api.StringType, 10000), + Column("item", api.StringType, 100), + Column("time_spent_ms", api.LongType, 5000) + ) + + val viewsTable = s"$namespace.view_temporal" + DataFrameGen.events(spark, viewsSchema, count = 1000, partitions = 200) + .withColumn("ds", col("ds").cast("long")) + .save(viewsTable, Map("tblProp1" -> "1")) + + val viewsSource = Builders.Source.events( + table = viewsTable, + query = Builders.Query(selects = Builders.Selects("time_spent_ms"), startPartition = yearAgo) + ) + val viewsGroupBy = Builders.GroupBy( + sources = Seq(viewsSource), + keyColumns = Seq("item"), + aggregations = Seq( + Builders.Aggregation(operation = Operation.AVERAGE, inputColumn = "time_spent_ms"), + Builders.Aggregation(operation = Operation.MIN, inputColumn = "ts"), + Builders.Aggregation(operation = Operation.MAX, inputColumn = "ts") + ), + metaData = Builders.MetaData(name = "unit_test.item_views", namespace = namespace) + ) + + // left side + val itemQueries = List(Column("item", api.StringType, 100)) + val itemQueriesTable = s"$namespace.item_queries" + val itemQueriesDf = DataFrameGen + .events(spark, itemQueries, 1000, partitions = 100) + // duplicate the events + itemQueriesDf.union(itemQueriesDf).save(itemQueriesTable) + + val start = tableUtils.partitionSpec.minus(today, new Window(100, TimeUnit.DAYS)) + (new Analyzer(tableUtils, joinConf, monthAgo, today)).run() + val join = new Join(joinConf = joinConf, endPartition = dayAndMonthBefore, tableUtils) + val computed = join.computeJoin(Some(100)) + computed.show() + + val expected = tableUtils.sql(s""" + |WITH + | queries AS (SELECT item, ts, ds from $itemQueriesTable where ds >= '$start' and ds <= '$dayAndMonthBefore') + | SELECT queries.item, queries.ts, queries.ds, part.user_unit_test_item_views_ts_min, part.user_unit_test_item_views_ts_max, part.user_unit_test_item_views_time_spent_ms_average + | FROM (SELECT queries.item, + | queries.ts, + | queries.ds, + | MIN(IF(queries.ts > $viewsTable.ts, $viewsTable.ts, null)) as user_unit_test_item_views_ts_min, + | MAX(IF(queries.ts > $viewsTable.ts, $viewsTable.ts, null)) as user_unit_test_item_views_ts_max, + | AVG(IF(queries.ts > $viewsTable.ts, time_spent_ms, null)) as user_unit_test_item_views_time_spent_ms_average + | FROM queries left outer join $viewsTable + | ON queries.item = $viewsTable.item + | WHERE $viewsTable.item IS NOT NULL AND $viewsTable.ds >= '$yearAgo' AND $viewsTable.ds <= '$dayAndMonthBefore' + | GROUP BY queries.item, queries.ts, queries.ds) as part + | JOIN queries + | ON queries.item <=> part.item AND queries.ts <=> part.ts AND queries.ds <=> part.ds + |""".stripMargin) + expected.show() + + val diff = Comparison.sideBySide(computed, expected, List("item", "ts", "ds")) + val queriesBare = + tableUtils.sql(s"SELECT item, ts, ds from $itemQueriesTable where ds >= '$start' and ds <= '$dayAndMonthBefore'") + assertEquals(queriesBare.count(), computed.count()) + if (diff.count() > 0) { + logger.debug(s"Diff count: ${diff.count()}") + logger.debug(s"diff result rows") + diff + .replaceWithReadableTime(Seq("ts", "a_user_unit_test_item_views_ts_max", "b_user_unit_test_item_views_ts_max"), + dropOriginal = true) + .show() + } + assertEquals(diff.count(), 0) + } + @Test def testEventsEventsCumulative(): Unit = { val spark: SparkSession = SparkSessionBuilder.build("JoinTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) diff --git a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala index d4c8b806ad..ad0cb54906 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala @@ -89,7 +89,7 @@ class TableUtilsFormatTest { Row(5L, 6, "7", "2022-10-02") ) ) - testInsertPartitions(spark, tableUtils, tableName, format, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02") + testInsertPartitions(spark, tableUtils, tableName, format, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02") } @Test diff --git a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala index 38417e7039..6eb3414582 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala @@ -22,6 +22,7 @@ import ai.chronon.spark.test.TestUtils.makeDf import ai.chronon.api.{StructField, _} import ai.chronon.online.SparkConversions import ai.chronon.spark.{IncompatibleSchemaException, PartitionRange, SparkSessionBuilder, TableUtils} +import ai.chronon.spark.SparkSessionBuilder.FormatTestEnvVar import org.apache.hadoop.hive.ql.exec.UDF import org.apache.spark.sql.functions.col import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession, types} @@ -31,8 +32,6 @@ import org.junit.Test import java.time.Instant import scala.util.{Random, Try} - - class SimpleAddUDF extends UDF { def evaluate(value: Int): Int = { value + 20 @@ -40,6 +39,7 @@ class SimpleAddUDF extends UDF { } class TableUtilsTest { + val format: String = sys.env.getOrElse(FormatTestEnvVar, "hive") lazy val spark: SparkSession = SparkSessionBuilder.build("TableUtilsTest", local = true) private val tableUtils = TableUtils(spark) @@ -75,10 +75,12 @@ class TableUtilsTest { Seq( types.StructField("name", types.StringType, nullable = true), types.StructField("age", types.IntegerType, nullable = false), - types.StructField("address", types.StructType(Seq( - types.StructField("street", types.StringType, nullable = true), - types.StructField("city", types.StringType, nullable = true) - ))) + types.StructField("address", + types.StructType( + Seq( + types.StructField("street", types.StringType, nullable = true), + types.StructField("city", types.StringType, nullable = true) + ))) ) ) val expectedFieldNames = Seq("name", "age", "address", "address.street", "address.city") @@ -286,13 +288,11 @@ class TableUtilsTest { |""".stripMargin) assertEquals(updated.count(), 2) assertTrue( - updated - .collect() - .sameElements( - List( - Row(1L, 2, "2022-10-01", "2022-11-01"), - Row(3L, 8, "2022-10-05", "2022-11-03") - ))) + updated.collect().toSet == + Set( + Row(1L, 2, "2022-10-01", "2022-11-01"), + Row(3L, 8, "2022-10-05", "2022-11-03") + )) } @Test @@ -340,7 +340,9 @@ class TableUtilsTest { PartitionRange("2022-10-05", "2022-10-05")(tableUtils))) } - private def prepareTestDataWithSubPartitionsWithView(tableName: String, viewName: String, partitionColOpt: Option[String] = None): Unit = { + private def prepareTestDataWithSubPartitionsWithView(tableName: String, + viewName: String, + partitionColOpt: Option[String] = None): Unit = { prepareTestDataWithSubPartitions(tableName, partitionColOpt) tableUtils.sql(s"CREATE OR REPLACE VIEW $viewName AS SELECT * FROM $tableName") } @@ -367,10 +369,10 @@ class TableUtilsTest { Row(3L, "2022-11-03", "2022-11-03") ) ) - tableUtils.insertPartitions(df1, - tableName, - partitionColumns = Seq(partitionColOpt.getOrElse(tableUtils.partitionColumn), - Constants.LabelPartitionColumn)) + tableUtils.insertPartitions( + df1, + tableName, + partitionColumns = Seq(partitionColOpt.getOrElse(tableUtils.partitionColumn), Constants.LabelPartitionColumn)) } @@ -456,14 +458,16 @@ class TableUtilsTest { // test that chronon_archived flag exists and is set to true val tblProps = tableUtils.sql(s"SHOW TBLPROPERTIES $dbName.$archiveTableName").collect() val mapVal = readTblPropertiesMap(tblProps) - assert(mapVal.getOrElse("chronon_archived","false") == "true") + assert(mapVal.getOrElse("chronon_archived", "false") == "true") // test after a un-archive we can remove chronon_archived property tableUtils.sql(s"ALTER TABLE $dbName.$archiveTableName RENAME TO $tableName") - tableUtils.alterTableProperties(tableName, properties = Map("chronon_archived" -> "true"), unsetProperties = Seq("chronon_archived")) + tableUtils.alterTableProperties(tableName, + properties = Map("chronon_archived" -> "true"), + unsetProperties = Seq("chronon_archived")) val tblPropsAfter = tableUtils.sql(s"SHOW TBLPROPERTIES $tableName").collect() val mapValAfter = readTblPropertiesMap(tblPropsAfter) - assert(mapValAfter.getOrElse("chronon_archived","false") == "false") + assert(mapValAfter.getOrElse("chronon_archived", "false") == "false") } @Test @@ -487,7 +491,8 @@ class TableUtilsTest { val tableName = s"$dbName.test_table" val viewName = s"$dbName.v_test_table" tableUtils.sql(s"CREATE DATABASE IF NOT EXISTS $dbName") - tableUtils.sql(s"CREATE TABLE IF NOT EXISTS $tableName (test INT, test_col STRING) PARTITIONED BY (ds STRING) STORED AS PARQUET") + tableUtils.sql( + s"CREATE TABLE IF NOT EXISTS $tableName (test INT, test_col STRING) PARTITIONED BY (ds STRING) STORED AS PARQUET") tableUtils.sql(s"CREATE OR REPLACE VIEW $viewName AS SELECT test, test_col FROM $tableName") val table_format = tableUtils.tableReadFormat(tableName) @@ -512,7 +517,7 @@ class TableUtilsTest { val viewName = "db.v_test_table_with_sub_partition" val partitionCol = "custom_partition_date" prepareTestDataWithSubPartitionsWithView(tableName, viewName, partitionColOpt = Some(partitionCol)) - val partitions = tableUtils.partitions(viewName, partitionColOpt=Some(partitionCol)) + val partitions = tableUtils.partitions(viewName, partitionColOpt = Some(partitionCol)) assertEquals(Seq("2022-11-01", "2022-11-02", "2022-11-03").sorted, partitions.sorted) } @@ -522,7 +527,9 @@ class TableUtilsTest { val viewName = "db.v_test_table_with_sub_partition" val partitionCol = "custom_partition_date" prepareTestDataWithSubPartitionsWithView(tableName, viewName, partitionColOpt = Some(partitionCol)) - val partitions = tableUtils.partitions(viewName, subPartitionsFilter=Map("label_ds" -> "2022-11-02"), partitionColOpt=Some(partitionCol)) + val partitions = tableUtils.partitions(viewName, + subPartitionsFilter = Map("label_ds" -> "2022-11-02"), + partitionColOpt = Some(partitionCol)) assertEquals(Seq("2022-11-01", "2022-11-02").sorted, partitions.sorted) } @@ -547,4 +554,75 @@ class TableUtilsTest { assertTrue(firstDs.contains("2022-11-01")) } } + + @Test + def testGetPartitionsWithLongPartition(): Unit = { + val tableName = "db.test_long_partitions" + spark.sql("CREATE DATABASE IF NOT EXISTS db") + val structFields = Array( + StructField("dateInt", LongType), + StructField("hour", IntType), + StructField("eventType", StringType), + StructField("labelDs", StringType), + StructField("featureValue", IntType) + ) + + val rows = List( + Row(20220101L, 1, "event1", "2022-01-01", 4), // 2022-01-01 with hr=1 + Row(20220102L, 2, "event2", "2022-01-02", 2), // 2022-01-02 with hr=2 + Row(20220103L, 10, "event1", "2022-01-03", 9), // 2022-01-03 with hr=10 + Row(20220104L, 12, "event1", "20224-01-04", 12) // 2022-01-04 with hr=12 + ) + + val df1 = makeDf( + spark, + StructType( + tableName, + structFields + ), + rows + ) + val partitionColumns = Seq("dateInt", "hour", "eventType") + tableUtils.insertPartitions(df1, tableName, partitionColumns = partitionColumns) + assert(tableUtils.tableExists(tableName)) + val partitions = tableUtils.partitions(tableName, Map.empty, partitionColOpt = Some("dateInt")) + assert(partitions.size == 4) + assert(tableUtils.allPartitions(tableName).size == 4) + } + + @Test + def testInsertPartitionsRemoveColumnsLongDs(): Unit = { + val tableName = "db.test_table_long_2" + spark.sql("CREATE DATABASE IF NOT EXISTS db") + val columns1 = Array( + StructField("longField", LongType), + StructField("intField", IntType), + StructField("stringField", StringType) + ) + val df1 = makeDf( + spark, + StructType( + tableName, + columns1 + :+ StructField("doubleField", DoubleType) + :+ StructField("ds", LongType) + ), + List( + Row(1L, 2, "3", 4.0, 20221001L) + ) + ) + + val df2 = makeDf( + spark, + StructType( + tableName, + columns1 :+ StructField("ds", LongType) + ), + List( + Row(5L, 6, "7", 20221002L) + ) + ) + testInsertPartitions(tableName, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02") + } + }