diff --git a/bin/ci b/bin/ci index dc48b73bcc10..7ee9495df137 100755 --- a/bin/ci +++ b/bin/ci @@ -74,7 +74,8 @@ prepare_system() { build() { with_build_env 'make std_spec clean' - with_build_env 'make crystal std_spec compiler_spec docs' + with_build_env 'make crystal' + with_build_env 'make std_spec compiler_spec docs FLAGS=--overflow-checked' with_build_env 'find samples -name "*.cr" | xargs -L 1 ./bin/crystal build --no-codegen' with_build_env './bin/crystal tool format --check samples spec src' } diff --git a/spec/compiler/codegen/overflow_check_scope_spec.cr b/spec/compiler/codegen/overflow_check_scope_spec.cr new file mode 100644 index 000000000000..8f2d7d9cbae4 --- /dev/null +++ b/spec/compiler/codegen/overflow_check_scope_spec.cr @@ -0,0 +1,327 @@ +require "../../spec_helper" + +def checked_run(code) + run(code, overflow_check: Crystal::OverflowCheckScope::Policy::Checked) +end + +def unchecked_run(code) + run(code, overflow_check: Crystal::OverflowCheckScope::Policy::Unchecked) +end + +describe "Code gen: overflow check scope" do + describe "for add" do + it "can be unchecked" do + checked_run(%( + unchecked { 2147483647_i32 + 1_i32 } + )).to_i.should eq(-2147483648_i32) + end + + it "can be checked " do + unchecked_run(%( + require "prelude" + + x = 0 + begin + checked { 2147483647_i32 + 1_i32 } + x = 1 + rescue OverflowError + x = 2 + end + x + )).to_i.should eq(2) + end + + {% for type in [UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64] %} + it "wrap around if unchecked for {{type}}" do + unchecked_run(%( + require "prelude" + {{type}}::MAX + {{type}}.new(1) == {{type}}::MIN + )).to_b.should be_true + end + + it "raises if checked for {{type}}" do + checked_run(%( + require "prelude" + begin + {{type}}::MAX + {{type}}.new(1) + 0 + rescue OverflowError + 1 + end + )).to_i.should eq(1) + end + + it "wrap around if unchecked for {{type}} + Int64" do + unchecked_run(%( + require "prelude" + {{type}}::MAX + 1_i64 == {{type}}::MIN + )).to_b.should be_true + end + + it "raises if checked for {{type}} + Int64" do + checked_run(%( + require "prelude" + begin + {{type}}::MAX + 1_i64 + 0 + rescue OverflowError + 1 + end + )).to_i.should eq(1) + end + {% end %} + end + + describe "for sub" do + it "can be unchecked" do + checked_run(%( + unchecked { -2147483648_i32 - 1_i32 } + )).to_i.should eq(2147483647_i32) + end + + it "can be checked " do + unchecked_run(%( + require "prelude" + + x = 0 + begin + checked { -2147483648_i32 - 1_i32 } + x = 1 + rescue OverflowError + x = 2 + end + x + )).to_i.should eq(2) + end + + {% for type in [UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64] %} + it "wrap around if unchecked for {{type}}" do + unchecked_run(%( + require "prelude" + {{type}}::MIN - {{type}}.new(1) == {{type}}::MAX + )).to_b.should be_true + end + + it "raises if checked for {{type}}" do + checked_run(%( + require "prelude" + begin + {{type}}::MIN - {{type}}.new(1) + 0 + rescue OverflowError + 1 + end + )).to_i.should eq(1) + end + + it "wrap around if unchecked for {{type}} - Int64" do + unchecked_run(%( + require "prelude" + {{type}}::MIN - 1_i64 == {{type}}::MAX + )).to_b.should be_true + end + + it "raises if checked for {{type}} - Int64" do + checked_run(%( + require "prelude" + begin + {{type}}::MIN - 1_i64 + 0 + rescue OverflowError + 1 + end + )).to_i.should eq(1) + end + {% end %} + end + + describe "for mul" do + it "can be unchecked" do + checked_run(%( + unchecked { 2147483647_i32 * 2_i32 } + )).to_i.should eq(-2_i32) + end + + it "can be checked " do + unchecked_run(%( + require "prelude" + + x = 0 + begin + checked { 2147483647_i32 * 2_i32 } + x = 1 + rescue OverflowError + x = 2 + end + x + )).to_i.should eq(2) + end + + {% for type in [UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64] %} + it "wrap around if unchecked for {{type}}" do + unchecked_run(%( + require "prelude" + ({{type}}::MAX / {{type}}.new(2) + {{type}}.new(1)) * {{type}}.new(2) == {{type}}::MIN + )).to_b.should be_true + end + + it "raises if checked for {{type}}" do + checked_run(%( + require "prelude" + begin + ({{type}}::MAX / {{type}}.new(2) + {{type}}.new(1)) * {{type}}.new(2) + 0 + rescue OverflowError + 1 + end + )).to_i.should eq(1) + end + + it "wrap around if unchecked for {{type}} + Int64" do + unchecked_run(%( + require "prelude" + ({{type}}::MAX / {{type}}.new(2) + {{type}}.new(1)) * 2_i64 == {{type}}::MIN + )).to_b.should be_true + end + + it "raises if checked for {{type}} + Int64" do + checked_run(%( + require "prelude" + begin + ({{type}}::MAX / {{type}}.new(2) + {{type}}.new(1)) * 2_i64 + 0 + rescue OverflowError + 1 + end + )).to_i.should eq(1) + end + {% end %} + end + + it "obey default checked" do + checked_run(%( + require "prelude" + + x = 0 + begin + a = 2147483647_i32 + 1_i32 + x = 1 + rescue OverflowError + x = 2 + end + x + )).to_i.should eq(2) + end + + it "obey default unchecked" do + unchecked_run(%( + 2147483647_i32 + 1_i32 + )).to_i.should eq(-2147483648_i32) + end + + it "is obeyed in return" do + checked_run(%( + def inc(v) + unchecked { + return v + 1_i8 + } + end + + inc(127_i8) + )).to_i.should eq(-128) + end + + it "can be nested" do + unchecked_run(%( + require "prelude" + + begin + checked { unchecked { 2147483647_i32 + 1_i32 } + 2147483647_i32 + 2147483647_i32 + 2_i32 } + 0 + rescue OverflowError + 1 + end + )).to_i.should eq(1) + end + + describe "work at lexical scope." do + it "is not forwarded to function calls" do + checked_run(%( + require "prelude" + + def inc_checked(v) + v + 1_i8 + end + + def twice(v) + unchecked { inc_checked(inc_checked(v)) } + end + + begin + twice(126_i8) + 1 + rescue OverflowError + 2 + end + )).to_i.should eq(2) + + unchecked_run(%( + require "prelude" + + def inc_unchecked(v) + v + 1_i8 + end + + def twice(v) + checked { inc_unchecked(inc_unchecked(v)) } + end + + begin + twice(126_i8) + 1 + rescue OverflowError + 2 + end + )).to_i.should eq(1) + end + + it "is not forwarded to blocks yields" do + unchecked_run(%( + require "prelude" + + def inc_unchecked + yield + 1_i8 + end + + def foo + checked { + inc_unchecked { 127_i8 } + } + end + + foo + )).to_i.should eq(-128_i8) + end + + it "is forwarded to blocks" do + unchecked_run(%( + require "prelude" + + def inc_unchecked(v) + yield v + end + + def foo + checked { + inc_unchecked(127_i8) { |v| v + 1_i8 } + } + 1 + rescue OverflowError + 0 + end + + foo + )).to_i.should eq(0) + end + end +end diff --git a/spec/compiler/parser/parser_spec.cr b/spec/compiler/parser/parser_spec.cr index c268ab82577f..7d8b0f46592d 100644 --- a/spec/compiler/parser/parser_spec.cr +++ b/spec/compiler/parser/parser_spec.cr @@ -126,6 +126,9 @@ module Crystal it_parses "foo[1] /2", Call.new(Call.new("foo".call, "[]", 1.int32), "/", 2.int32) it_parses "[1] /2", Call.new(([1.int32] of ASTNode).array, "/", 2.int32) + it_parses "checked { 1 + 2 }", OverflowCheckScope.new(OverflowCheckScope::Policy::Checked, Expressions.from([Call.new(1.int32, "+", 2.int32)] of ASTNode)) + it_parses "unchecked { 1 + 2 }", OverflowCheckScope.new(OverflowCheckScope::Policy::Unchecked, Expressions.from([Call.new(1.int32, "+", 2.int32)] of ASTNode)) + it_parses "!1", Not.new(1.int32) it_parses "- 1", Call.new(1.int32, "-") it_parses "+ 1", Call.new(1.int32, "+") @@ -181,6 +184,7 @@ module Crystal extend class struct module enum while until return next break lib fun alias pointerof sizeof instance_sizeof typeof private protected asm out + checked unchecked end ).each do |kw| assert_syntax_error "def foo(#{kw}); end", "cannot use '#{kw}' as an argument name", 1, 9 @@ -1643,6 +1647,8 @@ module Crystal assert_end_location "extend Foo" assert_end_location "1.as(Int32)" assert_end_location "puts obj.foo" + assert_end_location "checked { 1 + 1 }" + assert_end_location "unchecked { 1 + 1 }" assert_syntax_error %({"a" : 1}), "space not allowed between named argument name and ':'" assert_syntax_error %({"a": 1, "b" : 2}), "space not allowed between named argument name and ':'" diff --git a/spec/compiler/semantic/overflow_check_scope_spec.cr b/spec/compiler/semantic/overflow_check_scope_spec.cr new file mode 100644 index 000000000000..482ba9779024 --- /dev/null +++ b/spec/compiler/semantic/overflow_check_scope_spec.cr @@ -0,0 +1,10 @@ +require "../../spec_helper" + +describe "Semantic: overflow check scope" do + it "type block by expression" do + assert_type("checked { 1 }") { int32 } + assert_type("unchecked { 1 }") { int32 } + assert_type("checked { 1 + 2; '1' }") { char } + assert_type("def foo; unchecked { return 1 }; end; foo") { int32 } + end +end diff --git a/spec/spec_helper.cr b/spec/spec_helper.cr index 5ea235ae0cc6..08472506d5f0 100644 --- a/spec/spec_helper.cr +++ b/spec/spec_helper.cr @@ -153,7 +153,7 @@ class Crystal::SpecRunOutput end end -def run(code, filename = nil, inject_primitives = true, debug = Crystal::Debug::None) +def run(code, filename = nil, inject_primitives = true, debug = Crystal::Debug::None, overflow_check = Crystal::DefaultOverflowCheckPolicy) code = inject_primitives(code) if inject_primitives # Code that requires the prelude doesn't run in LLVM's MCJIT @@ -174,6 +174,7 @@ def run(code, filename = nil, inject_primitives = true, debug = Crystal::Debug:: compiler = Compiler.new compiler.debug = debug + compiler.overflow_check = overflow_check compiler.compile Compiler::Source.new("spec", code), output_filename output = `#{output_filename}` @@ -181,7 +182,7 @@ def run(code, filename = nil, inject_primitives = true, debug = Crystal::Debug:: SpecRunOutput.new(output) else - Program.new.run(code, filename: filename, debug: debug) + Program.new.run(code, filename: filename, debug: debug, overflow_check: overflow_check) end end diff --git a/spec/std/int_spec.cr b/spec/std/int_spec.cr index 13142e734fea..62896e06f143 100644 --- a/spec/std/int_spec.cr +++ b/spec/std/int_spec.cr @@ -39,6 +39,12 @@ describe "Int" do x.should be_a(Int64) end + it "should overflow with larger integers" do + expect_raises(OverflowError) do + 51_i64 ** 12 + end + end + describe "with float" do it { (2 ** 2.0).should be_close(4, 0.0001) } it { (2 ** 2.5_f32).should be_close(5.656854249492381, 0.0001) } @@ -46,6 +52,42 @@ describe "Int" do end end + describe "unchecked_pow" do + it "with positive Int32" do + x = 2.unchecked_pow(2) + x.should eq(4) + x.should be_a(Int32) + + x = 2.unchecked_pow(0) + x.should eq(1) + x.should be_a(Int32) + end + + it "with positive UInt8" do + x = 2_u8.unchecked_pow(2) + x.should eq(4) + x.should be_a(UInt8) + end + + it "raises with negative exponent" do + expect_raises(ArgumentError, "Cannot raise an integer to a negative integer power, use floats for that") do + 2.unchecked_pow(-1) + end + end + + it "should work with large integers" do + x = 51_i64.unchecked_pow(11) + x.should eq(6071163615208263051_i64) + x.should be_a(Int64) + end + + it "should wrap with larger integers" do + x = 51_i64.unchecked_pow(12) + x.should eq(-3965304877440961871_i64) + x.should be_a(Int64) + end + end + describe "#===(:Char)" do it { (99 === 'c').should be_true } it { (99_u8 === 'c').should be_true } diff --git a/src/char/reader.cr b/src/char/reader.cr index 8531aedee7e2..be7e23b3031a 100644 --- a/src/char/reader.cr +++ b/src/char/reader.cr @@ -200,7 +200,7 @@ struct Char end if first < 0xe0 - return yield (first << 6) + (second - 0x3080), 2, nil + return yield __next_unchecked { (first << 6) + (second - 0x3080) }, 2, nil end third = byte_at?(pos + 2) @@ -217,7 +217,7 @@ struct Char invalid_byte_sequence 3 end - return yield (first << 12) + (second << 6) + (third - 0xE2080), 3, nil + return yield __next_unchecked { (first << 12) + (second << 6) + (third - 0xE2080) }, 3, nil end if first == 0xf0 && second < 0x90 @@ -236,7 +236,7 @@ struct Char end if first < 0xf5 - return yield (first << 18) + (second << 12) + (third << 6) + (fourth - 0x3C82080), 4, nil + return yield __next_unchecked { (first << 18) + (second << 12) + (third << 6) + (fourth - 0x3C82080) }, 4, nil end invalid_byte_sequence 4 diff --git a/src/compiler/crystal/codegen/call.cr b/src/compiler/crystal/codegen/call.cr index 8277ae58c212..eed288906a8e 100644 --- a/src/compiler/crystal/codegen/call.cr +++ b/src/compiler/crystal/codegen/call.cr @@ -283,7 +283,12 @@ class Crystal::CodeGenVisitor context.return_phi = phi request_value do + @caller_overflow_check = @overflow_check + @overflow_check = @main_overflow_check + accept target_def.body + + @overflow_check = @caller_overflow_check.not_nil! end phi.add @last, target_def.body.type?, last: true diff --git a/src/compiler/crystal/codegen/codegen.cr b/src/compiler/crystal/codegen/codegen.cr index 236b525b48c4..59697c3b32e6 100644 --- a/src/compiler/crystal/codegen/codegen.cr +++ b/src/compiler/crystal/codegen/codegen.cr @@ -15,17 +15,17 @@ module Crystal GET_EXCEPTION_NAME = "__crystal_get_exception" class Program - def run(code, filename = nil, debug = Debug::Default) + def run(code, filename = nil, debug = Debug::Default, overflow_check = DefaultOverflowCheckPolicy) parser = Parser.new(code) parser.filename = filename node = parser.parse node = normalize node node = semantic node - evaluate node, debug: debug + evaluate node, debug: debug, overflow_check: overflow_check end - def evaluate(node, debug = Debug::Default) - llvm_mod = codegen(node, single_module: true, debug: debug)[""].mod + def evaluate(node, debug = Debug::Default, overflow_check = DefaultOverflowCheckPolicy) + llvm_mod = codegen(node, single_module: true, debug: debug, overflow_check: overflow_check)[""].mod main = llvm_mod.functions[MAIN_NAME] main_return_type = main.return_type @@ -60,8 +60,8 @@ module Crystal end end - def codegen(node, single_module = false, debug = Debug::Default) - visitor = CodeGenVisitor.new self, node, single_module: single_module, debug: debug + def codegen(node, single_module = false, debug = Debug::Default, overflow_check = DefaultOverflowCheckPolicy) + visitor = CodeGenVisitor.new self, node, single_module: single_module, debug: debug, overflow_check: overflow_check visitor.accept node visitor.process_finished_hooks visitor.finish @@ -140,9 +140,11 @@ module Crystal @main_llvm_typer : LLVMTyper @main_module_info : ModuleInfo @main_builder : CrystalLLVMBuilder + @overflow_check : OverflowCheckScope::Policy - def initialize(@program : Program, @node : ASTNode, single_module = false, @debug = Debug::Default) + def initialize(@program : Program, @node : ASTNode, single_module = false, @debug = Debug::Default, @overflow_check = DefaultOverflowCheckPolicy) @single_module = !!single_module + @main_overflow_check = @overflow_check @abi = @program.target_machine.abi @llvm_context = LLVM::Context.new # LLVM::Context.register(@llvm_context, "main") @@ -1453,7 +1455,12 @@ module Crystal @needs_value = true set_ensure_exception_handler(block) + old_overflow_check = @overflow_check + @overflow_check = @caller_overflow_check.not_nil! + accept block.body + + @overflow_check = old_overflow_check end phi.add @last, block.body.type?, last: true diff --git a/src/compiler/crystal/codegen/fun.cr b/src/compiler/crystal/codegen/fun.cr index cf5e8be7ebf1..eb57b8a43ac9 100644 --- a/src/compiler/crystal/codegen/fun.cr +++ b/src/compiler/crystal/codegen/fun.cr @@ -63,6 +63,8 @@ class Crystal::CodeGenVisitor old_needs_value = @needs_value + old_overflow_check = @overflow_check + with_cloned_context do |old_context| context.type = self_type context.vars = LLVMVars.new @@ -77,6 +79,8 @@ class Crystal::CodeGenVisitor @rescue_block = nil @needs_value = true + @overflow_check = @main_overflow_check + args = codegen_fun_signature(mangled_name, target_def, self_type, is_fun_literal, is_closure) needs_body = !target_def.is_a?(External) || is_exported_fun @@ -156,6 +160,8 @@ class Crystal::CodeGenVisitor @alloca_block = old_alloca_block @needs_value = old_needs_value + @overflow_check = old_overflow_check + if @debug.line_numbers? # set_current_debug_location associates a scope from the current fun, # and at this point the current one should be the old one before diff --git a/src/compiler/crystal/codegen/primitives.cr b/src/compiler/crystal/codegen/primitives.cr index ff87cf55230f..4845cda564dc 100644 --- a/src/compiler/crystal/codegen/primitives.cr +++ b/src/compiler/crystal/codegen/primitives.cr @@ -13,10 +13,18 @@ class Crystal::CodeGenVisitor end end + def visit(node : OverflowCheckScope) + old_overflow_check = @overflow_check + @overflow_check = node.policy + accept node.body + @overflow_check = old_overflow_check + false + end + def codegen_primitive(call, node, target_def, call_args) @last = case node.name when "binary" - codegen_primitive_binary node, target_def, call_args + codegen_primitive_binary call.name_location, node, target_def, call_args when "cast" codegen_primitive_cast node, target_def, call_args when "allocate" @@ -74,13 +82,13 @@ class Crystal::CodeGenVisitor end end - def codegen_primitive_binary(node, target_def, call_args) + def codegen_primitive_binary(location, node, target_def, call_args) p1, p2 = call_args t1, t2 = target_def.owner, target_def.args[0].type - codegen_binary_op target_def.name, t1, t2, p1, p2 + codegen_binary_op location, target_def.name, t1, t2, p1, p2 end - def codegen_binary_op(op, t1 : BoolType, t2 : BoolType, p1, p2) + def codegen_binary_op(location, op, t1 : BoolType, t2 : BoolType, p1, p2) case op when "==" then builder.icmp LLVM::IntPredicate::EQ, p1, p2 when "!=" then builder.icmp LLVM::IntPredicate::NE, p1, p2 @@ -88,7 +96,7 @@ class Crystal::CodeGenVisitor end end - def codegen_binary_op(op, t1 : CharType, t2 : CharType, p1, p2) + def codegen_binary_op(location, op, t1 : CharType, t2 : CharType, p1, p2) case op when "==" then return builder.icmp LLVM::IntPredicate::EQ, p1, p2 when "!=" then return builder.icmp LLVM::IntPredicate::NE, p1, p2 @@ -100,7 +108,7 @@ class Crystal::CodeGenVisitor end end - def codegen_binary_op(op, t1 : SymbolType, t2 : SymbolType, p1, p2) + def codegen_binary_op(location, op, t1 : SymbolType, t2 : SymbolType, p1, p2) case op when "==" then return builder.icmp LLVM::IntPredicate::EQ, p1, p2 when "!=" then return builder.icmp LLVM::IntPredicate::NE, p1, p2 @@ -108,55 +116,228 @@ class Crystal::CodeGenVisitor end end - def codegen_binary_op(op, t1 : IntegerType, t2 : IntegerType, p1, p2) + def codegen_binary_op(location, op, t1 : IntegerType, t2 : IntegerType, p1, p2) # Comparisons are a bit trickier because we want to get comparisons # between signed and unsigned integers right. case op - when "<" then return @last = codegen_binary_op_lt(t1, t2, p1, p2) - when "<=" then return @last = codegen_binary_op_lte(t1, t2, p1, p2) - when ">" then return @last = codegen_binary_op_gt(t1, t2, p1, p2) - when ">=" then return @last = codegen_binary_op_gte(t1, t2, p1, p2) + when "<" then return codegen_binary_op_lt(t1, t2, p1, p2) + when "<=" then return codegen_binary_op_lte(t1, t2, p1, p2) + when ">" then return codegen_binary_op_gt(t1, t2, p1, p2) + when ">=" then return codegen_binary_op_gte(t1, t2, p1, p2) end - p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2) - - @last = case op - when "+" then builder.add p1, p2 - when "-" then builder.sub p1, p2 - when "*" then builder.mul p1, p2 - when "/", "unsafe_div" then t1.signed? ? builder.sdiv(p1, p2) : builder.udiv(p1, p2) - when "%", "unsafe_mod" then t1.signed? ? builder.srem(p1, p2) : builder.urem(p1, p2) - when "unsafe_shl" then builder.shl(p1, p2) - when "unsafe_shr" then t1.signed? ? builder.ashr(p1, p2) : builder.lshr(p1, p2) - when "|" then or(p1, p2) - when "&" then and(p1, p2) - when "^" then builder.xor(p1, p2) - when "==" then return builder.icmp LLVM::IntPredicate::EQ, p1, p2 - when "!=" then return builder.icmp LLVM::IntPredicate::NE, p1, p2 - else raise "BUG: trying to codegen #{t1} #{op} #{t2}" - end + tmax, p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2) - if t1.normal_rank != t2.normal_rank && t1.rank < t2.rank - @last = trunc @last, llvm_type(t1) + case op + when "+" then codegen_binary_op_add(location, tmax, t1, t2, p1, p2) + when "-" then codegen_binary_op_sub(location, tmax, t1, t2, p1, p2) + when "*" then codegen_binary_op_mul(location, tmax, t1, t2, p1, p2) + when "/", "unsafe_div" then codegen_trunc_binary_op_result(t1, t2, t1.signed? ? builder.sdiv(p1, p2) : builder.udiv(p1, p2)) + when "%", "unsafe_mod" then codegen_trunc_binary_op_result(t1, t2, t1.signed? ? builder.srem(p1, p2) : builder.urem(p1, p2)) + when "unsafe_shl" then codegen_trunc_binary_op_result(t1, t2, builder.shl(p1, p2)) + when "unsafe_shr" then codegen_trunc_binary_op_result(t1, t2, t1.signed? ? builder.ashr(p1, p2) : builder.lshr(p1, p2)) + when "|" then codegen_trunc_binary_op_result(t1, t2, or(p1, p2)) + when "&" then codegen_trunc_binary_op_result(t1, t2, and(p1, p2)) + when "^" then codegen_trunc_binary_op_result(t1, t2, builder.xor(p1, p2)) + when "==" then builder.icmp(LLVM::IntPredicate::EQ, p1, p2) + when "!=" then builder.icmp(LLVM::IntPredicate::NE, p1, p2) + else raise "BUG: trying to codegen #{t1} #{op} #{t2}" end - - @last end def codegen_binary_extend_int(t1, t2, p1, p2) if t1.normal_rank == t2.normal_rank # Nothing to do + tmax = t1 elsif t1.rank < t2.rank p1 = extend_int t1, t2, p1 + tmax = t2 else p2 = extend_int t2, t1, p2 + tmax = t1 + end + {tmax, p1, p2} + end + + # Ensures the result is returned in the type of the left hand side operand t1. + # This is only needed if the operation was carried on in the realm of t2 + # because it was of higher rank + def codegen_trunc_binary_op_result(t1, t2, result) + if t1.normal_rank != t2.normal_rank && t1.rank < t2.rank + result = trunc result, llvm_type(t1) + else + result + end + end + + def codegen_binary_op_add(location, t : IntegerType, t1, t2, p1, p2) + if overflow_checked_scope? + llvm_fun = case t.kind + when :i8 + binary_overflow_fun "llvm.sadd.with.overflow.i8", llvm_context.int8 + when :i16 + binary_overflow_fun "llvm.sadd.with.overflow.i16", llvm_context.int16 + when :i32 + binary_overflow_fun "llvm.sadd.with.overflow.i32", llvm_context.int32 + when :i64 + binary_overflow_fun "llvm.sadd.with.overflow.i64", llvm_context.int64 + when :u8 + binary_overflow_fun "llvm.uadd.with.overflow.i8", llvm_context.int8 + when :u16 + binary_overflow_fun "llvm.uadd.with.overflow.i16", llvm_context.int16 + when :u32 + binary_overflow_fun "llvm.uadd.with.overflow.i32", llvm_context.int32 + when :u64 + binary_overflow_fun "llvm.uadd.with.overflow.i64", llvm_context.int64 + else + raise "unreachable" + end + + codegen_binary_overflow_check(location, llvm_fun, t, t1, t2, p1, p2) + else + result = builder.add p1, p2 + codegen_trunc_binary_op_result(t1, t2, result) end - {p1, p2} + end + + def codegen_binary_op_sub(location, t : IntegerType, t1, t2, p1, p2) + if overflow_checked_scope? + llvm_fun = case t.kind + when :i8 + binary_overflow_fun "llvm.ssub.with.overflow.i8", llvm_context.int8 + when :i16 + binary_overflow_fun "llvm.ssub.with.overflow.i16", llvm_context.int16 + when :i32 + binary_overflow_fun "llvm.ssub.with.overflow.i32", llvm_context.int32 + when :i64 + binary_overflow_fun "llvm.ssub.with.overflow.i64", llvm_context.int64 + when :u8 + binary_overflow_fun "llvm.usub.with.overflow.i8", llvm_context.int8 + when :u16 + binary_overflow_fun "llvm.usub.with.overflow.i16", llvm_context.int16 + when :u32 + binary_overflow_fun "llvm.usub.with.overflow.i32", llvm_context.int32 + when :u64 + binary_overflow_fun "llvm.usub.with.overflow.i64", llvm_context.int64 + else + raise "unreachable" + end + + codegen_binary_overflow_check(location, llvm_fun, t, t1, t2, p1, p2) + else + result = builder.sub(p1, p2) + codegen_trunc_binary_op_result(t1, t2, result) + end + end + + def codegen_binary_op_mul(location, t : IntegerType, t1, t2, p1, p2) + if overflow_checked_scope? + llvm_fun = case t.kind + when :i8 + binary_overflow_fun "llvm.smul.with.overflow.i8", llvm_context.int8 + when :i16 + binary_overflow_fun "llvm.smul.with.overflow.i16", llvm_context.int16 + when :i32 + binary_overflow_fun "llvm.smul.with.overflow.i32", llvm_context.int32 + when :i64 + binary_overflow_fun "llvm.smul.with.overflow.i64", llvm_context.int64 + when :u8 + binary_overflow_fun "llvm.umul.with.overflow.i8", llvm_context.int8 + when :u16 + binary_overflow_fun "llvm.umul.with.overflow.i16", llvm_context.int16 + when :u32 + binary_overflow_fun "llvm.umul.with.overflow.i32", llvm_context.int32 + when :u64 + binary_overflow_fun "llvm.umul.with.overflow.i64", llvm_context.int64 + else + raise "unreachable" + end + + codegen_binary_overflow_check(location, llvm_fun, t, t1, t2, p1, p2) + else + result = builder.mul(p1, p2) + codegen_trunc_binary_op_result(t1, t2, result) + end + end + + # Generates a call to llvm_fun(p1, p2). + # t1, t2 are the original types of p1, p2. + # t is the super type of t1 and t2 where the operation is performed. + # llvm_fun returns {res, o_bit} where the o_bit signals overflow. + # The generated code also performs a range check and truncation of res + # in order to fit in the original type t1 if needed. + # + # ``` + # %res_with_overflow = call {T, i1} (T %p1, T %p2) + # %res = extractvalue {T, i1} %res, 0 + # %o_bit = extractvalue {T, i1} %res, 1 + # ;; if T != T1 + # %out_of_range = %res < T1::MIN || %res > T1::MAX ;; compare T1.range and %res + # br i1 or(%o_bit, %out_of_range), label %overflow, label %normal + # ;; else + # br i1 %o_bit, label %overflow, label %normal + # ;; end + # + # overflow: + # ;; codegen: raise OverflowError.new with caller's location + # + # normal: + # ;; if T != T1 + # ;; %res' is returned + # %res' = trunc T %res to T1 + # ;; else + # ;; %res is returned + # ;; end + # ``` + private def codegen_binary_overflow_check(location, llvm_fun, t : IntegerType, t1, t2, p1, p2) + res_with_overflow = builder.call(llvm_fun, [p1, p2]) + + res = extract_value res_with_overflow, 0 + o_bit = extract_value res_with_overflow, 1 + + if t != t1 + t1_min_value, t1_max_value = t1.range + # out_of_range = res < t1_min_value || res > t1_max_value + out_of_range = or( + codegen_binary_op_lt(t, t1, res, int(t1_min_value, t1)), + codegen_binary_op_gt(t, t1, res, int(t1_max_value, t1)) + ) + + overflow = or(o_bit, out_of_range) + else + overflow = o_bit + end + + op_overflow = new_block "overflow" + op_normal = new_block "normal" + + cond overflow, op_overflow, op_normal + position_at_end op_overflow + + ex = Call.new(Path.global("OverflowError").at(location), "new").at(location) + call = Call.global("raise", ex).at(location) + visitor = MainVisitor.new(@program) + @program.visit_main call, visitor: visitor + accept call + + position_at_end op_normal + + codegen_trunc_binary_op_result(t1, t2, res) + end + + private def binary_overflow_fun(fun_name, llvm_operand_type) + llvm_mod.functions[fun_name]? || + llvm_mod.functions.add(fun_name, [llvm_operand_type, llvm_operand_type], + llvm_context.struct([llvm_operand_type, llvm_context.int1])) + end + + private def overflow_checked_scope? + @overflow_check == OverflowCheckScope::Policy::Checked end def codegen_binary_op_lt(t1, t2, p1, p2) if t1.signed? == t2.signed? - p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2) + _, p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2) builder.icmp (t1.signed? ? LLVM::IntPredicate::SLT : LLVM::IntPredicate::ULT), p1, p2 else if t1.signed? && t2.unsigned? @@ -194,7 +375,7 @@ class Crystal::CodeGenVisitor def codegen_binary_op_lte(t1, t2, p1, p2) if t1.signed? == t2.signed? - p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2) + _, p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2) builder.icmp (t1.signed? ? LLVM::IntPredicate::SLE : LLVM::IntPredicate::ULE), p1, p2 else if t1.signed? && t2.unsigned? @@ -232,7 +413,7 @@ class Crystal::CodeGenVisitor def codegen_binary_op_gt(t1, t2, p1, p2) if t1.signed? == t2.signed? - p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2) + _, p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2) builder.icmp (t1.signed? ? LLVM::IntPredicate::SGT : LLVM::IntPredicate::UGT), p1, p2 else if t1.signed? && t2.unsigned? @@ -270,7 +451,7 @@ class Crystal::CodeGenVisitor def codegen_binary_op_gte(t1, t2, p1, p2) if t1.signed? == t2.signed? - p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2) + _, p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2) builder.icmp (t1.signed? ? LLVM::IntPredicate::SGE : LLVM::IntPredicate::UGE), p1, p2 else if t1.signed? && t2.unsigned? @@ -306,17 +487,17 @@ class Crystal::CodeGenVisitor end end - def codegen_binary_op(op, t1 : IntegerType, t2 : FloatType, p1, p2) + def codegen_binary_op(location, op, t1 : IntegerType, t2 : FloatType, p1, p2) p1 = codegen_cast(t1, t2, p1) - codegen_binary_op(op, t2, t2, p1, p2) + codegen_binary_op(location, op, t2, t2, p1, p2) end - def codegen_binary_op(op, t1 : FloatType, t2 : IntegerType, p1, p2) + def codegen_binary_op(location, op, t1 : FloatType, t2 : IntegerType, p1, p2) p2 = codegen_cast(t2, t1, p2) - codegen_binary_op op, t1, t1, p1, p2 + codegen_binary_op(location, op, t1, t1, p1, p2) end - def codegen_binary_op(op, t1 : FloatType, t2 : FloatType, p1, p2) + def codegen_binary_op(location, op, t1 : FloatType, t2 : FloatType, p1, p2) if t1.rank < t2.rank p1 = extend_float t2, p1 elsif t1.rank > t2.rank @@ -340,11 +521,11 @@ class Crystal::CodeGenVisitor @last end - def codegen_binary_op(op, t1 : TypeDefType, t2, p1, p2) - codegen_binary_op op, t1.remove_typedef, t2, p1, p2 + def codegen_binary_op(location, op, t1 : TypeDefType, t2, p1, p2) + codegen_binary_op(location, op, t1.remove_typedef, t2, p1, p2) end - def codegen_binary_op(op, t1, t2, p1, p2) + def codegen_binary_op(location, op, t1, t2, p1, p2) raise "BUG: codegen_binary_op called with #{t1} #{op} #{t2}" end diff --git a/src/compiler/crystal/command.cr b/src/compiler/crystal/command.cr index 578307a0e5b3..b30189c1e382 100644 --- a/src/compiler/crystal/command.cr +++ b/src/compiler/crystal/command.cr @@ -300,6 +300,13 @@ class Crystal::Command opts.on("--no-debug", "Skip any symbolic debug info") do compiler.debug = Crystal::Debug::None end + # TODO change deafult + opts.on("--overflow-checked", "Perform integer overflow check") do + compiler.overflow_check = Crystal::OverflowCheckScope::Policy::Checked + end + opts.on("--overflow-unchecked", "Skip integer overflow check (default)") do + compiler.overflow_check = Crystal::OverflowCheckScope::Policy::Unchecked + end {% unless LibLLVM::IS_38 || LibLLVM::IS_39 %} opts.on("--lto=FLAG", "Use ThinLTO --lto=thin") do |flag| error "--lto=thin is the only lto supported option" unless flag == "thin" @@ -488,6 +495,13 @@ class Crystal::Command opts.on("--no-debug", "Skip any symbolic debug info") do compiler.debug = Crystal::Debug::None end + # TODO change deafult + opts.on("--overflow-checked", "Perform integer overflow check") do + compiler.overflow_check = Crystal::OverflowCheckScope::Policy::Checked + end + opts.on("--overflow-unchecked", "Skip integer overflow check (default)") do + compiler.overflow_check = Crystal::OverflowCheckScope::Policy::Unchecked + end opts.on("-D FLAG", "--define FLAG", "Define a compile-time flag") do |flag| compiler.flags << flag end diff --git a/src/compiler/crystal/compiler.cr b/src/compiler/crystal/compiler.cr index 7368f4e8d1a1..da569576cfb4 100644 --- a/src/compiler/crystal/compiler.cr +++ b/src/compiler/crystal/compiler.cr @@ -12,6 +12,9 @@ module Crystal Default = LineNumbers end + # TODO change to checked in future release + DefaultOverflowCheckPolicy = OverflowCheckScope::Policy::Unchecked + # Main interface to the compiler. # # A Compiler parses source code, type checks it and @@ -45,6 +48,9 @@ module Crystal # that can be understood by `gdb` and `lldb`. property debug = Debug::Default + # Defines the default overflow check policy + property overflow_check = DefaultOverflowCheckPolicy + # If `true`, `.ll` files will be generated in the default cache # directory for each generated LLVM module. property? dump_ll = false @@ -187,6 +193,7 @@ module Crystal program.stdout = stdout program.show_error_trace = show_error_trace? program.progress_tracker = @progress_tracker + program.overflow_check = @overflow_check program end @@ -223,7 +230,7 @@ module Crystal private def bc_flags_changed?(output_dir) bc_flags_changed = true - current_bc_flags = "#{@target_triple}|#{@mcpu}|#{@mattr}|#{@release}|#{@link_flags}" + current_bc_flags = "#{@target_triple}|#{@mcpu}|#{@mattr}|#{@release}|#{@link_flags}|#{@overflow_check}" bc_flags_filename = "#{output_dir}/bc_flags" if File.file?(bc_flags_filename) previous_bc_flags = File.read(bc_flags_filename).strip @@ -235,7 +242,10 @@ module Crystal private def codegen(program, node : ASTNode, sources, output_filename) llvm_modules = @progress_tracker.stage("Codegen (crystal)") do - program.codegen node, debug: debug, single_module: @single_module || (!@thin_lto && @release) || @cross_compile || @emit + program.codegen node, + debug: debug, + single_module: @single_module || (!@thin_lto && @release) || @cross_compile || @emit, + overflow_check: overflow_check end output_dir = CacheDir.instance.directory_for(sources) diff --git a/src/compiler/crystal/program.cr b/src/compiler/crystal/program.cr index afd0fb6556f2..de9c1789b0dc 100644 --- a/src/compiler/crystal/program.cr +++ b/src/compiler/crystal/program.cr @@ -113,6 +113,9 @@ module Crystal # Set to a `ProgressTracker` object which tracks compilation progress. property progress_tracker = ProgressTracker.new + # Defines the default overflow check policy + property overflow_check = DefaultOverflowCheckPolicy + def initialize super(self, self, "main") diff --git a/src/compiler/crystal/semantic/main_visitor.cr b/src/compiler/crystal/semantic/main_visitor.cr index 8b02f747333c..3948001bc232 100644 --- a/src/compiler/crystal/semantic/main_visitor.cr +++ b/src/compiler/crystal/semantic/main_visitor.cr @@ -3003,6 +3003,10 @@ module Crystal false end + def end_visit(node : OverflowCheckScope) + node.bind_to node.body + end + # # Helpers def free_vars diff --git a/src/compiler/crystal/semantic/recursive_struct_checker.cr b/src/compiler/crystal/semantic/recursive_struct_checker.cr index ff3d2d4e4570..132443f19126 100644 --- a/src/compiler/crystal/semantic/recursive_struct_checker.cr +++ b/src/compiler/crystal/semantic/recursive_struct_checker.cr @@ -36,16 +36,16 @@ class Crystal::RecursiveStructChecker if struct?(type) target = type - checked = Set(Type).new + checked_types = Set(Type).new path = [] of Var | Type - check_recursive_instance_var_container(target, type, checked, path) + check_recursive_instance_var_container(target, type, checked_types, path) end if type.is_a?(AliasType) && !type.simple? target = type - checked = Set(Type).new + checked_types = Set(Type).new path = [] of Var | Type - check_recursive(target, type.aliased_type, checked, path) + check_recursive(target, type.aliased_type, checked_types, path) end check_types(type) @@ -60,7 +60,7 @@ class Crystal::RecursiveStructChecker end end - def check_recursive(target, type, checked, path) + def check_recursive(target, type, checked_types, path) if target == type if target.is_a?(AliasType) alias_message = " (recursive aliases are structs)" @@ -90,14 +90,14 @@ class Crystal::RecursiveStructChecker end end - return if checked.includes?(type) + return if checked_types.includes?(type) if type.is_a?(VirtualType) if type.struct? push(path, type) do type.subtypes.each do |subtype| push(path, subtype) do - check_recursive(target, subtype, checked, path) + check_recursive(target, subtype, checked_types, path) end end end @@ -109,7 +109,7 @@ class Crystal::RecursiveStructChecker # Check if the module is composed, recursively, of the target struct type.raw_including_types.try &.each do |module_type| push(path, module_type) do - check_recursive(target, module_type, checked, path) + check_recursive(target, module_type, checked_types, path) end end end @@ -117,7 +117,7 @@ class Crystal::RecursiveStructChecker if type.is_a?(InstanceVarContainer) if struct?(type) - check_recursive_instance_var_container(target, type, checked, path) + check_recursive_instance_var_container(target, type, checked_types, path) end end @@ -125,7 +125,7 @@ class Crystal::RecursiveStructChecker push(path, type) do type.union_types.each do |union_type| push(path, union_type) do - check_recursive(target, union_type, checked, path) + check_recursive(target, union_type, checked_types, path) end end end @@ -135,7 +135,7 @@ class Crystal::RecursiveStructChecker push(path, type) do type.tuple_types.each do |tuple_type| push(path, tuple_type) do - check_recursive(target, tuple_type, checked, path) + check_recursive(target, tuple_type, checked_types, path) end end end @@ -145,24 +145,24 @@ class Crystal::RecursiveStructChecker push(path, type) do type.entries.each do |entry| push(path, entry.type) do - check_recursive(target, entry.type, checked, path) + check_recursive(target, entry.type, checked_types, path) end end end end end - def check_recursive_instance_var_container(target, type, checked, path) - checked.add type + def check_recursive_instance_var_container(target, type, checked_types, path) + checked_types.add type type.all_instance_vars.each_value do |var| var_type = var.type? next unless var_type push(path, var) do - check_recursive(target, var_type, checked, path) + check_recursive(target, var_type, checked_types, path) end end - checked.delete type + checked_types.delete type end def path_to_s(path) diff --git a/src/compiler/crystal/semantic/top_level_visitor.cr b/src/compiler/crystal/semantic/top_level_visitor.cr index 08d4fef0e4e1..5cf3e46ee5bf 100644 --- a/src/compiler/crystal/semantic/top_level_visitor.cr +++ b/src/compiler/crystal/semantic/top_level_visitor.cr @@ -675,7 +675,7 @@ class Crystal::TopLevelVisitor < Crystal::SemanticVisitor if counter == 0 # In case the member is set to 0 1 else - counter * 2 + __next_unchecked { counter * 2 } end else counter + 1 diff --git a/src/compiler/crystal/syntax/ast.cr b/src/compiler/crystal/syntax/ast.cr index 6a6b9620a199..bbabdd8ebef7 100644 --- a/src/compiler/crystal/syntax/ast.cr +++ b/src/compiler/crystal/syntax/ast.cr @@ -1526,6 +1526,30 @@ module Crystal def_equals_and_hash @body, @types, @name end + class OverflowCheckScope < ASTNode + enum Policy + Unchecked + Checked + end + + property body : ASTNode + property policy : Policy + + def initialize(@policy : Policy, body = nil) + @body = Expressions.from body + end + + def accept_children(visitor) + @body.accept visitor + end + + def clone_without_location + OverflowCheckScope.new(@policy, @body.clone) + end + + def_equals_and_hash @policy, @body + end + class ExceptionHandler < ASTNode property body : ASTNode property rescues : Array(Rescue)? diff --git a/src/compiler/crystal/syntax/lexer.cr b/src/compiler/crystal/syntax/lexer.cr index 5145ae54b283..c828c70ba631 100644 --- a/src/compiler/crystal/syntax/lexer.cr +++ b/src/compiler/crystal/syntax/lexer.cr @@ -762,6 +762,10 @@ module Crystal if next_char == 's' && next_char == 'e' return check_ident_or_keyword(:case, start) end + when 'h' + if next_char == 'e' && next_char == 'c' && next_char == 'k' && next_char == 'e' && next_char == 'd' + return check_ident_or_keyword(:checked, start) + end when 'l' if next_char == 'a' && next_char == 's' && next_char == 's' return check_ident_or_keyword(:class, start) @@ -1002,6 +1006,10 @@ module Crystal when 'u' if next_char == 'n' case next_char + when 'c' + if next_char == 'h' && next_char == 'e' && next_char == 'c' && next_char == 'k' && next_char == 'e' && next_char == 'd' + return check_ident_or_keyword(:unchecked, start) + end when 'i' case next_char when 'o' diff --git a/src/compiler/crystal/syntax/parser.cr b/src/compiler/crystal/syntax/parser.cr index 3dafc20158c6..06f9ee28d742 100644 --- a/src/compiler/crystal/syntax/parser.cr +++ b/src/compiler/crystal/syntax/parser.cr @@ -1064,6 +1064,10 @@ module Crystal parse_annotation_def end end + when :unchecked + check_type_declaration { parse_unchecked } + when :checked + check_type_declaration { parse_checked } else set_visibility parse_var_or_call end @@ -1349,6 +1353,31 @@ module Crystal types end + def parse_checked + parse_overflow_check_scope OverflowCheckScope::Policy::Checked + end + + def parse_unchecked + parse_overflow_check_scope OverflowCheckScope::Policy::Unchecked + end + + def parse_overflow_check_scope(policy) + slash_is_regex! + next_token_skip_space_or_newline + + check :"{" + next_token_skip_space_or_newline + + body = parse_expressions + skip_statement_end + + end_location = token_end_location + check :"}" + next_token_skip_space + + OverflowCheckScope.new(policy, body).at_end(end_location) + end + def parse_while parse_while_or_until While end @@ -3724,6 +3753,7 @@ module Crystal :extend, :class, :struct, :module, :enum, :while, :until, :return, :next, :break, :lib, :fun, :alias, :pointerof, :sizeof, :instance_sizeof, :typeof, :private, :protected, :asm, :out, + :checked, :unchecked, # `end` is also invalid because it maybe terminate `def` block. :end true diff --git a/src/compiler/crystal/syntax/to_s.cr b/src/compiler/crystal/syntax/to_s.cr index 0f9346457c37..cabe655d65b9 100644 --- a/src/compiler/crystal/syntax/to_s.cr +++ b/src/compiler/crystal/syntax/to_s.cr @@ -1087,6 +1087,20 @@ module Crystal false end + def visit(node : OverflowCheckScope) + case node.policy + when OverflowCheckScope::Policy::Unchecked + @str << "unchecked" + when OverflowCheckScope::Policy::Checked + @str << "checked" + end + + @str << " {" + node.body.accept self + @str << '}' + false + end + def to_s_binary(node, op) left_needs_parens = need_parens(node.left) in_parenthesis(left_needs_parens, node.left) diff --git a/src/compiler/crystal/syntax/transformer.cr b/src/compiler/crystal/syntax/transformer.cr index 687e5f653da8..09a6a4e11429 100644 --- a/src/compiler/crystal/syntax/transformer.cr +++ b/src/compiler/crystal/syntax/transformer.cr @@ -260,6 +260,11 @@ module Crystal node end + def transform(node : OverflowCheckScope) + node.body = node.body.transform(self) + node + end + def transform(node : Generic) node.name = node.name.transform(self) transform_many node.type_vars diff --git a/src/crypto/blowfish.cr b/src/crypto/blowfish.cr index c02964db17f8..8126dc5246af 100644 --- a/src/crypto/blowfish.cr +++ b/src/crypto/blowfish.cr @@ -73,7 +73,9 @@ class Crypto::Blowfish c = (x >> 8) & 0xff_u32 b = (x >> 16) & 0xff_u32 a = (x >> 24) & 0xff_u32 - ((@s.to_unsafe[a] + @s1[b]) ^ @s2[c]) + @s3[d] + __next_unchecked { + ((@s.to_unsafe[a] + @s1[b]) ^ @s2[c]) + @s3[d] + } end private def next_word(data, pos) diff --git a/src/crystal/hasher.cr b/src/crystal/hasher.cr index 388083d74bca..e5939a5b375e 100644 --- a/src/crystal/hasher.cr +++ b/src/crystal/hasher.cr @@ -93,26 +93,32 @@ struct Crystal::Hasher end private def permute(v : UInt64) - @a = rotl32(@a ^ v) * C1 - @b = (rotl32(@b) ^ v) * C2 - self + __next_unchecked { + @a = rotl32(@a ^ v) * C1 + @b = (rotl32(@b) ^ v) * C2 + self + } end def result - a, b = @a, @b - a ^= (a >> 23) ^ (a >> 40) - b ^= (b >> 23) ^ (b >> 40) - a *= C1 - b *= C2 - a ^= a >> 32 - b ^= b >> 32 - a + b + __next_unchecked { + a, b = @a, @b + a ^= (a >> 23) ^ (a >> 40) + b ^= (b >> 23) ^ (b >> 40) + a *= C1 + b *= C2 + a ^= a >> 32 + b ^= b >> 32 + a + b + } end def nil - @a += @b - @b += 1 - self + __next_unchecked { + @a += @b + @b += 1 + self + } end def bool(value) diff --git a/src/debug/dwarf/info.cr b/src/debug/dwarf/info.cr index 6f165a3597b4..362514b74f0f 100644 --- a/src/debug/dwarf/info.cr +++ b/src/debug/dwarf/info.cr @@ -38,23 +38,25 @@ module Debug end def each - end_offset = @offset + @unit_length - attributes = [] of {AT, FORM, Value} + __next_unchecked { + end_offset = @offset + @unit_length + attributes = [] of {AT, FORM, Value} - while @io.tell < end_offset - code = DWARF.read_unsigned_leb128(@io) - attributes.clear + while @io.tell < end_offset + code = DWARF.read_unsigned_leb128(@io) + attributes.clear - if abbrev = abbreviations[code - 1]? # abbreviations.find { |a| a.code == abbrev } - abbrev.attributes.each do |attr| - value = read_attribute_value(attr.form) - attributes << {attr.at, attr.form, value} + if abbrev = abbreviations[code - 1]? # abbreviations.find { |a| a.code == abbrev } + abbrev.attributes.each do |attr| + value = read_attribute_value(attr.form) + attributes << {attr.at, attr.form, value} + end + yield code, abbrev, attributes + else + yield code, nil, attributes end - yield code, abbrev, attributes - else - yield code, nil, attributes end - end + } end private def read_attribute_value(form) diff --git a/src/debug/dwarf/line_numbers.cr b/src/debug/dwarf/line_numbers.cr index 30a2afab4d67..2037c75ea871 100644 --- a/src/debug/dwarf/line_numbers.cr +++ b/src/debug/dwarf/line_numbers.cr @@ -278,88 +278,90 @@ module Debug # TODO: support LNE::DefineFile (manually register file, uncommon) private def read_statement_program(sequence) - registers = Register.new(sequence.default_is_stmt) + __next_unchecked { + registers = Register.new(sequence.default_is_stmt) - loop do - opcode = @io.read_byte.not_nil! - - if opcode >= sequence.opcode_base - # special opcode - adjusted_opcode = opcode - sequence.opcode_base - operation_advance = adjusted_opcode / sequence.line_range - increment_address_and_op_index(operation_advance) - - registers.line += sequence.line_base + (adjusted_opcode % sequence.line_range) - register_to_matrix(sequence, registers) - registers.reset - elsif opcode == 0 - # extended opcode - len = DWARF.read_unsigned_leb128(@io) - 1 # -1 accounts for the opcode - extended_opcode = LNE.new(@io.read_byte.not_nil!) - - case extended_opcode - when LNE::EndSequence - registers.end_sequence = true - register_to_matrix(sequence, registers) - if (@io.tell - @offset - sequence.offset) < sequence.total_length - registers = Register.new(sequence.default_is_stmt) - else - break - end - when LNE::SetAddress - case len - when 8 then registers.address = @io.read_bytes(UInt64) - when 4 then registers.address = @io.read_bytes(UInt32).to_u64 - else @io.skip(len) - end - registers.op_index = 0_u32 - when LNE::SetDiscriminator - registers.discriminator = DWARF.read_unsigned_leb128(@io) - else - # skip unsupported opcode - @io.read_fully(Bytes.new(len)) - end - else - # standard opcode - standard_opcode = LNS.new(opcode) + loop do + opcode = @io.read_byte.not_nil! - case standard_opcode - when LNS::Copy - register_to_matrix(sequence, registers) - registers.reset - when LNS::AdvancePc - operation_advance = DWARF.read_unsigned_leb128(@io) - increment_address_and_op_index(operation_advance) - when LNS::AdvanceLine - registers.line += DWARF.read_signed_leb128(@io) - when LNS::SetFile - registers.file = DWARF.read_unsigned_leb128(@io) - when LNS::SetColumn - registers.column = DWARF.read_unsigned_leb128(@io) - when LNS::NegateStmt - registers.is_stmt = !registers.is_stmt - when LNS::SetBasicBlock - registers.basic_block = true - when LNS::ConstAddPc - adjusted_opcode = 255 - sequence.opcode_base + if opcode >= sequence.opcode_base + # special opcode + adjusted_opcode = opcode - sequence.opcode_base operation_advance = adjusted_opcode / sequence.line_range increment_address_and_op_index(operation_advance) - when LNS::FixedAdvancePc - registers.address += @io.read_bytes(UInt16).not_nil! - registers.op_index = 0_u32 - when LNS::SetPrologueEnd - registers.prologue_end = true - when LNS::SetEpiloqueBegin - registers.epilogue_begin = true - when LNS::SetIsa - registers.isa = DWARF.read_unsigned_leb128(@io) + + registers.line += sequence.line_base + (adjusted_opcode % sequence.line_range) + register_to_matrix(sequence, registers) + registers.reset + elsif opcode == 0 + # extended opcode + len = DWARF.read_unsigned_leb128(@io) - 1 # -1 accounts for the opcode + extended_opcode = LNE.new(@io.read_byte.not_nil!) + + case extended_opcode + when LNE::EndSequence + registers.end_sequence = true + register_to_matrix(sequence, registers) + if (@io.tell - @offset - sequence.offset) < sequence.total_length + registers = Register.new(sequence.default_is_stmt) + else + break + end + when LNE::SetAddress + case len + when 8 then registers.address = @io.read_bytes(UInt64) + when 4 then registers.address = @io.read_bytes(UInt32).to_u64 + else @io.skip(len) + end + registers.op_index = 0_u32 + when LNE::SetDiscriminator + registers.discriminator = DWARF.read_unsigned_leb128(@io) + else + # skip unsupported opcode + @io.read_fully(Bytes.new(len)) + end else - # consume unknown opcode args - n_args = sequence.standard_opcode_lengths[opcode.to_i] - n_args.times { DWARF.read_unsigned_leb128(@io) } + # standard opcode + standard_opcode = LNS.new(opcode) + + case standard_opcode + when LNS::Copy + register_to_matrix(sequence, registers) + registers.reset + when LNS::AdvancePc + operation_advance = DWARF.read_unsigned_leb128(@io) + increment_address_and_op_index(operation_advance) + when LNS::AdvanceLine + registers.line += DWARF.read_signed_leb128(@io) + when LNS::SetFile + registers.file = DWARF.read_unsigned_leb128(@io) + when LNS::SetColumn + registers.column = DWARF.read_unsigned_leb128(@io) + when LNS::NegateStmt + registers.is_stmt = !registers.is_stmt + when LNS::SetBasicBlock + registers.basic_block = true + when LNS::ConstAddPc + adjusted_opcode = 255 - sequence.opcode_base + operation_advance = adjusted_opcode / sequence.line_range + increment_address_and_op_index(operation_advance) + when LNS::FixedAdvancePc + registers.address += @io.read_bytes(UInt16).not_nil! + registers.op_index = 0_u32 + when LNS::SetPrologueEnd + registers.prologue_end = true + when LNS::SetEpiloqueBegin + registers.epilogue_begin = true + when LNS::SetIsa + registers.isa = DWARF.read_unsigned_leb128(@io) + else + # consume unknown opcode args + n_args = sequence.standard_opcode_lengths[opcode.to_i] + n_args.times { DWARF.read_unsigned_leb128(@io) } + end end end - end + } end @current_sequence_matrix : Array(Row)? diff --git a/src/digest/md5.cr b/src/digest/md5.cr index edabb9eb9ae0..75d5bd6a30bf 100644 --- a/src/digest/md5.cr +++ b/src/digest/md5.cr @@ -103,27 +103,35 @@ class Digest::MD5 < Digest::Base end def ff(a, b, c, d, x, s, ac) - a += f(b, c, d) + x + ac.to_u32 - a = rotate_left a, s - a += b + __next_unchecked { + a += f(b, c, d) + x + ac.to_u32 + a = rotate_left a, s + a += b + } end def gg(a, b, c, d, x, s, ac) - a += g(b, c, d) + x + ac.to_u32 - a = rotate_left a, s - a += b + __next_unchecked { + a += g(b, c, d) + x + ac.to_u32 + a = rotate_left a, s + a += b + } end def hh(a, b, c, d, x, s, ac) - a += h(b, c, d) + x + ac.to_u32 - a = rotate_left a, s - a += b + __next_unchecked { + a += h(b, c, d) + x + ac.to_u32 + a = rotate_left a, s + a += b + } end def ii(a, b, c, d, x, s, ac) - a += i(b, c, d) + x + ac.to_u32 - a = rotate_left a, s - a += b + __next_unchecked { + a += i(b, c, d) + x + ac.to_u32 + a = rotate_left a, s + a += b + } end def transform(in) @@ -201,10 +209,12 @@ class Digest::MD5 < Digest::Base c = ii(c, d, a, b, in[2], S43, 718787259) # 63 b = ii(b, c, d, a, in[9], S44, 3951481745) # 64 - @buf[0] += a - @buf[1] += b - @buf[2] += c - @buf[3] += d + __next_unchecked { + @buf[0] += a + @buf[1] += b + @buf[2] += c + @buf[3] += d + } end def final diff --git a/src/digest/sha1.cr b/src/digest/sha1.cr index d4c5758c6ece..895820345638 100644 --- a/src/digest/sha1.cr +++ b/src/digest/sha1.cr @@ -45,72 +45,74 @@ class Digest::SHA1 < Digest::Base end def process_message_block - k = {0x5A827999_u32, 0x6ED9EBA1_u32, 0x8F1BBCDC_u32, 0xCA62C1D6_u32} - - w = uninitialized UInt32[80] - - {% for t in (0...16) %} - w[{{t}}] = @message_block[{{t}} * 4].to_u32 << 24 - w[{{t}}] |= @message_block[{{t}} * 4 + 1].to_u32 << 16 - w[{{t}}] |= @message_block[{{t}} * 4 + 2].to_u32 << 8 - w[{{t}}] |= @message_block[{{t}} * 4 + 3].to_u32 - {% end %} - - {% for t in (16...80) %} - w[{{t}}] = circular_shift(1, w[{{t - 3}}] ^ w[{{t - 8}}] ^ w[{{t - 14}}] ^ w[{{t - 16}}]) - {% end %} - - a = @intermediate_hash[0] - b = @intermediate_hash[1] - c = @intermediate_hash[2] - d = @intermediate_hash[3] - e = @intermediate_hash[4] - - {% for t in (0...20) %} - temp = circular_shift(5, a) + - ((b & c) | ((~b) & d)) + e + w[{{t}}] + k[0] - e = d - d = c - c = circular_shift(30, b) - b = a - a = temp - {% end %} - - {% for t in (20...40) %} - temp = circular_shift(5, a) + (b ^ c ^ d) + e + w[{{t}}] + k[1] - e = d - d = c - c = circular_shift(30, b) - b = a - a = temp - {% end %} - - {% for t in (40...60) %} - temp = circular_shift(5, a) + - ((b & c) | (b & d) | (c & d)) + e + w[{{t}}] + k[2] - e = d - d = c - c = circular_shift(30, b) - b = a - a = temp - {% end %} - - {% for t in (60...80) %} - temp = circular_shift(5, a) + (b ^ c ^ d) + e + w[{{t}}] + k[3] - e = d - d = c - c = circular_shift(30, b) - b = a - a = temp - {% end %} - - @intermediate_hash[0] += a - @intermediate_hash[1] += b - @intermediate_hash[2] += c - @intermediate_hash[3] += d - @intermediate_hash[4] += e - - @message_block_index = 0 + __next_unchecked { + k = {0x5A827999_u32, 0x6ED9EBA1_u32, 0x8F1BBCDC_u32, 0xCA62C1D6_u32} + + w = uninitialized UInt32[80] + + {% for t in (0...16) %} + w[{{t}}] = @message_block[{{t}} * 4].to_u32 << 24 + w[{{t}}] |= @message_block[{{t}} * 4 + 1].to_u32 << 16 + w[{{t}}] |= @message_block[{{t}} * 4 + 2].to_u32 << 8 + w[{{t}}] |= @message_block[{{t}} * 4 + 3].to_u32 + {% end %} + + {% for t in (16...80) %} + w[{{t}}] = circular_shift(1, w[{{t - 3}}] ^ w[{{t - 8}}] ^ w[{{t - 14}}] ^ w[{{t - 16}}]) + {% end %} + + a = @intermediate_hash[0] + b = @intermediate_hash[1] + c = @intermediate_hash[2] + d = @intermediate_hash[3] + e = @intermediate_hash[4] + + {% for t in (0...20) %} + temp = circular_shift(5, a) + + ((b & c) | ((~b) & d)) + e + w[{{t}}] + k[0] + e = d + d = c + c = circular_shift(30, b) + b = a + a = temp + {% end %} + + {% for t in (20...40) %} + temp = circular_shift(5, a) + (b ^ c ^ d) + e + w[{{t}}] + k[1] + e = d + d = c + c = circular_shift(30, b) + b = a + a = temp + {% end %} + + {% for t in (40...60) %} + temp = circular_shift(5, a) + + ((b & c) | (b & d) | (c & d)) + e + w[{{t}}] + k[2] + e = d + d = c + c = circular_shift(30, b) + b = a + a = temp + {% end %} + + {% for t in (60...80) %} + temp = circular_shift(5, a) + (b ^ c ^ d) + e + w[{{t}}] + k[3] + e = d + d = c + c = circular_shift(30, b) + b = a + a = temp + {% end %} + + @intermediate_hash[0] += a + @intermediate_hash[1] += b + @intermediate_hash[2] += c + @intermediate_hash[3] += d + @intermediate_hash[4] += e + + @message_block_index = 0 + } end def circular_shift(bits, word) diff --git a/src/exception.cr b/src/exception.cr index e419f07636f2..a2df739e7347 100644 --- a/src/exception.cr +++ b/src/exception.cr @@ -129,6 +129,12 @@ class DivisionByZeroError < Exception end end +class OverflowError < Exception + def initialize(message = "Overflow") + super(message) + end +end + # Raised when a method is not implemented. # # This can be used either to stub out method bodies, or when the method is not diff --git a/src/float/printer/grisu3.cr b/src/float/printer/grisu3.cr index 17734ed5c204..6671cd6be7de 100644 --- a/src/float/printer/grisu3.cr +++ b/src/float/printer/grisu3.cr @@ -151,7 +151,9 @@ module Float::Printer::Grisu3 # Since too_low = too_high - unsafe_interval this is equivalent to # [too_high - unsafe_interval + 4 ulp; too_high - 2 ulp] # Conceptually we have: rest ~= too_high - buffer - return (2 * unit <= rest) && (rest <= unsafe_interval - 4 * unit) + __next_unchecked { + return (2 * unit <= rest) && (rest <= unsafe_interval - 4 * unit) + } end # Generates the digits of input number w. diff --git a/src/float/printer/ieee.cr b/src/float/printer/ieee.cr index 5e9893f46f43..019e65d1dfde 100644 --- a/src/float/printer/ieee.cr +++ b/src/float/printer/ieee.cr @@ -148,7 +148,9 @@ module Float::Printer::IEEE exp = 1 - EXPONENT_BIAS_64 else frac = (d64 & SIGNIFICAND_MASK_64) + HIDDEN_BIT_64 - exp = (((d64 & EXPONENT_MASK_64) >> PHYSICAL_SIGNIFICAND_SIZE_64) - EXPONENT_BIAS_64).to_i + __next_unchecked { + exp = (((d64 & EXPONENT_MASK_64) >> PHYSICAL_SIGNIFICAND_SIZE_64) - EXPONENT_BIAS_64).to_i + } end {frac, exp} @@ -162,7 +164,9 @@ module Float::Printer::IEEE exp = 1 - EXPONENT_BIAS_32 else frac = (d32 & SIGNIFICAND_MASK_32) + HIDDEN_BIT_32 - exp = (((d32 & EXPONENT_MASK_32) >> PHYSICAL_SIGNIFICAND_SIZE_32) - EXPONENT_BIAS_32).to_i + __next_unchecked { + exp = (((d32 & EXPONENT_MASK_32) >> PHYSICAL_SIGNIFICAND_SIZE_32) - EXPONENT_BIAS_32).to_i + } end {frac.to_u64, exp} diff --git a/src/hash.cr b/src/hash.cr index 4bc9b7eb1a2a..4486aa42b589 100644 --- a/src/hash.cr +++ b/src/hash.cr @@ -759,7 +759,9 @@ class Hash(K, V) copy = hasher copy = key.hash(copy) copy = value.hash(copy) - result += copy.result + __next_unchecked { + result += copy.result + } end result.hash(hasher) diff --git a/src/int.cr b/src/int.cr index c778e39944d1..29c23283dcb3 100644 --- a/src/int.cr +++ b/src/int.cr @@ -248,6 +248,8 @@ struct Int # Raises `ArgumentError` if *exponent* is negative: if this is needed, # either use a float base or a float exponent. # + # Raises `OverflowError` if the overflow policy is `checked`. + # # ``` # 2 ** 3 # => 8 # 2 ** 0 # => 1 @@ -262,12 +264,41 @@ struct Int k = self while exponent > 0 result *= k if exponent & 0b1 != 0 - k *= k exponent = exponent.unsafe_shr(1) + k *= k if exponent > 0 end result end + # Returns the value of raising `self` to the power of *exponent*. + # + # Raises `ArgumentError` if *exponent* is negative: if this is needed, + # either use a float base or a float exponent. + # + # Intermediate multiplication will wrap arround silently in case of overflow. + # + # ``` + # 2 ** 3 # => 8 + # 2 ** 0 # => 1 + # 2 ** -1 # ArgumentError + # ``` + def unchecked_pow(exponent : Int) : self + if exponent < 0 + raise ArgumentError.new "Cannot raise an integer to a negative integer power, use floats for that" + end + + __next_unchecked { + result = self.class.new(1) + k = self + while exponent > 0 + result *= k if exponent & 0b1 != 0 + exponent = exponent.unsafe_shr(1) + k *= k if exponent > 0 + end + result + } + end + # Returns the value of raising `self` to the power of *exponent*. # # ``` diff --git a/src/json/lexer.cr b/src/json/lexer.cr index ab4d25baf5ff..a2cf6c94fe4b 100644 --- a/src/json/lexer.cr +++ b/src/json/lexer.cr @@ -247,11 +247,13 @@ abstract class JSON::Lexer integer = (current_char - '0').to_i64 char = next_char while '0' <= char <= '9' - append_number_char - integer *= 10 - integer += char - '0' - digits += 1 - char = next_char + __next_unchecked { + append_number_char + integer *= 10 + integer += char - '0' + digits += 1 + char = next_char + } end case char @@ -279,12 +281,14 @@ abstract class JSON::Lexer end while '0' <= char <= '9' - append_number_char - integer *= 10 - integer += char - '0' - divisor *= 10 - digits += 1 - char = next_char + __next_unchecked { + append_number_char + integer *= 10 + integer += char - '0' + divisor *= 10 + digits += 1 + char = next_char + } end float = integer.to_f64 / divisor diff --git a/src/macros.cr b/src/macros.cr index 22777ab31013..56a1cabf39d7 100644 --- a/src/macros.cr +++ b/src/macros.cr @@ -193,3 +193,12 @@ macro assert_responds_to(var, method) raise "Expected {{var}} to respond to :{{method}}, not #{ {{var}} }" end end + +# TODO remove in next release +macro __next_unchecked + {% if Crystal::VERSION.includes?("0.25.0+") || compare_versions(Crystal::VERSION, "0.25.0") > 0 %} + unchecked { {{ yield }} } + {% else %} + {{ yield }} + {% end %} +end diff --git a/src/primitives.cr b/src/primitives.cr index a5bffa75608c..d45a9c1c644a 100644 --- a/src/primitives.cr +++ b/src/primitives.cr @@ -325,6 +325,7 @@ end {% if op != "/" %} # Returns the result of {{desc.id}} `self` and *other*. @[Primitive(:binary)] + @[Raises] def {{op.id}}(other : {{int2.id}}) : self end {% end %} diff --git a/src/random.cr b/src/random.cr index 6a9989821d11..1dedb8fa006b 100644 --- a/src/random.cr +++ b/src/random.cr @@ -116,107 +116,109 @@ module Random {% utype = "UInt#{size}".id %} {% for type in ["Int#{size}".id, utype] %} private def rand_int(max : {{type}}) : {{type}} - if max == 0 - return {{type}}.new(0) - end - - unless max > 0 - raise ArgumentError.new "Invalid bound for rand: #{max}" - end - - # The basic ideas of the algorithm are best illustrated with examples. - # - # Let's say we have a random number generator that gives uniformly distributed random - # numbers between 0 and 15. We need to get a uniformly distributed random number between - # 0 and 5 (*max* = 6). The typical mistake made in this case is to just use `rand() % 6`, - # but it is clear that some results will appear more often than others. So, the surefire - # approach is to make the RNG spit out numbers until it gives one inside our desired range. - # That is really wasteful though. So the approach taken here is to discard only a small - # range of the possible generated numbers, and use the modulo operation on the "valid" ones, - # like this (where X means "discard and try again"): - # - # Generated number: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 - # Result: 0 1 2 3 4 5 0 1 2 3 4 5 X X X X - # - # 12 is the *limit* here - the highest number divisible by *max* while still being within - # bounds of what the RNG can produce. - # - # On the other side of the spectrum is the problem of generating a random number in a higher - # range than what the RNG can produce. Let's say we have the same mentioned RNG, but we need - # a uniformly distributed random number between 0 and 255. All that needs to be done is to - # generate two random numbers between 0 and 15, and combine their bits - # (i.e. `rand()*16 + rand()`). - # - # Using a combination of these tricks, any RNG can be turned into any RNG, however, there - # are several difficult parts about this. The code below uses as few calls to the underlying - # RNG as possible, meaning that (with the above example) with *max* being 257, it would call - # the RNG 3 times. (Of course, it doesn't actually deal with RNGs that produce numbers - # 0 to 15, only with the `UInt8`, `UInt16`, `UInt32` and `UInt64` ranges. - # - # Another problem is how to actually compute the *limit*. The obvious way to do it, which is - # `(RAND_MAX + 1) / max * max`, fails because `RAND_MAX` is usually already the highest - # number that an integer type can hold. And even the *limit* itself will often be - # `RAND_MAX + 1`, meaning that we don't have to discard anything. The ways to deal with this - # are described below. - - # if max - 1 <= typeof(next_u)::MAX - if typeof(next_u).new(max - 1) == max - 1 - # One number from the RNG will be enough. - # All the computations will (almost) fit into `typeof(next_u)`. - - # Relies on integer overflow + wraparound to find the highest number divisible by *max*. - limit = typeof(next_u).new(0) - (typeof(next_u).new(0) - max) % max - # *limit* might be 0, which means it would've been `typeof(next_u)::MAX + 1, but didn't - # fit into the integer type. - - loop do - result = next_u - - # For a uniform distribution we may need to throw away some numbers - if result < limit || limit == 0 - return {{type}}.new(result % max) - end + __next_unchecked { + if max == 0 + return {{type}}.new(0) end - else - # We need to find out how many random numbers need to be combined to be able to generate a - # random number of this magnitude. - # All the computations will be based on `{{utype}}` as the larger type. - - # `rand_max - 1` is the maximal number we can get from combining `needed_parts` random - # numbers. - # Compute *rand_max* as `(typeof(next_u)::MAX + 1) ** needed_parts)`. - # If *rand_max* overflows, that means it has reached `high({{utype}}) + 1`. - rand_max = {{utype}}.new(1) << (sizeof(typeof(next_u))*8) - needed_parts = 1 - while rand_max < max && rand_max > 0 - rand_max <<= sizeof(typeof(next_u))*8 - needed_parts += 1 + + unless max > 0 + raise ArgumentError.new "Invalid bound for rand: #{max}" end - limit = - if rand_max > 0 - # `rand_max` didn't overflow, so we can calculate the *limit* the straightforward way. - rand_max / max * max - else - # *rand_max* is `{{utype}}::MAX + 1`, need the same wraparound trick. *limit* might - # overflow, which means it would've been `{{utype}}::MAX + 1`, but didn't fit into - # the integer type. - {{utype}}.new(0) - ({{utype}}.new(0) - max) % max + # The basic ideas of the algorithm are best illustrated with examples. + # + # Let's say we have a random number generator that gives uniformly distributed random + # numbers between 0 and 15. We need to get a uniformly distributed random number between + # 0 and 5 (*max* = 6). The typical mistake made in this case is to just use `rand() % 6`, + # but it is clear that some results will appear more often than others. So, the surefire + # approach is to make the RNG spit out numbers until it gives one inside our desired range. + # That is really wasteful though. So the approach taken here is to discard only a small + # range of the possible generated numbers, and use the modulo operation on the "valid" ones, + # like this (where X means "discard and try again"): + # + # Generated number: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + # Result: 0 1 2 3 4 5 0 1 2 3 4 5 X X X X + # + # 12 is the *limit* here - the highest number divisible by *max* while still being within + # bounds of what the RNG can produce. + # + # On the other side of the spectrum is the problem of generating a random number in a higher + # range than what the RNG can produce. Let's say we have the same mentioned RNG, but we need + # a uniformly distributed random number between 0 and 255. All that needs to be done is to + # generate two random numbers between 0 and 15, and combine their bits + # (i.e. `rand()*16 + rand()`). + # + # Using a combination of these tricks, any RNG can be turned into any RNG, however, there + # are several difficult parts about this. The code below uses as few calls to the underlying + # RNG as possible, meaning that (with the above example) with *max* being 257, it would call + # the RNG 3 times. (Of course, it doesn't actually deal with RNGs that produce numbers + # 0 to 15, only with the `UInt8`, `UInt16`, `UInt32` and `UInt64` ranges. + # + # Another problem is how to actually compute the *limit*. The obvious way to do it, which is + # `(RAND_MAX + 1) / max * max`, fails because `RAND_MAX` is usually already the highest + # number that an integer type can hold. And even the *limit* itself will often be + # `RAND_MAX + 1`, meaning that we don't have to discard anything. The ways to deal with this + # are described below. + + # if max - 1 <= typeof(next_u)::MAX + if typeof(next_u).new(max - 1) == max - 1 + # One number from the RNG will be enough. + # All the computations will (almost) fit into `typeof(next_u)`. + + # Relies on integer overflow + wraparound to find the highest number divisible by *max*. + limit = typeof(next_u).new(0) - (typeof(next_u).new(0) - max) % max + # *limit* might be 0, which means it would've been `typeof(next_u)::MAX + 1, but didn't + # fit into the integer type. + + loop do + result = next_u + + # For a uniform distribution we may need to throw away some numbers + if result < limit || limit == 0 + return {{type}}.new(result % max) + end + end + else + # We need to find out how many random numbers need to be combined to be able to generate a + # random number of this magnitude. + # All the computations will be based on `{{utype}}` as the larger type. + + # `rand_max - 1` is the maximal number we can get from combining `needed_parts` random + # numbers. + # Compute *rand_max* as `(typeof(next_u)::MAX + 1) ** needed_parts)`. + # If *rand_max* overflows, that means it has reached `high({{utype}}) + 1`. + rand_max = {{utype}}.new(1) << (sizeof(typeof(next_u))*8) + needed_parts = 1 + while rand_max < max && rand_max > 0 + rand_max <<= sizeof(typeof(next_u))*8 + needed_parts += 1 end - loop do - result = rand_type({{utype}}, needed_parts) - - # For a uniform distribution we may need to throw away some numbers. - if result < limit || limit == 0 - return {{type}}.new(result % max) + limit = + if rand_max > 0 + # `rand_max` didn't overflow, so we can calculate the *limit* the straightforward way. + rand_max / max * max + else + # *rand_max* is `{{utype}}::MAX + 1`, need the same wraparound trick. *limit* might + # overflow, which means it would've been `{{utype}}::MAX + 1`, but didn't fit into + # the integer type. + {{utype}}.new(0) - ({{utype}}.new(0) - max) % max + end + + loop do + result = rand_type({{utype}}, needed_parts) + + # For a uniform distribution we may need to throw away some numbers. + if result < limit || limit == 0 + return {{type}}.new(result % max) + end end end - end + } end private def rand_range(range : Range({{type}}, {{type}})) : {{type}} - span = {{utype}}.new(range.end - range.begin) + span = {{utype}}.new(__next_unchecked { range.end - range.begin }) if range.excludes_end? unless range.begin < range.end raise ArgumentError.new "Invalid range for rand: #{range}" diff --git a/src/random/isaac.cr b/src/random/isaac.cr index 0070a8e8ce15..c05a648eb99f 100644 --- a/src/random/isaac.cr +++ b/src/random/isaac.cr @@ -43,21 +43,23 @@ class Random::ISAAC end private def isaac - @cc += 1 - @bb += cc - - 256.times do |i| - @aa ^= case i % 4 - when 0 then aa << 13 - when 1 then aa >> 6 - when 2 then aa << 2 - else aa >> 16 - end - x = @mm[i] - @aa = @mm[(i + 128) % 256] + aa - @mm[i] = y = @mm[(x >> 2) % 256] + aa + bb - @rsl[i] = @bb = @mm[(y >> 10) % 256] + x - end + __next_unchecked { + @cc += 1 + @bb += cc + + 256.times do |i| + @aa ^= case i % 4 + when 0 then aa << 13 + when 1 then aa >> 6 + when 2 then aa << 2 + else aa >> 16 + end + x = @mm[i] + @aa = @mm[(i + 128) % 256] + aa + @mm[i] = y = @mm[(x >> 2) % 256] + aa + bb + @rsl[i] = @bb = @mm[(y >> 10) % 256] + x + end + } end private def init_by_array(seeds) @@ -67,25 +69,29 @@ class Random::ISAAC a = b = c = d = e = f = g = h = 0x9e3779b9_u32 mix = ->{ - a ^= b << 11; d += a; b += c - b ^= c >> 2; e += b; c += d - c ^= d << 8; f += c; d += e - d ^= e >> 16; g += d; e += f - e ^= f << 10; h += e; f += g - f ^= g >> 4; a += f; g += h - g ^= h << 8; b += g; h += a - h ^= a >> 9; c += h; a += b + __next_unchecked { + a ^= b << 11; d += a; b += c + b ^= c >> 2; e += b; c += d + c ^= d << 8; f += c; d += e + d ^= e >> 16; g += d; e += f + e ^= f << 10; h += e; f += g + f ^= g >> 4; a += f; g += h + g ^= h << 8; b += g; h += a + h ^= a >> 9; c += h; a += b + } } 4.times(&mix) scramble = ->(seed : StaticArray(UInt32, 256)) { - 0.step(to: 255, by: 8) do |i| - a += seed[i]; b += seed[i + 1]; c += seed[i + 2]; d += seed[i + 3] - e += seed[i + 4]; f += seed[i + 5]; g += seed[i + 6]; h += seed[i + 7] - mix.call - @mm[i] = a; @mm[i + 1] = b; @mm[i + 2] = c; @mm[i + 3] = d - @mm[i + 4] = e; @mm[i + 5] = f; @mm[i + 6] = g; @mm[i + 7] = h - end + __next_unchecked { + 0.step(to: 255, by: 8) do |i| + a += seed[i]; b += seed[i + 1]; c += seed[i + 2]; d += seed[i + 3] + e += seed[i + 4]; f += seed[i + 5]; g += seed[i + 6]; h += seed[i + 7] + mix.call + @mm[i] = a; @mm[i + 1] = b; @mm[i + 2] = c; @mm[i + 3] = d + @mm[i + 4] = e; @mm[i + 5] = f; @mm[i + 6] = g; @mm[i + 7] = h + end + } } scramble.call(@rsl) diff --git a/src/random/pcg32.cr b/src/random/pcg32.cr index 2ad06cfb25a7..3179849b49fc 100644 --- a/src/random/pcg32.cr +++ b/src/random/pcg32.cr @@ -58,33 +58,40 @@ class Random::PCG32 @state = 0_u64 @inc = (initseq << 1) | 1 next_u - @state += initstate + __next_unchecked { + @state += initstate + } next_u end def next_u - oldstate = @state - @state = oldstate * PCG_DEFAULT_MULTIPLIER_64 + @inc - xorshifted = UInt32.new(((oldstate >> 18) ^ oldstate) >> 27) - rot = UInt32.new(oldstate >> 59) - return UInt32.new((xorshifted >> rot) | (xorshifted << ((~rot + 1) & 31))) + __next_unchecked { + oldstate = @state + @state = oldstate * PCG_DEFAULT_MULTIPLIER_64 + @inc + xorshifted = UInt32.new(((oldstate >> 18) ^ oldstate) >> 27) + rot = UInt32.new(oldstate >> 59) + res = UInt32.new((xorshifted >> rot) | (xorshifted << ((~rot + 1) & 31))) + return res + } end def jump(delta) - deltau64 = UInt64.new(delta) - acc_mult = 1u64 - acc_plus = 0u64 - cur_plus = @inc - cur_mult = PCG_DEFAULT_MULTIPLIER_64 - while (deltau64 > 0) - if deltau64 & 1 > 0 - acc_mult *= cur_mult - acc_plus = acc_plus * cur_mult + cur_plus + __next_unchecked { + deltau64 = UInt64.new(delta) + acc_mult = 1u64 + acc_plus = 0u64 + cur_plus = @inc + cur_mult = PCG_DEFAULT_MULTIPLIER_64 + while (deltau64 > 0) + if deltau64 & 1 > 0 + acc_mult *= cur_mult + acc_plus = acc_plus * cur_mult + cur_plus + end + cur_plus = (cur_mult + 1) * cur_plus + cur_mult *= cur_mult + deltau64 /= 2 end - cur_plus = (cur_mult + 1) * cur_plus - cur_mult *= cur_mult - deltau64 /= 2 - end - @state = acc_mult * @state + acc_plus + @state = acc_mult * @state + acc_plus + } end end diff --git a/src/string.cr b/src/string.cr index a1c7aaea8def..1653969adaab 100644 --- a/src/string.cr +++ b/src/string.cr @@ -497,7 +497,7 @@ class String if info.negative {% if max_negative %} return yield if info.value > {{max_negative}} - -info.value.to_{{method}} + __next_unchecked { 0.to_{{method}} - info.value.to_{{method}} } {% else %} return yield {% end %} @@ -582,7 +582,9 @@ class String value *= base old = value - value += digit + __next_unchecked { + value += digit + } if value < old invalid = true break @@ -2531,7 +2533,7 @@ class String {% if i != 1 %} byte = head_pointer.value {% end %} - hash = hash * PRIME_RK + pointer.value - pow * byte + hash = __next_unchecked { hash * PRIME_RK + pointer.value - pow * byte } pointer += 1 head_pointer += 1 {% end %} @@ -2579,9 +2581,9 @@ class String # calculate a rolling hash of search text (needle) search_hash = 0u32 search.each_byte do |b| - search_hash = search_hash * PRIME_RK + b + search_hash = __next_unchecked { search_hash * PRIME_RK + b } end - pow = PRIME_RK ** search.bytesize + pow = PRIME_RK.unchecked_pow search.bytesize # Find start index with offset char_index = 0 @@ -2608,7 +2610,7 @@ class String hash_end_pointer = pointer + search.bytesize return if hash_end_pointer > end_pointer while pointer < hash_end_pointer - hash = hash * PRIME_RK + pointer.value + hash = __next_unchecked { hash * PRIME_RK + pointer.value } pointer += 1 end @@ -2695,9 +2697,9 @@ class String # calculate a rolling hash of search text (needle) search_hash = 0u32 search.to_slice.reverse_each do |b| - search_hash = search_hash * PRIME_RK + b + search_hash = __next_unchecked { search_hash * PRIME_RK + b } end - pow = PRIME_RK ** search.bytesize + pow = PRIME_RK.unchecked_pow search.bytesize hash = 0u32 char_index = size @@ -2715,7 +2717,7 @@ class String byte = pointer.value char_index -= 1 if (byte & 0xC0) != 0x80 - hash = hash * PRIME_RK + byte + hash = __next_unchecked { hash * PRIME_RK + byte } end while true @@ -2733,7 +2735,7 @@ class String char_index -= 1 if (byte & 0xC0) != 0x80 # update a rolling hash of this text (haystack) - hash = hash * PRIME_RK + byte - pow * tail_pointer.value + hash = __next_unchecked { hash * PRIME_RK + byte - pow * tail_pointer.value } end end @@ -2869,9 +2871,9 @@ class String # calculate a rolling hash of search text (needle) search_hash = 0u32 search.each_byte do |b| - search_hash = search_hash * PRIME_RK + b + search_hash = __next_unchecked { search_hash * PRIME_RK + b } end - pow = PRIME_RK ** search.bytesize + pow = PRIME_RK.unchecked_pow search.bytesize # calculate a rolling hash of this text (haystack) pointer = head_pointer = to_unsafe + offset @@ -2880,7 +2882,7 @@ class String hash = 0u32 return if hash_end_pointer > end_pointer while pointer < hash_end_pointer - hash = hash * PRIME_RK + pointer.value + hash = __next_unchecked { hash * PRIME_RK + pointer.value } pointer += 1 end @@ -2893,7 +2895,7 @@ class String return if pointer >= end_pointer # update a rolling hash of this text (haystack) - hash = hash * PRIME_RK + pointer.value - pow * head_pointer.value + hash = __next_unchecked { hash * PRIME_RK + pointer.value - pow * head_pointer.value } pointer += 1 head_pointer += 1 offset += 1 diff --git a/src/time/span.cr b/src/time/span.cr index 894af5eb33cb..267b4a124c17 100644 --- a/src/time/span.cr +++ b/src/time/span.cr @@ -109,7 +109,7 @@ struct Time::Span # "legal" (i.e. temporary) (e.g. if other parameters are negative) or # illegal (e.g. sign change). if days > 0 - sd = SECONDS_PER_DAY.to_i64 * days + sd = __next_unchecked { SECONDS_PER_DAY.to_i64 * days } if sd < days overflow = true elsif s < 0 @@ -123,7 +123,7 @@ struct Time::Span overflow = s < 0 end elsif days < 0 - sd = SECONDS_PER_DAY.to_i64 * days + sd = __next_unchecked { SECONDS_PER_DAY.to_i64 * days } if sd > days overflow = true elsif s <= 0