Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTree
import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types.LongType
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -1142,12 +1143,21 @@ private[python] class PythonMLLibAPI extends Serializable {
new RowMatrix(rows.rdd, numRows, numCols)
}

def createRowMatrix(df: DataFrame, numRows: Long, numCols: Int): RowMatrix = {
require(df.schema.length == 1 && df.schema.head.dataType.getClass == classOf[VectorUDT],
"DataFrame must have a single vector type column")
new RowMatrix(df.rdd.map { case Row(vector: Vector) => vector }, numRows, numCols)
}

/**
* Wrapper around IndexedRowMatrix constructor.
*/
def createIndexedRowMatrix(rows: DataFrame, numRows: Long, numCols: Int): IndexedRowMatrix = {
// We use DataFrames for serialization of IndexedRows from Python,
// so map each Row in the DataFrame back to an IndexedRow.
require(rows.schema.length == 2 && rows.schema.head.dataType == LongType &&
rows.schema(1).dataType.getClass == classOf[VectorUDT],
"DataFrame must consist of a long type index column and a vector type column")
val indexedRows = rows.rdd.map {
case Row(index: Long, vector: Vector) => IndexedRow(index, vector)
}
Expand Down
15 changes: 11 additions & 4 deletions python/pyspark/mllib/linalg/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix, QRDecomposition
from pyspark.mllib.stat import MultivariateStatisticalSummary
from pyspark.sql import DataFrame
from pyspark.storagelevel import StorageLevel


Expand Down Expand Up @@ -57,7 +58,8 @@ class RowMatrix(DistributedMatrix):
Represents a row-oriented distributed Matrix with no meaningful
row indices.

:param rows: An RDD of vectors.
:param rows: An RDD or DataFrame of vectors. If a DataFrame is provided, it must have a single
vector typed column.
:param numRows: Number of rows in the matrix. A non-positive
value means unknown, at which point the number
of rows will be determined by the number of
Expand All @@ -73,7 +75,7 @@ def __init__(self, rows, numRows=0, numCols=0):

Create a wrapper over a Java RowMatrix.

Publicly, we require that `rows` be an RDD. However, for
Publicly, we require that `rows` be an RDD or DataFrame. However, for
internal usage, `rows` can also be a Java RowMatrix
object, in which case we can wrap it directly. This
assists in clean matrix conversions.
Expand All @@ -94,6 +96,8 @@ def __init__(self, rows, numRows=0, numCols=0):
if isinstance(rows, RDD):
rows = rows.map(_convert_to_vector)
java_matrix = callMLlibFunc("createRowMatrix", rows, long(numRows), int(numCols))
elif isinstance(rows, DataFrame):
java_matrix = callMLlibFunc("createRowMatrix", rows, long(numRows), int(numCols))
elif (isinstance(rows, JavaObject)
and rows.getClass().getSimpleName() == "RowMatrix"):
java_matrix = rows
Expand Down Expand Up @@ -461,7 +465,8 @@ class IndexedRowMatrix(DistributedMatrix):
"""
Represents a row-oriented distributed Matrix with indexed rows.

:param rows: An RDD of IndexedRows or (long, vector) tuples.
:param rows: An RDD of IndexedRows or (long, vector) tuples or a DataFrame consisting of a
long typed column of indices and a vector typed column.
:param numRows: Number of rows in the matrix. A non-positive
value means unknown, at which point the number
of rows will be determined by the max row
Expand All @@ -477,7 +482,7 @@ def __init__(self, rows, numRows=0, numCols=0):

Create a wrapper over a Java IndexedRowMatrix.

Publicly, we require that `rows` be an RDD. However, for
Publicly, we require that `rows` be an RDD or DataFrame. However, for
internal usage, `rows` can also be a Java IndexedRowMatrix
object, in which case we can wrap it directly. This
assists in clean matrix conversions.
Expand Down Expand Up @@ -506,6 +511,8 @@ def __init__(self, rows, numRows=0, numCols=0):
# IndexedRows on the Scala side.
java_matrix = callMLlibFunc("createIndexedRowMatrix", rows.toDF(),
long(numRows), int(numCols))
elif isinstance(rows, DataFrame):
java_matrix = callMLlibFunc("createIndexedRowMatrix", rows, long(numRows), int(numCols))
elif (isinstance(rows, JavaObject)
and rows.getClass().getSimpleName() == "IndexedRowMatrix"):
java_matrix = rows
Expand Down
23 changes: 23 additions & 0 deletions python/pyspark/mllib/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@
from pyspark.serializers import PickleSerializer
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
from pyspark.mllib.linalg.distributed import RowMatrix, IndexedRowMatrix
from pyspark.mllib.regression import LabeledPoint
from pyspark.sql import Row
from pyspark.testing.mllibutils import MLlibTestCase
from pyspark.testing.utils import have_scipy

if sys.version >= '3':
long = int


class VectorTests(MLlibTestCase):

Expand Down Expand Up @@ -431,6 +436,24 @@ def test_infer_schema(self):
else:
raise TypeError("expecting a vector but got %r of type %r" % (v, type(v)))

def test_row_matrix_from_dataframe(self):
from pyspark.sql.utils import IllegalArgumentException
df = self.spark.createDataFrame([Row(Vectors.dense(1))])
row_matrix = RowMatrix(df)
self.assertEqual(row_matrix.numRows(), 1)
self.assertEqual(row_matrix.numCols(), 1)
with self.assertRaises(IllegalArgumentException):
RowMatrix(df.selectExpr("'monkey'"))

def test_indexed_row_matrix_from_dataframe(self):
from pyspark.sql.utils import IllegalArgumentException
df = self.spark.createDataFrame([Row(long(0), Vectors.dense(1))])
matrix = IndexedRowMatrix(df)
self.assertEqual(matrix.numRows(), 1)
self.assertEqual(matrix.numCols(), 1)
with self.assertRaises(IllegalArgumentException):
IndexedRowMatrix(df.drop("_1"))


class MatrixUDTTests(MLlibTestCase):

Expand Down