Skip to content

Commit

Permalink
Use varis to get nth gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
bbbales2 committed Mar 30, 2021
1 parent 36847d9 commit cd1b0cb
Showing 1 changed file with 11 additions and 24 deletions.
35 changes: 11 additions & 24 deletions stan/math/rev/functor/integrate_1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,43 +102,30 @@ inline return_type_t<T_a, T_b, Args...> integrate_1d_impl(
// The arguments copy is used multiple times in the following nests, so
// do it once in a separate nest for efficiency
auto args_tuple_local_copy = std::make_tuple(deep_copy_vars(args)...);

// Save the varis so it's easy to efficiently access the nth adjoint
std::vector<vari*> local_varis(num_vars_args);
apply([&](const auto&... args) {
save_varis(local_varis.data(), args...);
}, args_tuple_local_copy);

for (size_t n = 0; n < num_vars_args; ++n) {
// This computes the integral of the gradient of f with respect to the
// nth parameter in args using a nested nested reverse mode autodiff
*partials_ptr = integrate(
[&](const auto &x, const auto &xc) {
argument_nest.set_zero_all_adjoints();

nested_rev_autodiff gradient_nest;
var fx = apply(
[&f, &x, &xc, msgs](auto &&... local_args) {
return f(x, xc, msgs, local_args...);
},
args_tuple_local_copy);
fx.grad();
size_t adjoint_count = 0;
double gradient = 0;
bool not_found = true;
// for_each is guaranteed to go off from left to right.
// So for var argument we count the number of previous vars
// until we go past n, then index into that argument to get
// the correct adjoint.
stan::math::for_each(
[&](auto &arg) {
using arg_t = decltype(arg);
using scalar_arg_t = scalar_type_t<arg_t>;
if (is_var<scalar_arg_t>::value) {
size_t var_count = count_vars(arg);
if (((adjoint_count + var_count) < n) && not_found) {
adjoint_count += var_count;
} else if (not_found) {
not_found = false;
gradient
= forward_as<var>(stan::get(arg, n - adjoint_count))
.adj();
}
}
},
args_tuple_local_copy);

double gradient = local_varis[n]->adj();

// Gradients that evaluate to NaN are set to zero if the function
// itself evaluates to zero. If the function is not zero and the
// gradient evaluates to NaN, a std::domain_error is thrown
Expand Down

0 comments on commit cd1b0cb

Please sign in to comment.