You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
that computes the difference of equal squares and the expected result is 0. However, when using CPU backend with optimization, the execution of jax.jit(f)(x, x) produces non-zero result:
Note that the value of -4.0978193e-10 indicates that float64 power evaluation is used in one of the square expression but not in both as one could conclude from the following result:
(this relation could be a coincidence but it could also be a hint...).
The same issue exists when the difference of powers with other integer exponents is computed while the issue does not exists when using non-integer exponents, for example, when defining
deff(x1, x2):
returnx1**2.1-x2**2.1
then the output is correct also when using CPU with backend optimization:
Note that in JAX, integer powers go via lax.integer_pow_p which for n = 2 lowers to a simple hlo.mul, while floating point powers go via lax.pow_p which lowers to hlo.power. So for XLA this issue is about optimization of mul, not optimization of power.
@jakevdp good point. Indeed, the XLA integer power uses exponentiation by squaring that expands the power in terms of products (see EmitIntegerPow), and then the subtraction of these multiplications is optimized to use fused multiply-add that involves higher precision arithmetic than of float32.
Here's a script that provides a model of the issue for integer powers with exponents 2, 3, and 4, respectively:
The origin of this issue is a bug report in jax-ml/jax#26147 that has the following simple reproducer:
that computes the difference of equal squares and the expected result is 0. However, when using CPU backend with optimization, the execution of
jax.jit(f)(x, x)
produces non-zero result:$ JAX_PLATFORM_NAME=cpu XLA_FLAGS="--xla_backend_optimization_level=1" python issue26147.py f(x, x)=Array([0.], dtype=float32) jax.jit(f)(x, x)=Array([-4.0978193e-10], dtype=float32)
while disabling the XLA backend optimization or using CUDA/TPU backends (see jax-ml/jax#26147), the output is as expected:
$ JAX_PLATFORM_NAME=cpu XLA_FLAGS="--xla_backend_optimization_level=0" python issue26147.py f(x, x)=Array([0.], dtype=float32) jax.jit(f)(x, x)=Array([0.], dtype=float32)
Note that the value of
-4.0978193e-10
indicates that float64 power evaluation is used in one of the square expression but not in both as one could conclude from the following result:(this relation could be a coincidence but it could also be a hint...).
The same issue exists when the difference of powers with other integer exponents is computed while the issue does not exists when using non-integer exponents, for example, when defining
then the output is correct also when using CPU with backend optimization:
$ JAX_PLATFORM_NAME=cpu XLA_FLAGS="--xla_backend_optimization_level=1" python issue26147.py f(x, x)=Array([0.], dtype=float32) jax.jit(f)(x, x)=Array([0.], dtype=float32)
The text was updated successfully, but these errors were encountered: