diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index dec118330aec..fef788f78359 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -76,12 +76,12 @@ private[libsvm] class LibSVMFileFormat override def toString: String = "LibSVM" - private def verifySchema(dataSchema: StructType): Unit = { + private def verifySchema(dataSchema: StructType, forWriting: Boolean): Unit = { if ( dataSchema.size != 2 || !dataSchema(0).dataType.sameType(DataTypes.DoubleType) || !dataSchema(1).dataType.sameType(new VectorUDT()) || - !(dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt > 0) + !(forWriting || dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt > 0) ) { throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema") } @@ -119,7 +119,7 @@ private[libsvm] class LibSVMFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - verifySchema(dataSchema) + verifySchema(dataSchema, true) new OutputWriterFactory { override def newInstance( path: String, @@ -142,7 +142,7 @@ private[libsvm] class LibSVMFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - verifySchema(dataSchema) + verifySchema(dataSchema, false) val numFeatures = dataSchema("features").metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt assert(numFeatures > 0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index a67e49d54e14..3eabff434e8d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -19,13 +19,16 @@ package org.apache.spark.ml.source.libsvm import java.io.{File, IOException} import java.nio.charset.StandardCharsets +import java.util.List import com.google.common.io.Files import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.ml.linalg.SQLDataTypes.VectorType import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.util.Utils @@ -44,14 +47,14 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { """ |0 2:4.0 4:5.0 6:6.0 """.stripMargin - val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data") + val dir = Utils.createTempDir() val succ = new File(dir, "_SUCCESS") val file0 = new File(dir, "part-00000") val file1 = new File(dir, "part-00001") Files.write("", succ, StandardCharsets.UTF_8) Files.write(lines0, file0, StandardCharsets.UTF_8) Files.write(lines1, file1, StandardCharsets.UTF_8) - path = dir.toURI.toString + path = dir.getPath } override def afterAll(): Unit = { @@ -108,12 +111,12 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("write libsvm data and read it again") { val df = spark.read.format("libsvm").load(path) - val tempDir2 = new File(tempDir, "read_write_test") - val writepath = tempDir2.toURI.toString + val writePath = Utils.createTempDir().getPath + // TODO: Remove requirement to coalesce by supporting multiple reads. - df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath) + df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writePath) - val df2 = spark.read.format("libsvm").load(writepath) + val df2 = spark.read.format("libsvm").load(writePath) val row1 = df2.first() val v = row1.getAs[SparseVector](1) assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) @@ -126,6 +129,27 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("write libsvm data from scratch and read it again") { + val rawData = new java.util.ArrayList[Row]() + rawData.add(Row(1.0, Vectors.sparse(3, Seq((0, 2.0), (1, 3.0))))) + rawData.add(Row(4.0, Vectors.sparse(3, Seq((0, 5.0), (2, 6.0))))) + + val struct = StructType( + StructField("labelFoo", DoubleType, false) :: + StructField("featuresBar", VectorType, false) :: Nil + ) + val df = spark.sqlContext.createDataFrame(rawData, struct) + + val writePath = Utils.createTempDir().getPath + + df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writePath) + + val df2 = spark.read.format("libsvm").load(writePath) + val row1 = df2.first() + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(3, Seq((0, 2.0), (1, 3.0)))) + } + test("select features from libsvm relation") { val df = spark.read.format("libsvm").load(path) df.select("features").rdd.map { case Row(d: Vector) => d }.first