diff --git a/src/stan/mcmc/covar_adaptation.hpp b/src/stan/mcmc/covar_adaptation.hpp index 6c21c63e33..05b1a87d49 100644 --- a/src/stan/mcmc/covar_adaptation.hpp +++ b/src/stan/mcmc/covar_adaptation.hpp @@ -14,6 +14,12 @@ class covar_adaptation : public windowed_adaptation { explicit covar_adaptation(int n) : windowed_adaptation("covariance"), estimator_(n) {} + /** + * Return true if covariance was updated and adaptation is not finished + * + * @param covar Covariance + * @param q Last draw + */ bool learn_covariance(Eigen::MatrixXd& covar, const Eigen::VectorXd& q) { if (adaptation_window()) estimator_.add_sample(q); @@ -30,11 +36,11 @@ class covar_adaptation : public windowed_adaptation { estimator_.restart(); - ++adapt_window_counter_; - return true; + increment_window_counter(); + return true && !finished(); } - ++adapt_window_counter_; + increment_window_counter(); return false; } diff --git a/src/stan/mcmc/var_adaptation.hpp b/src/stan/mcmc/var_adaptation.hpp index c81de41d98..1edb8ff41f 100644 --- a/src/stan/mcmc/var_adaptation.hpp +++ b/src/stan/mcmc/var_adaptation.hpp @@ -14,6 +14,12 @@ class var_adaptation : public windowed_adaptation { explicit var_adaptation(int n) : windowed_adaptation("variance"), estimator_(n) {} + /** + * Return true if variance was updated and adaptation is not finished + * + * @param var Diagonal covariance + * @param q Last draw + */ bool learn_variance(Eigen::VectorXd& var, const Eigen::VectorXd& q) { if (adaptation_window()) estimator_.add_sample(q); @@ -29,11 +35,11 @@ class var_adaptation : public windowed_adaptation { estimator_.restart(); - ++adapt_window_counter_; - return true; + increment_window_counter(); + return true && !finished(); } - ++adapt_window_counter_; + increment_window_counter(); return false; } diff --git a/src/stan/mcmc/windowed_adaptation.hpp b/src/stan/mcmc/windowed_adaptation.hpp index 36bf1d1d98..81fb46eac0 100644 --- a/src/stan/mcmc/windowed_adaptation.hpp +++ b/src/stan/mcmc/windowed_adaptation.hpp @@ -109,6 +109,19 @@ class windowed_adaptation : public base_adaptation { } } + /** + * Check if there is any more warmup left to do + */ + bool finished() { return adapt_window_counter_ + 1 >= num_warmup_; } + + /** + * Increment the window counter and return the new value + */ + unsigned int increment_window_counter() { + adapt_window_counter_ += 1; + return adapt_window_counter_; + } + protected: std::string estimator_name_; diff --git a/src/test/unit/mcmc/covar_adaptation_test.cpp b/src/test/unit/mcmc/covar_adaptation_test.cpp index f1e0853d88..9bcdf36a2a 100644 --- a/src/test/unit/mcmc/covar_adaptation_test.cpp +++ b/src/test/unit/mcmc/covar_adaptation_test.cpp @@ -15,15 +15,26 @@ TEST(McmcCovarAdaptation, learn_covariance) { target_covar *= 1e-3 * 5.0 / (n_learn + 5.0); stan::mcmc::covar_adaptation adapter(n); - adapter.set_window_params(50, 0, 0, n_learn, logger); + adapter.set_window_params(30, 0, 0, n_learn, logger); - for (int i = 0; i < n_learn; ++i) - adapter.learn_covariance(covar, q); + for (int i = 0; i < n_learn - 1; ++i) { + EXPECT_FALSE(adapter.learn_covariance(covar, q)); + } + // Learn covariance should return true at end of first window + EXPECT_TRUE(adapter.learn_covariance(covar, q)); for (int i = 0; i < n; ++i) { for (int j = 0; j < n; ++j) { EXPECT_EQ(target_covar(i, j), covar(i, j)); } } + + // Make sure learn_covariance doesn't return true after second window + // (adaptation finished) + for (int i = 0; i < 2 * n_learn; ++i) { + EXPECT_FALSE(adapter.learn_covariance(covar, q)); + } + EXPECT_TRUE(adapter.finished()); + EXPECT_EQ(0, logger.call_count()); } diff --git a/src/test/unit/mcmc/var_adaptation_test.cpp b/src/test/unit/mcmc/var_adaptation_test.cpp index 708f5b2e7f..6580809ad3 100644 --- a/src/test/unit/mcmc/var_adaptation_test.cpp +++ b/src/test/unit/mcmc/var_adaptation_test.cpp @@ -15,13 +15,23 @@ TEST(McmcVarAdaptation, learn_variance) { target_var *= 1e-3 * 5.0 / (n_learn + 5.0); stan::mcmc::var_adaptation adapter(n); - adapter.set_window_params(50, 0, 0, n_learn, logger); + adapter.set_window_params(30, 0, 0, n_learn, logger); - for (int i = 0; i < n_learn; ++i) - adapter.learn_variance(var, q); + for (int i = 0; i < n_learn - 1; ++i) { + EXPECT_FALSE(adapter.learn_variance(var, q)); + } + // Learn variance should return true at end of first window + EXPECT_TRUE(adapter.learn_variance(var, q)); for (int i = 0; i < n; ++i) EXPECT_EQ(target_var(i), var(i)); + // Make sure learn_variance doesn't return true after second window + // (adaptation finished) + for (int i = 0; i < 2 * n_learn; ++i) { + EXPECT_FALSE(adapter.learn_variance(var, q)); + } + EXPECT_TRUE(adapter.finished()); + EXPECT_EQ(0, logger.call_count()); } diff --git a/src/test/unit/mcmc/windowed_adaptation_test.cpp b/src/test/unit/mcmc/windowed_adaptation_test.cpp index c30455b0ff..f54c867d76 100644 --- a/src/test/unit/mcmc/windowed_adaptation_test.cpp +++ b/src/test/unit/mcmc/windowed_adaptation_test.cpp @@ -46,3 +46,17 @@ TEST(McmcWindowedAdaptation, set_window_params3) { ASSERT_EQ(0, logger.call_count()); ASSERT_EQ(0, logger.call_count_info()); } + +TEST(McmcWindowedAdaptation, finished) { + stan::test::unit::instrumented_logger logger; + + stan::mcmc::windowed_adaptation adapter("test"); + + adapter.set_window_params(1000, 75, 50, 25, logger); + + for (size_t i = 0; i < 999; i++) { + EXPECT_FALSE(adapter.finished()); + adapter.increment_window_counter(); + } + EXPECT_TRUE(adapter.finished()); +} diff --git a/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_test.cpp b/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_test.cpp index 8638966d5d..7fd9b11eaa 100644 --- a/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_test.cpp +++ b/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_test.cpp @@ -184,3 +184,47 @@ TEST_F(ServicesSampleHmcNutsDenseEAdapt, output_regression) { EXPECT_EQ(1, logger.find_info("seconds (Total)")); EXPECT_EQ(0, logger.call_count_error()); } + +TEST_F(ServicesSampleHmcNutsDenseEAdapt, no_timestep_reset) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 70; + int num_samples = 100; + int num_thin = 5; + bool save_warmup = true; + int refresh = 0; + double stepsize = 0.1; + double stepsize_jitter = 0; + int max_depth = 8; + double delta = .1; + double gamma = .1; + double kappa = .1; + double t0 = .1; + unsigned int init_buffer = 0; + unsigned int term_buffer = 0; + unsigned int window = 10; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + int return_code = stan::services::sample::hmc_nuts_dense_e_adapt( + model, context, random_seed, chain, init_radius, num_warmup, num_samples, + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, + logger, init, parameter, diagnostic); + + std::vector string_values = parameter.string_values(); + bool found_step_size = false; + for (size_t i = 0; i < string_values.size(); i++) { + // Make sure the sampler wrote a Step size and that it is not reset to + // exactly 1 + if (string_values[i].compare("Step size")) { + found_step_size = true; + EXPECT_NE(string_values[i].compare("Step size = 1"), 0); + } + } + + EXPECT_TRUE(found_step_size); + + EXPECT_EQ(return_code, 0); +}