diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java index 9a89c8193dd6e..b2c908dc73a61 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java @@ -49,4 +49,35 @@ public DataSourceV2Options(Map originalMap) { public Optional get(String key) { return Optional.ofNullable(keyLowerCasedMap.get(toLowerCase(key))); } + + /** + * Returns the boolean value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public boolean getBoolean(String key, boolean defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Boolean.parseBoolean(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } + + /** + * Returns the integer value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public int getInt(String key, int defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Integer.parseInt(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } + + /** + * Returns the long value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public long getLong(String key, long defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Long.parseLong(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala index 933f4075bcc8a..752d3c193cc74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala @@ -37,4 +37,35 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { val options = new DataSourceV2Options(Map("foo" -> "bAr").asJava) assert(options.get("foo").get == "bAr") } + + test("getInt") { + val options = new DataSourceV2Options(Map("numFOo" -> "1", "foo" -> "bar").asJava) + assert(options.getInt("numFOO", 10) == 1) + assert(options.getInt("numFOO2", 10) == 10) + + intercept[NumberFormatException]{ + options.getInt("foo", 1) + } + } + + test("getBoolean") { + val options = new DataSourceV2Options( + Map("isFoo" -> "true", "isFOO2" -> "false", "foo" -> "bar").asJava) + assert(options.getBoolean("isFoo", false)) + assert(!options.getBoolean("isFoo2", true)) + assert(options.getBoolean("isBar", true)) + assert(!options.getBoolean("isBar", false)) + assert(!options.getBoolean("FOO", true)) + } + + test("getLong") { + val options = new DataSourceV2Options(Map("numFoo" -> "9223372036854775807", + "foo" -> "bar").asJava) + assert(options.getLong("numFOO", 0L) == 9223372036854775807L) + assert(options.getLong("numFoo2", -1L) == -1L) + + intercept[NumberFormatException]{ + options.getLong("foo", 0L) + } + } }