Stan  2.14.0
probability, sampling & optimization
base_nuts.hpp
Go to the documentation of this file.
1 #ifndef STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
2 #define STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
3 
5 #include <boost/math/special_functions/fpclassify.hpp>
6 #include <stan/math/prim/scal.hpp>
9 #include <algorithm>
10 #include <cmath>
11 #include <limits>
12 #include <string>
13 #include <vector>
14 
15 namespace stan {
16  namespace mcmc {
20  template <class Model, template<class, class> class Hamiltonian,
21  template<class> class Integrator, class BaseRNG>
22  class base_nuts : public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
23  public:
24  base_nuts(const Model& model, BaseRNG& rng)
25  : base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
26  depth_(0), max_depth_(5), max_deltaH_(1000),
27  n_leapfrog_(0), divergent_(0), energy_(0) {
28  }
29 
31 
32  void set_max_depth(int d) {
33  if (d > 0)
34  max_depth_ = d;
35  }
36 
37  void set_max_delta(double d) {
38  max_deltaH_ = d;
39  }
40 
41  int get_max_depth() { return this->max_depth_; }
42  double get_max_delta() { return this->max_deltaH_; }
43 
44  sample
45  transition(sample& init_sample,
48  // Initialize the algorithm
49  this->sample_stepsize();
50 
51  this->seed(init_sample.cont_params());
52 
53  this->hamiltonian_.sample_p(this->z_, this->rand_int_);
54  this->hamiltonian_.init(this->z_, info_writer, error_writer);
55 
56  ps_point z_plus(this->z_);
57  ps_point z_minus(z_plus);
58 
59  ps_point z_sample(z_plus);
60  ps_point z_propose(z_plus);
61 
62  Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_);
63  Eigen::VectorXd p_sharp_dummy = p_sharp_plus;
64  Eigen::VectorXd p_sharp_minus = p_sharp_plus;
65  Eigen::VectorXd rho = this->z_.p;
66 
67  double log_sum_weight = 0; // log(exp(H0 - H0))
68  double H0 = this->hamiltonian_.H(this->z_);
69  int n_leapfrog = 0;
70  double sum_metro_prob = 1; // exp(H0 - H0)
71 
72  // Build a trajectory until the NUTS criterion is no longer satisfied
73  this->depth_ = 0;
74  this->divergent_ = false;
75 
76  while (this->depth_ < this->max_depth_) {
77  // Build a new subtree in a random direction
78  Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(rho.size());
79  bool valid_subtree = false;
80  double log_sum_weight_subtree
81  = -std::numeric_limits<double>::infinity();
82 
83  if (this->rand_uniform_() > 0.5) {
84  this->z_.ps_point::operator=(z_plus);
85  valid_subtree
86  = build_tree(this->depth_, z_propose,
87  p_sharp_dummy, p_sharp_plus, rho_subtree,
88  H0, 1, n_leapfrog,
89  log_sum_weight_subtree, sum_metro_prob,
90  info_writer, error_writer);
91  z_plus.ps_point::operator=(this->z_);
92  } else {
93  this->z_.ps_point::operator=(z_minus);
94  valid_subtree
95  = build_tree(this->depth_, z_propose,
96  p_sharp_dummy, p_sharp_minus, rho_subtree,
97  H0, -1, n_leapfrog,
98  log_sum_weight_subtree, sum_metro_prob,
99  info_writer, error_writer);
100  z_minus.ps_point::operator=(this->z_);
101  }
102 
103  if (!valid_subtree) break;
104 
105  // Sample from an accepted subtree
106  ++(this->depth_);
107 
108  if (log_sum_weight_subtree > log_sum_weight) {
109  z_sample = z_propose;
110  } else {
111  double accept_prob
112  = std::exp(log_sum_weight_subtree - log_sum_weight);
113  if (this->rand_uniform_() < accept_prob)
114  z_sample = z_propose;
115  }
116 
117  log_sum_weight
118  = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
119 
120  // Break when NUTS criterion is no longer satisfied
121  rho += rho_subtree;
122  if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho))
123  break;
124  }
125 
126  this->n_leapfrog_ = n_leapfrog;
127 
128  // Compute average acceptance probabilty across entire trajectory,
129  // even over subtrees that may have been rejected
130  double accept_prob
131  = sum_metro_prob / static_cast<double>(n_leapfrog + 1);
132 
133  this->z_.ps_point::operator=(z_sample);
134  this->energy_ = this->hamiltonian_.H(this->z_);
135  return sample(this->z_.q, -this->z_.V, accept_prob);
136  }
137 
138  void get_sampler_param_names(std::vector<std::string>& names) {
139  names.push_back("stepsize__");
140  names.push_back("treedepth__");
141  names.push_back("n_leapfrog__");
142  names.push_back("divergent__");
143  names.push_back("energy__");
144  }
145 
146  void get_sampler_params(std::vector<double>& values) {
147  values.push_back(this->epsilon_);
148  values.push_back(this->depth_);
149  values.push_back(this->n_leapfrog_);
150  values.push_back(this->divergent_);
151  values.push_back(this->energy_);
152  }
153 
154  virtual bool compute_criterion(Eigen::VectorXd& p_sharp_minus,
155  Eigen::VectorXd& p_sharp_plus,
156  Eigen::VectorXd& rho) {
157  return p_sharp_plus.dot(rho) > 0
158  && p_sharp_minus.dot(rho) > 0;
159  }
160 
179  bool build_tree(int depth, ps_point& z_propose,
180  Eigen::VectorXd& p_sharp_left,
181  Eigen::VectorXd& p_sharp_right,
182  Eigen::VectorXd& rho,
183  double H0, double sign, int& n_leapfrog,
184  double& log_sum_weight, double& sum_metro_prob,
187  // Base case
188  if (depth == 0) {
189  this->integrator_.evolve(this->z_, this->hamiltonian_,
190  sign * this->epsilon_,
191  info_writer, error_writer);
192  ++n_leapfrog;
193 
194  double h = this->hamiltonian_.H(this->z_);
195  if (boost::math::isnan(h))
196  h = std::numeric_limits<double>::infinity();
197 
198  if ((h - H0) > this->max_deltaH_) this->divergent_ = true;
199 
200  log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h);
201 
202  if (H0 - h > 0)
203  sum_metro_prob += 1;
204  else
205  sum_metro_prob += std::exp(H0 - h);
206 
207  z_propose = this->z_;
208  rho += this->z_.p;
209 
210  p_sharp_left = this->hamiltonian_.dtau_dp(this->z_);
211  p_sharp_right = p_sharp_left;
212 
213  return !this->divergent_;
214  }
215  // General recursion
216  Eigen::VectorXd p_sharp_dummy(this->z_.p.size());
217 
218  // Build the left subtree
219  double log_sum_weight_left = -std::numeric_limits<double>::infinity();
220  Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size());
221 
222  bool valid_left
223  = build_tree(depth - 1, z_propose,
224  p_sharp_left, p_sharp_dummy, rho_left,
225  H0, sign, n_leapfrog,
226  log_sum_weight_left, sum_metro_prob,
227  info_writer, error_writer);
228 
229  if (!valid_left) return false;
230 
231  // Build the right subtree
232  ps_point z_propose_right(this->z_);
233 
234  double log_sum_weight_right = -std::numeric_limits<double>::infinity();
235  Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size());
236 
237  bool valid_right
238  = build_tree(depth - 1, z_propose_right,
239  p_sharp_dummy, p_sharp_right, rho_right,
240  H0, sign, n_leapfrog,
241  log_sum_weight_right, sum_metro_prob,
242  info_writer, error_writer);
243 
244  if (!valid_right) return false;
245 
246  // Multinomial sample from right subtree
247  double log_sum_weight_subtree
248  = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right);
249  log_sum_weight
250  = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
251 
252  if (log_sum_weight_right > log_sum_weight_subtree) {
253  z_propose = z_propose_right;
254  } else {
255  double accept_prob
256  = std::exp(log_sum_weight_right - log_sum_weight_subtree);
257  if (this->rand_uniform_() < accept_prob)
258  z_propose = z_propose_right;
259  }
260 
261  Eigen::VectorXd rho_subtree = rho_left + rho_right;
262  rho += rho_subtree;
263 
264  return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree);
265  }
266 
267  int depth_;
269  double max_deltaH_;
270 
273  double energy_;
274  };
275 
276  } // mcmc
277 } // stan
278 #endif
base_nuts(const Model &model, BaseRNG &rng)
Definition: base_nuts.hpp:24
Hamiltonian< Model, BaseRNG >::PointType z_
Definition: base_hmc.hpp:164
void sample(stan::mcmc::base_mcmc *sampler, int num_warmup, int num_samples, int num_thin, int refresh, bool save, stan::services::sample::mcmc_writer< Model, SampleRecorder, DiagnosticRecorder, MessageRecorder > &mcmc_writer, stan::mcmc::sample &init_s, Model &model, RNG &base_rng, const std::string &prefix, const std::string &suffix, std::ostream &o, StartTransitionCallback &callback, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Definition: sample.hpp:17
void get_sampler_params(std::vector< double > &values)
Definition: base_nuts.hpp:146
Probability, optimization and sampling library.
Point in a generic phase space.
Definition: ps_point.hpp:17
bool build_tree(int depth, ps_point &z_propose, Eigen::VectorXd &p_sharp_left, Eigen::VectorXd &p_sharp_right, Eigen::VectorXd &rho, double H0, double sign, int &n_leapfrog, double &log_sum_weight, double &sum_metro_prob, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Recursively build a new subtree to completion or until the subtree becomes invalid.
Definition: base_nuts.hpp:179
sample transition(sample &init_sample, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Definition: base_nuts.hpp:45
void set_max_delta(double d)
Definition: base_nuts.hpp:37
The No-U-Turn sampler (NUTS) with multinomial sampling.
Definition: base_nuts.hpp:22
base_writer is an abstract base class defining the interface for Stan writer callbacks.
Definition: base_writer.hpp:20
void seed(const Eigen::VectorXd &q)
Definition: base_hmc.hpp:53
virtual bool compute_criterion(Eigen::VectorXd &p_sharp_minus, Eigen::VectorXd &p_sharp_plus, Eigen::VectorXd &rho)
Definition: base_nuts.hpp:154
boost::uniform_01< BaseRNG & > rand_uniform_
Definition: base_hmc.hpp:171
void set_max_depth(int d)
Definition: base_nuts.hpp:32
void get_sampler_param_names(std::vector< std::string > &names)
Definition: base_nuts.hpp:138
double cont_params(int k) const
Definition: sample.hpp:24
Hamiltonian< Model, BaseRNG > hamiltonian_
Definition: base_hmc.hpp:166
Integrator< Hamiltonian< Model, BaseRNG > > integrator_
Definition: base_hmc.hpp:165

     [ Stan Home Page ] © 2011–2016, Stan Development Team.