Skip to content

Commit c97b1de

Browse files
authored
Merge 30442fa into 9059111
2 parents 9059111 + 30442fa commit c97b1de

22 files changed

+439
-53
lines changed

NAMESPACE

+13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ S3method(.subset_draws,draws_df)
2323
S3method(.subset_draws,draws_list)
2424
S3method(.subset_draws,draws_matrix)
2525
S3method(.subset_draws,draws_rvars)
26+
S3method(Complex,rvar)
2627
S3method(Math,rvar)
2728
S3method(Math,rvar_factor)
2829
S3method(Ops,rvar)
@@ -275,6 +276,7 @@ S3method(rhat_basic,default)
275276
S3method(rhat_basic,rvar)
276277
S3method(rhat_nested,default)
277278
S3method(rhat_nested,rvar)
279+
S3method(sd,complex)
278280
S3method(sd,default)
279281
S3method(sd,rvar)
280282
S3method(split_chains,draws)
@@ -297,6 +299,7 @@ S3method(thin_draws,draws)
297299
S3method(thin_draws,rvar)
298300
S3method(unique,rvar)
299301
S3method(unique,rvar_factor)
302+
S3method(var,complex)
300303
S3method(var,default)
301304
S3method(var,rvar)
302305
S3method(variables,"NULL")
@@ -305,6 +308,7 @@ S3method(variables,draws_df)
305308
S3method(variables,draws_list)
306309
S3method(variables,draws_matrix)
307310
S3method(variables,draws_rvars)
311+
S3method(variance,complex)
308312
S3method(variance,draws_array)
309313
S3method(variance,draws_matrix)
310314
S3method(variance,rvar)
@@ -313,6 +317,7 @@ S3method(vec_cast,character.rvar_factor)
313317
S3method(vec_cast,character.rvar_ordered)
314318
S3method(vec_cast,distribution.rvar)
315319
S3method(vec_cast,rvar.character)
320+
S3method(vec_cast,rvar.complex)
316321
S3method(vec_cast,rvar.distribution)
317322
S3method(vec_cast,rvar.double)
318323
S3method(vec_cast,rvar.factor)
@@ -323,6 +328,7 @@ S3method(vec_cast,rvar.rvar)
323328
S3method(vec_cast,rvar.rvar_factor)
324329
S3method(vec_cast,rvar.rvar_ordered)
325330
S3method(vec_cast,rvar_factor.character)
331+
S3method(vec_cast,rvar_factor.complex)
326332
S3method(vec_cast,rvar_factor.double)
327333
S3method(vec_cast,rvar_factor.factor)
328334
S3method(vec_cast,rvar_factor.integer)
@@ -332,6 +338,7 @@ S3method(vec_cast,rvar_factor.rvar)
332338
S3method(vec_cast,rvar_factor.rvar_factor)
333339
S3method(vec_cast,rvar_factor.rvar_ordered)
334340
S3method(vec_cast,rvar_ordered.character)
341+
S3method(vec_cast,rvar_ordered.complex)
335342
S3method(vec_cast,rvar_ordered.double)
336343
S3method(vec_cast,rvar_ordered.factor)
337344
S3method(vec_cast,rvar_ordered.integer)
@@ -349,6 +356,7 @@ S3method(vec_ptype,rvar_factor)
349356
S3method(vec_ptype,rvar_ordered)
350357
S3method(vec_ptype2,character.rvar_factor)
351358
S3method(vec_ptype2,character.rvar_ordered)
359+
S3method(vec_ptype2,complex.rvar)
352360
S3method(vec_ptype2,distribution.rvar)
353361
S3method(vec_ptype2,double.rvar)
354362
S3method(vec_ptype2,factor.rvar_factor)
@@ -357,6 +365,7 @@ S3method(vec_ptype2,integer.rvar)
357365
S3method(vec_ptype2,logical.rvar)
358366
S3method(vec_ptype2,ordered.rvar_factor)
359367
S3method(vec_ptype2,ordered.rvar_ordered)
368+
S3method(vec_ptype2,rvar.complex)
360369
S3method(vec_ptype2,rvar.distribution)
361370
S3method(vec_ptype2,rvar.double)
362371
S3method(vec_ptype2,rvar.integer)
@@ -396,6 +405,7 @@ export(as_draws_list)
396405
export(as_draws_matrix)
397406
export(as_draws_rvars)
398407
export(as_rvar)
408+
export(as_rvar_complex)
399409
export(as_rvar_factor)
400410
export(as_rvar_integer)
401411
export(as_rvar_logical)
@@ -438,7 +448,10 @@ export(is_draws_list)
438448
export(is_draws_matrix)
439449
export(is_draws_rvars)
440450
export(is_rvar)
451+
export(is_rvar_complex)
441452
export(is_rvar_factor)
453+
export(is_rvar_integer)
454+
export(is_rvar_logical)
442455
export(is_rvar_ordered)
443456
export(iteration_ids)
444457
export(mad)

R/as_draws.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ check_draws_object <- function(x) {
109109
#' @noRd
110110
check_variables_are_numeric <- function(
111111
x, to = "draws_array",
112-
is_non_numeric = function(x_i) !is.numeric(x_i) && !is.logical(x_i),
112+
is_non_numeric = function(x_i) !is.numeric(x_i) && !is.logical(x_i) && !is.complex(x_i),
113113
convert = TRUE
114114
) {
115115

@@ -145,7 +145,7 @@ validate_draws_per_variable <- function(...) {
145145
# '.nchains' is an additional argument in chain supporting formats
146146
stop_no_call("'.nchains' is not supported for this format.")
147147
}
148-
out <- lapply(out, as.numeric)
148+
out <- lapply(out, as_numeric_or_complex)
149149
ndraws_per_variable <- lengths(out)
150150
ndraws <- max(ndraws_per_variable)
151151
if (!all(ndraws_per_variable %in% c(1, ndraws))) {

R/misc.R

+5
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ as_one_character <- function(x, allow_na = FALSE) {
111111
x
112112
}
113113

114+
# coerce 'x' to a numeric or complex vector
115+
as_numeric_or_complex <- function(x) {
116+
if (is.numeric(x) || is.complex(x)) x else as.numeric(x)
117+
}
118+
114119
# check if all inputs are NULL
115120
all_null <- function(...) {
116121
all(ulapply(list(...), is.null))

R/rvar-.R

+27-4
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,14 @@ setOldClass(get_rvar_class(ordered(NULL)))
474474

475475
# helpers: validation -----------------------------------------------------------------
476476

477+
# check the given rvar is not complex
478+
check_rvar_not_complex <- function(x, f = NULL) {
479+
if (is_rvar_complex(x)) {
480+
f <- if (is.null(f)) "" else paste0("`", f, "` ")
481+
stop_no_call("Cannot apply ", f, "function to complex rvars.")
482+
}
483+
}
484+
477485
# Check the passed yank index (for x[[...]]) is valid
478486
check_rvar_yank_index = function(x, i, ...) {
479487
index <- dots_list(i, ..., .preserve_empty = TRUE, .ignore_empty = "none")
@@ -948,12 +956,16 @@ summarise_rvar_within_draws <- function(x, .f, ..., .transpose = FALSE, .when_em
948956
#' by first collapsing dimensions into columns of the draws matrix
949957
#' (so that .f can be a rowXXX() function)
950958
#' @param x an rvar
951-
#' @param name function name to use for error messages
959+
#' @param .name function name to use for error messages
952960
#' @param .f a function that takes a matrix and summarises its rows, like rowMeans
953961
#' @param ... arguments passed to `.f`
954962
#' @param .ordered_okay can this function be applied to rvar_ordereds?
955963
#' @noRd
956964
summarise_rvar_within_draws_via_matrix <- function(x, .name, .f, ..., .ordered_okay = FALSE) {
965+
if (is_rvar_complex(x)) {
966+
return(summarise_rvar_within_draws(x, match.fun(.name), ...))
967+
}
968+
957969
.length <- length(x)
958970
if (!.length) {
959971
x <- rvar()
@@ -966,7 +978,7 @@ summarise_rvar_within_draws_via_matrix <- function(x, .name, .f, ..., .ordered_o
966978
.draws <- .f(draws_of(as_rvar_numeric(x)), ...)
967979
.draws <- while_preserving_dims(function(.draws) ordered(.levels[round(.draws)], .levels), .draws)
968980
} else if (is_rvar_factor(x)) {
969-
stop_no_call("Cannot apply `", .name, "` function to rvar_factor objects.")
981+
stop_no_call("Cannot apply `rvar_", .name, "` function to rvar_factor objects.")
970982
} else {
971983
.draws <- .f(draws_of(x), ...)
972984
}
@@ -997,18 +1009,29 @@ summarise_rvar_by_element <- function(x, .f, ...) {
9971009
#' by first collapsing dimensions into columns of the draws matrix, applying the
9981010
#' function, then restoring dimensions (so that .f can be a colXXX() function)
9991011
#' @param x an rvar
1000-
#' @param name function name to use for error messages
1012+
#' @param .name function name to use for error messages, and also function to
1013+
#' be used as a backup for complex numbers
10011014
#' @param .f a function that takes a matrix and summarises its columns, like colMeans
10021015
#' @param .extra_dim extra dims added by `.f` to the output, e.g. in the case of
10031016
#' matrixStats::colRanges this is `2`
10041017
#' @param .extra_dimnames extra dimension names for dims added by `.f` to the output
10051018
#' @param .ordered_okay can this function be applied to rvar_ordereds?
10061019
#' @param .factor_okay can this function be applied to rvar_factors?
1020+
#' @param .complex_okay can this function be applied to complex rvars? If not,
1021+
#' the function match.fun(.name) will be used instead, element-by-element.
10071022
#' @param ... arguments passed to `.f`
10081023
#' @noRd
10091024
summarise_rvar_by_element_via_matrix <- function(
1010-
x, .name, .f, .extra_dim = NULL, .extra_dimnames = NULL, .ordered_okay = TRUE, .factor_okay = FALSE, ...
1025+
x, .name, .f,
1026+
.extra_dim = NULL, .extra_dimnames = NULL,
1027+
.ordered_okay = TRUE, .factor_okay = FALSE,
1028+
.complex_okay = FALSE,
1029+
...
10111030
) {
1031+
if (is_rvar_complex(x) && !.complex_okay) {
1032+
return(summarise_rvar_by_element(x, match.fun(.name), ...))
1033+
}
1034+
10121035
.dim <- dim(x)
10131036
.dimnames <- dimnames(x)
10141037
.length <- length(x)

R/rvar-cast.R

+72-3
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
#' @details For objects that are already [`rvar`]s, returns them (with modified dimensions
1212
#' if `dim` is not `NULL`).
1313
#'
14-
#' For numeric or logical vectors or arrays, returns an [`rvar`] with a single draw and
14+
#' For [`numeric`], [`complex`], or [`logical`] vectors or arrays, returns an [`rvar`] with a single draw and
1515
#' the same dimensions as `x`. This is in contrast to the [rvar()] constructor, which
1616
#' treats the first dimension of `x` as the draws dimension. As a result, `as_rvar()`
1717
#' is useful for creating constants.
1818
#'
1919
#' While `as_rvar()` attempts to pick the most suitable subtype of [`rvar`] based on the
2020
#' type of `x` (possibly returning an [`rvar_factor`] or [`rvar_ordered`]),
21-
#' `as_rvar_numeric()`, `as_rvar_integer()`, and `as_rvar_logical()` always coerce
22-
#' the draws of the output [`rvar`] to be [`numeric`], [`integer`], or [`logical`]
21+
#' `as_rvar_numeric()`, `as_rvar_complex()`, `as_rvar_integer()`, and `as_rvar_logical()` always coerce
22+
#' the draws of the output [`rvar`] to be [`numeric`], [`complex`], [`integer`], or [`logical`]
2323
#' (respectively), and always return a base [`rvar`], never a subtype.
2424
#'
2525
#' @seealso [rvar()] to construct [`rvar`]s directly. See [rdo()], [rfun()], and
@@ -87,6 +87,14 @@ as_rvar_numeric <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) {
8787
out
8888
}
8989

90+
#' @rdname as_rvar
91+
#' @export
92+
as_rvar_complex <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) {
93+
out <- as_rvar(x, dim = dim, dimnames = dimnames, nchains = nchains)
94+
draws_of(out) <- while_preserving_dims(as.complex, draws_of(out))
95+
out
96+
}
97+
9098
#' @rdname as_rvar
9199
#' @export
92100
as_rvar_integer <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) {
@@ -121,6 +129,51 @@ is_rvar <- function(x) {
121129
inherits(x, "rvar")
122130
}
123131

132+
#' Is `x` a complex random variable?
133+
#'
134+
#' Test if `x` is an [`rvar`] backed by [`complex`] draws.
135+
#'
136+
#' @inheritParams is_rvar
137+
#'
138+
#' @seealso [as_rvar_complex()] to convert objects to `rvar`s backed by [`complex`] draws.
139+
#'
140+
#' @return `TRUE` if `x` is an [`rvar`] backed by [`complex`] draws, `FALSE` otherwise.
141+
#'
142+
#' @export
143+
is_rvar_complex <- function(x) {
144+
is.complex(draws_of(x))
145+
}
146+
147+
#' Is `x` an integer random variable?
148+
#'
149+
#' Test if `x` is an [`rvar`] backed by [`integer`] draws.
150+
#'
151+
#' @inheritParams is_rvar
152+
#'
153+
#' @seealso [as_rvar_integer()] to convert objects to `rvar`s backed by [`integer`] draws.
154+
#'
155+
#' @return `TRUE` if `x` is an [`rvar`] backed by [`integer`] draws, `FALSE` otherwise.
156+
#'
157+
#' @export
158+
is_rvar_integer <- function(x) {
159+
is.integer(draws_of(x))
160+
}
161+
162+
#' Is `x` a logical random variable?
163+
#'
164+
#' Test if `x` is an [`rvar`] backed by [`logical`] draws.
165+
#'
166+
#' @inheritParams is_rvar
167+
#'
168+
#' @seealso [as_rvar_logical()] to convert objects to `rvar`s backed by [`logical`] draws.
169+
#'
170+
#' @return `TRUE` if `x` is an [`rvar`] backed by [`logical`] draws, `FALSE` otherwise.
171+
#'
172+
#' @export
173+
is_rvar_logical <- function(x) {
174+
is.logical(draws_of(x))
175+
}
176+
124177
#' @export
125178
is.matrix.rvar <- function(x) {
126179
length(dim(draws_of(x))) == 3
@@ -384,6 +437,22 @@ vec_cast.rvar_factor.double <- function(x, to, ...) new_constant_rvar(while_pres
384437
#' @export
385438
vec_cast.rvar_ordered.double <- function(x, to, ...) new_constant_rvar(while_preserving_dims(as.ordered, x))
386439

440+
# complex -> rvar
441+
#' @export
442+
vec_ptype2.complex.rvar <- function(x, y, ...) new_rvar()
443+
#' @export
444+
vec_ptype2.rvar.complex <- function(x, y, ...) new_rvar()
445+
#' @export
446+
vec_cast.rvar.complex <- function(x, to, ...) new_constant_rvar(x)
447+
448+
# complex -> rvar_factor
449+
#' @export
450+
vec_cast.rvar_factor.complex <- function(x, to, ...) new_constant_rvar(while_preserving_dims(as.factor, x))
451+
452+
# complex -> rvar_ordered
453+
#' @export
454+
vec_cast.rvar_ordered.complex <- function(x, to, ...) new_constant_rvar(while_preserving_dims(as.ordered, x))
455+
387456
# integer -> rvar
388457
#' @export
389458
vec_ptype2.integer.rvar <- function(x, y, ...) new_rvar()

R/rvar-dist.R

+8-4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#' @name rvar-dist
4141
#' @export
4242
density.rvar <- function(x, at, ...) {
43+
check_rvar_not_complex(x, "density")
4344
summarise_rvar_by_element(x, function(draws) {
4445
d <- density(draws, cut = 0, ...)
4546
f <- approxfun(d$x, d$y, yleft = 0, yright = 0)
@@ -66,6 +67,7 @@ distributional::cdf
6667
#' @rdname rvar-dist
6768
#' @export
6869
cdf.rvar <- function(x, q, ...) {
70+
check_rvar_not_complex(x, "cdf")
6971
summarise_rvar_by_element(x, function(draws) {
7072
ecdf(draws)(q)
7173
})
@@ -91,13 +93,15 @@ cdf.rvar_ordered <- function(x, q, ...) {
9193
#' @rdname rvar-dist
9294
#' @export
9395
quantile.rvar <- function(x, probs, ...) {
96+
check_rvar_not_complex(x, "quantile")
9497
summarise_rvar_by_element_via_matrix(x,
9598
"quantile",
96-
function(draws) {
97-
t(matrixStats::colQuantiles(draws, probs = probs, useNames = TRUE, ...))
98-
},
99+
function(..., names) t(matrixStats::colQuantiles(..., useNames = FALSE)),
99100
.extra_dim = length(probs),
100-
.extra_dimnames = list(NULL)
101+
.extra_dimnames = list(NULL),
102+
probs = probs,
103+
names = FALSE,
104+
...
101105
)
102106
}
103107

R/rvar-math.R

+12-2
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,27 @@ Math.rvar <- function(x, ...) {
9595
if (.Generic %in% c("cumsum", "cumprod", "cummax", "cummin")) {
9696
# cumulative functions need to be handled differently
9797
# from other functions in this generic
98-
new_rvar(t(apply(draws_of(x), 1, f)), .nchains = nchains(x))
98+
if (length(x) > 1) {
99+
draws_of(x) <- t(apply(draws_of(x), 1, f))
100+
}
99101
} else {
100-
new_rvar(f(draws_of(x), ...), .nchains = nchains(x))
102+
draws_of(x) <- f(draws_of(x), ...)
101103
}
104+
105+
x
102106
}
103107

104108
#' @export
105109
Math.rvar_factor <- function(x, ...) {
106110
stop_no_call("Cannot apply `", .Generic, "` function to rvar_factor objects.")
107111
}
108112

113+
#' @export
114+
Complex.rvar <- function(z) {
115+
f <- get(.Generic)
116+
rvar_apply_vec_fun(f, z)
117+
}
118+
109119
# matrix stuff ---------------------------------------------------
110120

111121
#' Matrix multiplication of random variables

0 commit comments

Comments
 (0)