-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-9525] [PySpark] [MLlib] Optimize SparseVector initialization #7854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -461,32 +461,41 @@ def __init__(self, size, *args): | |
| self.size = int(size) | ||
| """ Size of the vector. """ | ||
| assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" | ||
| if len(args) == 1: | ||
| pairs = args[0] | ||
| if type(pairs) == dict: | ||
| pairs = pairs.items() | ||
| pairs = sorted(pairs) | ||
| self.indices = np.array([p[0] for p in pairs], dtype=np.int32) | ||
| """ A list of indices corresponding to active entries. """ | ||
| self.values = np.array([p[1] for p in pairs], dtype=np.float64) | ||
| """ A list of values corresponding to active entries. """ | ||
| if isinstance(args[0], bytes): | ||
| assert isinstance(args[1], bytes), "values should be string too" | ||
| if args[0]: | ||
| self.indices = np.frombuffer(args[0], np.int32) | ||
| self.values = np.frombuffer(args[1], np.float64) | ||
| else: | ||
| # np.frombuffer() doesn't work well with empty string in older version | ||
| self.indices = np.array([], dtype=np.int32) | ||
| self.values = np.array([], dtype=np.float64) | ||
| else: | ||
| if isinstance(args[0], bytes): | ||
| assert isinstance(args[1], bytes), "values should be string too" | ||
| if args[0]: | ||
| self.indices = np.frombuffer(args[0], np.int32) | ||
| self.values = np.frombuffer(args[1], np.float64) | ||
| else: | ||
| # np.frombuffer() doesn't work well with empty string in older version | ||
| self.indices = np.array([], dtype=np.int32) | ||
| self.values = np.array([], dtype=np.float64) | ||
| if len(args) == 1: | ||
| args = args[0] | ||
| if isinstance(args, dict): | ||
| args = args.items() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think we should use an OrderedDict here |
||
| args = list(zip(*args)) | ||
|
|
||
| # Handle empty args case. | ||
| if len(args) == 0: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unnecessary since L463 checks that len(args) >= 1
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But args is changed before this line.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mutation only happens inside the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you suggesting something like this? That would not work because it will raise an error if
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By len(args), I mean to handle this case.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking something more along the lines of if len(args) == 1:
args = args[0]
# Handle empty args case.
if len(args) == 0:
args = [[], []]
elif isinstance(args, dict):
args = list(zip(*args.items())
indices, values = argsYou don't have to worry about
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm.. This seems better, but I do know that the lens(args) is checked in L463 LOL (as you have repeated 3 times), it was just to check only the empty dict case. |
||
| indices = [] | ||
| values = [] | ||
| else: | ||
| self.indices = np.array(args[0], dtype=np.int32) | ||
| self.values = np.array(args[1], dtype=np.float64) | ||
| assert len(self.indices) == len(self.values), "index and value arrays not same length" | ||
| for i in xrange(len(self.indices) - 1): | ||
| if self.indices[i] >= self.indices[i + 1]: | ||
| raise TypeError("indices array must be sorted") | ||
| indices, values = args | ||
|
|
||
| """ A list of indices corresponding to active entries. """ | ||
| self.indices = np.array(indices, dtype=np.int32) | ||
| """ A list of values corresponding to active entries. """ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did we remove this check?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that #8166 requires sorted indices
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is because this sorting is expensive, this code path is followed is so common internally.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, but the PR I linked (and maybe other parts of the code) depend on sorted indices. Have we checked if the ordering of indices is assumed elsewhere in the code? Also, Python's
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I often just look at the direct SparseVector initialization but not The SparseVector initialization should take constant time and one should expect the user should supply sorted indices (I remember another PR which was closed because it did a O(n) check) However, I do think that this should be documented somewhere clearly that the indices provided should be sorted.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if we should expect users to supply sorted indices. There are tests covering this use case.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To match Scala, we should:
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MechCoder I know your summer of code is over, so please do say if you're too busy to update.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hey, I'll update it definitely in a while (I sent you a mail regarding how much time I would be able to allocate from now on) |
||
| self.values = np.array(values, dtype=np.float64) | ||
|
|
||
| indices_length = len(self.indices) | ||
| values_length = len(self.values) | ||
| if indices_length != values_length: | ||
| raise ValueError( | ||
| "expected values of length %d, got %d." % (indices_length, values_length)) | ||
| if indices_length > self.size: | ||
| raise ValueError("expected indices length <= %d, got %d" % (self.size, indices_length)) | ||
|
|
||
| def numNonzeros(self): | ||
| return np.count_nonzero(self.values) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if
args[0]is a list of (index, value) tuples?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is to change a dict into a (indices, values) tuple and then unpack it below on L485
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK makes sense