diff --git a/spec/std/enumerable_spec.cr b/spec/std/enumerable_spec.cr index 3383f5ad3f9c..bf72b35e2e4b 100644 --- a/spec/std/enumerable_spec.cr +++ b/spec/std/enumerable_spec.cr @@ -803,6 +803,24 @@ describe "Enumerable" do describe "max" do it { [1, 2, 3].max.should eq(3) } + it { [1, 2, 3].max(0).should eq([] of Int32) } + it { [1, 2, 3].max(1).should eq([3]) } + it { [1, 2, 3].max(2).should eq([3, 2]) } + it { [1, 2, 3].max(3).should eq([3, 2, 1]) } + it { [1, 2, 3].max(4).should eq([3, 2, 1]) } + it { ([] of Int32).max(0).should eq([] of Int32) } + it { ([] of Int32).max(5).should eq([] of Int32) } + it { + (0..1000).map { |x| (x*137 + x*x*139) % 5000 }.max(10).should eq([ + 4992, 4990, 4980, 4972, 4962, 4962, 4960, 4960, 4952, 4952, + ]) + } + + it "does not modify the array" do + xs = [7, 5, 2, 4, 9] + xs.max(2).should eq([9, 7]) + xs.should eq([7, 5, 2, 4, 9]) + end it "raises if empty" do expect_raises Enumerable::EmptyError do @@ -810,11 +828,23 @@ describe "Enumerable" do end end + it "raises if n is negative" do + expect_raises ArgumentError do + ([1, 2, 3] of Int32).max(-1) + end + end + it "raises if not comparable" do expect_raises ArgumentError do [Float64::NAN, 1.0, 2.0, Float64::NAN].max end end + + it "raises if not comparable in max(n)" do + expect_raises ArgumentError do + [Float64::NAN, 1.0, 2.0, Float64::NAN].max(2) + end + end end describe "max?" do @@ -851,6 +881,24 @@ describe "Enumerable" do describe "min" do it { [1, 2, 3].min.should eq(1) } + it { [1, 2, 3].min(0).should eq([] of Int32) } + it { [1, 2, 3].min(1).should eq([1]) } + it { [1, 2, 3].min(2).should eq([1, 2]) } + it { [1, 2, 3].min(3).should eq([1, 2, 3]) } + it { [1, 2, 3].min(4).should eq([1, 2, 3]) } + it { ([] of Int32).min(0).should eq([] of Int32) } + it { ([] of Int32).min(1).should eq([] of Int32) } + it { + (0..1000).map { |x| (x*137 + x*x*139) % 5000 }.min(10).should eq([ + 0, 10, 20, 26, 26, 26, 26, 30, 32, 32, + ]) + } + + it "does not modify the array" do + xs = [7, 5, 2, 4, 9] + xs.min(2).should eq([2, 4]) + xs.should eq([7, 5, 2, 4, 9]) + end it "raises if empty" do expect_raises Enumerable::EmptyError do @@ -858,11 +906,23 @@ describe "Enumerable" do end end + it "raises if n is negative" do + expect_raises ArgumentError do + ([1, 2, 3] of Int32).min(-1) + end + end + it "raises if not comparable" do expect_raises ArgumentError do [-1.0, Float64::NAN, -3.0].min end end + + it "raises if not comparable in min(n)" do + expect_raises ArgumentError do + [Float64::NAN, 1.0, 2.0, Float64::NAN].min(2) + end + end end describe "min?" do diff --git a/src/enumerable.cr b/src/enumerable.cr index 5476e2b55433..d613916be685 100644 --- a/src/enumerable.cr +++ b/src/enumerable.cr @@ -965,6 +965,35 @@ module Enumerable(T) ary end + private def quickselect_internal(data : Array(T), left : Int, right : Int, k : Int) : T + loop do + return data[left] if left == right + pivot_index = left + (right - left)//2 + pivot_index = quickselect_partition_internal(data, left, right, pivot_index) + if k == pivot_index + return data[k] + elsif k < pivot_index + right = pivot_index - 1 + else + left = pivot_index + 1 + end + end + end + + private def quickselect_partition_internal(data : Array(T), left : Int, right : Int, pivot_index : Int) : Int + pivot_value = data[pivot_index] + data.swap(pivot_index, right) + store_index = left + (left...right).each do |i| + if compare_or_raise(data[i], pivot_value) < 0 + data.swap(store_index, i) + store_index += 1 + end + end + data.swap(right, store_index) + store_index + end + # Returns the element with the maximum value in the collection. # # It compares using `>` so it will work for any type that supports that method. @@ -984,6 +1013,30 @@ module Enumerable(T) max_by? &.itself end + # Returns an array of the maximum *count* elements, sorted descending. + # + # It compares using `<=>` so it will work for any type that supports that method. + # + # ``` + # [7, 5, 2, 4, 9].max(3) # => [9, 7, 5] + # %w[Eve Alice Bob Mallory Carol].max(2) # => ["Mallory", "Eve"] + # ``` + # + # Returns all elements sorted descending if *count* is greater than the number + # of elements in the source. + # + # Raises `Enumerable::ArgumentError` if *count* is negative or if any elements + # are not comparable. + def max(count : Int) : Array(T) + raise ArgumentError.new("Count must be positive") if count < 0 + data = self.is_a?(Array) ? self.dup : self.to_a + n = data.size + count = n if count > n + (0...count).map do |i| + quickselect_internal(data, 0, n - 1, n - 1 - i) + end + end + # Returns the element for which the passed block returns with the maximum value. # # It compares using `>` so the block must return a type that supports that method @@ -1073,6 +1126,30 @@ module Enumerable(T) min_by? &.itself end + # Returns an array of the minimum *count* elements, sorted ascending. + # + # It compares using `<=>` so it will work for any type that supports that method. + # + # ``` + # [7, 5, 2, 4, 9].min(3) # => [2, 4, 5] + # %w[Eve Alice Bob Mallory Carol].min(2) # => ["Alice", "Bob"] + # ``` + # + # Returns all elements sorted ascending if *count* is greater than the number + # of elements in the source. + # + # Raises `Enumerable::ArgumentError` if *count* is negative or if any elements + # are not comparable. + def min(count : Int) : Array(T) + raise ArgumentError.new("Count must be positive") if count < 0 + data = self.is_a?(Array) ? self.dup : self.to_a + n = data.size + count = n if count > n + (0...count).map do |i| + quickselect_internal(data, 0, n - 1, i) + end + end + # Returns the element for which the passed block returns with the minimum value. # # It compares using `<` so the block must return a type that supports that method