Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling CPU backend optimization reveals a possible inconsistency in computing float32 power differences #22116

Open
pearu opened this issue Jan 30, 2025 · 2 comments

Comments

@pearu
Copy link
Contributor

pearu commented Jan 30, 2025

The origin of this issue is a bug report in jax-ml/jax#26147 that has the following simple reproducer:

# File: issue26147.py 
import jax
import jax.numpy as jnp

def f(x1, x2):
    return x1 ** 2 - x2 ** 2

x = jnp.array([0.1], dtype=jnp.float32)

print(f'{f(x, x)=}')
print(f'{jax.jit(f)(x, x)=}')

that computes the difference of equal squares and the expected result is 0. However, when using CPU backend with optimization, the execution of jax.jit(f)(x, x) produces non-zero result:

$ JAX_PLATFORM_NAME=cpu XLA_FLAGS="--xla_backend_optimization_level=1" python issue26147.py 
f(x, x)=Array([0.], dtype=float32)
jax.jit(f)(x, x)=Array([-4.0978193e-10], dtype=float32)

while disabling the XLA backend optimization or using CUDA/TPU backends (see jax-ml/jax#26147), the output is as expected:

$ JAX_PLATFORM_NAME=cpu XLA_FLAGS="--xla_backend_optimization_level=0" python issue26147.py 
f(x, x)=Array([0.], dtype=float32)
jax.jit(f)(x, x)=Array([0.], dtype=float32)

Note that the value of -4.0978193e-10 indicates that float64 power evaluation is used in one of the square expression but not in both as one could conclude from the following result:

>>> numpy.float32(numpy.float32(0.1) ** 2 - numpy.float32(numpy.float64(numpy.float32(0.1)) ** 2))
-4.0978193e-10

(this relation could be a coincidence but it could also be a hint...).

The same issue exists when the difference of powers with other integer exponents is computed while the issue does not exists when using non-integer exponents, for example, when defining

def f(x1, x2):
    return x1 ** 2.1 - x2 ** 2.1

then the output is correct also when using CPU with backend optimization:

$ JAX_PLATFORM_NAME=cpu XLA_FLAGS="--xla_backend_optimization_level=1" python issue26147.py
f(x, x)=Array([0.], dtype=float32)
jax.jit(f)(x, x)=Array([0.], dtype=float32)
@jakevdp
Copy link
Contributor

jakevdp commented Jan 30, 2025

Note that in JAX, integer powers go via lax.integer_pow_p which for n = 2 lowers to a simple hlo.mul, while floating point powers go via lax.pow_p which lowers to hlo.power. So for XLA this issue is about optimization of mul, not optimization of power.

@pearu
Copy link
Contributor Author

pearu commented Jan 31, 2025

@jakevdp good point. Indeed, the XLA integer power uses exponentiation by squaring that expands the power in terms of products (see EmitIntegerPow), and then the subtraction of these multiplications is optimized to use fused multiply-add that involves higher precision arithmetic than of float32.

Here's a script that provides a model of the issue for integer powers with exponents 2, 3, and 4, respectively:

import numpy
import jax
import jax.numpy as jnp

def fma(x, y, z):
    # model of fused multiply-add
    return numpy.float32(numpy.float64(x) + numpy.float64(y) * numpy.float64(z))

def f2(x1, x2):
    return x1 ** 2 - x2 ** 2

def fused_f2(x1, x2):
    return fma(-x2 * x2, x1, x1)

def f3(x1, x2):
    return x1 ** 3 - x2 ** 3

def fused_f3(x1, x2):
    return fma(-x2 * x2 * x2, x1 * x1, x1)

def f4(x1, x2):
    return x1 ** 4 - x2 ** 4

def fused_f4(x1, x2):
    return fma(-(x2 * x2) * (x2 * x2), x1 * x1, x1 * x1)

x = jnp.array([0.1], dtype=jnp.float32)

print(f'{jax.jit(f2)(x, x)} {fused_f2(x, x)}')  # -> [-4.0978193e-10] [-4.0978193e-10]
print(f'{jax.jit(f3)(x, x)} {fused_f3(x, x)}')  # -> [3.8184227e-11] [3.8184227e-11]
print(f'{jax.jit(f4)(x, x)} {fused_f4(x, x)}')  # -> [2.1304009e-12] [2.1304009e-12]

Funnily enough, this is a well-known problem of FMA: "Fused multiply–add can usually be relied on to give more accurate results. However, William Kahan has pointed out that it can give problems if used unthinkingly.", and from Kahan's paper: "Fused MACs cannot be used indiscriminately; there are a few programs ... from which Fused MACs must be banned".

As a resolution, I suggest not using FMA in this particular case to ensure correctness of the result.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants