Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,14 @@ case class CsvRelation protected[spark] (
try {
index = 0
while (index < schemaFields.length) {
rowArray(index) = TypeCast.castTo(tokens(index), schemaFields(index).dataType)
rowArray(index) = if (schemaFields(index).nullable && tokens(index) == ""){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to check for nullability for StringType. Basically I think the following logic would be simple an intuitive:

  • if StringType just return the token (don't need to check for nullability)
  • if any other time, if token is "", return null, else cast.

I suggest moving this logic (along with some comment that explains it) to TypeCast.castTo (maybe as a simple private method). This way you can add a few simple unit tests for it. Added benefit is we keep CsvRelation simpler.

schemaFields(index).dataType match {
case StringType => ""
case _ => null
}
} else {
TypeCast.castTo(tokens(index), schemaFields(index).dataType)
}
index = index + 1
}
Some(Row.fromSeq(rowArray))
Expand Down
4 changes: 4 additions & 0 deletions src/test/resources/null-numbers.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name,age
alice,35
bob,
,24
16 changes: 16 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CsvFastSuite extends FunSuite {
val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv"
val carsTsvFile = "src/test/resources/cars.tsv"
val carsAltFile = "src/test/resources/cars-alternative.csv"
val nullNumbersFile = "src/test/resources/null-numbers.csv"
val emptyFile = "src/test/resources/empty.csv"
val escapeFile = "src/test/resources/escape.csv"
val tempEmptyDir = "target/test/empty2/"
Expand Down Expand Up @@ -387,4 +388,19 @@ class CsvFastSuite extends FunSuite {
assert(results.first().getInt(0) === 1997)

}

test("DSL test nullable fields"){

val results = new CsvParser()
.withSchema(StructType(List(StructField("name", StringType, false), StructField("age", IntegerType, true))))
.withUseHeader(true)
.withParserLib("univocity")
.csvFile(TestSQLContext, nullNumbersFile)
.collect()

assert(results.head.toSeq == Seq("alice", 35))
assert(results(1).toSeq == Seq("bob", null))
assert(results(2).toSeq == Seq("", 24))

}
}
15 changes: 15 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CsvSuite extends FunSuite {
val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv"
val carsTsvFile = "src/test/resources/cars.tsv"
val carsAltFile = "src/test/resources/cars-alternative.csv"
val nullNumbersFile = "src/test/resources/null-numbers.csv"
val emptyFile = "src/test/resources/empty.csv"
val escapeFile = "src/test/resources/escape.csv"
val tempEmptyDir = "target/test/empty/"
Expand Down Expand Up @@ -392,4 +393,18 @@ class CsvSuite extends FunSuite {
assert(results.first().getInt(0) === 1997)

}

test("DSL test nullable fields"){

val results = new CsvParser()
.withSchema(StructType(List(StructField("name", StringType, false), StructField("age", IntegerType, true))))
.withUseHeader(true)
.csvFile(TestSQLContext, nullNumbersFile)
.collect()

assert(results.head.toSeq == Seq("alice", 35))
assert(results(1).toSeq == Seq("bob", null))
assert(results(2).toSeq == Seq("", 24))

}
}