diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 334dc8e38bb8f..29568b0eab1bc 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -25,6 +25,7 @@ import sys import array +import struct if sys.version >= '3': basestring = str @@ -122,6 +123,15 @@ def _format_float_list(l): return [_format_float(x) for x in l] +def _double_to_long_bits(value): + if value != value: + # value is NaN, standardize to canonical non-signaling NaN + return 0x7ff8000000000000 + else: + # pack double into 64 bits, then unpack as long int + return struct.unpack('Q', struct.pack('d', value))[0] + + class VectorUDT(UserDefinedType): """ SQL user-defined type (UDT) for Vector. @@ -409,6 +419,34 @@ def __eq__(self, other): def __ne__(self, other): return not self == other + def __hash__(self): + """ + Compute hashcode + + >>> v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + >>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + >>> hash(v1) == hash(v2) + True + >>> v2 = DenseVector([0.0, 1.0, 0.0, 5.5]) + >>> hash(v1) == hash(v2) + True + >>> v2 = DenseVector([1.0, 1.0, 0.0, 5.5]) + >>> hash(v1) == hash(v2) + False + """ + size = len(self) + result = 31 + size + count = 0 + i = 0 + while i < size and count < 16: + if self.array[i] != 0: + bits = _double_to_long_bits(self.array[i] + i) + result = 31 * result + (bits ^ (bits >> 32)) + + count += 1 + i += 1 + return result + def __getattr__(self, item): return getattr(self.array, item) @@ -739,6 +777,33 @@ def __getitem__(self, index): def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): + """ + Compute hashcode + + >>> v1 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + >>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + >>> hash(v1) == hash(v2) + True + >>> v2 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + >>> hash(v1) == hash(v2) + False + >>> v2 = SparseVector(4, [(2, 1.0), (3, 5.5)]) + >>> hash(v1) == hash(v2) + False + """ + result = 31 + self.size + count = 0 + i = 0 + while i < len(self.values) and count < 16: + if self.values[i] != 0: + bits = _double_to_long_bits(self.values[i] + self.indices[i]) + result = 31 * result + (bits ^ (bits >> 32)) + + count += 1 + i += 1 + return result + class Vectors(object):