Skip to content

Commit ec60491

Browse files
authored
Refactor type conversions (#173)
* `big`, `widen` and `float` now only return the promoted primal type when called on `Dual` * reorganize and shorten the code for type conversions
1 parent 2cad635 commit ec60491

File tree

2 files changed

+34
-52
lines changed

2 files changed

+34
-52
lines changed

src/overloads/conversion.jl

Lines changed: 33 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,67 @@
1-
#! format: off
2-
31
##===============#
42
# AbstractTracer #
53
#================#
64

7-
## Type conversions (non-dual)
85
Base.promote_rule(::Type{T}, ::Type{N}) where {T<:AbstractTracer,N<:Real} = T
96
Base.promote_rule(::Type{N}, ::Type{T}) where {T<:AbstractTracer,N<:Real} = T
107

11-
Base.big(::Type{T}) where {T<:AbstractTracer} = T
12-
Base.widen(::Type{T}) where {T<:AbstractTracer} = T
13-
Base.float(::Type{T}) where {T<:AbstractTracer} = T
14-
15-
Base.convert(::Type{T}, x::Real) where {T<:AbstractTracer} = myempty(T)
16-
Base.convert(::Type{T}, t::T) where {T<:AbstractTracer} = t
8+
Base.convert(::Type{T}, x::Real) where {T<:AbstractTracer} = myempty(T)
9+
Base.convert(::Type{T}, t::T) where {T<:AbstractTracer} = t
1710
Base.convert(::Type{<:Real}, t::T) where {T<:AbstractTracer} = t
1811

19-
## Constants
20-
# These are methods defined on types. Methods on variables are in operators.jl
21-
Base.zero(::Type{T}) where {T<:AbstractTracer} = myempty(T)
22-
Base.one(::Type{T}) where {T<:AbstractTracer} = myempty(T)
23-
Base.oneunit(::Type{T}) where {T<:AbstractTracer} = myempty(T)
24-
Base.typemin(::Type{T}) where {T<:AbstractTracer} = myempty(T)
25-
Base.typemax(::Type{T}) where {T<:AbstractTracer} = myempty(T)
26-
Base.eps(::Type{T}) where {T<:AbstractTracer} = myempty(T)
27-
Base.floatmin(::Type{T}) where {T<:AbstractTracer} = myempty(T)
28-
Base.floatmax(::Type{T}) where {T<:AbstractTracer} = myempty(T)
29-
Base.maxintfloat(::Type{T}) where {T<:AbstractTracer} = myempty(T)
30-
3112
##======#
3213
# Duals #
3314
#=======#
3415

35-
function Base.promote_rule(::Type{Dual{P1, T}}, ::Type{Dual{P2, T}}) where {P1,P2,T}
16+
function Base.promote_rule(::Type{Dual{P1,T}}, ::Type{Dual{P2,T}}) where {P1,P2,T}
3617
PP = Base.promote_type(P1, P2) # TODO: possible method call error?
3718
return Dual{PP,T}
3819
end
39-
function Base.promote_rule(::Type{Dual{P, T}}, ::Type{N}) where {P,T,N<:Real}
20+
function Base.promote_rule(::Type{Dual{P,T}}, ::Type{N}) where {P,T,N<:Real}
4021
PP = Base.promote_type(P, N) # TODO: possible method call error?
4122
return Dual{PP,T}
4223
end
43-
function Base.promote_rule(::Type{N}, ::Type{Dual{P, T}}) where {P,T,N<:Real}
24+
function Base.promote_rule(::Type{N}, ::Type{Dual{P,T}}) where {P,T,N<:Real}
4425
PP = Base.promote_type(P, N) # TODO: possible method call error?
4526
return Dual{PP,T}
4627
end
4728

48-
Base.big(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{big(P),T}
49-
Base.widen(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{widen(P),T}
50-
Base.float(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{float(P),T}
29+
Base.convert(::Type{D}, x::Real) where {P,T,D<:Dual{P,T}} = Dual(x, myempty(T))
30+
Base.convert(::Type{D}, d::D) where {P,T,D<:Dual{P,T}} = d
31+
Base.convert(::Type{N}, d::D) where {N<:Real,P,T,D<:Dual{P,T}} = Dual(convert(N, primal(d)), tracer(d))
5132

52-
Base.convert(::Type{D}, x::Real) where {P,T,D<:Dual{P,T}} = Dual(x, myempty(T))
53-
Base.convert(::Type{D}, d::D) where {P,T,D<:Dual{P,T}} = d
54-
Base.convert(::Type{N}, d::D) where {N<:Real,P,T,D<:Dual{P,T}} = Dual(convert(N, primal(d)), tracer(d))
55-
56-
function Base.convert(::Type{Dual{P1,T}}, d::Dual{P2,T}) where {P1,P2,T}
33+
function Base.convert(::Type{Dual{P1,T}}, d::Dual{P2,T}) where {P1,P2,T}
5734
return Dual(convert(P1, primal(d)), tracer(d))
5835
end
5936

60-
# Explicit type conversions
37+
##==========================#
38+
# Explicit type conversions #
39+
#===========================#
40+
6141
for T in (:Int, :Integer, :Float64, :Float32)
42+
# Currently only defined on Dual to avoid invalidations.
6243
@eval function Base.$T(d::Dual)
6344
isemptytracer(d) || throw(InexactError(Symbol($T), $T, d))
64-
$T(primal(d))
45+
return $T(primal(d))
6546
end
6647
end
6748

68-
## Constants
69-
# These are methods defined on types. Methods on variables are in operators.jl
70-
# TODO: only return primal on methods on variable
71-
Base.zero(::Type{D}) where {P,T,D<:Dual{P,T}} = zero(P)
72-
Base.one(::Type{D}) where {P,T,D<:Dual{P,T}} = one(P)
73-
Base.oneunit(::Type{D}) where {P,T,D<:Dual{P,T}} = oneunit(P)
74-
Base.typemin(::Type{D}) where {P,T,D<:Dual{P,T}} = typemin(P)
75-
Base.typemax(::Type{D}) where {P,T,D<:Dual{P,T}} = typemax(P)
76-
Base.eps(::Type{D}) where {P,T,D<:Dual{P,T}} = eps(P)
77-
Base.floatmin(::Type{D}) where {P,T,D<:Dual{P,T}} = floatmin(P)
78-
Base.floatmax(::Type{D}) where {P,T,D<:Dual{P,T}} = floatmax(P)
79-
Base.maxintfloat(::Type{D}) where {P,T,D<:Dual{P,T}} = maxintfloat(P)
49+
##======================#
50+
# Named type promotions #
51+
#=======================#
8052

81-
#! format: on
53+
for f in (:big, :widen, :float)
54+
@eval Base.$f(::Type{T}) where {T<:AbstractTracer} = T
55+
@eval Base.$f(::Type{D}) where {P,T,D<:Dual{P,T}} = $f(P) # only return primal type
56+
end
57+
58+
##============================#
59+
# Constant functions on types #
60+
#=============================#
61+
62+
# Methods on variables are in operators.jl
63+
for f in
64+
(:zero, :one, :oneunit, :typemin, :typemax, :eps, :floatmin, :floatmax, :maxintfloat)
65+
@eval Base.$f(::Type{T}) where {T<:AbstractTracer} = myempty(T)
66+
@eval Base.$f(::Type{D}) where {P,T,D<:Dual{P,T}} = $f(P) # only return primal
67+
end

test/test_constructors.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,7 @@ end
4545
function test_type_conversion_functions(::Type{D}, f::Function) where {P,T,D<:Dual{P,T}}
4646
@testset "Primal type $P_IN" for P_IN in (Int, Float32, Irrational)
4747
P_OUT = f(P_IN)
48-
49-
# Note that this tests Dual{P_IN,T}, not Dual{P,T}
50-
D_IN = Dual{P_IN,T}
51-
D_OUT = Dual{P_OUT,T}
52-
@test f(D_IN) == D_OUT
48+
@test f(Dual{P_IN,T}) == P_OUT # NOTE: this tests Dual{P_IN,T}, not Dual{P,T}
5349
end
5450
end
5551

0 commit comments

Comments
 (0)