diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index eced260..899fd5e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,6 @@ jobs: fail-fast: true matrix: version: - - '1.3' - '1.5' - 'nightly' os: diff --git a/InteractModels/Project.toml b/InteractModels/Project.toml index 9ef3b36..0397583 100644 --- a/InteractModels/Project.toml +++ b/InteractModels/Project.toml @@ -12,7 +12,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Interact = "0.10" ModelParameters = "0.3" Reexport = "0.2, 1" -julia = "1" +julia = "1.5" [extras] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" diff --git a/Project.toml b/Project.toml index c64ea43..5e5f533 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ Flatten = "0.4" PrettyTables = "0.8, 0.9, 0.10, 0.11, 0.12, 1" Setfield = "0.6, 0.7, 0.8" Tables = "1" -julia = "1" +julia = "1.5" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/src/ModelParameters.jl b/src/ModelParameters.jl index 7b9cc29..c40a658 100644 --- a/src/ModelParameters.jl +++ b/src/ModelParameters.jl @@ -14,6 +14,8 @@ import AbstractNumbers, Tables using Setfield + +using Base: tail export AbstractModel, Model, StaticModel diff --git a/src/model.jl b/src/model.jl index c8da2e2..235a844 100644 --- a/src/model.jl +++ b/src/model.jl @@ -83,8 +83,10 @@ abstract type AbstractModel end Base.parent(m::AbstractModel) = getfield(m, :parent) setparent(m::AbstractModel, newparent) = @set m.parent = newparent -params(m::AbstractModel) = params(parent(m)) -stripparams(m::AbstractModel) = stripparams(parent(m)) +params(m::AbstractModel; kw...) = params(parent(m); kw...) +stripparams(m::AbstractModel; kw...) = stripparams(parent(m); kw...) +paramfieldnames(m::AbstractModel; kw...) = paramfieldnames(parent(m); kw...) +paramparenttypes(m::AbstractModel; kw...) = paramparenttypes(parent(m); kw...) function update(x::T, values) where {T<:AbstractModel} hasfield(T, :parent) || _updatenotdefined(T) setparent(x, update(parent(x), values)) @@ -92,8 +94,6 @@ end @noinline _update_methoderror(T) = error("Interface method `update` is not defined for $T") -paramfieldnames(m) = Flatten.fieldnameflatten(parent(m), SELECT, IGNORE) -paramparenttypes(m) = Flatten.metaflatten(parent(m), _fieldparentbasetype, SELECT, IGNORE) _fieldparentbasetype(T, ::Type{Val{N}}) where N = T.name.wrapper @@ -102,64 +102,64 @@ _fieldparentbasetype(T, ::Type{Val{N}}) where N = T.name.wrapper # It may seem expensive always calling `param`, but flattening the # object occurs once at compile-time, and should have very little cost here. -Base.length(m::AbstractModel) = length(params(m)) -Base.size(m::AbstractModel) = (length(params(m)),) -Base.first(m::AbstractModel) = first(params(m)) -Base.last(m::AbstractModel) = last(params(m)) +Base.length(m::AbstractModel; kw...) = length(params(m; kw...)) +Base.size(m::AbstractModel; kw...) = (length(params(m; kw...)),) +Base.first(m::AbstractModel; kw...) = first(params(m; kw...)) +Base.last(m::AbstractModel; kw...) = last(params(m; kw...)) Base.firstindex(m::AbstractModel) = 1 -Base.lastindex(m::AbstractModel) = length(params(m)) +Base.lastindex(m::AbstractModel; kw...) = length(params(m; kw...)) Base.getindex(m::AbstractModel, i) = getindex(params(m), i) Base.iterate(m::AbstractModel) = (first(params(m)), 1) Base.iterate(m::AbstractModel, s) = s > length(m) ? nothing : (params(m)[s], s + 1) # Vector methods -Base.collect(m::AbstractModel) = collect(m.val) -Base.vec(m::AbstractModel) = collect(m) -Base.Array(m::AbstractModel) = vec(m) +Base.collect(m::AbstractModel; kw...) = collect(m[:val; kw...]) +Base.vec(m::AbstractModel; kw...) = collect(m; kw...) +Base.Array(m::AbstractModel; kw...) = vec(m; kw...) # Dict methods - data as columns Base.haskey(m::AbstractModel, key::Symbol) = key in keys(m) Base.keys(m::AbstractModel) = _keys(params(m), m) -@inline function Base.setindex!(m::AbstractModel, x, nm::Symbol) +@inline function Base.setindex!(m::AbstractModel, x, nm::Symbol; kw...) if nm == :component erorr("cannot set :component index") elseif nm == :fieldname erorr("cannot set :fieldname index") else newparent = if nm in keys(m) - _setindex(parent(m), Tuple(x), nm) + _setindex(parent(m), Tuple(x), nm; kw...) else - _addindex(parent(m), Tuple(x), nm) + _addindex(parent(m), Tuple(x), nm; kw...) end setparent!(m, newparent) end end # TODO do this with lenses -@inline function _setindex(obj, xs::Tuple, nm::Symbol) +@inline function _setindex(obj, xs::Tuple, nm::Symbol; select=SELECT, ignore=IGNORE) lens = Setfield.PropertyLens{nm}() - newparams = map(params(obj), xs) do par, x - Param(Setfield.set(parent(par), lens, x)) + newparams = map(params(obj; select, ignore), xs) do p, x + setparent(p, Setfield.set(parent(p), lens, x)) end - Flatten.reconstruct(obj, newparams, SELECT, IGNORE) + Flatten.reconstruct(obj, newparams, select, ignore) end -@inline function _addindex(obj, xs::Tuple, nm::Symbol) - newparams = map(params(obj), xs) do par, x - Param((; parent(par)..., (nm => x,)...)) +@inline function _addindex(obj, xs::Tuple, nm::Symbol; select=SELECT, ignore=IGNORE) + newparams = map(params(obj), xs) do p, x + setparent(p, (; parent(p)..., (nm => x,)...)) end - Flatten.reconstruct(obj, newparams, SELECT, IGNORE) + Flatten.reconstruct(obj, newparams, select, ignore) end _keys(params::Tuple, m::AbstractModel) = (:component, :fieldname, keys(first(params))...) _keys(params::Tuple{}, m::AbstractModel) = () -@inline function Base.getindex(m::AbstractModel, nm::Symbol) +@inline function Base.getindex(m::AbstractModel, nm::Symbol; kw...) if nm == :component - paramparenttypes(m) + paramparenttypes(m; kw...) elseif nm == :fieldname - paramfieldnames(m) + paramfieldnames(m; kw...) else - map(p -> getindex(p, nm), params(m)) + map(p -> getindex(p, nm), params(m; kw...)) end end @@ -181,16 +181,15 @@ end setparent!(m::AbstractModel, newparent) = setfield!(m, :parent, newparent) -update!(m::AbstractModel, vals::AbstractVector{<:AbstractParam}) = update!(m, Tuple(vals)) -function update!(params::Tuple{<:AbstractParam,Vararg{<:AbstractParam}}) - setparent!(m, Flatten.reconstruct(parent(m), params, SELECT, IGNORE)) +function update!(m::AbstractModel, vals::AbstractVector{<:AbstractParam}; kw...) + update!(m, Tuple(vals); kw...) end -function update!(m::AbstractModel, table) +function update!(m::AbstractModel, table; kw...) cols = (c for c in Tables.columnnames(table) if !(c in (:component, :fieldname))) for col in cols - setindex!(m, Tables.getcolumn(table, col), col) + setindex!(m, Tables.getcolumn(table, col), col; kw...) end - m + return m end """ @@ -220,30 +219,34 @@ mutable struct Model <: AbstractModel end Model(m::AbstractModel) = Model(parent(m)) -@inline @generated function _update_params(ps::P, values::Union{<:AbstractVector,<:Tuple}) where {N,P<:NTuple{N,Param}} +const ParamTuple = Tuple{Vararg{<:AbstractParam}} + +@inline @generated function _update_params(ps::P, values::Union{<:AbstractVector,<:Tuple}) where {P<:ParamTuple} expr = Expr(:tuple) - for i in 1:N - expr_i = :(Param(NamedTuple{keys(ps[$i])}((values[$i], Base.tail(parent(ps[$i]))...)))) + for i in 1:length(ps.parameters) + expr_i = :(setparent(ps[$i], NamedTuple{keys(ps[$i])}((values[$i], tail(parent(ps[$i]))...)))) push!(expr.args, expr_i) end return expr end -update(x, values) = _update(ModelParameters.params(x), x, values) -@inline function _update(p::P, x, values::Union{<:AbstractVector,<:Tuple}) where {N,P<:NTuple{N,Param}} - @assert length(values) == N "values length must match the number of parameters" - newparams = _update_params(p, values) - Flatten.reconstruct(x, newparams, SELECT, IGNORE) +update(x, values; kw...) = _update(ModelParameters.params(x), x, values; kw...) +@inline function _update(ps::P, x, values::Union{<:AbstractVector,<:Tuple}; + select=SELECT, ignore=IGNORE +) where {P<:ParamTuple} + @assert length(values) == length(ps) "values length must match the number of parameters" + newparams = _update_params(ps, values) + return Flatten.reconstruct(x, newparams, select, ignore) end -@inline function _update(p::P, x, table) where {N,P<:NTuple{N,Param}} - @assert size(table, 1) == N "number of rows must match the number of parameters" +@inline function _update(ps::P, x, table; select=SELECT, ignore=IGNORE) where {P<:ParamTuple} + @assert size(table, 1) == length(ps) "number of rows must match the number of parameters" colnames = (c for c in Tables.columnnames(table) if !(c in (:component, :fieldname))) - newparams = map(p, tuple(1:N...)) do param, i - let names = keys(param); - Param(NamedTuple{names}(map(name -> Tables.getcolumn(table, name)[i], names))) + newparams = map(ps, tuple(1:length(ps)...)) do p, i + let names = keys(p); + setparent(p, NamedTuple{names}(map(name -> Tables.getcolumn(table, name)[i], names))) end end - Flatten.reconstruct(x, newparams, SELECT, IGNORE) + return Flatten.reconstruct(x, newparams, select, ignore) end """ @@ -270,17 +273,17 @@ StaticModel(m::AbstractModel) = StaticModel(parent(m)) # Model Utils -_expandpars(x) = Flatten.reconstruct(parent, _expandkeys(parent), SELECT, IGNORE) +_expandpars(x) = Flatten.reconstruct(parent, _expandkeys(parent), select, ignore) # Expand all Params to have the same keys, filling with `nothing` # This probably will allocate due to `union` returning `Vector` function _expandkeys(x) pars = params(x) allkeys = Tuple(union(map(keys, pars)...)) - return map(pars) do par + return map(pars) do p vals = map(allkeys) do key - get(par, key, nothing) + get(p, key, nothing) end - Param(NamedTuple{allkeys}(vals)) + setparent(p, NamedTuple{allkeys}(vals)) end end diff --git a/src/param.jl b/src/param.jl index 098921e..1607a58 100644 --- a/src/param.jl +++ b/src/param.jl @@ -15,16 +15,18 @@ function ConstructionBase.setproperties(p::P, patch::NamedTuple) where P <: Abst P.name.wrapper(fields) end -@inline withunits(m, args...) = map(p -> withunits(p, args...), params(m)) +@inline withunits(x, args...; kw...) = map(p -> withunits(p, args...), params(x; kw...)) @inline function withunits(p::AbstractParam, fn::Symbol=:val) _applyunits(*, getproperty(p, fn), get(p, :units, nothing)) end -@inline stripunits(m, xs) = map(stripunits, params(m), xs) +@inline stripunits(x, xs; kw...) = map(stripunits, params(x; kw...), xs) @inline function stripunits(p::AbstractParam, x) _applyunits(/, x, get(p, :units, nothing)) end +setparent(m::P, newparent) where P<:AbstractParam = P.name.wrapper(newparent) + # Param might have `nothing` for units @inline _applyunits(f, x, units) = f(x, units) @inline _applyunits(f, x, ::Nothing) = x @@ -42,9 +44,11 @@ Base.values(p::AbstractParam) = values(parent(p)) @inline Base.getproperty(p::AbstractParam, x::Symbol) = getproperty(parent(p), x) @inline Base.get(p::AbstractParam, key::Symbol, default) = get(parent(p), key, default) @inline Base.getindex(p::AbstractParam, i) = getindex(parent(p), i) +Base.parent(p::AbstractParam) = getfield(p, :parent) # AbstractNumber interface + Base.convert(::Type{Number}, x::AbstractParam) = AbstractNumbers.number(x) Base.convert(::Type{P}, x::P) where {P<:AbstractParam} = x AbstractNumbers.number(p::AbstractParam) = withunits(p) @@ -79,15 +83,20 @@ end Param(val; kwargs...) = Param((; val=val, kwargs...)) Param(; kwargs...) = Param((; kwargs...)) -Base.parent(p::Param) = getfield(p, :parent) - -# Methods for objects that hold params -params(x) = Flatten.flatten(x, SELECT, IGNORE) -stripparams(x) = hasparam(x) ? Flatten.reconstruct(x, withunits(x), SELECT, IGNORE) : x +params(x; select=SELECT, ignore=IGNORE) = Flatten.flatten(x, select, ignore) +function stripparams(x; select=SELECT, ignore=IGNORE) + hasparam(x) ? Flatten.reconstruct(x, withunits(x), select, ignore) : x +end +function paramfieldnames(x; select=SELECT, ignore=IGNORE) + Flatten.fieldnameflatten(x, select, ignore) +end +function paramparenttypes(x; select=SELECT, ignore=IGNORE) + Flatten.metaflatten(x, _fieldparentbasetype, select, ignore) +end # Utils -hasparam(obj) = length(params(obj)) > 0 +hasparam(obj; kw...) = length(params(obj; kw...)) > 0 _checkhasval(nt::NamedTuple{Keys}) where {Keys} = first(Keys) == :val || _novalerror(nt) # @noinline avoids allocations unless there is actually an error diff --git a/src/tables.jl b/src/tables.jl index 94db809..953fc37 100644 --- a/src/tables.jl +++ b/src/tables.jl @@ -10,7 +10,7 @@ _columntypes(m) = map(k -> Union{map(typeof, getindex(m, k))...}, keys(m)) # As Columns Tables.columnaccess(::Type{<:AbstractModel}) = true Tables.columns(m::AbstractModel) = m -Tables.getcolumn(m::AbstractModel, nm::Symbol) = collect(getindex(m, nm)) -Tables.getcolumn(m::AbstractModel, i::Int) = collect(getindex(m, i)) -Tables.getcolumn(m::AbstractModel, ::Type{T}, col::Int, nm::Symbol) where T = - collect(getindex(m, nm)) +Tables.getcolumn(m::AbstractModel, nm::Symbol; kw...) = collect(getindex(m, nm; kw...)) +Tables.getcolumn(m::AbstractModel, i::Int; kw...) = collect(getindex(m, i; kw...)) +Tables.getcolumn(m::AbstractModel, ::Type{T}, col::Int, nm::Symbol; kw...) where T = + collect(getindex(m, nm; kw...)) diff --git a/test/runtests.jl b/test/runtests.jl index bc9e2cc..0c84283 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -246,3 +246,28 @@ end @test isa(tuplegroups.A.i, Tuple) @test isa(tuplegroups.B.j, Tuple) end + +struct FreeParam{T,P<:NamedTuple} <: AbstractParam{T} + parent::P +end +FreeParam(nt::NT) where {NT<:NamedTuple} = begin + ModelParameters._checkhasval(nt) + FreeParam{typeof(nt.val),NT}(nt) +end +FreeParam(val; kwargs...) = FreeParam((; val=val, kwargs...)) +FreeParam(; kwargs...) = FreeParam((; kwargs...)) + +@testset "spefify param type" begin + sf = S2( + FreeParam(99), + FreeParam(7), + Param(100.0; bounds=(50.0, 150.0)) + ) + @test params(sf; select=FreeParam) == (FreeParam(99), FreeParam(7)) + m = Model(sf) + @test m[:val, select=Param] == (100.0,) + @test getindex(m, :val; select=FreeParam) == (99, 7) + m[:bounds, select=FreeParam] = ((1, 100), (1, 10)) + @test m[:bounds] == ((1, 100), (1, 10), (50.0, 150.0)) +end +