Skip to content

Commit

Permalink
Default JAX_CPU_COLLECTIVES_IMPLEMENTATION to 'gloo'.
Browse files Browse the repository at this point in the history
This enables CPU collectives by default, making multi-process CPU
communication work without extra configuration.

PiperOrigin-RevId: 722209646
  • Loading branch information
skye authored and Google-ML-Automation committed Feb 2, 2025
1 parent de48ce2 commit 202af1e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Changes
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
env vars. Before they could only be specified via jax.config or flags.
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` now defaults to `'gloo'`, meaning
multi-process CPU communication works out-of-the-box.

## jax 0.5.0 (Jan 17, 2025)

Expand Down
5 changes: 5 additions & 0 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@

MIN_COMPUTE_CAPABILITY = 52

_DEFAULT_CPU_COLLECTIVES_IMPL = 'gloo'

# TODO(phawkins): Remove jax_xla_backend.
_XLA_BACKEND = config.string_flag(
'jax_xla_backend', '',
Expand Down Expand Up @@ -244,6 +246,9 @@ def make_cpu_client(
'"jax_cpu_collectives_implementation", "gloo")` instead.',
DeprecationWarning,
)
if collectives_impl is None:
collectives_impl = _DEFAULT_CPU_COLLECTIVES_IMPL

if collectives_impl == 'gloo':
collectives = xla_client._xla.make_gloo_tcp_collectives(
distributed_client=distributed.global_state.client,
Expand Down

0 comments on commit 202af1e

Please sign in to comment.