Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[better_errors] Add debug info to the Jaxprs formed for AD (step 2) #26313

Merged
merged 1 commit into from
Feb 5, 2025

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Feb 4, 2025

Following #26078 , we add debug info to more calls of lu.wrap_init.
Most of the work here is to construct a debug info in appropriate places and thread it to places where lu.wrap_init is called. There are just a lot of calls to lu.wrap_init, and this PR does not even fix all of them.

We also add type declarations in all the functions that I had to understand for this PR.

@gnecula gnecula self-assigned this Feb 4, 2025
@gnecula gnecula added the pull ready Ready for copybara import and testing label Feb 4, 2025
@gnecula gnecula force-pushed the debug_info_vjp branch 2 times, most recently from 32151e2 to 7606afc Compare February 4, 2025 19:13
@gnecula gnecula changed the title [better_errors] Add debug info to the Jaxprs formed for AD [better_errors] Add debug info to the Jaxprs formed for AD (step 2) Feb 5, 2025
@gnecula gnecula force-pushed the debug_info_vjp branch 9 times, most recently from 2331c9e to 5b9b852 Compare February 5, 2025 11:45
@gnecula gnecula requested a review from dfm February 5, 2025 11:47
jax/_src/custom_derivatives.py Outdated Show resolved Hide resolved
jax/_src/custom_derivatives.py Outdated Show resolved Hide resolved
debug_fwd = debug_info("custom_vjp fwd", self.fwd, args, kwargs,
static_argnums=self.nondiff_argnums)
# TODO(necula): figure out how to construct the debug_bwd args
debug_bwd = debug_info("custom_vjp bwd", self.bwd, args, {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted in the TODO, this definitely isn't the right signature here. Would it be better to use a more obvious placeholder?

(Side note, while I'm here: I think the arguments to bwd are *static_args, residuals, output_cotangents, where residuals are output by fwd and output_cotangents has the tangent structure associated with fun's output. I can see why this will be tricky to annotate!)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this is tricky because I do not know the structure of the residuals and output_cotangents at this point.

The signatures are only needed to pick up the proper names for arg_names. Fixing that will have to come later, right now I am focusing just populating at least the function src info. We may never be able to get the arg_names perfectly right, there are so many places where functions are transformed. In that case, I will think (later) about a way to safely default to None. For now, the default is that we use None if the length of arg_names is different than jaxpr.invars.

Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
@copybara-service copybara-service bot merged commit c46b021 into jax-ml:main Feb 5, 2025
23 checks passed
@gnecula gnecula deleted the debug_info_vjp branch February 5, 2025 18:58
gnecula added a commit to gnecula/jax that referenced this pull request Feb 6, 2025
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 6, 2025
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 6, 2025
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 6, 2025
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 6, 2025
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 6, 2025
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 6, 2025
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 6, 2025
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 7, 2025
This follows after jax-ml#26078, jax-ml#26313, jax-ml#26348, adding `debug_info` to more calls to `lu.wrap_init`.

As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`.

These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 7, 2025
This follows after jax-ml#26078, jax-ml#26313, jax-ml#26348, adding `debug_info` to more calls to `lu.wrap_init`.

As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`.

These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 7, 2025
This follows after jax-ml#26078, jax-ml#26313, jax-ml#26348, adding `debug_info` to more calls to `lu.wrap_init`.

As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`.

These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 7, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest,
api_test.py:CustomVmapTest and api_test.py:RematTest.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 7, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest,
api_test.py:CustomVmapTest and api_test.py:RematTest.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 7, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest,
api_test.py:CustomVmapTest and api_test.py:RematTest.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 7, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows after jax-ml#26078, jax-ml#26313, jax-ml#26348, adding `debug_info` to more calls to `lu.wrap_init`.

As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`.

These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest,
api_test.py:CustomVmapTest and api_test.py:RematTest.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, ...
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, ...
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, tests.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, tests.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, tests.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 9, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 9, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants