From 651d18303419ae4065f80eed4d700f862b092367 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Mon, 3 Feb 2025 19:55:37 -0800 Subject: [PATCH] [JAX] Add a test verifying the behavior of module-level state accessed by colocated Python A new test verifies that * Python module-level variables can be created/set and read from a colocated Python function * Python module-level variables are not pickled on the controller (JAX) or sent to executors via pickling An API for defining user-defined state and accessing it from multiple colocated Python functions (i.e., object support) will be added later. That will be a recommended way to express user-defined state. The capability of accessing Python module variables is still crucial because a lot of Python code (including JAX) requires this behavior to implement caching. PiperOrigin-RevId: 722898985 --- tests/colocated_python_test.py | 45 +++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index f9dd3ce52b58..fc04fab4d20a 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import threading import time from typing import Sequence @@ -51,7 +52,8 @@ def _colocated_cpu_devices( _count_colocated_python_specialization_cache_miss = jtu.count_events( - "colocated_python_func._get_specialized_func") + "colocated_python_func._get_specialized_func" +) class ColocatedPythonTest(jtu.JaxTestCase): @@ -330,6 +332,47 @@ def add(x: jax.Array, y: jax.Array) -> jax.Array: out = jax.device_get(out) np.testing.assert_equal(out, np.array([2 + 4, 0 + 8])) + def testModuleVariableAccess(self): + try: + # The following pattern of storing and accessing non-serialized state in + # the Python module is discouraged for storing user-defined state. + # However, it should still work because many caching mechanisms rely on + # this behavior. + + # Poison the test's own `colocated_python` module with a non-serializable + # object (file) to detect any invalid attempt to serialize the module as + # part of a colocated Python function. + colocated_python._testing_non_serializable_object = ( + tempfile.TemporaryFile() + ) + + @colocated_python.colocated_python + def set_global_state(x: jax.Array) -> jax.Array: + colocated_python._testing_global_state = x + return x + 1 + + @colocated_python.colocated_python + def get_global_state(x: jax.Array) -> jax.Array: + del x + return colocated_python._testing_global_state + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = np.array(1) + x = jax.device_put(x, cpu_devices[0]) + y = np.array(2) + y = jax.device_put(y, cpu_devices[0]) + + jax.block_until_ready(set_global_state(x)) + out = jax.device_get(get_global_state(y)) + + np.testing.assert_equal(out, np.array(1)) + finally: + if "_testing_non_serializable_object" in colocated_python.__dict__: + colocated_python._testing_non_serializable_object.close() + del colocated_python._testing_non_serializable_object + if "_testing_global_state" in colocated_python.__dict__: + del colocated_python._testing_global_state + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())