Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,6 @@ class FlintSpark(val spark: SparkSession) extends Logging {
case (true, false) => AUTO
case (false, false) => FULL
case (false, true) => INCREMENTAL
case (true, true) =>
throw new IllegalArgumentException(
"auto_refresh and incremental_refresh options cannot both be true")
}

// validate allowed options depending on refresh mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package org.opensearch.flint.spark
import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.opensearch.flint.spark.FlintSparkIndexOptions.empty
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh

import org.apache.spark.sql.catalog.Column
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
Expand Down Expand Up @@ -59,7 +60,7 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) {
* ignore existing index
*/
def create(ignoreIfExists: Boolean = false): Unit =
flint.createIndex(buildIndex(), ignoreIfExists)
flint.createIndex(validateIndex(buildIndex()), ignoreIfExists)

/**
* Copy Flint index with updated options.
Expand All @@ -80,7 +81,24 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) {
val updatedMetadata = index
.metadata()
.copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava)
FlintSparkIndexFactory.create(updatedMetadata).get
validateIndex(FlintSparkIndexFactory.create(updatedMetadata).get)
}

/**
* Pre-validate index to ensure its validity. By default, this method validates index options by
* delegating to specific index refresh (index options are mostly serving index refresh).
* Subclasses can extend this method to include additional validation logic.
*
* @param index
* Flint index to be validated
* @return
* the index or exception occurred if validation failed
*/
protected def validateIndex(index: FlintSparkIndex): FlintSparkIndex = {
FlintSparkIndexRefresh
.create(index.name(), index) // TODO: remove first argument?
.validate(flint.spark)
index
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import java.io.IOException

import org.apache.hadoop.fs.Path
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName}

/**
* Flint Spark validation helper.
*/
trait FlintSparkValidationHelper extends Logging {

/**
* Determines whether the source table(s) for a given Flint index are supported.
*
* @param spark
* Spark session
* @param index
* Flint index
* @return
* true if all non Hive, otherwise false
*/
def isTableProviderSupported(spark: SparkSession, index: FlintSparkIndex): Boolean = {
// Extract source table name (possibly more than one for MV query)
val tableNames = index match {
case skipping: FlintSparkSkippingIndex => Seq(skipping.tableName)
case covering: FlintSparkCoveringIndex => Seq(covering.tableName)
case mv: FlintSparkMaterializedView =>
spark.sessionState.sqlParser
.parsePlan(mv.query)
.collect { case relation: UnresolvedRelation =>
qualifyTableName(spark, relation.tableName)
}
}

// Validate if any source table is not supported (currently Hive only)
tableNames.exists { tableName =>
val (catalog, ident) = parseTableName(spark, tableName)
val table = loadTable(catalog, ident).get

// TODO: add allowed table provider list
DDLUtils.isHiveTable(Option(table.properties().get("provider")))
}
}

/**
* Checks whether a specified checkpoint location is accessible. Accessibility, in this context,
* means that the folder exists and the current Spark session has the necessary permissions to
* access it.
*
* @param spark
* Spark session
* @param checkpointLocation
* checkpoint location
* @return
* true if accessible, otherwise false
*/
def isCheckpointLocationAccessible(spark: SparkSession, checkpointLocation: String): Boolean = {
try {
val checkpointManager =
CheckpointFileManager.create(
new Path(checkpointLocation),
spark.sessionState.newHadoopConf())

checkpointManager.exists(new Path(checkpointLocation))
} catch {
case e: IOException =>
logWarning(s"Failed to check if checkpoint location $checkpointLocation exists", e)
false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

package org.opensearch.flint.spark.refresh

import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexOptions}
import java.util.Collections

import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexOptions, FlintSparkValidationHelper}
import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, StreamingRefresh}
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{AUTO, RefreshMode}

Expand All @@ -23,10 +25,41 @@ import org.apache.spark.sql.streaming.{DataStreamWriter, Trigger}
* @param index
* Flint index
*/
class AutoIndexRefresh(indexName: String, index: FlintSparkIndex) extends FlintSparkIndexRefresh {
class AutoIndexRefresh(indexName: String, index: FlintSparkIndex)
extends FlintSparkIndexRefresh
with FlintSparkValidationHelper {

override def refreshMode: RefreshMode = AUTO

override def validate(spark: SparkSession): Unit = {
// Incremental refresh cannot enabled at the same time
val options = index.options
require(
!options.incrementalRefresh(),
"Incremental refresh cannot be enabled if auto refresh is enabled")
Comment thread
dai-chen marked this conversation as resolved.

// Hive table doesn't support auto refresh
require(
!isTableProviderSupported(spark, index),
"Index auto refresh doesn't support Hive table")

// Checkpoint location is required if mandatory option set
val flintSparkConf = new FlintSparkConf(Collections.emptyMap[String, String])
val checkpointLocation = options.checkpointLocation()
if (flintSparkConf.isCheckpointMandatory) {
require(
checkpointLocation.isDefined,
s"Checkpoint location is required if ${CHECKPOINT_MANDATORY.key} option enabled")
}

// Checkpoint location must be accessible
if (checkpointLocation.isDefined) {
require(
isCheckpointLocationAccessible(spark, checkpointLocation.get),
s"Checkpoint location ${checkpointLocation.get} doesn't exist or no permission to access")
}
}

override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = {
val options = index.options
val tableName = index.metadata().source
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ trait FlintSparkIndexRefresh extends Logging {
*/
def refreshMode: RefreshMode

/**
* Validates the current index refresh settings before the actual execution begins. This method
* checks for the integrity of the index refresh configurations and ensures that all options set
* for the current refresh mode are valid. This preemptive validation helps in identifying
* configuration issues before the refresh operation is initiated, minimizing runtime errors and
* potential inconsistencies.
*
* @param spark
* Spark session
* @throws IllegalArgumentException
* if any invalid or inapplicable config identified
*/
def validate(spark: SparkSession): Unit

/**
* Start refreshing the index.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ class FullIndexRefresh(

override def refreshMode: RefreshMode = FULL

override def validate(spark: SparkSession): Unit = {
// Full refresh validates nothing for now, including Hive table validation.
// This allows users to continue using their existing Hive table with full refresh only.
}

override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = {
logInfo(s"Start refreshing index $indexName in full mode")
index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.flint.spark.refresh

import org.opensearch.flint.spark.FlintSparkIndex
import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkValidationHelper}
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{INCREMENTAL, RefreshMode}

import org.apache.spark.sql.SparkSession
Expand All @@ -20,18 +20,31 @@ import org.apache.spark.sql.flint.config.FlintSparkConf
* Flint index
*/
class IncrementalIndexRefresh(indexName: String, index: FlintSparkIndex)
extends FlintSparkIndexRefresh {
extends FlintSparkIndexRefresh
with FlintSparkValidationHelper {

override def refreshMode: RefreshMode = INCREMENTAL

override def validate(spark: SparkSession): Unit = {
// Non-Hive table is required for incremental refresh
require(
!isTableProviderSupported(spark, index),
"Index incremental refresh doesn't support Hive table")

// Checkpoint location is required regardless of mandatory option
val options = index.options
val checkpointLocation = options.checkpointLocation()
require(
options.checkpointLocation().nonEmpty,
"Checkpoint location is required by incremental refresh")
require(
isCheckpointLocationAccessible(spark, checkpointLocation.get),
s"Checkpoint location ${checkpointLocation.get} doesn't exist or no permission to access")
}

override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = {
logInfo(s"Start refreshing index $indexName in incremental mode")

// TODO: move this to validation method together in future
if (index.options.checkpointLocation().isEmpty) {
throw new IllegalStateException("Checkpoint location is required by incremental refresh")
}

// Reuse auto refresh which uses AvailableNow trigger and will stop once complete
val jobId =
new AutoIndexRefresh(indexName, index)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.hive.HiveSessionStateBuilder
import org.apache.spark.sql.internal.{SessionState, StaticSQLConf}
import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession}

/**
* Flint Spark base suite with Hive support enabled. Because enabling Hive support in Spark
* configuration alone is not adequate, as [[TestSparkSession]] disregards it and consistently
* creates its own instance of [[org.apache.spark.sql.test.TestSQLSessionStateBuilder]]. We need
* to override its session state with that of Hive in the meanwhile.
*
* Note that we need to extend [[SharedSparkSession]] to call super.sparkConf() method.
*/
trait SparkHiveSupportSuite extends SharedSparkSession {

override protected def sparkConf: SparkConf = {
super.sparkConf
// Enable Hive support
.set(StaticSQLConf.CATALOG_IMPLEMENTATION.key, "hive")
// Use in-memory Derby as Hive metastore so no need to clean up metastore_db folder after test
.set("javax.jdo.option.ConnectionURL", "jdbc:derby:memory:metastore_db;create=true")
.set("hive.metastore.uris", "")
}

override protected def createSparkSession: TestSparkSession = {
SparkSession.cleanupAnyExistingSession()
new FlintTestSparkSession(sparkConf)
}

class FlintTestSparkSession(sparkConf: SparkConf) extends TestSparkSession(sparkConf) { self =>

override lazy val sessionState: SessionState = {
// Override to replace [[TestSQLSessionStateBuilder]] with Hive session state
new HiveSessionStateBuilder(spark, None).build()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite {
test("create skipping index with auto refresh should fail if mandatory checkpoint enabled") {
setFlintSparkConf(CHECKPOINT_MANDATORY, "true")
try {
the[IllegalStateException] thrownBy {
the[IllegalArgumentException] thrownBy {
sql(s"""
| CREATE INDEX $testIndex ON $testTable
| (name, age)
Expand Down
Loading