Skip to content

Commit

Permalink
Reverts 3b2410f
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721908845
  • Loading branch information
emilyfertig authored and Google-ML-Automation committed Feb 5, 2025
1 parent 781172c commit 413269f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 40 deletions.
130 changes: 91 additions & 39 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.lib import xla_extension_version
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding,
Expand All @@ -55,7 +56,10 @@

def _get_device(a: ArrayImpl) -> Device:
devices = a.sharding._internal_device_list # pytype: disable=attribute-error
assert len(devices) == 1
if len(devices) != 1:
raise ValueError(
"When making an array from single-device arrays the input arrays must "
f"have one shard each. An argument array had {len(devices)} shard(s).")
return devices[0]


Expand Down Expand Up @@ -195,54 +199,102 @@ def __init__(self, aval: core.ShapedArray, sharding: Sharding,

self.aval = aval
self._sharding = sharding
self._arrays = [a._arrays[0] for a in arrays]
self._committed = committed
self._npy_value = None
arrays = [a._arrays[0] for a in arrays]

# Don't rearrange if skip_checks is enabled because this assumes that the
# input buffers are already arranged properly. This usually happens when
# Array's are created as output of a JAX transformation
# (like pjit, etc).
if not _skip_checks or config.enable_checks.value:
self._check_and_rearrange()
arrays = self._check_and_rearrange(arrays, self._sharding, self.aval)
self._arrays = arrays # type: ignore

def _check_and_rearrange(self):
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
if xla_extension_version >= 310:
def _check_and_rearrange(self, arrays, sharding, aval):
device_id_to_buffer = {_get_device(db).id: db for db in arrays}

addressable_dev = self.sharding.addressable_devices
if len(self._arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(self._arrays)}")

array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)

_validate_shape_and_dtype_for_per_device_arrays(
self._arrays,
sharding=self.sharding,
aval=self.aval,
expected_shape=self.sharding.shard_shape(self.shape),
)
# Rearrange arrays based on the device assignment.
addressable_da = self.sharding._addressable_device_assignment
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
addressable_dev = sharding.addressable_devices
if len(arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(arrays)}")

array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
if len(array_device_ids) != len(arrays):
buffer_device_ids = [_get_device(db).id for db in arrays]
raise ValueError(
"When making an array from single-device arrays, the input arrays"
" must be from distinct devices, but got device IDs"
f" {buffer_device_ids}")

# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)

_validate_shape_and_dtype_for_per_device_arrays(
arrays,
sharding=sharding,
aval=aval,
expected_shape=sharding.shard_shape(aval.shape),
)

# Rearrange arrays based on the device assignment.
addressable_da = sharding._addressable_device_assignment
return [device_id_to_buffer[device.id] for device in addressable_da]
else:
def _check_and_rearrange(self): # type: ignore
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}

addressable_dev = self.sharding.addressable_devices
if len(self._arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(self._arrays)}")

array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)

_validate_shape_and_dtype_for_per_device_arrays(
self._arrays,
sharding=self.sharding,
aval=self.aval,
expected_shape=self.sharding.shard_shape(self.shape),
)
# Rearrange arrays based on the device assignment.
addressable_da = self.sharding._addressable_device_assignment
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]

@property
def shape(self) -> Shape:
Expand Down
2 changes: 1 addition & 1 deletion tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def test_duplicated_devices_in_arrays(self):
# Sharding device ids = {0, 1}
s = jax.sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {0, 2}
# _arrays device ids = {0, 0}
bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)]
with self.assertRaisesRegex(
ValueError,
Expand Down

0 comments on commit 413269f

Please sign in to comment.