-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-9793] [MLlib] [PySpark] PySpark DenseVector, SparseVector implement __eq__ and __hash__ correctly #8166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1e9d1bc
7489a44
83f51ed
fca0f5a
d3f8c14
3b8ac7a
b58d1bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
|
|
||
| import sys | ||
| import array | ||
| import struct | ||
|
|
||
| if sys.version >= '3': | ||
| basestring = str | ||
|
|
@@ -122,6 +123,13 @@ def _format_float_list(l): | |
| return [_format_float(x) for x in l] | ||
|
|
||
|
|
||
| def _double_to_long_bits(value): | ||
| if np.isnan(value): | ||
| value = float('nan') | ||
| # pack double into 64 bits, then unpack as long int | ||
| return struct.unpack('Q', struct.pack('d', value))[0] | ||
|
|
||
|
|
||
| class VectorUDT(UserDefinedType): | ||
| """ | ||
| SQL user-defined type (UDT) for Vector. | ||
|
|
@@ -404,11 +412,31 @@ def __repr__(self): | |
| return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array)) | ||
|
|
||
| def __eq__(self, other): | ||
| return isinstance(other, DenseVector) and np.array_equal(self.array, other.array) | ||
| if isinstance(other, DenseVector): | ||
| return np.array_equal(self.array, other.array) | ||
| elif isinstance(other, SparseVector): | ||
| if len(self) != other.size: | ||
| return False | ||
| return Vectors._equals(list(xrange(len(self))), self.array, other.indices, other.values) | ||
| return False | ||
|
|
||
| def __ne__(self, other): | ||
| return not self == other | ||
|
|
||
| def __hash__(self): | ||
| size = len(self) | ||
| result = 31 + size | ||
| nnz = 0 | ||
| i = 0 | ||
| while i < size and nnz < 128: | ||
| if self.array[i] != 0: | ||
| result = 31 * result + i | ||
| bits = _double_to_long_bits(self.array[i]) | ||
| result = 31 * result + (bits ^ (bits >> 32)) | ||
| nnz += 1 | ||
| i += 1 | ||
| return result | ||
|
|
||
| def __getattr__(self, item): | ||
| return getattr(self.array, item) | ||
|
|
||
|
|
@@ -704,20 +732,14 @@ def __repr__(self): | |
| return "SparseVector({0}, {{{1}}})".format(self.size, entries) | ||
|
|
||
| def __eq__(self, other): | ||
| """ | ||
| Test SparseVectors for equality. | ||
|
|
||
| >>> v1 = SparseVector(4, [(1, 1.0), (3, 5.5)]) | ||
| >>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) | ||
| >>> v1 == v2 | ||
| True | ||
| >>> v1 != v2 | ||
| False | ||
| """ | ||
| return (isinstance(other, self.__class__) | ||
| and other.size == self.size | ||
| and np.array_equal(other.indices, self.indices) | ||
| and np.array_equal(other.values, self.values)) | ||
| if isinstance(other, SparseVector): | ||
| return other.size == self.size and np.array_equal(other.indices, self.indices) \ | ||
| and np.array_equal(other.values, self.values) | ||
| elif isinstance(other, DenseVector): | ||
| if self.size != len(other): | ||
| return False | ||
| return Vectors._equals(self.indices, self.values, list(xrange(len(other))), other.array) | ||
| return False | ||
|
|
||
| def __getitem__(self, index): | ||
| inds = self.indices | ||
|
|
@@ -739,6 +761,19 @@ def __getitem__(self, index): | |
| def __ne__(self, other): | ||
| return not self.__eq__(other) | ||
|
|
||
| def __hash__(self): | ||
| result = 31 + self.size | ||
| nnz = 0 | ||
| i = 0 | ||
| while i < len(self.values) and nnz < 128: | ||
| if self.values[i] != 0: | ||
| result = 31 * result + int(self.indices[i]) | ||
| bits = _double_to_long_bits(self.values[i]) | ||
| result = 31 * result + (bits ^ (bits >> 32)) | ||
| nnz += 1 | ||
| i += 1 | ||
| return result | ||
|
|
||
|
|
||
| class Vectors(object): | ||
|
|
||
|
|
@@ -841,6 +876,31 @@ def parse(s): | |
| def zeros(size): | ||
| return DenseVector(np.zeros(size)) | ||
|
|
||
| @staticmethod | ||
| def _equals(v1_indices, v1_values, v2_indices, v2_values): | ||
| """ | ||
| Check equality between sparse/dense vectors, | ||
| v1_indices and v2_indices assume to be strictly increasing. | ||
| """ | ||
| v1_size = len(v1_values) | ||
| v2_size = len(v2_values) | ||
| k1 = 0 | ||
| k2 = 0 | ||
| all_equal = True | ||
| while all_equal: | ||
| while k1 < v1_size and v1_values[k1] == 0: | ||
| k1 += 1 | ||
| while k2 < v2_size and v2_values[k2] == 0: | ||
| k2 += 1 | ||
|
|
||
| if k1 >= v1_size or k2 >= v2_size: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: since
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I think checking
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, that's fine with me |
||
| return k1 >= v1_size and k2 >= v2_size | ||
|
|
||
| all_equal = v1_indices[k1] == v2_indices[k2] and v1_values[k1] == v2_values[k2] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about when
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, this is OK since https://github.com/apache/spark/blob/master/python/pyspark/mllib/linalg/__init__.py#L489 checks for that. Could you please document this assumption though?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes,
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, can you please document that in the method's docstring. #7854 is proposing to remove the explicit call to |
||
| k1 += 1 | ||
| k2 += 1 | ||
| return all_equal | ||
|
|
||
|
|
||
| class Matrix(object): | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can make the code more readable: