Skip to content

Commit

Permalink
Changing sqrt_abs in favour of sqrt_nan as part of resolving issue Mi…
Browse files Browse the repository at this point in the history
  • Loading branch information
johanbluecreek committed Sep 7, 2022
1 parent 325d681 commit 17566e6
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import .OperatorsModule:
log2_nan,
log10_nan,
log1p_nan,
sqrt_abs,
sqrt_nan,
acosh_nan,
neg,
greater,
Expand Down
8 changes: 5 additions & 3 deletions src/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ function acosh_nan(x::T)::T where {T<:Real}
x < T(1) && return T(NaN)
return acosh(x)
end
function sqrt_nan(x::T)::T where {T<:Real}
x < T(0) && return T(NaN)
return sqrt(x)
end

# Generics:
square(x) = x * x
Expand All @@ -77,10 +81,8 @@ log2_nan(x) = log2(x)
log10_nan(x) = log10(x)
log1p_nan(x) = log1p(x)
acosh_nan(x) = acosh(x)
sqrt_nan(x) = sqrt(x)

function sqrt_abs(x::T)::T where {T}
return sqrt(abs(x))
end
function neg(x::T)::T where {T}
return -x
end
Expand Down
4 changes: 2 additions & 2 deletions src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import ..OperatorsModule:
log10_nan,
log2_nan,
log1p_nan,
sqrt_abs,
sqrt_nan,
acosh_nan,
atanh_clip
import ..EquationModule: Node, string_tree
Expand Down Expand Up @@ -110,7 +110,7 @@ function unaopmap(op)
elseif op == log1p
return log1p_nan
elseif op == sqrt
return sqrt_abs
return sqrt_nan
elseif op == acosh
return acosh_nan
elseif op == atanh
Expand Down
4 changes: 2 additions & 2 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export Population,
log10_nan,
log1p_nan,
acosh_nan,
sqrt_abs,
sqrt_nan,
neg,
greater,
relu,
Expand Down Expand Up @@ -129,7 +129,7 @@ import .CoreModule:
log2_nan,
log10_nan,
log1p_nan,
sqrt_abs,
sqrt_nan,
acosh_nan,
neg,
greater,
Expand Down
6 changes: 3 additions & 3 deletions test/test_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using SymbolicRegression:
log_nan,
log2_nan,
log10_nan,
sqrt_abs,
sqrt_nan,
acosh_nan,
neg,
greater,
Expand All @@ -26,7 +26,6 @@ types_to_test = [Float16, Float32, Float64, BigFloat]
for T in types_to_test
val = T(0.5)
val2 = T(3.2)
@test sqrt_abs(val) == sqrt_abs(-val)
@test abs(log_nan(val) - log(val)) < 1e-6
@test isnan(log_nan(-val))
@test abs(log2_nan(val) - log2(val)) < 1e-6
Expand All @@ -36,7 +35,8 @@ for T in types_to_test
@test abs(acosh_nan(val2) - acosh(val2)) < 1e-6
@test isnan(acosh_nan(-val2))
@test neg(-val) == val
@test sqrt_abs(val) == sqrt(val)
@test sqrt_nan(val) == sqrt(val)
@test isnan(sqrt_nan(-val))
@test mult(val, val2) == val * val2
@test plus(val, val2) == val + val2
@test sub(val, val2) == val - val2
Expand Down
4 changes: 2 additions & 2 deletions test/test_tree_construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ include("test_params.jl")
x1 = 2.0

# Initialize functions in Base....
for unaop in [cos, exp, log_nan, log2_nan, log10_nan, relu, gamma, acosh_nan]
for unaop in [cos, exp, log_nan, log2_nan, log10_nan, sqrt_nan, relu, gamma, acosh_nan]
for binop in [sub]
function make_options(; kw...)
return Options(;
Expand Down Expand Up @@ -56,7 +56,7 @@ for unaop in [cos, exp, log_nan, log2_nan, log10_nan, relu, gamma, acosh_nan]

Random.seed!(0)
N = 100
if unaop in [log_nan, log2_nan, log10_nan, acosh_nan]
if unaop in [log_nan, log2_nan, log10_nan, acosh_nan, sqrt_nan]
X = T.(rand(MersenneTwister(0), 5, N) / 3)
else
X = T.(randn(MersenneTwister(0), 5, N) / 3)
Expand Down

0 comments on commit 17566e6

Please sign in to comment.