1 #ifndef STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP 2 #define STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP 5 #include <boost/math/special_functions/fpclassify.hpp> 6 #include <stan/math/prim/scal.hpp> 20 template <
class Model,
template<
class,
class>
class Hamiltonian,
21 template<
class>
class Integrator,
class BaseRNG>
25 :
base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
62 Eigen::VectorXd p_sharp_plus = this->
hamiltonian_.dtau_dp(this->
z_);
63 Eigen::VectorXd p_sharp_dummy = p_sharp_plus;
64 Eigen::VectorXd p_sharp_minus = p_sharp_plus;
65 Eigen::VectorXd rho = this->
z_.p;
67 double log_sum_weight = 0;
70 double sum_metro_prob = 1;
78 Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(rho.size());
79 bool valid_subtree =
false;
80 double log_sum_weight_subtree
81 = -std::numeric_limits<double>::infinity();
84 this->
z_.ps_point::operator=(z_plus);
87 p_sharp_dummy, p_sharp_plus, rho_subtree,
89 log_sum_weight_subtree, sum_metro_prob,
90 info_writer, error_writer);
91 z_plus.ps_point::operator=(this->
z_);
93 this->
z_.ps_point::operator=(z_minus);
96 p_sharp_dummy, p_sharp_minus, rho_subtree,
98 log_sum_weight_subtree, sum_metro_prob,
99 info_writer, error_writer);
100 z_minus.ps_point::operator=(this->
z_);
103 if (!valid_subtree)
break;
108 if (log_sum_weight_subtree > log_sum_weight) {
109 z_sample = z_propose;
112 = std::exp(log_sum_weight_subtree - log_sum_weight);
114 z_sample = z_propose;
118 = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
131 = sum_metro_prob /
static_cast<double>(n_leapfrog + 1);
133 this->
z_.ps_point::operator=(z_sample);
135 return sample(this->
z_.q, -this->z_.V, accept_prob);
139 names.push_back(
"stepsize__");
140 names.push_back(
"treedepth__");
141 names.push_back(
"n_leapfrog__");
142 names.push_back(
"divergent__");
143 names.push_back(
"energy__");
148 values.push_back(this->
depth_);
151 values.push_back(this->
energy_);
155 Eigen::VectorXd& p_sharp_plus,
156 Eigen::VectorXd& rho) {
157 return p_sharp_plus.dot(rho) > 0
158 && p_sharp_minus.dot(rho) > 0;
180 Eigen::VectorXd& p_sharp_left,
181 Eigen::VectorXd& p_sharp_right,
182 Eigen::VectorXd& rho,
183 double H0,
double sign,
int& n_leapfrog,
184 double& log_sum_weight,
double& sum_metro_prob,
191 info_writer, error_writer);
195 if (boost::math::isnan(h))
196 h = std::numeric_limits<double>::infinity();
200 log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h);
205 sum_metro_prob += std::exp(H0 - h);
207 z_propose = this->
z_;
211 p_sharp_right = p_sharp_left;
216 Eigen::VectorXd p_sharp_dummy(this->
z_.p.size());
219 double log_sum_weight_left = -std::numeric_limits<double>::infinity();
220 Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size());
224 p_sharp_left, p_sharp_dummy, rho_left,
225 H0, sign, n_leapfrog,
226 log_sum_weight_left, sum_metro_prob,
227 info_writer, error_writer);
229 if (!valid_left)
return false;
234 double log_sum_weight_right = -std::numeric_limits<double>::infinity();
235 Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size());
239 p_sharp_dummy, p_sharp_right, rho_right,
240 H0, sign, n_leapfrog,
241 log_sum_weight_right, sum_metro_prob,
242 info_writer, error_writer);
244 if (!valid_right)
return false;
247 double log_sum_weight_subtree
248 = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right);
250 = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
252 if (log_sum_weight_right > log_sum_weight_subtree) {
253 z_propose = z_propose_right;
256 = std::exp(log_sum_weight_right - log_sum_weight_subtree);
258 z_propose = z_propose_right;
261 Eigen::VectorXd rho_subtree = rho_left + rho_right;
base_nuts(const Model &model, BaseRNG &rng)
Hamiltonian< Model, BaseRNG >::PointType z_
void sample(stan::mcmc::base_mcmc *sampler, int num_warmup, int num_samples, int num_thin, int refresh, bool save, stan::services::sample::mcmc_writer< Model, SampleRecorder, DiagnosticRecorder, MessageRecorder > &mcmc_writer, stan::mcmc::sample &init_s, Model &model, RNG &base_rng, const std::string &prefix, const std::string &suffix, std::ostream &o, StartTransitionCallback &callback, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
void get_sampler_params(std::vector< double > &values)
Probability, optimization and sampling library.
Point in a generic phase space.
bool build_tree(int depth, ps_point &z_propose, Eigen::VectorXd &p_sharp_left, Eigen::VectorXd &p_sharp_right, Eigen::VectorXd &rho, double H0, double sign, int &n_leapfrog, double &log_sum_weight, double &sum_metro_prob, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Recursively build a new subtree to completion or until the subtree becomes invalid.
sample transition(sample &init_sample, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
void set_max_delta(double d)
The No-U-Turn sampler (NUTS) with multinomial sampling.
base_writer is an abstract base class defining the interface for Stan writer callbacks.
void seed(const Eigen::VectorXd &q)
virtual bool compute_criterion(Eigen::VectorXd &p_sharp_minus, Eigen::VectorXd &p_sharp_plus, Eigen::VectorXd &rho)
boost::uniform_01< BaseRNG & > rand_uniform_
void set_max_depth(int d)
void get_sampler_param_names(std::vector< std::string > &names)
double cont_params(int k) const
Hamiltonian< Model, BaseRNG > hamiltonian_
Integrator< Hamiltonian< Model, BaseRNG > > integrator_