Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU identity subtraction < 0 #26147

Open
andrewlkd opened this issue Jan 28, 2025 · 8 comments
Open

CPU identity subtraction < 0 #26147

andrewlkd opened this issue Jan 28, 2025 · 8 comments
Labels
bug Something isn't working

Comments

@andrewlkd
Copy link

Description

Hello,

I'm not sure if this is a JAX bug or a device precision issue, but the following code produces a negative value when jitted (0 is expected).

def f(x1, s):
  x2 = x1 * (1.0 + s)
  d = (x2)**2 - (x1)**2
  return(d)
f_jit = jax.jit(f)
x = 0.1
s = 0.0
print(f(x, s)) # 0 
print(f_jit(x, s)) # < 0

Note that

  • Replacing with d = (x2 + x1)(x2 - x1)
  • Not passing s as an argument, and instead harcoding to 0
  • Running on TPU
    all fix the issue.

Thanks!

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.1
jaxlib: 0.5.1
numpy:  2.2.1
python: 3.11.8 (stable, redacted, redacted) [Clang 9999.0.0 (faa3f752896903c2d09d389970d3d0ebf50a1073)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node=[redacted], release='5.10.0-smp-1105.32.0.0', version='#1 [v5.10.0-1105.32.0.0] SMP @1729903589', machine='x86_64')
@jakevdp
Copy link
Collaborator

jakevdp commented Jan 28, 2025

Hi - in general JIT compilation does not guarantee bitwise exact outputs, but the output should generally be within the expected precision of the floating point representation being used. In this case, you're working with float32, so differences of the order np.finfo('float32').eps or about 1E-7 are not unexpected.

You can dig a bit into what's going on by using the ahead-of-time compilation tools; for example:

print("uncompiled:")
print(f_jit.lower(x, s).as_text())
print("\ncompiled:")
print(f_jit.lower(x, s).compile().as_text())
expand output
uncompiled:
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}, %arg1: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %0 = stablehlo.add %cst, %arg1 : tensor<f32>
    %1 = stablehlo.multiply %arg0, %0 : tensor<f32>
    %2 = stablehlo.multiply %1, %1 : tensor<f32>
    %3 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
    %4 = stablehlo.subtract %2, %3 : tensor<f32>
    return %4 : tensor<f32>
  }
}


compiled:
HloModule jit_f, is_scheduled=true, entry_computation_layout={(f32[], f32[])->f32[]}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}

%fused_computation (param_0.1: f32[], param_1.4: f32[]) -> f32[] {
  %param_0.1 = f32[] parameter(0)
  %param_1.4 = f32[] parameter(1)
  %constant.0 = f32[] constant(1)
  %add.0 = f32[] add(f32[] %param_1.4, f32[] %constant.0), metadata={op_name="jit(f)/jit(main)/add" source_file="<ipython-input-1-afa1b93b6de2>" source_line=4}
  %multiply.2 = f32[] multiply(f32[] %param_0.1, f32[] %add.0), metadata={op_name="jit(f)/jit(main)/mul" source_file="<ipython-input-1-afa1b93b6de2>" source_line=4}
  %multiply.1 = f32[] multiply(f32[] %multiply.2, f32[] %multiply.2), metadata={op_name="jit(f)/jit(main)/integer_pow" source_file="<ipython-input-1-afa1b93b6de2>" source_line=5}
  %multiply.0 = f32[] multiply(f32[] %param_0.1, f32[] %param_0.1), metadata={op_name="jit(f)/jit(main)/integer_pow" source_file="<ipython-input-1-afa1b93b6de2>" source_line=5}
  ROOT %subtract.0 = f32[] subtract(f32[] %multiply.1, f32[] %multiply.0), metadata={op_name="jit(f)/jit(main)/sub" source_file="<ipython-input-1-afa1b93b6de2>" source_line=5}
}

ENTRY %main.9 (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[] {
  %Arg_0.1 = f32[] parameter(0), metadata={op_name="x1"}
  %Arg_1.2 = f32[] parameter(1), metadata={op_name="s"}
  ROOT %fusion = f32[] fusion(f32[] %Arg_0.1, f32[] %Arg_1.2), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(f)/jit(main)/sub" source_file="<ipython-input-1-afa1b93b6de2>" source_line=5}
}

That said, I don't see anything obviously problematic here. The compiled version fuses the full computation into a single kernel, and perhaps some of the fusion logic re-orders the computation in such a way that errors accumulate differently.

@andrewlkd
Copy link
Author

Thanks Jake! Indeed I had taken a look at some of the lowered methods and was confused because I couldn't find an obvious issue between cases that did and didn't work.

We have a method in our repo https://github.com/google-deepmind/graphcast/blob/main/graphcast/samplers_utils.py#L418 that this is a simple reproducer of.

The method returns NaNs when running on CPU and the stochastic_churn_rate is 0. Otherwise, it is fine (when on TPU or when the stochastic_churn_rate is non zero).

We have identified that new_noise_level**2 - noise_level**2 can evaluate to less than 0 when new_noise_level == noise_level (this occurs when the stochastic_churn_rate is 0).

I guess we'll have to add some jnp.maximum call to clamp to 0, or rewrite as the difference of two squares.

@pearu
Copy link
Collaborator

pearu commented Jan 29, 2025

I can reproduce the issue on CPU but not on CUDA.

While it is true that evaluation of f(x, s) is performed using float64, I would expect that forcing float32 inputs, that is, calling f(jnp.float32(x), jnp.float32(s)) would produce the same results as f_jit(x, s).

If "the fusion logic re-orders the computation in such a way that errors accumulate differently" is indeed true, I'd consider this as a bug because floating-point arithmetic is non-associative and algorithms that are designed to take non-associativity into account to improve the accuracy of results, become broken.

However, with the given inputs (x1=0.1, s=0.0), I cannot pinpoint what could cause differences in error accumulations:

  • fp addition when one operand is zero, is exact,
  • fp multiplication when one operand is one, is exact.

Hence, x1 and x2 ought to be equal, so ought to be equal their squares, and subtraction of equal values ought to result zero value.
Even when CPU and CUDA operations use different FTZ modes, I cannot explain the non-zero result from jitted function.

@pearu
Copy link
Collaborator

pearu commented Jan 29, 2025

@andrewlkd, using d = (x2 + x1)(x2 - x1) makes more sense because it reduces cancellations errors that occur when using x2**2 - x1**2.

@pearu
Copy link
Collaborator

pearu commented Jan 29, 2025

Here is a simpler reproducer of the issue:

>>> def f(x1, x2):
...   return x2 ** 2 - x1 ** 2
... 
>>> jax.jit(f)(jnp.array(0.1), jnp.array(0.1))
Array(-4.0978193e-10, dtype=float32, weak_type=True)
>>> def f(x1, x2):
...   return 0 + x1 ** 2 - x1 ** 2
... 
>>> jax.jit(f)(jnp.array(0.1), jnp.array(0.1))
Array(0., dtype=float32, weak_type=True)

@pearu
Copy link
Collaborator

pearu commented Jan 29, 2025

Notice the following (coincidence?):

>>> numpy.float32(0.1) ** 2 - numpy.float32(numpy.float64(numpy.float32(0.1)) ** 2)
-4.097819306103645e-10

that reproduces the above result, that is, the origin of the value is float->double->float casting.

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 29, 2025

This would be worth reporting upstream at https://github.com/openxla/xla. @pearu would you like to do that?

@pearu
Copy link
Collaborator

pearu commented Jan 30, 2025

Here's report to upstream openxla/xla#22116 that includes a couple of other diagnosis results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants