diff --git a/jax/_src/array.py b/jax/_src/array.py index c99698a4153c..b88cde61f221 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -55,7 +55,10 @@ def _get_device(a: ArrayImpl) -> Device: devices = a.sharding._internal_device_list # pytype: disable=attribute-error - assert len(devices) == 1 + if len(devices) != 1: + raise ValueError( + "When making an array from single-device arrays the input arrays must " + f"have one shard each. An argument array had {len(devices)} shard(s).") return devices[0] @@ -195,17 +198,67 @@ def __init__(self, aval: core.ShapedArray, sharding: Sharding, self.aval = aval self._sharding = sharding - self._arrays = [a._arrays[0] for a in arrays] self._committed = committed self._npy_value = None + arrays = [a._arrays[0] for a in arrays] # Don't rearrange if skip_checks is enabled because this assumes that the # input buffers are already arranged properly. This usually happens when # Array's are created as output of a JAX transformation # (like pjit, etc). if not _skip_checks or config.enable_checks.value: - self._check_and_rearrange() + arrays = self._check_and_rearrange_arrays( + arrays, self._sharding, self.aval) + self._arrays = arrays # type: ignore + def _check_and_rearrange_arrays(self, arrays, sharding, aval): + device_id_to_buffer = {_get_device(db).id: db for db in arrays} + + addressable_dev = sharding.addressable_devices + if len(arrays) != len(addressable_dev): + raise ValueError( + f"Expected {len(addressable_dev)} per-device arrays " + "(this is how many devices are addressable by the sharding), but " + f"got {len(arrays)}") + + array_device_ids = set(device_id_to_buffer.keys()) + addressable_device_ids = {d.id for d in addressable_dev} + if len(array_device_ids) != len(arrays): + buffer_device_ids = [_get_device(db).id for db in arrays] + raise ValueError( + "When making an array from single-device arrays, the input arrays" + " must be from distinct devices, but got device IDs" + f" {buffer_device_ids}") + + # Calculate a symmetric difference because the device ids between sharding + # and _arrays should match. + diff = array_device_ids ^ addressable_device_ids + if diff: + dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids + dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids + err_msg = ( + "Addressable devices and per-device arrays devices do not match.") + if dev_in_sharding_not_in_arrays: + err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} " + "that are not present in per-device arrays.") + if dev_in_arrays_not_in_sharding: + err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} " + "that are not present in the sharding.") + raise ValueError(err_msg) + + _validate_shape_and_dtype_for_per_device_arrays( + arrays, + sharding=sharding, + aval=aval, + expected_shape=sharding.shard_shape(aval.shape), + ) + + # Rearrange arrays based on the device assignment. + addressable_da = sharding._addressable_device_assignment + return [device_id_to_buffer[device.id] for device in addressable_da] + + # TODO(emilyaf): Remove this method and its callsite in py_array.cc once + # xla_extension_version < 310 is no longer supported. def _check_and_rearrange(self): device_id_to_buffer = {_get_device(db).id: db for db in self._arrays} diff --git a/tests/array_test.py b/tests/array_test.py index 2b1f53f4bea5..99517295c2b4 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -374,7 +374,7 @@ def test_duplicated_devices_in_arrays(self): # Sharding device ids = {0, 1} s = jax.sharding.NamedSharding(mesh, P('x')) inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) - # _arrays device ids = {0, 2} + # _arrays device ids = {0, 0} bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)] with self.assertRaisesRegex( ValueError,