Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update integrate_1d to use variadic autodiff stuff internally in preparation for closures #2397

Merged
merged 28 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6e0f680
Saving work
bbbales2 Nov 22, 2020
340e62c
integrate_1d_new working (variadic integrate_1d) (Issue #2110)
bbbales2 Nov 23, 2020
767ebb6
Merge remote-tracking branch 'origin/develop' into feature/variadic-i…
bbbales2 Feb 24, 2021
f84f7eb
Bit of cleanup and added adapter file (Issue #2197)
bbbales2 Feb 24, 2021
5019637
Renamed tests (Issue #2197)
bbbales2 Feb 24, 2021
403987f
Turned on reverse mode tests (Issue #2197)
bbbales2 Feb 24, 2021
7f53b90
Merge remote-tracking branch 'origin/develop' into feature/variadic-i…
bbbales2 Feb 27, 2021
132cd2c
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Feb 27, 2021
59d1099
Switch binds to lambdas
bbbales2 Mar 25, 2021
989cb49
Merge commit 'a426eea0ec9d9a7547061bc776a08e509d3406f3' into HEAD
yashikno Mar 25, 2021
a6cebfc
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 25, 2021
abaa602
use double nested reverse pass to save making N tuple copies
SteveBronder Mar 25, 2021
82624e8
Update stan/math/rev/functor/integrate_1d.hpp
bbbales2 Mar 29, 2021
3469fa1
Update stan/math/prim/functor/integrate_1d.hpp
bbbales2 Mar 29, 2021
eaaeeb2
Updated docs
bbbales2 Mar 29, 2021
95b4032
Update stan/math/rev/functor/integrate_1d.hpp
bbbales2 Mar 29, 2021
485e089
Merge branch 'feature/variadic-integrate-1d' of github.com:stan-dev/m…
bbbales2 Mar 29, 2021
c4540c5
Reordered if
bbbales2 Mar 29, 2021
e4eb66f
Merge remote-tracking branch 'origin/review/variadic-integrate-1d' in…
bbbales2 Mar 29, 2021
aba98b6
Put error checks into functions
bbbales2 Mar 29, 2021
231b0e9
Merge commit 'f390d823261266c2a0999eeaedcd5ac216f857b3' into HEAD
yashikno Mar 29, 2021
77164cc
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 29, 2021
4e882f2
small cleanups and use get the adjoint in each integration by a looku…
SteveBronder Mar 29, 2021
a2e8e57
Merge remote-tracking branch 'origin/develop' into feature/variadic-i…
bbbales2 Mar 30, 2021
36847d9
Merge remote-tracking branch 'origin/review/integrate-1d-variadic-2' …
bbbales2 Mar 30, 2021
cd1b0cb
Use varis to get nth gradient
bbbales2 Mar 30, 2021
9360084
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 30, 2021
49ba955
remove changes to math::get() now that we don't need it in integrate1d
SteveBronder Mar 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions stan/math/prim/functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <stan/math/prim/functor/finite_diff_gradient_auto.hpp>
#include <stan/math/prim/functor/for_each.hpp>
#include <stan/math/prim/functor/integrate_1d.hpp>
#include <stan/math/prim/functor/integrate_1d_adapter.hpp>
#include <stan/math/prim/functor/integrate_ode_rk45.hpp>
#include <stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp>
#include <stan/math/prim/functor/ode_ckrk.hpp>
Expand Down
58 changes: 41 additions & 17 deletions stan/math/prim/functor/integrate_1d.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#ifndef STAN_MATH_PRIM_FUNCTOR_integrate_1d_HPP
#define STAN_MATH_PRIM_FUNCTOR_integrate_1d_HPP
#ifndef STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_HPP
#define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/functor/integrate_1d_adapter.hpp>
#include <boost/math/quadrature/exp_sinh.hpp>
#include <boost/math/quadrature/sinh_sinh.hpp>
#include <boost/math/quadrature/tanh_sinh.hpp>
Expand Down Expand Up @@ -57,6 +58,8 @@ inline double integrate(const F& f, double a, double b,
bool used_two_integrals = false;
size_t levels;
double Q = 0.0;
// if a or b is infinite, set xc argument to NaN (see docs above for user
// function for xc info)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] We can do this some other time but it would be nice to make the errors in the if (used_two_integrals) into a function that's just called in the places we use two integrals

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I split these into lambda functions. That look better or is used_two_integrals clearer?

auto f_wrap = [&](double x) { return f(x, NOT_A_NUMBER); };
if (std::isinf(a) && std::isinf(b)) {
boost::math::quadrature::sinh_sinh<double> integrator;
Expand Down Expand Up @@ -131,6 +134,39 @@ inline double integrate(const F& f, double a, double b,
return Q;
}

/**
* Compute the integral of the single variable function f from a to b to within
* a specified relative tolerance. a and b can be finite or infinite.
*
* @tparam T Type of f
* @param f the function to be integrated
* @param a lower limit of integration
* @param b upper limit of integration
* @param relative_tolerance tolerance passed to Boost quadrature
* @param[in, out] msgs the print stream for warning messages
* @param args additional arguments passed to f
* @return numeric integral of function f
*/
template <typename F, typename... Args,
require_all_not_st_var<Args...>* = nullptr>
inline double integrate_1d_impl(const F& f, double a, double b,
double relative_tolerance, std::ostream* msgs,
const Args&... args) {
static const char* function = "integrate_1d";
bbbales2 marked this conversation as resolved.
Show resolved Hide resolved
check_less_or_equal(function, "lower limit", a, b);
bbbales2 marked this conversation as resolved.
Show resolved Hide resolved

if (a == b) {
if (std::isinf(a)) {
throw_domain_error(function, "Integration endpoints are both", a, "", "");
}
return 0.0;
} else {
return integrate(
[&](const auto& x, const auto& xc) { return f(x, xc, msgs, args...); },
a, b, relative_tolerance);
}
}

/**
* Compute the integral of the single variable function f from a to b to within
* a specified relative tolerance. a and b can be finite or infinite.
Expand Down Expand Up @@ -178,26 +214,14 @@ inline double integrate(const F& f, double a, double b,
* @return numeric integral of function f
*/
template <typename F>
inline double integrate_1d(const F& f, const double a, const double b,
inline double integrate_1d(const F& f, double a, double b,
const std::vector<double>& theta,
const std::vector<double>& x_r,
const std::vector<int>& x_i, std::ostream* msgs,
const double relative_tolerance
= std::sqrt(EPSILON)) {
static const char* function = "integrate_1d";
check_less_or_equal(function, "lower limit", a, b);

if (a == b) {
if (std::isinf(a)) {
throw_domain_error(function, "Integration endpoints are both", a, "", "");
}
return 0.0;
} else {
return integrate(
std::bind<double>(f, std::placeholders::_1, std::placeholders::_2,
theta, x_r, x_i, msgs),
a, b, relative_tolerance);
}
return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b, relative_tolerance,
msgs, theta, x_r, x_i);
}

} // namespace math
Expand Down
28 changes: 28 additions & 0 deletions stan/math/prim/functor/integrate_1d_adapter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_ADAPTER_HPP
#define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_ADAPTER_HPP

#include <ostream>
#include <vector>

/**
* Adapt the non-variadic integrate_1d arguments to the variadic
* integrate_1d_impl interface
*
* @tparam F type of function to adapt
*/
template <typename F>
struct integrate_1d_adapter {
const F& f_;

explicit integrate_1d_adapter(const F& f) : f_(f) {}

template <typename T_a, typename T_b, typename T_theta>
auto operator()(const T_a& x, const T_b& xc, std::ostream* msgs,
const std::vector<T_theta>& theta,
const std::vector<double>& x_r,
const std::vector<int>& x_i) const {
return f_(x, xc, theta, x_r, x_i, msgs);
}
};

#endif
175 changes: 118 additions & 57 deletions stan/math/rev/functor/integrate_1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,42 @@ namespace math {
* NaN, a std::domain_error is thrown
*
* @tparam F type of f
* @tparam Args types of arguments to f
* @param f function to compute gradients of
* @param x location at which to evaluate gradients
* @param xc complement of location (if bounded domain of integration)
* @param n compute gradient with respect to nth parameter
* @param msgs stream for messages
* @param args other arguments to pass to f
*/
template <typename F>
template <typename F, typename... Args>
inline double gradient_of_f(const F &f, const double &x, const double &xc,
const std::vector<double> &theta_vals,
const std::vector<double> &x_r,
const std::vector<int> &x_i, size_t n,
std::ostream *msgs) {
size_t n, std::ostream *msgs,
const Args &... args) {
double gradient = 0.0;

// Run nested autodiff in this scope
nested_rev_autodiff nested;

std::vector<var> theta_var(theta_vals.size());
for (size_t i = 0; i < theta_vals.size(); i++) {
theta_var[i] = theta_vals[i];
}
var fx = f(x, xc, theta_var, x_r, x_i, msgs);
std::tuple<decltype(deep_copy_vars(args))...> args_tuple_local_copy(
deep_copy_vars(args)...);
bbbales2 marked this conversation as resolved.
Show resolved Hide resolved

Eigen::VectorXd adjoints = Eigen::VectorXd::Zero(count_vars(args...));

var fx = apply(
[&f, &x, &xc, msgs](auto &&... args) { return f(x, xc, msgs, args...); },
args_tuple_local_copy);

fx.grad();
gradient = theta_var[n].adj();

apply(
[&](auto &&... args) {
accumulate_adjoints(adjoints.data(),
std::forward<decltype(args)>(args)...);
},
std::move(args_tuple_local_copy));

gradient = adjoints.coeff(n);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to make this comment in my review, but optionally we can definitely figure out which arg has the adjoint value we need here in a clever way that doesn't require copying all of them. We could have some function to get the Nth var in a tuple but I'm also fine with not doing that in this PR

Copy link
Collaborator

@SteveBronder SteveBronder Mar 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have some internal function like

template <typename... Args>
double get_nth_adjoint(size_t n, const std::tuple<Args...>& tuple_arg) {
   size_t accum_vars = 0;
   bool stop_checking = false;
   // for_each goes off from left to right
   std::array<double, sizeof...(Args)> possible_adjs = 
    for_each([&accum_vars, &stop_checking](auto&& arg){
      if (stop_checking) return 0.0;
      size_t num_vars = count_vars(arg);
      // Need to keep moving along
      if ((accum_vars + num_vars) < nth || stop_checking) {
         accum_vars += num_vars;
         return 0.0;
      } else { // We reached the first arg that passes
         stop_checking = true;
         // I'm tired but do the logic here to get the nth value from that particular arg
         return get_the_adj(arg, some_index_calculation);
      } 
   }, tuple_arg);
   // Loop over possible_adjs until we hit a nonzero value (or they are all zero and return a zero)
   return get_nonzero_value(possible_adjs);
};

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's awkward. I also want to leave it for now.

The way to speed this up is writing our own quadratures or talking to the Boost people about an interface where we integrate multiple functions on the same domain together. Now we compute all three of these integrals totally separately:

\int f(x, a, b) dx
\int df(x, a, b)/da
\int df(x, a, b)/db

But anytime we compute df(x, a, b)/da we also get df(x, a, b)/db, and so the efficiency gains would be taking advantage of that (and what we're doing here is throwing away a ton of gradient info).

if (is_nan(gradient)) {
if (fx.val() == 0) {
gradient = 0;
Expand All @@ -59,6 +76,94 @@ inline double gradient_of_f(const F &f, const double &x, const double &xc,
return gradient;
}

/**
* Return the integral of f from a to b to the given relative tolerance
*
* @tparam T_a type of first limit
* @tparam T_b type of second limit
* @tparam T_theta type of parameters
bbbales2 marked this conversation as resolved.
Show resolved Hide resolved
* @tparam T Type of f
*
* @param f the functor to integrate
* @param a lower limit of integration
* @param b upper limit of integration
* @param relative_tolerance relative tolerance passed to Boost quadrature
* @param[in, out] msgs the print stream for warning messages
* @param args additional arguments to pass to f
* @return numeric integral of function f
*/
template <typename F, typename T_a, typename T_b, typename... Args,
require_any_st_var<T_a, T_b, Args...> * = nullptr>
inline return_type_t<T_a, T_b, Args...> integrate_1d_impl(
const F &f, const T_a &a, const T_b &b, double relative_tolerance,
std::ostream *msgs, const Args &... args) {
static const char *function = "integrate_1d";
bbbales2 marked this conversation as resolved.
Show resolved Hide resolved
check_less_or_equal(function, "lower limit", a, b);

double a_val = value_of(a);
double b_val = value_of(b);

if (a_val == b_val) {
if (is_inf(a_val)) {
throw_domain_error(function, "Integration endpoints are both", a_val, "",
"");
}
bbbales2 marked this conversation as resolved.
Show resolved Hide resolved
return var(0.0);
} else {
std::tuple<decltype(value_of(args))...> args_val_tuple(value_of(args)...);

double integral = integrate(
[&](const auto &x, const auto &xc) {
return apply([&](auto &&... args) { return f(x, xc, msgs, args...); },
args_val_tuple);
},
a_val, b_val, relative_tolerance);

size_t num_vars_ab = count_vars(a, b);
size_t num_vars_args = count_vars(args...);
vari **varis = ChainableStack::instance_->memalloc_.alloc_array<vari *>(
num_vars_ab + num_vars_args);
double *partials = ChainableStack::instance_->memalloc_.alloc_array<double>(
num_vars_ab + num_vars_args);
double *partials_ptr = partials;

save_varis(varis, a, b, args...);

for (size_t i = 0; i < num_vars_ab + num_vars_args; ++i) {
partials[i] = 0.0;
}

if (!is_inf(a) && is_var<T_a>::value) {
bbbales2 marked this conversation as resolved.
Show resolved Hide resolved
*partials_ptr = apply(
[&f, a_val, msgs](auto &&... args) {
return -f(a_val, 0.0, msgs, args...);
},
args_val_tuple);
partials_ptr++;
}

if (!is_inf(b) && is_var<T_b>::value) {
*partials_ptr
= apply([&f, b_val, msgs](
auto &&... args) { return f(b_val, 0.0, msgs, args...); },
args_val_tuple);
partials_ptr++;
}

for (size_t n = 0; n < num_vars_args; ++n) {
*partials_ptr = integrate(
[&](const auto &x, const auto &xc) {
return gradient_of_f<F, Args...>(f, x, xc, n, msgs, args...);
},
a_val, b_val, relative_tolerance);
partials_ptr++;
}

return var(new precomputed_gradients_vari(
integral, num_vars_ab + num_vars_args, varis, partials));
}
}

/**
* Compute the integral of the single variable function f from a to b to within
* a specified relative tolerance. a and b can be finite or infinite.
Expand Down Expand Up @@ -120,52 +225,8 @@ inline return_type_t<T_a, T_b, T_theta> integrate_1d(
const F &f, const T_a &a, const T_b &b, const std::vector<T_theta> &theta,
const std::vector<double> &x_r, const std::vector<int> &x_i,
std::ostream *msgs, const double relative_tolerance = std::sqrt(EPSILON)) {
static const char *function = "integrate_1d";
check_less_or_equal(function, "lower limit", a, b);

if (value_of(a) == value_of(b)) {
if (is_inf(a)) {
throw_domain_error(function, "Integration endpoints are both",
value_of(a), "", "");
}
return var(0.0);
} else {
double integral = integrate(
std::bind<double>(f, std::placeholders::_1, std::placeholders::_2,
value_of(theta), x_r, x_i, msgs),
value_of(a), value_of(b), relative_tolerance);

size_t N_theta_vars = is_var<T_theta>::value ? theta.size() : 0;
std::vector<double> dintegral_dtheta(N_theta_vars);
std::vector<var> theta_concat(N_theta_vars);

if (N_theta_vars > 0) {
std::vector<double> theta_vals = value_of(theta);

for (size_t n = 0; n < N_theta_vars; ++n) {
dintegral_dtheta[n] = integrate(
std::bind<double>(gradient_of_f<F>, f, std::placeholders::_1,
std::placeholders::_2, theta_vals, x_r, x_i, n,
msgs),
value_of(a), value_of(b), relative_tolerance);
theta_concat[n] = theta[n];
}
}

if (!is_inf(a) && is_var<T_a>::value) {
theta_concat.push_back(a);
dintegral_dtheta.push_back(
-value_of(f(value_of(a), 0.0, theta, x_r, x_i, msgs)));
}

if (!is_inf(b) && is_var<T_b>::value) {
theta_concat.push_back(b);
dintegral_dtheta.push_back(
value_of(f(value_of(b), 0.0, theta, x_r, x_i, msgs)));
}

return precomputed_gradients(integral, theta_concat, dintegral_dtheta);
}
return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b, relative_tolerance,
msgs, theta, x_r, x_i);
}

} // namespace math
Expand Down
Loading