Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions python/pyspark/mllib/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,32 +461,41 @@ def __init__(self, size, *args):
self.size = int(size)
""" Size of the vector. """
assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments"
if len(args) == 1:
pairs = args[0]
if type(pairs) == dict:
pairs = pairs.items()
pairs = sorted(pairs)
self.indices = np.array([p[0] for p in pairs], dtype=np.int32)
""" A list of indices corresponding to active entries. """
self.values = np.array([p[1] for p in pairs], dtype=np.float64)
""" A list of values corresponding to active entries. """
if isinstance(args[0], bytes):
assert isinstance(args[1], bytes), "values should be string too"
if args[0]:
self.indices = np.frombuffer(args[0], np.int32)
self.values = np.frombuffer(args[1], np.float64)
else:
# np.frombuffer() doesn't work well with empty string in older version
self.indices = np.array([], dtype=np.int32)
self.values = np.array([], dtype=np.float64)
else:
if isinstance(args[0], bytes):
assert isinstance(args[1], bytes), "values should be string too"
if args[0]:
self.indices = np.frombuffer(args[0], np.int32)
self.values = np.frombuffer(args[1], np.float64)
else:
# np.frombuffer() doesn't work well with empty string in older version
self.indices = np.array([], dtype=np.int32)
self.values = np.array([], dtype=np.float64)
if len(args) == 1:
args = args[0]
if isinstance(args, dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if args[0] is a list of (index, value) tuples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to change a dict into a (indices, values) tuple and then unpack it below on L485

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK makes sense

args = args.items()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

items() has no guarantees on the ordering of the keys... Is it okay that indices may not be sorted after this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we should use an OrderedDict here

args = list(zip(*args))

# Handle empty args case.
if len(args) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unnecessary since L463 checks that len(args) >= 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But args is changed before this line.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutation only happens inside the if block on L474, so this can be moved into that block without affecting correctness

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting something like this?

        if len(args) == 1:
            args = args[0]
            if isinstance(args, dict):
                args = args.items()
            args = list(zip(*args))

            # Handle empty args case.
            if len(args) == 0:
                indices = []
                values = []

        indices, values = args

That would not work because it will raise an error if lens(args) is zero. and if I keep the else statement, the unpacking is done only when a list of lists / tuples is provided.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By len(args), I mean to handle this case.

SparseVector(1, {})

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking something more along the lines of

        if len(args) == 1:
            args = args[0]

            # Handle empty args case.
            if len(args) == 0:
                args = [[], []]
            elif isinstance(args, dict):
                args = list(zip(*args.items())
        indices, values = args

You don't have to worry about len(args) == 0 because of L463 and this makes it so len(args) == 0 is only checked when args can possibly be [[]] or [dict()] (the current impl will check both len(args) == 1 and len(args) ==0 when args = [[0,1,2],[3,4,5]] whereas after this change only len(args)==1 will be checked)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. This seems better, but I do know that the lens(args) is checked in L463 LOL (as you have repeated 3 times), it was just to check only the empty dict case.

indices = []
values = []
else:
self.indices = np.array(args[0], dtype=np.int32)
self.values = np.array(args[1], dtype=np.float64)
assert len(self.indices) == len(self.values), "index and value arrays not same length"
for i in xrange(len(self.indices) - 1):
if self.indices[i] >= self.indices[i + 1]:
raise TypeError("indices array must be sorted")
indices, values = args

""" A list of indices corresponding to active entries. """
self.indices = np.array(indices, dtype=np.int32)
""" A list of values corresponding to active entries. """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we remove this check?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that #8166 requires sorted indices

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is because this sorting is expensive, this code path is followed is so common internally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but the PR I linked (and maybe other parts of the code) depend on sorted indices. Have we checked if the ordering of indices is assumed elsewhere in the code?

Also, Python's sorted uses timsort internally which has O(n) complexity for an already sorted input so the overhead is not terrible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I often just look at the direct SparseVector initialization but not Vectors.sparse. hence I did not notice it.

The SparseVector initialization should take constant time and one should expect the user should supply sorted indices (I remember another PR which was closed because it did a O(n) check)

However, I do think that this should be documented somewhere clearly that the indices provided should be sorted.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we should expect users to supply sorted indices. There are tests covering this use case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To match Scala, we should:

  • have constant-time initialization for SparseVector (for fast conversions between MLlib and Breeze/scipy)
  • allow sorting for construction via Vectors.sparse (but that can depend upon the argument types, as in Scala)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MechCoder I know your summer of code is over, so please do say if you're too busy to update.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey, I'll update it definitely in a while (I sent you a mail regarding how much time I would be able to allocate from now on)

self.values = np.array(values, dtype=np.float64)

indices_length = len(self.indices)
values_length = len(self.values)
if indices_length != values_length:
raise ValueError(
"expected values of length %d, got %d." % (indices_length, values_length))
if indices_length > self.size:
raise ValueError("expected indices length <= %d, got %d" % (self.size, indices_length))

def numNonzeros(self):
return np.count_nonzero(self.values)
Expand Down