diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index ec1627a3898b..865a14509485 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -33,6 +33,7 @@ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -51,6 +52,13 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(newEntries, originalEntries) } + test("resolve avro data source") { + Seq("avro", "com.databricks.spark.avro").foreach { provider => + assert(DataSource.lookupDataSource(provider, spark.sessionState.conf) === + classOf[org.apache.spark.sql.avro.AvroFileFormat]) + } + } + test("reading from multiple paths") { val df = spark.read.format("avro").load(episodesAvro, episodesAvro) assert(df.count == 16) @@ -456,7 +464,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // get the same values back. withTempPath { tempDir => val name = "AvroTest" - val namespace = "com.databricks.spark.avro" + val namespace = "org.apache.spark.avro" val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) val avroDir = tempDir + "/namedAvro" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 0c3d9a4895fe..b1a10fdb6020 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -571,6 +571,7 @@ object DataSource extends Logging { val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName val rate = classOf[RateStreamProvider].getCanonicalName + val avro = "org.apache.spark.sql.avro.AvroFileFormat" Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -592,6 +593,7 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, + "com.databricks.spark.avro" -> avro, "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate ) @@ -635,12 +637,6 @@ object DataSource extends Logging { "Hive built-in ORC data source must be used with Hive support enabled. " + "Please use the native ORC data source by setting 'spark.sql.orc.impl' to " + "'native'") - } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || - provider1 == "com.databricks.spark.avro") { - throw new AnalysisException( - s"Failed to find data source: ${provider1.toLowerCase(Locale.ROOT)}. " + - "Please find an Avro package at " + - "http://spark.apache.org/third-party-projects.html") } else { throw new ClassNotFoundException( s"Failed to find data source: $provider1. Please find packages at " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dfb9c137b74f..86083d1701c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1689,22 +1689,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } assert(e.message.contains("Hive built-in ORC data source must be used with Hive support")) - e = intercept[AnalysisException] { - sql(s"select id from `com.databricks.spark.avro`.`file_path`") - } - assert(e.message.contains("Failed to find data source: com.databricks.spark.avro.")) - - // data source type is case insensitive - e = intercept[AnalysisException] { - sql(s"select id from Avro.`file_path`") - } - assert(e.message.contains("Failed to find data source: avro.")) - - e = intercept[AnalysisException] { - sql(s"select id from avro.`file_path`") - } - assert(e.message.contains("Failed to find data source: avro.")) - e = intercept[AnalysisException] { sql(s"select id from `org.apache.spark.sql.sources.HadoopFsRelationProvider`.`file_path`") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 4adbff5c663b..95460fa70d8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -77,19 +77,9 @@ class ResolvedDataSourceSuite extends SparkFunSuite with SharedSQLContext { } test("error message for unknown data sources") { - val error1 = intercept[AnalysisException] { - getProvidingClass("avro") - } - assert(error1.getMessage.contains("Failed to find data source: avro.")) - - val error2 = intercept[AnalysisException] { - getProvidingClass("com.databricks.spark.avro") - } - assert(error2.getMessage.contains("Failed to find data source: com.databricks.spark.avro.")) - - val error3 = intercept[ClassNotFoundException] { + val error = intercept[ClassNotFoundException] { getProvidingClass("asfdwefasdfasdf") } - assert(error3.getMessage.contains("Failed to find data source: asfdwefasdfasdf.")) + assert(error.getMessage.contains("Failed to find data source: asfdwefasdfasdf.")) } }