Skip to content

Commit

Permalink
PR #20853: Enable stablehlo-complex-math-expander pass.
Browse files Browse the repository at this point in the history
Imported from GitHub PR #20853

As in the title.

Enabling stablehlo-complex-math-expander pass improves the accuracy of complex log_plus_one as follows. The accuracy pattern (obtained by running `functional_algorithms` test `test_accuracy` with JAX backend on CUDA arrays) before enabling the pass (legend: label `=` means ULP differences with respect to reference value is `0`, label `1` - ULP difference is `1`, etc, label `!` - ULP difference is more that `3`, `E` - ULP difference is more that `300`):

```
  -inf    11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111112
  -4e35   11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111112221
  -4e32   11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111222121
  -5e29   11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111112111211
  -5e26   11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111112222111211
  -6e23   11=11111111111111111111111111111111111111111111111111111111111111111111111111111111111111122111111111
  -6e20   11=1==1111111111111===============================================================1111122122111111111
  -7e17   11=1==111111111111111===========================================================111112222222111111111
  -8e14   11=1==111=1111111111111111111111111111111111111111111111111111111111111111111111111122222122111111111
  -9e11   11=1==111===11111111111111111111111111111111111111111111111111111111111111111111111221122122111111111
  -9e8    11=1==111===1=111111111111111111111111111111111111111111111111111111111111111111111121112122111121211
  -1e6    11=1==111===1=1=1111111111111111111111111111111111111111111111111111111111111222211121122122111121121
  -1e3    11=1==111===1=1=1111111111111111111111111111111111111111111111111111111111112221111111121122111111211
  -1      11=1==111===1=1=111111111!!!E!!111111111111111111111111111111111111111333!222222211121212122111111211
  -1e-3   11=1==111===1=1=1111=11112222!!!!!111111111111111111111111111111111111222222211221111111211211111111=
  -2e-6   11=1==111===1=1=1111=1==1212222=1!!!!111111================1111111111=1112122121211111121122111111=11
  -2e-9   11=1==111===1=1=1111=1===212222=====1!!!!11111========11111111111=====111212212121112211212211111111=
  -2e-12  11=1==111===1=1=1111=1===212122======1==!!!!!1111111111111111==1======111222212121112112212211=111=1=
  -2e-15  11=1==111===1=1=1111=1===212122=============!!!!!11111111=============11121221222111211221111111=1=1=
  -2e-18  11=1==111===1=1=1111=1===212222=================!!111=================11122221222111111111=111=1=1=1=
  -2e-21  11=1==111===1=1=1111=1===212122=================11111=================111212212121112111=1111==1=1=1=
  -3e-24  11=1==111===1=1=1111=1===211122===================1===================11121221212111111111=11==1=1=1=
  -3e-27  11=1==111===1=1=1111=1===212222=======================================111222212221111111===11==1=1=1=
  -3e-30  11=1==111===1=1=1111=1===211122=======================================111222212121111111===11==1=1=1=
  -4e-33  11=1==111===1=1=1111=1===212222=======================================111222211111111=11===11==1=1=1=
  -4e-36  11=1==111===1=1=1111=1===212221=======================================11122121==111=1=11===11==1=1=1=
  0       11=1==111===1=1=1111=1===212221=======================================111222211=11==1=11===11==1=1=1=
  4e-36   11=1==111===1=1=1111=1===212122=======================================111212212111111=11===11==1=1=1=
  4e-33   11=1==111===1=1=1111=1===211122=======================================111212212111=11111===11==1=1=1=
  3e-30   11=1==111===1=1=1111=1===212122=======================================111222212221111111===11==1=1=1=
  3e-27   11=1==111===1=1=1111=1===212222===================1===================11122221222111111111=11==1=1=1=
  3e-24   11=1==111===1=1=1111=1===211122==================111==================111222212121111111=1111==1=1=1=
  3e-21   11=1==111===1=1=1111=1===212222================11!1111================11121221122111111121=111=1=1=1=
  2e-18   11=1==111===1=1=1111=1===212122=============2!!!!11111111=============11122221222111121121111111=1=1=
  2e-15   11=1==111===1=1=1111=1===211122=========1!!!!1111111111111111=========1112122121211112222122111111=1=
  2e-12   11=1==111===1=1=1111=1===212122=====1!!!!111111========111111111======111222212121112112112211111111=
  2e-9    11=1==111===1=1=1111=1==121222111!!!!111111================111111111111112222122211111212122111111=11
  2e-6    11=1==111===1=1=1111=11112221!!!!!111111111111111111111111111111111111221222212211111221212211111111=
  2e-3    11=1==111===1=1=111111111!!!!!3111111111111111111111111111111111111111333!222211211121121122111111111
  2       11=1==111===1=1=1111111111111111111111111111111111111111111111111111111111122222111121112122111121121
  1e3     11=1==111===1=1=1111111111111111111111111111111111111111111111111111111111111122211122112122111111111
  1e6     11=1==111===1=111111111111111111111111111111111111111111111111111111111111111111221112121122111111211
  1e9     11=1==111===1111111111111===================================================1111122221112122111111111
  1e12    11=1==111=1111111111111=======================================================11111122222122111111211
  1e15    11=1==11111111111111111111111111111111111111111111111111111111111111111111111111111111122122111111111
  9e17    11=1==1111111111111===============================================================1111122122111111111
  8e20    11=11111111111111111111111111111111111111111111111111111111111111111111111111111111111111222121121111
  8e23    11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111112222111211
  7e26    1111111111111===========================================================================1111122221111
  6e29    11111111111===============================================================================11111222221
  6e32    111111111===================================================================================111112211
  5e35    1111111=======================================================================================1111112
          -inf  -3e29 -2e20 -1e11 -1e2  -8e-8 -6e-1 -4e-2 -3e-3 1e-33 1e-24 2e-15 3e-6  4e3   5e12  8e21  1e31
```

<details>

```
| ULP-difference |                                z |                     jax:log1p(z) |                  mpmath:log1p(z) |                 numpy:log1p(z) |                      fa:log1p(z) |
| -------------- | -------------------------------- | -------------------------------- | -------------------------------- | ------------------------------ | -------------------------------- |
|            417 |   (-1.4850477e-05-0.0054416545j) |     (-4.456342e-08-0.005441682j) |   (-4.4564903e-08-0.0054416815j) | (-5.9604645e-08-0.0054416815j) |    (-4.4564903e-08-0.005441681j) |
|            262 |   (-4.604984e-27+9.5909554e-14j) |  (-5.6626192e-30+9.5909554e-14j) |  (-5.6627178e-30+9.5909554e-14j) |                 9.5909554e-14j |  (-5.6627178e-30+9.5909554e-14j) |
|            205 |  (-4.2529954e-19+9.2143404e-10j) |  (-7.7920414e-22+9.2143404e-10j) |   (-7.791938e-22+9.2143404e-10j) |                 9.2143404e-10j |   (-7.791938e-22+9.2143404e-10j) |
|            196 |  (-3.3057245e-07+0.00081537315j) |   (1.8443131e-09+0.00081537326j) |   (1.8442914e-09+0.00081537326j) |                 0.00081537326j |   (1.8442914e-09+0.00081537326j) |
|            185 |  (-2.2565159e-20-2.1264301e-10j) |   (4.3367205e-23-2.1264301e-10j) |   (4.3366622e-23-2.1264301e-10j) |                -2.1264301e-10j |   (4.3366622e-23-2.1264301e-10j) |
|            117 |     (-0.0012349789+0.049288128j) |    (-1.9556726e-05+0.049309075j) |      (-1.955694e-05+0.04930907j) |   (-1.9610121e-05+0.04930907j) |      (-1.955694e-05+0.04930907j) |
|            110 |      (-0.0034784595-0.08273122j) |      (-5.018548e-05-0.08283005j) |      (-5.018508e-05-0.08283005j) |    (-5.018837e-05-0.08283006j) |      (-5.018508e-05-0.08283006j) |
|            105 |    (-9.969826e-30+4.464897e-15j) |    (-2.173442e-33+4.464897e-15j) |   (-2.1734613e-33+4.464897e-15j) |                  4.464897e-15j |   (-2.1734613e-33+4.464897e-15j) |
|            103 |  (-2.1849506e-12-2.0837508e-06j) |  (-1.3941756e-14-2.0837508e-06j) |  (-1.3941843e-14-2.0837508e-06j) |                -2.0837508e-06j |  (-1.3941843e-14-2.0837508e-06j) |
|            ... |                              ... |                              ... |                              ... |                            ... |                              ... |
|              8 |  (-2.3963775e-36-2.2461159e-18j) |    (1.261409e-37-2.2461159e-18j) |   (1.2614081e-37-2.2461159e-18j) |                -2.2461159e-18j |   (1.2614081e-37-2.2461159e-18j) |
|              8 |  (-3.1043605e-12+2.4817625e-06j) |  (-2.4787897e-14+2.4817625e-06j) |  (-2.4787911e-14+2.4817625e-06j) |                 2.4817625e-06j |  (-2.4787911e-14+2.4817625e-06j) |
|              7 |  (-2.0593258e-30-1.9919841e-15j) |   (-7.532549e-32-1.9919841e-15j) |   (-7.532545e-32-1.9919841e-15j) |                -1.9919841e-15j |   (-7.532545e-32-1.9919841e-15j) |
|              7 |   (-1.542265e-12+1.7250827e-06j) |  (-5.4309855e-14+1.7250827e-06j) |   (-5.430983e-14+1.7250827e-06j) |                 1.7250827e-06j |   (-5.430983e-14+1.7250827e-06j) |
|              6 |    (-6.859575e-37-1.138099e-18j) |   (-3.8322845e-38-1.138099e-18j) |   (-3.8322828e-38-1.138099e-18j) |                 -1.138099e-18j |   (-3.8322828e-38-1.138099e-18j) |
|              5 |       (0.0002259754-0.08273122j) |      (0.0036349818-0.082524665j) |        (0.003634983-0.08252467j) |      (0.003634991-0.08252467j) |       (0.0036349827-0.08252467j) |
|              5 | (-5.2773154e-33+1.02234825e-16j) | (-5.1335675e-35+1.02234825e-16j) | (-5.1335646e-35+1.02234825e-16j) |                1.02234825e-16j | (-5.1335646e-35+1.02234825e-16j) |
|              4 |     (-3.2815045e-08-0.11343931j) |       (0.0063931597-0.11295645j) |       (0.006393158-0.112956434j) |     (0.0063930997-0.11295644j) |       (0.0063931583-0.11295645j) |
|              4 |   (-1.683874e-24-1.7716747e-12j) |  (-1.1445849e-25-1.7716747e-12j) |  (-1.1445844e-25-1.7716747e-12j) |                -1.7716747e-12j |  (-1.1445844e-25-1.7716747e-12j) |
|              4 |       (0.0002259754+0.51158303j) |          (0.11641588+0.4727794j) |         (0.11641591+0.47277942j) |       (0.11641591+0.47277942j) |        (0.116415925+0.47277942j) |

```

</details>

and after enabling the pass:

```
 -inf    11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111112
  -4e35   11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111112221
  -4e32   11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111222121
  -5e29   11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111112111211
  -5e26   11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111112222111211
  -6e23   11=11111111111111111111111111111111111111111111111111111111111111111111111111111111111111122111111111
  -6e20   11=1==1111111111111===============================================================1111122122111111111
  -7e17   11=1==111111111111111===========================================================111112222222111111111
  -8e14   11=1==111=1111111111111111111111111111111111111111111111111111111111111111111111111122222122111111111
  -9e11   11=1==111===11111111111111111111111111111111111111111111111111111111111111111111111221122122111111111
  -9e8    11=1==111===1=111111111111111111111111111111111111111111111111111111111111111111111121112122111121211
  -1e6    11=1==111===1=1=1111111111111111111111111111111111111111111111111111111111111222211121122122111121121
  -1e3    11=1==111===1=1=1111111111111111111111111111111111111111111111111111111111112221111111121122111111211
  -1      11=1==111===1=1=111111111!222221111111111111111111111111111111111111111222222222211121212122111111211
  -1e-3   11=1==111===1=1=1111=1111222222111111111111111111111111111111111111111112222211221111111211211111111=
  -2e-6   11=1==111===1=1=1111=1==12122221=====================================11112122121211111121122111111=11
  -2e-9   11=1==111===1=1=1111=1==1212222=======================================111212212121112211212211111111=
  -2e-12  11=1==111===1=1=1111=1==1212122======1====1===============11===1======111222212121112112212211=111=1=
  -2e-15  11=1==111===1=1=1111=1==1212122=================1=====================11121221222111211221111111=1=1=
  -2e-18  11=1==111===1=1=1111=1==1212222=================11111=================11122221222111111111=111=1=1=1=
  -2e-21  11=1==111===1=1=1111=1==1212122==================111==================111212212121112111=1111==1=1=1=
  -3e-24  11=1==111===1=1=1111=1==1211122===================1===================11121221212111111111=11==1=1=1=
  -3e-27  11=1==111===1=1=1111=1==1212222=======================================111222212221111111===11==1=1=1=
  -3e-30  11=1==111===1=1=1111=1==1211122=======================================111222212121111111===11==1=1=1=
  -4e-33  11=1==111===1=1=1111=1==1212222=======================================111222211111111=11===11==1=1=1=
  -4e-36  11=1==111===1=1=1111=1==1212221=======================================11122121==111=1=11===11==1=1=1=
  0       11=1==111===1=1=1111=1==1212221=======================================111222211=11==1=11===11==1=1=1=
  4e-36   11=1==111===1=1=1111=1==1212122=======================================111212212111111=11===11==1=1=1=
  4e-33   11=1==111===1=1=1111=1==1211122=======================================111212212111=11111===11==1=1=1=
  3e-30   11=1==111===1=1=1111=1==1212122=======================================111222212221111111===11==1=1=1=
  3e-27   11=1==111===1=1=1111=1==1212222===================1===================11122221222111111111=11==1=1=1=
  3e-24   11=1==111===1=1=1111=1==1211122==================111==================111222212121111111=1111==1=1=1=
  3e-21   11=1==111===1=1=1111=1==1212222==================111==================11121221122111111121=111=1=1=1=
  2e-18   11=1==111===1=1=1111=1==1212122=======================================11122221222111121121111111=1=1=
  2e-15   11=1==111===1=1=1111=1==1211122=======================================1112122121211112222122111111=1=
  2e-12   11=1==111===1=1=1111=1==1212122=======================================111222212121112112112211111111=
  2e-9    11=1==111===1=1=1111=1==12122211=====================================11112222122211111212122111111=11
  2e-6    11=1==111===1=1=1111=1111222122111111111111111111111111111111111111111111222212211111221212211111111=
  2e-3    11=1==111===1=1=1111111113222221111111111111111111111111111111111111111122222211211121121122111111111
  2       11=1==111===1=1=1111111111111111111111111111111111111111111111111111111111122222111121112122111121121
  1e3     11=1==111===1=1=1111111111111111111111111111111111111111111111111111111111111122211122112122111111111
  1e6     11=1==111===1=111111111111111111111111111111111111111111111111111111111111111111221112121122111111211
  1e9     11=1==111===1111111111111===================================================1111122221112122111111111
  1e12    11=1==111=1111111111111=======================================================11111122222122111111211
  1e15    11=1==11111111111111111111111111111111111111111111111111111111111111111111111111111111122122111111111
  9e17    11=1==1111111111111===============================================================1111122122111111111
  8e20    11=11111111111111111111111111111111111111111111111111111111111111111111111111111111111111222121121111
  8e23    11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111112222111211
  7e26    1111111111111===========================================================================1111122221111
  6e29    11111111111===============================================================================11111222221
  6e32    111111111===================================================================================111112211
  5e35    1111111=======================================================================================1111112
          -inf  -3e29 -2e20 -1e11 -1e2  -8e-8 -6e-1 -4e-2 -3e-3 1e-33 1e-24 2e-15 3e-6  4e3   5e12  8e21  1e31

```

<details>

```
| ULP-difference |                        z |            jax:log1p(z) |         mpmath:log1p(z) |          numpy:log1p(z) |             fa:log1p(z) |
| -------------- | ------------------------ | ----------------------- | ----------------------- | ----------------------- | ----------------------- |
|              4 | (-1.1209118-0.11343931j) | (-1.7969997-2.3880699j) | (-1.7970002-2.3880696j) | (-1.7970002-2.3880696j) | (-1.7969997-2.3880699j) |

```

</details>
Copybara import of the project:

--
896c0f4 by Pearu Peterson <[email protected]>:

Enable stable-complex-math-expander pass.

Merging this change closes #20853

COPYBARA_INTEGRATE_REVIEW=#20853 from pearu:pearu/complex-math-expander 896c0f4
PiperOrigin-RevId: 709846266
  • Loading branch information
pearu authored and Google-ML-Automation committed Dec 26, 2024
1 parent 4107c21 commit ccc23a9
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions xla/pjrt/mlir_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ absl::Status MlirToXlaComputation(mlir::ModuleOp module,
mlir::BaseScopedDiagnosticHandler diagnostic_handler(context);
{
mlir::PassManager pm(context);
// Expand stablehlo complex math functions such as log_plus_one, etc.
pm.addNestedPass<mlir::func::FuncOp>(
mlir::stablehlo::createStablehloComplexMathExpanderPass());
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::mhlo::createChloLegalizeToHloPass());
Expand Down Expand Up @@ -223,6 +226,10 @@ absl::StatusOr<std::string> SerializeUsingVersionedStablehlo(
// Legalize CHLO -> [StableHLO+Shape] -> StableHLO
// Preserve higher-level ops with XLA support. To be replaced by composites.
mlir::PassManager pm(context);
// Expand stablehlo complex math functions such as log_plus_one, etc.
pm.addNestedPass<mlir::func::FuncOp>(
mlir::stablehlo::createStablehloComplexMathExpanderPass());

xla::sdy::addSdyRoundTripExportPipeline(pm);
pm.addNestedPass<mlir::func::FuncOp>(
mlir::mhlo::createChloLegalizeToHighLevelMhloPass());
Expand Down

0 comments on commit ccc23a9

Please sign in to comment.