diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 425fc02e315e9..cab2fe9b90de2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -560,21 +560,31 @@ case class RocksDBConf( object RocksDBConf { /** Common prefix of all confs in SQLConf that affects RocksDB */ - val ROCKSDB_CONF_NAME_PREFIX = "spark.sql.streaming.stateStore.rocksdb" + val ROCKSDB_SQL_CONF_NAME_PREFIX = "spark.sql.streaming.stateStore.rocksdb" - private case class ConfEntry(name: String, default: String) { - def fullName: String = s"$ROCKSDB_CONF_NAME_PREFIX.${name}".toLowerCase(Locale.ROOT) + private abstract class ConfEntry(name: String, val default: String) { + def fullName: String = name.toLowerCase(Locale.ROOT) } + private case class SQLConfEntry(name: String, override val default: String) + extends ConfEntry(name, default) { + + override def fullName: String = + s"$ROCKSDB_SQL_CONF_NAME_PREFIX.${name}".toLowerCase(Locale.ROOT) + } + + private case class ExtraConfEntry(name: String, override val default: String) + extends ConfEntry(name, default) + // Configuration that specifies whether to compact the RocksDB data every time data is committed - private val COMPACT_ON_COMMIT_CONF = ConfEntry("compactOnCommit", "false") - private val BLOCK_SIZE_KB_CONF = ConfEntry("blockSizeKB", "4") - private val BLOCK_CACHE_SIZE_MB_CONF = ConfEntry("blockCacheSizeMB", "8") - private val LOCK_ACQUIRE_TIMEOUT_MS_CONF = ConfEntry("lockAcquireTimeoutMs", "60000") - private val RESET_STATS_ON_LOAD = ConfEntry("resetStatsOnLoad", "true") + private val COMPACT_ON_COMMIT_CONF = SQLConfEntry("compactOnCommit", "false") + private val BLOCK_SIZE_KB_CONF = SQLConfEntry("blockSizeKB", "4") + private val BLOCK_CACHE_SIZE_MB_CONF = SQLConfEntry("blockCacheSizeMB", "8") + private val LOCK_ACQUIRE_TIMEOUT_MS_CONF = SQLConfEntry("lockAcquireTimeoutMs", "60000") + private val RESET_STATS_ON_LOAD = SQLConfEntry("resetStatsOnLoad", "true") // Config to specify the number of open files that can be used by the DB. Value of -1 means // that files opened are always kept open. - private val MAX_OPEN_FILES_CONF = ConfEntry("maxOpenFiles", "-1") + private val MAX_OPEN_FILES_CONF = SQLConfEntry("maxOpenFiles", "-1") // Configuration to set the RocksDB format version. When upgrading the RocksDB version in Spark, // it may introduce a new table format version that can not be supported by an old RocksDB version // used by an old Spark version. Hence, we store the table format version in the checkpoint when @@ -586,7 +596,7 @@ object RocksDBConf { // // Note: this is also defined in `SQLConf.STATE_STORE_ROCKSDB_FORMAT_VERSION`. These two // places should be updated together. - private val FORMAT_VERSION = ConfEntry("formatVersion", "5") + private val FORMAT_VERSION = SQLConfEntry("formatVersion", "5") // Flag to enable/disable tracking the total number of rows. // When this is enabled, this class does additional lookup on write operations (put/delete) to @@ -594,33 +604,45 @@ object RocksDBConf { // The additional lookups bring non-trivial overhead on write-heavy workloads - if your query // does lots of writes on state, it would be encouraged to turn off the config and turn on // again when you really need the know the number for observability/debuggability. - private val TRACK_TOTAL_NUMBER_OF_ROWS = ConfEntry("trackTotalNumberOfRows", "true") + private val TRACK_TOTAL_NUMBER_OF_ROWS = SQLConfEntry("trackTotalNumberOfRows", "true") def apply(storeConf: StateStoreConf): RocksDBConf = { - val confs = CaseInsensitiveMap[String](storeConf.confs) + val sqlConfs = CaseInsensitiveMap[String](storeConf.sqlConfs) + val extraConfs = CaseInsensitiveMap[String](storeConf.extraOptions) + + def getConfigMap(conf: ConfEntry): CaseInsensitiveMap[String] = { + conf match { + case _: SQLConfEntry => sqlConfs + case _: ExtraConfEntry => extraConfs + } + } def getBooleanConf(conf: ConfEntry): Boolean = { - Try { confs.getOrElse(conf.fullName, conf.default).toBoolean } getOrElse { + Try { getConfigMap(conf).getOrElse(conf.fullName, conf.default).toBoolean } getOrElse { throw new IllegalArgumentException(s"Invalid value for '${conf.fullName}', must be boolean") } } def getIntConf(conf: ConfEntry): Int = { - Try { confs.getOrElse(conf.fullName, conf.default).toInt } getOrElse { + Try { getConfigMap(conf).getOrElse(conf.fullName, conf.default).toInt } getOrElse { throw new IllegalArgumentException(s"Invalid value for '${conf.fullName}', " + "must be an integer") } } def getPositiveLongConf(conf: ConfEntry): Long = { - Try { confs.getOrElse(conf.fullName, conf.default).toLong } filter { _ >= 0 } getOrElse { + Try { + getConfigMap(conf).getOrElse(conf.fullName, conf.default).toLong + } filter { _ >= 0 } getOrElse { throw new IllegalArgumentException( s"Invalid value for '${conf.fullName}', must be a positive integer") } } def getPositiveIntConf(conf: ConfEntry): Int = { - Try { confs.getOrElse(conf.fullName, conf.default).toInt } filter { _ >= 0 } getOrElse { + Try { + getConfigMap(conf).getOrElse(conf.fullName, conf.default).toInt + } filter { _ >= 0 } getOrElse { throw new IllegalArgumentException( s"Invalid value for '${conf.fullName}', must be a positive integer") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 66bb37d7a57bd..21a1874534846 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ class StateStoreConf( @transient private val sqlConf: SQLConf, - extraOptions: Map[String, String] = Map.empty) + val extraOptions: Map[String, String] = Map.empty) extends Serializable { def this() = this(new SQLConf) @@ -71,11 +71,10 @@ class StateStoreConf( /** * Additional configurations related to state store. This will capture all configs in - * SQLConf that start with `spark.sql.streaming.stateStore.` and extraOptions for a specific - * operator. + * SQLConf that start with `spark.sql.streaming.stateStore.` */ - val confs: Map[String, String] = - sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore.")) ++ extraOptions + val sqlConfs: Map[String, String] = + sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore.")) } object StateStoreConf { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala index 2dcb10536d4f8..dc505963b4d34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala @@ -173,7 +173,7 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest { (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName), (SQLConf.CHECKPOINT_LOCATION.key -> dir.getCanonicalPath), (SQLConf.SHUFFLE_PARTITIONS.key, "1"), - (s"${RocksDBConf.ROCKSDB_CONF_NAME_PREFIX}.trackTotalNumberOfRows" -> "false")) { + (s"${RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX}.trackTotalNumberOfRows" -> "false")) { val inputData = MemoryStream[Int] val query = inputData.toDS().toDF("value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 67181d7684e96..1998e2af114d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -73,9 +73,9 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val testConfs = Seq( ("spark.sql.streaming.stateStore.providerClass", classOf[RocksDBStateStoreProvider].getName), - (RocksDBConf.ROCKSDB_CONF_NAME_PREFIX + ".compactOnCommit", "true"), - (RocksDBConf.ROCKSDB_CONF_NAME_PREFIX + ".lockAcquireTimeoutMs", "10"), - (RocksDBConf.ROCKSDB_CONF_NAME_PREFIX + ".maxOpenFiles", "1000"), + (RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".compactOnCommit", "true"), + (RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".lockAcquireTimeoutMs", "10"), + (RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".maxOpenFiles", "1000"), (SQLConf.STATE_STORE_ROCKSDB_FORMAT_VERSION.key, "4") ) testConfs.foreach { case (k, v) => spark.conf.set(k, v) }