Stan  2.14.0
probability, sampling & optimization
base_xhmc.hpp
Go to the documentation of this file.
1 #ifndef STAN_MCMC_HMC_NUTS_BASE_XHMC_HPP
2 #define STAN_MCMC_HMC_NUTS_BASE_XHMC_HPP
3 
5 #include <boost/math/special_functions/fpclassify.hpp>
8 #include <algorithm>
9 #include <cmath>
10 #include <limits>
11 #include <string>
12 #include <vector>
13 
14 namespace stan {
15  namespace mcmc {
39  void stable_sum(double a1, double log_w1, double a2, double log_w2,
40  double& sum_a, double& log_sum_w) {
41  if (log_w2 > log_w1) {
42  double e = std::exp(log_w1 - log_w2);
43  sum_a = (e * a1 + a2) / (1 + e);
44  log_sum_w = log_w2 + std::log(1 + e);
45  } else {
46  double e = std::exp(log_w2 - log_w1);
47  sum_a = (a1 + e * a2) / (1 + e);
48  log_sum_w = log_w1 + std::log(1 + e);
49  }
50  }
51 
56  template <class Model, template<class, class> class Hamiltonian,
57  template<class> class Integrator, class BaseRNG>
58  class base_xhmc : public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
59  public:
60  base_xhmc(const Model& model, BaseRNG& rng)
61  : base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
62  depth_(0), max_depth_(5), max_deltaH_(1000), x_delta_(0.1),
63  n_leapfrog_(0), divergent_(0), energy_(0) {
64  }
65 
67 
68  void set_max_depth(int d) {
69  if (d > 0)
70  max_depth_ = d;
71  }
72 
73  void set_max_deltaH(double d) {
74  max_deltaH_ = d;
75  }
76 
77  void set_x_delta(double d) {
78  if (d > 0)
79  x_delta_ = d;
80  }
81 
82  int get_max_depth() { return this->max_depth_; }
83  double get_max_deltaH() { return this->max_deltaH_; }
84  double get_x_delta() { return this->x_delta_; }
85 
86  sample
87  transition(sample& init_sample,
90  // Initialize the algorithm
91  this->sample_stepsize();
92 
93  this->seed(init_sample.cont_params());
94 
95  this->hamiltonian_.sample_p(this->z_, this->rand_int_);
96  this->hamiltonian_.init(this->z_, info_writer, error_writer);
97 
98  ps_point z_plus(this->z_);
99  ps_point z_minus(z_plus);
100 
101  ps_point z_sample(z_plus);
102  ps_point z_propose(z_plus);
103 
104  double ave = this->hamiltonian_.dG_dt(this->z_,
105  info_writer, error_writer);
106  double log_sum_weight = 0; // log(exp(H0 - H0))
107 
108  double H0 = this->hamiltonian_.H(this->z_);
109  int n_leapfrog = 0;
110  double sum_metro_prob = 1; // exp(H0 - H0)
111 
112  // Build a trajectory until the NUTS criterion is no longer satisfied
113  this->depth_ = 0;
114  this->divergent_ = 0;
115 
116  while (this->depth_ < this->max_depth_) {
117  // Build a new subtree in a random direction
118  bool valid_subtree = false;
119  double ave_subtree = 0;
120  double log_sum_weight_subtree
121  = -std::numeric_limits<double>::infinity();
122 
123  if (this->rand_uniform_() > 0.5) {
124  this->z_.ps_point::operator=(z_plus);
125  valid_subtree
126  = build_tree(this->depth_, z_propose,
127  ave_subtree, log_sum_weight_subtree,
128  H0, 1, n_leapfrog, sum_metro_prob,
129  info_writer, error_writer);
130  z_plus.ps_point::operator=(this->z_);
131  } else {
132  this->z_.ps_point::operator=(z_minus);
133  valid_subtree
134  = build_tree(this->depth_, z_propose,
135  ave_subtree, log_sum_weight_subtree,
136  H0, -1, n_leapfrog, sum_metro_prob,
137  info_writer, error_writer);
138  z_minus.ps_point::operator=(this->z_);
139  }
140 
141  if (!valid_subtree) break;
142  stable_sum(ave, log_sum_weight,
143  ave_subtree, log_sum_weight_subtree,
144  ave, log_sum_weight);
145 
146  // Sample from an accepted subtree
147  ++(this->depth_);
148 
149  double accept_prob
150  = std::exp(log_sum_weight_subtree - log_sum_weight);
151  if (this->rand_uniform_() < accept_prob)
152  z_sample = z_propose;
153 
154  // Break if exhaustion criterion is satisfied
155  if (std::fabs(ave) < x_delta_)
156  break;
157  }
158 
159  this->n_leapfrog_ = n_leapfrog;
160 
161  // Compute average acceptance probabilty across entire trajectory,
162  // even over subtrees that may have been rejected
163  double accept_prob
164  = sum_metro_prob / static_cast<double>(n_leapfrog + 1);
165 
166  this->z_.ps_point::operator=(z_sample);
167  this->energy_ = this->hamiltonian_.H(this->z_);
168  return sample(this->z_.q, -this->z_.V, accept_prob);
169  }
170 
171  void get_sampler_param_names(std::vector<std::string>& names) {
172  names.push_back("stepsize__");
173  names.push_back("treedepth__");
174  names.push_back("n_leapfrog__");
175  names.push_back("divergent__");
176  names.push_back("energy__");
177  }
178 
179  void get_sampler_params(std::vector<double>& values) {
180  values.push_back(this->epsilon_);
181  values.push_back(this->depth_);
182  values.push_back(this->n_leapfrog_);
183  values.push_back(this->divergent_);
184  values.push_back(this->energy_);
185  }
186 
203  int build_tree(int depth, ps_point& z_propose,
204  double& ave, double& log_sum_weight,
205  double H0, double sign, int& n_leapfrog,
206  double& sum_metro_prob,
209  // Base case
210  if (depth == 0) {
211  this->integrator_.evolve(this->z_, this->hamiltonian_,
212  sign * this->epsilon_,
213  info_writer, error_writer);
214  ++n_leapfrog;
215 
216  double h = this->hamiltonian_.H(this->z_);
217  if (boost::math::isnan(h))
218  h = std::numeric_limits<double>::infinity();
219 
220  if ((h - H0) > this->max_deltaH_) this->divergent_ = true;
221 
222  double dG_dt = this->hamiltonian_.dG_dt(this->z_,
223  info_writer,
224  error_writer);
225 
226  stable_sum(ave, log_sum_weight,
227  dG_dt, H0 - h,
228  ave, log_sum_weight);
229 
230  if (H0 - h > 0)
231  sum_metro_prob += 1;
232  else
233  sum_metro_prob += std::exp(H0 - h);
234 
235  z_propose = this->z_;
236 
237  return !this->divergent_;
238  }
239  // General recursion
240 
241  // Build the left subtree
242  double ave_left = 0;
243  double log_sum_weight_left = -std::numeric_limits<double>::infinity();
244 
245  bool valid_left
246  = build_tree(depth - 1, z_propose,
247  ave_left, log_sum_weight_left,
248  H0, sign, n_leapfrog, sum_metro_prob,
249  info_writer, error_writer);
250 
251  if (!valid_left) return false;
252  stable_sum(ave, log_sum_weight,
253  ave_left, log_sum_weight_left,
254  ave, log_sum_weight);
255 
256  // Build the right subtree
257  ps_point z_propose_right(this->z_);
258  double ave_right = 0;
259  double log_sum_weight_right = -std::numeric_limits<double>::infinity();
260 
261  bool valid_right
262  = build_tree(depth - 1, z_propose_right,
263  ave_right, log_sum_weight_right,
264  H0, sign, n_leapfrog, sum_metro_prob,
265  info_writer, error_writer);
266 
267  if (!valid_right) return false;
268  stable_sum(ave, log_sum_weight,
269  ave_right, log_sum_weight_right,
270  ave, log_sum_weight);
271 
272  // Multinomial sample from right subtree
273  double ave_subtree;
274  double log_sum_weight_subtree;
275  stable_sum(ave_left, log_sum_weight_left,
276  ave_right, log_sum_weight_right,
277  ave_subtree, log_sum_weight_subtree);
278 
279  double accept_prob
280  = std::exp(log_sum_weight_right - log_sum_weight_subtree);
281  if (this->rand_uniform_() < accept_prob)
282  z_propose = z_propose_right;
283 
284  return std::fabs(ave_subtree) >= x_delta_;
285  }
286 
287  int depth_;
289  double max_deltaH_;
290  double x_delta_;
291 
294  double energy_;
295  };
296 
297  } // mcmc
298 } // stan
299 #endif
Hamiltonian< Model, BaseRNG >::PointType z_
Definition: base_hmc.hpp:164
base_xhmc(const Model &model, BaseRNG &rng)
Definition: base_xhmc.hpp:60
sample transition(sample &init_sample, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Definition: base_xhmc.hpp:87
Exhaustive Hamiltonian Monte Carlo (XHMC) with multinomial sampling.
Definition: base_xhmc.hpp:58
void sample(stan::mcmc::base_mcmc *sampler, int num_warmup, int num_samples, int num_thin, int refresh, bool save, stan::services::sample::mcmc_writer< Model, SampleRecorder, DiagnosticRecorder, MessageRecorder > &mcmc_writer, stan::mcmc::sample &init_s, Model &model, RNG &base_rng, const std::string &prefix, const std::string &suffix, std::ostream &o, StartTransitionCallback &callback, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Definition: sample.hpp:17
Probability, optimization and sampling library.
Point in a generic phase space.
Definition: ps_point.hpp:17
void set_max_depth(int d)
Definition: base_xhmc.hpp:68
int build_tree(int depth, ps_point &z_propose, double &ave, double &log_sum_weight, double H0, double sign, int &n_leapfrog, double &sum_metro_prob, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Recursively build a new subtree to completion or until the subtree becomes invalid.
Definition: base_xhmc.hpp:203
base_writer is an abstract base class defining the interface for Stan writer callbacks.
Definition: base_writer.hpp:20
void stable_sum(double a1, double log_w1, double a2, double log_w2, double &sum_a, double &log_sum_w)
a1 and a2 are running averages of the form and the weights are the respective normalizing constants...
Definition: base_xhmc.hpp:39
void set_x_delta(double d)
Definition: base_xhmc.hpp:77
void get_sampler_params(std::vector< double > &values)
Definition: base_xhmc.hpp:179
void seed(const Eigen::VectorXd &q)
Definition: base_hmc.hpp:53
void get_sampler_param_names(std::vector< std::string > &names)
Definition: base_xhmc.hpp:171
boost::uniform_01< BaseRNG & > rand_uniform_
Definition: base_hmc.hpp:171
void set_max_deltaH(double d)
Definition: base_xhmc.hpp:73
double cont_params(int k) const
Definition: sample.hpp:24
Hamiltonian< Model, BaseRNG > hamiltonian_
Definition: base_hmc.hpp:166
Integrator< Hamiltonian< Model, BaseRNG > > integrator_
Definition: base_hmc.hpp:165

     [ Stan Home Page ] © 2011–2016, Stan Development Team.