Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update typehinting code to use primitives and collections types as main types over typing variants #33427

Merged
123 changes: 100 additions & 23 deletions sdks/python/apache_beam/typehints/native_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@
# pytype: skip-file

import collections
import collections.abc
import logging
import sys
import types
import typing
from typing import Generic
from typing import TypeVar

from apache_beam.typehints import typehints

T = TypeVar('T')

_LOGGER = logging.getLogger(__name__)

# Describes an entry in the type map in convert_to_beam_type.
Expand All @@ -45,7 +50,18 @@
frozenset: typing.FrozenSet,
}

_BUILTINS = [
dict,
list,
tuple,
set,
frozenset,
]

_CONVERTED_COLLECTIONS = [
collections.abc.Iterable,
collections.abc.Iterator,
collections.abc.Generator,
collections.abc.Set,
collections.abc.MutableSet,
collections.abc.Collection,
Expand Down Expand Up @@ -99,14 +115,25 @@ def _match_issubclass(match_against):
return lambda user_type: _safe_issubclass(user_type, match_against)


def _is_primative(user_type, primative):
jrmccluskey marked this conversation as resolved.
Show resolved Hide resolved
# catch bare primatives
if user_type is primative:
return True
return getattr(user_type, '__origin__', None) is primative


def _match_is_primative(match_against):
return lambda user_type: _is_primative(user_type, match_against)


def _match_is_exactly_mapping(user_type):
# Avoid unintentionally catching all subtypes (e.g. strings and mappings).
expected_origin = collections.abc.Mapping
return getattr(user_type, '__origin__', None) is expected_origin


def _match_is_exactly_iterable(user_type):
if user_type is typing.Iterable:
if user_type is typing.Iterable or user_type is collections.abc.Iterable:
return True
# Avoid unintentionally catching all subtypes (e.g. strings and mappings).
expected_origin = collections.abc.Iterable
Expand Down Expand Up @@ -152,11 +179,13 @@ def _match_is_union(user_type):
return False


def match_is_set(user_type):
if _safe_issubclass(user_type, typing.Set):
def _match_is_set(user_type):
if _safe_issubclass(user_type, typing.Set) or _is_primative(user_type, set):
return True
elif getattr(user_type, '__origin__', None) is not None:
return _safe_issubclass(user_type.__origin__, collections.abc.Set)
return _safe_issubclass(
user_type.__origin__, collections.abc.Set) or _safe_issubclass(
user_type.__origin__, collections.abc.MutableSet)
else:
return False

Expand Down Expand Up @@ -197,6 +226,36 @@ def convert_builtin_to_typing(typ):
return typ


def convert_typing_to_builtin(typ):
"""Converts a given typing collections type to its builtin counterpart.

Args:
typ: A typing type (e.g., typing.List[int]).

Returns:
type: The corresponding builtin type (e.g., list).
jrmccluskey marked this conversation as resolved.
Show resolved Hide resolved
"""
origin = getattr(typ, '__origin__', None)
args = getattr(typ, '__args__', None)
# Typing types return the primative type as the origin from 3.9 on
if origin not in _BUILTINS:
return typ
# Early return for bare types
if not args:
return origin
if origin is list:
return list[convert_typing_to_builtin(args[0])]
elif origin is dict:
return dict[convert_typing_to_builtin(args[0]),
convert_typing_to_builtin(args[1])]
elif origin is tuple:
return tuple[tuple(convert_typing_to_builtin(args))]
elif origin is set:
return set[convert_typing_to_builtin(args)]
elif origin is frozenset:
return frozenset[convert_typing_to_builtin(args)]


def convert_collections_to_typing(typ):
"""Converts a given collections.abc type to a typing object.

Expand All @@ -216,6 +275,24 @@ def convert_collections_to_typing(typ):
return typ


def is_builtin(typ):
if typ in _BUILTINS:
return True
return getattr(typ, '__origin__', None) in _BUILTINS


# During type inference of WindowedValue, we need to pass in the inner value
# type. This cannot be achieved immediately with WindowedValue class because it
# is not parameterized. Changing it to a generic class (e.g. WindowedValue[T])
# could work in theory. However, the class is cythonized and it seems that
# cython does not handle generic classes well.
# The workaround here is to create a separate class solely for the type
# inference purpose. This class should never be used for creating instances.
class TypedWindowedValue(Generic[T]):
jrmccluskey marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, *args, **kwargs):
raise NotImplementedError("This class is solely for type inference")


def convert_to_beam_type(typ):
"""Convert a given typing type to a Beam type.

Expand All @@ -238,11 +315,8 @@ def convert_to_beam_type(typ):
sys.version_info.minor >= 10) and (isinstance(typ, types.UnionType)):
typ = typing.Union[typ]

if isinstance(typ, types.GenericAlias):
typ = convert_builtin_to_typing(typ)

if getattr(typ, '__module__', None) == 'collections.abc':
typ = convert_collections_to_typing(typ)
if getattr(typ, '__module__', None) == 'typing':
typ = convert_typing_to_builtin(typ)

typ_module = getattr(typ, '__module__', None)
if isinstance(typ, typing.TypeVar):
Expand All @@ -267,8 +341,16 @@ def convert_to_beam_type(typ):
# TODO(https://github.com/apache/beam/issues/20076): Currently unhandled.
_LOGGER.info('Converting NewType type hint to Any: "%s"', typ)
return typehints.Any
elif (typ_module != 'typing') and (typ_module != 'collections.abc'):
# Only translate types from the typing and collections.abc modules.
elif typ_module == 'apache_beam.typehints.native_type_compatibility' and \
getattr(typ, "__name__", typ.__origin__.__name__) == 'TypedWindowedValue':
# Need to pass through WindowedValue class so that it can be converted
# to the correct type constraint in Beam
# This is needed to fix https://github.com/apache/beam/issues/33356
pass

elif (typ_module != 'typing') and (typ_module !=
'collections.abc') and not is_builtin(typ):
# Only translate primatives and types from collections.abc and typing.
return typ
if (typ_module == 'collections.abc' and
typ.__origin__ not in _CONVERTED_COLLECTIONS):
Expand All @@ -285,39 +367,34 @@ def convert_to_beam_type(typ):
_TypeMapEntry(match=is_forward_ref, arity=0, beam_type=typehints.Any),
_TypeMapEntry(match=is_any, arity=0, beam_type=typehints.Any),
_TypeMapEntry(
match=_match_issubclass(typing.Dict),
arity=2,
beam_type=typehints.Dict),
match=_match_is_primative(dict), arity=2, beam_type=typehints.Dict),
_TypeMapEntry(
match=_match_is_exactly_iterable,
arity=1,
beam_type=typehints.Iterable),
_TypeMapEntry(
match=_match_issubclass(typing.List),
arity=1,
beam_type=typehints.List),
match=_match_is_primative(list), arity=1, beam_type=typehints.List),
# FrozenSets are a specific instance of a set, so we check this first.
_TypeMapEntry(
match=_match_issubclass(typing.FrozenSet),
match=_match_is_primative(frozenset),
arity=1,
beam_type=typehints.FrozenSet),
_TypeMapEntry(match=match_is_set, arity=1, beam_type=typehints.Set),
_TypeMapEntry(match=_match_is_set, arity=1, beam_type=typehints.Set),
# NamedTuple is a subclass of Tuple, but it needs special handling.
# We just convert it to Any for now.
# This MUST appear before the entry for the normal Tuple.
_TypeMapEntry(
match=match_is_named_tuple, arity=0, beam_type=typehints.Any),
_TypeMapEntry(
match=_match_issubclass(typing.Tuple),
arity=-1,
match=_match_is_primative(tuple), arity=-1,
beam_type=typehints.Tuple),
_TypeMapEntry(match=_match_is_union, arity=-1, beam_type=typehints.Union),
_TypeMapEntry(
match=_match_issubclass(typing.Generator),
match=_match_issubclass(collections.abc.Generator),
arity=3,
beam_type=typehints.Generator),
_TypeMapEntry(
match=_match_issubclass(typing.Iterator),
match=_match_issubclass(collections.abc.Iterator),
arity=1,
beam_type=typehints.Iterator),
_TypeMapEntry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from apache_beam.typehints.native_type_compatibility import convert_to_beam_types
from apache_beam.typehints.native_type_compatibility import convert_to_typing_type
from apache_beam.typehints.native_type_compatibility import convert_to_typing_types
from apache_beam.typehints.native_type_compatibility import convert_typing_to_builtin
from apache_beam.typehints.native_type_compatibility import is_any

_TestNamedTuple = typing.NamedTuple(
Expand All @@ -43,6 +44,7 @@ class _TestClass(object):


T = typing.TypeVar('T')
U = typing.TypeVar('U')


class _TestGeneric(typing.Generic[T]):
Expand Down Expand Up @@ -140,7 +142,7 @@ def test_convert_to_beam_type_with_builtin_types(self):
(
'builtin nested tuple',
tuple[str, list],
typehints.Tuple[str, typehints.List[typehints.Any]],
typehints.Tuple[str, typehints.List[typehints.TypeVariable('T')]],
)
]

Expand All @@ -159,7 +161,7 @@ def test_convert_to_beam_type_with_collections_types(self):
typehints.Iterable[int]),
(
'collection generator',
collections.abc.Generator[int],
collections.abc.Generator[int, None, None],
typehints.Generator[int]),
(
'collection iterator',
Expand All @@ -177,9 +179,8 @@ def test_convert_to_beam_type_with_collections_types(self):
'mapping not caught',
collections.abc.Mapping[str, int],
collections.abc.Mapping[str, int]),
('set', collections.abc.Set[str], typehints.Set[str]),
('set', collections.abc.Set[int], typehints.Set[int]),
('mutable set', collections.abc.MutableSet[int], typehints.Set[int]),
('enum set', collections.abc.Set[_TestEnum], typehints.Set[_TestEnum]),
(
'enum mutable set',
collections.abc.MutableSet[_TestEnum],
Expand Down Expand Up @@ -337,6 +338,24 @@ def test_is_any(self):
for expected, typ in test_cases:
self.assertEqual(expected, is_any(typ), msg='%s' % typ)

def test_convert_typing_to_builtin(self):
test_cases = [
('list', typing.List[int],
list[int]), ('dict', typing.Dict[str, int], dict[str, int]),
('tuple', typing.Tuple[str, int], tuple[str, int]),
('set', typing.Set[str], set[str]),
('frozenset', typing.FrozenSet[int], frozenset[int]),
(
'nested',
typing.List[typing.Dict[str, typing.Tuple[int]]],
list[dict[str, tuple[int]]]), ('typevar', typing.List[T], list[T]),
('nested_typevar', typing.Dict[T, typing.List[U]], dict[T, list[U]])
]

for description, typing_type, expected_builtin_type in test_cases:
builtin_type = convert_typing_to_builtin(typing_type)
self.assertEqual(builtin_type, expected_builtin_type, description)


if __name__ == '__main__':
unittest.main()
Loading