Skip to content

Commit

Permalink
[JAX] Add a test verifying the behavior of module-level state accesse…
Browse files Browse the repository at this point in the history
…d by colocated Python

A new test verifies that
* Python module-level variables can be created/set and read from a colocated Python function
* Python module-level variables are not pickled on the controller (JAX) or sent to executors via pickling

An API for defining user-defined state and accessing it from multiple colocated
Python functions (i.e., object support) will be added later. That will be a
recommended way to express user-defined state. The capability of accessing
Python module variables is still crucial because a lot of Python code
(including JAX) requires this behavior to implement caching.

PiperOrigin-RevId: 722898985
  • Loading branch information
hyeontaek authored and Google-ML-Automation committed Feb 4, 2025
1 parent 6281b86 commit 8615355
Showing 1 changed file with 43 additions and 1 deletion.
44 changes: 43 additions & 1 deletion tests/colocated_python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
import threading
import time
from typing import Sequence
Expand Down Expand Up @@ -51,7 +52,8 @@ def _colocated_cpu_devices(


_count_colocated_python_specialization_cache_miss = jtu.count_events(
"colocated_python_func._get_specialized_func")
"colocated_python_func._get_specialized_func"
)


class ColocatedPythonTest(jtu.JaxTestCase):
Expand Down Expand Up @@ -330,6 +332,46 @@ def add(x: jax.Array, y: jax.Array) -> jax.Array:
out = jax.device_get(out)
np.testing.assert_equal(out, np.array([2 + 4, 0 + 8]))

def testModuleVariableAccess(self):
try:
# The following pattern of storing and accessing non-serialized state in
# the Python module is discouraged for storing user-defined state.
# However, it should still work because many caching mechanisms rely on
# this behavior.

# Poison the test's own `colocated_python` module with a non-serializable
# object (file) to detect any invalid attempt to serialize the module as
# part of a colocated Python function.
colocated_python._testing_non_serializable_object = (
tempfile.TemporaryFile()
)

@colocated_python.colocated_python
def set_global_state(x: jax.Array) -> jax.Array:
colocated_python._testing_global_state = x
return x + 1

@colocated_python.colocated_python
def get_global_state(x: jax.Array) -> jax.Array:
del x
return colocated_python._testing_global_state

cpu_devices = _colocated_cpu_devices(jax.local_devices())
x = np.array(1)
x = jax.device_put(x, cpu_devices[0])
y = np.array(2)
y = jax.device_put(y, cpu_devices[0])

jax.block_until_ready(set_global_state(x))
out = jax.device_get(get_global_state(y))

np.testing.assert_equal(out, np.array(1))
finally:
if "_testing_non_serializable_object" in colocated_python.__dict__:
del colocated_python._testing_non_serializable_object
if "_testing_global_state" in colocated_python.__dict__:
del colocated_python._testing_global_state


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 8615355

Please sign in to comment.