Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 61 additions & 30 deletions spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
private val blockingCacheEviction: Boolean =
sparkSession.conf.get("spark.chronon.table_write.cache.blocking", "false").toBoolean

private[spark] lazy val tableFormatProvider: FormatProvider = {
// Add transient here because the format provider can sometimes not be serializable.
// for example, BigQueryImpl during reflecting with bq flavor
@transient private[spark] lazy val tableFormatProvider: FormatProvider = {
val clazzName =
sparkSession.conf.get("spark.chronon.table.format_provider.class", classOf[DefaultFormatProvider].getName)
val mirror = runtimeMirror(getClass.getClassLoader)
Expand Down Expand Up @@ -261,6 +263,48 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
def firstAvailablePartition(tableName: String, subPartitionFilters: Map[String, String] = Map.empty): Option[String] =
partitions(tableName, subPartitionFilters).reduceOption((x, y) => Ordering[String].min(x, y))

def createTable(df: DataFrame,
tableName: String,
partitionColumns: Seq[String] = Seq.empty,
writeFormatTypeString: String = "",
tableProperties: Map[String, String] = null,
fileFormat: String = "PARQUET",
autoExpand: Boolean = false): Boolean = {
val doesTableExist = tableExists(tableName)

// create table sql doesn't work for bigquery here. instead of creating the table explicitly, we can rely on the
// bq connector to indirectly create the table and eventually write the data
if (writeFormatTypeString.toUpperCase == "BIGQUERY") {
logger.info(s"Skipping table creation in BigQuery for $tableName. tableExists=$doesTableExist")

return doesTableExist
}

if (!doesTableExist) {
val creationSql = createTableSql(tableName, df.schema, partitionColumns, tableProperties, fileFormat)
try {
sql(creationSql)
} catch {
case _: TableAlreadyExistsException =>
logger.info(s"Table $tableName already exists, skipping creation")
case e: Exception =>
logger.error(s"Failed to create table $tableName", e)
throw e
}
}

// TODO: we need to also allow for bigquery tables to have their table properties (or tags) to be persisted too.
// https://app.asana.com/0/1208949807589885/1209111629687568/f
if (tableProperties != null && tableProperties.nonEmpty) {
sql(alterTablePropertiesSql(tableName, tableProperties))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we actually want to persist some of these table properties though. The behavior here isn't quite equivalent for bigquery in that any custom properties we pass through here ultimately don't make it to the bigquery table. That might be problematic down the line.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@david-zlai mind just adding a TODO here so we can take care of it for BigQuery?

}
if (autoExpand) {
expandTable(tableName, df.schema)
}

true
}

// Needs provider
def insertPartitions(df: DataFrame,
tableName: String,
Expand All @@ -279,30 +323,18 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
df
}

if (!tableExists(tableName)) {
val creationSql = createTableSql(tableName, dfRearranged.schema, partitionColumns, tableProperties, fileFormat)
try {
sql(creationSql)
} catch {
case _: TableAlreadyExistsException =>
logger.info(s"Table $tableName already exists, skipping creation")
case e: Exception =>
logger.error(s"Failed to create table $tableName", e)
throw e
}
}
if (tableProperties != null && tableProperties.nonEmpty) {
sql(alterTablePropertiesSql(tableName, tableProperties))
}

if (autoExpand) {
expandTable(tableName, dfRearranged.schema)
}
val isTableCreated = createTable(dfRearranged,
tableName,
partitionColumns,
tableFormatProvider.writeFormat(tableName).createTableTypeString,
tableProperties,
fileFormat,
autoExpand)

val finalizedDf = if (autoExpand) {
val finalizedDf = if (autoExpand && isTableCreated) {
Copy link
Collaborator

@tchow-zlai tchow-zlai Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@david-zlai will this be correct if the table already exists and therefore does not need to be created? I think looking at the code yes, but could you add a unit test to verify that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a test

// reselect the columns so that an deprecated columns will be selected as NULL before write
val updatedSchema = getSchemaFromTable(tableName)
val finalColumns = updatedSchema.fieldNames.map(fieldName => {
val tableSchema = getSchemaFromTable(tableName)
val finalColumns = tableSchema.fieldNames.map(fieldName => {
if (dfRearranged.schema.fieldNames.contains(fieldName)) {
col(fieldName)
} else {
Expand Down Expand Up @@ -362,13 +394,12 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
saveMode: SaveMode = SaveMode.Overwrite,
fileFormat: String = "PARQUET"): Unit = {

if (!tableExists(tableName)) {
sql(createTableSql(tableName, df.schema, Seq.empty[String], tableProperties, fileFormat))
} else {
if (tableProperties != null && tableProperties.nonEmpty) {
sql(alterTablePropertiesSql(tableName, tableProperties))
}
}
createTable(df,
tableName,
Seq.empty[String],
tableFormatProvider.writeFormat(tableName).createTableTypeString,
tableProperties,
fileFormat)

repartitionAndWrite(df, tableName, saveMode, None)
}
Expand Down
125 changes: 125 additions & 0 deletions spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.junit.Test

Expand Down Expand Up @@ -431,4 +432,128 @@ class TableUtilsTest {
tableUtils.sql("CREATE TEMPORARY FUNCTION test AS 'ai.chronon.spark.test.SimpleAddUDF'")
}

@Test
def testInsertPartitionsTableExistsAlready(): Unit = {
val tableName = "db.test_table_exists_already"

spark.sql("CREATE DATABASE IF NOT EXISTS db")
val columns = Array(
StructField("long_field", LongType),
StructField("int_field", IntType),
StructField("string_field", StringType),
StructField("ds", StringType)
)

// Create the table beforehand
spark.sql(s"CREATE TABLE IF NOT EXISTS $tableName (long_field LONG, int_field INT, string_field STRING, ds STRING)")

val df1 = makeDf(
spark,
StructType(
tableName,
columns
),
List(
Row(1L, 2, "3", "2022-10-01")
)
)
val df2 = makeDf(
spark,
StructType(
tableName,
columns
),
List(
Row(1L, 2, "3", "2022-10-02")
)
)

// check if insertion still works
testInsertPartitions(tableName, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02")
}

@Test
def testCreateTableAlreadyExists(): Unit = {
val tableName = "db.test_create_table_already_exists"
spark.sql("CREATE DATABASE IF NOT EXISTS db")

val columns = Array(
StructField("long_field", LongType),
StructField("int_field", IntType),
StructField("string_field", StringType)
)

spark.sql(
"CREATE TABLE IF NOT EXISTS db.test_create_table_already_exists (long_field LONG, int_field INT, string_field STRING)")

try {
val df = makeDf(
spark,
StructType(
tableName,
columns
),
List(
Row(1L, 2, "3")
)
)
assertTrue(tableUtils.createTable(df, tableName))
assertTrue(spark.catalog.tableExists(tableName))
} finally {
spark.sql(s"DROP TABLE IF EXISTS $tableName")
}
}

@Test
def testCreateTable(): Unit = {
val tableName = "db.test_create_table"
spark.sql("CREATE DATABASE IF NOT EXISTS db")
try {
val columns = Array(
StructField("long_field", LongType),
StructField("int_field", IntType),
StructField("string_field", StringType)
)
val df = makeDf(
spark,
StructType(
tableName,
columns
),
List(
Row(1L, 2, "3")
)
)
assertTrue(tableUtils.createTable(df, tableName))
assertTrue(spark.catalog.tableExists(tableName))
} finally {
spark.sql(s"DROP TABLE IF EXISTS $tableName")
}
}

@Test
def testCreateTableBigQuery(): Unit = {
val tableName = "db.test_create_table_bigquery"
spark.sql("CREATE DATABASE IF NOT EXISTS db")
try {
val columns = Array(
StructField("long_field", LongType),
StructField("int_field", IntType),
StructField("string_field", StringType)
)
val df = makeDf(
spark,
StructType(
tableName,
columns
),
List(Row(1L, 2, "3"))
)
assertFalse(tableUtils.createTable(df, tableName, writeFormatTypeString = "BIGQUERY"))
assertFalse(spark.catalog.tableExists(tableName))
} finally {
spark.sql(s"DROP TABLE IF EXISTS $tableName")
}
}

}
Loading