Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,27 @@ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types._

class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val episodesFile = "src/test/resources/episodes.avro"
val testFile = "src/test/resources/test.avro"
val episodesAvro = testFile("episodes.avro")
val testAvro = testFile("test.avro")

override protected def beforeAll(): Unit = {
super.beforeAll()
spark.conf.set("spark.sql.files.maxPartitionBytes", 1024)
}

def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = {
val originalEntries = spark.read.avro(testFile).collect()
val originalEntries = spark.read.avro(testAvro).collect()
val newEntries = spark.read.avro(newFile)
checkAnswer(newEntries, originalEntries)
}

test("reading from multiple paths") {
val df = spark.read.avro(episodesFile, episodesFile)
val df = spark.read.avro(episodesAvro, episodesAvro)
assert(df.count == 16)
}

test("reading and writing partitioned data") {
val df = spark.read.avro(episodesFile)
val df = spark.read.avro(episodesAvro)
val fields = List("title", "air_date", "doctor")
for (field <- fields) {
withTempPath { dir =>
Expand All @@ -72,22 +72,22 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}

test("request no fields") {
val df = spark.read.avro(episodesFile)
val df = spark.read.avro(episodesAvro)
df.createOrReplaceTempView("avro_table")
assert(spark.sql("select count(*) from avro_table").collect().head === Row(8))
}

test("convert formats") {
withTempPath { dir =>
val df = spark.read.avro(episodesFile)
val df = spark.read.avro(episodesAvro)
df.write.parquet(dir.getCanonicalPath)
assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count)
}
}

test("rearrange internal schema") {
withTempPath { dir =>
val df = spark.read.avro(episodesFile)
val df = spark.read.avro(episodesAvro)
df.select("doctor", "title").write.avro(dir.getCanonicalPath)
}
}
Expand Down Expand Up @@ -362,7 +362,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val deflateDir = s"$dir/deflate"
val snappyDir = s"$dir/snappy"

val df = spark.read.avro(testFile)
val df = spark.read.avro(testAvro)
spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed")
df.write.avro(uncompressDir)
spark.conf.set(AVRO_COMPRESSION_CODEC, "deflate")
Expand All @@ -381,49 +381,49 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}

test("dsl test") {
val results = spark.read.avro(episodesFile).select("title").collect()
val results = spark.read.avro(episodesAvro).select("title").collect()
assert(results.length === 8)
}

test("support of various data types") {
// This test uses data from test.avro. You can see the data and the schema of this file in
// test.json and test.avsc
val all = spark.read.avro(testFile).collect()
val all = spark.read.avro(testAvro).collect()
assert(all.length == 3)

val str = spark.read.avro(testFile).select("string").collect()
val str = spark.read.avro(testAvro).select("string").collect()
assert(str.map(_(0)).toSet.contains("Terran is IMBA!"))

val simple_map = spark.read.avro(testFile).select("simple_map").collect()
val simple_map = spark.read.avro(testAvro).select("simple_map").collect()
assert(simple_map(0)(0).getClass.toString.contains("Map"))
assert(simple_map.map(_(0).asInstanceOf[Map[String, Some[Int]]].size).toSet == Set(2, 0))

val union0 = spark.read.avro(testFile).select("union_string_null").collect()
val union0 = spark.read.avro(testAvro).select("union_string_null").collect()
assert(union0.map(_(0)).toSet == Set("abc", "123", null))

val union1 = spark.read.avro(testFile).select("union_int_long_null").collect()
val union1 = spark.read.avro(testAvro).select("union_int_long_null").collect()
assert(union1.map(_(0)).toSet == Set(66, 1, null))

val union2 = spark.read.avro(testFile).select("union_float_double").collect()
val union2 = spark.read.avro(testAvro).select("union_float_double").collect()
assert(
union2
.map(x => new java.lang.Double(x(0).toString))
.exists(p => Math.abs(p - Math.PI) < 0.001))

val fixed = spark.read.avro(testFile).select("fixed3").collect()
val fixed = spark.read.avro(testAvro).select("fixed3").collect()
assert(fixed.map(_(0).asInstanceOf[Array[Byte]]).exists(p => p(1) == 3))

val enum = spark.read.avro(testFile).select("enum").collect()
val enum = spark.read.avro(testAvro).select("enum").collect()
assert(enum.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS"))

val record = spark.read.avro(testFile).select("record").collect()
val record = spark.read.avro(testAvro).select("record").collect()
assert(record(0)(0).getClass.toString.contains("Row"))
assert(record.map(_(0).asInstanceOf[Row](0)).contains("TEST_STR123"))

val array_of_boolean = spark.read.avro(testFile).select("array_of_boolean").collect()
val array_of_boolean = spark.read.avro(testAvro).select("array_of_boolean").collect()
assert(array_of_boolean.map(_(0).asInstanceOf[Seq[Boolean]].size).toSet == Set(3, 1, 0))

val bytes = spark.read.avro(testFile).select("bytes").collect()
val bytes = spark.read.avro(testAvro).select("bytes").collect()
assert(bytes.map(_(0).asInstanceOf[Array[Byte]].length).toSet == Set(3, 1, 0))
}

Expand All @@ -432,7 +432,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
s"""
|CREATE TEMPORARY VIEW avroTable
|USING avro
|OPTIONS (path "$episodesFile")
|OPTIONS (path "${episodesAvro}")
""".stripMargin.replaceAll("\n", " "))

assert(spark.sql("SELECT * FROM avroTable").collect().length === 8)
Expand All @@ -443,8 +443,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
// get the same values back.
withTempPath { dir =>
val avroDir = s"$dir/avro"
spark.read.avro(testFile).write.avro(avroDir)
checkReloadMatchesSaved(testFile, avroDir)
spark.read.avro(testAvro).write.avro(avroDir)
checkReloadMatchesSaved(testAvro, avroDir)
}
}

Expand All @@ -457,8 +457,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val parameters = Map("recordName" -> name, "recordNamespace" -> namespace)

val avroDir = tempDir + "/namedAvro"
spark.read.avro(testFile).write.options(parameters).avro(avroDir)
checkReloadMatchesSaved(testFile, avroDir)
spark.read.avro(testAvro).write.options(parameters).avro(avroDir)
checkReloadMatchesSaved(testAvro, avroDir)

// Look at raw file and make sure has namespace info
val rawSaved = spark.sparkContext.textFile(avroDir)
Expand Down Expand Up @@ -532,10 +532,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}

test("support of globbed paths") {
val e1 = spark.read.avro("*/test/resources/episodes.avro").collect()
val resourceDir = testFile(".")
val e1 = spark.read.avro(resourceDir + "../*/episodes.avro").collect()
assert(e1.length == 8)

val e2 = spark.read.avro("src/*/*/episodes.avro").collect()
val e2 = spark.read.avro(resourceDir + "../../*/*/episodes.avro").collect()
assert(e2.length == 8)
}

Expand Down Expand Up @@ -574,8 +575,12 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
| }]
|}
""".stripMargin
val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema).avro(testFile).collect()
val expected = spark.read.avro(testFile).select("string").collect()
val result = spark
.read
.option(AvroFileFormat.AvroSchema, avroSchema)
.avro(testAvro)
.collect()
val expected = spark.read.avro(testAvro).select("string").collect()
assert(result.sameElements(expected))
}

Expand All @@ -593,7 +598,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|}
""".stripMargin
val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema)
.avro(testFile).select("missingField").first
.avro(testAvro).select("missingField").first
assert(result === Row("foo"))
}

Expand Down Expand Up @@ -632,7 +637,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
s"""
|CREATE TEMPORARY VIEW episodes
|USING avro
|OPTIONS (path "$episodesFile")
|OPTIONS (path "${episodesAvro}")
""".stripMargin.replaceAll("\n", " "))
spark.sql(
s"""
Expand All @@ -657,7 +662,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("test save and load") {
// Test if load works as expected
withTempPath { tempDir =>
val df = spark.read.avro(episodesFile)
val df = spark.read.avro(episodesAvro)
assert(df.count == 8)

val tempSaveDir = s"$tempDir/save/"
Expand All @@ -671,7 +676,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("test load with non-Avro file") {
// Test if load works as expected
withTempPath { tempDir =>
val df = spark.read.avro(episodesFile)
val df = spark.read.avro(episodesAvro)
assert(df.count == 8)

val tempSaveDir = s"$tempDir/save/"
Expand Down Expand Up @@ -701,10 +706,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
StructField("record", StructType(Seq(StructField("value_field", StringType, false))), false),
StructField("array_of_boolean", ArrayType(BooleanType), false),
StructField("bytes", BinaryType, true)))
val withSchema = spark.read.schema(partialColumns).avro(testFile).collect()
val withSchema = spark.read.schema(partialColumns).avro(testAvro).collect()
val withOutSchema = spark
.read
.avro(testFile)
.avro(testAvro)
.select("string", "simple_map", "complex_map", "union_string_null", "union_int_long_null",
"fixed3", "fixed2", "enum", "record", "array_of_boolean", "bytes")
.collect()
Expand All @@ -722,7 +727,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
StructField("non_exist_field", StringType, false),
StructField("non_exist_field2", StringType, false))),
false)))
val withEmptyColumn = spark.read.schema(schema).avro(testFile).collect()
val withEmptyColumn = spark.read.schema(schema).avro(testAvro).collect()

assert(withEmptyColumn.forall(_ == Row(null: String, Row(null: String, null: String))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
private val unescapedQuotesFile = "test-data/unescaped-quotes.csv"
private val valueMalformedFile = "test-data/value-malformed.csv"

private def testFile(fileName: String): String = {
Thread.currentThread().getContextClassLoader.getResource(fileName).toString
}

/** Verifies data and schema. */
private def verifyCars(
df: DataFrame,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@ class TestFileFilter extends PathFilter {
class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
import testImplicits._

def testFile(fileName: String): String = {
Thread.currentThread().getContextClassLoader.getResource(fileName).toString
}

test("Type promotion") {
def checkTypePromotion(expected: Any, actual: Any) {
assert(expected.getClass == actual.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import java.io.File

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types.{StringType, StructType}

class WholeTextFileSuite extends QueryTest with SharedSQLContext {
class WholeTextFileSuite extends QueryTest with SharedSQLContext with SQLTestUtils {

// Hadoop's FileSystem caching does not use the Configuration as part of its cache key, which
// can cause Filesystem.get(Configuration) to return a cached instance created with a different
Expand All @@ -35,13 +35,10 @@ class WholeTextFileSuite extends QueryTest with SharedSQLContext {
protected override def sparkConf =
super.sparkConf.set("spark.hadoop.fs.file.impl.disable.cache", "true")

private def testFile: String = {
Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString
}

test("reading text file with option wholetext=true") {
val df = spark.read.option("wholetext", "true")
.format("text").load(testFile)
.format("text")
.load(testFile("test-data/text-suite.txt"))
// schema
assert(df.schema == new StructType().add("value", StringType))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,13 @@ private[sql] trait SQLTestUtilsBase
val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf())
fs.makeQualified(hadoopPath).toUri
}

/**
* Returns full path to the given file in the resouce folder
*/
protected def testFile(fileName: String): String = {
Thread.currentThread().getContextClassLoader.getResource(fileName).toString
}
}

private[sql] object SQLTestUtils {
Expand Down