Skip to content
Merged
28 changes: 28 additions & 0 deletions spec/std/enumerable_spec.cr
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
require "spec"
require "./spec_helper"
require "spec/helpers/iterate"

module SomeInterface; end
Expand Down Expand Up @@ -1364,6 +1365,19 @@ describe "Enumerable" do
it { [1, 2, 3].sum(4.5).should eq(10.5) }
it { (1..3).sum { |x| x * 2 }.should eq(12) }
it { (1..3).sum(1.5) { |x| x * 2 }.should eq(13.5) }
it { [1, 3_u64].sum(0_i32).should eq(4_u32) }
it { [1, 3].sum(0_u64).should eq(4_u64) }
it { [1, 10000000000_u64].sum(0_u64).should eq(10000000001) }
pending_wasm32 "raises if union types are summed", tags: %w[slow] do
assert_compile_error <<-CRYSTAL,
require "prelude"
[1, 10000000000_u64].sum
CRYSTAL
"`Enumerable#sum` and `#product` do not support Union " +
"types. Instead, use `Enumerable#sum(initial)` and " +
"`#product(initial)`, respectively, with an initial value " +
"of the intended type of the call."
end

it "uses additive_identity from type" do
typeof([1, 2, 3].sum).should eq(Int32)
Expand Down Expand Up @@ -1405,6 +1419,20 @@ describe "Enumerable" do
typeof([1.5, 2.5, 3.5].product).should eq(Float64)
typeof([1, 2, 3].product(&.to_f)).should eq(Float64)
end

it { [1, 3_u64].product(3_i32).should eq(9_u32) }
it { [1, 3].product(3_u64).should eq(9_u64) }
it { [1, 10000000000_u64].product(3_u64).should eq(30000000000_u64) }
pending_wasm32 "raises if union types are multiplied", tags: %w[slow] do
assert_compile_error <<-CRYSTAL,
require "prelude"
[1, 10000000000_u64].product
CRYSTAL
"`Enumerable#sum` and `#product` do not support Union " +
"types. Instead, use `Enumerable#sum(initial)` and " +
"`#product(initial)`, respectively, with an initial value " +
"of the intended type of the call."
end
end

describe "first" do
Expand Down
27 changes: 27 additions & 0 deletions spec/std/spec_helper.cr
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,33 @@ def compile_file(source_file, *, bin_name = "executable_file", flags = %w(), fil
end
end

def assert_compile_error(source, expected_error, *, flags = %w(), file = __FILE__, line = __LINE__)
# can't use backtick in interpreted code (#12241)
pending_interpreted! "Unable to compile Crystal code in interpreted code"

with_tempfile("source_file", file: file) do |source_file|
File.write(source_file, source)

bin_name = "executable_file"
with_temp_executable(bin_name, file: file) do |executable_file|
compiler = ENV["CRYSTAL_SPEC_COMPILER_BIN"]? || "bin/crystal"
args = ["build"] + flags + ["-o", executable_file, source_file]
output = IO::Memory.new
status = Process.run(compiler, args, env: {
"CRYSTAL_PATH" => Crystal::PATH,
"CRYSTAL_LIBRARY_PATH" => Crystal::LIBRARY_PATH,
"CRYSTAL_CACHE_DIR" => Crystal::CACHE_DIR,
"NO_COLOR" => "1",
}, output: output, error: output)

output.to_s.should contain(expected_error)

status.success?.should be_false
File.exists?(executable_file).should be_false
end
end
end

def compile_source(source, flags = %w(), file = __FILE__, &)
with_tempfile("source_file", file: file) do |source_file|
File.write(source_file, source)
Expand Down
8 changes: 4 additions & 4 deletions spec/support/wasm32.cr
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
require "spec"

{% if flag?(:wasm32) %}
def pending_wasm32(description = "assert", file = __FILE__, line = __LINE__, end_line = __END_LINE__, &block)
pending("#{description} [wasm32]", file, line, end_line)
def pending_wasm32(description = "assert", file = __FILE__, line = __LINE__, end_line = __END_LINE__, focus : Bool = false, tags : String | Enumerable(String) | Nil = nil, &block)
pending("#{description} [wasm32]", file, line, end_line, focus: focus, tags: tags)
end

def pending_wasm32(*, describe, file = __FILE__, line = __LINE__, end_line = __END_LINE__, &block)
pending_wasm32(describe, file, line, end_line) { }
end
{% else %}
def pending_wasm32(description = "assert", file = __FILE__, line = __LINE__, end_line = __END_LINE__, &block)
it(description, file, line, end_line, &block)
def pending_wasm32(description = "assert", file = __FILE__, line = __LINE__, end_line = __END_LINE__, focus : Bool = false, tags : String | Enumerable(String) | Nil = nil, &block)
it(description, file, line, end_line, focus: focus, tags: tags, &block)
end

def pending_wasm32(*, describe, file = __FILE__, line = __LINE__, end_line = __END_LINE__, &block)
Expand Down
34 changes: 22 additions & 12 deletions src/enumerable.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,7 @@ module Enumerable(T)
end

private def additive_identity(reflect)
type = reflect.first
type = reflect.type
if type.responds_to? :additive_identity
type.additive_identity
else
Expand Down Expand Up @@ -1808,7 +1808,10 @@ module Enumerable(T)
# Expects all types returned from the block to respond to `#+` method.
#
# This method calls `.additive_identity` on the yielded type to determine the
# type of the sum value.
# type of the sum value. Hence, it can fail to compile if
# `.additive_identity` fails to determine a safe type, e.g., in case of
# union types. In such cases, use `sum(initial)` with an initial value of
# the expected type of the sum value.
#
# If the collection is empty, returns `additive_identity`.
#
Expand Down Expand Up @@ -1847,15 +1850,15 @@ module Enumerable(T)
# ```
#
# This method calls `.multiplicative_identity` on the element type to determine the
# type of the sum value.
# type of the product value.
#
# If the collection is empty, returns `multiplicative_identity`.
#
# ```
# ([] of Int32).product # => 1
# ```
def product
product Reflect(T).first.multiplicative_identity
product Reflect(T).type.multiplicative_identity
end

# Multiplies *initial* and all the elements in the collection
Expand Down Expand Up @@ -1886,16 +1889,19 @@ module Enumerable(T)
#
# Expects all types returned from the block to respond to `#*` method.
#
# This method calls `.multiplicative_identity` on the element type to determine the
# type of the sum value.
# This method calls `.multiplicative_identity` on the element type to
# determine the type of the product value. Hence, it can fail to compile if
# `.multiplicative_identity` fails to determine a safe type, e.g., in case
# of union types. In such cases, use `product(initial)` with an initial
# value of the expected type of the product value.
#
# If the collection is empty, returns `multiplicative_identity`.
#
# ```
# ([] of Int32).product { |x| x + 1 } # => 1
# ```
def product(& : T -> _)
product(Reflect(typeof(yield Enumerable.element_type(self))).first.multiplicative_identity) do |value|
product(Reflect(typeof(yield Enumerable.element_type(self))).type.multiplicative_identity) do |value|
yield value
end
end
Expand Down Expand Up @@ -2285,12 +2291,16 @@ module Enumerable(T)

# :nodoc:
private struct Reflect(X)
# For now it's just a way to implement `Enumerable#sum` in a way that the
# initial value given to it has the type of the first type in the union,
# if the type is a union.
def self.first
# For now, Reflect is used to reject union types in `#sum()` and
# `#product()` methods.
def self.type
{% if X.union? %}
{{X.union_types.first}}
{{
raise("`Enumerable#sum` and `#product` do not support Union " +
"types. Instead, use `Enumerable#sum(initial)` and " +
"`#product(initial)`, respectively, with an initial value " +
"of the intended type of the call.")
}}
{% else %}
X
{% end %}
Expand Down