Skip to content

Commit 3bb8731

Browse files
committed
Refactors HiveParquetSuite
1 parent aa2cb2e commit 3bb8731

File tree

1 file changed

+41
-78
lines changed

1 file changed

+41
-78
lines changed

sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala

Lines changed: 41 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -17,103 +17,66 @@
1717

1818
package org.apache.spark.sql.parquet
1919

20-
import java.io.File
21-
22-
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
23-
24-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row}
25-
import org.apache.spark.sql.catalyst.types.{DataType, StringType, IntegerType}
26-
import org.apache.spark.sql.{parquet, SchemaRDD}
27-
import org.apache.spark.util.Utils
28-
29-
// Implicits
30-
import org.apache.spark.sql.hive.test.TestHive._
20+
import org.apache.spark.sql.QueryTest
21+
import org.apache.spark.sql.catalyst.expressions.Row
22+
import org.apache.spark.sql.hive.test.TestHive
3123

3224
case class Cases(lower: String, UPPER: String)
3325

34-
class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
35-
36-
val dirname = Utils.createTempDir()
37-
38-
var testRDD: SchemaRDD = null
39-
40-
override def beforeAll() {
41-
// write test data
42-
ParquetTestData.writeFile()
43-
testRDD = parquetFile(ParquetTestData.testDir.toString)
44-
testRDD.registerTempTable("testsource")
45-
}
46-
47-
override def afterAll() {
48-
Utils.deleteRecursively(ParquetTestData.testDir)
49-
Utils.deleteRecursively(dirname)
50-
reset() // drop all tables that were registered as part of the tests
51-
}
52-
53-
// in case tests are failing we delete before and after each test
54-
override def beforeEach() {
55-
Utils.deleteRecursively(dirname)
56-
}
26+
class HiveParquetSuite extends QueryTest with ParquetTest {
27+
val sqlContext = TestHive
5728

58-
override def afterEach() {
59-
Utils.deleteRecursively(dirname)
60-
}
29+
import sqlContext._
6130

6231
test("Case insensitive attribute names") {
63-
val tempFile = File.createTempFile("parquet", "")
64-
tempFile.delete()
65-
sparkContext.parallelize(1 to 10)
66-
.map(_.toString)
67-
.map(i => Cases(i, i))
68-
.saveAsParquetFile(tempFile.getCanonicalPath)
69-
70-
parquetFile(tempFile.getCanonicalPath).registerTempTable("cases")
71-
sql("SELECT upper FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString)
72-
sql("SELECT LOWER FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString)
32+
withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") {
33+
val expected = (1 to 4).map(i => Row(i.toString))
34+
checkAnswer(sql("SELECT upper FROM cases"), expected)
35+
checkAnswer(sql("SELECT LOWER FROM cases"), expected)
36+
}
7337
}
7438

7539
test("SELECT on Parquet table") {
76-
val rdd = sql("SELECT * FROM testsource").collect()
77-
assert(rdd != null)
78-
assert(rdd.forall(_.size == 6))
40+
val data = (1 to 4).map(i => (i, s"val_$i"))
41+
withParquetTable(data, "t") {
42+
checkAnswer(sql("SELECT * FROM t"), data)
43+
}
7944
}
8045

8146
test("Simple column projection + filter on Parquet table") {
82-
val rdd = sql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect()
83-
assert(rdd.size === 5, "Filter returned incorrect number of rows")
84-
assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value")
47+
withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") {
48+
checkAnswer(
49+
sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"),
50+
Seq(Row(true, "val_2"), Row(true, "val_4")))
51+
}
8552
}
8653

8754
test("Converting Hive to Parquet Table via saveAsParquetFile") {
88-
sql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath)
89-
parquetFile(dirname.getAbsolutePath).registerTempTable("ptable")
90-
val rddOne = sql("SELECT * FROM src").collect().sortBy(_.getInt(0))
91-
val rddTwo = sql("SELECT * from ptable").collect().sortBy(_.getInt(0))
92-
93-
compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String"))
55+
withTempPath { dir =>
56+
sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath)
57+
parquetFile(dir.getCanonicalPath).registerTempTable("p")
58+
withTempTable("p") {
59+
checkAnswer(
60+
sql("SELECT * FROM src ORDER BY key"),
61+
sql("SELECT * from p ORDER BY key").collect().toSeq)
62+
}
63+
}
9464
}
9565

96-
test("INSERT OVERWRITE TABLE Parquet table") {
97-
sql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath)
98-
parquetFile(dirname.getAbsolutePath).registerTempTable("ptable")
99-
// let's do three overwrites for good measure
100-
sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
101-
sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
102-
sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
103-
val rddCopy = sql("SELECT * FROM ptable").collect()
104-
val rddOrig = sql("SELECT * FROM testsource").collect()
105-
assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??")
106-
compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames)
107-
}
10866

109-
private def compareRDDs(rddOne: Array[Row], rddTwo: Array[Row], tableName: String, fieldNames: Seq[String]) {
110-
var counter = 0
111-
(rddOne, rddTwo).zipped.foreach {
112-
(a,b) => (a,b).zipped.toArray.zipWithIndex.foreach {
113-
case ((value_1, value_2), index) =>
114-
assert(value_1 === value_2, s"table $tableName row $counter field ${fieldNames(index)} don't match")
67+
test("INSERT OVERWRITE TABLE Parquet table") {
68+
withParquetTable((1 to 4).map(i => (i, s"val_$i")), "t") {
69+
withTempPath { file =>
70+
sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath)
71+
parquetFile(file.getCanonicalPath).registerTempTable("p")
72+
withTempTable("p") {
73+
// let's do three overwrites for good measure
74+
sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
75+
sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
76+
sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
77+
checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq)
78+
}
11579
}
116-
counter = counter + 1
11780
}
11881
}
11982
}

0 commit comments

Comments
 (0)