-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CPU identity subtraction < 0 #26147
Comments
Hi - in general JIT compilation does not guarantee bitwise exact outputs, but the output should generally be within the expected precision of the floating point representation being used. In this case, you're working with float32, so differences of the order You can dig a bit into what's going on by using the ahead-of-time compilation tools; for example: print("uncompiled:")
print(f_jit.lower(x, s).as_text())
print("\ncompiled:")
print(f_jit.lower(x, s).compile().as_text()) expand output
That said, I don't see anything obviously problematic here. The compiled version fuses the full computation into a single kernel, and perhaps some of the fusion logic re-orders the computation in such a way that errors accumulate differently. |
Thanks Jake! Indeed I had taken a look at some of the lowered methods and was confused because I couldn't find an obvious issue between cases that did and didn't work. We have a method in our repo https://github.com/google-deepmind/graphcast/blob/main/graphcast/samplers_utils.py#L418 that this is a simple reproducer of. The method returns NaNs when running on CPU and the We have identified that I guess we'll have to add some |
I can reproduce the issue on CPU but not on CUDA. While it is true that evaluation of If "the fusion logic re-orders the computation in such a way that errors accumulate differently" is indeed true, I'd consider this as a bug because floating-point arithmetic is non-associative and algorithms that are designed to take non-associativity into account to improve the accuracy of results, become broken. However, with the given inputs (x1=0.1, s=0.0), I cannot pinpoint what could cause differences in error accumulations:
Hence, x1 and x2 ought to be equal, so ought to be equal their squares, and subtraction of equal values ought to result zero value. |
@andrewlkd, using |
Here is a simpler reproducer of the issue: >>> def f(x1, x2):
... return x2 ** 2 - x1 ** 2
...
>>> jax.jit(f)(jnp.array(0.1), jnp.array(0.1))
Array(-4.0978193e-10, dtype=float32, weak_type=True)
>>> def f(x1, x2):
... return 0 + x1 ** 2 - x1 ** 2
...
>>> jax.jit(f)(jnp.array(0.1), jnp.array(0.1))
Array(0., dtype=float32, weak_type=True) |
Notice the following (coincidence?): >>> numpy.float32(0.1) ** 2 - numpy.float32(numpy.float64(numpy.float32(0.1)) ** 2)
-4.097819306103645e-10 that reproduces the above result, that is, the origin of the value is float->double->float casting. |
This would be worth reporting upstream at https://github.com/openxla/xla. @pearu would you like to do that? |
Here's report to upstream openxla/xla#22116 that includes a couple of other diagnosis results. |
Description
Hello,
I'm not sure if this is a JAX bug or a device precision issue, but the following code produces a negative value when jitted (0 is expected).
Note that
d = (x2 + x1)(x2 - x1)
s
as an argument, and instead harcoding to 0all fix the issue.
Thanks!
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: