You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been using code from this amazing package for my research (preprint here) and have found that the default single precision/float32 is insufficient for the Kalman filtering and smoothing operations, causing numerical instabilities. In particular,
for the Periodic kernel, it is rather sensitive to the matrix operations in _sequential_kf() and _sequential_rts().
Likewise, the same when the lengthscales are too large in the Matern32 kernel.
However, reverting to float64 by setting config.update("jax_enable_x64", True) makes everything quite slow, especially when I use objax neural network modules, due to the fact that doing so puts all arrays into double precision.
Currently, my solution is to set the neural network weights to float32 manually, and convert input arrays into float32 before entering the network and the outputs back into float64. However, I was wondering if there could be a more elegant solution as is done in https://github.com/thomaspinder/GPJax, where all arrays are assumed to be float64. My understanding is that their package depends on Haiku, but I'm unsure how they got around the computational scalability issue.
That's great to hear that you've found the package useful, and thanks for sharing your paper - it looks great!
I agree that float64 is generally needed when working with GPs, whether that's using the Markov formulation or the standard one. Unfortunately I've never tried switching between float32 and float64, and I'm not aware of a more elegent solution to your problem. I'm also not aware of how / whether GPJax solves this issue - perhaps you could ask the authors of that package?
Hi!
Many thanks for open-sourcing this package.
I've been using code from this amazing package for my research (preprint here) and have found that the default single precision/
float32
is insufficient for the Kalman filtering and smoothing operations, causing numerical instabilities. In particular,Periodic
kernel, it is rather sensitive to the matrix operations in_sequential_kf()
and_sequential_rts()
.Matern32
kernel.However, reverting to
float64
by settingconfig.update("jax_enable_x64", True)
makes everything quite slow, especially when I useobjax
neural network modules, due to the fact that doing so puts all arrays into double precision.Currently, my solution is to set the neural network weights to
float32
manually, and convert input arrays intofloat32
before entering the network and the outputs back intofloat64
. However, I was wondering if there could be a more elegant solution as is done in https://github.com/thomaspinder/GPJax, where all arrays are assumed to be float64. My understanding is that their package depends on Haiku, but I'm unsure how they got around the computational scalability issue.Software and hardware details:
Thanks in advance.
Best,
Harrison
The text was updated successfully, but these errors were encountered: