Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package ai.chronon.integrations.cloud_gcp

import ai.chronon.spark.format.Format
import com.google.cloud.bigquery.BigQuery
import com.google.cloud.bigquery.connector.common.BigQueryUtil
import com.google.cloud.spark.bigquery.SchemaConverters
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had to pull in all the shaded stuff from the bigquery connector in order to leverage some of the utils.

import com.google.cloud.spark.bigquery.SchemaConvertersConfiguration
import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery.BigQuery
import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery.StandardTableDefinition
import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery.TableInfo
import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery.TimePartitioning
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
Expand All @@ -11,19 +16,47 @@ import org.apache.spark.sql.functions.to_date

case class BigQueryFormat(project: String, bqClient: BigQuery, override val options: Map[String, String])
extends Format {

override def name: String = "bigquery"

override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
implicit sparkSession: SparkSession): Seq[String] =
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)

override def createTable(df: DataFrame,
tableName: String,
partitionColumns: Seq[String],
tableProperties: Map[String, String],
fileFormat: String): (String => Unit) => Unit = {
throw new UnsupportedOperationException("createTable not yet supported for BigQuery")

def inner(df: DataFrame, tableName: String, partitionColumns: Seq[String])(sqlEvaluator: String => Unit) = {

// See: https://cloud.google.com/bigquery/docs/partitioned-tables#limitations
// "BigQuery does not support partitioning by multiple columns. Only one column can be used to partition a table."
assert(partitionColumns.size < 2,
s"BigQuery only supports at most one partition column, incoming spec: ${partitionColumns}")
val shadedTableId = BigQueryUtil.parseTableId(tableName)

val shadedBqSchema =
SchemaConverters.from(SchemaConvertersConfiguration.createDefault()).toBigQuerySchema(df.schema)

val baseTableDef = StandardTableDefinition.newBuilder
.setSchema(shadedBqSchema)

val tableDefinition = partitionColumns.headOption
.map((col) => {
val timePartitioning = TimePartitioning.newBuilder(TimePartitioning.Type.DAY).setField(col)
baseTableDef
.setTimePartitioning(timePartitioning.build())
})
.getOrElse(baseTableDef)

val tableInfoBuilder = TableInfo.newBuilder(shadedTableId, tableDefinition.build)

val tableInfo = tableInfoBuilder.build

bqClient.create(tableInfo)
}

inner(df, tableName, partitionColumns)
}

override def partitions(tableName: String)(implicit sparkSession: SparkSession): Seq[Map[String, String]] = {
Expand All @@ -32,7 +65,6 @@ case class BigQueryFormat(project: String, bqClient: BigQuery, override val opti
val table = tableIdentifier.getTable
val database =
Option(tableIdentifier.getDataset).getOrElse(throw new IllegalArgumentException("database required!"))

try {

// See: https://cloud.google.com/bigquery/docs/information-schema-columns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,8 @@ import ai.chronon.spark.TableUtils
import ai.chronon.spark.format.Format
import ai.chronon.spark.format.FormatProvider
import ai.chronon.spark.format.Hive
import com.google.cloud.bigquery.BigQuery
import com.google.cloud.bigquery.BigQueryOptions
import com.google.cloud.bigquery.ExternalTableDefinition
import com.google.cloud.bigquery.FormatOptions
import com.google.cloud.bigquery.StandardTableDefinition
import com.google.cloud.bigquery.Table
import com.google.cloud.bigquery.TableDefinition
import com.google.cloud.bigquery.connector.common.BigQueryUtil
import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery.TableId
import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery._
import org.apache.spark.sql.SparkSession

import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -65,7 +58,8 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider
val formatOptions = definition.getFormatOptions
.asInstanceOf[FormatOptions]
val externalTable = table.getDefinition.asInstanceOf[ExternalTableDefinition]
val uri = Option(externalTable.getHivePartitioningOptions)
val uri = scala
.Option(externalTable.getHivePartitioningOptions)
.map(_.getSourceUriPrefix)
.getOrElse {
val uris = externalTable.getSourceUris.asScala
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFS
import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem
import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystemConfiguration
import com.google.cloud.hadoop.fs.gcs.HadoopConfigurationProperty
import com.google.cloud.hadoop.gcsio.GoogleCloudStorageFileSystem
import org.apache.spark.sql.SparkSession
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ai.chronon.integrations.cloud_gcp

import ai.chronon.spark.SparkSessionBuilder
import com.google.cloud.bigquery._
import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery._
import org.apache.spark.sql.SparkSession
import org.mockito.Mockito.when
import org.scalatest.flatspec.AnyFlatSpec
Expand Down
13 changes: 5 additions & 8 deletions spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -285,17 +285,14 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
fileFormat: String,
autoExpand: Boolean = false): Unit = {

val doesTableExist = tableReachable(tableName)

val writeFormat = tableFormatProvider.writeFormat(tableName)

val createTableOperation =
writeFormat.createTable(df, tableName, partitionColumns, tableProperties, fileFormat)

if (!doesTableExist) {
if (!tableReachable(tableName)) {

try {

val writeFormat = tableFormatProvider.writeFormat(tableName)
val createTableOperation =
writeFormat.createTable(df, tableName, partitionColumns, tableProperties, fileFormat)

createTableOperation(sql)

} catch {
Expand Down
Loading