Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 147 additions & 1 deletion python/pyspark/mllib/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,46 @@ def __init__(self, ar):
ar = ar.astype(np.float64)
self.array = ar

@staticmethod
def parse(s):
"""
Parse string representation back into the DenseVector.

>>> DenseVector.parse(' [ 0.0,1.0,2.0, 3.0]')
DenseVector([0.0, 1.0, 2.0, 3.0])
"""
start = s.find('[')
if start == -1:
raise ValueError("Array should start with '['.")
end = s.find(']')
if end == -1:
raise ValueError("Array should end with ']'.")
s = s[start + 1: end]

try:
values = [float(val) for val in s.split(',')]
except ValueError:
raise ValueError("Unable to parse values from %s" % s)
return DenseVector(values)

def __reduce__(self):
return DenseVector, (self.array.tostring(),)

def numNonzeros(self):
return np.count_nonzero(self.array)

def norm(self, p):
"""
Calculte the norm of a DenseVector.

>>> a = DenseVector([0, -1, 2, -3])
>>> a.norm(2)
3.7...
>>> a.norm(1)
6.0
"""
return np.linalg.norm(self.array, p)

def dot(self, other):
"""
Compute the dot product of two Vectors. We support
Expand Down Expand Up @@ -387,8 +424,74 @@ def __init__(self, size, *args):
if self.indices[i] >= self.indices[i + 1]:
raise TypeError("indices array must be sorted")

def numNonzeros(self):
return np.count_nonzero(self.values)

def norm(self, p):
"""
Calculte the norm of a SparseVector.

>>> a = SparseVector(4, [0, 1], [3., -4.])
>>> a.norm(1)
7.0
>>> a.norm(2)
5.0
"""
return np.linalg.norm(self.values, p)

def __reduce__(self):
return (SparseVector, (self.size, self.indices.tostring(), self.values.tostring()))
return (
SparseVector,
(self.size, self.indices.tostring(), self.values.tostring()))

@staticmethod
def parse(s):
"""
Parse string representation back into the DenseVector.

>>> SparseVector.parse(' (4, [0,1 ],[ 4.0,5.0] )')
SparseVector(4, {0: 4.0, 1: 5.0})
"""
start = s.find('(')
if start == -1:
raise ValueError("Tuple should start with '('")
end = s.find(')')
if start == -1:
raise ValueError("Tuple should end with ')'")
s = s[start + 1: end].strip()

size = s[: s.find(',')]
try:
size = int(size)
except ValueError:
raise ValueError("Cannot parse size %s." % size)

ind_start = s.find('[')
if ind_start == -1:
raise ValueError("Indices array should start with '['.")
ind_end = s.find(']')
if ind_end == -1:
raise ValueError("Indices array should end with ']'")
new_s = s[ind_start + 1: ind_end]
ind_list = new_s.split(',')
try:
indices = [int(ind) for ind in ind_list]
except ValueError:
raise ValueError("Unable to parse indices from %s." % new_s)
s = s[ind_end + 1:].strip()

val_start = s.find('[')
if val_start == -1:
raise ValueError("Values array should start with '['.")
val_end = s.find(']')
if val_end == -1:
raise ValueError("Values array should end with ']'.")
val_list = s[val_start + 1: val_end].split(',')
try:
values = [float(val) for val in val_list]
except ValueError:
raise ValueError("Unable to parse values from %s." % s)
return SparseVector(size, indices, values)

def dot(self, other):
"""
Expand Down Expand Up @@ -633,6 +736,49 @@ def stringify(vector):
"""
return str(vector)

@staticmethod
def squared_distance(v1, v2):
"""
Squared distance between two vectors.
a and b can be of type SparseVector, DenseVector, np.ndarray
or array.array.

>>> a = Vectors.sparse(4, [(0, 1), (3, 4)])
>>> b = Vectors.dense([2, 5, 4, 1])
>>> a.squared_distance(b)
51.0
"""
v1, v2 = _convert_to_vector(v1), _convert_to_vector(v2)
return v1.squared_distance(v2)

@staticmethod
def norm(vector, p):
"""
Find norm of the given vector.
"""
return _convert_to_vector(vector).norm(p)

@staticmethod
def parse(s):
"""Parse a string representation back into the Vector.

>>> Vectors.parse('[2,1,2 ]')
DenseVector([2.0, 1.0, 2.0])
>>> Vectors.parse(' ( 100, [0], [2])')
SparseVector(100, {0: 2.0})
"""
if s.find('(') == -1 and s.find('[') != -1:
return DenseVector.parse(s)
elif s.find('(') != -1:
return SparseVector.parse(s)
else:
raise ValueError(
"Cannot find tokens '[' or '(' from the input string.")

@staticmethod
def zeros(size):
return DenseVector(np.zeros(size))


class Matrix(object):
"""
Expand Down
25 changes: 24 additions & 1 deletion python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import tempfile
import array as pyarray

from numpy import array, array_equal, zeros
from numpy import array, array_equal, zeros, inf
from py4j.protocol import Py4JJavaError

if sys.version_info[:2] <= (2, 6):
Expand Down Expand Up @@ -220,6 +220,29 @@ def test_dense_matrix_is_transposed(self):
self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5]))
self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))

def test_parse_vector(self):
a = DenseVector([3, 4, 6, 7])
self.assertTrue(str(a), '[3.0,4.0,6.0,7.0]')
self.assertTrue(Vectors.parse(str(a)), a)
a = SparseVector(4, [0, 2], [3, 4])
self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])')
self.assertTrue(Vectors.parse(str(a)), a)
a = SparseVector(10, [0, 1], [4, 5])
self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)

def test_norms(self):
a = DenseVector([0, 2, 3, -1])
self.assertAlmostEqual(a.norm(2), 3.742, 3)
self.assertTrue(a.norm(1), 6)
self.assertTrue(a.norm(inf), 3)
a = SparseVector(4, [0, 2], [3, -4])
self.assertAlmostEqual(a.norm(2), 5)
self.assertTrue(a.norm(1), 7)
self.assertTrue(a.norm(inf), 4)

tmp = SparseVector(4, [0, 2], [3, 0])
self.assertEqual(tmp.numNonzeros(), 1)


class ListTests(MLlibTestCase):

Expand Down