Skip to content

Commit 6ef2052

Browse files
committed
allow more Param types
1 parent e64c6c7 commit 6ef2052

File tree

5 files changed

+102
-63
lines changed

5 files changed

+102
-63
lines changed

src/ModelParameters.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import AbstractNumbers,
1414
Tables
1515

1616
using Setfield
17+
18+
using Base: tail
1719

1820
export AbstractModel, Model, StaticModel
1921

src/model.jl

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,17 @@ abstract type AbstractModel end
8383
Base.parent(m::AbstractModel) = getfield(m, :parent)
8484
setparent(m::AbstractModel, newparent) = @set m.parent = newparent
8585

86-
params(m::AbstractModel) = params(parent(m))
87-
stripparams(m::AbstractModel) = stripparams(parent(m))
86+
params(m::AbstractModel; kw...) = params(parent(m); kw...)
87+
stripparams(m::AbstractModel; kw...) = stripparams(parent(m); kw...)
88+
paramfieldnames(m::AbstractModel; kw...) = paramfieldnames(parent(m); kw...)
89+
paramparenttypes(m::AbstractModel; kw...) = paramparenttypes(parent(m); kw...)
8890
function update(x::T, values) where {T<:AbstractModel}
8991
hasfield(T, :parent) || _updatenotdefined(T)
9092
setparent(x, update(parent(x), values))
9193
end
9294

9395
@noinline _update_methoderror(T) = error("Interface method `update` is not defined for $T")
9496

95-
paramfieldnames(m) = Flatten.fieldnameflatten(parent(m), SELECT, IGNORE)
96-
paramparenttypes(m) = Flatten.metaflatten(parent(m), _fieldparentbasetype, SELECT, IGNORE)
9797

9898
_fieldparentbasetype(T, ::Type{Val{N}}) where N = T.name.wrapper
9999

@@ -102,64 +102,64 @@ _fieldparentbasetype(T, ::Type{Val{N}}) where N = T.name.wrapper
102102

103103
# It may seem expensive always calling `param`, but flattening the
104104
# object occurs once at compile-time, and should have very little cost here.
105-
Base.length(m::AbstractModel) = length(params(m))
106-
Base.size(m::AbstractModel) = (length(params(m)),)
107-
Base.first(m::AbstractModel) = first(params(m))
108-
Base.last(m::AbstractModel) = last(params(m))
105+
Base.length(m::AbstractModel; kw...) = length(params(m; kw...))
106+
Base.size(m::AbstractModel; kw...) = (length(params(m; kw...)),)
107+
Base.first(m::AbstractModel; kw...) = first(params(m; kw...))
108+
Base.last(m::AbstractModel; kw...) = last(params(m; kw...))
109109
Base.firstindex(m::AbstractModel) = 1
110-
Base.lastindex(m::AbstractModel) = length(params(m))
110+
Base.lastindex(m::AbstractModel; kw...) = length(params(m; kw...))
111111
Base.getindex(m::AbstractModel, i) = getindex(params(m), i)
112112
Base.iterate(m::AbstractModel) = (first(params(m)), 1)
113113
Base.iterate(m::AbstractModel, s) = s > length(m) ? nothing : (params(m)[s], s + 1)
114114

115115
# Vector methods
116-
Base.collect(m::AbstractModel) = collect(m.val)
117-
Base.vec(m::AbstractModel) = collect(m)
118-
Base.Array(m::AbstractModel) = vec(m)
116+
Base.collect(m::AbstractModel; kw...) = collect(m[:val; kw...])
117+
Base.vec(m::AbstractModel; kw...) = collect(m; kw...)
118+
Base.Array(m::AbstractModel; kw...) = vec(m; kw...)
119119

120120
# Dict methods - data as columns
121121
Base.haskey(m::AbstractModel, key::Symbol) = key in keys(m)
122122
Base.keys(m::AbstractModel) = _keys(params(m), m)
123123

124-
@inline function Base.setindex!(m::AbstractModel, x, nm::Symbol)
124+
@inline function Base.setindex!(m::AbstractModel, x, nm::Symbol; kw...)
125125
if nm == :component
126126
erorr("cannot set :component index")
127127
elseif nm == :fieldname
128128
erorr("cannot set :fieldname index")
129129
else
130130
newparent = if nm in keys(m)
131-
_setindex(parent(m), Tuple(x), nm)
131+
_setindex(parent(m), Tuple(x), nm; kw...)
132132
else
133-
_addindex(parent(m), Tuple(x), nm)
133+
_addindex(parent(m), Tuple(x), nm; kw...)
134134
end
135135
setparent!(m, newparent)
136136
end
137137
end
138138
# TODO do this with lenses
139-
@inline function _setindex(obj, xs::Tuple, nm::Symbol)
139+
@inline function _setindex(obj, xs::Tuple, nm::Symbol; select=SELECT, ignore=IGNORE)
140140
lens = Setfield.PropertyLens{nm}()
141-
newparams = map(params(obj), xs) do par, x
142-
Param(Setfield.set(parent(par), lens, x))
141+
newparams = map(params(obj; select, ignore), xs) do p, x
142+
setparent(p, Setfield.set(parent(p), lens, x))
143143
end
144-
Flatten.reconstruct(obj, newparams, SELECT, IGNORE)
144+
Flatten.reconstruct(obj, newparams, select, ignore)
145145
end
146-
@inline function _addindex(obj, xs::Tuple, nm::Symbol)
147-
newparams = map(params(obj), xs) do par, x
148-
Param((; parent(par)..., (nm => x,)...))
146+
@inline function _addindex(obj, xs::Tuple, nm::Symbol; select=SELECT, ignore=IGNORE)
147+
newparams = map(params(obj), xs) do p, x
148+
setparent(p, (; parent(p)..., (nm => x,)...))
149149
end
150-
Flatten.reconstruct(obj, newparams, SELECT, IGNORE)
150+
Flatten.reconstruct(obj, newparams, select, ignore)
151151
end
152152

153153
_keys(params::Tuple, m::AbstractModel) = (:component, :fieldname, keys(first(params))...)
154154
_keys(params::Tuple{}, m::AbstractModel) = ()
155155

156-
@inline function Base.getindex(m::AbstractModel, nm::Symbol)
156+
@inline function Base.getindex(m::AbstractModel, nm::Symbol; kw...)
157157
if nm == :component
158-
paramparenttypes(m)
158+
paramparenttypes(m; kw...)
159159
elseif nm == :fieldname
160-
paramfieldnames(m)
160+
paramfieldnames(m; kw...)
161161
else
162-
map(p -> getindex(p, nm), params(m))
162+
map(p -> getindex(p, nm), params(m; kw...))
163163
end
164164
end
165165

@@ -181,16 +181,15 @@ end
181181

182182
setparent!(m::AbstractModel, newparent) = setfield!(m, :parent, newparent)
183183

184-
update!(m::AbstractModel, vals::AbstractVector{<:AbstractParam}) = update!(m, Tuple(vals))
185-
function update!(params::Tuple{<:AbstractParam,Vararg{<:AbstractParam}})
186-
setparent!(m, Flatten.reconstruct(parent(m), params, SELECT, IGNORE))
184+
function update!(m::AbstractModel, vals::AbstractVector{<:AbstractParam}; kw...)
185+
update!(m, Tuple(vals); kw...)
187186
end
188-
function update!(m::AbstractModel, table)
187+
function update!(m::AbstractModel, table; kw...)
189188
cols = (c for c in Tables.columnnames(table) if !(c in (:component, :fieldname)))
190189
for col in cols
191-
setindex!(m, Tables.getcolumn(table, col), col)
190+
setindex!(m, Tables.getcolumn(table, col), col; kw...)
192191
end
193-
m
192+
return m
194193
end
195194

196195
"""
@@ -220,30 +219,34 @@ mutable struct Model <: AbstractModel
220219
end
221220
Model(m::AbstractModel) = Model(parent(m))
222221

223-
@inline @generated function _update_params(ps::P, values::Union{<:AbstractVector,<:Tuple}) where {N,P<:NTuple{N,Param}}
222+
const ParamTuple = Tuple{Vararg{<:AbstractParam}}
223+
224+
@inline @generated function _update_params(ps::P, values::Union{<:AbstractVector,<:Tuple}) where {P<:ParamTuple}
224225
expr = Expr(:tuple)
225-
for i in 1:N
226-
expr_i = :(Param(NamedTuple{keys(ps[$i])}((values[$i], Base.tail(parent(ps[$i]))...))))
226+
for i in 1:length(ps.parameters)
227+
expr_i = :(setparent(ps[$i], NamedTuple{keys(ps[$i])}((values[$i], tail(parent(ps[$i]))...))))
227228
push!(expr.args, expr_i)
228229
end
229230
return expr
230231
end
231232

232-
update(x, values) = _update(ModelParameters.params(x), x, values)
233-
@inline function _update(p::P, x, values::Union{<:AbstractVector,<:Tuple}) where {N,P<:NTuple{N,Param}}
234-
@assert length(values) == N "values length must match the number of parameters"
235-
newparams = _update_params(p, values)
236-
Flatten.reconstruct(x, newparams, SELECT, IGNORE)
233+
update(x, values; kw...) = _update(ModelParameters.params(x), x, values; kw...)
234+
@inline function _update(ps::P, x, values::Union{<:AbstractVector,<:Tuple};
235+
select=SELECT, ignore=IGNORE
236+
) where {P<:ParamTuple}
237+
@assert length(values) == length(ps) "values length must match the number of parameters"
238+
newparams = _update_params(ps, values)
239+
return Flatten.reconstruct(x, newparams, select, ignore)
237240
end
238-
@inline function _update(p::P, x, table) where {N,P<:NTuple{N,Param}}
239-
@assert size(table, 1) == N "number of rows must match the number of parameters"
241+
@inline function _update(ps::P, x, table; select=SELECT, ignore=IGNORE) where {P<:ParamTuple}
242+
@assert size(table, 1) == length(ps) "number of rows must match the number of parameters"
240243
colnames = (c for c in Tables.columnnames(table) if !(c in (:component, :fieldname)))
241-
newparams = map(p, tuple(1:N...)) do param, i
242-
let names = keys(param);
243-
Param(NamedTuple{names}(map(name -> Tables.getcolumn(table, name)[i], names)))
244+
newparams = map(ps, tuple(1:length(ps)...)) do p, i
245+
let names = keys(p);
246+
setparent(p, NamedTuple{names}(map(name -> Tables.getcolumn(table, name)[i], names)))
244247
end
245248
end
246-
Flatten.reconstruct(x, newparams, SELECT, IGNORE)
249+
return Flatten.reconstruct(x, newparams, select, ignore)
247250
end
248251

249252
"""
@@ -270,17 +273,17 @@ StaticModel(m::AbstractModel) = StaticModel(parent(m))
270273

271274
# Model Utils
272275

273-
_expandpars(x) = Flatten.reconstruct(parent, _expandkeys(parent), SELECT, IGNORE)
276+
_expandpars(x) = Flatten.reconstruct(parent, _expandkeys(parent), select, ignore)
274277
# Expand all Params to have the same keys, filling with `nothing`
275278
# This probably will allocate due to `union` returning `Vector`
276279
function _expandkeys(x)
277280
pars = params(x)
278281
allkeys = Tuple(union(map(keys, pars)...))
279-
return map(pars) do par
282+
return map(pars) do p
280283
vals = map(allkeys) do key
281-
get(par, key, nothing)
284+
get(p, key, nothing)
282285
end
283-
Param(NamedTuple{allkeys}(vals))
286+
setparent(p, NamedTuple{allkeys}(vals))
284287
end
285288
end
286289

src/param.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,18 @@ function ConstructionBase.setproperties(p::P, patch::NamedTuple) where P <: Abst
1515
P.name.wrapper(fields)
1616
end
1717

18-
@inline withunits(m, args...) = map(p -> withunits(p, args...), params(m))
18+
@inline withunits(x, args...; kw...) = map(p -> withunits(p, args...), params(x; kw...))
1919
@inline function withunits(p::AbstractParam, fn::Symbol=:val)
2020
_applyunits(*, getproperty(p, fn), get(p, :units, nothing))
2121
end
2222

23-
@inline stripunits(m, xs) = map(stripunits, params(m), xs)
23+
@inline stripunits(x, xs; kw...) = map(stripunits, params(x; kw...), xs)
2424
@inline function stripunits(p::AbstractParam, x)
2525
_applyunits(/, x, get(p, :units, nothing))
2626
end
2727

28+
setparent(m::P, newparent) where P<:AbstractParam = P.name.wrapper(newparent)
29+
2830
# Param might have `nothing` for units
2931
@inline _applyunits(f, x, units) = f(x, units)
3032
@inline _applyunits(f, x, ::Nothing) = x
@@ -42,9 +44,11 @@ Base.values(p::AbstractParam) = values(parent(p))
4244
@inline Base.getproperty(p::AbstractParam, x::Symbol) = getproperty(parent(p), x)
4345
@inline Base.get(p::AbstractParam, key::Symbol, default) = get(parent(p), key, default)
4446
@inline Base.getindex(p::AbstractParam, i) = getindex(parent(p), i)
47+
Base.parent(p::AbstractParam) = getfield(p, :parent)
4548

4649

4750
# AbstractNumber interface
51+
4852
Base.convert(::Type{Number}, x::AbstractParam) = AbstractNumbers.number(x)
4953
Base.convert(::Type{P}, x::P) where {P<:AbstractParam} = x
5054
AbstractNumbers.number(p::AbstractParam) = withunits(p)
@@ -79,15 +83,20 @@ end
7983
Param(val; kwargs...) = Param((; val=val, kwargs...))
8084
Param(; kwargs...) = Param((; kwargs...))
8185

82-
Base.parent(p::Param) = getfield(p, :parent)
83-
84-
# Methods for objects that hold params
85-
params(x) = Flatten.flatten(x, SELECT, IGNORE)
86-
stripparams(x) = hasparam(x) ? Flatten.reconstruct(x, withunits(x), SELECT, IGNORE) : x
86+
params(x; select=SELECT, ignore=IGNORE) = Flatten.flatten(x, select, ignore)
87+
function stripparams(x; select=SELECT, ignore=IGNORE)
88+
hasparam(x) ? Flatten.reconstruct(x, withunits(x), select, ignore) : x
89+
end
90+
function paramfieldnames(x; select=SELECT, ignore=IGNORE)
91+
Flatten.fieldnameflatten(x, select, ignore)
92+
end
93+
function paramparenttypes(x; select=SELECT, ignore=IGNORE)
94+
Flatten.metaflatten(x, _fieldparentbasetype, select, ignore)
95+
end
8796

8897

8998
# Utils
90-
hasparam(obj) = length(params(obj)) > 0
99+
hasparam(obj; kw...) = length(params(obj; kw...)) > 0
91100

92101
_checkhasval(nt::NamedTuple{Keys}) where {Keys} = first(Keys) == :val || _novalerror(nt)
93102
# @noinline avoids allocations unless there is actually an error

src/tables.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ _columntypes(m) = map(k -> Union{map(typeof, getindex(m, k))...}, keys(m))
1010
# As Columns
1111
Tables.columnaccess(::Type{<:AbstractModel}) = true
1212
Tables.columns(m::AbstractModel) = m
13-
Tables.getcolumn(m::AbstractModel, nm::Symbol) = collect(getindex(m, nm))
14-
Tables.getcolumn(m::AbstractModel, i::Int) = collect(getindex(m, i))
15-
Tables.getcolumn(m::AbstractModel, ::Type{T}, col::Int, nm::Symbol) where T =
16-
collect(getindex(m, nm))
13+
Tables.getcolumn(m::AbstractModel, nm::Symbol; kw...) = collect(getindex(m, nm; kw...))
14+
Tables.getcolumn(m::AbstractModel, i::Int; kw...) = collect(getindex(m, i; kw...))
15+
Tables.getcolumn(m::AbstractModel, ::Type{T}, col::Int, nm::Symbol; kw...) where T =
16+
collect(getindex(m, nm; kw...))
1717

1818
# As rows
1919
Tables.rowaccess(::Type{<:AbstractModel}) = true

test/runtests.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,28 @@ end
232232
@test isa(tuplegroups.A.i, Tuple)
233233
@test isa(tuplegroups.B.j, Tuple)
234234
end
235+
236+
struct FreeParam{T,P<:NamedTuple} <: AbstractParam{T}
237+
parent::P
238+
end
239+
FreeParam(nt::NT) where {NT<:NamedTuple} = begin
240+
ModelParameters._checkhasval(nt)
241+
FreeParam{typeof(nt.val),NT}(nt)
242+
end
243+
FreeParam(val; kwargs...) = FreeParam((; val=val, kwargs...))
244+
FreeParam(; kwargs...) = FreeParam((; kwargs...))
245+
246+
@testset "spefify param type" begin
247+
sf = S2(
248+
FreeParam(99),
249+
FreeParam(7),
250+
Param(100.0; bounds=(50.0, 150.0))
251+
)
252+
@test params(sf; select=FreeParam) == (FreeParam(99), FreeParam(7))
253+
m = Model(sf)
254+
@test m[:val, select=Param] == (100.0,)
255+
@test getindex(m, :val; select=FreeParam) == (99, 7)
256+
m[:bounds, select=FreeParam] = ((1, 100), (1, 10))
257+
@test m[:bounds] == ((1, 100), (1, 10), (50.0, 150.0))
258+
end
259+

0 commit comments

Comments
 (0)