Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP/RFC: Add explicitly wrapping versions of integer arithmetic #50790

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3405,7 +3405,7 @@ pushfirst!(A, a, b, c...) = pushfirst!(pushfirst!(A, c...), a, b)

const hash_abstractarray_seed = UInt === UInt64 ? 0x7e2d6fb6448beb77 : 0xd4514ce5
function hash(A::AbstractArray, h::UInt)
h += hash_abstractarray_seed
h +%= hash_abstractarray_seed
# Axes are themselves AbstractArrays, so hashing them directly would stack overflow
# Instead hash the tuple of firsts and lasts along each dimension
h = hash(map(first, axes(A)), h)
Expand Down
4 changes: 3 additions & 1 deletion base/abstractset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ max_values(T::Union{map(X -> Type{X}, BitIntegerSmall_types)...}) = 1 << (8*size
function max_values(T::Union)
a = max_values(T.a)::Int
b = max_values(T.b)::Int
return max(a, b, a + b)
r, o = add_with_overflow(a, b)
o && return typemax(Int)
return r
end
max_values(::Type{Bool}) = 2
max_values(::Type{Nothing}) = 1
Expand Down
15 changes: 10 additions & 5 deletions base/bool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,17 @@ isone(x::Bool) = x

## do arithmetic as Int ##

+(x::Bool) = Int(x)
-(x::Bool) = -Int(x)

+(x::Bool, y::Bool) = Int(x) + Int(y)
-(x::Bool, y::Bool) = Int(x) - Int(y)
+(x::Bool) = Int(x)
+%(x::Bool) = Int(x)
-(x::Bool) = -%(Int(x))
-%(x::Bool) = -%(Int(x))

+(x::Bool, y::Bool) = Int(x) +% Int(y)
-(x::Bool, y::Bool) = Int(x) -% Int(y)
+%(x::Bool, y::Bool) = Int(x) +% Int(y)
-%(x::Bool, y::Bool) = Int(x) -% Int(y)
*(x::Bool, y::Bool) = x & y
*%(x::Bool, y::Bool) = x & y
^(x::Bool, y::Bool) = x | !y
^(x::Integer, y::Bool) = ifelse(y, x, one(x))

Expand Down
11 changes: 8 additions & 3 deletions base/char.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,9 @@ isless(x::AbstractChar, y::AbstractChar) = isless(Char(x), Char(y))
hash(x::AbstractChar, h::UInt) = hash(Char(x), h)
widen(::Type{T}) where {T<:AbstractChar} = T

@inline -%(x::AbstractChar, y::AbstractChar) = Int(x) -% Int(y)
@inline -(x::AbstractChar, y::AbstractChar) = Int(x) - Int(y)
@inline function -(x::T, y::Integer) where {T<:AbstractChar}
@inline function -%(x::T, y::Integer) where {T<:AbstractChar}
if x isa Char
u = Int32((bitcast(UInt32, x) >> 24) % Int8)
if u >= 0 # inline the runtime fast path
Expand All @@ -234,7 +235,7 @@ widen(::Type{T}) where {T<:AbstractChar} = T
end
return T(Int32(x) - Int32(y))
end
@inline function +(x::T, y::Integer) where {T<:AbstractChar}
@inline function +%(x::T, y::Integer) where {T<:AbstractChar}
if x isa Char
u = Int32((bitcast(UInt32, x) >> 24) % Int8)
if u >= 0 # inline the runtime fast path
Expand All @@ -244,7 +245,11 @@ end
end
return T(Int32(x) + Int32(y))
end
@inline +(x::Integer, y::AbstractChar) = y + x
@inline +%(x::Integer, y::AbstractChar) = y + x

-(x::AbstractChar, y::Integer) = x -% y
+(x::AbstractChar, y::Integer) = x +% y
+(x::Integer, y::AbstractChar) = x +% y

# `print` should output UTF-8 by default for all AbstractChar types.
# (Packages may implement other IO subtypes to specify different encodings.)
Expand Down
26 changes: 15 additions & 11 deletions base/checked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import Core.Intrinsics:
checked_srem_int,
checked_uadd_int, checked_usub_int, checked_umul_int, checked_udiv_int,
checked_urem_int
import ..no_op_err, ..@inline, ..@noinline, ..checked_length
import ..no_op_err, ..@inline, ..@noinline, ..checked_length, ..BitInteger

# define promotion behavior for checked operations
checked_add(x::Integer, y::Integer) = checked_add(promote(x,y)...)
Expand Down Expand Up @@ -98,7 +98,7 @@ throw_overflowerr_negation(x) = (@noinline;
throw(OverflowError(Base.invokelatest(string, "checked arithmetic: cannot compute -x for x = ", x, "::", typeof(x)))))
if BrokenSignedInt != Union{}
function checked_neg(x::BrokenSignedInt)
r = -x
r = -%(x)
(x<0) & (r<0) && throw_overflowerr_negation(x)
r
end
Expand Down Expand Up @@ -140,11 +140,11 @@ Calculates `r = x+y`, with the flag `f` indicating whether overflow has occurred
function add_with_overflow end
add_with_overflow(x::T, y::T) where {T<:SignedInt} = checked_sadd_int(x, y)
add_with_overflow(x::T, y::T) where {T<:UnsignedInt} = checked_uadd_int(x, y)
add_with_overflow(x::Bool, y::Bool) = (x+y, false)
add_with_overflow(x::Bool, y::Bool) = (x +% y, false)

if BrokenSignedInt != Union{}
function add_with_overflow(x::T, y::T) where T<:BrokenSignedInt
r = x + y
r = x +% y
# x and y have the same sign, and the result has a different sign
f = (x<0) == (y<0) != (r<0)
r, f
Expand All @@ -154,7 +154,7 @@ if BrokenUnsignedInt != Union{}
function add_with_overflow(x::T, y::T) where T<:BrokenUnsignedInt
# x + y > typemax(T)
# Note: ~y == -y-1
x + y, x > ~y
x +% y, x > ~y
end
end

Expand All @@ -171,7 +171,11 @@ The overflow protection may impose a perceptible performance penalty.
"""
function checked_add(x::T, y::T) where T<:Integer
@inline
z, b = add_with_overflow(x, y)
zb = add_with_overflow(x, y)
# Avoid use of tuple destructuring, which uses aritmetic internally,
# so that this can be used as a replacement for +
z = getfield(zb, 1)
b = getfield(zb, 2)
b && throw_overflowerr_binaryop(:+, x, y)
z
end
Expand Down Expand Up @@ -206,7 +210,7 @@ sub_with_overflow(x::Bool, y::Bool) = (x-y, false)

if BrokenSignedInt != Union{}
function sub_with_overflow(x::T, y::T) where T<:BrokenSignedInt
r = x - y
r = x -% y
# x and y have different signs, and the result has a different sign than x
f = (x<0) != (y<0) == (r<0)
r, f
Expand All @@ -215,7 +219,7 @@ end
if BrokenUnsignedInt != Union{}
function sub_with_overflow(x::T, y::T) where T<:BrokenUnsignedInt
# x - y < 0
x - y, x < y
x -% y, x < y
end
end

Expand All @@ -242,7 +246,7 @@ Calculates `r = x*y`, with the flag `f` indicating whether overflow has occurred
function mul_with_overflow end
mul_with_overflow(x::T, y::T) where {T<:SignedInt} = checked_smul_int(x, y)
mul_with_overflow(x::T, y::T) where {T<:UnsignedInt} = checked_umul_int(x, y)
mul_with_overflow(x::Bool, y::Bool) = (x*y, false)
mul_with_overflow(x::Bool, y::Bool) = (x *% y, false)

if BrokenSignedIntMul != Union{} && BrokenSignedIntMul != Int128
function mul_with_overflow(x::T, y::T) where T<:BrokenSignedIntMul
Expand Down Expand Up @@ -273,14 +277,14 @@ if Int128 <: BrokenSignedIntMul
else
false
end
x*y, f
x *% y, f
end
end
if UInt128 <: BrokenUnsignedIntMul
# Avoid BigInt
function mul_with_overflow(x::T, y::T) where T<:UInt128
# x * y > typemax(T)
x * y, y > 0 && x > fld(typemax(T), y)
x *% y, y > 0 && x > fld(typemax(T), y)
end
end

Expand Down
11 changes: 11 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,14 @@ macro pure(ex)
end

# END 1.10 deprecations

# BEGIN 1.11 deprecations

# These operators are new in 1.11, but these fallback methods are added for
# compatibility while packages adjust to defining both operators, to allow
# Base and other packages to start using these.
*%(a::T, b::T) where {T} = *(a, b)
+%(a::T, b::T) where {T} = +(a, b)
-%(a::T, b::T) where {T} = -(a, b)

# END 1.11 deprecations
69 changes: 69 additions & 0 deletions base/docs/basedocs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2644,6 +2644,30 @@ julia> +(1, 20, 4)
"""
(+)(x, y...)

"""
+%(x::Integer, y::Integer...)

Addition operator with semantic wrapping. In the default Julia environment, this
is equivalent to the regular addition operator `+`. However, some users may choose to overwrite
`+` in their local environment to perform checked arithmetic instead (e.g. using
[`Experimental.@make_all_arithmetic_checked`](@ref)). The `+%` operator may be used to indicate
that wrapping behavior is semantically expected and correct and should thus be exempted from
any opt-in overflow checking.

# Examples
```jldoctest
julia> 1 +% 20 +% 4
25

julia> +%(1, 20, 4)
25

julia> typemax(Int) +% 1
-9223372036854775808
```
"""
(+%)(x, y...)

"""
-(x)

Expand Down Expand Up @@ -2683,6 +2707,27 @@ julia> -(2, 4.5)
"""
-(x, y)

"""
-%(x::Integer, y::Integer...)

Subtraction operator with semantic wrapping. In the default Julia environment, this
is equivalent to the regular subtraction operator `-`. However, some users may choose to overwrite
`-` in their local environment to perform checked arithmetic instead (e.g. using
[`Experimental.@make_all_arithmetic_checked`](@ref)). The `-%` operator may be used to indicate
that wrapping behavior is semantically expected and correct and should thus be exempted from
any opt-in overflow checking.

# Examples
```jldoctest
julia> 2 -% 3
-1

julia> -(typemin(Int))
-9223372036854775808
```
"""
(-%)(x, y...)

"""
*(x, y...)

Expand All @@ -2699,6 +2744,30 @@ julia> *(2, 7, 8)
"""
(*)(x, y...)

"""
*%(x::Integer, y::Integer, z::Integer...)

Multiplication operator with semantic wrapping. In the default Julia environment, this
is equivalent to the regular multiplication operator `*`. However, some users may choose to overwrite
`*` in their local environment to perform checked arithmetic instead (e.g. using
[`Experimental.@make_all_arithmetic_checked`](@ref)). The `*%` operator may be used to indicate
that wrapping behavior is semantically expected and correct and should thus be exempted from
any opt-in overflow checking.

# Examples
```jldoctest
julia> 2 *% 7 *% 8
112

julia> *(2, 7, 8)
112

julia> 0xff *% 0xff
0x01
```
"""
(*%)(x, y, z...)

"""
/(x, y)

Expand Down
21 changes: 21 additions & 0 deletions base/experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,25 @@ adding them to the global method table.
"""
:@MethodTable

"""
Experimental.@make_all_arithmetic_checked()

This macro defines methods that overwrite the base definition of basic arithmetic (+,-,*),
to use their checked variants instead. Explicitly overflowing arithmetic operators (+%,-%,*%)
are not affected.

!!! warning
This macro is temporary and will likely be replaced by a more complete mechanism in the
future. It is subject to change or removal without notice.
"""
macro make_all_arithmetic_checked()
esc(quote
Base.:(-)(x::BitInteger) = Base.Checked.checked_neg(x)
Base.:(-)(x::T, y::T) where {T<:BitInteger} = Base.Checked.checked_sub(x, y)
Base.:(+)(x::T, y::T) where {T<:BitInteger} = Base.Checked.checked_add(x, y)
Base.:(*)(x::T, y::T) where {T<:BitInteger} = Base.Checked.checked_mul(x, y)
Base.:(-)(x::AbstractChar, y::AbstractChar) = Int(x) - Int(y)
end)
end

end
3 changes: 3 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,11 @@ export
÷,
&,
*,
*%,
+,
+%,
-,
-%,
/,
//,
<,
Expand Down
5 changes: 3 additions & 2 deletions base/filesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,10 @@ end

function read(f::File, ::Type{Char})
b0 = read(f, UInt8)
l = 0x08 * (0x04 - UInt8(leading_ones(b0)))
lo = UInt8(leading_ones(b0))
c = UInt32(b0) << 24
if l ≤ 0x10
if 0x02 ≤ lo ≤ 0x04
l = 0x08 * (0x04 - lo)
s = 16
while s ≥ l && !eof(f)
# this works around lack of peek(::File)
Expand Down
4 changes: 2 additions & 2 deletions base/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ function hash(x::Float64, h::UInt)
elseif isnan(x)
return hx_NaN ⊻ h # NaN does not have a stable bit pattern
end
return hash_uint64(bitcast(UInt64, x)) - 3h
return hash_uint64(bitcast(UInt64, x)) -% (3 *% h)
end

hash(x::Float32, h::UInt) = hash(Float64(x), h)
Expand All @@ -665,7 +665,7 @@ function hash(x::Float16, h::UInt)
elseif isnan(x)
return hx_NaN ⊻ h # NaN does not have a stable bit pattern
end
return hash_uint64(bitcast(UInt64, Float64(x))) - 3h
return hash_uint64(bitcast(UInt64, Float64(x))) - (3 *% h)
end

## generic hashing for rational values ##
Expand Down
15 changes: 8 additions & 7 deletions base/gmp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module GMP

export BigInt

import .Base: *, +, -, /, <, <<, >>, >>>, <=, ==, >, >=, ^, (~), (&), (|), xor, nand, nor,
import .Base: *, *%, +, +%, -, -%, /, <, <<, >>, >>>, <=, ==, >, >=, ^, (~), (&), (|), xor, nand, nor,
binomial, cmp, convert, div, divrem, factorial, cld, fld, gcd, gcdx, lcm, mod,
ndigits, promote_rule, rem, show, isqrt, string, powermod, sum, prod,
trailing_zeros, trailing_ones, count_ones, count_zeros, tryparse_internal,
Expand Down Expand Up @@ -334,7 +334,7 @@ function BigInt(x::Integer)
isbits(x) && typemin(Clong) <= x <= typemax(Clong) && return BigInt((x % Clong)::Clong)
nd = ndigits(x, base=2)
z = MPZ.realloc2(nd)
ux = unsigned(x < 0 ? -x : x)
ux = unsigned(x < 0 ? -%(x) : x)
size = 0
limbnbits = sizeof(Limb) << 3
while nd > 0
Expand Down Expand Up @@ -494,6 +494,7 @@ big(n::Integer) = convert(BigInt, n)

# Binary ops
for (fJ, fC) in ((:+, :add), (:-,:sub), (:*, :mul),
(:+%, :add), (:-%,:sub), (:*%, :mul),
(:mod, :fdiv_r), (:rem, :tdiv_r),
(:gcd, :gcd), (:lcm, :lcm),
(:&, :and), (:|, :ior), (:xor, :xor))
Expand Down Expand Up @@ -552,10 +553,10 @@ end
-(x::BigInt, c::CulongMax) = MPZ.sub_ui(x, c)
-(c::CulongMax, x::BigInt) = MPZ.ui_sub(c, x)

+(x::BigInt, c::ClongMax) = c < 0 ? -(x, -(c % Culong)) : x + convert(Culong, c)
+(c::ClongMax, x::BigInt) = c < 0 ? -(x, -(c % Culong)) : x + convert(Culong, c)
-(x::BigInt, c::ClongMax) = c < 0 ? +(x, -(c % Culong)) : -(x, convert(Culong, c))
-(c::ClongMax, x::BigInt) = c < 0 ? -(x + -(c % Culong)) : -(convert(Culong, c), x)
+(x::BigInt, c::ClongMax) = c < 0 ? -(x, -%(c % Culong)) : x + convert(Culong, c)
+(c::ClongMax, x::BigInt) = c < 0 ? -(x, -%(c % Culong)) : x + convert(Culong, c)
-(x::BigInt, c::ClongMax) = c < 0 ? +(x, -%(c % Culong)) : -(x, convert(Culong, c))
-(c::ClongMax, x::BigInt) = c < 0 ? -(x + -%(c % Culong)) : -(convert(Culong, c), x)

*(x::BigInt, c::CulongMax) = MPZ.mul_ui(x, c)
*(c::CulongMax, x::BigInt) = x * c
Expand Down Expand Up @@ -873,7 +874,7 @@ if Limb === UInt64 === UInt
return hash(unsafe_load(ptr), h)
elseif sz == -1
limb = unsafe_load(ptr)
limb <= typemin(Int) % UInt && return hash(-(limb % Int), h)
limb <= typemin(Int) % UInt && return hash(-%(limb % Int), h)
end
pow = trailing_zeros(x)
nd = Base.ndigits0z(x, 2)
Expand Down
Loading