Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(deps): update dependency jax to >=0.4.38, <=0.4.38 #940

Closed
wants to merge 1 commit into from

Conversation

renovate[bot]
Copy link
Contributor

@renovate renovate bot commented Dec 18, 2024

This PR contains the following updates:

Package Change Age Adoption Passing Confidence
jax >=0.4.16, <=0.4.34 -> >=0.4.38, <=0.4.38 age adoption passing confidence

Release Notes

jax-ml/jax (jax)

v0.4.38

Compare Source

  • Changes:

    • jax.tree.flatten_with_path and jax.tree.map_with_path are added
      as shortcuts of the corresponding tree_util functions.
  • Deprecations

    • a number of APIs in the internal jax.core namespace have been deprecated.
      Most were no-ops, were little-used, or can be replaced by APIs of the same
      name in {mod}jax.extend.core; see the documentation for {mod}jax.extend
      for information on the compatibility guarantees of these semi-public extensions.
    • Several previously-deprecated APIs have been removed, including:
      • from {mod}jax.core: check_eqn, check_type, check_valid_jaxtype, and
        non_negative_dim.
      • from {mod}jax.lib.xla_bridge: xla_client and default_backend.
      • from {mod}jax.lib.xla_client: _xla and bfloat16.
      • from {mod}jax.numpy: round_.
  • New Features

    • {func}jax.export.export can be used for device-polymorphic export with
      shardings constructed with {func}jax.sharding.AbstractMesh.
      See the jax.export documentation.
    • Added {func}jax.lax.split. This is a primitive version of
      {func}jax.numpy.split, added because it yields a more compact
      transpose during automatic differentiation.

v0.4.37

Compare Source

This is a patch release of jax 0.4.36. Only "jax" was released at this version.

  • Bug fixes
    • Fixed a bug where jit would error if an argument was named f (#​25329).
    • Fix a bug that will throw index out of range error in
      {func}jax.lax.while_loop if the user register pytree node class with
      different aux data for the flatten and flatten_with_path.
    • Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.

v0.4.36

Compare Source

  • Breaking Changes

    • This release lands "stackless", an internal change to JAX's tracing
      machinery. We made trace dispatch purely a function of context rather than a
      function of both context and data. This let us delete a lot of machinery for
      managing data-dependent tracing: levels, sublevels, post_process_call,
      new_base_main, custom_bind, and so on. The change should only affect
      users that use JAX internals.

      If you do use JAX internals then you may need to
      update your code (see
      jax-ml/jax@c36e1f7
      for clues about how to do this). There might also be version skew
      issues with JAX libraries that do this. If you find this change breaks your
      non-JAX-internals-using code then try the
      config.jax_data_dependent_tracing_fallback flag as a workaround, and if
      you need help updating your code then please file a bug.

    • {func}jax.experimental.jax2tf.convert with native_serialization=False
      or with enable_xla=False have been deprecated since July 2024, with
      JAX version 0.4.31. Now we removed support for these use cases. jax2tf
      with native serialization will still be supported.

    • In jax.interpreters.xla, the xb, xc, and xe symbols have been removed
      after being deprecated in JAX v0.4.31. Instead use xb = jax.lib.xla_bridge,
      xc = jax.lib.xla_client, and xe = jax.lib.xla_extension.

    • The deprecated module jax.experimental.export has been removed. It was replaced
      by {mod}jax.export in JAX v0.4.30. See the migration guide
      for information on migrating to the new API.

    • The initial argument to {func}jax.nn.softmax and {func}jax.nn.log_softmax
      has been removed, after being deprecated in v0.4.27.

    • Calling np.asarray on typed PRNG keys (i.e. keys produced by :func:jax.random.key)
      now raises an error. Previously, this returned a scalar object array.

    • The following deprecated methods and functions in {mod}jax.export have
      been removed:

      • jax.export.DisabledSafetyCheck.shape_assertions: it had no effect
        already.
      • jax.export.Exported.lowering_platforms: use platforms.
      • jax.export.Exported.mlir_module_serialization_version:
        use calling_convention_version.
      • jax.export.Exported.uses_shape_polymorphism:
        use uses_global_constants.
      • the lowering_platforms kwarg for {func}jax.export.export: use
        platforms instead.
    • The kwargs symbolic_scope and symbolic_constraints from
      {func}jax.export.symbolic_args_specs have been removed. They were
      deprecated in June 2024. Use scope and constraints instead.

    • Hashing of tracers, which has been deprecated since version 0.4.30, now
      results in a TypeError.

    • Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
      replaces previous build.py usage. Run python build/build.py --help for
      more details. Brief overview of the new subcommand options:

      • build: Builds JAX wheel packages. For e.g., python build/build.py build --wheels=jaxlib,jax-cuda-pjrt
      • requirements_update: Updates requirements_lock.txt files.
    • {func}jax.scipy.linalg.toeplitz now does implicit batching on multi-dimensional
      inputs. To recover the previous behavior, you can call {func}jax.numpy.ravel
      on the function inputs.

    • {func}jax.scipy.special.gamma and {func}jax.scipy.special.gammasgn now
      return NaN for negative integer inputs, to match the behavior of Scihttps://github.com/scipy/scipy/pull/21827scipy/pull/21827.

    • jax.clear_backends was removed after being deprecated in v0.4.26.

    • We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
      call that we guarantee export stability. This is because this custom call
      relies on Triton IR, which is not guaranteed to be stable. If you need
      to export code that uses this custom call, you can use the disabled_checks
      parameter. See more details in the documentation.

  • New Features

    • {func}jax.jit got a new compiler_options: dict[str, Any] argument, for
      passing compilation options to XLA. For the moment it's undocumented and
      may be in flux.
    • {func}jax.tree_util.register_dataclass now allows metadata fields to be
      declared inline via {func}dataclasses.field. See the function documentation
      for examples.
    • Added {func}jax.numpy.put_along_axis.
    • {func}jax.lax.linalg.eig and the related jax.numpy functions
      ({func}jax.numpy.linalg.eig and {func}jax.numpy.linalg.eigvals) are now
      supported on GPU. See {jax-issue}#24663 for more details.
    • Added two new configuration flags, jax_exec_time_optimization_effort and jax_memory_fitting_effort, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
  • Bug fixes

    • Fixed a bug where the GPU implementations of LU and QR decomposition would
      result in an indexing overflow for batch sizes close to int32 max. See
      {jax-issue}#24843 for more details.
  • Deprecations

    • jax.lib.xla_extension.ArrayImpl and jax.lib.xla_client.ArrayImpl are deprecated;
      use jax.Array instead.
    • jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError
      instead.

v0.4.35

Compare Source

  • Breaking Changes

    • {func}jax.numpy.isscalar now returns True for any array-like object with
      zero dimensions. Previously it only returned True for zero-dimensional
      array-like objects with a weak dtype.
    • jax.experimental.host_callback has been deprecated since March 2024, with
      JAX version 0.4.26. Now we removed it.
      See {jax-issue}#20385 for a discussion of alternatives.
  • Changes:

    • jax.lax.FftType was introduced as a public name for the enum of FFT
      operations. The semi-public API jax.lib.xla_client.FftType has been
      deprecated.
    • TPU: JAX now installs TPU support from the libtpu package rather than
      libtpu-nightly. For the next few releases JAX will pin an empty version of
      libtpu-nightly as well as libtpu to ease the transition; that dependency
      will be removed in Q1 2025.
  • Deprecations:

    • The semi-public API jax.lib.xla_client.PaddingType has been deprecated.
      No JAX APIs consume this type, so there is no replacement.
    • The default behavior of {func}jax.pure_callback and
      {func}jax.extend.ffi.ffi_call under vmap has been deprecated and so has
      the vectorized parameter to those functions. The vmap_method parameter
      should be used instead for better defined behavior. See the discussion in
      {jax-issue}#23881 for more details.
    • The semi-public API jax.lib.xla_client.register_custom_call_target has
      been deprecated. Use the JAX FFI instead.
    • The semi-public APIs jax.lib.xla_client.dtype_to_etype,
      jax.lib.xla_client.ops,
      jax.lib.xla_client.shape_from_pyval, jax.lib.xla_client.PrimitiveType,
      jax.lib.xla_client.Shape, jax.lib.xla_client.XlaBuilder, and
      jax.lib.xla_client.XlaComputation have been deprecated. Use StableHLO
      instead.

Configuration

📅 Schedule: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).

🚦 Automerge: Disabled by config. Please merge this manually once you are satisfied.

Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.

🔕 Ignore: Close this PR and you won't be reminded about this update again.


  • If you want to rebase/retry this PR, check this box

This PR was generated by Mend Renovate. View the repository job log.

@renovate renovate bot requested a review from a team December 18, 2024 02:58
@w-gc w-gc closed this Dec 18, 2024
@github-actions github-actions bot locked and limited conversation to collaborators Dec 18, 2024
@renovate renovate bot deleted the renovate/jax-0.x branch December 18, 2024 06:01
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant