Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.internal.config

import java.util.Locale
import java.util.concurrent.TimeUnit
import java.util.regex.PatternSyntaxException

Expand Down Expand Up @@ -46,6 +47,16 @@ private object ConfigHelpers {
}
}

def toEnum[E <: Enumeration](s: String, enumClass: E, key: String): enumClass.Value = {
try {
enumClass.withName(s.trim.toUpperCase(Locale.ROOT))
} catch {
case _: NoSuchElementException =>
throw new IllegalArgumentException(
s"$key should be one of ${enumClass.values.mkString(", ")}, but was $s")
}
}

def stringToSeq[T](str: String, converter: String => T): Seq[T] = {
SparkStringUtils.stringToSeq(str).map(converter)
}
Expand Down Expand Up @@ -271,6 +282,11 @@ private[spark] case class ConfigBuilder(key: String) {
new TypedConfigBuilder(this, v => v)
}

def enumConf(e: Enumeration): TypedConfigBuilder[e.Value] = {
checkPrependConfig
new TypedConfigBuilder(this, toEnum(_, e, key))
}

def timeConf(unit: TimeUnit): TypedConfigBuilder[Long] = {
checkPrependConfig
new TypedConfigBuilder(this, timeFromString(_, unit), timeToString(_, unit))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,4 +387,25 @@ class ConfigEntrySuite extends SparkFunSuite {
ConfigBuilder(testKey("oc5")).onCreate(_ => onCreateCalled = true).fallbackConf(fallback)
assert(onCreateCalled)
}


test("SPARK-51874: Add Enum support to ConfigBuilder") {
object MyTestEnum extends Enumeration {
val X, Y, Z = Value
}
val conf = new SparkConf()
val enumConf = ConfigBuilder("spark.test.enum.key")
.enumConf(MyTestEnum)
.createWithDefault(MyTestEnum.X)
assert(conf.get(enumConf) === MyTestEnum.X)
conf.set(enumConf, MyTestEnum.Y)
assert(conf.get(enumConf) === MyTestEnum.Y)
conf.set(enumConf.key, "Z")
assert(conf.get(enumConf) === MyTestEnum.Z)
val e = intercept[IllegalArgumentException] {
conf.set(enumConf.key, "A")
conf.get(enumConf)
}
assert(e.getMessage === s"${enumConf.key} should be one of X, Y, Z, but was A")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ object CTESubstitution extends Rule[LogicalPlan] {

val cteDefs = ArrayBuffer.empty[CTERelationDef]
val (substituted, firstSubstituted) =
LegacyBehaviorPolicy.withName(conf.getConf(LEGACY_CTE_PRECEDENCE_POLICY)) match {
conf.getConf(LEGACY_CTE_PRECEDENCE_POLICY) match {
case LegacyBehaviorPolicy.EXCEPTION =>
assertNoNameConflictsInCTE(plan)
traverseAndSubstituteCTE(plan, forceInline, Seq.empty, cteDefs, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper {
Some("hiveCaseSensitiveInferenceMode")
} else if (conf.getConf(SQLConf.LEGACY_INLINE_CTE_IN_COMMANDS)) {
Some("legacyInlineCTEInCommands")
} else if (LegacyBehaviorPolicy.withName(conf.getConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY)) !=
} else if (conf.getConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY) !=
LegacyBehaviorPolicy.CORRECTED) {
Some("legacyCTEPrecedencePolicy")
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] extends Logging {

def createObject(in: IN): OUT = {
// We are allowed to choose codegen-only or no-codegen modes if under tests.
val fallbackMode = CodegenObjectFactoryMode.withName(SQLConf.get.codegenFactoryMode)
val fallbackMode = SQLConf.get.codegenFactoryMode

fallbackMode match {
case CodegenObjectFactoryMode.CODEGEN_ONLY =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression =>
object ToStringBase {
def getBinaryFormatter: BinaryFormatter = {
val style = SQLConf.get.getConf(SQLConf.BINARY_OUTPUT_STYLE)
style.map(BinaryOutputStyle.withName) match {
style match {
case Some(BinaryOutputStyle.UTF8) =>
(array: Array[Byte]) => UTF8String.fromBytes(array)
case Some(BinaryOutputStyle.BASIC) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
keys.append(keyNormalized)
values.append(value)
} else {
if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.EXCEPTION.toString) {
if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.EXCEPTION) {
throw QueryExecutionErrors.duplicateMapKeyFoundError(key)
} else if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
} else if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.LAST_WIN) {
// Overwrite the previous value, as the policy is last wins.
values(index) = value
} else {
Expand Down
Loading