Skip to content

Commit

Permalink
add assert rewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
tsv1 committed May 26, 2024
1 parent c70f4fa commit bce0d77
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 16 deletions.
3 changes: 2 additions & 1 deletion tests/plugins/director/rich/printer/test_rich_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def test_print_exception(*, printer: RichPrinter, exc_info: ExcInfo, console_: M
def test_print_pretty_exception(*, printer: RichPrinter, exc_info: ExcInfo, console_: Mock):
with given:
trace = Traceback.extract(exc_info.type, exc_info.value, exc_info.traceback)
tb = TestTraceback(trace, max_frames=8, word_wrap=False, width=console_.size.width)
tb = TestTraceback(trace, max_frames=8, word_wrap=False, width=console_.size.width,
indent_guides=False)

with when:
printer.print_pretty_exception(exc_info)
Expand Down
9 changes: 7 additions & 2 deletions vedro/plugins/assert_rewriter/_assert_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Type

from vedro.core import ConfigType, Dispatcher, Plugin, PluginConfig
from vedro.events import ArgParsedEvent, ConfigLoadedEvent
from vedro.events import ArgParsedEvent, ArgParseEvent, ConfigLoadedEvent

from ._assert_rewriter_module_loader import AssertRewriterModuleLoader
from ._legacy_assert_rewriter import LegacyAssertRewriter
Expand All @@ -15,13 +15,18 @@ def __init__(self, config: Type["AssertRewriter"]):

def subscribe(self, dispatcher: Dispatcher) -> None:
dispatcher.listen(ConfigLoadedEvent, self.on_config_loaded) \
.listen(ArgParseEvent, self.on_arg_parse) \
.listen(ArgParsedEvent, self.on_arg_parsed)

def on_config_loaded(self, event: ConfigLoadedEvent) -> None:
self._global_config: ConfigType = event.config

def on_arg_parse(self, event: ArgParseEvent) -> None:
event.arg_parser.add_argument("--exp-pretty-diff", action="store_true", default=False,
help="")

def on_arg_parsed(self, event: ArgParsedEvent) -> None:
exp_pretty_diff = getattr(event.args, "exp_pretty_diff", False)
exp_pretty_diff = event.args.exp_pretty_diff
module_loader = AssertRewriterModuleLoader if exp_pretty_diff else LegacyAssertRewriter
self._global_config.Registry.ModuleLoader.register(module_loader, self)

Expand Down
54 changes: 53 additions & 1 deletion vedro/plugins/assert_rewriter/_assert_rewriter_module_loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,63 @@
import ast
import inspect
from importlib.abc import Loader
from types import ModuleType
from typing import Any, Optional

from vedro.core import ModuleFileLoader

__all__ = ("AssertRewriterModuleLoader",)


def assert_eq(left: Any, right: Any, message: Optional[str] = None) -> bool:
if left != right:
exc = AssertionError(message)
exc.__vedro_assert_left__ = left # type: ignore
exc.__vedro_assert_right__ = right # type: ignore
exc.__vedro_assert_message__ = message # type: ignore
raise exc

return True


class AssertRewriter(ast.NodeTransformer):
def visit_Assert(self, node: ast.Assert) -> ast.Assert:
if not isinstance(node.test, ast.Compare):
return node

if not len(node.test.ops) == 1:
return node

if not isinstance(node.test.ops[0], ast.Eq):
return node

msg = node.msg if node.msg else ast.Constant(value="", kind=None)
new_node = ast.Assert(
test=ast.Call(
func=ast.Name(id='assert_eq', ctx=ast.Load()),
args=[node.test.left, node.test.comparators[0], msg],
keywords=[],
),
)

ast.copy_location(new_node, node)
ast.fix_missing_locations(new_node)

return new_node


class AssertRewriterModuleLoader(ModuleFileLoader):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._assert_rewriter = AssertRewriter()

__builtins__["assert_eq"] = assert_eq

def _exec_module(self, loader: Loader, module: ModuleType) -> None:
super()._exec_module(loader, module)
source_code = inspect.getsource(module)

tree = ast.parse(source_code)
tree = self._assert_rewriter.visit(tree)

transformed = compile(tree, module.__file__, "exec") # type: ignore
exec(transformed, module.__dict__)
35 changes: 34 additions & 1 deletion vedro/plugins/director/rich/_rich_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
from rich.pretty import Pretty
from rich.status import Status
from rich.style import Style
from rich.text import Text
from rich.traceback import Trace, Traceback

import vedro
from vedro.core import ExcInfo, ScenarioStatus, StepStatus

from .utils.pretty_diff import pretty_diff

__all__ = ("RichPrinter",)


Expand Down Expand Up @@ -132,6 +135,20 @@ def _filter_locals(self, trace: Trace) -> None:
frame.locals = {k: v for k, v in frame.locals.items()
if k != "self" and k.isidentifier()}

def __trim_traceback(self, tb: TracebackType) -> TracebackType:
# Traverse the traceback to exclude the last frame
frames = []
while tb is not None:
frames.append(tb)
tb = tb.tb_next # type: ignore
trimmed_frames = frames[:-1] # Exclude the last frame

# Construct a new traceback object from the trimmed frames
new_tb = None
for frame in reversed(trimmed_frames):
new_tb = TracebackType(new_tb, frame.tb_frame, frame.tb_lasti, frame.tb_lineno)
return new_tb

def print_pretty_exception(self, exc_info: ExcInfo, *,
max_frames: int = 8, # min=4 (see rich.traceback.Traceback impl)
show_locals: bool = False,
Expand All @@ -143,6 +160,9 @@ def print_pretty_exception(self, exc_info: ExcInfo, *,
else:
traceback = self.__filter_internals(exc_info.traceback)

if hasattr(exc_info.value, "__vedro_assert_left__"):
traceback = self.__trim_traceback(traceback)

trace = Traceback.extract(exc_info.type, exc_info.value, traceback,
show_locals=show_locals)

Expand All @@ -153,10 +173,23 @@ def print_pretty_exception(self, exc_info: ExcInfo, *,
width = self._console.size.width

tb = self._traceback_factory(trace, max_frames=max_frames, word_wrap=word_wrap,
width=width)
width=width, indent_guides=False)
self._console.print(tb)
self.print_empty_line()

if hasattr(exc_info.value, "__vedro_assert_left__"):
left = getattr(exc_info.value, "__vedro_assert_left__")
right = getattr(exc_info.value, "__vedro_assert_right__")

self.pretty_print(
Text(">>> assert ", style="bold") +
Text("actual ", style="bold red") +
Text("== ", style="bold") +
Text("expected", style="bold green")
)
self.pretty_print(pretty_diff(left, right))
self.print_empty_line()

def pretty_format(self, value: Any) -> Any:
warnings.warn("Deprecated: method will be removed in v2.0", DeprecationWarning)
if hasattr(value, "__rich__") or hasattr(value, "__rich_console__"):
Expand Down
7 changes: 0 additions & 7 deletions vedro/plugins/director/rich/_rich_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __init__(self, config: Type["RichReporter"], *,
self._show_scope = config.show_scope
self._v2_verbosity = config.v2_verbosity
self._ring_bell = config.ring_bell
self._exp_pretty_diff = False
self._namespace: Union[str, None] = None

def subscribe(self, dispatcher: Dispatcher) -> None:
Expand Down Expand Up @@ -117,11 +116,6 @@ def on_arg_parse(self, event: ArgParseEvent) -> None:
dest="ring_bell",
help="Trigger a 'bell' sound at the end of scenario execution")

group.add_argument("--exp-pretty-diff",
action="store_true",
default=self._exp_pretty_diff,
help="")

def on_arg_parsed(self, event: ArgParsedEvent) -> None:
self._verbosity = event.args.verbose
self._show_scope = event.args.show_scope
Expand All @@ -140,7 +134,6 @@ def on_arg_parsed(self, event: ArgParsedEvent) -> None:
self._tb_show_internal_calls = event.args.tb_show_internal_calls
self._tb_show_locals = event.args.tb_show_locals
self._ring_bell = event.args.ring_bell
self._exp_pretty_diff = event.args.exp_pretty_diff

if self._tb_max_frames < 4:
raise ValueError("RichReporter: `tb_max_frames` must be >= 4")
Expand Down
8 changes: 4 additions & 4 deletions vedro/plugins/director/rich/utils/pretty_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Generator, Iterable, List, Optional, Tuple

from rich.console import Group
from rich.panel import Panel
from rich.padding import Padding
from rich.text import Text


Expand Down Expand Up @@ -60,8 +60,8 @@ def _compare(actual: Any, expected: Any) -> Generator[str, None, None]:
yield from differ.compare(_format(actual), _format(expected))


def pretty_diff(actual: Any, expected: Any) -> Panel:
diff = _compare(actual, expected)
def pretty_diff(actual: Any, expected: Any) -> Padding:
diff = _compare(expected, actual)
colored_diff = _color_diff(diff)
renderable = Group(*colored_diff)
return Panel(renderable, title="Diff", padding=(1, 1), expand=True)
return Padding(renderable, (0, 2))

0 comments on commit bce0d77

Please sign in to comment.