|
24 | 24 | import sys |
25 | 25 | import shlex |
26 | 26 | import traceback |
| 27 | +from bisect import bisect_right |
27 | 28 | from subprocess import Popen, PIPE |
28 | 29 | from tempfile import NamedTemporaryFile |
29 | 30 | from threading import Thread |
@@ -534,6 +535,7 @@ def func(iterator): |
534 | 535 | return reduce(op, vals, zeroValue) |
535 | 536 |
|
536 | 537 | # TODO: aggregate |
| 538 | + |
537 | 539 |
|
538 | 540 | def sum(self): |
539 | 541 | """ |
@@ -610,6 +612,60 @@ def sampleVariance(self): |
610 | 612 | """ |
611 | 613 | return self.stats().sampleVariance() |
612 | 614 |
|
| 615 | + def getBuckets(self, bucketCount): |
| 616 | + """ |
| 617 | + Compute a histogram of the data using bucketCount number of buckets |
| 618 | + evenly spaced between the min and max of the RDD. |
| 619 | +
|
| 620 | + >>> sc.parallelize([1,49, 23, 100, 75, 50]).histogram() |
| 621 | + {(0,49):3, (50, 100):3} |
| 622 | + """ |
| 623 | + |
| 624 | + #use the statscounter as a quick way of getting max and min |
| 625 | + mm_stats = self.stats() |
| 626 | + min = mm_stats.min() |
| 627 | + max = mm_stats.max() |
| 628 | + |
| 629 | + increment = (max-min)/bucketCount |
| 630 | + buckets = range(min,min) |
| 631 | + if increment != 0: |
| 632 | + buckets = range(min,max, increment) |
| 633 | + |
| 634 | + return buckets |
| 635 | + |
| 636 | + def histogram(self, bucketCount, buckets=None): |
| 637 | + evenBuckets = False |
| 638 | + if not buckets: |
| 639 | + buckets = self.getBuckets(bucketCount) |
| 640 | + if len(buckets) < 2: |
| 641 | + raise ValueError("requires more than 1 bucket") |
| 642 | + if len(buckets) % 2 == 0: |
| 643 | + evenBuckets = True |
| 644 | + # histogram partition |
| 645 | + def histogramPartition(iterator): |
| 646 | + counters = defaultdict(int) |
| 647 | + for obj in iterator: |
| 648 | + k = bisect_right(buckets, obj) |
| 649 | + if k < len(buckets) and k > 0: |
| 650 | + key = (buckets[k-1], buckets[k]-1) |
| 651 | + elif k == len(buckets): |
| 652 | + key = (buckets[k-1], float("inf")) |
| 653 | + elif k == 0: |
| 654 | + key = (float("-inf"), buckets[k]-1) |
| 655 | + counters[key] += 1 |
| 656 | + yield counters |
| 657 | + |
| 658 | + # merge counters |
| 659 | + def mergeCounters(d1, d2): |
| 660 | + for k in d2.keys(): |
| 661 | + if k in d1: |
| 662 | + d1[k] += d2[k] |
| 663 | + return d1 |
| 664 | + |
| 665 | + #map partitions(histogram_partition(bucketFunction)).reduce(mergeCounters) |
| 666 | + return self.mapPartitions(histogramPartition).reduce(mergeCounters) |
| 667 | + |
| 668 | + |
613 | 669 | def countByValue(self): |
614 | 670 | """ |
615 | 671 | Return the count of each unique value in this RDD as a dictionary of |
|
0 commit comments