Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX] Add a test verifying the behavior of module-level state accesse…
…d by colocated Python A new test verifies that * Python module-level variables can be created/set and read from a colocated Python function * Python module-level variables are not pickled on the controller (JAX) or sent to executors via pickling An API for defining user-defined state and accessing it from multiple colocated Python functions (i.e., object support) will be added later. That will be a recommended way to express user-defined state. The capability of accessing Python module variables is still crucial because a lot of Python code (including JAX) requires this behavior to implement caching. PiperOrigin-RevId: 722898985
- Loading branch information