Skip to content

Commit

Permalink
Add Rasters.sample (#791)
Browse files Browse the repository at this point in the history
* extract tweaks

* add rasters.sample

* get rid of some toy code

* just write kw... in docs

Co-authored-by: Rafael Schouten <[email protected]>

* use Type{G}

Co-authored-by: Rafael Schouten <[email protected]>

* add a missing where G

* neater formatting

* better reuse _rowtypes code

* more code style

Co-authored-by: Rafael Schouten <[email protected]>

* extract tweaks

* add rasters.sample

* get rid of some toy code

* neater formatting

* better reuse _rowtypes code

* more code style

Co-authored-by: Rafael Schouten <[email protected]>

* merge _rowtype changes from other branch

* optimization for extract points with skipmissing

* streamline _rowtype

* correctly use _rowtype

* improve type stability

* switch to using dimindices internally

* let users specify geometry type

* just use a broadcast

* don't need to be compatible with julia < 1.9

Co-authored-by: Rafael Schouten <[email protected]>

* deal with type stability better

* allow not providing `n`

* tweak the docstring

---------

Co-authored-by: Rafael Schouten <[email protected]>
  • Loading branch information
tiemvanderdeure and rafaqz authored Oct 16, 2024
1 parent 3202707 commit 1f00143
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 38 deletions.
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
Proj = "c94c279d-25a6-4763-9509-64d165bea63e"
RasterDataSources = "3cb90ccd-e1b6-4867-9617-4276c8b2ca36"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
ZarrDatasets = "519a4cdf-1362-424a-9ea1-b1d782dbb24b"

[extensions]
Expand All @@ -44,6 +45,7 @@ RastersMakieExt = "Makie"
RastersNCDatasetsExt = "NCDatasets"
RastersProjExt = "Proj"
RastersRasterDataSourcesExt = "RasterDataSources"
RastersStatsBaseExt = "StatsBase"
RastersZarrDatasetsExt = "ZarrDatasets"

[compat]
Expand Down Expand Up @@ -81,6 +83,7 @@ SafeTestsets = "0.1"
Setfield = "0.6, 0.7, 0.8, 1"
Shapefile = "0.10, 0.11"
Statistics = "1"
StatsBase = "0.34"
Test = "1"
ZarrDatasets = "0.1"
julia = "1.10"
Expand All @@ -101,8 +104,10 @@ Proj = "c94c279d-25a6-4763-9509-64d165bea63e"
RasterDataSources = "3cb90ccd-e1b6-4867-9617-4276c8b2ca36"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Shapefile = "8e980c4a-a4fe-5da2-b3a7-4b4b0353a2f4"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "ArchGDAL", "CFTime", "CoordinateTransformations", "DataFrames", "GeoDataFrames", "GeometryBasics", "GRIBDatasets", "NCDatasets", "Plots", "Proj", "RasterDataSources", "SafeTestsets", "Shapefile", "Statistics", "Test", "ZarrDatasets"]
test = ["Aqua", "ArchGDAL", "CFTime", "CoordinateTransformations", "DataFrames", "GeoDataFrames", "GeometryBasics", "GRIBDatasets", "NCDatasets", "Plots", "Proj", "RasterDataSources", "SafeTestsets", "Shapefile", "StableRNGs", "Statistics", "StatsBase", "Test", "ZarrDatasets"]
13 changes: 13 additions & 0 deletions ext/RastersStatsBaseExt/RastersStatsBaseExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module RastersStatsBaseExt

using Rasters, StatsBase
using StatsBase.Random

const RA = Rasters

import Rasters: _True, _False, _booltype, istrue
import Rasters.DimensionalData as DD

include("sample.jl")

end # Module
94 changes: 94 additions & 0 deletions ext/RastersStatsBaseExt/sample.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
Rasters.sample(x::RA.RasterStackOrArray, args...; kw...) = Rasters.sample(Random.GLOBAL_RNG, x, args...; kw...)
@inline function Rasters.sample(
rng::Random.AbstractRNG, x::RA.RasterStackOrArray, args...;
geometry=(X,Y),
index = false,
names=RA._names(x),
name=names,
skipmissing=false,
weights=nothing,
weightstype::Type{<:StatsBase.AbstractWeights}=StatsBase.Weights,
kw...
)
na = DD._astuple(name)
geometry, geometrytype, dims = _geometrytype(x, geometry)

_sample(rng, x, args...;
dims,
names = NamedTuple{na}(na),
geometry,
geometrytype,
# These keywords are converted to _True/_False for type stability later on
index = _booltype(index),
skipmissing = _booltype(skipmissing),
weights,
weightstype,
kw... # passed to StatsBase.sample, could be replace or ordered
)
end
function _sample(
rng, x, n::Integer;
dims, names::NamedTuple{K}, geometry, geometrytype::Type{G}, index, skipmissing, weights, weightstype, kw...,
) where {K, G}
indices = sample_indices(rng, x, skipmissing, weights, weightstype, n; kw...)
T = RA._rowtype(x, G; geometry, index, skipmissing, skipinvalid = _True(), names)
x2 = x isa AbstractRasterStack ? x[K] : RasterStack(NamedTuple{K}((x,)))
return _getindices(T, x2, dims, indices)
end
function _sample(
rng, x;
dims, names::NamedTuple{K}, geometry, geometrytype::Type{G}, index, skipmissing, weights, weightstype, kw...,
) where {K, G}
indices = sample_indices(rng, x, skipmissing, weights, weightstype)
T = RA._rowtype(x, G; geometry, index, skipmissing, skipinvalid = _True(), names)
x2 = x isa AbstractRasterStack ? x[K] : RasterStack(NamedTuple{K}((x,)))
return _getindex(T, x2, dims, indices)
end

_getindices(::Type{T}, x, dims, indices) where {T} =
broadcast(I -> _getindex(T, x, dims, I), indices)

function _getindex(::Type{T}, x::AbstractRasterStack{<:Any, NT}, dims, idx) where {T, NT}
RA._maybe_add_fields(
T,
NT(x[RA.commondims(idx, x)]),
DimPoints(dims)[RA.commondims(idx, dims)],
val(idx)
)
end

# args may be an integer or nothing
function sample_indices(rng, x, skipmissing, weights::Nothing, weightstype, args...; kw...)
if istrue(skipmissing)
wts = StatsBase.Weights(vec(boolmask(x)))
StatsBase.sample(rng, RA.DimIndices(x), wts, args...; kw...)
else
StatsBase.sample(rng, RA.DimIndices(x), args...; kw...)
end
end
function sample_indices(rng, x, skipmissing, weights::AbstractDimArray, ::Type{W}, args...; kw...) where W
wts = if istrue(skipmissing)
@d boolmask(x) .* weights
elseif dims(weights) == dims(x)
weights
else
@d ones(eltype(weights), dims(x)) .* weights
end |> vec |> W
StatsBase.sample(rng, RA.DimIndices(x), wts, args...; kw...)
end
function _geometrytype(x, geometry::Bool)
if geometry
error("Specify a geometry type by setting `geometry` to a Tuple or NamedTuple of Dimensions. E.g. `geometry = (X, Y)`")
else
return _False(), Nothing, dims(x)
end
end

function _geometrytype(x, geometry::Tuple)
dims = DD.commondims(DD.dims(x), geometry)
return _True(), Tuple{map(eltype, dims)...}, dims
end
function _geometrytype(x, geometry::NamedTuple{K}) where K
dims = DD.commondims(DD.dims(x), values(geometry))
return _True(), NamedTuple{K, Tuple{map(eltype, dims)...}}, dims
end
44 changes: 44 additions & 0 deletions src/extensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,50 @@ Note that cellarea returns the area in square m, while cellsize still uses squar
return cellarea(args...; kw..., radius = 6371.0088)
end

"""
Rasters.sample([rng], x, [n::Integer]; kw...)
Sample `n` random and optionally weighted points from from a `Raster` or `RasterStack`.
Returns a `Vector` of `NamedTuple`, closely resembling the return type of [`extract`](@ref).
Run `using StatsBase` to make this method available.
Note that this function is not exported to avoid confusion with StatsBase.sample
# Keywords
- `geometry`: include `:geometry` in returned `NamedTuple`. Specify the type and dimensions of the returned geometry by
providing a `Tuple` or `NamedTuple` of dimensions. Defaults to `(X,Y)`
- `index`: include `:index` of the `CartesianIndex` in returned `NamedTuple`, `false` by default.
- `name`: a `Symbol` or `Tuple` of `Symbol` corresponding to layer/s of a `RasterStack` to extract. All layers by default.
- `skipmissing`: skip missing points automatically.
- `weights`: A DimArray that matches one or more of the dimensions of `x` with weights for sampling.
- `weightstype`: a `StatsBase.AbstractWeights` specifying the type of weights. Defaults to `StatsBase.Weights`.
- `replace`: sample with replacement, `true` by default. See `StatsBase.sample`
- `ordered`: sample in order, `false` by default. See `StatsBase.sample`
# Example
This code draws 5 random points from a raster, weighted by cell area.
```julia
using Rasters, Rasters.Lookups, Proj, StatsBase
xdim = X(Projected(90.0:10.0:120; sampling=Intervals(Start()), crs=EPSG(4326)))
ydim = Y(Projected(0.0:10.0:50; sampling=Intervals(Start()), crs=EPSG(4326)))
myraster = rand(xdim, ydim)
Rasters.sample(myraster, 5; weights = cellarea(myraster))
# output
5-element Vector{@NamedTuple{geometry::Tuple{Float64, Float64}, ::Union{Missing, Float64}}}:
@NamedTuple{geometry::Tuple{Float64, Float64}, ::Union{Missing, Float64}}(((90.0, 10.0), 0.7360504790189618))
@NamedTuple{geometry::Tuple{Float64, Float64}, ::Union{Missing, Float64}}(((90.0, 30.0), 0.5447657183842469))
@NamedTuple{geometry::Tuple{Float64, Float64}, ::Union{Missing, Float64}}(((90.0, 30.0), 0.5447657183842469))
@NamedTuple{geometry::Tuple{Float64, Float64}, ::Union{Missing, Float64}}(((90.0, 10.0), 0.7360504790189618))
@NamedTuple{geometry::Tuple{Float64, Float64}, ::Union{Missing, Float64}}(((110.0, 10.0), 0.5291143028176258))
```
"""
sample(args...; kw...) = throw_extension_error(sample, "StatsBase", :RastersStatsBaseExt, args)



# Other shared stubs
function layerkeys end
function smapseries end
Expand Down
43 changes: 8 additions & 35 deletions src/methods/extract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function extract end
end

function _extract(A::RasterStackOrArray, geom::Missing, names, kw...)
T = _extractrowtype(A, geom; names, kw...)
T = _rowtype(A, geom; names, kw...)
[_maybe_add_fields(T, map(_ -> missing, names), missing, missing)]
end
function _extract(A::RasterStackOrArray, geom; names, kw...)
Expand All @@ -82,11 +82,7 @@ function _extract(A::RasterStackOrArray, ::Nothing, data;
names, skipmissing, geometrycolumn, kw...
)
geoms = _get_geometries(data, geometrycolumn)
T = if istrue(skipmissing)
_extractrowtype(A, nonmissingtype(eltype(geoms)); names, skipmissing, kw...)
else
_extractrowtype(A, eltype(geoms); names, skipmissing, kw...)
end
T = _rowtype(A, eltype(geoms); names, skipmissing, kw...)
# Handle empty / all missing cases
(length(geoms) > 0 && any(!ismissing, geoms)) || return T[]

Expand All @@ -95,25 +91,21 @@ function _extract(A::RasterStackOrArray, ::Nothing, data;
# We need to split out points from other geoms
# TODO this will fail with mixed point/geom vectors
if trait1 isa GI.PointTrait
rows = Vector{T}(undef, length(geoms))
if istrue(skipmissing)
T2 = _extractrowtype(A, eltype(geoms), Tuple{Int64, Int64};
names, skipmissing = _False(), kw...)
rows = Vector{T2}(undef, length(geoms))
j = 1
for i in eachindex(geoms)
g = geoms[i]
ismissing(g) && continue
e = _extract_point(T2, A, g, skipmissing; names, kw...)
e = _extract_point(T, A, g, skipmissing; names, kw...)
if !ismissing(e)
rows[j] = e
j += 1
end
nothing
end
deleteat!(rows, j:length(rows))
rows = T === T2 ? rows : T.(rows)
else
rows = Vector{T}(undef, length(geoms))
for i in eachindex(geoms)
g = geoms[i]
rows[i] = _extract_point(T, A, g, skipmissing; names, kw...)::T
Expand Down Expand Up @@ -141,14 +133,14 @@ end
function _extract(A::RasterStackOrArray, ::GI.AbstractMultiPointTrait, geom;
skipmissing, kw...
)
T = _extractrowtype(A, GI.getpoint(geom, 1); names, skipmissing, kw...)
T = _rowtype(A, GI.getpoint(geom, 1); names, skipmissing, kw...)
rows = (_extract_point(T, A, p, skipmissing; kw...) for p in GI.getpoint(geom))
return skipmissing isa _True ? collect(_skip_missing_rows(rows, _missingval_or_missing(A), names)) : collect(rows)
end
function _extract(A::RasterStackOrArray, ::GI.PointTrait, geom;
skipmissing, kw...
)
T = _extractrowtype(A, geom; names, skipmissing, kw...)
T = _rowtype(A, geom; names, skipmissing, kw...)
_extract_point(T, A, geom, skipmissing; kw...)
end
function _extract(A::RasterStackOrArray, t::GI.AbstractGeometryTrait, geom;
Expand All @@ -166,7 +158,7 @@ function _extract(A::RasterStackOrArray, t::GI.AbstractGeometryTrait, geom;
else
GI.x(p), GI.y(p)
end
T = _extractrowtype(A, tuplepoint; names, skipmissing, kw...)
T = _rowtype(A, tuplepoint; names, skipmissing, kw...)
B = boolmask(geom; to=template, kw...)
offset = CartesianIndex(map(x -> first(x) - 1, parentindices(parent(template))))
# Add a row for each pixel that is `true` in the mask
Expand Down Expand Up @@ -285,27 +277,8 @@ Base.@assume_effects :total function _maybe_add_fields(::Type{T}, props::NamedTu
:index in K ? merge((; geometry=point, index=I), props) : merge((; geometry=point), props)
else
:index in K ? merge((; index=I), props) : props
end
end

function _extractrowtype(x, g; geometry, index, skipmissing, names, kw...)
I = if istrue(skipmissing)
Tuple{Int, Int}
else
Union{Missing, Tuple{Int, Int}}
end
_extractrowtype(x, g, I; geometry, index, skipmissing, names, kw...)
end
function _extractrowtype(x, g, ::Type{I}; geometry, index, skipmissing, names, kw...) where I
G = if istrue(skipmissing)
nonmissingtype(typeof(g))
else
typeof(g)
end
_extractrowtype(x, G, I; geometry, index, skipmissing, names, kw...)
end |> T
end
_extractrowtype(x, ::Type{G}, ::Type{I}; geometry, index, skipmissing, names, kw...) where {G, I} =
_rowtype(x, G, I; geometry, index, skipmissing, names)

@inline _skip_missing_rows(rows, ::Missing, names) =
Iterators.filter(row -> !any(ismissing, row), rows)
Expand Down
12 changes: 10 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,12 +372,20 @@ _booltype(x) = x ? _True() : _False()
istrue(::_True) = true
istrue(::_False) = false

function _rowtype(x, ::Type{G}, ::Type{I}; geometry, index, skipmissing, names) where {G, I}
# skipinvalid: can G and I be missing. skipmissing: can nametypes be missing
_rowtype(x, g, args...; kw...) = _rowtype(x, typeof(g), args...; kw...)
function _rowtype(
x, ::Type{G}, i::Type{I} = typeof(size(x));
geometry, index, skipmissing, skipinvalid = skipmissing, names, kw...
) where {G, I}
_G = istrue(skipinvalid) ? nonmissingtype(G) : G
_I = istrue(skipinvalid) ? I : Union{Missing, I}
keys = _rowkeys(geometry, index, names)
types = _rowtypes(x, G, I, geometry, index, skipmissing, names)
types = _rowtypes(x, _G, _I, geometry, index, skipmissing, names)
NamedTuple{keys,types}
end


function _rowtypes(
x, ::Type{G}, ::Type{I}, geometry::_True, index::_True, skipmissing, names::NamedTuple{Names}
) where {G,I,Names}
Expand Down
Loading

0 comments on commit 1f00143

Please sign in to comment.