From 468922894bb8996976f17db4c1c63ada0de95e4e Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Sun, 2 Jul 2017 00:00:18 +0800 Subject: [PATCH 1/2] load test table based on case sensitivity --- .../apache/spark/sql/hive/test/TestHive.scala | 9 +++- .../apache/spark/sql/hive/TestHiveSuite.scala | 43 +++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 4e1792321c89..7cb84bdedae5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.test import java.io.File -import java.util.{Set => JavaSet} +import java.util.{Locale, Set => JavaSet} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -449,6 +449,8 @@ private[hive] class TestHiveSparkSession( private val loadedTables = new collection.mutable.HashSet[String] + def getLoadedTables: collection.mutable.HashSet[String] = loadedTables + def loadTestTable(name: String) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infinite mutually recursive table loading. @@ -553,7 +555,10 @@ private[hive] class TestHiveQueryExecution( val referencedTables = describedTables ++ logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.table } - val referencedTestTables = referencedTables.filter(sparkSession.testTables.contains) + val formattedRefTables = referencedTables.map { t => + if (sparkSession.sessionState.conf.caseSensitiveAnalysis) t else t.toLowerCase(Locale.ROOT) + } + val referencedTestTables = formattedRefTables.filter(sparkSession.testTables.contains) logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) // Proceed with analysis. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala new file mode 100644 index 000000000000..7a59ddcfc418 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.hive.test.{TestHiveSingleton, TestHiveSparkSession} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + + +class TestHiveSuite extends TestHiveSingleton with SQLTestUtils { + test("load test table based on case sensitivity") { + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + sql("SELECT * FROM SRC").queryExecution.analyzed + assert(spark.asInstanceOf[TestHiveSparkSession].getLoadedTables.contains("src")) + } + spark.asInstanceOf[TestHiveSparkSession].reset() + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val err = intercept[AnalysisException] { + sql("SELECT * FROM SRC").queryExecution.analyzed + } + assert(err.message.contains("Table or view not found")) + } + spark.asInstanceOf[TestHiveSparkSession].reset() + } +} From a30b75ae9dc86fd79833e95a143efc2909a6386c Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 3 Jul 2017 21:52:25 +0800 Subject: [PATCH 2/2] use resolver --- .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 8 ++++---- .../scala/org/apache/spark/sql/hive/TestHiveSuite.scala | 8 +++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 7cb84bdedae5..801f9b992364 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.test import java.io.File -import java.util.{Locale, Set => JavaSet} +import java.util.{Set => JavaSet} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -555,10 +555,10 @@ private[hive] class TestHiveQueryExecution( val referencedTables = describedTables ++ logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.table } - val formattedRefTables = referencedTables.map { t => - if (sparkSession.sessionState.conf.caseSensitiveAnalysis) t else t.toLowerCase(Locale.ROOT) + val resolver = sparkSession.sessionState.conf.resolver + val referencedTestTables = sparkSession.testTables.keys.filter { testTable => + referencedTables.exists(resolver(_, testTable)) } - val referencedTestTables = formattedRefTables.filter(sparkSession.testTables.contains) logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) // Proceed with analysis. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala index 7a59ddcfc418..193fa83dbad9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala @@ -25,12 +25,14 @@ import org.apache.spark.sql.test.SQLTestUtils class TestHiveSuite extends TestHiveSingleton with SQLTestUtils { test("load test table based on case sensitivity") { + val testHiveSparkSession = spark.asInstanceOf[TestHiveSparkSession] withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { sql("SELECT * FROM SRC").queryExecution.analyzed - assert(spark.asInstanceOf[TestHiveSparkSession].getLoadedTables.contains("src")) + assert(testHiveSparkSession.getLoadedTables.contains("src")) + assert(testHiveSparkSession.getLoadedTables.size == 1) } - spark.asInstanceOf[TestHiveSparkSession].reset() + testHiveSparkSession.reset() withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { val err = intercept[AnalysisException] { @@ -38,6 +40,6 @@ class TestHiveSuite extends TestHiveSingleton with SQLTestUtils { } assert(err.message.contains("Table or view not found")) } - spark.asInstanceOf[TestHiveSparkSession].reset() + testHiveSparkSession.reset() } }