diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 541d39d3c3..8642b951d7 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -88,7 +88,9 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable private val blockingCacheEviction: Boolean = sparkSession.conf.get("spark.chronon.table_write.cache.blocking", "false").toBoolean - private[spark] lazy val tableFormatProvider: FormatProvider = { + // Add transient here because the format provider can sometimes not be serializable. + // for example, BigQueryImpl during reflecting with bq flavor + @transient private[spark] lazy val tableFormatProvider: FormatProvider = { val clazzName = sparkSession.conf.get("spark.chronon.table.format_provider.class", classOf[DefaultFormatProvider].getName) val mirror = runtimeMirror(getClass.getClassLoader) @@ -261,6 +263,48 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable def firstAvailablePartition(tableName: String, subPartitionFilters: Map[String, String] = Map.empty): Option[String] = partitions(tableName, subPartitionFilters).reduceOption((x, y) => Ordering[String].min(x, y)) + def createTable(df: DataFrame, + tableName: String, + partitionColumns: Seq[String] = Seq.empty, + writeFormatTypeString: String = "", + tableProperties: Map[String, String] = null, + fileFormat: String = "PARQUET", + autoExpand: Boolean = false): Boolean = { + val doesTableExist = tableExists(tableName) + + // create table sql doesn't work for bigquery here. instead of creating the table explicitly, we can rely on the + // bq connector to indirectly create the table and eventually write the data + if (writeFormatTypeString.toUpperCase == "BIGQUERY") { + logger.info(s"Skipping table creation in BigQuery for $tableName. tableExists=$doesTableExist") + + return doesTableExist + } + + if (!doesTableExist) { + val creationSql = createTableSql(tableName, df.schema, partitionColumns, tableProperties, fileFormat) + try { + sql(creationSql) + } catch { + case _: TableAlreadyExistsException => + logger.info(s"Table $tableName already exists, skipping creation") + case e: Exception => + logger.error(s"Failed to create table $tableName", e) + throw e + } + } + + // TODO: we need to also allow for bigquery tables to have their table properties (or tags) to be persisted too. + // https://app.asana.com/0/1208949807589885/1209111629687568/f + if (tableProperties != null && tableProperties.nonEmpty) { + sql(alterTablePropertiesSql(tableName, tableProperties)) + } + if (autoExpand) { + expandTable(tableName, df.schema) + } + + true + } + // Needs provider def insertPartitions(df: DataFrame, tableName: String, @@ -279,30 +323,18 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable df } - if (!tableExists(tableName)) { - val creationSql = createTableSql(tableName, dfRearranged.schema, partitionColumns, tableProperties, fileFormat) - try { - sql(creationSql) - } catch { - case _: TableAlreadyExistsException => - logger.info(s"Table $tableName already exists, skipping creation") - case e: Exception => - logger.error(s"Failed to create table $tableName", e) - throw e - } - } - if (tableProperties != null && tableProperties.nonEmpty) { - sql(alterTablePropertiesSql(tableName, tableProperties)) - } - - if (autoExpand) { - expandTable(tableName, dfRearranged.schema) - } + val isTableCreated = createTable(dfRearranged, + tableName, + partitionColumns, + tableFormatProvider.writeFormat(tableName).createTableTypeString, + tableProperties, + fileFormat, + autoExpand) - val finalizedDf = if (autoExpand) { + val finalizedDf = if (autoExpand && isTableCreated) { // reselect the columns so that an deprecated columns will be selected as NULL before write - val updatedSchema = getSchemaFromTable(tableName) - val finalColumns = updatedSchema.fieldNames.map(fieldName => { + val tableSchema = getSchemaFromTable(tableName) + val finalColumns = tableSchema.fieldNames.map(fieldName => { if (dfRearranged.schema.fieldNames.contains(fieldName)) { col(fieldName) } else { @@ -362,13 +394,12 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable saveMode: SaveMode = SaveMode.Overwrite, fileFormat: String = "PARQUET"): Unit = { - if (!tableExists(tableName)) { - sql(createTableSql(tableName, df.schema, Seq.empty[String], tableProperties, fileFormat)) - } else { - if (tableProperties != null && tableProperties.nonEmpty) { - sql(alterTablePropertiesSql(tableName, tableProperties)) - } - } + createTable(df, + tableName, + Seq.empty[String], + tableFormatProvider.writeFormat(tableName).createTableTypeString, + tableProperties, + fileFormat) repartitionAndWrite(df, tableName, saveMode, None) } diff --git a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala index c39e8efe55..dddcba0b4c 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.col import org.apache.spark.sql.types import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue import org.junit.Test @@ -431,4 +432,128 @@ class TableUtilsTest { tableUtils.sql("CREATE TEMPORARY FUNCTION test AS 'ai.chronon.spark.test.SimpleAddUDF'") } + @Test + def testInsertPartitionsTableExistsAlready(): Unit = { + val tableName = "db.test_table_exists_already" + + spark.sql("CREATE DATABASE IF NOT EXISTS db") + val columns = Array( + StructField("long_field", LongType), + StructField("int_field", IntType), + StructField("string_field", StringType), + StructField("ds", StringType) + ) + + // Create the table beforehand + spark.sql(s"CREATE TABLE IF NOT EXISTS $tableName (long_field LONG, int_field INT, string_field STRING, ds STRING)") + + val df1 = makeDf( + spark, + StructType( + tableName, + columns + ), + List( + Row(1L, 2, "3", "2022-10-01") + ) + ) + val df2 = makeDf( + spark, + StructType( + tableName, + columns + ), + List( + Row(1L, 2, "3", "2022-10-02") + ) + ) + + // check if insertion still works + testInsertPartitions(tableName, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02") + } + + @Test + def testCreateTableAlreadyExists(): Unit = { + val tableName = "db.test_create_table_already_exists" + spark.sql("CREATE DATABASE IF NOT EXISTS db") + + val columns = Array( + StructField("long_field", LongType), + StructField("int_field", IntType), + StructField("string_field", StringType) + ) + + spark.sql( + "CREATE TABLE IF NOT EXISTS db.test_create_table_already_exists (long_field LONG, int_field INT, string_field STRING)") + + try { + val df = makeDf( + spark, + StructType( + tableName, + columns + ), + List( + Row(1L, 2, "3") + ) + ) + assertTrue(tableUtils.createTable(df, tableName)) + assertTrue(spark.catalog.tableExists(tableName)) + } finally { + spark.sql(s"DROP TABLE IF EXISTS $tableName") + } + } + + @Test + def testCreateTable(): Unit = { + val tableName = "db.test_create_table" + spark.sql("CREATE DATABASE IF NOT EXISTS db") + try { + val columns = Array( + StructField("long_field", LongType), + StructField("int_field", IntType), + StructField("string_field", StringType) + ) + val df = makeDf( + spark, + StructType( + tableName, + columns + ), + List( + Row(1L, 2, "3") + ) + ) + assertTrue(tableUtils.createTable(df, tableName)) + assertTrue(spark.catalog.tableExists(tableName)) + } finally { + spark.sql(s"DROP TABLE IF EXISTS $tableName") + } + } + + @Test + def testCreateTableBigQuery(): Unit = { + val tableName = "db.test_create_table_bigquery" + spark.sql("CREATE DATABASE IF NOT EXISTS db") + try { + val columns = Array( + StructField("long_field", LongType), + StructField("int_field", IntType), + StructField("string_field", StringType) + ) + val df = makeDf( + spark, + StructType( + tableName, + columns + ), + List(Row(1L, 2, "3")) + ) + assertFalse(tableUtils.createTable(df, tableName, writeFormatTypeString = "BIGQUERY")) + assertFalse(spark.catalog.tableExists(tableName)) + } finally { + spark.sql(s"DROP TABLE IF EXISTS $tableName") + } + } + }