diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 322ef93473da..4c478a5477c0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -53,6 +53,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTree import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types.LongType import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -1142,12 +1143,21 @@ private[python] class PythonMLLibAPI extends Serializable { new RowMatrix(rows.rdd, numRows, numCols) } + def createRowMatrix(df: DataFrame, numRows: Long, numCols: Int): RowMatrix = { + require(df.schema.length == 1 && df.schema.head.dataType.getClass == classOf[VectorUDT], + "DataFrame must have a single vector type column") + new RowMatrix(df.rdd.map { case Row(vector: Vector) => vector }, numRows, numCols) + } + /** * Wrapper around IndexedRowMatrix constructor. */ def createIndexedRowMatrix(rows: DataFrame, numRows: Long, numCols: Int): IndexedRowMatrix = { // We use DataFrames for serialization of IndexedRows from Python, // so map each Row in the DataFrame back to an IndexedRow. + require(rows.schema.length == 2 && rows.schema.head.dataType == LongType && + rows.schema(1).dataType.getClass == classOf[VectorUDT], + "DataFrame must consist of a long type index column and a vector type column") val indexedRows = rows.rdd.map { case Row(index: Long, vector: Vector) => IndexedRow(index, vector) } diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index b7f09782be9d..56701758c89c 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -30,6 +30,7 @@ from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix, QRDecomposition from pyspark.mllib.stat import MultivariateStatisticalSummary +from pyspark.sql import DataFrame from pyspark.storagelevel import StorageLevel @@ -57,7 +58,8 @@ class RowMatrix(DistributedMatrix): Represents a row-oriented distributed Matrix with no meaningful row indices. - :param rows: An RDD of vectors. + :param rows: An RDD or DataFrame of vectors. If a DataFrame is provided, it must have a single + vector typed column. :param numRows: Number of rows in the matrix. A non-positive value means unknown, at which point the number of rows will be determined by the number of @@ -73,7 +75,7 @@ def __init__(self, rows, numRows=0, numCols=0): Create a wrapper over a Java RowMatrix. - Publicly, we require that `rows` be an RDD. However, for + Publicly, we require that `rows` be an RDD or DataFrame. However, for internal usage, `rows` can also be a Java RowMatrix object, in which case we can wrap it directly. This assists in clean matrix conversions. @@ -94,6 +96,8 @@ def __init__(self, rows, numRows=0, numCols=0): if isinstance(rows, RDD): rows = rows.map(_convert_to_vector) java_matrix = callMLlibFunc("createRowMatrix", rows, long(numRows), int(numCols)) + elif isinstance(rows, DataFrame): + java_matrix = callMLlibFunc("createRowMatrix", rows, long(numRows), int(numCols)) elif (isinstance(rows, JavaObject) and rows.getClass().getSimpleName() == "RowMatrix"): java_matrix = rows @@ -461,7 +465,8 @@ class IndexedRowMatrix(DistributedMatrix): """ Represents a row-oriented distributed Matrix with indexed rows. - :param rows: An RDD of IndexedRows or (long, vector) tuples. + :param rows: An RDD of IndexedRows or (long, vector) tuples or a DataFrame consisting of a + long typed column of indices and a vector typed column. :param numRows: Number of rows in the matrix. A non-positive value means unknown, at which point the number of rows will be determined by the max row @@ -477,7 +482,7 @@ def __init__(self, rows, numRows=0, numCols=0): Create a wrapper over a Java IndexedRowMatrix. - Publicly, we require that `rows` be an RDD. However, for + Publicly, we require that `rows` be an RDD or DataFrame. However, for internal usage, `rows` can also be a Java IndexedRowMatrix object, in which case we can wrap it directly. This assists in clean matrix conversions. @@ -506,6 +511,8 @@ def __init__(self, rows, numRows=0, numCols=0): # IndexedRows on the Scala side. java_matrix = callMLlibFunc("createIndexedRowMatrix", rows.toDF(), long(numRows), int(numCols)) + elif isinstance(rows, DataFrame): + java_matrix = callMLlibFunc("createIndexedRowMatrix", rows, long(numRows), int(numCols)) elif (isinstance(rows, JavaObject) and rows.getClass().getSimpleName() == "IndexedRowMatrix"): java_matrix = rows diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py index 703aed2fe16a..588fc6259317 100644 --- a/python/pyspark/mllib/tests/test_linalg.py +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -25,10 +25,15 @@ from pyspark.serializers import PickleSerializer from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +from pyspark.mllib.linalg.distributed import RowMatrix, IndexedRowMatrix from pyspark.mllib.regression import LabeledPoint +from pyspark.sql import Row from pyspark.testing.mllibutils import MLlibTestCase from pyspark.testing.utils import have_scipy +if sys.version >= '3': + long = int + class VectorTests(MLlibTestCase): @@ -431,6 +436,24 @@ def test_infer_schema(self): else: raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) + def test_row_matrix_from_dataframe(self): + from pyspark.sql.utils import IllegalArgumentException + df = self.spark.createDataFrame([Row(Vectors.dense(1))]) + row_matrix = RowMatrix(df) + self.assertEqual(row_matrix.numRows(), 1) + self.assertEqual(row_matrix.numCols(), 1) + with self.assertRaises(IllegalArgumentException): + RowMatrix(df.selectExpr("'monkey'")) + + def test_indexed_row_matrix_from_dataframe(self): + from pyspark.sql.utils import IllegalArgumentException + df = self.spark.createDataFrame([Row(long(0), Vectors.dense(1))]) + matrix = IndexedRowMatrix(df) + self.assertEqual(matrix.numRows(), 1) + self.assertEqual(matrix.numCols(), 1) + with self.assertRaises(IllegalArgumentException): + IndexedRowMatrix(df.drop("_1")) + class MatrixUDTTests(MLlibTestCase):