From 17566e6c9fad26f7b1201bbc7029d19c1fa015a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Bl=C3=A5b=C3=A4ck?= Date: Wed, 7 Sep 2022 09:03:19 +0200 Subject: [PATCH] Changing sqrt_abs in favour of sqrt_nan as part of resolving issue #109 --- src/Core.jl | 2 +- src/Operators.jl | 8 +++++--- src/Options.jl | 4 ++-- src/SymbolicRegression.jl | 4 ++-- test/test_operators.jl | 6 +++--- test/test_tree_construction.jl | 4 ++-- 6 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/Core.jl b/src/Core.jl index 7dc17425e..1f76db9f6 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -34,7 +34,7 @@ import .OperatorsModule: log2_nan, log10_nan, log1p_nan, - sqrt_abs, + sqrt_nan, acosh_nan, neg, greater, diff --git a/src/Operators.jl b/src/Operators.jl index 8f7ea89e2..4da4f8757 100644 --- a/src/Operators.jl +++ b/src/Operators.jl @@ -63,6 +63,10 @@ function acosh_nan(x::T)::T where {T<:Real} x < T(1) && return T(NaN) return acosh(x) end +function sqrt_nan(x::T)::T where {T<:Real} + x < T(0) && return T(NaN) + return sqrt(x) +end # Generics: square(x) = x * x @@ -77,10 +81,8 @@ log2_nan(x) = log2(x) log10_nan(x) = log10(x) log1p_nan(x) = log1p(x) acosh_nan(x) = acosh(x) +sqrt_nan(x) = sqrt(x) -function sqrt_abs(x::T)::T where {T} - return sqrt(abs(x)) -end function neg(x::T)::T where {T} return -x end diff --git a/src/Options.jl b/src/Options.jl index 30e89ab10..1caeec59e 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -18,7 +18,7 @@ import ..OperatorsModule: log10_nan, log2_nan, log1p_nan, - sqrt_abs, + sqrt_nan, acosh_nan, atanh_clip import ..EquationModule: Node, string_tree @@ -110,7 +110,7 @@ function unaopmap(op) elseif op == log1p return log1p_nan elseif op == sqrt - return sqrt_abs + return sqrt_nan elseif op == acosh return acosh_nan elseif op == atanh diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 11e8d5839..97e6ddbea 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -40,7 +40,7 @@ export Population, log10_nan, log1p_nan, acosh_nan, - sqrt_abs, + sqrt_nan, neg, greater, relu, @@ -129,7 +129,7 @@ import .CoreModule: log2_nan, log10_nan, log1p_nan, - sqrt_abs, + sqrt_nan, acosh_nan, neg, greater, diff --git a/test/test_operators.jl b/test/test_operators.jl index 0facd6243..a9ec7548a 100644 --- a/test/test_operators.jl +++ b/test/test_operators.jl @@ -9,7 +9,7 @@ using SymbolicRegression: log_nan, log2_nan, log10_nan, - sqrt_abs, + sqrt_nan, acosh_nan, neg, greater, @@ -26,7 +26,6 @@ types_to_test = [Float16, Float32, Float64, BigFloat] for T in types_to_test val = T(0.5) val2 = T(3.2) - @test sqrt_abs(val) == sqrt_abs(-val) @test abs(log_nan(val) - log(val)) < 1e-6 @test isnan(log_nan(-val)) @test abs(log2_nan(val) - log2(val)) < 1e-6 @@ -36,7 +35,8 @@ for T in types_to_test @test abs(acosh_nan(val2) - acosh(val2)) < 1e-6 @test isnan(acosh_nan(-val2)) @test neg(-val) == val - @test sqrt_abs(val) == sqrt(val) + @test sqrt_nan(val) == sqrt(val) + @test isnan(sqrt_nan(-val)) @test mult(val, val2) == val * val2 @test plus(val, val2) == val + val2 @test sub(val, val2) == val - val2 diff --git a/test/test_tree_construction.jl b/test/test_tree_construction.jl index be26cc693..b00e370da 100644 --- a/test/test_tree_construction.jl +++ b/test/test_tree_construction.jl @@ -8,7 +8,7 @@ include("test_params.jl") x1 = 2.0 # Initialize functions in Base.... -for unaop in [cos, exp, log_nan, log2_nan, log10_nan, relu, gamma, acosh_nan] +for unaop in [cos, exp, log_nan, log2_nan, log10_nan, sqrt_nan, relu, gamma, acosh_nan] for binop in [sub] function make_options(; kw...) return Options(; @@ -56,7 +56,7 @@ for unaop in [cos, exp, log_nan, log2_nan, log10_nan, relu, gamma, acosh_nan] Random.seed!(0) N = 100 - if unaop in [log_nan, log2_nan, log10_nan, acosh_nan] + if unaop in [log_nan, log2_nan, log10_nan, acosh_nan, sqrt_nan] X = T.(rand(MersenneTwister(0), 5, N) / 3) else X = T.(randn(MersenneTwister(0), 5, N) / 3)