Skip to content

Commit

Permalink
Fix bug in JAX array initialization: Qs.set(0).set(P0) -> Qs.at[0].se…
Browse files Browse the repository at this point in the history
…t(P0).

The current code fails with the following AttributeError:

    --> 240     Qs = Qs.set(0).set(P0)  # first element requires different initialisation
    AttributeError: DynamicJaxprTracer has no attribute set
  • Loading branch information
fsaad committed May 2, 2023
1 parent ad56794 commit aa52e19
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion bayesnewton/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def parallel_filtering_operator(elem1, elem2):


def make_associative_filtering_elements(As, Qs, H, ys, noise_covs, m0, P0):
Qs = Qs.set(0).set(P0) # first element requires different initialisation
Qs = Qs.at[0].set(P0) # first element requires different initialisation
AA, b, C, J, eta = parallel_filtering_element(As, Qs, H, noise_covs, ys)
# modify initial b to account for m0 (not needed if m0=zeros)
S = H @ Qs[0] @ H.T + noise_covs[0]
Expand Down

0 comments on commit aa52e19

Please sign in to comment.