diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 334dc8e38bb8..344b8acc3f86 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -461,32 +461,41 @@ def __init__(self, size, *args): self.size = int(size) """ Size of the vector. """ assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" - if len(args) == 1: - pairs = args[0] - if type(pairs) == dict: - pairs = pairs.items() - pairs = sorted(pairs) - self.indices = np.array([p[0] for p in pairs], dtype=np.int32) - """ A list of indices corresponding to active entries. """ - self.values = np.array([p[1] for p in pairs], dtype=np.float64) - """ A list of values corresponding to active entries. """ + if isinstance(args[0], bytes): + assert isinstance(args[1], bytes), "values should be string too" + if args[0]: + self.indices = np.frombuffer(args[0], np.int32) + self.values = np.frombuffer(args[1], np.float64) + else: + # np.frombuffer() doesn't work well with empty string in older version + self.indices = np.array([], dtype=np.int32) + self.values = np.array([], dtype=np.float64) else: - if isinstance(args[0], bytes): - assert isinstance(args[1], bytes), "values should be string too" - if args[0]: - self.indices = np.frombuffer(args[0], np.int32) - self.values = np.frombuffer(args[1], np.float64) - else: - # np.frombuffer() doesn't work well with empty string in older version - self.indices = np.array([], dtype=np.int32) - self.values = np.array([], dtype=np.float64) + if len(args) == 1: + args = args[0] + if isinstance(args, dict): + args = args.items() + args = list(zip(*args)) + + # Handle empty args case. + if len(args) == 0: + indices = [] + values = [] else: - self.indices = np.array(args[0], dtype=np.int32) - self.values = np.array(args[1], dtype=np.float64) - assert len(self.indices) == len(self.values), "index and value arrays not same length" - for i in xrange(len(self.indices) - 1): - if self.indices[i] >= self.indices[i + 1]: - raise TypeError("indices array must be sorted") + indices, values = args + + """ A list of indices corresponding to active entries. """ + self.indices = np.array(indices, dtype=np.int32) + """ A list of values corresponding to active entries. """ + self.values = np.array(values, dtype=np.float64) + + indices_length = len(self.indices) + values_length = len(self.values) + if indices_length != values_length: + raise ValueError( + "expected values of length %d, got %d." % (indices_length, values_length)) + if indices_length > self.size: + raise ValueError("expected indices length <= %d, got %d" % (self.size, indices_length)) def numNonzeros(self): return np.count_nonzero(self.values)