|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.parquet |
19 | 19 |
|
20 | | -import java.io.File |
21 | | - |
22 | | -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} |
23 | | - |
24 | | -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} |
25 | | -import org.apache.spark.sql.catalyst.types.{DataType, StringType, IntegerType} |
26 | | -import org.apache.spark.sql.{parquet, SchemaRDD} |
27 | | -import org.apache.spark.util.Utils |
28 | | - |
29 | | -// Implicits |
30 | | -import org.apache.spark.sql.hive.test.TestHive._ |
| 20 | +import org.apache.spark.sql.QueryTest |
| 21 | +import org.apache.spark.sql.catalyst.expressions.Row |
| 22 | +import org.apache.spark.sql.hive.test.TestHive |
31 | 23 |
|
32 | 24 | case class Cases(lower: String, UPPER: String) |
33 | 25 |
|
34 | | -class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { |
35 | | - |
36 | | - val dirname = Utils.createTempDir() |
37 | | - |
38 | | - var testRDD: SchemaRDD = null |
39 | | - |
40 | | - override def beforeAll() { |
41 | | - // write test data |
42 | | - ParquetTestData.writeFile() |
43 | | - testRDD = parquetFile(ParquetTestData.testDir.toString) |
44 | | - testRDD.registerTempTable("testsource") |
45 | | - } |
46 | | - |
47 | | - override def afterAll() { |
48 | | - Utils.deleteRecursively(ParquetTestData.testDir) |
49 | | - Utils.deleteRecursively(dirname) |
50 | | - reset() // drop all tables that were registered as part of the tests |
51 | | - } |
52 | | - |
53 | | - // in case tests are failing we delete before and after each test |
54 | | - override def beforeEach() { |
55 | | - Utils.deleteRecursively(dirname) |
56 | | - } |
| 26 | +class HiveParquetSuite extends QueryTest with ParquetTest { |
| 27 | + val sqlContext = TestHive |
57 | 28 |
|
58 | | - override def afterEach() { |
59 | | - Utils.deleteRecursively(dirname) |
60 | | - } |
| 29 | + import sqlContext._ |
61 | 30 |
|
62 | 31 | test("Case insensitive attribute names") { |
63 | | - val tempFile = File.createTempFile("parquet", "") |
64 | | - tempFile.delete() |
65 | | - sparkContext.parallelize(1 to 10) |
66 | | - .map(_.toString) |
67 | | - .map(i => Cases(i, i)) |
68 | | - .saveAsParquetFile(tempFile.getCanonicalPath) |
69 | | - |
70 | | - parquetFile(tempFile.getCanonicalPath).registerTempTable("cases") |
71 | | - sql("SELECT upper FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) |
72 | | - sql("SELECT LOWER FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) |
| 32 | + withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { |
| 33 | + val expected = (1 to 4).map(i => Row(i.toString)) |
| 34 | + checkAnswer(sql("SELECT upper FROM cases"), expected) |
| 35 | + checkAnswer(sql("SELECT LOWER FROM cases"), expected) |
| 36 | + } |
73 | 37 | } |
74 | 38 |
|
75 | 39 | test("SELECT on Parquet table") { |
76 | | - val rdd = sql("SELECT * FROM testsource").collect() |
77 | | - assert(rdd != null) |
78 | | - assert(rdd.forall(_.size == 6)) |
| 40 | + val data = (1 to 4).map(i => (i, s"val_$i")) |
| 41 | + withParquetTable(data, "t") { |
| 42 | + checkAnswer(sql("SELECT * FROM t"), data) |
| 43 | + } |
79 | 44 | } |
80 | 45 |
|
81 | 46 | test("Simple column projection + filter on Parquet table") { |
82 | | - val rdd = sql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect() |
83 | | - assert(rdd.size === 5, "Filter returned incorrect number of rows") |
84 | | - assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value") |
| 47 | + withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { |
| 48 | + checkAnswer( |
| 49 | + sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), |
| 50 | + Seq(Row(true, "val_2"), Row(true, "val_4"))) |
| 51 | + } |
85 | 52 | } |
86 | 53 |
|
87 | 54 | test("Converting Hive to Parquet Table via saveAsParquetFile") { |
88 | | - sql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath) |
89 | | - parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") |
90 | | - val rddOne = sql("SELECT * FROM src").collect().sortBy(_.getInt(0)) |
91 | | - val rddTwo = sql("SELECT * from ptable").collect().sortBy(_.getInt(0)) |
92 | | - |
93 | | - compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String")) |
| 55 | + withTempPath { dir => |
| 56 | + sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath) |
| 57 | + parquetFile(dir.getCanonicalPath).registerTempTable("p") |
| 58 | + withTempTable("p") { |
| 59 | + checkAnswer( |
| 60 | + sql("SELECT * FROM src ORDER BY key"), |
| 61 | + sql("SELECT * from p ORDER BY key").collect().toSeq) |
| 62 | + } |
| 63 | + } |
94 | 64 | } |
95 | 65 |
|
96 | | - test("INSERT OVERWRITE TABLE Parquet table") { |
97 | | - sql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath) |
98 | | - parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") |
99 | | - // let's do three overwrites for good measure |
100 | | - sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() |
101 | | - sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() |
102 | | - sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() |
103 | | - val rddCopy = sql("SELECT * FROM ptable").collect() |
104 | | - val rddOrig = sql("SELECT * FROM testsource").collect() |
105 | | - assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??") |
106 | | - compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames) |
107 | | - } |
108 | 66 |
|
109 | | - private def compareRDDs(rddOne: Array[Row], rddTwo: Array[Row], tableName: String, fieldNames: Seq[String]) { |
110 | | - var counter = 0 |
111 | | - (rddOne, rddTwo).zipped.foreach { |
112 | | - (a,b) => (a,b).zipped.toArray.zipWithIndex.foreach { |
113 | | - case ((value_1, value_2), index) => |
114 | | - assert(value_1 === value_2, s"table $tableName row $counter field ${fieldNames(index)} don't match") |
| 67 | + test("INSERT OVERWRITE TABLE Parquet table") { |
| 68 | + withParquetTable((1 to 4).map(i => (i, s"val_$i")), "t") { |
| 69 | + withTempPath { file => |
| 70 | + sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath) |
| 71 | + parquetFile(file.getCanonicalPath).registerTempTable("p") |
| 72 | + withTempTable("p") { |
| 73 | + // let's do three overwrites for good measure |
| 74 | + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") |
| 75 | + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") |
| 76 | + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") |
| 77 | + checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) |
| 78 | + } |
115 | 79 | } |
116 | | - counter = counter + 1 |
117 | 80 | } |
118 | 81 | } |
119 | 82 | } |
0 commit comments