Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finalize some deprecations from jax.lib.xla_client #26292

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
env vars. Before they could only be specified via jax.config or flags.

* Deprecations
* Several previously-deprecated APIs have been removed, including:
* From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`,
and `shape_from_pyval`.

## jax 0.5.0 (Jan 17, 2025)

As of this release, JAX now uses
Expand Down
32 changes: 13 additions & 19 deletions jax/lib/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from jax._src.lax.fft import FftType as _FftType
from jax._src.lib import xla_client as _xc

get_topology_for_devices = _xc.get_topology_for_devices
Expand Down Expand Up @@ -48,23 +47,27 @@
),
None,
),
# Added Oct 10 2024
# Finalized 2024-02-03; remove after 2025-05-03
"FftType": (
"jax.lib.xla_client.FftType is deprecated; use jax.lax.FftType.",
_FftType,
"jax.lib.xla_client.FftType was removed in JAX v0.5.1; use jax.lax.FftType.",
None,
),
"PaddingType": (
(
"jax.lib.xla_client.PaddingType is deprecated; this type is unused"
" by JAX so there is no replacement."
"jax.lib.xla_client.PaddingType was removed in JAX v0.5.1;"
" this type is unused by JAX so there is no replacement."
),
_xc.PaddingType,
None,
),
# Added Oct 11 2024
"dtype_to_etype": (
"dtype_to_etype is deprecated; use StableHLO instead.",
_xc.dtype_to_etype,
"dtype_to_etype was removed in JAX v0.5.1; use StableHLO instead.",
None,
),
"shape_from_pyval": (
"shape_from_pyval was removed in JAX v0.5.1; use StableHLO instead.",
None,
),
# Added Oct 11 2024
"ops": (
"ops is deprecated; use StableHLO instead.",
_xc.ops,
Expand All @@ -74,10 +77,6 @@
"(https://jax.readthedocs.io/en/latest/ffi.html)",
_xc.register_custom_call_target,
),
"shape_from_pyval": (
"shape_from_pyval is deprecated; use StableHLO instead.",
_xc.shape_from_pyval,
),
"PrimitiveType": (
"PrimitiveType is deprecated; use StableHLO instead.",
_xc.PrimitiveType,
Expand All @@ -104,14 +103,10 @@
import typing as _typing

if _typing.TYPE_CHECKING:
dtype_to_etype = _xc.dtype_to_etype
ops = _xc.ops
register_custom_call_target = _xc.register_custom_call_target
shape_from_pyval = _xc.shape_from_pyval
ArrayImpl = _xc.ArrayImpl
Device = _xc.Device
FftType = _FftType
PaddingType = _xc.PaddingType
PrimitiveType = _xc.PrimitiveType
Shape = _xc.Shape
XlaBuilder = _xc.XlaBuilder
Expand All @@ -123,5 +118,4 @@
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
del _FftType
del _xc
Loading