Skip to content

Commit

Permalink
Merge pull request #872 from SciML/sb/complex_check
Browse files Browse the repository at this point in the history
fix: checking complex type in the parameters of nn
  • Loading branch information
ChrisRackauckas authored Jul 5, 2024
2 parents 6627aa0 + e86183e commit c563541
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,22 @@ Adapt = "4"
AdvancedHMC = "0.6.1"
Aqua = "0.8"
ArrayInterface = "7.9"
CUDA = "5.2"
ChainRulesCore = "1.21"
ComponentArrays = "0.15.8"
CUDA = "5.3"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.14"
Cubature = "1.5"
DiffEqNoiseProcess = "5.20"
Distributions = "0.25.107"
DocStringExtensions = "0.9"
DocStringExtensions = "0.9.3"
DomainSets = "0.6, 0.7"
Flux = "0.14.11"
ForwardDiff = "0.10.36"
Functors = "0.4.4"
Functors = "0.4.10"
Integrals = "4.4"
LineSearches = "7.2"
LinearAlgebra = "1"
LogDensityProblems = "2"
Lux = "0.5.22"
Lux = "0.5.58"
LuxCUDA = "0.3.2"
MCMCChains = "6"
MethodOfLines = "0.11"
Expand Down
2 changes: 1 addition & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, righte
using SciMLBase: @add_kwonly, parameterless_type
using UnPack: @unpack
import ChainRulesCore, Lux, ComponentArrays
using Lux: FromFluxAdaptor
using Lux: FromFluxAdaptor, recursive_eltype
using ChainRulesCore: @non_differentiable

RuntimeGeneratedFunctions.init(@__MODULE__)
Expand Down
3 changes: 1 addition & 2 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
!(chain isa Lux.AbstractExplicitLayer) &&
error("Only Lux.AbstractExplicitLayer neural networks are supported")
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
((eltype(eltype(init_params).types[1]) <: Complex ||
eltype(eltype(init_params).types[2]) <: Complex) &&
(recursive_eltype(init_params) <: Complex &&
alg.strategy isa QuadratureTraining) &&
error("QuadratureTraining cannot be used with complex parameters. Use other strategies.")

Expand Down

0 comments on commit c563541

Please sign in to comment.