Skip to content

Commit

Permalink
WIP: Add explicitly wrapping versions of integer arithmetic
Browse files Browse the repository at this point in the history
This adds operators `+%`, `-%`, `*%`, which are equivalent to the
non-`%` versions, but indicate an explicit semantic expectation that
twos completement wrapping behavior is expected and correct. As discussed
at JuliaCon 2014 and every year since, users have often requested
a way to opt into explicit overflow checking of arithmetic, whether
for debugging or because they have regulatory or procedural requirements
that expect to be able to do this. Having explicit operators for
overflowing semantics allows use cases that depend on overflow behavior
for correct functioning to explicitly opt-out of any such checking.

I want to explicitly emphasize that there are no plans to change
the default behavior of arithmetic in Julia, neither by introducing
error checking nor by making it undefined behavior (as in C). The
general consensus here is that while overflow checking can be useful,
and would be a fine default, even if hardware supported it efficiently
(which it doesn't), the performance costs of performing the check
(through inhibition of other optimization) is too high. In our experience
it also tends to be relatively harmless, even if it can be a very
rude awakeing to users coming from Python or other languages with
big-default integers.

The idea here is simply to give users another tool in their arsenal
for checking correctness. Think sanitizers, not language change.
This PR includes a macro `@Base.Experimental.make_all_arithmetic_checked`,
that will define overrides to make arithmetic checked, but does not
include any mechanism (e.g. #50239) to make this fast.

What is included in this PR:
 - Flisp parser changes to parse the new operators
 - Definitions of the new operators
 - Some basic replacements in base to give a flavor for using the
   new operator and make sure it works

Still to be done:
 - [] Parser changes in JuliaSyntax
 - [] Correct parsing for `+%` by itself, which currently parses as `+(%)`

The places to change in base were found by using the above-mentioned
macro and running the test suite. I did not work through the tests
exhaustively. We have many tests that explicitly expect overflow and
many others that we should go through on a case by case basis. The
idea here is merely to give an idea of the kind of changes that
may be required if overflow checking is enabled. I think they can
broadly be classed into:

- Crypto and hashing code that explicitly want modular arithmetic
- Bit twidelling code for arithmetic tricks (though many of these,
  particularly in Ryu, could probably be replaced with better
  abstractions).
- UInt8 range checks written by Stefan
- Misc
  • Loading branch information
Keno committed Aug 3, 2023
1 parent 210c5b5 commit 3b57645
Show file tree
Hide file tree
Showing 44 changed files with 310 additions and 131 deletions.
2 changes: 1 addition & 1 deletion base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3405,7 +3405,7 @@ pushfirst!(A, a, b, c...) = pushfirst!(pushfirst!(A, c...), a, b)

const hash_abstractarray_seed = UInt === UInt64 ? 0x7e2d6fb6448beb77 : 0xd4514ce5
function hash(A::AbstractArray, h::UInt)
h += hash_abstractarray_seed
h +%= hash_abstractarray_seed
# Axes are themselves AbstractArrays, so hashing them directly would stack overflow
# Instead hash the tuple of firsts and lasts along each dimension
h = hash(map(first, axes(A)), h)
Expand Down
4 changes: 3 additions & 1 deletion base/abstractset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ max_values(T::Union{map(X -> Type{X}, BitIntegerSmall_types)...}) = 1 << (8*size
function max_values(T::Union)
a = max_values(T.a)::Int
b = max_values(T.b)::Int
return max(a, b, a + b)
r, o = add_with_overflow(a, b)
o && return typemax(Int)
return r
end
max_values(::Type{Bool}) = 2
max_values(::Type{Nothing}) = 1
Expand Down
15 changes: 10 additions & 5 deletions base/bool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,17 @@ isone(x::Bool) = x

## do arithmetic as Int ##

+(x::Bool) = Int(x)
-(x::Bool) = -Int(x)

+(x::Bool, y::Bool) = Int(x) + Int(y)
-(x::Bool, y::Bool) = Int(x) - Int(y)
+(x::Bool) = Int(x)
+%(x::Bool) = Int(x)
-(x::Bool) = -%(Int(x))
-%(x::Bool) = -%(Int(x))

+(x::Bool, y::Bool) = Int(x) +% Int(y)
-(x::Bool, y::Bool) = Int(x) -% Int(y)
+%(x::Bool, y::Bool) = Int(x) +% Int(y)
-%(x::Bool, y::Bool) = Int(x) -% Int(y)
*(x::Bool, y::Bool) = x & y
*%(x::Bool, y::Bool) = x & y
^(x::Bool, y::Bool) = x | !y
^(x::Integer, y::Bool) = ifelse(y, x, one(x))

Expand Down
11 changes: 8 additions & 3 deletions base/char.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,9 @@ isless(x::AbstractChar, y::AbstractChar) = isless(Char(x), Char(y))
hash(x::AbstractChar, h::UInt) = hash(Char(x), h)
widen(::Type{T}) where {T<:AbstractChar} = T

@inline -%(x::AbstractChar, y::AbstractChar) = Int(x) -% Int(y)
@inline -(x::AbstractChar, y::AbstractChar) = Int(x) - Int(y)
@inline function -(x::T, y::Integer) where {T<:AbstractChar}
@inline function -%(x::T, y::Integer) where {T<:AbstractChar}
if x isa Char
u = Int32((bitcast(UInt32, x) >> 24) % Int8)
if u >= 0 # inline the runtime fast path
Expand All @@ -234,7 +235,7 @@ widen(::Type{T}) where {T<:AbstractChar} = T
end
return T(Int32(x) - Int32(y))
end
@inline function +(x::T, y::Integer) where {T<:AbstractChar}
@inline function +%(x::T, y::Integer) where {T<:AbstractChar}
if x isa Char
u = Int32((bitcast(UInt32, x) >> 24) % Int8)
if u >= 0 # inline the runtime fast path
Expand All @@ -244,7 +245,11 @@ end
end
return T(Int32(x) + Int32(y))
end
@inline +(x::Integer, y::AbstractChar) = y + x
@inline +%(x::Integer, y::AbstractChar) = y + x

-(x::AbstractChar, y::Integer) = x -% y
+(x::AbstractChar, y::Integer) = x +% y
+(x::Integer, y::AbstractChar) = x +% y

# `print` should output UTF-8 by default for all AbstractChar types.
# (Packages may implement other IO subtypes to specify different encodings.)
Expand Down
26 changes: 15 additions & 11 deletions base/checked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import Core.Intrinsics:
checked_srem_int,
checked_uadd_int, checked_usub_int, checked_umul_int, checked_udiv_int,
checked_urem_int
import ..no_op_err, ..@inline, ..@noinline, ..checked_length
import ..no_op_err, ..@inline, ..@noinline, ..checked_length, ..BitInteger

# define promotion behavior for checked operations
checked_add(x::Integer, y::Integer) = checked_add(promote(x,y)...)
Expand Down Expand Up @@ -98,7 +98,7 @@ throw_overflowerr_negation(x) = (@noinline;
throw(OverflowError(Base.invokelatest(string, "checked arithmetic: cannot compute -x for x = ", x, "::", typeof(x)))))
if BrokenSignedInt != Union{}
function checked_neg(x::BrokenSignedInt)
r = -x
r = -%(x)
(x<0) & (r<0) && throw_overflowerr_negation(x)
r
end
Expand Down Expand Up @@ -140,11 +140,11 @@ Calculates `r = x+y`, with the flag `f` indicating whether overflow has occurred
function add_with_overflow end
add_with_overflow(x::T, y::T) where {T<:SignedInt} = checked_sadd_int(x, y)
add_with_overflow(x::T, y::T) where {T<:UnsignedInt} = checked_uadd_int(x, y)
add_with_overflow(x::Bool, y::Bool) = (x+y, false)
add_with_overflow(x::Bool, y::Bool) = (x +% y, false)

if BrokenSignedInt != Union{}
function add_with_overflow(x::T, y::T) where T<:BrokenSignedInt
r = x + y
r = x +% y
# x and y have the same sign, and the result has a different sign
f = (x<0) == (y<0) != (r<0)
r, f
Expand All @@ -154,7 +154,7 @@ if BrokenUnsignedInt != Union{}
function add_with_overflow(x::T, y::T) where T<:BrokenUnsignedInt
# x + y > typemax(T)
# Note: ~y == -y-1
x + y, x > ~y
x +% y, x > ~y
end
end

Expand All @@ -171,7 +171,11 @@ The overflow protection may impose a perceptible performance penalty.
"""
function checked_add(x::T, y::T) where T<:Integer
@inline
z, b = add_with_overflow(x, y)
zb = add_with_overflow(x, y)
# Avoid use of tuple destructuring, which uses aritmetic internally,
# so that this can be used as a replacement for +
z = getfield(zb, 1)
b = getfield(zb, 2)
b && throw_overflowerr_binaryop(:+, x, y)
z
end
Expand Down Expand Up @@ -206,7 +210,7 @@ sub_with_overflow(x::Bool, y::Bool) = (x-y, false)

if BrokenSignedInt != Union{}
function sub_with_overflow(x::T, y::T) where T<:BrokenSignedInt
r = x - y
r = x -% y
# x and y have different signs, and the result has a different sign than x
f = (x<0) != (y<0) == (r<0)
r, f
Expand All @@ -215,7 +219,7 @@ end
if BrokenUnsignedInt != Union{}
function sub_with_overflow(x::T, y::T) where T<:BrokenUnsignedInt
# x - y < 0
x - y, x < y
x -% y, x < y
end
end

Expand All @@ -242,7 +246,7 @@ Calculates `r = x*y`, with the flag `f` indicating whether overflow has occurred
function mul_with_overflow end
mul_with_overflow(x::T, y::T) where {T<:SignedInt} = checked_smul_int(x, y)
mul_with_overflow(x::T, y::T) where {T<:UnsignedInt} = checked_umul_int(x, y)
mul_with_overflow(x::Bool, y::Bool) = (x*y, false)
mul_with_overflow(x::Bool, y::Bool) = (x *% y, false)

if BrokenSignedIntMul != Union{} && BrokenSignedIntMul != Int128
function mul_with_overflow(x::T, y::T) where T<:BrokenSignedIntMul
Expand Down Expand Up @@ -273,14 +277,14 @@ if Int128 <: BrokenSignedIntMul
else
false
end
x*y, f
x *% y, f
end
end
if UInt128 <: BrokenUnsignedIntMul
# Avoid BigInt
function mul_with_overflow(x::T, y::T) where T<:UInt128
# x * y > typemax(T)
x * y, y > 0 && x > fld(typemax(T), y)
x *% y, y > 0 && x > fld(typemax(T), y)
end
end

Expand Down
11 changes: 11 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,14 @@ macro pure(ex)
end

# END 1.10 deprecations

# BEGIN 1.11 deprecations

# These operators are new in 1.11, but these fallback methods are added for
# compatibility while packages adjust to defining both operators, to allow
# Base and other packages to start using these.
*%(a::T, b::T) where {T} = *(a, b)
+%(a::T, b::T) where {T} = +(a, b)
-%(a::T, b::T) where {T} = -(a, b)

# END 1.11 deprecations
69 changes: 69 additions & 0 deletions base/docs/basedocs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2644,6 +2644,30 @@ julia> +(1, 20, 4)
"""
(+)(x, y...)

"""
+%(x::Integer, y::Integer...)
Addition operator with semantic wrapping. In the default Julia environment, this
is equivalent to the regular addition operator `+`. However, some users may choose to overwrite
`+` in their local environment to perform checked arithmetic instead (e.g. using
[`Experimental.@make_all_arithmetic_checked`](@ref)). The `+%` operator may be used to indicate
that wrapping behavior is semantically expected and correct and should thus be exempted from
any opt-in overflow checking.
# Examples
```jldoctest
julia> 1 +% 20 +% 4
25
julia> +%(1, 20, 4)
25
julia> typemax(Int) +% 1
-9223372036854775808
```
"""
(+%)(x, y...)

"""
-(x)
Expand Down Expand Up @@ -2683,6 +2707,27 @@ julia> -(2, 4.5)
"""
-(x, y)

"""
-%(x::Integer, y::Integer...)
Subtraction operator with semantic wrapping. In the default Julia environment, this
is equivalent to the regular subtraction operator `-`. However, some users may choose to overwrite
`-` in their local environment to perform checked arithmetic instead (e.g. using
[`Experimental.@make_all_arithmetic_checked`](@ref)). The `-%` operator may be used to indicate
that wrapping behavior is semantically expected and correct and should thus be exempted from
any opt-in overflow checking.
# Examples
```jldoctest
julia> 2 -% 3
-1
julia> -(typemin(Int))
-9223372036854775808
```
"""
(-%)(x, y...)

"""
*(x, y...)
Expand All @@ -2699,6 +2744,30 @@ julia> *(2, 7, 8)
"""
(*)(x, y...)

"""
*%(x::Integer, y::Integer, z::Integer...)
Multiplication operator with semantic wrapping. In the default Julia environment, this
is equivalent to the regular multiplication operator `*`. However, some users may choose to overwrite
`*` in their local environment to perform checked arithmetic instead (e.g. using
[`Experimental.@make_all_arithmetic_checked`](@ref)). The `*%` operator may be used to indicate
that wrapping behavior is semantically expected and correct and should thus be exempted from
any opt-in overflow checking.
# Examples
```jldoctest
julia> 2 *% 7 *% 8
112
julia> *(2, 7, 8)
112
julia> 0xff *% 0xff
0x01
```
"""
(*%)(x, y, z...)

"""
/(x, y)
Expand Down
21 changes: 21 additions & 0 deletions base/experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,25 @@ adding them to the global method table.
"""
:@MethodTable

"""
Experimental.@make_all_arithmetic_checked()
This macro defines methods that overwrite the base definition of basic arithmetic (+,-,*),
to use their checked variants instead. Explicitly overflowing arithmetic operators (+%,-%,*%)
are not affected.
!!! warning
This macro is temporary and will likely be replaced by a more complete mechanism in the
future. It is subject to change or removal without notice.
"""
macro make_all_arithmetic_checked()
esc(quote
Base.:(-)(x::BitInteger) = Base.Checked.checked_neg(x)
Base.:(-)(x::T, y::T) where {T<:BitInteger} = Base.Checked.checked_sub(x, y)
Base.:(+)(x::T, y::T) where {T<:BitInteger} = Base.Checked.checked_add(x, y)
Base.:(*)(x::T, y::T) where {T<:BitInteger} = Base.Checked.checked_mul(x, y)
Base.:(-)(x::AbstractChar, y::AbstractChar) = Int(x) - Int(y)
end)
end

end
3 changes: 3 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,11 @@ export
÷,
&,
*,
*%,
+,
+%,
-,
-%,
/,
//,
<,
Expand Down
5 changes: 3 additions & 2 deletions base/filesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,10 @@ end

function read(f::File, ::Type{Char})
b0 = read(f, UInt8)
l = 0x08 * (0x04 - UInt8(leading_ones(b0)))
lo = UInt8(leading_ones(b0))
c = UInt32(b0) << 24
if l 0x10
if 0x02 lo 0x04
l = 0x08 * (0x04 - lo)
s = 16
while s l && !eof(f)
# this works around lack of peek(::File)
Expand Down
4 changes: 2 additions & 2 deletions base/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ function hash(x::Float64, h::UInt)
elseif isnan(x)
return hx_NaN h # NaN does not have a stable bit pattern
end
return hash_uint64(bitcast(UInt64, x)) - 3h
return hash_uint64(bitcast(UInt64, x)) -% (3 *% h)
end

hash(x::Float32, h::UInt) = hash(Float64(x), h)
Expand All @@ -665,7 +665,7 @@ function hash(x::Float16, h::UInt)
elseif isnan(x)
return hx_NaN h # NaN does not have a stable bit pattern
end
return hash_uint64(bitcast(UInt64, Float64(x))) - 3h
return hash_uint64(bitcast(UInt64, Float64(x))) - (3 *% h)
end

## generic hashing for rational values ##
Expand Down
15 changes: 8 additions & 7 deletions base/gmp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module GMP

export BigInt

import .Base: *, +, -, /, <, <<, >>, >>>, <=, ==, >, >=, ^, (~), (&), (|), xor, nand, nor,
import .Base: *, *%, +, +%, -, -%, /, <, <<, >>, >>>, <=, ==, >, >=, ^, (~), (&), (|), xor, nand, nor,
binomial, cmp, convert, div, divrem, factorial, cld, fld, gcd, gcdx, lcm, mod,
ndigits, promote_rule, rem, show, isqrt, string, powermod, sum, prod,
trailing_zeros, trailing_ones, count_ones, count_zeros, tryparse_internal,
Expand Down Expand Up @@ -334,7 +334,7 @@ function BigInt(x::Integer)
isbits(x) && typemin(Clong) <= x <= typemax(Clong) && return BigInt((x % Clong)::Clong)
nd = ndigits(x, base=2)
z = MPZ.realloc2(nd)
ux = unsigned(x < 0 ? -x : x)
ux = unsigned(x < 0 ? -%(x) : x)
size = 0
limbnbits = sizeof(Limb) << 3
while nd > 0
Expand Down Expand Up @@ -494,6 +494,7 @@ big(n::Integer) = convert(BigInt, n)

# Binary ops
for (fJ, fC) in ((:+, :add), (:-,:sub), (:*, :mul),
(:+%, :add), (:-%,:sub), (:*%, :mul),
(:mod, :fdiv_r), (:rem, :tdiv_r),
(:gcd, :gcd), (:lcm, :lcm),
(:&, :and), (:|, :ior), (:xor, :xor))
Expand Down Expand Up @@ -552,10 +553,10 @@ end
-(x::BigInt, c::CulongMax) = MPZ.sub_ui(x, c)
-(c::CulongMax, x::BigInt) = MPZ.ui_sub(c, x)

+(x::BigInt, c::ClongMax) = c < 0 ? -(x, -(c % Culong)) : x + convert(Culong, c)
+(c::ClongMax, x::BigInt) = c < 0 ? -(x, -(c % Culong)) : x + convert(Culong, c)
-(x::BigInt, c::ClongMax) = c < 0 ? +(x, -(c % Culong)) : -(x, convert(Culong, c))
-(c::ClongMax, x::BigInt) = c < 0 ? -(x + -(c % Culong)) : -(convert(Culong, c), x)
+(x::BigInt, c::ClongMax) = c < 0 ? -(x, -%(c % Culong)) : x + convert(Culong, c)
+(c::ClongMax, x::BigInt) = c < 0 ? -(x, -%(c % Culong)) : x + convert(Culong, c)
-(x::BigInt, c::ClongMax) = c < 0 ? +(x, -%(c % Culong)) : -(x, convert(Culong, c))
-(c::ClongMax, x::BigInt) = c < 0 ? -(x + -%(c % Culong)) : -(convert(Culong, c), x)

*(x::BigInt, c::CulongMax) = MPZ.mul_ui(x, c)
*(c::CulongMax, x::BigInt) = x * c
Expand Down Expand Up @@ -873,7 +874,7 @@ if Limb === UInt64 === UInt
return hash(unsafe_load(ptr), h)
elseif sz == -1
limb = unsafe_load(ptr)
limb <= typemin(Int) % UInt && return hash(-(limb % Int), h)
limb <= typemin(Int) % UInt && return hash(-%(limb % Int), h)
end
pow = trailing_zeros(x)
nd = Base.ndigits0z(x, 2)
Expand Down
Loading

0 comments on commit 3b57645

Please sign in to comment.