Stan  2.14.0
probability, sampling & optimization
function_signatures_def.hpp
Go to the documentation of this file.
1 #ifndef STAN_LANG_AST_SIGS_FUNCTION_SIGNATURES_DEF_HPP
2 #define STAN_LANG_AST_SIGS_FUNCTION_SIGNATURES_DEF_HPP
3 
4 #include <stan/lang/ast.hpp>
5 #include <limits>
6 #include <map>
7 #include <set>
8 #include <string>
9 #include <utility>
10 #include <vector>
11 
12 namespace stan {
13  namespace lang {
14 
16  if (sigs_ == 0) return;
17  delete sigs_;
18  sigs_ = 0;
19  }
20 
22  // TODO(carpenter): for threaded autodiff, requires double-check lock
23  if (!sigs_)
24  sigs_ = new function_signatures;
25  return *sigs_;
26  }
27 
29  const std::pair<std::string, function_signature_t>& name_sig) {
30  user_defined_set_.insert(name_sig);
31  }
32 
34  const std::pair<std::string, function_signature_t>& name_sig) {
35  return user_defined_set_.find(name_sig) != user_defined_set_.end();
36  }
37 
38  bool function_signatures::is_defined(const std::string& name,
39  const function_signature_t& sig) {
40  if (sigs_map_.find(name) == sigs_map_.end())
41  return false;
42  const std::vector<function_signature_t> sigs = sigs_map_[name];
43  for (size_t i = 0; i < sigs.size(); ++i)
44  if (sig.second == sigs[i].second)
45  return true;
46  return false;
47  }
48 
50  const {
51  using std::map;
52  using std::string;
53  using std::vector;
54  map<string, vector<function_signature_t> >::const_iterator it
55  = sigs_map_.find(fun);
56  if (it == sigs_map_.end())
57  return false;
58  const vector<function_signature_t> sigs = it->second;
59  for (size_t i = 0; i < sigs.size(); ++i) {
60  if (sigs[i].second.size() == 0
61  || sigs[i].second[0].base_type_ != INT_T)
62  return false;
63  }
64  return true;
65  }
66 
67  void function_signatures::add(const std::string& name,
68  const expr_type& result_type,
69  const std::vector<expr_type>& arg_types) {
70  sigs_map_[name].push_back(function_signature_t(result_type, arg_types));
71  }
72 
73  void function_signatures::add(const std::string& name,
74  const expr_type& result_type) {
75  std::vector<expr_type> arg_types;
76  add(name, result_type, arg_types);
77  }
78 
79  void function_signatures::add(const std::string& name,
80  const expr_type& result_type,
81  const expr_type& arg_type) {
82  std::vector<expr_type> arg_types;
83  arg_types.push_back(arg_type);
84  add(name, result_type, arg_types);
85  }
86 
87  void function_signatures::add(const std::string& name,
88  const expr_type& result_type,
89  const expr_type& arg_type1,
90  const expr_type& arg_type2) {
91  std::vector<expr_type> arg_types;
92  arg_types.push_back(arg_type1);
93  arg_types.push_back(arg_type2);
94  add(name, result_type, arg_types);
95  }
96 
97  void function_signatures::add(const std::string& name,
98  const expr_type& result_type,
99  const expr_type& arg_type1,
100  const expr_type& arg_type2,
101  const expr_type& arg_type3) {
102  std::vector<expr_type> arg_types;
103  arg_types.push_back(arg_type1);
104  arg_types.push_back(arg_type2);
105  arg_types.push_back(arg_type3);
106  add(name, result_type, arg_types);
107  }
108 
109  void function_signatures::add(const std::string& name,
110  const expr_type& result_type,
111  const expr_type& arg_type1,
112  const expr_type& arg_type2,
113  const expr_type& arg_type3,
114  const expr_type& arg_type4) {
115  std::vector<expr_type> arg_types;
116  arg_types.push_back(arg_type1);
117  arg_types.push_back(arg_type2);
118  arg_types.push_back(arg_type3);
119  arg_types.push_back(arg_type4);
120  add(name, result_type, arg_types);
121  }
122 
123  void function_signatures::add(const std::string& name,
124  const expr_type& result_type,
125  const expr_type& arg_type1,
126  const expr_type& arg_type2,
127  const expr_type& arg_type3,
128  const expr_type& arg_type4,
129  const expr_type& arg_type5) {
130  std::vector<expr_type> arg_types;
131  arg_types.push_back(arg_type1);
132  arg_types.push_back(arg_type2);
133  arg_types.push_back(arg_type3);
134  arg_types.push_back(arg_type4);
135  arg_types.push_back(arg_type5);
136  add(name, result_type, arg_types);
137  }
138 
139  void function_signatures::add(const std::string& name,
140  const expr_type& result_type,
141  const expr_type& arg_type1,
142  const expr_type& arg_type2,
143  const expr_type& arg_type3,
144  const expr_type& arg_type4,
145  const expr_type& arg_type5,
146  const expr_type& arg_type6) {
147  std::vector<expr_type> arg_types;
148  arg_types.push_back(arg_type1);
149  arg_types.push_back(arg_type2);
150  arg_types.push_back(arg_type3);
151  arg_types.push_back(arg_type4);
152  arg_types.push_back(arg_type5);
153  arg_types.push_back(arg_type6);
154  add(name, result_type, arg_types);
155  }
156 
157  void function_signatures::add(const std::string& name,
158  const expr_type& result_type,
159  const expr_type& arg_type1,
160  const expr_type& arg_type2,
161  const expr_type& arg_type3,
162  const expr_type& arg_type4,
163  const expr_type& arg_type5,
164  const expr_type& arg_type6,
165  const expr_type& arg_type7) {
166  std::vector<expr_type> arg_types;
167  arg_types.push_back(arg_type1);
168  arg_types.push_back(arg_type2);
169  arg_types.push_back(arg_type3);
170  arg_types.push_back(arg_type4);
171  arg_types.push_back(arg_type5);
172  arg_types.push_back(arg_type6);
173  arg_types.push_back(arg_type7);
174  add(name, result_type, arg_types);
175  }
176 
177  void function_signatures::add_nullary(const::std::string& name) {
178  add(name, DOUBLE_T);
179  }
180 
181  void function_signatures::add_unary(const::std::string& name) {
182  add(name, DOUBLE_T, DOUBLE_T);
183  }
184 
186  name) {
187  for (size_t i = 0; i < 8; ++i) {
188  add(name, expr_type(DOUBLE_T, i), expr_type(INT_T, i));
189  add(name, expr_type(DOUBLE_T, i), expr_type(DOUBLE_T, i));
190  add(name, expr_type(VECTOR_T, i), expr_type(VECTOR_T, i));
192  add(name, expr_type(MATRIX_T, i), expr_type(MATRIX_T, i));
193  }
194  }
195 
196  void function_signatures::add_binary(const::std::string& name) {
197  add(name, DOUBLE_T, DOUBLE_T, DOUBLE_T);
198  }
199 
200  void function_signatures::add_ternary(const::std::string& name) {
202  }
203 
204  void function_signatures::add_quaternary(const::std::string& name) {
206  }
207 
209  const std::vector<expr_type>& call_args,
210  const std::vector<expr_type>& sig_args) {
211  if (call_args.size() != sig_args.size()) {
212  return -1; // failure
213  }
214  int num_promotions = 0;
215  for (size_t i = 0; i < call_args.size(); ++i) {
216  if (call_args[i] == sig_args[i]) {
217  continue;
218  } else if (call_args[i].is_primitive_int()
219  && sig_args[i].is_primitive_double()) {
220  ++num_promotions;
221  } else {
222  return -1; // failed match
223  }
224  }
225  return num_promotions;
226  }
227 
228  int function_signatures::get_signature_matches(const std::string& name,
229  const std::vector<expr_type>& args,
230  function_signature_t& signature) {
231  if (!has_key(name)) return 0;
232  std::vector<function_signature_t> signatures = sigs_map_[name];
233  size_t min_promotions = std::numeric_limits<size_t>::max();
234  size_t num_matches = 0;
235  for (size_t i = 0; i < signatures.size(); ++i) {
236  signature = signatures[i];
237  int promotions = num_promotions(args, signature.second);
238  if (promotions < 0) continue; // no match
239  size_t promotions_ui = static_cast<size_t>(promotions);
240  if (promotions_ui < min_promotions) {
241  min_promotions = promotions_ui;
242  num_matches = 1;
243  } else if (promotions_ui == min_promotions) {
244  ++num_matches;
245  }
246  }
247  return num_matches;
248  }
249 
250 
251 
252  bool is_binary_operator(const std::string& name) {
253  return name == "add"
254  || name == "subtract"
255  || name == "multiply"
256  || name == "divide"
257  || name == "modulus"
258  || name == "mdivide_left"
259  || name == "mdivide_right"
260  || name == "elt_multiply"
261  || name == "elt_divide";
262  }
263 
264  bool is_unary_operator(const std::string& name) {
265  return name == "minus"
266  || name == "logical_negation";
267  }
268 
269  bool is_unary_postfix_operator(const std::string& name) {
270  return name == "transpose";
271  }
272 
273  bool is_operator(const std::string& name) {
274  return is_binary_operator(name)
275  || is_unary_operator(name)
276  || is_unary_postfix_operator(name);
277  }
278 
279 
280 
281 
282  std::string fun_name_to_operator(const std::string& name) {
283  // binary infix (pow handled by parser)
284  if (name == "add") return "+";
285  if (name == "subtract") return "-";
286  if (name == "multiply") return "*";
287  if (name == "divide") return "/";
288  if (name == "modulus") return "%";
289  if (name == "mdivide_left") return "\\";
290  if (name == "mdivide_right") return "/";
291  if (name == "elt_multiply") return ".*";
292  if (name == "elt_divide") return "./";
293 
294  // unary prefix (+ handled by parser)
295  if (name == "minus") return "-";
296  if (name == "logical_negation") return "!";
297 
298  // unary suffix
299  if (name == "transpose") return "'";
300 
301  // none of the above
302  return "ERROR";
303  }
304 
305  void print_signature(const std::string& name,
306  const std::vector<expr_type>& arg_types,
307  bool sampling_error_style,
308  std::ostream& msgs) {
309  static size_t OP_SIZE = std::string("operator").size();
310  msgs << " ";
311  if (name.size() > OP_SIZE && name.substr(0, OP_SIZE) == "operator") {
312  std::string operator_name = name.substr(OP_SIZE);
313  if (arg_types.size() == 2) {
314  msgs << arg_types[0] << " " << operator_name << " " << arg_types[1]
315  << std::endl;
316  return;
317  } else if (arg_types.size() == 1) {
318  if (operator_name == "'") // exception for postfix
319  msgs << arg_types[0] << operator_name << std::endl;
320  else
321  msgs << operator_name << arg_types[0] << std::endl;
322  return;
323  } else {
324  // should not be reachable due to operator grammar
325  // continue on purpose to get more info to user if this happens
326  msgs << "Operators must have 1 or 2 arguments." << std::endl;
327  }
328  }
329  if (sampling_error_style && arg_types.size() > 0)
330  msgs << arg_types[0] << " ~ ";
331  msgs << name << "(";
332  size_t start = sampling_error_style ? 1 : 0;
333  for (size_t j = start; j < arg_types.size(); ++j) {
334  if (j > start) msgs << ", ";
335  msgs << arg_types[j];
336  }
337  msgs << ")" << std::endl;
338  }
339 
341  const std::vector<expr_type>& args,
342  std::ostream& error_msgs,
343  bool sampling_error_style) {
344  std::vector<function_signature_t> signatures = sigs_map_[name];
345  size_t match_index = 0;
346  size_t min_promotions = std::numeric_limits<size_t>::max();
347  size_t num_matches = 0;
348 
349  std::string display_name;
350  if (is_operator(name)) {
351  display_name = "operator" + fun_name_to_operator(name);
352  } else if (sampling_error_style && ends_with("_log", name)) {
353  display_name = name.substr(0, name.size() - 4);
354  } else if (sampling_error_style
355  && (ends_with("_lpdf", name) || ends_with("_lcdf", name))) {
356  display_name = name.substr(0, name.size() - 5);
357  } else {
358  display_name = name;
359  }
360 
361  for (size_t i = 0; i < signatures.size(); ++i) {
362  int promotions = num_promotions(args, signatures[i].second);
363  if (promotions < 0) continue; // no match
364  size_t promotions_ui = static_cast<size_t>(promotions);
365  if (promotions_ui < min_promotions) {
366  min_promotions = promotions_ui;
367  match_index = i;
368  num_matches = 1;
369  } else if (promotions_ui == min_promotions) {
370  ++num_matches;
371  }
372  }
373 
374  if (num_matches == 1)
375  return signatures[match_index].first;
376 
377  // all returns after here are for ill-typed input
378 
379  if (num_matches == 0) {
380  error_msgs << "No matches for: "
381  << std::endl << std::endl;
382  } else {
383  error_msgs << "Ambiguous: "
384  << num_matches << " matches with "
385  << min_promotions << " integer promotions for: "
386  << std::endl;
387  }
388  print_signature(display_name, args, sampling_error_style, error_msgs);
389 
390  if (signatures.size() == 0) {
391  error_msgs << std::endl
392  << (sampling_error_style ? "Distribution " : "Function ")
393  << display_name << " not found.";
394  if (sampling_error_style)
395  error_msgs << " Require function with _lpdf or _lpmf or _log suffix";
396  error_msgs << std::endl;
397  } else {
398  error_msgs << std::endl
399  << "Available argument signatures for "
400  << display_name << ":" << std::endl << std::endl;
401 
402  for (size_t i = 0; i < signatures.size(); ++i) {
403  print_signature(display_name, signatures[i].second,
404  sampling_error_style, error_msgs);
405  }
406  error_msgs << std::endl;
407  }
408  return expr_type(); // ill-formed dummy
409  }
410 
411  function_signatures::function_signatures() {
412 #include <stan/lang/function_signatures.h> // NOLINT
413  }
414 
415  bool function_signatures::has_user_defined_key(const std::string& key)
416  const {
417  using std::pair;
418  using std::set;
419  using std::string;
420  for (set<pair<string, function_signature_t> >::const_iterator
421  it = user_defined_set_.begin();
422  it != user_defined_set_.end();
423  ++it) {
424  if (it->first == key)
425  return true;
426  }
427  return false;
428  }
429 
430  std::set<std::string> function_signatures::key_set() const {
431  using std::map;
432  using std::set;
433  using std::string;
434  using std::vector;
435  set<string> result;
436  for (map<string, vector<function_signature_t> >::const_iterator
437  it = sigs_map_.begin();
438  it != sigs_map_.end();
439  ++it)
440  result.insert(it->first);
441  return result;
442  }
443 
444  bool function_signatures::has_key(const std::string& key) const {
445  return sigs_map_.find(key) != sigs_map_.end();
446  }
447 
453  function_signatures* function_signatures::sigs_ = 0;
454  }
455 }
456 #endif
expr_type get_result_type(const std::string &name, const std::vector< expr_type > &args, std::ostream &error_msgs, bool sampling_error_style=false)
Return the result expression type resulting from applying a function of the speicified name and argum...
void add(const std::string &name, const expr_type &result_type, const std::vector< expr_type > &arg_types)
Add a built-in function with the specified name, result, type and arguments.
const int ROW_VECTOR_T
Row vector type; scalar type is real.
bool is_operator(const std::string &name)
Probability, optimization and sampling library.
static function_signatures & instance()
Return the instance of this singleton.
void add(const std::string &name, const expr_type &result_type, const expr_type &arg_type1)
Add a built-in function with the specifed name, result type, and argument types.
bool is_unary_postfix_operator(const std::string &name)
This class is a singleton used to store the available functions in the Stan object language and their...
const int DOUBLE_T
Real scalar type.
int num_promotions(const std::vector< expr_type > &call_args, const std::vector< expr_type > &sig_args)
Return the number of integer to real promotions required to convert the specified call arguments to t...
Structure for function application.
Definition: fun.hpp:17
int get_signature_matches(const std::string &name, const std::vector< expr_type > &args, function_signature_t &signature)
Return the number of declared function signatures match for the specified name, argument types...
Structure of the type of an expression, which consists of a base type and a number of dimensions...
Definition: expr_type.hpp:14
void add_binary(const ::std::string &name)
Add a built-in function with the specified name, a real return type, and two real arguments...
bool has_user_defined_key(const std::string &name) const
Return true if the specified name is the name of a user-defined function.
bool is_unary_operator(const std::string &name)
void add_unary(const ::std::string &name)
Add a built-in function with the specified name, a real return type, and a single real argument...
static void reset_sigs()
Reset the signature singleton to contain no instances.
bool discrete_first_arg(const std::string &name) const
Return true if all of the function signatures for functions with the specified name have integer base...
void add_nullary(const ::std::string &name)
Add a built-in function with the specified name, a real return type, and no arguments.
bool ends_with(const std::string &suffix, const std::string &s)
Returns true if the specified suffix appears at the end of the specified string.
bool is_defined(const std::string &name, const function_signature_t &sig)
Return true if the specified function name is defined for the specified signature.
void add_quaternary(const ::std::string &name)
Add a built-in function with the specified name, a real return type, and four real arguments...
bool is_user_defined(const std::pair< std::string, function_signature_t > &name_sig)
Return true if the specified name and signature have been added as user-defined functions.
void print_signature(const std::string &name, const std::vector< expr_type > &arg_types, bool sampling_error_style, std::ostream &msgs)
std::pair< expr_type, std::vector< expr_type > > function_signature_t
The type of a function signature, mapping a vector of argument expression types to a result expressio...
bool is_binary_operator(const std::string &name)
const int INT_T
Integer type.
void add_ternary(const ::std::string &name)
Add a built-in function with the specified name, a real return type, and three real arguments...
std::set< std::string > key_set() const
Return the set of function names defined.
const int VECTOR_T
Column vector type; scalar type is real.
bool has_key(const std::string &key) const
Return true if specified key is the name of a declared function.
void set_user_defined(const std::pair< std::string, function_signature_t > &name_sig)
Set the specified name and signature to be a user-defined function.
void add_unary_vectorized(const ::std::string &name)
Add built-in functions for all the vectorized form of a unary function with the speicifed name and a ...
std::string fun_name_to_operator(const std::string &name)
const int MATRIX_T
Matrix type; scalar type is real.

     [ Stan Home Page ] © 2011–2016, Stan Development Team.