Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tsv1 committed Jun 12, 2024
1 parent cb9f9c2 commit 818c2ec
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 4 deletions.
50 changes: 50 additions & 0 deletions tests/plugins/assert_rewriter/test_assert_rewriter_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
from pathlib import Path
from textwrap import dedent

import pytest
from baby_steps import given, then, when
from niltype import Nil
from pytest import raises

from vedro.plugins.assert_rewriter import AssertRewriterLoader, CompareOperator, assert_


@pytest.fixture()
def tmp_scn_dir(tmp_path: Path) -> Path:
cwd = os.getcwd()
try:
os.chdir(tmp_path)

scn_dir = tmp_path / "scenarios/"
scn_dir.mkdir(exist_ok=True)
yield scn_dir.relative_to(tmp_path)
finally:
os.chdir(cwd)


async def test_load(tmp_scn_dir: Path):
with given:
path = tmp_scn_dir / "scenario.py"
path.write_text(dedent('''
import vedro
class Scenario(vedro.Scenario):
def then(self):
assert 1 == 2
'''))

loader = AssertRewriterLoader()
module = await loader.load(path)
scenario = module.Scenario()

with when, raises(BaseException) as exc:
scenario.then()

with then:
assert exc.type is AssertionError
assert str(exc.value) == ""

assert assert_.get_left(exc.value) == 1
assert assert_.get_right(exc.value) == 2
assert assert_.get_operator(exc.value) == CompareOperator.EQUAL
assert assert_.get_message(exc.value) == Nil
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def then(self):
module = await loader.load(path)
scenario = module.Scenario()

with when, raises(BaseException) as exception:
with when, raises(BaseException) as exc:
scenario.then()

with then:
assert exception.type is AssertionError
assert str(exception.value) == "assert 1 == 2"
assert exc.type is AssertionError
assert str(exc.value) == "assert 1 == 2"
82 changes: 82 additions & 0 deletions tests/plugins/assert_rewriter/test_node_assert_rewriter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import ast
from ast import dump

from baby_steps import given, then, when
from pytest import warns

from vedro.plugins.assert_rewriter import NodeAssertRewriter


def test_rewrite_assert():
with given:
rewriter = NodeAssertRewriter("assert_tool", {ast.Eq: "assert_equal"})
tree = ast.parse("assert x == y")

with when:
rewritten_tree = rewriter.visit(tree)

with then:
assert dump(rewritten_tree) == dump(ast.parse(
"assert assert_tool.assert_equal(x, y)"
))


def test_rewrite_assert_with_message():
with given:
rewriter = NodeAssertRewriter("assert_tool", {ast.Eq: "assert_equal"})
tree = ast.parse("assert x == y, 'x should be equal to y'")

with when:
rewritten_tree = rewriter.visit(tree)

with then:
assert dump(rewritten_tree) == dump(ast.parse(
"assert assert_tool.assert_equal(x, y, message='x should be equal to y')"
))


def test_rewrite_assert_multiple_comparisons():
with given:
rewriter = NodeAssertRewriter("assert_tool", {
ast.Lt: "assert_less",
ast.LtE: "assert_less_equal",
})
tree = ast.parse("assert x < y <= z")

with when:
rewritten_tree = rewriter.visit(tree)

with then:
assert dump(rewritten_tree) == dump(ast.parse(
"assert assert_tool.assert_less(x, y) and assert_tool.assert_less_equal(y, z)"
))


def test_rewrite_assert_truthy():
with given:
rewriter = NodeAssertRewriter("assert_tool", {type(None): "assert_truthy"})
tree = ast.parse("assert x")

with when:
rewritten_tree = rewriter.visit(tree)

with then:
assert dump(rewritten_tree) == dump(ast.parse(
"assert assert_tool.assert_truthy(x)"
))


def test_rewrite_assert_unsupported_operator():
with given:
rewriter = NodeAssertRewriter("assert_tool", {})
tree = ast.parse("assert x is y")

with when, warns(Warning) as record:
rewritten_tree = rewriter.visit(tree)

with then:
assert dump(rewritten_tree) == dump(tree)

warning_message = str(record.list[0].message)
assert "NodeAssertRewriter: Unsupported operator" in warning_message
assert "<ast.Is object" in warning_message
2 changes: 1 addition & 1 deletion vedro/plugins/assert_rewriter/_node_assert_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _create_assert(self, left: ast.AST,
if method is not Nil:
return self._create_assert_call(method, args, keywords)
else:
warnings.warn(f"AssertRewriter: Unsupported operator '{operator}'")
warnings.warn(f"NodeAssertRewriter: Unsupported operator '{operator}'")
raise ValueError(f"Unsupported operator: '{operator}'")

def _create_assert_call(self, method: str,
Expand Down

0 comments on commit 818c2ec

Please sign in to comment.