1 #ifndef STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP 2 #define STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP 5 #include <boost/math/special_functions/fpclassify.hpp> 31 template <
class Model,
template<
class,
class>
class Hamiltonian,
32 template<
class>
class Integrator,
class BaseRNG>
34 public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
37 base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
38 depth_(0), max_depth_(5), max_delta_(1000),
39 n_leapfrog_(0), divergent_(0), energy_(0) {
61 this->sample_stepsize();
67 this->hamiltonian_.sample_p(this->z_, this->rand_int_);
68 this->hamiltonian_.init(this->z_, info_writer, error_writer);
78 Eigen::VectorXd rho_init = this->z_.p;
79 Eigen::VectorXd rho_plus(n_cont); rho_plus.setZero();
80 Eigen::VectorXd rho_minus(n_cont); rho_minus.setZero();
82 util.
H0 = this->hamiltonian_.H(this->z_);
85 util.
log_u = std::log(this->rand_uniform_());
97 while (util.
criterion && (this->depth_ <= this->max_depth_)) {
100 Eigen::VectorXd* rho = 0;
102 if (this->rand_uniform_() > 0.5) {
113 this->z_.ps_point::operator=(*z);
115 int n_valid_subtree = build_tree(depth_, *rho, 0, z_propose, util,
116 info_writer, error_writer);
125 double subtree_prob = 0;
128 subtree_prob =
static_cast<double>(n_valid_subtree) /
129 static_cast<double>(n_valid);
131 subtree_prob = n_valid_subtree ? 1 : 0;
134 if (this->rand_uniform_() < subtree_prob)
135 z_sample = z_propose;
137 n_valid += n_valid_subtree;
140 this->z_.ps_point::operator=(z_plus);
141 Eigen::VectorXd delta_rho = rho_minus + rho_init + rho_plus;
143 util.
criterion = compute_criterion(z_minus, this->z_, delta_rho);
146 this->n_leapfrog_ = util.
n_tree;
148 double accept_prob = util.
sum_prob /
static_cast<double>(util.
n_tree);
150 this->z_.ps_point::operator=(z_sample);
151 this->energy_ = this->hamiltonian_.H(this->z_);
152 return sample(this->z_.q, - this->z_.V, accept_prob);
156 names.push_back(
"stepsize__");
157 names.push_back(
"treedepth__");
158 names.push_back(
"n_leapfrog__");
159 names.push_back(
"divergent__");
160 names.push_back(
"energy__");
164 values.push_back(this->epsilon_);
165 values.push_back(this->depth_);
166 values.push_back(this->n_leapfrog_);
167 values.push_back(this->divergent_);
168 values.push_back(this->energy_);
171 virtual bool compute_criterion(
ps_point& start,
172 typename Hamiltonian<Model, BaseRNG>
174 Eigen::VectorXd& rho) = 0;
184 this->integrator_.evolve(this->z_, this->hamiltonian_,
185 util.
sign * this->epsilon_,
186 info_writer, error_writer);
189 if (z_init_parent) *z_init_parent = this->z_;
190 z_propose = this->z_;
192 double h = this->hamiltonian_.H(this->z_);
193 if (boost::math::isnan(h))
194 h = std::numeric_limits<double>::infinity();
197 if (!util.
criterion) ++(this->divergent_);
199 util.
sum_prob += std::min(1.0, std::exp(util.
H0 - h));
202 return (util.
log_u + (h - util.
H0) < 0);
206 Eigen::VectorXd left_subtree_rho(rho.size());
207 left_subtree_rho.setZero();
210 int n1 = build_tree(depth - 1, left_subtree_rho, &z_init,
212 info_writer, error_writer);
214 if (z_init_parent) *z_init_parent = z_init;
218 Eigen::VectorXd right_subtree_rho(rho.size());
219 right_subtree_rho.setZero();
222 int n2 = build_tree(depth - 1, right_subtree_rho, 0,
223 z_propose_right, util,
224 info_writer, error_writer);
226 double accept_prob =
static_cast<double>(n2) /
227 static_cast<double>(n1 + n2);
229 if ( util.
criterion && (this->rand_uniform_() < accept_prob) )
230 z_propose = z_propose_right;
232 Eigen::VectorXd& subtree_rho = left_subtree_rho;
233 subtree_rho += right_subtree_rho;
237 util.
criterion &= compute_criterion(z_init, this->z_, subtree_rho);
void sample(stan::mcmc::base_mcmc *sampler, int num_warmup, int num_samples, int num_thin, int refresh, bool save, stan::services::sample::mcmc_writer< Model, SampleRecorder, DiagnosticRecorder, MessageRecorder > &mcmc_writer, stan::mcmc::sample &init_s, Model &model, RNG &base_rng, const std::string &prefix, const std::string &suffix, std::ostream &o, StartTransitionCallback &callback, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Probability, optimization and sampling library.
void set_max_depth(int d)
void set_max_delta(double d)
Point in a generic phase space.
int build_tree(int depth, Eigen::VectorXd &rho, ps_point *z_init_parent, ps_point &z_propose, nuts_util &util, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
base_writer is an abstract base class defining the interface for Stan writer callbacks.
base_nuts_classic(const Model &model, BaseRNG &rng)
sample transition(sample &init_sample, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
void get_sampler_param_names(std::vector< std::string > &names)
double cont_params(int k) const
void get_sampler_params(std::vector< double > &values)