Skip to content

Commit

Permalink
Merge 3bab85c into 9059111
Browse files Browse the repository at this point in the history
  • Loading branch information
mjskay authored Nov 20, 2023
2 parents 9059111 + 3bab85c commit a76290c
Show file tree
Hide file tree
Showing 22 changed files with 495 additions and 53 deletions.
13 changes: 13 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ S3method(.subset_draws,draws_df)
S3method(.subset_draws,draws_list)
S3method(.subset_draws,draws_matrix)
S3method(.subset_draws,draws_rvars)
S3method(Complex,rvar)
S3method(Math,rvar)
S3method(Math,rvar_factor)
S3method(Ops,rvar)
Expand Down Expand Up @@ -275,6 +276,7 @@ S3method(rhat_basic,default)
S3method(rhat_basic,rvar)
S3method(rhat_nested,default)
S3method(rhat_nested,rvar)
S3method(sd,complex)
S3method(sd,default)
S3method(sd,rvar)
S3method(split_chains,draws)
Expand All @@ -297,6 +299,7 @@ S3method(thin_draws,draws)
S3method(thin_draws,rvar)
S3method(unique,rvar)
S3method(unique,rvar_factor)
S3method(var,complex)
S3method(var,default)
S3method(var,rvar)
S3method(variables,"NULL")
Expand All @@ -305,6 +308,7 @@ S3method(variables,draws_df)
S3method(variables,draws_list)
S3method(variables,draws_matrix)
S3method(variables,draws_rvars)
S3method(variance,complex)
S3method(variance,draws_array)
S3method(variance,draws_matrix)
S3method(variance,rvar)
Expand All @@ -313,6 +317,7 @@ S3method(vec_cast,character.rvar_factor)
S3method(vec_cast,character.rvar_ordered)
S3method(vec_cast,distribution.rvar)
S3method(vec_cast,rvar.character)
S3method(vec_cast,rvar.complex)
S3method(vec_cast,rvar.distribution)
S3method(vec_cast,rvar.double)
S3method(vec_cast,rvar.factor)
Expand All @@ -323,6 +328,7 @@ S3method(vec_cast,rvar.rvar)
S3method(vec_cast,rvar.rvar_factor)
S3method(vec_cast,rvar.rvar_ordered)
S3method(vec_cast,rvar_factor.character)
S3method(vec_cast,rvar_factor.complex)
S3method(vec_cast,rvar_factor.double)
S3method(vec_cast,rvar_factor.factor)
S3method(vec_cast,rvar_factor.integer)
Expand All @@ -332,6 +338,7 @@ S3method(vec_cast,rvar_factor.rvar)
S3method(vec_cast,rvar_factor.rvar_factor)
S3method(vec_cast,rvar_factor.rvar_ordered)
S3method(vec_cast,rvar_ordered.character)
S3method(vec_cast,rvar_ordered.complex)
S3method(vec_cast,rvar_ordered.double)
S3method(vec_cast,rvar_ordered.factor)
S3method(vec_cast,rvar_ordered.integer)
Expand All @@ -349,6 +356,7 @@ S3method(vec_ptype,rvar_factor)
S3method(vec_ptype,rvar_ordered)
S3method(vec_ptype2,character.rvar_factor)
S3method(vec_ptype2,character.rvar_ordered)
S3method(vec_ptype2,complex.rvar)
S3method(vec_ptype2,distribution.rvar)
S3method(vec_ptype2,double.rvar)
S3method(vec_ptype2,factor.rvar_factor)
Expand All @@ -357,6 +365,7 @@ S3method(vec_ptype2,integer.rvar)
S3method(vec_ptype2,logical.rvar)
S3method(vec_ptype2,ordered.rvar_factor)
S3method(vec_ptype2,ordered.rvar_ordered)
S3method(vec_ptype2,rvar.complex)
S3method(vec_ptype2,rvar.distribution)
S3method(vec_ptype2,rvar.double)
S3method(vec_ptype2,rvar.integer)
Expand Down Expand Up @@ -396,6 +405,7 @@ export(as_draws_list)
export(as_draws_matrix)
export(as_draws_rvars)
export(as_rvar)
export(as_rvar_complex)
export(as_rvar_factor)
export(as_rvar_integer)
export(as_rvar_logical)
Expand Down Expand Up @@ -438,7 +448,10 @@ export(is_draws_list)
export(is_draws_matrix)
export(is_draws_rvars)
export(is_rvar)
export(is_rvar_complex)
export(is_rvar_factor)
export(is_rvar_integer)
export(is_rvar_logical)
export(is_rvar_ordered)
export(iteration_ids)
export(mad)
Expand Down
4 changes: 2 additions & 2 deletions R/as_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ check_draws_object <- function(x) {
#' @noRd
check_variables_are_numeric <- function(
x, to = "draws_array",
is_non_numeric = function(x_i) !is.numeric(x_i) && !is.logical(x_i),
is_non_numeric = function(x_i) !is.numeric(x_i) && !is.logical(x_i) && !is.complex(x_i),
convert = TRUE
) {

Expand Down Expand Up @@ -145,7 +145,7 @@ validate_draws_per_variable <- function(...) {
# '.nchains' is an additional argument in chain supporting formats
stop_no_call("'.nchains' is not supported for this format.")
}
out <- lapply(out, as.numeric)
out <- lapply(out, function(x) if (is.numeric(x) || is.complex(x)) x else as.numeric(x))
ndraws_per_variable <- lengths(out)
ndraws <- max(ndraws_per_variable)
if (!all(ndraws_per_variable %in% c(1, ndraws))) {
Expand Down
31 changes: 27 additions & 4 deletions R/rvar-.R
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,14 @@ setOldClass(get_rvar_class(ordered(NULL)))

# helpers: validation -----------------------------------------------------------------

# check the given rvar is not complex
check_rvar_not_complex <- function(x, f = NULL) {
if (is_rvar_complex(x)) {
f <- if (is.null(f)) "" else paste0("`", f, "` ")
stop_no_call("Cannot apply ", f, "function to complex rvars.")
}
}

# Check the passed yank index (for x[[...]]) is valid
check_rvar_yank_index = function(x, i, ...) {
index <- dots_list(i, ..., .preserve_empty = TRUE, .ignore_empty = "none")
Expand Down Expand Up @@ -948,12 +956,16 @@ summarise_rvar_within_draws <- function(x, .f, ..., .transpose = FALSE, .when_em
#' by first collapsing dimensions into columns of the draws matrix
#' (so that .f can be a rowXXX() function)
#' @param x an rvar
#' @param name function name to use for error messages
#' @param .name function name to use for error messages
#' @param .f a function that takes a matrix and summarises its rows, like rowMeans
#' @param ... arguments passed to `.f`
#' @param .ordered_okay can this function be applied to rvar_ordereds?
#' @noRd
summarise_rvar_within_draws_via_matrix <- function(x, .name, .f, ..., .ordered_okay = FALSE) {
if (is_rvar_complex(x)) {
return(summarise_rvar_within_draws(x, match.fun(.name), ...))
}

.length <- length(x)
if (!.length) {
x <- rvar()
Expand All @@ -966,7 +978,7 @@ summarise_rvar_within_draws_via_matrix <- function(x, .name, .f, ..., .ordered_o
.draws <- .f(draws_of(as_rvar_numeric(x)), ...)
.draws <- while_preserving_dims(function(.draws) ordered(.levels[round(.draws)], .levels), .draws)
} else if (is_rvar_factor(x)) {
stop_no_call("Cannot apply `", .name, "` function to rvar_factor objects.")
stop_no_call("Cannot apply `rvar_", .name, "` function to rvar_factor objects.")
} else {
.draws <- .f(draws_of(x), ...)
}
Expand Down Expand Up @@ -997,18 +1009,29 @@ summarise_rvar_by_element <- function(x, .f, ...) {
#' by first collapsing dimensions into columns of the draws matrix, applying the
#' function, then restoring dimensions (so that .f can be a colXXX() function)
#' @param x an rvar
#' @param name function name to use for error messages
#' @param .name function name to use for error messages, and also function to
#' be used as a backup for complex numbers
#' @param .f a function that takes a matrix and summarises its columns, like colMeans
#' @param .extra_dim extra dims added by `.f` to the output, e.g. in the case of
#' matrixStats::colRanges this is `2`
#' @param .extra_dimnames extra dimension names for dims added by `.f` to the output
#' @param .ordered_okay can this function be applied to rvar_ordereds?
#' @param .factor_okay can this function be applied to rvar_factors?
#' @param .complex_okay can this function be applied to complex rvars? If not,
#' the function match.fun(.name) will be used instead, element-by-element.
#' @param ... arguments passed to `.f`
#' @noRd
summarise_rvar_by_element_via_matrix <- function(
x, .name, .f, .extra_dim = NULL, .extra_dimnames = NULL, .ordered_okay = TRUE, .factor_okay = FALSE, ...
x, .name, .f,
.extra_dim = NULL, .extra_dimnames = NULL,
.ordered_okay = TRUE, .factor_okay = FALSE,
.complex_okay = FALSE,
...
) {
if (is_rvar_complex(x) && !.complex_okay) {
return(summarise_rvar_by_element(x, match.fun(.name), ...))
}

.dim <- dim(x)
.dimnames <- dimnames(x)
.length <- length(x)
Expand Down
75 changes: 72 additions & 3 deletions R/rvar-cast.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
#' @details For objects that are already [`rvar`]s, returns them (with modified dimensions
#' if `dim` is not `NULL`).
#'
#' For numeric or logical vectors or arrays, returns an [`rvar`] with a single draw and
#' For [`numeric`], [`complex`], or [`logical`] vectors or arrays, returns an [`rvar`] with a single draw and
#' the same dimensions as `x`. This is in contrast to the [rvar()] constructor, which
#' treats the first dimension of `x` as the draws dimension. As a result, `as_rvar()`
#' is useful for creating constants.
#'
#' While `as_rvar()` attempts to pick the most suitable subtype of [`rvar`] based on the
#' type of `x` (possibly returning an [`rvar_factor`] or [`rvar_ordered`]),
#' `as_rvar_numeric()`, `as_rvar_integer()`, and `as_rvar_logical()` always coerce
#' the draws of the output [`rvar`] to be [`numeric`], [`integer`], or [`logical`]
#' `as_rvar_numeric()`, `as_rvar_complex()`, `as_rvar_integer()`, and `as_rvar_logical()` always coerce
#' the draws of the output [`rvar`] to be [`numeric`], [`complex`], [`integer`], or [`logical`]
#' (respectively), and always return a base [`rvar`], never a subtype.
#'
#' @seealso [rvar()] to construct [`rvar`]s directly. See [rdo()], [rfun()], and
Expand Down Expand Up @@ -87,6 +87,14 @@ as_rvar_numeric <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) {
out
}

#' @rdname as_rvar
#' @export
as_rvar_complex <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) {
out <- as_rvar(x, dim = dim, dimnames = dimnames, nchains = nchains)
draws_of(out) <- while_preserving_dims(as.complex, draws_of(out))
out
}

#' @rdname as_rvar
#' @export
as_rvar_integer <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) {
Expand Down Expand Up @@ -121,6 +129,51 @@ is_rvar <- function(x) {
inherits(x, "rvar")
}

#' Is `x` a complex random variable?
#'
#' Test if `x` is an [`rvar`] backed by [`complex`] draws.
#'
#' @inheritParams is_rvar
#'
#' @seealso [as_rvar_complex()] to convert objects to `rvar`s backed by [`complex`] draws.
#'
#' @return `TRUE` if `x` is an [`rvar`] backed by [`complex`] draws, `FALSE` otherwise.
#'
#' @export
is_rvar_complex <- function(x) {
is.complex(draws_of(x))
}

#' Is `x` an integer random variable?
#'
#' Test if `x` is an [`rvar`] backed by [`integer`] draws.
#'
#' @inheritParams is_rvar
#'
#' @seealso [as_rvar_integer()] to convert objects to `rvar`s backed by [`integer`] draws.
#'
#' @return `TRUE` if `x` is an [`rvar`] backed by [`integer`] draws, `FALSE` otherwise.
#'
#' @export
is_rvar_integer <- function(x) {
is.integer(draws_of(x))
}

#' Is `x` a logical random variable?
#'
#' Test if `x` is an [`rvar`] backed by [`logical`] draws.
#'
#' @inheritParams is_rvar
#'
#' @seealso [as_rvar_logical()] to convert objects to `rvar`s backed by [`logical`] draws.
#'
#' @return `TRUE` if `x` is an [`rvar`] backed by [`logical`] draws, `FALSE` otherwise.
#'
#' @export
is_rvar_logical <- function(x) {
is.logical(draws_of(x))
}

#' @export
is.matrix.rvar <- function(x) {
length(dim(draws_of(x))) == 3
Expand Down Expand Up @@ -384,6 +437,22 @@ vec_cast.rvar_factor.double <- function(x, to, ...) new_constant_rvar(while_pres
#' @export
vec_cast.rvar_ordered.double <- function(x, to, ...) new_constant_rvar(while_preserving_dims(as.ordered, x))

# complex -> rvar
#' @export
vec_ptype2.complex.rvar <- function(x, y, ...) new_rvar()
#' @export
vec_ptype2.rvar.complex <- function(x, y, ...) new_rvar()
#' @export
vec_cast.rvar.complex <- function(x, to, ...) new_constant_rvar(x)

# complex -> rvar_factor
#' @export
vec_cast.rvar_factor.complex <- function(x, to, ...) new_constant_rvar(while_preserving_dims(as.factor, x))

# complex -> rvar_ordered
#' @export
vec_cast.rvar_ordered.complex <- function(x, to, ...) new_constant_rvar(while_preserving_dims(as.ordered, x))

# integer -> rvar
#' @export
vec_ptype2.integer.rvar <- function(x, y, ...) new_rvar()
Expand Down
12 changes: 8 additions & 4 deletions R/rvar-dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#' @name rvar-dist
#' @export
density.rvar <- function(x, at, ...) {
check_rvar_not_complex(x, "density")
summarise_rvar_by_element(x, function(draws) {
d <- density(draws, cut = 0, ...)
f <- approxfun(d$x, d$y, yleft = 0, yright = 0)
Expand All @@ -66,6 +67,7 @@ distributional::cdf
#' @rdname rvar-dist
#' @export
cdf.rvar <- function(x, q, ...) {
check_rvar_not_complex(x, "cdf")
summarise_rvar_by_element(x, function(draws) {
ecdf(draws)(q)
})
Expand All @@ -91,13 +93,15 @@ cdf.rvar_ordered <- function(x, q, ...) {
#' @rdname rvar-dist
#' @export
quantile.rvar <- function(x, probs, ...) {
check_rvar_not_complex(x, "quantile")
summarise_rvar_by_element_via_matrix(x,
"quantile",
function(draws) {
t(matrixStats::colQuantiles(draws, probs = probs, useNames = TRUE, ...))
},
function(..., names) t(matrixStats::colQuantiles(..., useNames = FALSE)),
.extra_dim = length(probs),
.extra_dimnames = list(NULL)
.extra_dimnames = list(NULL),
probs = probs,
names = FALSE,
...
)
}

Expand Down
14 changes: 12 additions & 2 deletions R/rvar-math.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,27 @@ Math.rvar <- function(x, ...) {
if (.Generic %in% c("cumsum", "cumprod", "cummax", "cummin")) {
# cumulative functions need to be handled differently
# from other functions in this generic
new_rvar(t(apply(draws_of(x), 1, f)), .nchains = nchains(x))
if (length(x) > 1) {
draws_of(x) <- t(apply(draws_of(x), 1, f))
}
} else {
new_rvar(f(draws_of(x), ...), .nchains = nchains(x))
draws_of(x) <- f(draws_of(x), ...)
}

x
}

#' @export
Math.rvar_factor <- function(x, ...) {
stop_no_call("Cannot apply `", .Generic, "` function to rvar_factor objects.")
}

#' @export
Complex.rvar <- function(z) {
f <- get(.Generic)
rvar_apply_vec_fun(f, z)
}

# matrix stuff ---------------------------------------------------

#' Matrix multiplication of random variables
Expand Down
Loading

0 comments on commit a76290c

Please sign in to comment.