Skip to content

Commit

Permalink
Simplify and cleanup the test_mypy.py file (pydantic#10612)
Browse files Browse the repository at this point in the history
  • Loading branch information
Viicos authored Oct 14, 2024
1 parent 6d3717c commit 9ef4637
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 108 deletions.
12 changes: 6 additions & 6 deletions pydantic/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,16 @@ def version_info() -> str:
return '\n'.join('{:>30} {}'.format(k + ':', str(v).replace('\n', ' ')) for k, v in info.items())


def parse_mypy_version(version: str) -> tuple[int, ...]:
"""Parse mypy string version to tuple of ints.
def parse_mypy_version(version: str) -> tuple[int, int, int]:
"""Parse `mypy` string version to a 3-tuple of ints.
It parses normal version like `0.930` and extra info followed by a `+` sign
like `0.940+dev.04cac4b5d911c4f9529e6ce86a27b44f28846f5d.dirty`.
It parses normal version like `1.11.0` and extra info followed by a `+` sign
like `1.11.0+dev.d6d9d8cd4f27c52edac1f537e236ec48a01e54cb.dirty`.
Args:
version: The mypy version string.
Returns:
A tuple of ints. e.g. (0, 930).
A triple of ints, e.g. `(1, 11, 0)`.
"""
return tuple(map(int, version.partition('+')[0].split('.')))
return tuple(map(int, version.partition('+')[0].split('.'))) # pyright: ignore[reportReturnType]
192 changes: 90 additions & 102 deletions tests/mypy/test_mypy.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,38 @@
from __future__ import annotations

import dataclasses
import importlib
import os
import re
import sys
from bisect import insort
from collections.abc import Collection
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union
from typing import TYPE_CHECKING

import pytest
from _pytest.mark import Mark, MarkDecorator
from _pytest.mark.structures import ParameterSet
from typing_extensions import TypeAlias

try:
# Pyright doesn't like try/expect blocks for imports:
if TYPE_CHECKING:
from mypy import api as mypy_api
from mypy.version import __version__ as mypy_version

from pydantic.version import parse_mypy_version
else:
try:
from mypy import api as mypy_api
from mypy.version import __version__ as mypy_version

from pydantic.version import parse_mypy_version

except ImportError:
mypy_api = None
mypy_version = None
except ImportError:
mypy_api = None
mypy_version = None

parse_mypy_version = lambda _: (0,) # noqa: E731
parse_mypy_version = lambda _: (0,) # noqa: E731


MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version)
Expand All @@ -35,52 +48,44 @@
os.chdir(Path(__file__).parent.parent.parent)


@dataclasses.dataclass
class MypyCasesBuilder:
configs: Union[str, List[str]]
modules: Union[str, List[str]]
marks: Any = None

def build(self) -> List[Union[Tuple[str, str], Any]]:
"""
Produces the cartesian product of the configs and modules, optionally with marks.
"""
if isinstance(self.configs, str):
self.configs = [self.configs]
if isinstance(self.modules, str):
self.modules = [self.modules]
built_cases = []
for config in self.configs:
for module in self.modules:
built_cases.append((config, module))
if self.marks is not None:
built_cases = [pytest.param(config, module, marks=self.marks) for config, module in built_cases]
return built_cases


cases = (
# Type hint taken from the signature of `pytest.param`:
Marks: TypeAlias = 'MarkDecorator | Collection[MarkDecorator | Mark]'


def build_cases(
configs: list[str],
modules: list[str],
marks: Marks = (),
) -> list[ParameterSet]:
"""Produces the cartesian product of the configs and modules, optionally with marks."""

return [pytest.param(config, module, marks=marks) for config in configs for module in modules]


cases: list[ParameterSet | tuple[str, str]] = [
# No plugin
MypyCasesBuilder(
*build_cases(
['mypy-default.ini', 'pyproject-default.toml'],
['fail1.py', 'fail2.py', 'fail3.py', 'fail4.py', 'pydantic_settings.py'],
).build()
+ MypyCasesBuilder(
),
*build_cases(
['mypy-default.ini', 'pyproject-default.toml'],
'success.py',
['success.py'],
pytest.mark.skipif(MYPY_VERSION_TUPLE > (1, 0, 1), reason='Need to handle some more things for mypy >=1.1.1'),
).build()
+ MypyCasesBuilder(
),
*build_cases(
['mypy-default.ini', 'pyproject-default.toml'],
'root_models.py',
['root_models.py'],
pytest.mark.skipif(
MYPY_VERSION_TUPLE < (1, 1, 1), reason='`dataclass_transform` only supported on mypy >= 1.1.1'
),
).build()
+ MypyCasesBuilder(
'mypy-default.ini', ['plugin_success.py', 'plugin_success_baseConfig.py', 'metaclass_args.py']
).build()
),
*build_cases(
['mypy-default.ini'],
['plugin_success.py', 'plugin_success_baseConfig.py', 'metaclass_args.py'],
),
# Default plugin config
+ MypyCasesBuilder(
*build_cases(
['mypy-plugin.ini', 'pyproject-plugin.toml'],
[
'plugin_success.py',
Expand All @@ -89,9 +94,9 @@ def build(self) -> List[Union[Tuple[str, str], Any]]:
'plugin_fail_baseConfig.py',
'pydantic_settings.py',
],
).build()
),
# Strict plugin config
+ MypyCasesBuilder(
*build_cases(
['mypy-plugin-strict.ini', 'pyproject-plugin-strict.toml'],
[
'plugin_success.py',
Expand All @@ -100,9 +105,9 @@ def build(self) -> List[Union[Tuple[str, str], Any]]:
'plugin_success_baseConfig.py',
'plugin_fail_baseConfig.py',
],
).build()
),
# One-off cases
+ [
*[
('mypy-plugin.ini', 'custom_constructor.py'),
('mypy-plugin.ini', 'config_conditional_extra.py'),
('mypy-plugin.ini', 'covariant_typevar.py'),
Expand All @@ -119,23 +124,22 @@ def build(self) -> List[Union[Tuple[str, str], Any]]:
('pyproject-plugin-no-strict-optional.toml', 'no_strict_optional.py'),
('pyproject-plugin-strict-equality.toml', 'strict_equality.py'),
('pyproject-plugin.toml', 'from_orm_v1_noconflict.py'),
]
)


@dataclasses.dataclass
class MypyTestTarget:
parsed_mypy_version: Tuple[int, ...]
output_path: Path
],
]


@dataclasses.dataclass
class MypyTestConfig:
existing: Optional[MypyTestTarget] # the oldest target with an output that is no older than the installed mypy
current: MypyTestTarget # the target for the current installed mypy
existing_output_path: Path | None
"""The path pointing to the existing test result, or `None` if this is the first time the test is run."""

current_output_path: Path
"""The path pointing to the current test result to be created or compared to the existing one."""


def get_test_config(module_path: Path, config_path: Path) -> MypyTestConfig:
"""Given a file to test with a specific config, get a test config."""

outputs_dir = PYDANTIC_ROOT / 'tests/mypy/outputs'
outputs_dir.mkdir(exist_ok=True)
existing_versions = [
Expand All @@ -145,10 +149,10 @@ def get_test_config(module_path: Path, config_path: Path) -> MypyTestConfig:
def _convert_to_output_path(v: str) -> Path:
return outputs_dir / v / config_path.name.replace('.', '_') / module_path.name

existing = None
existing: Path | None = None

# Build sorted list of (parsed_version, version) pairs, including the current mypy version being used
parsed_version_pairs = sorted([(parse_mypy_version(v), v) for v in existing_versions])
parsed_version_pairs = sorted((parse_mypy_version(v), v) for v in existing_versions)
if MYPY_VERSION_TUPLE not in [x[0] for x in parsed_version_pairs]:
insort(parsed_version_pairs, (MYPY_VERSION_TUPLE, mypy_version))

Expand All @@ -157,15 +161,23 @@ def _convert_to_output_path(v: str) -> Path:
continue
output_path = _convert_to_output_path(version)
if output_path.exists():
existing = MypyTestTarget(parsed_version, output_path)
existing = output_path
break

current = MypyTestTarget(MYPY_VERSION_TUPLE, _convert_to_output_path(mypy_version))
return MypyTestConfig(existing, current)
return MypyTestConfig(existing, _convert_to_output_path(mypy_version))


def get_expected_return_code(source_code: str) -> int:
"""Return 1 if at least one `# MYPY:` comment was found, else 0."""
if re.findall(r'^\s*# MYPY:', source_code, flags=re.MULTILINE):
return 1
return 0


@pytest.mark.filterwarnings('ignore:ast.:DeprecationWarning') # these are produced by mypy in python 3.12
@pytest.mark.parametrize('config_filename,python_filename', cases)
@pytest.mark.parametrize(
['config_filename', 'python_filename'],
cases,
)
def test_mypy_results(config_filename: str, python_filename: str, request: pytest.FixtureRequest) -> None:
input_path = PYDANTIC_ROOT / 'tests/mypy/modules' / python_filename
config_path = PYDANTIC_ROOT / 'tests/mypy/configs' / config_filename
Expand Down Expand Up @@ -195,10 +207,10 @@ def test_mypy_results(config_filename: str, python_filename: str, request: pytes

input_code = input_path.read_text()

existing_output_code: Optional[str] = None
if test_config.existing is not None:
existing_output_code = test_config.existing.output_path.read_text()
print(f'Comparing output with {test_config.existing.output_path}')
existing_output_code: str | None = None
if test_config.existing_output_path is not None:
existing_output_code = test_config.existing_output_path.read_text()
print(f'Comparing output with {test_config.existing_output_path}')
else:
print(f'Comparing output with {input_path} (expecting no mypy errors)')

Expand All @@ -208,8 +220,8 @@ def test_mypy_results(config_filename: str, python_filename: str, request: pytes
# Test passed, no changes needed
pass
elif request.config.getoption('update_mypy'):
test_config.current.output_path.parent.mkdir(parents=True, exist_ok=True)
test_config.current.output_path.write_text(merged_output)
test_config.current_output_path.parent.mkdir(parents=True, exist_ok=True)
test_config.current_output_path.write_text(merged_output)
else:
print('**** Merged Output ****')
print(merged_output)
Expand Down Expand Up @@ -237,50 +249,26 @@ def test_bad_toml_config() -> None:
assert str(e.value) == 'Configuration value must be a boolean for key: init_forbid_extra'


def get_expected_return_code(source_code: str) -> int:
if re.findall(r'^\s*# MYPY:', source_code, flags=re.MULTILINE):
return 1
return 0


@pytest.mark.parametrize('module', ['dataclass_no_any', 'plugin_success', 'plugin_success_baseConfig'])
@pytest.mark.filterwarnings('ignore:.*is deprecated.*:DeprecationWarning')
@pytest.mark.filterwarnings('ignore:.*are deprecated.*:DeprecationWarning')
def test_success_cases_run(module: str) -> None:
"""
Ensure the "success" files can actually be executed
"""
importlib.import_module(f'tests.mypy.modules.{module}')


def test_explicit_reexports():
from pydantic import __all__ as root_all
from pydantic.deprecated.tools import __all__ as tools
from pydantic.main import __all__ as main
from pydantic.networks import __all__ as networks
from pydantic.types import __all__ as types

for name, export_all in [('main', main), ('network', networks), ('tools', tools), ('types', types)]:
for export in export_all:
assert export in root_all, f'{export} is in {name}.__all__ but missing from re-export in __init__.py'


def test_explicit_reexports_exist():
import pydantic

for name in pydantic.__all__:
assert hasattr(pydantic, name), f'{name} is in pydantic.__all__ but missing from pydantic'
module_name = f'tests.mypy.modules.{module}'
try:
importlib.import_module(module_name)
except Exception:
pytest.fail(reason=f'Unable to execute module {module_name}')


@pytest.mark.parametrize(
'v_str,v_tuple',
['v_str', 'v_tuple'],
[
('0', (0,)),
('0.930', (0, 930)),
('0.940+dev.04cac4b5d911c4f9529e6ce86a27b44f28846f5d.dirty', (0, 940)),
('1.11.0', (1, 11, 0)),
('1.11.0+dev.d6d9d8cd4f27c52edac1f537e236ec48a01e54cb.dirty', (1, 11, 0)),
],
)
def test_parse_mypy_version(v_str, v_tuple):
def test_parse_mypy_version(v_str: str, v_tuple: tuple[int, int, int]) -> None:
assert parse_mypy_version(v_str) == v_tuple


Expand Down
17 changes: 17 additions & 0 deletions tests/test_dunder_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
def test_explicit_reexports() -> None:
from pydantic import __all__ as root_all
from pydantic.deprecated.tools import __all__ as tools
from pydantic.main import __all__ as main
from pydantic.networks import __all__ as networks
from pydantic.types import __all__ as types

for name, export_all in [('main', main), ('networks', networks), ('deprecated.tools', tools), ('types', types)]:
for export in export_all:
assert export in root_all, f'{export} is in `pydantic.{name}.__all__` but missing in `pydantic.__all__`'


def test_explicit_reexports_exist() -> None:
import pydantic

for name in pydantic.__all__:
assert hasattr(pydantic, name), f'{name} is in `pydantic.__all__` but `from pydantic import {name}` fails'

0 comments on commit 9ef4637

Please sign in to comment.