Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 43 additions & 26 deletions spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
private val blockingCacheEviction: Boolean =
sparkSession.conf.get("spark.chronon.table_write.cache.blocking", "false").toBoolean

private[spark] lazy val tableFormatProvider: FormatProvider = {
// Add transient here because it can contain BigQueryImpl during reflection with bq flavor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe swap the 'BigqueryImpl' with the generic phrase - concrete format provider which might not be serializable? (As this isn't restricted to BQ)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing

// and that can't be serialized by Spark
@transient private[spark] lazy val tableFormatProvider: FormatProvider = {
val clazzName =
sparkSession.conf.get("spark.chronon.table.format_provider.class", classOf[DefaultFormatProvider].getName)
val mirror = runtimeMirror(getClass.getClassLoader)
Expand Down Expand Up @@ -261,6 +263,35 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
def firstAvailablePartition(tableName: String, subPartitionFilters: Map[String, String] = Map.empty): Option[String] =
partitions(tableName, subPartitionFilters).reduceOption((x, y) => Ordering[String].min(x, y))

def createTable(df: DataFrame,
tableName: String,
partitionColumns: Seq[String] = Seq.empty,
writeFormatTypeString: String = "",
tableProperties: Map[String, String] = null,
fileFormat: String = "PARQUET",
autoExpand: Boolean = false): Unit = {
// create table sql doesn't work for bigquery here. instead of creating the table explicitly, we can rely on the
// bq connector to indirectly create the table and eventually write the data
if (!tableExists(tableName) && writeFormatTypeString.toUpperCase != "BIGQUERY") {
val creationSql = createTableSql(tableName, df.schema, partitionColumns, tableProperties, fileFormat)
try {
sql(creationSql)
} catch {
case _: TableAlreadyExistsException =>
logger.info(s"Table $tableName already exists, skipping creation")
case e: Exception =>
logger.error(s"Failed to create table $tableName", e)
throw e
}
}
if (tableProperties != null && tableProperties.nonEmpty) {
sql(alterTablePropertiesSql(tableName, tableProperties))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we actually want to persist some of these table properties though. The behavior here isn't quite equivalent for bigquery in that any custom properties we pass through here ultimately don't make it to the bigquery table. That might be problematic down the line.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@david-zlai mind just adding a TODO here so we can take care of it for BigQuery?

}
if (autoExpand) {
expandTable(tableName, df.schema)
}
}

// Needs provider
def insertPartitions(df: DataFrame,
tableName: String,
Expand All @@ -279,25 +310,15 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
df
}

if (!tableExists(tableName)) {
val creationSql = createTableSql(tableName, dfRearranged.schema, partitionColumns, tableProperties, fileFormat)
try {
sql(creationSql)
} catch {
case _: TableAlreadyExistsException =>
logger.info(s"Table $tableName already exists, skipping creation")
case e: Exception =>
logger.error(s"Failed to create table $tableName", e)
throw e
}
}
if (tableProperties != null && tableProperties.nonEmpty) {
sql(alterTablePropertiesSql(tableName, tableProperties))
}
val writeFormatTypeString = tableFormatProvider.writeFormat(tableName).createTableTypeString

if (autoExpand) {
expandTable(tableName, dfRearranged.schema)
}
createTable(dfRearranged,
tableName,
partitionColumns,
writeFormatTypeString,
tableProperties,
fileFormat,
autoExpand)

val finalizedDf = if (autoExpand) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the chain of logic here doesn't match the original then - in this if branch we are trying to retrieve the schema from the table, but it's possible it doesn't yet exist, since autoExpand may be set true for a BigQuery table. getSchemaFromTable will probably throw an exception if the table doesn't exist. So we probably need to pull this df logic into createTable as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird this piece of code is only in insertPartitions but not at all in insertUnPartitioned 🙃

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh...insertUnpartitioned doesn't even have autoExpand in it

// reselect the columns so that an deprecated columns will be selected as NULL before write
Expand Down Expand Up @@ -362,13 +383,9 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
saveMode: SaveMode = SaveMode.Overwrite,
fileFormat: String = "PARQUET"): Unit = {

if (!tableExists(tableName)) {
sql(createTableSql(tableName, df.schema, Seq.empty[String], tableProperties, fileFormat))
} else {
if (tableProperties != null && tableProperties.nonEmpty) {
sql(alterTablePropertiesSql(tableName, tableProperties))
}
}
val writeFormatTypeString = tableFormatProvider.writeFormat(tableName).createTableTypeString

createTable(df, tableName, Seq.empty[String], writeFormatTypeString, tableProperties, fileFormat)

repartitionAndWrite(df, tableName, saveMode, None)
}
Expand Down
51 changes: 49 additions & 2 deletions spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Assert.{assertEquals, assertFalse, assertTrue}
import org.junit.Test

import scala.util.Try
Expand Down Expand Up @@ -431,4 +430,52 @@ class TableUtilsTest {
tableUtils.sql("CREATE TEMPORARY FUNCTION test AS 'ai.chronon.spark.test.SimpleAddUDF'")
}

@Test
def testCreateTable(): Unit = {
val tableName = "db.test_create_table"
spark.sql("CREATE DATABASE IF NOT EXISTS db")

val columns = Array(
StructField("long_field", LongType),
StructField("int_field", IntType),
StructField("string_field", StringType)
)
val df = makeDf(
spark,
StructType(
tableName,
columns
),
List(
Row(1L, 2, "3")
)
)
tableUtils.createTable(df, tableName)
assertTrue(spark.catalog.tableExists(tableName))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add table cleanup and data verification.

The test should:

  1. Clean up the table after the test
  2. Verify the schema and data, not just table existence
 @Test
 def testCreateTable(): Unit = {
   val tableName = "db.test_create_table"
   spark.sql("CREATE DATABASE IF NOT EXISTS db")
+  try {
     val columns = Array(
       StructField("long_field", LongType),
       StructField("int_field", IntType),
       StructField("string_field", StringType)
     )
     val df = makeDf(
       spark,
       StructType(
         tableName,
         columns
       ),
       List(
         Row(1L, 2, "3")
       )
     )
     tableUtils.createTable(df, tableName)
     assertTrue(spark.catalog.tableExists(tableName))
+    val createdTable = spark.table(tableName)
+    assertEquals(df.schema, createdTable.schema)
+    assertEquals(df.collect().toSeq, createdTable.collect().toSeq)
+  } finally {
+    spark.sql(s"DROP TABLE IF EXISTS $tableName")
+  }
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@Test
def testCreateTable(): Unit = {
val tableName = "db.test_create_table"
spark.sql("CREATE DATABASE IF NOT EXISTS db")
val columns = Array(
StructField("long_field", LongType),
StructField("int_field", IntType),
StructField("string_field", StringType)
)
val df = makeDf(
spark,
StructType(
tableName,
columns
),
List(
Row(1L, 2, "3")
)
)
tableUtils.createTable(df, tableName)
assertTrue(spark.catalog.tableExists(tableName))
}
@Test
def testCreateTable(): Unit = {
val tableName = "db.test_create_table"
spark.sql("CREATE DATABASE IF NOT EXISTS db")
try {
val columns = Array(
StructField("long_field", LongType),
StructField("int_field", IntType),
StructField("string_field", StringType)
)
val df = makeDf(
spark,
StructType(
tableName,
columns
),
List(
Row(1L, 2, "3")
)
)
tableUtils.createTable(df, tableName)
assertTrue(spark.catalog.tableExists(tableName))
val createdTable = spark.table(tableName)
assertEquals(df.schema, createdTable.schema)
assertEquals(df.collect().toSeq, createdTable.collect().toSeq)
} finally {
spark.sql(s"DROP TABLE IF EXISTS $tableName")
}
}


@Test
def testCreateTableBigQuery(): Unit = {
val tableName = "db.test_create_table_bigquery"
spark.sql("CREATE DATABASE IF NOT EXISTS db")

val columns = Array(
StructField("long_field", LongType),
StructField("int_field", IntType),
StructField("string_field", StringType)
)
val df = makeDf(
spark,
StructType(
tableName,
columns
),
List(
Row(1L, 2, "3")
)
)
tableUtils.createTable(df, tableName, writeFormatTypeString = "BIGQUERY")
assertFalse(spark.catalog.tableExists(tableName))
}

}
Loading