Skip to content

Commit

Permalink
Merge pull request #124 from mrc-ide/mrc-6096
Browse files Browse the repository at this point in the history
Port tests to new interface
  • Loading branch information
richfitz authored Dec 9, 2024
2 parents e6af121 + f74bc48 commit 6c95da7
Show file tree
Hide file tree
Showing 7 changed files with 473 additions and 378 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.3.12
Version: 0.3.13
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
24 changes: 20 additions & 4 deletions R/cpp11.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 21 additions & 4 deletions R/random.R
Original file line number Diff line number Diff line change
Expand Up @@ -447,19 +447,36 @@ monty_random_n_negative_binomial_mu <- function(n_samples, size, mu, state) {
##'
##' @param sd The standard deviation of the normal distribution
##'
##' @param algorithm The algorithm to use for the normal samples;
##' currently `box_muller`, `polar` and and `ziggurat` are
##' supported, with the latter being considerably faster. The
##' default may change in a future version.
##'
##' @inheritParams monty_random_real
##' @inherit monty_random_real return
##'
##' @export
monty_random_normal <- function(mean, sd, state) {
cpp_monty_random_normal(mean, sd, state)
monty_random_normal <- function(mean, sd, state, algorithm = "box_muller") {
switch(algorithm,
box_muller = cpp_monty_random_normal_box_muller(mean, sd, state),
polar = cpp_monty_random_normal_polar(mean, sd, state),
ziggurat = cpp_monty_random_normal_ziggurat(mean, sd, state),
cli::cli_abort("Unknown normal algorithm '{algorithm}'"))
}


##' @export
##' @rdname monty_random_normal
monty_random_n_normal <- function(n_samples, mean, sd, state) {
cpp_monty_random_n_normal(n_samples, mean, sd, state)
monty_random_n_normal <- function(n_samples, mean, sd, state,
algorithm = "box_muller") {
switch(algorithm,
box_muller =
cpp_monty_random_n_normal_box_muller(n_samples, mean, sd, state),
polar =
cpp_monty_random_n_normal_polar(n_samples, mean, sd, state),
ziggurat =
cpp_monty_random_n_normal_ziggurat(n_samples, mean, sd, state),
cli::cli_abort("Unknown normal algorithm '{algorithm}'"))
}


Expand Down
9 changes: 7 additions & 2 deletions man/monty_random_normal.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 40 additions & 8 deletions src/cpp11.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 51 additions & 4 deletions src/random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,25 +533,72 @@ cpp11::doubles cpp_monty_random_n_negative_binomial_mu(size_t n_samples,
}

// normal
[[cpp11::register]]
template <typename monty::random::algorithm::normal A>
cpp11::doubles cpp_monty_random_normal(cpp11::doubles mean,
cpp11::doubles sd,
cpp11::sexp ptr) {
const auto fn = [](auto& state, auto mean, auto sd) { return monty::random::normal<double>(state, mean, sd); };
const auto fn = [](auto& state, auto mean, auto sd) { return monty::random::normal<double, A>(state, mean, sd); };
return monty_random_sample_1_2(fn, ptr, "normal",
mean, sd, "mean", "sd");
}

[[cpp11::register]]
template <typename monty::random::algorithm::normal A>
cpp11::doubles cpp_monty_random_n_normal(size_t n_samples,
cpp11::doubles mean,
cpp11::doubles sd,
cpp11::sexp ptr) {
const auto fn = [](auto& state, auto mean, auto sd) { return monty::random::normal<double>(state, mean, sd); };
const auto fn = [](auto& state, auto mean, auto sd) { return monty::random::normal<double, A>(state, mean, sd); };
return monty_random_sample_n_2(fn, n_samples, ptr, "normal",
mean, sd, "mean", "sd");
}

[[cpp11::register]]
cpp11::doubles cpp_monty_random_normal_box_muller(cpp11::doubles mean,
cpp11::doubles sd,
cpp11::sexp ptr) {
return cpp_monty_random_normal<monty::random::algorithm::normal::box_muller>(mean, sd, ptr);
}


[[cpp11::register]]
cpp11::doubles cpp_monty_random_normal_polar(cpp11::doubles mean,
cpp11::doubles sd,
cpp11::sexp ptr) {
return cpp_monty_random_normal<monty::random::algorithm::normal::polar>(mean, sd, ptr);
}

[[cpp11::register]]
cpp11::doubles cpp_monty_random_normal_ziggurat(cpp11::doubles mean,
cpp11::doubles sd,
cpp11::sexp ptr) {
return cpp_monty_random_normal<monty::random::algorithm::normal::ziggurat>(mean, sd, ptr);
}

[[cpp11::register]]
cpp11::doubles cpp_monty_random_n_normal_box_muller(size_t n_samples,
cpp11::doubles mean,
cpp11::doubles sd,
cpp11::sexp ptr) {
return cpp_monty_random_n_normal<monty::random::algorithm::normal::box_muller>(n_samples, mean, sd, ptr);
}


[[cpp11::register]]
cpp11::doubles cpp_monty_random_n_normal_polar(size_t n_samples,
cpp11::doubles mean,
cpp11::doubles sd,
cpp11::sexp ptr) {
return cpp_monty_random_n_normal<monty::random::algorithm::normal::polar>(n_samples, mean, sd, ptr);
}

[[cpp11::register]]
cpp11::doubles cpp_monty_random_n_normal_ziggurat(size_t n_samples,
cpp11::doubles mean,
cpp11::doubles sd,
cpp11::sexp ptr) {
return cpp_monty_random_n_normal<monty::random::algorithm::normal::ziggurat>(n_samples, mean, sd, ptr);
}

// uniform
[[cpp11::register]]
cpp11::doubles cpp_monty_random_uniform(cpp11::doubles min,
Expand Down
Loading

0 comments on commit 6c95da7

Please sign in to comment.