Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ class BigQueryCatalogTest extends AnyFlatSpec with MockitoSugar {
val nativeTable = "data.sample_native"
val table = tableUtils.loadTable(nativeTable)
table.show
val partitioned = tableUtils.isPartitioned(nativeTable)
println(partitioned)
// val database = tableUtils.createDatabase("test_database")
val allParts = tableUtils.allPartitions(nativeTable)
println(allParts)
Expand All @@ -80,8 +78,6 @@ class BigQueryCatalogTest extends AnyFlatSpec with MockitoSugar {
println(bs)
val table = tableUtils.loadTable(externalTable)
table.show
val partitioned = tableUtils.isPartitioned(externalTable)
println(partitioned)
// val database = tableUtils.createDatabase("test_database")
val allParts = tableUtils.allPartitions(externalTable)
println(allParts)
Expand Down
4 changes: 0 additions & 4 deletions spark/src/main/scala/ai/chronon/spark/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,6 @@ object Extensions {
sortByCols = sortByCols)
}

def saveUnPartitioned(tableName: String, tableProperties: Map[String, String] = null): Unit = {
TableUtils(df.sparkSession).insertUnPartitioned(df, tableName, tableProperties)
}

def prefixColumnNames(prefix: String, columns: Seq[String]): DataFrame = {
columns.foldLeft(df) { (renamedDf, key) =>
renamedDf.withColumnRenamed(key, s"${prefix}_$key")
Expand Down
2 changes: 1 addition & 1 deletion spark/src/main/scala/ai/chronon/spark/GroupByUpload.scala
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ object GroupByUpload {
kvDf
.union(metaDf)
.withColumn("ds", lit(endDs))
.saveUnPartitioned(groupByConf.metaData.uploadTable, groupByConf.metaData.tableProps)
.save(groupByConf.metaData.uploadTable, groupByConf.metaData.tableProps, partitionColumns = List.empty)

val kvDfReloaded = tableUtils
.loadTable(groupByConf.metaData.uploadTable)
Expand Down
2 changes: 1 addition & 1 deletion spark/src/main/scala/ai/chronon/spark/StagingQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class StagingQuery(stagingQueryConf: api.StagingQuery, endPartition: String, tab
}
// the input table is not partitioned, usually for data testing or for kaggle demos
if (stagingQueryConf.startPartition == null) {
tableUtils.sql(stagingQueryConf.query).saveUnPartitioned(outputTable)
tableUtils.sql(stagingQueryConf.query).save(outputTable, partitionColumns = List.empty)
} else {
val overrideStart = overrideStartPartition.getOrElse(stagingQueryConf.startPartition)
val unfilledRanges =
Expand Down
42 changes: 9 additions & 33 deletions spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import ai.chronon.spark.TableUtils.{
TableCreationStatus
}
import ai.chronon.spark.format.CreationUtils.alterTablePropertiesSql
import ai.chronon.spark.format.{DefaultFormatProvider, Format, FormatProvider}
import ai.chronon.spark.format.{DefaultFormatProvider, FormatProvider}
import org.apache.hadoop.hive.metastore.api.AlreadyExistsException
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project}
Expand Down Expand Up @@ -111,7 +111,6 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
private val aggregationParallelism: Int = sparkSession.conf.get("spark.chronon.group_by.parallelism", "1000").toInt

sparkSession.sparkContext.setLogLevel("ERROR")
// converts String-s like "a=b/c=d" to Map("a" -> "b", "c" -> "d")

def preAggRepartition(df: DataFrame): DataFrame =
if (df.rdd.getNumPartitions < aggregationParallelism) {
Expand All @@ -122,7 +121,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable

def tableReachable(tableName: String): Boolean = {
try {
tableReadFormat(tableName).isDefined
tableFormatProvider.readFormat(tableName).isDefined
} catch {
case ex: Exception =>
logger.info(s"""Couldn't reach $tableName. Error: ${ex.getMessage.red}
Expand All @@ -137,12 +136,6 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
sparkSession.read.load(DataPointer.from(tableName, sparkSession))
}

def isPartitioned(tableName: String): Boolean = {
// TODO: use proper way to detect if a table is partitioned or not
val schema = getSchemaFromTable(tableName)
schema.fieldNames.contains(partitionColumn)
}

// Needs provider
def createDatabase(database: String): Boolean = {
try {
Expand All @@ -159,17 +152,16 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
}
}

def tableReadFormat(tableName: String): Option[Format] = tableFormatProvider.readFormat(tableName)

// Needs provider
// return all specified partition columns in a table in format of Map[partitionName, PartitionValue]
def allPartitions(tableName: String, partitionColumnsFilter: Seq[String] = Seq.empty): Seq[Map[String, String]] = {

if (!tableReachable(tableName)) return Seq.empty[Map[String, String]]

val format = tableReadFormat(tableName).getOrElse(
throw new IllegalStateException(
s"Could not determine read format of table ${tableName}. It is no longer reachable."))
val format = tableFormatProvider
.readFormat(tableName)
.getOrElse(
throw new IllegalStateException(
s"Could not determine read format of table ${tableName}. It is no longer reachable."))
val partitionSeq = format.partitions(tableName)(sparkSession)

if (partitionColumnsFilter.isEmpty) {
Expand All @@ -189,7 +181,8 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
subPartitionsFilter: Map[String, String] = Map.empty,
partitionColumnName: String = partitionColumn): Seq[String] = {

tableReadFormat(tableName)
tableFormatProvider
.readFormat(tableName)
.map((format) => {
val partitions = format.primaryPartitions(tableName, partitionColumnName, subPartitionsFilter)(sparkSession)

Expand Down Expand Up @@ -385,23 +378,6 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
}
}

// Needs provider
def insertUnPartitioned(df: DataFrame,
tableName: String,
tableProperties: Map[String, String] = null,
saveMode: SaveMode = SaveMode.Overwrite,
fileFormat: String = "PARQUET"): Unit = {

val creationStatus = createTable(df, tableName, Seq.empty[String], tableProperties, fileFormat)

creationStatus match {
case TableUtils.TableCreatedWithoutInitialData | TableUtils.TableAlreadyExists =>
repartitionAndWrite(df, tableName, saveMode, None, partitionColumns = Seq.empty)
case TableUtils.TableCreatedWithInitialData =>
}

}

def columnSizeEstimator(dataType: DataType): Long = {
dataType match {
// TODO: improve upon this very basic estimate approach
Expand Down
4 changes: 2 additions & 2 deletions spark/src/main/scala/ai/chronon/spark/stats/CompareJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ class CompareJob(
logger.info("Saving comparison output..")
logger.info(
s"Comparison schema ${compareDf.schema.fields.map(sb => (sb.name, sb.dataType)).toMap.mkString("\n - ")}")
compareDf.saveUnPartitioned(comparisonTableName, tableProps)
compareDf.save(comparisonTableName, tableProps, partitionColumns = List.empty)

// Save the metrics table
logger.info("Saving metrics output..")
val metricsDf = metricsTimedKvRdd.toFlatDf
logger.info(s"Metrics schema ${metricsDf.schema.fields.map(sb => (sb.name, sb.dataType)).toMap.mkString("\n - ")}")
metricsDf.saveUnPartitioned(metricsTableName, tableProps)
metricsDf.save(metricsTableName, tableProps, partitionColumns = List.empty)

logger.info("Printing basic comparison results..")
logger.info("(Note: This is just an estimation and not a detailed analysis of results)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class PartitionRunner[T](verb: String,
if (outputDf.columns.contains(tu.partitionColumn)) {
outputDf.save(outputTable)
} else {
outputDf.saveUnPartitioned(outputTable)
outputDf.save(outputTable, partitionColumns = List.empty)
}
println(s"""
|Finished computing range ${i + 1}/$n
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class TableUtilsFormatTest extends AnyFlatSpec {
it should "return empty read format if table doesn't exist" in {
val dbName = s"db_${System.currentTimeMillis()}"
val tableName = s"$dbName.test_table_nonexistent_$format"
assertTrue(tableUtils.tableReadFormat(tableName).isEmpty)
assertTrue(tableUtils.tableFormatProvider.readFormat(tableName).isEmpty)
assertFalse(tableUtils.tableReachable(tableName))
}
}
Expand Down Expand Up @@ -188,7 +188,7 @@ object TableUtilsFormatTest {
tableUtils.insertPartitions(df2, tableName, autoExpand = true)

// check that we wrote out a table in the right format
val readTableFormat = tableUtils.tableReadFormat(tableName).get.toString
val readTableFormat = tableUtils.tableFormatProvider.readFormat(tableName).get.toString
assertTrue(s"Mismatch in table format: $readTableFormat; expected: $format", readTableFormat.toLowerCase == format)

// check we have all the partitions written
Expand Down
Loading