-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[better_errors] Add debug info to the Jaxprs formed for AD (step 2) #26313
Conversation
32151e2
to
7606afc
Compare
2331c9e
to
5b9b852
Compare
debug_fwd = debug_info("custom_vjp fwd", self.fwd, args, kwargs, | ||
static_argnums=self.nondiff_argnums) | ||
# TODO(necula): figure out how to construct the debug_bwd args | ||
debug_bwd = debug_info("custom_vjp bwd", self.bwd, args, {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As noted in the TODO, this definitely isn't the right signature here. Would it be better to use a more obvious placeholder?
(Side note, while I'm here: I think the arguments to bwd
are *static_args, residuals, output_cotangents
, where residuals
are output by fwd
and output_cotangents
has the tangent structure associated with fun
's output. I can see why this will be tricky to annotate!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, this is tricky because I do not know the structure of the residuals and output_cotangents at this point.
The signatures are only needed to pick up the proper names for arg_names
. Fixing that will have to come later, right now I am focusing just populating at least the function src info. We may never be able to get the arg_names
perfectly right, there are so many places where functions are transformed. In that case, I will think (later) about a way to safely default to None
. For now, the default is that we use None
if the length of arg_names
is different than jaxpr.invars
.
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
5b9b852
to
abcaec7
Compare
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitives `custom_vjp_call_jaxpr` and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`, but in almost all cases if was really ` lu.WrappedFun.call_wrapped`.
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitives `custom_vjp_call_jaxpr` and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`, but in almost all cases if was really ` lu.WrappedFun.call_wrapped`.
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitives `custom_vjp_call_jaxpr` and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`, but in almost all cases if was really ` lu.WrappedFun.call_wrapped`.
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitives `custom_vjp_call_jaxpr` and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`, but in almost all cases if was really ` lu.WrappedFun.call_wrapped`.
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitives `custom_vjp_call_jaxpr` and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`, but in almost all cases if was really ` lu.WrappedFun.call_wrapped`.
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitives `custom_vjp_call_jaxpr` and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`, but in almost all cases if was really ` lu.WrappedFun.call_wrapped`.
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitives `custom_vjp_call_jaxpr` and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`, but in almost all cases if was really ` lu.WrappedFun.call_wrapped`.
This follows after jax-ml#26078, and jax-ml#26313, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitives `custom_vjp_call_jaxpr` and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`, but in almost all cases if was really ` lu.WrappedFun.call_wrapped`.
This follows after jax-ml#26078, jax-ml#26313, jax-ml#26348, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`. These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
This follows after jax-ml#26078, jax-ml#26313, jax-ml#26348, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`. These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
This follows after jax-ml#26078, jax-ml#26313, jax-ml#26348, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`. These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest, api_test.py:CustomVmapTest and api_test.py:RematTest.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest, api_test.py:CustomVmapTest and api_test.py:RematTest.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest, api_test.py:CustomVmapTest and api_test.py:RematTest.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info).
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info).
This follows after jax-ml#26078, jax-ml#26313, jax-ml#26348, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`. These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest, api_test.py:CustomVmapTest and api_test.py:RematTest.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info).
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info).
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, and jax2tf.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, and jax2tf.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, and jax2tf.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, and jax2tf.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, ...
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, ...
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, stateful code, tests.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, stateful code, key_reuse, ode, tests.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, stateful code, key_reuse, ode, tests.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
Following #26078 , we add debug info to more calls of
lu.wrap_init
.Most of the work here is to construct a debug info in appropriate places and thread it to places where
lu.wrap_init
is called. There are just a lot of calls tolu.wrap_init
, and this PR does not even fix all of them.We also add type declarations in all the functions that I had to understand for this PR.