Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/stan/mcmc/covar_adaptation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ class covar_adaptation : public windowed_adaptation {
explicit covar_adaptation(int n)
: windowed_adaptation("covariance"), estimator_(n) {}

/**
* Return true if covariance was updated and adaptation is not finished
*
* @param covar Covariance
* @param q Last draw
*/
bool learn_covariance(Eigen::MatrixXd& covar, const Eigen::VectorXd& q) {
if (adaptation_window())
estimator_.add_sample(q);
Expand All @@ -30,11 +36,11 @@ class covar_adaptation : public windowed_adaptation {

estimator_.restart();

++adapt_window_counter_;
return true;
increment_window_counter();
return true && !finished();
}

++adapt_window_counter_;
increment_window_counter();
return false;
}

Expand Down
12 changes: 9 additions & 3 deletions src/stan/mcmc/var_adaptation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ class var_adaptation : public windowed_adaptation {
explicit var_adaptation(int n)
: windowed_adaptation("variance"), estimator_(n) {}

/**
* Return true if variance was updated and adaptation is not finished
*
* @param var Diagonal covariance
* @param q Last draw
*/
bool learn_variance(Eigen::VectorXd& var, const Eigen::VectorXd& q) {
if (adaptation_window())
estimator_.add_sample(q);
Expand All @@ -29,11 +35,11 @@ class var_adaptation : public windowed_adaptation {

estimator_.restart();

++adapt_window_counter_;
return true;
increment_window_counter();
return true && !finished();
}

++adapt_window_counter_;
increment_window_counter();
return false;
}

Expand Down
19 changes: 19 additions & 0 deletions src/stan/mcmc/windowed_adaptation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,25 @@ class windowed_adaptation : public base_adaptation {
}
}

/**
* Check if there is any more warmup left to do
*/
bool finished() {
if (adapt_window_counter_ + 1 >= num_warmup_) {
return true;
} else {
return false;
}
}

/**
* Increment the window counter and return the new value
*/
unsigned int increment_window_counter() {
adapt_window_counter_ += 1;
return adapt_window_counter_;
}

protected:
std::string estimator_name_;

Expand Down
17 changes: 14 additions & 3 deletions src/test/unit/mcmc/covar_adaptation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,26 @@ TEST(McmcCovarAdaptation, learn_covariance) {
target_covar *= 1e-3 * 5.0 / (n_learn + 5.0);

stan::mcmc::covar_adaptation adapter(n);
adapter.set_window_params(50, 0, 0, n_learn, logger);
adapter.set_window_params(30, 0, 0, n_learn, logger);

for (int i = 0; i < n_learn; ++i)
adapter.learn_covariance(covar, q);
for (int i = 0; i < n_learn - 1; ++i) {
EXPECT_FALSE(adapter.learn_covariance(covar, q));
}
// Learn variance should return true at end of first window
EXPECT_TRUE(adapter.learn_covariance(covar, q));

for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
EXPECT_EQ(target_covar(i, j), covar(i, j));
}
}

// Make sure learn_variance doesn't return true after second window
// (adaptation finished)
for (int i = 0; i < 2 * n_learn; ++i) {
EXPECT_FALSE(adapter.learn_covariance(covar, q));
}
EXPECT_TRUE(adapter.finished());

EXPECT_EQ(0, logger.call_count());
}
16 changes: 13 additions & 3 deletions src/test/unit/mcmc/var_adaptation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,23 @@ TEST(McmcVarAdaptation, learn_variance) {
target_var *= 1e-3 * 5.0 / (n_learn + 5.0);

stan::mcmc::var_adaptation adapter(n);
adapter.set_window_params(50, 0, 0, n_learn, logger);
adapter.set_window_params(30, 0, 0, n_learn, logger);

for (int i = 0; i < n_learn; ++i)
adapter.learn_variance(var, q);
for (int i = 0; i < n_learn - 1; ++i) {
EXPECT_FALSE(adapter.learn_variance(var, q));
}
// Learn variance should return true at end of first window
EXPECT_TRUE(adapter.learn_variance(var, q));

for (int i = 0; i < n; ++i)
EXPECT_EQ(target_var(i), var(i));

// Make sure learn_variance doesn't return true after second window
// (adaptation finished)
for (int i = 0; i < 2 * n_learn; ++i) {
EXPECT_FALSE(adapter.learn_variance(var, q));
}
EXPECT_TRUE(adapter.finished());

EXPECT_EQ(0, logger.call_count());
}
14 changes: 14 additions & 0 deletions src/test/unit/mcmc/windowed_adaptation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,17 @@ TEST(McmcWindowedAdaptation, set_window_params3) {
ASSERT_EQ(0, logger.call_count());
ASSERT_EQ(0, logger.call_count_info());
}

TEST(McmcWindowedAdaptation, finished) {
stan::test::unit::instrumented_logger logger;

stan::mcmc::windowed_adaptation adapter("test");

adapter.set_window_params(1000, 75, 50, 25, logger);

for (size_t i = 0; i < 999; i++) {
EXPECT_FALSE(adapter.finished());
adapter.increment_window_counter();
}
EXPECT_TRUE(adapter.finished());
}