Skip to content

Commit

Permalink
Return arrays from ArrayImpl._check_and_rearrange. Build IFRT shard…
Browse files Browse the repository at this point in the history
…ings with both addressable and non-addressable devices, instead of only addressable devices.

This is a roll-forward of two previous rollbacks after fixing breakages.

PiperOrigin-RevId: 721929080
  • Loading branch information
emilyfertig authored and Google-ML-Automation committed Feb 4, 2025
1 parent 654a2f6 commit fbaae7e
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 4 deletions.
59 changes: 56 additions & 3 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@

def _get_device(a: ArrayImpl) -> Device:
devices = a.sharding._internal_device_list # pytype: disable=attribute-error
assert len(devices) == 1
if len(devices) != 1:
raise ValueError(
"When making an array from single-device arrays the input arrays must "
f"have one shard each. An argument array had {len(devices)} shard(s).")
return devices[0]


Expand Down Expand Up @@ -195,17 +198,67 @@ def __init__(self, aval: core.ShapedArray, sharding: Sharding,

self.aval = aval
self._sharding = sharding
self._arrays = [a._arrays[0] for a in arrays]
self._committed = committed
self._npy_value = None
arrays = [a._arrays[0] for a in arrays]

# Don't rearrange if skip_checks is enabled because this assumes that the
# input buffers are already arranged properly. This usually happens when
# Array's are created as output of a JAX transformation
# (like pjit, etc).
if not _skip_checks or config.enable_checks.value:
self._check_and_rearrange()
arrays = self._check_and_rearrange_arrays(
arrays, self._sharding, self.aval)
self._arrays = arrays # type: ignore

def _check_and_rearrange_arrays(self, arrays, sharding, aval):
device_id_to_buffer = {_get_device(db).id: db for db in arrays}

addressable_dev = sharding.addressable_devices
if len(arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(arrays)}")

array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
if len(array_device_ids) != len(arrays):
buffer_device_ids = [_get_device(db).id for db in arrays]
raise ValueError(
"When making an array from single-device arrays, the input arrays"
" must be from distinct devices, but got device IDs"
f" {buffer_device_ids}")

# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)

_validate_shape_and_dtype_for_per_device_arrays(
arrays,
sharding=sharding,
aval=aval,
expected_shape=sharding.shard_shape(aval.shape),
)

# Rearrange arrays based on the device assignment.
addressable_da = sharding._addressable_device_assignment
return [device_id_to_buffer[device.id] for device in addressable_da]

# TODO(emilyaf): Remove this method and its callsite in py_array.cc once
# xla_extension_version < 310 is no longer supported.
def _check_and_rearrange(self):
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}

Expand Down
2 changes: 1 addition & 1 deletion tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def test_duplicated_devices_in_arrays(self):
# Sharding device ids = {0, 1}
s = jax.sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {0, 2}
# _arrays device ids = {0, 0}
bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)]
with self.assertRaisesRegex(
ValueError,
Expand Down

0 comments on commit fbaae7e

Please sign in to comment.