Skip to content

Commit

Permalink
add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao committed Sep 16, 2024
1 parent c134678 commit dcd91d8
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
43 changes: 43 additions & 0 deletions depyf/explain/enable_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,46 @@ def enable_bytecode_hook(hook):
@contextlib.contextmanager
def prepare_debug(dump_src_dir, clean_wild_fx_code=True, log_bytecode=False):
"""
A context manager to dump debugging information for torch.compile.
It should wrap the code that actually triggers the compilation, rather than
the code that applies ``torch.compile``.
Example:
.. code-block:: python
import torch
@torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def main():
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
if __name__ == "__main__":
# main()
# surround the code you want to run inside `with depyf.prepare_debug`
import depyf
with depyf.prepare_debug("./dump_src_dir"):
main()
After running the code, you will find the dumped information in the directory ``dump_src_dir``. The details are organized into the following:
- ``full_code_for_xxx.py`` for each function using torch.compile
- ``__transformed_code_for_xxx.py`` for Python code associated with each graph.
- ``__transformed_code_for_xxx.py.xxx_bytecode`` for Python bytecode, dumped code object, can be loaded via ``dill.load(open("/path/to/file", "wb"))``. Note that the load function might import some modules like transformers. Make sure you have these modules installed.
- ``__compiled_fn_xxx.py`` for each computation graph and its optimization:
- ``Captured Graph``: a plain forward computation graph
- ``Joint Graph``: joint forward-backward graph from AOTAutograd
- ``Forward Graph``: forward graph from AOTAutograd
- ``Backward Graph``: backward graph from AOTAutograd
- ``kernel xxx``: compiled CPU/GPU kernel wrapper from Inductor.
Args:
dump_src_dir: the directory to dump the source code.
clean_wild_fx_code: whether to clean the wild fx code that are not recognized for parts of compiled functions. They are usually used by PyTorch internally.
Expand Down Expand Up @@ -185,6 +225,9 @@ def prepare_debug(dump_src_dir, clean_wild_fx_code=True, log_bytecode=False):

@contextlib.contextmanager
def debug():
"""
A context manager to debug the compiled code. Essentially, it sets a breakpoint to pause the program and allows you to check the full source code in files with prefix ``full_code_for_`` in the ``dump_src_dir`` argument of :func:`depyf.prepare_debug`, and set breakpoints in their separate ``__transformed_code_`` files according to the function name. Then continue your debugging.
"""
from .global_variables import data
if data["is_inside_prepare_debug"]:
raise RuntimeError("You cannot use `depyf.debug` inside `depyf.prepare_debug`.")
Expand Down
11 changes: 11 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
API Reference
=============

Understand and debug ``torch.compile``
--------------------------------------

.. autofunction:: depyf.prepare_debug

.. autofunction:: depyf.debug

.. warning::

It is recommended to read the :doc:`walk_through` to have a basic understanding of how ``torch.compile`` works, before using the above functions.

Decompile general Python Bytecode/Function
-------------------------------------------

Expand Down
6 changes: 3 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ If you'd like to contribute (which we highly appreciate), please read the `devel
:maxdepth: 1
:hidden:

api_reference
walk_through
faq
dev_doc
opt_tutorial
api_reference
dev_doc
faq

0 comments on commit dcd91d8

Please sign in to comment.