Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ jobs:
fail-fast: true
matrix:
version:
- '1.3'
- '1.5'
- 'nightly'
os:
Expand Down
2 changes: 1 addition & 1 deletion InteractModels/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Interact = "0.10"
ModelParameters = "0.3"
Reexport = "0.2, 1"
julia = "1"
julia = "1.5"

[extras]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Flatten = "0.4"
PrettyTables = "0.8, 0.9, 0.10, 0.11, 0.12, 1"
Setfield = "0.6, 0.7, 0.8"
Tables = "1"
julia = "1"
julia = "1.5"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
2 changes: 2 additions & 0 deletions src/ModelParameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import AbstractNumbers,
Tables

using Setfield

using Base: tail

export AbstractModel, Model, StaticModel

Expand Down
105 changes: 54 additions & 51 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,17 @@ abstract type AbstractModel end
Base.parent(m::AbstractModel) = getfield(m, :parent)
setparent(m::AbstractModel, newparent) = @set m.parent = newparent

params(m::AbstractModel) = params(parent(m))
stripparams(m::AbstractModel) = stripparams(parent(m))
params(m::AbstractModel; kw...) = params(parent(m); kw...)
stripparams(m::AbstractModel; kw...) = stripparams(parent(m); kw...)
paramfieldnames(m::AbstractModel; kw...) = paramfieldnames(parent(m); kw...)
paramparenttypes(m::AbstractModel; kw...) = paramparenttypes(parent(m); kw...)
function update(x::T, values) where {T<:AbstractModel}
hasfield(T, :parent) || _updatenotdefined(T)
setparent(x, update(parent(x), values))
end

@noinline _update_methoderror(T) = error("Interface method `update` is not defined for $T")

paramfieldnames(m) = Flatten.fieldnameflatten(parent(m), SELECT, IGNORE)
paramparenttypes(m) = Flatten.metaflatten(parent(m), _fieldparentbasetype, SELECT, IGNORE)

_fieldparentbasetype(T, ::Type{Val{N}}) where N = T.name.wrapper

Expand All @@ -102,64 +102,64 @@ _fieldparentbasetype(T, ::Type{Val{N}}) where N = T.name.wrapper

# It may seem expensive always calling `param`, but flattening the
# object occurs once at compile-time, and should have very little cost here.
Base.length(m::AbstractModel) = length(params(m))
Base.size(m::AbstractModel) = (length(params(m)),)
Base.first(m::AbstractModel) = first(params(m))
Base.last(m::AbstractModel) = last(params(m))
Base.length(m::AbstractModel; kw...) = length(params(m; kw...))
Base.size(m::AbstractModel; kw...) = (length(params(m; kw...)),)
Base.first(m::AbstractModel; kw...) = first(params(m; kw...))
Base.last(m::AbstractModel; kw...) = last(params(m; kw...))
Base.firstindex(m::AbstractModel) = 1
Base.lastindex(m::AbstractModel) = length(params(m))
Base.lastindex(m::AbstractModel; kw...) = length(params(m; kw...))
Base.getindex(m::AbstractModel, i) = getindex(params(m), i)
Base.iterate(m::AbstractModel) = (first(params(m)), 1)
Base.iterate(m::AbstractModel, s) = s > length(m) ? nothing : (params(m)[s], s + 1)

# Vector methods
Base.collect(m::AbstractModel) = collect(m.val)
Base.vec(m::AbstractModel) = collect(m)
Base.Array(m::AbstractModel) = vec(m)
Base.collect(m::AbstractModel; kw...) = collect(m[:val; kw...])
Base.vec(m::AbstractModel; kw...) = collect(m; kw...)
Base.Array(m::AbstractModel; kw...) = vec(m; kw...)

# Dict methods - data as columns
Base.haskey(m::AbstractModel, key::Symbol) = key in keys(m)
Base.keys(m::AbstractModel) = _keys(params(m), m)

@inline function Base.setindex!(m::AbstractModel, x, nm::Symbol)
@inline function Base.setindex!(m::AbstractModel, x, nm::Symbol; kw...)
if nm == :component
erorr("cannot set :component index")
elseif nm == :fieldname
erorr("cannot set :fieldname index")
else
newparent = if nm in keys(m)
_setindex(parent(m), Tuple(x), nm)
_setindex(parent(m), Tuple(x), nm; kw...)
else
_addindex(parent(m), Tuple(x), nm)
_addindex(parent(m), Tuple(x), nm; kw...)
end
setparent!(m, newparent)
end
end
# TODO do this with lenses
@inline function _setindex(obj, xs::Tuple, nm::Symbol)
@inline function _setindex(obj, xs::Tuple, nm::Symbol; select=SELECT, ignore=IGNORE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should select and ignore maybe have type bounds, Type{T} where {T<:AbstractParam}?

lens = Setfield.PropertyLens{nm}()
newparams = map(params(obj), xs) do par, x
Param(Setfield.set(parent(par), lens, x))
newparams = map(params(obj; select, ignore), xs) do p, x
setparent(p, Setfield.set(parent(p), lens, x))
end
Flatten.reconstruct(obj, newparams, SELECT, IGNORE)
Flatten.reconstruct(obj, newparams, select, ignore)
end
@inline function _addindex(obj, xs::Tuple, nm::Symbol)
newparams = map(params(obj), xs) do par, x
Param((; parent(par)..., (nm => x,)...))
@inline function _addindex(obj, xs::Tuple, nm::Symbol; select=SELECT, ignore=IGNORE)
newparams = map(params(obj), xs) do p, x
setparent(p, (; parent(p)..., (nm => x,)...))
end
Flatten.reconstruct(obj, newparams, SELECT, IGNORE)
Flatten.reconstruct(obj, newparams, select, ignore)
end

_keys(params::Tuple, m::AbstractModel) = (:component, :fieldname, keys(first(params))...)
_keys(params::Tuple{}, m::AbstractModel) = ()

@inline function Base.getindex(m::AbstractModel, nm::Symbol)
@inline function Base.getindex(m::AbstractModel, nm::Symbol; kw...)
if nm == :component
paramparenttypes(m)
paramparenttypes(m; kw...)
elseif nm == :fieldname
paramfieldnames(m)
paramfieldnames(m; kw...)
else
map(p -> getindex(p, nm), params(m))
map(p -> getindex(p, nm), params(m; kw...))
end
end

Expand All @@ -181,16 +181,15 @@ end

setparent!(m::AbstractModel, newparent) = setfield!(m, :parent, newparent)

update!(m::AbstractModel, vals::AbstractVector{<:AbstractParam}) = update!(m, Tuple(vals))
function update!(params::Tuple{<:AbstractParam,Vararg{<:AbstractParam}})
setparent!(m, Flatten.reconstruct(parent(m), params, SELECT, IGNORE))
function update!(m::AbstractModel, vals::AbstractVector{<:AbstractParam}; kw...)
update!(m, Tuple(vals); kw...)
end
function update!(m::AbstractModel, table)
function update!(m::AbstractModel, table; kw...)
cols = (c for c in Tables.columnnames(table) if !(c in (:component, :fieldname)))
for col in cols
setindex!(m, Tables.getcolumn(table, col), col)
setindex!(m, Tables.getcolumn(table, col), col; kw...)
end
m
return m
end

"""
Expand Down Expand Up @@ -220,30 +219,34 @@ mutable struct Model <: AbstractModel
end
Model(m::AbstractModel) = Model(parent(m))

@inline @generated function _update_params(ps::P, values::Union{<:AbstractVector,<:Tuple}) where {N,P<:NTuple{N,Param}}
const ParamTuple = Tuple{Vararg{<:AbstractParam}}

@inline @generated function _update_params(ps::P, values::Union{<:AbstractVector,<:Tuple}) where {P<:ParamTuple}
expr = Expr(:tuple)
for i in 1:N
expr_i = :(Param(NamedTuple{keys(ps[$i])}((values[$i], Base.tail(parent(ps[$i]))...))))
for i in 1:length(ps.parameters)
expr_i = :(setparent(ps[$i], NamedTuple{keys(ps[$i])}((values[$i], tail(parent(ps[$i]))...))))
push!(expr.args, expr_i)
end
return expr
end

update(x, values) = _update(ModelParameters.params(x), x, values)
@inline function _update(p::P, x, values::Union{<:AbstractVector,<:Tuple}) where {N,P<:NTuple{N,Param}}
@assert length(values) == N "values length must match the number of parameters"
newparams = _update_params(p, values)
Flatten.reconstruct(x, newparams, SELECT, IGNORE)
update(x, values; kw...) = _update(ModelParameters.params(x), x, values; kw...)
@inline function _update(ps::P, x, values::Union{<:AbstractVector,<:Tuple};
select=SELECT, ignore=IGNORE
) where {P<:ParamTuple}
@assert length(values) == length(ps) "values length must match the number of parameters"
newparams = _update_params(ps, values)
return Flatten.reconstruct(x, newparams, select, ignore)
end
@inline function _update(p::P, x, table) where {N,P<:NTuple{N,Param}}
@assert size(table, 1) == N "number of rows must match the number of parameters"
@inline function _update(ps::P, x, table; select=SELECT, ignore=IGNORE) where {P<:ParamTuple}
@assert size(table, 1) == length(ps) "number of rows must match the number of parameters"
colnames = (c for c in Tables.columnnames(table) if !(c in (:component, :fieldname)))
newparams = map(p, tuple(1:N...)) do param, i
let names = keys(param);
Param(NamedTuple{names}(map(name -> Tables.getcolumn(table, name)[i], names)))
newparams = map(ps, tuple(1:length(ps)...)) do p, i
let names = keys(p);
setparent(p, NamedTuple{names}(map(name -> Tables.getcolumn(table, name)[i], names)))
end
end
Flatten.reconstruct(x, newparams, SELECT, IGNORE)
return Flatten.reconstruct(x, newparams, select, ignore)
end

"""
Expand All @@ -270,17 +273,17 @@ StaticModel(m::AbstractModel) = StaticModel(parent(m))

# Model Utils

_expandpars(x) = Flatten.reconstruct(parent, _expandkeys(parent), SELECT, IGNORE)
_expandpars(x) = Flatten.reconstruct(parent, _expandkeys(parent), select, ignore)
# Expand all Params to have the same keys, filling with `nothing`
# This probably will allocate due to `union` returning `Vector`
function _expandkeys(x)
pars = params(x)
allkeys = Tuple(union(map(keys, pars)...))
return map(pars) do par
return map(pars) do p
vals = map(allkeys) do key
get(par, key, nothing)
get(p, key, nothing)
end
Param(NamedTuple{allkeys}(vals))
setparent(p, NamedTuple{allkeys}(vals))
end
end

Expand Down
25 changes: 17 additions & 8 deletions src/param.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@ function ConstructionBase.setproperties(p::P, patch::NamedTuple) where P <: Abst
P.name.wrapper(fields)
end

@inline withunits(m, args...) = map(p -> withunits(p, args...), params(m))
@inline withunits(x, args...; kw...) = map(p -> withunits(p, args...), params(x; kw...))
@inline function withunits(p::AbstractParam, fn::Symbol=:val)
_applyunits(*, getproperty(p, fn), get(p, :units, nothing))
end

@inline stripunits(m, xs) = map(stripunits, params(m), xs)
@inline stripunits(x, xs; kw...) = map(stripunits, params(x; kw...), xs)
@inline function stripunits(p::AbstractParam, x)
_applyunits(/, x, get(p, :units, nothing))
end

setparent(m::P, newparent) where P<:AbstractParam = P.name.wrapper(newparent)

# Param might have `nothing` for units
@inline _applyunits(f, x, units) = f(x, units)
@inline _applyunits(f, x, ::Nothing) = x
Expand All @@ -42,9 +44,11 @@ Base.values(p::AbstractParam) = values(parent(p))
@inline Base.getproperty(p::AbstractParam, x::Symbol) = getproperty(parent(p), x)
@inline Base.get(p::AbstractParam, key::Symbol, default) = get(parent(p), key, default)
@inline Base.getindex(p::AbstractParam, i) = getindex(parent(p), i)
Base.parent(p::AbstractParam) = getfield(p, :parent)


# AbstractNumber interface

Base.convert(::Type{Number}, x::AbstractParam) = AbstractNumbers.number(x)
Base.convert(::Type{P}, x::P) where {P<:AbstractParam} = x
AbstractNumbers.number(p::AbstractParam) = withunits(p)
Expand Down Expand Up @@ -79,15 +83,20 @@ end
Param(val; kwargs...) = Param((; val=val, kwargs...))
Param(; kwargs...) = Param((; kwargs...))

Base.parent(p::Param) = getfield(p, :parent)

# Methods for objects that hold params
params(x) = Flatten.flatten(x, SELECT, IGNORE)
stripparams(x) = hasparam(x) ? Flatten.reconstruct(x, withunits(x), SELECT, IGNORE) : x
params(x; select=SELECT, ignore=IGNORE) = Flatten.flatten(x, select, ignore)
function stripparams(x; select=SELECT, ignore=IGNORE)
hasparam(x) ? Flatten.reconstruct(x, withunits(x), select, ignore) : x
end
function paramfieldnames(x; select=SELECT, ignore=IGNORE)
Flatten.fieldnameflatten(x, select, ignore)
end
function paramparenttypes(x; select=SELECT, ignore=IGNORE)
Flatten.metaflatten(x, _fieldparentbasetype, select, ignore)
end


# Utils
hasparam(obj) = length(params(obj)) > 0
hasparam(obj; kw...) = length(params(obj; kw...)) > 0

_checkhasval(nt::NamedTuple{Keys}) where {Keys} = first(Keys) == :val || _novalerror(nt)
# @noinline avoids allocations unless there is actually an error
Expand Down
8 changes: 4 additions & 4 deletions src/tables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ _columntypes(m) = map(k -> Union{map(typeof, getindex(m, k))...}, keys(m))
# As Columns
Tables.columnaccess(::Type{<:AbstractModel}) = true
Tables.columns(m::AbstractModel) = m
Tables.getcolumn(m::AbstractModel, nm::Symbol) = collect(getindex(m, nm))
Tables.getcolumn(m::AbstractModel, i::Int) = collect(getindex(m, i))
Tables.getcolumn(m::AbstractModel, ::Type{T}, col::Int, nm::Symbol) where T =
collect(getindex(m, nm))
Tables.getcolumn(m::AbstractModel, nm::Symbol; kw...) = collect(getindex(m, nm; kw...))
Tables.getcolumn(m::AbstractModel, i::Int; kw...) = collect(getindex(m, i; kw...))
Tables.getcolumn(m::AbstractModel, ::Type{T}, col::Int, nm::Symbol; kw...) where T =
collect(getindex(m, nm; kw...))
25 changes: 25 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,28 @@ end
@test isa(tuplegroups.A.i, Tuple)
@test isa(tuplegroups.B.j, Tuple)
end

struct FreeParam{T,P<:NamedTuple} <: AbstractParam{T}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're going to do this, then the macro to abstract away this boilerplate (and reduce chance of user error) should be included in this PR, I think.

parent::P
end
FreeParam(nt::NT) where {NT<:NamedTuple} = begin
ModelParameters._checkhasval(nt)
FreeParam{typeof(nt.val),NT}(nt)
end
FreeParam(val; kwargs...) = FreeParam((; val=val, kwargs...))
FreeParam(; kwargs...) = FreeParam((; kwargs...))

@testset "spefify param type" begin
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you seem to agree that the FreeParam/FixedParam is not a good use-case for this feature, I would like to see an example use-case where this feature actually is actually necessary in contrast to using table manipulations, i.e. something where simply adding an extra field to Param wouldn't cut it or would be prohibitively difficult. I can't really think of one off the top of my head, but maybe you can.

sf = S2(
FreeParam(99),
FreeParam(7),
Param(100.0; bounds=(50.0, 150.0))
)
@test params(sf; select=FreeParam) == (FreeParam(99), FreeParam(7))
m = Model(sf)
@test m[:val, select=Param] == (100.0,)
@test getindex(m, :val; select=FreeParam) == (99, 7)
m[:bounds, select=FreeParam] = ((1, 100), (1, 10))
@test m[:bounds] == ((1, 100), (1, 10), (50.0, 150.0))
end