What is the compatible version between CUDA, pybamm and jax #4527
-
PyBaMM Version24.9 Jax Version0.4.27 Python Version3.10 CUDA Version12.6 Describe the bugThe question is Jax 0.4.27 version has bug, the Info is below, and Jax fix the bug in version 0.4.28. But Pybamm 24.9 can only support 0.4.27, so I need Pybamm can support Jax version newer than 0.4.27 Info'+ptx86' is not a recognized feature for this target (ignoring feature) ImageSteps to Reproducepip install -U "jax[cuda12]==0.4.27" Scriptthis nootbook demo https://docs.pybamm.org/en/latest/source/examples/notebooks/solvers/idaklu-jax-interface.html Relevant log output'+ptx86' is not a recognized feature for this target (ignoring feature) No response |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 10 replies
-
@jsbrittain I am guessing this is a question for you |
Beta Was this translation helpful? Give feedback.
-
@lcwxz1989 PyBaMM comes with the default Jax (CPU) install by default. You can upgrade this to support cuda on your system (cuda-12 by the look of it in this case) by: pip install --upgrade "jax[cuda12]" If you now re-launch Python and |
Beta Was this translation helpful? Give feedback.
-
I'll convert this to a discussion since it doesn't seem like there is a bug on our end – more details can be shared there. A CUDA-enabled JAX distribution is required for GPU support. |
Beta Was this translation helpful? Give feedback.
-
If we have hard code "0.4.27", I think we have a compaitble cuda verison? Who can share about cuda version? |
Beta Was this translation helpful? Give feedback.
I'll convert this to a discussion since it doesn't seem like there is a bug on our end – more details can be shared there. A CUDA-enabled JAX distribution is required for GPU support.