1 #ifndef STAN_VARIATIONAL_ADVI_HPP 2 #define STAN_VARIATIONAL_ADVI_HPP 4 #include <stan/math.hpp> 13 #include <boost/circular_buffer.hpp> 14 #include <boost/lexical_cast.hpp> 25 namespace variational {
38 template <
class Model,
class Q,
class BaseRNG>
57 Eigen::VectorXd& cont_params,
59 int n_monte_carlo_grad,
60 int n_monte_carlo_elbo,
62 int n_posterior_samples)
70 static const char*
function =
"stan::variational::advi";
71 math::check_positive(
function,
72 "Number of Monte Carlo samples for gradients",
74 math::check_positive(
function,
75 "Number of Monte Carlo samples for ELBO",
77 math::check_positive(
function,
78 "Evaluate ELBO at every eval_elbo iteration",
80 math::check_positive(
function,
81 "Number of posterior samples for output",
102 static const char*
function =
103 "stan::variational::advi::calc_ELBO";
106 int dim = variational.dimension();
107 Eigen::VectorXd zeta(dim);
109 int n_dropped_evaluations = 0;
111 variational.sample(
rng_, zeta);
113 std::stringstream ss;
114 double log_prob =
model_.template log_prob<false, true>(zeta, &ss);
115 if (ss.str().length() > 0)
116 message_writer(ss.str());
117 stan::math::check_finite(
function,
"log_prob", log_prob);
120 }
catch (
const std::domain_error& e) {
121 ++n_dropped_evaluations;
122 if (n_dropped_evaluations >= n_monte_carlo_elbo_) {
123 const char* name =
"The number of dropped evaluations";
124 const char* msg1 =
"has reached its maximum amount (";
125 const char* msg2 =
"). Your model may be either severely " 126 "ill-conditioned or misspecified.";
127 stan::math::domain_error(
function, name, n_monte_carlo_elbo_,
133 elbo += variational.entropy();
150 static const char*
function =
151 "stan::variational::advi::calc_ELBO_grad";
153 stan::math::check_size_match(
function,
154 "Dimension of elbo_grad",
155 elbo_grad.dimension(),
156 "Dimension of variational q",
157 variational.dimension());
158 stan::math::check_size_match(
function,
159 "Dimension of variational q",
160 variational.dimension(),
161 "Dimension of variables in model",
164 variational.calc_grad(elbo_grad,
182 int adapt_iterations,
185 static const char*
function =
"stan::variational::advi::adapt_eta";
187 stan::math::check_positive(
function,
188 "Number of adaptation iterations",
191 message_writer(
"Begin eta adaptation.");
194 const int eta_sequence_size = 5;
195 double eta_sequence[eta_sequence_size] = {100, 10, 1, 0.1, 0.01};
198 double elbo = -std::numeric_limits<double>::max();
199 double elbo_best = -std::numeric_limits<double>::max();
202 elbo_init =
calc_ELBO(variational, message_writer);
203 }
catch (
const std::domain_error& e) {
204 const char* name =
"Cannot compute ELBO using the initial " 205 "variational distribution.";
206 const char* msg1 =
"Your model may be either " 207 "severely ill-conditioned or misspecified.";
208 stan::math::domain_error(
function, name,
"", msg1);
212 Q elbo_grad = Q(
model_.num_params_r());
215 Q history_grad_squared = Q(
model_.num_params_r());
217 double pre_factor = 0.9;
218 double post_factor = 0.1;
220 double eta_best = 0.0;
224 bool do_more_tuning =
true;
225 int eta_sequence_index = 0;
226 while (do_more_tuning) {
228 eta = eta_sequence[eta_sequence_index];
230 int print_progress_m;
231 for (
int iter_tune = 1; iter_tune <= adapt_iterations; ++iter_tune) {
232 print_progress_m = eta_sequence_index
233 * adapt_iterations + iter_tune;
236 adapt_iterations * eta_sequence_size,
237 adapt_iterations,
true,
"",
"", message_writer);
243 }
catch (
const std::domain_error& e) {
244 elbo_grad.set_to_zero();
248 if (iter_tune == 1) {
249 history_grad_squared += elbo_grad.square();
251 history_grad_squared = pre_factor * history_grad_squared
252 + post_factor * elbo_grad.square();
254 eta_scaled = eta / sqrt(static_cast<double>(iter_tune));
256 variational += eta_scaled * elbo_grad
257 / (tau + history_grad_squared.sqrt());
262 elbo =
calc_ELBO(variational, message_writer);
263 }
catch (
const std::domain_error& e) {
264 elbo = -std::numeric_limits<double>::max();
270 if (elbo < elbo_best && elbo_best > elbo_init) {
271 std::stringstream ss;
273 <<
" Found best value [eta = " << eta_best
275 if (eta_sequence_index < eta_sequence_size - 1)
276 ss << (
" earlier than expected.");
279 message_writer(ss.str());
281 do_more_tuning =
false;
283 if (eta_sequence_index < eta_sequence_size - 1) {
290 if (elbo > elbo_init) {
291 std::stringstream ss;
293 <<
" Found best value [eta = " << eta_best
295 message_writer(ss.str());
298 do_more_tuning =
false;
300 const char* name =
"All proposed step-sizes";
301 const char* msg1 =
"failed. Your model may be either " 302 "severely ill-conditioned or misspecified.";
303 stan::math::domain_error(
function, name,
"", msg1);
307 history_grad_squared.set_to_zero();
309 ++eta_sequence_index;
334 static const char*
function =
335 "stan::variational::advi::stochastic_gradient_ascent";
337 stan::math::check_positive(
function,
"Eta stepsize", eta);
338 stan::math::check_positive(
function,
339 "Relative objective function tolerance",
341 stan::math::check_positive(
function,
342 "Maximum iterations",
346 Q elbo_grad = Q(
model_.num_params_r());
349 Q history_grad_squared = Q(
model_.num_params_r());
351 double pre_factor = 0.9;
352 double post_factor = 0.1;
357 double elbo_best = -std::numeric_limits<double>::max();
358 double elbo_prev = -std::numeric_limits<double>::max();
359 double delta_elbo = std::numeric_limits<double>::max();
360 double delta_elbo_ave = std::numeric_limits<double>::max();
361 double delta_elbo_med = std::numeric_limits<double>::max();
365 =
static_cast<int>(std::max(0.1 * max_iterations /
eval_elbo_,
367 boost::circular_buffer<double> elbo_diff(cb_size);
369 message_writer(
"Begin stochastic gradient ascent.");
370 message_writer(
" iter" 377 clock_t start = clock();
382 bool do_more_iterations =
true;
383 for (
int iter_counter = 1; do_more_iterations; ++iter_counter) {
388 if (iter_counter == 1) {
389 history_grad_squared += elbo_grad.square();
391 history_grad_squared = pre_factor * history_grad_squared
392 + post_factor * elbo_grad.square();
394 eta_scaled = eta / sqrt(static_cast<double>(iter_counter));
397 variational += eta_scaled * elbo_grad
398 / (tau + history_grad_squared.sqrt());
403 elbo =
calc_ELBO(variational, message_writer);
404 if (elbo > elbo_best)
407 elbo_diff.push_back(delta_elbo);
408 delta_elbo_ave = std::accumulate(elbo_diff.begin(),
409 elbo_diff.end(), 0.0)
410 /
static_cast<double>(elbo_diff.size());
412 std::stringstream ss;
414 << std::setw(4) << iter_counter
416 << std::right << std::setw(9) << std::setprecision(1)
419 << std::setw(16) << std::fixed << std::setprecision(3)
422 << std::setw(15) << std::fixed << std::setprecision(3)
426 delta_t =
static_cast<double>(end - start) / CLOCKS_PER_SEC;
428 std::vector<double> print_vector;
429 print_vector.clear();
430 print_vector.push_back(iter_counter);
431 print_vector.push_back(delta_t);
432 print_vector.push_back(elbo);
433 diagnostic_writer(print_vector);
435 if (delta_elbo_ave < tol_rel_obj) {
436 ss <<
" MEAN ELBO CONVERGED";
437 do_more_iterations =
false;
440 if (delta_elbo_med < tol_rel_obj) {
441 ss <<
" MEDIAN ELBO CONVERGED";
442 do_more_iterations =
false;
446 if (delta_elbo_med > 0.5 || delta_elbo_ave > 0.5) {
447 ss <<
" MAY BE DIVERGING... INSPECT ELBO";
451 message_writer(ss.str());
453 if (do_more_iterations ==
false &&
455 message_writer(
"Informational Message: The ELBO at a previous " 456 "iteration is larger than the ELBO upon " 458 message_writer(
"This variational approximation may not " 459 "have converged to a good optimum.");
463 if (iter_counter == max_iterations) {
464 message_writer(
"Informational Message: The maximum number of " 465 "iterations is reached! The algorithm may not have " 467 message_writer(
"This variational approximation is not " 468 "guaranteed to be meaningful.");
469 do_more_iterations =
false;
486 int run(
double eta,
bool adapt_engaged,
int adapt_iterations,
487 double tol_rel_obj,
int max_iterations,
492 diagnostic_writer(
"iter,time_in_seconds,ELBO");
498 eta =
adapt_eta(variational, adapt_iterations, message_writer);
499 parameter_writer(
"Stepsize adaptation complete.");
500 std::stringstream ss;
501 ss <<
"eta = " << eta;
502 parameter_writer(ss.str());
506 tol_rel_obj, max_iterations,
507 message_writer, diagnostic_writer);
514 std::vector<int> disc_vector;
517 0, cont_vector, disc_vector,
522 std::stringstream ss;
523 ss <<
"Drawing a sample of size " 525 <<
" from the approximate posterior... ";
526 message_writer(ss.str());
534 0, cont_vector, disc_vector,
538 message_writer(
"COMPLETED.");
553 std::vector<double> v;
554 for (boost::circular_buffer<double>::const_iterator i = cb.begin();
555 i != cb.end(); ++i) {
559 size_t n = v.size() / 2;
560 std::nth_element(v.begin(), v.begin()+n, v.end());
572 return std::fabs((curr - prev) / prev);
double adapt_eta(Q &variational, int adapt_iterations, interface_callbacks::writer::base_writer &message_writer) const
Heuristic grid search to adapt eta to the scale of the problem.
Automatic Differentiation Variational Inference.
Probability, optimization and sampling library.
void write_iteration(Model &model, RNG &base_rng, double lp, std::vector< double > &cont_vector, std::vector< int > &disc_vector, interface_callbacks::writer::base_writer &message_writer, interface_callbacks::writer::base_writer ¶meter_writer)
void calc_ELBO_grad(const Q &variational, Q &elbo_grad, interface_callbacks::writer::base_writer &message_writer) const
Calculates the "black box" gradient of the ELBO.
base_writer is an abstract base class defining the interface for Stan writer callbacks.
double rel_difference(double prev, double curr) const
Compute the relative difference between two double values.
advi(Model &m, Eigen::VectorXd &cont_params, BaseRNG &rng, int n_monte_carlo_grad, int n_monte_carlo_elbo, int eval_elbo, int n_posterior_samples)
Constructor.
double calc_ELBO(const Q &variational, interface_callbacks::writer::base_writer &message_writer) const
Calculates the Evidence Lower BOund (ELBO) by sampling from the variational distribution and then eva...
void stochastic_gradient_ascent(Q &variational, double eta, double tol_rel_obj, int max_iterations, interface_callbacks::writer::base_writer &message_writer, interface_callbacks::writer::base_writer &diagnostic_writer) const
Runs stochastic gradient ascent with an adaptive stepsize sequence.
Eigen::VectorXd & cont_params_
void print_progress(int m, int start, int finish, int refresh, bool tune, const std::string &prefix, const std::string &suffix, interface_callbacks::writer::base_writer &writer)
Helper function for printing progress for variational inference.
int run(double eta, bool adapt_engaged, int adapt_iterations, double tol_rel_obj, int max_iterations, interface_callbacks::writer::base_writer &message_writer, interface_callbacks::writer::base_writer ¶meter_writer, interface_callbacks::writer::base_writer &diagnostic_writer) const
Runs ADVI and writes to output.
double circ_buff_median(const boost::circular_buffer< double > &cb) const
Compute the median of a circular buffer.