diff --git a/sdks/python/apache_beam/coders/avro_record.py b/sdks/python/apache_beam/coders/avro_record.py index c9ed26d34eb7..c16c922591ea 100644 --- a/sdks/python/apache_beam/coders/avro_record.py +++ b/sdks/python/apache_beam/coders/avro_record.py @@ -24,6 +24,7 @@ class AvroRecord(object): """Simple wrapper class for dictionary records.""" + def __init__(self, value): self.record = value diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 5dff35052901..9292fd5e25d5 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -125,6 +125,7 @@ class CoderImpl(object): """For internal use only; no backwards-compatibility guarantees.""" + def encode_to_stream(self, value, stream, nested): # type: (Any, create_OutputStream, bool) -> None @@ -211,6 +212,7 @@ class SimpleCoderImpl(CoderImpl): """For internal use only; no backwards-compatibility guarantees. Subclass of CoderImpl implementing stream methods using encode/decode.""" + def encode_to_stream(self, value, stream, nested): # type: (Any, create_OutputStream, bool) -> None @@ -228,6 +230,7 @@ class StreamCoderImpl(CoderImpl): """For internal use only; no backwards-compatibility guarantees. Subclass of CoderImpl implementing encode/decode using stream methods.""" + def encode(self, value): # type: (Any) -> bytes out = create_OutputStream() @@ -255,6 +258,7 @@ class CallbackCoderImpl(CoderImpl): This is the default implementation used if Coder._get_impl() is not overwritten. """ + def __init__(self, encoder, decoder, size_estimator=None): self._encoder = encoder self._decoder = decoder @@ -297,6 +301,7 @@ def __repr__(self): class ProtoCoderImpl(SimpleCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" + def __init__(self, proto_message_type): self.proto_message_type = proto_message_type @@ -311,12 +316,14 @@ def decode(self, encoded): class DeterministicProtoCoderImpl(ProtoCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" + def encode(self, value): return value.SerializePartialToString(deterministic=True) class ProtoPlusCoderImpl(SimpleCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" + def __init__(self, proto_plus_type): # type: (Type[proto.Message]) -> None self.proto_plus_type = proto_plus_type @@ -356,6 +363,7 @@ def decode(self, value): class FastPrimitivesCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" + def __init__( self, fallback_coder_impl, requires_deterministic_step_label=None): self.fallback_coder_impl = fallback_coder_impl @@ -610,6 +618,7 @@ class BytesCoderImpl(CoderImpl): """For internal use only; no backwards-compatibility guarantees. A coder for bytes/str objects.""" + def encode_to_stream(self, value, out, nested): # type: (bytes, create_OutputStream, bool) -> None @@ -636,6 +645,7 @@ class BooleanCoderImpl(CoderImpl): """For internal use only; no backwards-compatibility guarantees. A coder for bool objects.""" + def encode_to_stream(self, value, out, nested): out.write_byte(1 if value else 0) @@ -675,12 +685,12 @@ class MapCoderImpl(StreamCoderImpl): attribute values. A coder for typing.Mapping objects.""" + def __init__( self, key_coder, # type: CoderImpl value_coder, # type: CoderImpl - is_deterministic = False - ): + is_deterministic=False): self._key_coder = key_coder self._value_coder = value_coder self._is_deterministic = is_deterministic @@ -760,6 +770,7 @@ def estimate_size(self, unused_value, nested=False): class BigEndianShortCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" + def encode_to_stream(self, value, out, nested): # type: (int, create_OutputStream, bool) -> None out.write_bigendian_int16(value) @@ -776,6 +787,7 @@ def estimate_size(self, unused_value, nested=False): class SinglePrecisionFloatCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" + def encode_to_stream(self, value, out, nested): # type: (float, create_OutputStream, bool) -> None out.write_bigendian_float(value) @@ -792,6 +804,7 @@ def estimate_size(self, unused_value, nested=False): class FloatCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" + def encode_to_stream(self, value, out, nested): # type: (float, create_OutputStream, bool) -> None out.write_bigendian_double(value) @@ -863,6 +876,7 @@ class TimestampCoderImpl(StreamCoderImpl): that of the Java SDK InstantCoder. https://github.com/apache/beam/blob/f5029b4f0dfff404310b2ef55e2632bbacc7b04f/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/InstantCoder.java#L79 """ + def encode_to_stream(self, value, out, nested): # type: (Timestamp, create_OutputStream, bool) -> None millis = value.micros // 1000 @@ -889,6 +903,7 @@ def estimate_size(self, unused_value, nested=False): class TimerCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" + def __init__(self, key_coder_impl, window_coder_impl): self._timestamp_coder_impl = TimestampCoderImpl() self._boolean_coder_impl = BooleanCoderImpl() @@ -947,6 +962,7 @@ class VarIntCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. A coder for int objects.""" + def encode_to_stream(self, value, out, nested): # type: (int, create_OutputStream, bool) -> None out.write_var_int64(value) @@ -978,6 +994,7 @@ class SingletonCoderImpl(CoderImpl): """For internal use only; no backwards-compatibility guarantees. A coder that always encodes exactly one value.""" + def __init__(self, value): self._value = value @@ -1005,6 +1022,7 @@ class AbstractComponentCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. CoderImpl for coders that are comprised of several component coders.""" + def __init__(self, coder_impls): for c in coder_impls: assert isinstance(c, CoderImpl), c @@ -1030,8 +1048,8 @@ def decode_from_stream(self, in_stream, nested): # type: (create_InputStream, bool) -> Any return self._construct_from_components([ c.decode_from_stream( - in_stream, nested or i + 1 < len(self._coder_impls)) for i, - c in enumerate(self._coder_impls) + in_stream, nested or i + 1 < len(self._coder_impls)) + for i, c in enumerate(self._coder_impls) ]) def estimate_size(self, value, nested=False): @@ -1061,6 +1079,7 @@ def get_estimated_size_and_observables(self, value, nested=False): class AvroCoderImpl(SimpleCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" + def __init__(self, schema): self.parsed_schema = parse_schema(json.loads(schema)) @@ -1077,6 +1096,7 @@ def decode(self, encoded): class TupleCoderImpl(AbstractComponentCoderImpl): """A coder for tuple objects.""" + def _extract_components(self, value): return tuple(value) @@ -1085,6 +1105,7 @@ def _construct_from_components(self, components): class _ConcatSequence(object): + def __init__(self, head, tail): # type: (Iterable[Any], Iterable[Any]) -> None self._head = head @@ -1277,12 +1298,14 @@ class TupleSequenceCoderImpl(SequenceCoderImpl): """For internal use only; no backwards-compatibility guarantees. A coder for homogeneous tuple objects.""" + def _construct_from_sequence(self, components): return tuple(components) class _AbstractIterable(object): """Wraps an iterable hiding methods that might not always be available.""" + def __init__(self, contents): self._contents = contents @@ -1315,6 +1338,7 @@ class IterableCoderImpl(SequenceCoderImpl): """For internal use only; no backwards-compatibility guarantees. A coder for homogeneous iterable objects.""" + def __init__(self, *args, use_abstract_iterable=None, **kwargs): super().__init__(*args, **kwargs) if use_abstract_iterable is None: @@ -1332,6 +1356,7 @@ class ListCoderImpl(SequenceCoderImpl): """For internal use only; no backwards-compatibility guarantees. A coder for homogeneous list objects.""" + def _construct_from_sequence(self, components): return components if isinstance(components, list) else list(components) @@ -1360,6 +1385,7 @@ class PaneInfoCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. Coder for a PaneInfo descriptor.""" + def _choose_encoding(self, value): if ((value._index == 0 and value._nonspeculative_index == 0) or value._timing == PaneInfoTiming_UNKNOWN): @@ -1422,6 +1448,7 @@ def estimate_size(self, value, nested=False): class _OrderedUnionCoderImpl(StreamCoderImpl): + def __init__(self, coder_impl_types, fallback_coder_impl): assert len(coder_impl_types) < 128 self._types, self._coder_impls = zip(*coder_impl_types) @@ -1555,6 +1582,7 @@ class ParamWindowedValueCoderImpl(WindowedValueCoderImpl): encoding, and uses the supplied parameterized timestamp, windows and pane info values during decoding when reconstructing the windowed value.""" + def __init__(self, value_coder, window_coder, payload): super().__init__(value_coder, TimestampCoderImpl(), window_coder) self._timestamp, self._windows, self._pane_info = self._from_proto( @@ -1595,6 +1623,7 @@ class LengthPrefixCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. Coder which prefixes the length of the encoded object in the stream.""" + def __init__(self, value_coder): # type: (CoderImpl) -> None self._value_coder = value_coder @@ -1625,6 +1654,7 @@ class ShardedKeyCoderImpl(StreamCoderImpl): shard id byte string encoded user key """ + def __init__(self, key_coder_impl): self._shard_id_coder_impl = BytesCoderImpl() self._key_coder_impl = key_coder_impl @@ -1660,6 +1690,7 @@ class TimestampPrefixingWindowCoderImpl(StreamCoderImpl): window's max_timestamp() encoded window using it's own coder. """ + def __init__(self, window_coder_impl: CoderImpl) -> None: self._window_coder_impl = window_coder_impl @@ -1687,6 +1718,7 @@ def _create_opaque_window(end, encoded_window): from apache_beam.transforms.window import BoundedWindow class _OpaqueWindow(BoundedWindow): + def __init__(self, end, encoded_window): super().__init__(end) self.encoded_window = encoded_window @@ -1715,6 +1747,7 @@ class TimestampPrefixingOpaqueWindowCoderImpl(StreamCoderImpl): window's max_timestamp() length prefixed encoded window """ + def __init__(self) -> None: pass @@ -1770,6 +1803,7 @@ def finalize_write(self): class GenericRowColumnEncoder(RowColumnEncoder): + def __init__(self, coder_impl, column): self.coder_impl = coder_impl self.column = column @@ -1790,6 +1824,7 @@ def finalize_write(self): class RowCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" + def __init__(self, schema, components): self.schema = schema self.num_fields = len(self.schema.fields) @@ -1859,8 +1894,7 @@ def _row_column_encoders(self, columns): RowColumnEncoder.create( self.schema.fields[i].type.atomic_type, self.components[i], - columns[name]) for i, - name in enumerate(self.field_names) + columns[name]) for i, name in enumerate(self.field_names) ] def encode_batch_to_stream(self, columns: Dict[str, np.ndarray], out): @@ -1964,6 +1998,7 @@ def decode_batch_from_stream(self, dest: Dict[str, np.ndarray], in_stream): class LogicalTypeCoderImpl(StreamCoderImpl): + def __init__(self, logical_type, representation_coder): self.logical_type = logical_type self.representation_coder = representation_coder.get_impl() @@ -1982,6 +2017,7 @@ class BigIntegerCoderImpl(StreamCoderImpl): For interoperability with Java SDK, encoding needs to match that of the Java SDK BigIntegerCoder.""" + def encode_to_stream(self, value, out, nested): # type: (int, create_OutputStream, bool) -> None if value < 0: diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index 0f2a42686854..54bbcb3a3729 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -139,6 +139,7 @@ def deserialize_coder(serialized): class Coder(object): """Base class for coders.""" + def encode(self, value): # type: (Any) -> bytes @@ -331,6 +332,7 @@ def register_urn(cls, urn, parameter_type, fn=None): A corresponding to_runner_api_parameter method would be expected that returns the tuple ('beam:fn:foo', FooPayload) """ + def register(fn): cls._known_urns[urn] = parameter_type, fn return fn @@ -382,9 +384,8 @@ def register_structured_urn(urn, cls): """ setattr( cls, - 'to_runner_api_parameter', - lambda self, - unused_context: (urn, None, self._get_component_coders())) + 'to_runner_api_parameter', lambda self, unused_context: + (urn, None, self._get_component_coders())) # pylint: disable=unused-variable @Coder.register_urn(urn, None) @@ -403,6 +404,7 @@ def _pickle_from_runner_api_parameter(payload, components, context): class StrUtf8Coder(Coder): """A coder used for reading and writing strings as UTF-8.""" + def encode(self, value): return value.encode('utf-8') @@ -422,6 +424,7 @@ def to_type_hint(self): class ToBytesCoder(Coder): """A default string coder used if no sink coder is specified.""" + def encode(self, value): return value if isinstance(value, bytes) else str(value).encode('utf-8') @@ -444,6 +447,7 @@ class FastCoder(Coder): this class inverts that by defining encode() and decode() in terms of _create_impl(). """ + def encode(self, value): """Encodes the given object into a byte string.""" return self.get_impl().encode(value) @@ -461,6 +465,7 @@ def _create_impl(self): class BytesCoder(FastCoder): """Byte string coder.""" + def _create_impl(self): return coder_impl.BytesCoderImpl() @@ -482,6 +487,7 @@ def __hash__(self): class BooleanCoder(FastCoder): + def _create_impl(self): return coder_impl.BooleanCoderImpl() @@ -503,6 +509,7 @@ def __hash__(self): class MapCoder(FastCoder): + def __init__(self, key_coder, value_coder): # type: (Coder, Coder) -> None self._key_coder = key_coder @@ -548,6 +555,7 @@ def __repr__(self): # This is a separate class from MapCoder as the former is a standard coder with # no way to carry the is_deterministic bit. class DeterministicMapCoder(FastCoder): + def __init__(self, key_coder, value_coder): # type: (Coder, Coder) -> None assert key_coder.is_deterministic() @@ -576,6 +584,7 @@ def __repr__(self): class NullableCoder(FastCoder): + def __init__(self, value_coder): # type: (Coder) -> None self._value_coder = value_coder @@ -630,6 +639,7 @@ def __repr__(self): class VarIntCoder(FastCoder): """Variable-length integer coder.""" + def _create_impl(self): return coder_impl.VarIntCoderImpl() @@ -652,6 +662,7 @@ def __hash__(self): class BigEndianShortCoder(FastCoder): """A coder used for big-endian int16 values.""" + def _create_impl(self): return coder_impl.BigEndianShortCoderImpl() @@ -671,6 +682,7 @@ def __hash__(self): class SinglePrecisionFloatCoder(FastCoder): """A coder used for single-precision floating-point values.""" + def _create_impl(self): return coder_impl.SinglePrecisionFloatCoderImpl() @@ -696,6 +708,7 @@ class FloatCoder(FastCoder): :class:`SinglePrecisionFloatCoder` for a single-precision version of this coder. """ + def _create_impl(self): return coder_impl.FloatCoderImpl() @@ -718,6 +731,7 @@ def __hash__(self): class TimestampCoder(FastCoder): """A coder used for timeutil.Timestamp values.""" + def _create_impl(self): return coder_impl.TimestampCoderImpl() @@ -736,6 +750,7 @@ class _TimerCoder(FastCoder): """A coder used for timer values. For internal use.""" + def __init__(self, key_coder, window_coder): # type: (Coder, Coder) -> None self._key_coder = key_coder @@ -769,6 +784,7 @@ def __hash__(self): class SingletonCoder(FastCoder): """A coder that always encodes exactly one value.""" + def __init__(self, value): self._value = value @@ -806,6 +822,7 @@ def maybe_dill_loads(o): class _PickleCoderBase(FastCoder): """Base class for pickling coders.""" + def is_deterministic(self): # type: () -> bool # Note that the default coder, the PickleCoder, is not deterministic (for @@ -836,6 +853,7 @@ def __hash__(self): class _MemoizingPickleCoder(_PickleCoderBase): """Coder using Python's pickle functionality with memoization.""" + def __init__(self, cache_size=16): super().__init__() self.cache_size = cache_size @@ -863,6 +881,7 @@ def to_type_hint(self): class PickleCoder(_PickleCoderBase): """Coder using Python's pickle functionality.""" + def _create_impl(self): dumps = pickle.dumps protocol = pickle.HIGHEST_PROTOCOL @@ -878,12 +897,14 @@ def to_type_hint(self): class DillCoder(_PickleCoderBase): """Coder using dill's pickle functionality.""" + def _create_impl(self): return coder_impl.CallbackCoderImpl(maybe_dill_dumps, maybe_dill_loads) class DeterministicFastPrimitivesCoder(FastCoder): """Throws runtime errors when encoding non-deterministic values.""" + def __init__(self, coder, step_label): self._underlying_coder = coder self._step_label = step_label @@ -916,6 +937,7 @@ class FastPrimitivesCoder(FastCoder): For unknown types, falls back to another coder (e.g. PickleCoder). """ + def __init__(self, fallback_coder=PickleCoder()): # type: (Coder) -> None self._fallback_coder = fallback_coder @@ -962,6 +984,7 @@ class FakeDeterministicFastPrimitivesCoder(FastPrimitivesCoder): This can be registered as a fallback coder to go back to the behavior before deterministic encoding was enforced (BEAM-11719). """ + def is_deterministic(self): return True @@ -1013,6 +1036,7 @@ class ProtoCoder(FastCoder): any protobuf Message object. """ + def __init__(self, proto_message_type): # type: (Type[google.protobuf.message.Message]) -> None self.proto_message_type = proto_message_type @@ -1065,6 +1089,7 @@ class DeterministicProtoCoder(ProtoCoder): version of the protoc compiler what was used to generate the protobuf messages. """ + def _create_impl(self): return coder_impl.DeterministicProtoCoderImpl(self.proto_message_type) @@ -1082,6 +1107,7 @@ class ProtoPlusCoder(FastCoder): ProtoPlusCoder is registered in the global CoderRegistry as the default coder for any proto.Message object. """ + def __init__(self, proto_plus_message_type): # type: (Type[proto.Message]) -> None self.proto_plus_message_type = proto_plus_message_type @@ -1117,6 +1143,7 @@ def to_type_hint(self): class AvroGenericCoder(FastCoder): """A coder used for AvroRecord values.""" + def __init__(self, schema): self.schema = schema @@ -1148,6 +1175,7 @@ def from_runner_api_parameter(payload, unused_components, unused_context): class TupleCoder(FastCoder): """Coder of tuple objects.""" + def __init__(self, components): # type: (Iterable[Coder]) -> None self._coders = tuple(components) @@ -1224,6 +1252,7 @@ def from_runner_api_parameter(unused_payload, components, unused_context): class TupleSequenceCoder(FastCoder): """Coder of homogeneous tuple objects.""" + def __init__(self, elem_coder): # type: (Coder) -> None self._elem_coder = elem_coder @@ -1267,6 +1296,7 @@ def __hash__(self): class ListLikeCoder(FastCoder): """Coder of iterables of homogeneous objects.""" + def __init__(self, elem_coder): # type: (Coder) -> None self._elem_coder = elem_coder @@ -1310,6 +1340,7 @@ def __hash__(self): class IterableCoder(ListLikeCoder): """Coder of iterables of homogeneous objects.""" + def to_type_hint(self): return typehints.Iterable[self._elem_coder.to_type_hint()] @@ -1319,6 +1350,7 @@ def to_type_hint(self): class ListCoder(ListLikeCoder): """Coder of Python lists.""" + def to_type_hint(self): return typehints.List[self._elem_coder.to_type_hint()] @@ -1328,6 +1360,7 @@ def _create_impl(self): class GlobalWindowCoder(SingletonCoder): """Coder for global windows.""" + def __init__(self): from apache_beam.transforms import window super().__init__(window.GlobalWindow()) @@ -1339,6 +1372,7 @@ def __init__(self): class IntervalWindowCoder(FastCoder): """Coder for an window defined by a start timestamp and a duration.""" + def _create_impl(self): return coder_impl.IntervalWindowCoderImpl() @@ -1358,6 +1392,7 @@ def __hash__(self): class _OrderedUnionCoder(FastCoder): + def __init__( self, *coder_types: Tuple[type, Coder], fallback_coder: Optional[Coder]): self._coder_types = coder_types @@ -1390,6 +1425,7 @@ def __hash__(self): class WindowedValueCoder(FastCoder): """Coder for windowed values.""" + def __init__(self, wrapped_value_coder, window_coder=None): # type: (Coder, Optional[Coder]) -> None if not window_coder: @@ -1452,6 +1488,7 @@ def __hash__(self): class ParamWindowedValueCoder(WindowedValueCoder): """A coder used for parameterized windowed values.""" + def __init__(self, payload, components): super().__init__(components[0], components[1]) self.payload = payload @@ -1494,6 +1531,7 @@ class LengthPrefixCoder(FastCoder): """For internal use only; no backwards-compatibility guarantees. Coder which prefixes the length of the encoded object in the stream.""" + def __init__(self, value_coder): # type: (Coder) -> None self._value_coder = value_coder @@ -1595,6 +1633,7 @@ def from_runner_api_parameter(payload, components, context): class ShardedKeyCoder(FastCoder): """A coder for sharded key.""" + def __init__(self, key_coder): # type: (Coder) -> None self._key_coder = key_coder @@ -1644,6 +1683,7 @@ class TimestampPrefixingWindowCoder(FastCoder): Coder which prefixes the max timestamp of arbitrary window to its encoded form.""" + def __init__(self, window_coder: Coder) -> None: self._window_coder = window_coder @@ -1679,6 +1719,7 @@ class TimestampPrefixingOpaqueWindowCoder(FastCoder): """For internal use only; no backwards-compatibility guarantees. Coder which decodes windows as bytes.""" + def __init__(self) -> None: pass @@ -1704,6 +1745,7 @@ def __hash__(self): class BigIntegerCoder(FastCoder): + def _create_impl(self): return coder_impl.BigIntegerCoderImpl() @@ -1722,6 +1764,7 @@ def __hash__(self): class DecimalCoder(FastCoder): + def _create_impl(self): return coder_impl.DecimalCoderImpl() diff --git a/sdks/python/apache_beam/coders/coders_property_based_test.py b/sdks/python/apache_beam/coders/coders_property_based_test.py index 9279fc31c099..c80418f0f3ee 100644 --- a/sdks/python/apache_beam/coders/coders_property_based_test.py +++ b/sdks/python/apache_beam/coders/coders_property_based_test.py @@ -91,6 +91,7 @@ class TypesAreAllTested(unittest.TestCase): + def test_all_types_are_tested(self): # Verify that all types among Beam's defined types are being tested self.assertEqual( @@ -100,6 +101,7 @@ def test_all_types_are_tested(self): class ProperyTestingCoders(unittest.TestCase): + @given(st.text()) def test_string_coder(self, txt: str): coder = StrUtf8Coder() @@ -144,9 +146,7 @@ def test_row_coder(self, data: st.DataObject): row = RowType( **{ name: data.draw(SCHEMA_TYPES_TO_STRATEGY[type_]) - for name, - type_, - nullable in schema + for name, type_, nullable in schema }) coder = RowCoder(typing_to_runner_api(RowType).row_type.schema) diff --git a/sdks/python/apache_beam/coders/coders_test.py b/sdks/python/apache_beam/coders/coders_test.py index 5e5debca36e6..c4f6f239ee9c 100644 --- a/sdks/python/apache_beam/coders/coders_test.py +++ b/sdks/python/apache_beam/coders/coders_test.py @@ -34,6 +34,7 @@ class PickleCoderTest(unittest.TestCase): + def test_basics(self): v = ('a' * 10, 'b' * 90) pickler = coders.PickleCoder() @@ -52,6 +53,7 @@ def test_equality(self): class CodersTest(unittest.TestCase): + def test_str_utf8_coder(self): real_coder = coders_registry.get_coder(bytes) expected_coder = coders.BytesCoder() @@ -75,6 +77,7 @@ def test_str_utf8_coder(self): # TODO(https://github.com/apache/beam/issues/22319): The proto file should be # placed in a common directory that can be shared between java and python. class ProtoCoderTest(unittest.TestCase): + def test_proto_coder(self): ma = test_message.MessageA() mb = ma.field2.add() @@ -106,6 +109,7 @@ def test_proto_coder_on_protobuf_message_subclasses(self): class DeterministicProtoCoderTest(unittest.TestCase): + def test_deterministic_proto_coder(self): ma = test_message.MessageA() mb = ma.field2.add() @@ -147,6 +151,7 @@ class ProtoPlusMessageWithMap(proto.Message): class ProtoPlusCoderTest(unittest.TestCase): + def test_proto_plus_coder(self): ma = ProtoPlusMessageA() ma.field2 = [ProtoPlusMessageB(field1=True)] @@ -195,6 +200,7 @@ class AvroTestRecord(AvroRecord): class AvroCoderTest(unittest.TestCase): + def test_avro_record_coder(self): real_coder = coders_registry.get_coder(AvroTestRecord) expected_coder = AvroTestCoder() @@ -219,6 +225,7 @@ def test_avro_record_coder(self): class DummyClass(object): """A class with no registered coder.""" + def __init__(self): pass @@ -232,6 +239,7 @@ def __hash__(self): class FallbackCoderTest(unittest.TestCase): + def test_default_fallback_path(self): """Test fallback path picks a matching coder if no coder is registered.""" @@ -243,6 +251,7 @@ def test_default_fallback_path(self): class NullableCoderTest(unittest.TestCase): + def test_determinism(self): deterministic = coders_registry.get_coder(typehints.Optional[int]) deterministic.as_deterministic_coder('label') @@ -257,12 +266,14 @@ def test_determinism(self): class LengthPrefixCoderTest(unittest.TestCase): + def test_to_type_hint(self): coder = coders.LengthPrefixCoder(coders.BytesCoder()) assert coder.to_type_hint() is bytes class NumpyIntAsKeyTest(unittest.TestCase): + def test_numpy_int(self): # this type is not supported as the key import numpy as np diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index f3381cdb1d69..5ce45b1aeabe 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -69,6 +69,7 @@ class MyEnum(enum.Enum): class DefinesGetState: + def __init__(self, value): self.value = value @@ -80,12 +81,14 @@ def __eq__(self, other): class DefinesGetAndSetState(DefinesGetState): + def __setstate__(self, value): self.value = value # Defined out of line for picklability. class CustomCoder(coders.Coder): + def encode(self, x): return str(x + 1).encode('utf-8') @@ -163,7 +166,7 @@ def tearDownClass(cls): coders.BigEndianShortCoder, coders.SinglePrecisionFloatCoder, coders.ToBytesCoder, - coders.BigIntegerCoder, # tested in DecimalCoder + coders.BigIntegerCoder, # tested in DecimalCoder coders.TimestampPrefixingOpaqueWindowCoder, ]) cls.seen_nested -= set( @@ -434,6 +437,7 @@ def test_iterable_coder_unknown_length(self): self._test_iterable_coder_of_unknown_length(80000) def _test_iterable_coder_of_unknown_length(self, count): + def iter_generator(count): for i in range(count): yield i @@ -614,7 +618,9 @@ def test_length_prefix_coder(self): coders.TupleCoder((coder, coder)), (b'', b'a'), (b'bc', b'def')) def test_nested_observables(self): + class FakeObservableIterator(observable.ObservableMixin): + def __iter__(self): return iter([1, 2, 3]) @@ -672,9 +678,14 @@ def test_nullable_coder(self): def test_map_coder(self): values = [ - {1: "one", 300: "three hundred"}, # force yapf to be nice + { + 1: "one", 300: "three hundred" + }, # force yapf to be nice {}, - {i: str(i) for i in range(5000)} + { + i: str(i) + for i in range(5000) + } ] map_coder = coders.MapCoder(coders.VarIntCoder(), coders.StrUtf8Coder()) self.check_coder(map_coder, *values) diff --git a/sdks/python/apache_beam/coders/fast_coders_test.py b/sdks/python/apache_beam/coders/fast_coders_test.py index fa8643c2a383..8fc7e4a3bdb4 100644 --- a/sdks/python/apache_beam/coders/fast_coders_test.py +++ b/sdks/python/apache_beam/coders/fast_coders_test.py @@ -27,6 +27,7 @@ class FastCoders(unittest.TestCase): + def test_using_fast_impl(self): try: utils.check_compiled('apache_beam.coders.coder_impl') diff --git a/sdks/python/apache_beam/coders/observable.py b/sdks/python/apache_beam/coders/observable.py index 8cdc3227ec85..28ad58876de9 100644 --- a/sdks/python/apache_beam/coders/observable.py +++ b/sdks/python/apache_beam/coders/observable.py @@ -29,6 +29,7 @@ class ObservableMixin(object): Subclasses need to call self.notify_observers with any object yielded. """ + def __init__(self): self.observers = [] diff --git a/sdks/python/apache_beam/coders/observable_test.py b/sdks/python/apache_beam/coders/observable_test.py index df4e7ef09408..b1a59a192880 100644 --- a/sdks/python/apache_beam/coders/observable_test.py +++ b/sdks/python/apache_beam/coders/observable_test.py @@ -37,7 +37,9 @@ def observer(self, value, key=None): self.observed_keys.append(key) def test_observable(self): + class Watched(observable.ObservableMixin): + def __iter__(self): for i in (1, 4, 3): self.notify_observers(i, key='a%d' % i) diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py index e93abbc887fb..3b39c00b7f0f 100644 --- a/sdks/python/apache_beam/coders/row_coder.py +++ b/sdks/python/apache_beam/coders/row_coder.py @@ -51,6 +51,7 @@ class RowCoder(FastCoder): Implements the beam:coder:row:v1 standard coder spec. """ + def __init__(self, schema, force_deterministic=False): """Initializes a :class:`RowCoder`. @@ -189,6 +190,7 @@ def _nonnull_coder_from_type(field_type): class LogicalTypeCoder(FastCoder): + def __init__(self, logical_type, representation_coder): self.logical_type = logical_type self.representation_coder = representation_coder diff --git a/sdks/python/apache_beam/coders/row_coder_test.py b/sdks/python/apache_beam/coders/row_coder_test.py index 6ac982835cb3..440a4182f97a 100644 --- a/sdks/python/apache_beam/coders/row_coder_test.py +++ b/sdks/python/apache_beam/coders/row_coder_test.py @@ -423,8 +423,7 @@ def test_batch_encode_decode(self): for size in [len(self.PEOPLE) - 1, len(self.PEOPLE), len(self.PEOPLE) + 1]: dest = { field: np.ndarray((size, ), dtype=a.dtype) - for field, - a in columnar.items() + for field, a in columnar.items() } n = min(size, len(self.PEOPLE)) self.assertEqual( diff --git a/sdks/python/apache_beam/coders/slow_coders_test.py b/sdks/python/apache_beam/coders/slow_coders_test.py index 7915116a19a3..ab325edd5e20 100644 --- a/sdks/python/apache_beam/coders/slow_coders_test.py +++ b/sdks/python/apache_beam/coders/slow_coders_test.py @@ -29,6 +29,7 @@ 'Remove non-cython tests.' 'https://github.com/apache/beam/issues/28307') class SlowCoders(unittest.TestCase): + def test_using_slow_impl(self): try: # pylint: disable=wrong-import-position diff --git a/sdks/python/apache_beam/coders/slow_stream.py b/sdks/python/apache_beam/coders/slow_stream.py index b08ad8e9a37f..f4cfb5b82738 100644 --- a/sdks/python/apache_beam/coders/slow_stream.py +++ b/sdks/python/apache_beam/coders/slow_stream.py @@ -29,6 +29,7 @@ class OutputStream(object): """For internal use only; no backwards-compatibility guarantees. A pure Python implementation of stream.OutputStream.""" + def __init__(self): self.data: List[bytes] = [] self.byte_count = 0 @@ -91,6 +92,7 @@ class ByteCountingOutputStream(OutputStream): """For internal use only; no backwards-compatibility guarantees. A pure Python implementation of stream.ByteCountingOutputStream.""" + def __init__(self): # Note that we don't actually use any of the data initialized by our super. super().__init__() @@ -119,6 +121,7 @@ class InputStream(object): """For internal use only; no backwards-compatibility guarantees. A pure Python implementation of stream.InputStream.""" + def __init__(self, data: bytes) -> None: self.data = data self.pos = 0 diff --git a/sdks/python/apache_beam/coders/standard_coders_test.py b/sdks/python/apache_beam/coders/standard_coders_test.py index 47df0116f2c6..59b7024188a2 100644 --- a/sdks/python/apache_beam/coders/standard_coders_test.py +++ b/sdks/python/apache_beam/coders/standard_coders_test.py @@ -77,6 +77,7 @@ def parse_float(s): def value_parser_from_schema(schema): + def attribute_parser_from_type(type_): parser = nonnull_attribute_parser_from_type(type_) if type_.nullable: @@ -139,65 +140,48 @@ class StandardCodersTest(unittest.TestCase): 'beam:coder:bool:v1': lambda x: x, 'beam:coder:string_utf8:v1': lambda x: x, 'beam:coder:varint:v1': lambda x: x, - 'beam:coder:kv:v1': lambda x, - key_parser, - value_parser: (key_parser(x['key']), value_parser(x['value'])), + 'beam:coder:kv:v1': lambda x, key_parser, value_parser: + (key_parser(x['key']), value_parser(x['value'])), 'beam:coder:interval_window:v1': lambda x: IntervalWindow( - start=Timestamp(micros=(x['end'] - x['span']) * 1000), - end=Timestamp(micros=x['end'] * 1000)), - 'beam:coder:iterable:v1': lambda x, - parser: list(map(parser, x)), - 'beam:coder:state_backed_iterable:v1': lambda x, - parser: list(map(parser, x)), + start=Timestamp(micros=(x['end'] - x['span']) * 1000), end=Timestamp( + micros=x['end'] * 1000)), + 'beam:coder:iterable:v1': lambda x, parser: list(map(parser, x)), + 'beam:coder:state_backed_iterable:v1': lambda x, parser: list( + map(parser, x)), 'beam:coder:global_window:v1': lambda x: window.GlobalWindow(), - 'beam:coder:windowed_value:v1': lambda x, - value_parser, - window_parser: windowed_value.create( - value_parser(x['value']), - x['timestamp'] * 1000, - tuple(window_parser(w) for w in x['windows'])), - 'beam:coder:param_windowed_value:v1': lambda x, - value_parser, + 'beam:coder:windowed_value:v1': lambda x, value_parser, window_parser: + windowed_value.create( + value_parser(x['value']), x['timestamp'] * 1000, tuple( + window_parser(w) for w in x['windows'])), + 'beam:coder:param_windowed_value:v1': lambda x, value_parser, window_parser: windowed_value.create( - value_parser(x['value']), - x['timestamp'] * 1000, - tuple(window_parser(w) for w in x['windows']), - PaneInfo( - x['pane']['is_first'], - x['pane']['is_last'], - PaneInfoTiming.from_string(x['pane']['timing']), - x['pane']['index'], - x['pane']['on_time_index'])), - 'beam:coder:timer:v1': lambda x, - value_parser, - window_parser: userstate.Timer( - user_key=value_parser(x['userKey']), - dynamic_timer_tag=x['dynamicTimerTag'], - clear_bit=x['clearBit'], - windows=tuple(window_parser(w) for w in x['windows']), - fire_timestamp=None, - hold_timestamp=None, - paneinfo=None) if x['clearBit'] else userstate.Timer( - user_key=value_parser(x['userKey']), - dynamic_timer_tag=x['dynamicTimerTag'], - clear_bit=x['clearBit'], - fire_timestamp=Timestamp(micros=x['fireTimestamp'] * 1000), - hold_timestamp=Timestamp(micros=x['holdTimestamp'] * 1000), - windows=tuple(window_parser(w) for w in x['windows']), - paneinfo=PaneInfo( - x['pane']['is_first'], - x['pane']['is_last'], - PaneInfoTiming.from_string(x['pane']['timing']), - x['pane']['index'], - x['pane']['on_time_index'])), + value_parser(x['value']), x['timestamp'] * 1000, tuple( + window_parser(w) for w in x['windows']), PaneInfo( + x['pane']['is_first'], x['pane']['is_last'], PaneInfoTiming. + from_string(x['pane']['timing']), x['pane']['index'], x[ + 'pane']['on_time_index'])), + 'beam:coder:timer:v1': lambda x, value_parser, window_parser: userstate. + Timer( + user_key=value_parser(x['userKey']), dynamic_timer_tag=x[ + 'dynamicTimerTag'], clear_bit=x['clearBit'], windows=tuple( + window_parser(w) for w in x['windows']), fire_timestamp=None, + hold_timestamp=None, paneinfo=None) + if x['clearBit'] else userstate.Timer( + user_key=value_parser(x['userKey']), dynamic_timer_tag=x[ + 'dynamicTimerTag'], clear_bit=x['clearBit'], fire_timestamp= + Timestamp(micros=x['fireTimestamp'] * 1000), hold_timestamp=Timestamp( + micros=x['holdTimestamp'] * 1000), windows=tuple( + window_parser(w) for w in x['windows']), paneinfo=PaneInfo( + x['pane']['is_first'], x['pane']['is_last'], + PaneInfoTiming.from_string(x['pane']['timing']), x[ + 'pane']['index'], x['pane']['on_time_index'])), 'beam:coder:double:v1': parse_float, - 'beam:coder:sharded_key:v1': lambda x, - value_parser: ShardedKey( + 'beam:coder:sharded_key:v1': lambda x, value_parser: ShardedKey( key=value_parser(x['key']), shard_id=x['shardId'].encode('utf-8')), - 'beam:coder:custom_window:v1': lambda x, - window_parser: window_parser(x['window']), - 'beam:coder:nullable:v1': lambda x, - value_parser: x.encode('utf-8') if x else None + 'beam:coder:custom_window:v1': lambda x, window_parser: window_parser( + x['window']), + 'beam:coder:nullable:v1': lambda x, value_parser: x.encode('utf-8') + if x else None } def test_standard_coders(self): @@ -206,6 +190,7 @@ def test_standard_coders(self): self._run_standard_coder(name, spec) def _run_standard_coder(self, name, spec): + def assert_equal(actual, expected): """Handle nan values which self.assertEqual fails on.""" if (isinstance(actual, float) and isinstance(expected, float) and diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index 1667cb7a916a..fdc66d5d0f6c 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -79,6 +79,7 @@ def MakeXyzs(v): class CoderRegistry(object): """A coder registry for typehint/coder associations.""" + def __init__(self, fallback_coder=None): self._coders: Dict[Any, Type[coders.Coder]] = {} self.custom_types: List[Any] = [] @@ -188,6 +189,7 @@ class FirstOf(object): """For internal use only; no backwards-compatibility guarantees. A class used to get the first matching coder from a list of coders.""" + def __init__(self, coders: Iterable[Type[coders.Coder]]) -> None: self._coders = coders diff --git a/sdks/python/apache_beam/coders/typecoders_test.py b/sdks/python/apache_beam/coders/typecoders_test.py index 3adc8255409d..03e52a0802c6 100644 --- a/sdks/python/apache_beam/coders/typecoders_test.py +++ b/sdks/python/apache_beam/coders/typecoders_test.py @@ -27,6 +27,7 @@ class CustomClass(object): + def __init__(self, n): self.number = n @@ -38,6 +39,7 @@ def __hash__(self): class CustomCoder(coders.Coder): + def encode(self, value): return str(value.number).encode('ASCII') @@ -52,6 +54,7 @@ def is_deterministic(self): class TypeCodersTest(unittest.TestCase): + def test_register_non_type_coder(self): coder = CustomCoder() with self.assertRaisesRegex( diff --git a/sdks/python/apache_beam/dataframe/convert.py b/sdks/python/apache_beam/dataframe/convert.py index c5a0d1025c6d..52b4900cf790 100644 --- a/sdks/python/apache_beam/dataframe/convert.py +++ b/sdks/python/apache_beam/dataframe/convert.py @@ -92,12 +92,14 @@ def to_dataframe( class RowsToDataFrameFn(beam.DoFn): + @beam.DoFn.yields_elements def process_batch(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: yield batch class ElementsToSeriesFn(beam.DoFn): + @beam.DoFn.yields_elements def process_batch(self, batch: pd.Series) -> Iterable[pd.Series]: yield batch @@ -136,6 +138,7 @@ def _make_unbatched_pcoll( class DataFrameToRowsFn(beam.DoFn): + def __init__(self, proxy, include_indexes): self._proxy = proxy self._include_indexes = include_indexes @@ -150,6 +153,7 @@ def infer_output_type(self, input_element_type): class SeriesToElementsFn(beam.DoFn): + def __init__(self, proxy): self._proxy = proxy @@ -257,14 +261,14 @@ def extract_input(placeholder): {ix: df._expr for (ix, df) in enumerate(new_dataframes)}) - TO_PCOLLECTION_CACHE.update( - {new_dataframes[ix]._expr._id: pc - for ix, pc in new_results.items()}) + TO_PCOLLECTION_CACHE.update({ + new_dataframes[ix]._expr._id: pc + for ix, pc in new_results.items() + }) raw_results = { ix: TO_PCOLLECTION_CACHE[df._expr._id] - for ix, - df in enumerate(dataframes) + for ix, df in enumerate(dataframes) } if yield_elements == "schemas": diff --git a/sdks/python/apache_beam/dataframe/convert_test.py b/sdks/python/apache_beam/dataframe/convert_test.py index b00ce0e51fa8..3c78c8a4343e 100644 --- a/sdks/python/apache_beam/dataframe/convert_test.py +++ b/sdks/python/apache_beam/dataframe/convert_test.py @@ -25,6 +25,7 @@ def equal_to_unordered_series(expected): + def check(actual): actual = pd.concat(actual) if sorted(expected) != sorted(actual): @@ -34,6 +35,7 @@ def check(actual): class ConvertTest(unittest.TestCase): + def test_convert_yield_pandas(self): with beam.Pipeline() as p: a = pd.Series([1, 2, 3]) @@ -185,7 +187,9 @@ def test_convert_memoization_clears_cache(self): logging.disable(logging.NOTSET) def test_auto_convert(self): + class MySchemaTransform(beam.PTransform): + def expand(self, pcoll): return pcoll | beam.Map( lambda x: beam.Row( diff --git a/sdks/python/apache_beam/dataframe/doctests.py b/sdks/python/apache_beam/dataframe/doctests.py index 57ee8009ba44..52b72efde9b9 100644 --- a/sdks/python/apache_beam/dataframe/doctests.py +++ b/sdks/python/apache_beam/dataframe/doctests.py @@ -60,6 +60,7 @@ class FakePandasObject(object): """A stand-in for the wrapped pandas objects. """ + def __init__(self, pandas_obj, test_env): self._pandas_obj = pandas_obj self._test_env = test_env @@ -94,6 +95,7 @@ class TestEnvironment(object): These classes are patched to be able to recognize and retrieve inputs and results, stored in `self._inputs` and `self._all_frames` respectively. """ + def __init__(self): self._inputs = {} self._all_frames = {} @@ -157,6 +159,7 @@ def __exit__(self, *unused_args): del self._ALL_RESULTS[self._id] def record_fn(self, name): + def record(value): self._ALL_RESULTS[self._id][name].append(value) @@ -173,6 +176,7 @@ def get_recorded(self, name): class _DeferrredDataframeOutputChecker(doctest.OutputChecker): """Validates output by replacing DeferredBase[...] with computed values. """ + def __init__(self, env, use_beam): self._env = env if use_beam: @@ -188,8 +192,7 @@ def compute_using_session(self, to_compute): session = expressions.PartitioningSession(self._env._inputs) return { name: session.evaluate(frame._expr) - for name, - frame in to_compute.items() + for name, frame in to_compute.items() } def compute_using_beam(self, to_compute): @@ -198,13 +201,13 @@ def compute_using_beam(self, to_compute): input_pcolls = { placeholder: p | 'Create%s' % placeholder >> beam.Create([input[::2], input[1::2]]) - for placeholder, - input in self._env._inputs.items() + for placeholder, input in self._env._inputs.items() } output_pcolls = ( - input_pcolls | transforms._DataframeExpressionsTransform( - {name: frame._expr - for name, frame in to_compute.items()})) + input_pcolls | transforms._DataframeExpressionsTransform({ + name: frame._expr + for name, frame in to_compute.items() + })) for name, output_pcoll in output_pcolls.items(): _ = output_pcoll | 'Record%s' % name >> beam.FlatMap( recorder.record_fn(name)) @@ -341,6 +344,7 @@ class BeamDataframeDoctestRunner(doctest.DocTestRunner): """A Doctest runner suitable for replacing the `pd` module with one backed by beam. """ + def __init__( self, env, @@ -359,18 +363,15 @@ def to_callable(cond): self._wont_implement_ok = { test: [to_callable(cond) for cond in examples] - for test, - examples in (wont_implement_ok or {}).items() + for test, examples in (wont_implement_ok or {}).items() } self._not_implemented_ok = { test: [to_callable(cond) for cond in examples] - for test, - examples in (not_implemented_ok or {}).items() + for test, examples in (not_implemented_ok or {}).items() } self._skip = { test: [to_callable(cond) for cond in examples] - for test, - examples in (skip or {}).items() + for test, examples in (skip or {}).items() } super().__init__( checker=_DeferrredDataframeOutputChecker(self._test_env, use_beam), @@ -420,6 +421,7 @@ def run(self, test, **kwargs): return result def report_success(self, out, test, example, got): + def extract_concise_reason(got, expected_exc): m = re.search(r"Implement(?:ed)?Error:\s+(.*)\n$", got) if m: @@ -463,6 +465,7 @@ class AugmentedTestResults(doctest.TestResults): class Summary(object): + def __init__(self, failures=0, tries=0, skipped=0, error_reasons=None): self.failures = failures self.tries = tries @@ -487,6 +490,7 @@ def __add__(self, other): merged_reasons) def summarize(self): + def print_partition(indent, desc, n, total): print("%s%d %s (%.1f%%)" % (" " * indent, n, desc, n / total * 100)) @@ -535,9 +539,9 @@ def is_example_line(line): IMPORT_PANDAS = 'import pandas as pd' example_srcs = [] - lines = iter([(lineno, line.rstrip()) for lineno, - line in enumerate(rst.split('\n')) if is_example_line(line)] + - [(None, 'END')]) + lines = iter([(lineno, line.rstrip()) + for lineno, line in enumerate(rst.split('\n')) + if is_example_line(line)] + [(None, 'END')]) # https://ipython.readthedocs.io/en/stable/sphinxext.html lineno, line = next(lines) @@ -610,6 +614,7 @@ def test_rst_ipython( """Extracts examples from an rst file and run them through pandas to get the expected output, and then compare them against our dataframe implementation. """ + def run_tests(extraglobs, optionflags, **kwargs): # The patched one. tests = parse_rst_ipython_tests(rst, name, extraglobs, optionflags) @@ -690,12 +695,8 @@ def _run_patched(func, *args, **kwargs): # Unfortunately the runner is not injectable. original_doc_test_runner = doctest.DocTestRunner doctest.DocTestRunner = lambda **kwargs: BeamDataframeDoctestRunner( - env, - use_beam=use_beam, - wont_implement_ok=wont_implement_ok, - not_implemented_ok=not_implemented_ok, - skip=skip, - **kwargs) + env, use_beam=use_beam, wont_implement_ok=wont_implement_ok, + not_implemented_ok=not_implemented_ok, skip=skip, **kwargs) with expressions.allow_non_parallel_operations(): return func( *args, extraglobs=extraglobs, optionflags=optionflags, **kwargs) diff --git a/sdks/python/apache_beam/dataframe/doctests_test.py b/sdks/python/apache_beam/dataframe/doctests_test.py index df24213c8716..f966dc1f92fc 100644 --- a/sdks/python/apache_beam/dataframe/doctests_test.py +++ b/sdks/python/apache_beam/dataframe/doctests_test.py @@ -144,6 +144,7 @@ def foo(x): class DoctestTest(unittest.TestCase): + def test_good(self): result = doctests.teststring(SAMPLE_DOCTEST, report=False) self.assertEqual(result.attempted, 3) diff --git a/sdks/python/apache_beam/dataframe/expressions.py b/sdks/python/apache_beam/dataframe/expressions.py index 2ef172b8dad3..94a9d4d825ce 100644 --- a/sdks/python/apache_beam/dataframe/expressions.py +++ b/sdks/python/apache_beam/dataframe/expressions.py @@ -33,6 +33,7 @@ class Session(object): The bindings typically include required placeholders, but may be any intermediate expression as well. """ + def __init__(self, bindings=None): self._bindings = dict(bindings or {}) @@ -60,6 +61,7 @@ class PartitioningSession(Session): For testing only. """ + def evaluate(self, expr): import pandas as pd import collections @@ -205,6 +207,7 @@ class Expression(Generic[T]): expression. However, unless the inputs are Singleton-partitioned, the expression makes no guarantees about the partitioning of the output. """ + def __init__(self, name: str, proxy: T, _id: Optional[str] = None): self._name = name self._proxy = proxy @@ -250,6 +253,7 @@ def preserves_partition_by(self) -> partitionings.Partitioning: class PlaceholderExpression(Expression): """An expression whose value must be explicitly bound in the session.""" + def __init__( self, proxy: T, @@ -282,6 +286,7 @@ def preserves_partition_by(self): class ConstantExpression(Expression): """An expression whose value is known at pipeline construction time.""" + def __init__(self, value: T, proxy: Optional[T] = None): """Initialize a constant expression. @@ -314,6 +319,7 @@ def preserves_partition_by(self): class ComputedExpression(Expression): """An expression whose value must be computed at pipeline execution time.""" + def __init__( self, name: str, @@ -410,6 +416,7 @@ def allow_non_parallel_operations(allow=True): class NonParallelOperation(Exception): + def __init__(self, msg): super().__init__(self, msg) self.msg = msg diff --git a/sdks/python/apache_beam/dataframe/expressions_test.py b/sdks/python/apache_beam/dataframe/expressions_test.py index 2c5c716f5beb..65e428f9780f 100644 --- a/sdks/python/apache_beam/dataframe/expressions_test.py +++ b/sdks/python/apache_beam/dataframe/expressions_test.py @@ -23,6 +23,7 @@ class ExpressionTest(unittest.TestCase): + def test_placeholder_expression(self): a = expressions.PlaceholderExpression(None) b = expressions.PlaceholderExpression(None) diff --git a/sdks/python/apache_beam/dataframe/frame_base.py b/sdks/python/apache_beam/dataframe/frame_base.py index 8e206fc5e037..a700c044090a 100644 --- a/sdks/python/apache_beam/dataframe/frame_base.py +++ b/sdks/python/apache_beam/dataframe/frame_base.py @@ -43,6 +43,7 @@ def __init__(self, expr): @classmethod def _register_for(cls, pandas_type): + def wrapper(deferred_type): cls._pandas_type_map[pandas_type] = deferred_type return deferred_type @@ -91,6 +92,7 @@ class UnusableUnpickledDeferredBase(object): Trying to use this object after unpickling is a bug and will result in an error. """ + def __init__(self, name): self._name = name @@ -103,6 +105,7 @@ class DeferredFrame(DeferredBase): class _DeferredScalar(DeferredBase): + def apply(self, func, name=None, args=()): if name is None: name = func.__name__ @@ -135,6 +138,7 @@ def __bool__(self): def _scalar_binop(op): + def binop(self, other): if not isinstance(other, DeferredBase): return self.apply(lambda left: getattr(left, op)(other), name=op) @@ -442,6 +446,7 @@ def wrapper(*args, **kwargs): def _copy_and_mutate(func): + def wrapper(self, *args, **kwargs): copy = self.copy() func(copy, *args, **kwargs) @@ -462,6 +467,7 @@ def maybe_inplace(func): the inplace operation will refernce the updated expression. For internal use only. No backwards compatibility guarantees.""" + @functools.wraps(func) def wrapper(self, inplace=False, **kwargs): result = func(self, **kwargs) @@ -491,6 +497,7 @@ def args_to_kwargs(base_type, removed_method=False, removed_args=None): removed_args: If not empty, which arguments have been dropped in the running Pandas version. """ + def wrap(func): if removed_method: # Do no processing, let Beam function itself raise the error if called. @@ -566,6 +573,7 @@ def with_docs_from(base_type, name=None, removed_method=False): removed_method used in cases where a method has been removed in a later version of Pandas. """ + def wrap(func): if removed_method: func.__doc__ = ( @@ -659,6 +667,7 @@ def populate_defaults(base_type, removed_method=False, removed_args=None): removed_args: If not empty, which arguments have been dropped in the running Pandas version. """ + def wrap(func): if removed_method: return func @@ -750,6 +759,7 @@ class WontImplementError(NotImplementedError): Raising this error will also prevent this doctests from being validated when run with the beam dataframe validation doctest runner. """ + def __init__(self, msg, reason=None): if reason is not None: if reason not in _WONT_IMPLEMENT_REASONS: diff --git a/sdks/python/apache_beam/dataframe/frame_base_test.py b/sdks/python/apache_beam/dataframe/frame_base_test.py index 0a73905339fd..9c3743147140 100644 --- a/sdks/python/apache_beam/dataframe/frame_base_test.py +++ b/sdks/python/apache_beam/dataframe/frame_base_test.py @@ -24,6 +24,7 @@ class FrameBaseTest(unittest.TestCase): + def test_elementwise_func(self): a = pd.Series([1, 2, 3]) b = pd.Series([100, 200, 300]) @@ -56,6 +57,7 @@ def test_elementwise_func_kwarg(self): self.assertTrue(sub(x, y)._expr.evaluate_at(session).equals(a - b)) def test_maybe_inplace(self): + @frame_base.maybe_inplace def add_one(frame): return frame + 1 @@ -71,11 +73,14 @@ def add_one(frame): self.assertIsNot(x._expr, original_expr) def test_args_to_kwargs(self): + class Base(object): + def func(self, a=1, b=2, c=3, *, kw_only=4): pass class Proxy(object): + @frame_base.args_to_kwargs(Base) def func(self, **kwargs): return kwargs @@ -92,7 +97,9 @@ def func(self, **kwargs): proxy.func(2, 4, 6, 8) def test_args_to_kwargs_populates_defaults(self): + class Base(object): + def func(self, a=1, b=2, c=3): pass @@ -100,6 +107,7 @@ def func_removed_args(self, a): pass class Proxy(object): + @frame_base.args_to_kwargs(Base) @frame_base.populate_defaults(Base) def func(self, a, c=1000, **kwargs): @@ -133,11 +141,14 @@ def func_removed_args(self, a, c, **kwargs): self.assertEqual(proxy.func_removed_args(12, d=100), {'a': 12, 'd': 100}) def test_args_to_kwargs_populates_default_handles_kw_only(self): + class Base(object): + def func(self, a, b=2, c=3, *, kw_only=4): pass class ProxyUsesKwOnly(object): + @frame_base.args_to_kwargs(Base) @frame_base.populate_defaults(Base) def func(self, a, kw_only, **kwargs): @@ -158,6 +169,7 @@ def func(self, a, kw_only, **kwargs): proxy.func(2, 4, 6, 8) # got too many positioned arguments class ProxyDoesntUseKwOnly(object): + @frame_base.args_to_kwargs(Base) @frame_base.populate_defaults(Base) def func(self, a, **kwargs): @@ -175,11 +187,14 @@ def func(self, a, **kwargs): }) def test_populate_defaults_overwrites_copy(self): + class Base(object): + def func(self, a=1, b=2, c=3, *, copy=None): pass class Proxy(object): + @frame_base.args_to_kwargs(Base) @frame_base.populate_defaults(Base) def func(self, a, copy, **kwargs): diff --git a/sdks/python/apache_beam/dataframe/frames.py b/sdks/python/apache_beam/dataframe/frames.py index ccd01f35f87b..b2d22a58c065 100644 --- a/sdks/python/apache_beam/dataframe/frames.py +++ b/sdks/python/apache_beam/dataframe/frames.py @@ -65,6 +65,7 @@ def populate_not_implemented(pd_type): + def wrapper(deferred_type): for attr in dir(pd_type): # Don't auto-define hidden methods or dunders @@ -91,6 +92,7 @@ def wrapper(deferred_type): def _fillna_alias(method): + def wrapper(self, *args, **kwargs): return self.fillna(*args, method=method, **kwargs) @@ -153,6 +155,7 @@ def wrapper(self, *args, **kwargs): def _agg_method(base, func): + def wrapper(self, *args, **kwargs): return self.agg(func, *args, **kwargs) @@ -177,6 +180,7 @@ def wrapper(self, *args, **kwargs): class DeferredDataFrameOrSeries(frame_base.DeferredFrame): + def _render_indexes(self): if self.index.nlevels == 1: return 'index=' + ( @@ -232,13 +236,9 @@ def drop(self, labels, axis, index, columns, errors, **kwargs): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'drop', - lambda df: df.drop( - axis=axis, - index=index, - columns=columns, - errors=errors, - **kwargs), [self._expr], + 'drop', lambda df: df.drop( + axis=axis, index=index, columns=columns, errors=errors, **kwargs + ), [self._expr], proxy=proxy, requires_partition_by=requires)) @@ -248,8 +248,8 @@ def drop(self, labels, axis, index, columns, errors, **kwargs): def droplevel(self, level, axis): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'droplevel', - lambda df: df.droplevel(level, axis=axis), [self._expr], + 'droplevel', lambda df: df.droplevel(level, axis=axis), + [self._expr], requires_partition_by=partitionings.Arbitrary(), preserves_partition_by=partitionings.Arbitrary() if axis in (1, 'column') else partitionings.Singleton())) @@ -259,8 +259,7 @@ def droplevel(self, level, axis): def swaplevel(self, **kwargs): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'swaplevel', - lambda df: df.swaplevel(**kwargs), [self._expr], + 'swaplevel', lambda df: df.swaplevel(**kwargs), [self._expr], requires_partition_by=partitionings.Arbitrary(), preserves_partition_by=partitionings.Arbitrary())) @@ -295,13 +294,13 @@ def fillna(self, value, method, axis, limit, **kwargs): # This is OK, as its index must be the same size as the columns set of # self, so cannot be too large. class AsScalar(object): + def __init__(self, value): self.value = value with expressions.allow_non_parallel_operations(): value_expr = expressions.ComputedExpression( - 'as_scalar', - lambda df: AsScalar(df), [value._expr], + 'as_scalar', lambda df: AsScalar(df), [value._expr], requires_partition_by=partitionings.Singleton()) get_value = lambda x: x.value @@ -322,14 +321,9 @@ def __init__(self, value): return frame_base.DeferredFrame.wrap( # yapf: disable expressions.ComputedExpression( - 'fillna', - lambda df, - value: df.fillna( - get_value(value), - method=method, - axis=axis, - limit=limit, - **kwargs), [self._expr, value_expr], + 'fillna', lambda df, value: df.fillna( + get_value(value), method=method, axis=axis, limit=limit, ** + kwargs), [self._expr, value_expr], preserves_partition_by=partitionings.Arbitrary(), requires_partition_by=requires)) @@ -345,30 +339,30 @@ def __init__(self, value): @frame_base.with_docs_from(pd.DataFrame) def first(self, offset): per_partition = expressions.ComputedExpression( - 'first-per-partition', - lambda df: df.sort_index().first(offset=offset), [self._expr], + 'first-per-partition', lambda df: df.sort_index().first(offset=offset), + [self._expr], preserves_partition_by=partitionings.Arbitrary(), requires_partition_by=partitionings.Arbitrary()) with expressions.allow_non_parallel_operations(True): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'first', - lambda df: df.sort_index().first(offset=offset), [per_partition], + 'first', lambda df: df.sort_index().first(offset=offset), + [per_partition], preserves_partition_by=partitionings.Arbitrary(), requires_partition_by=partitionings.Singleton())) @frame_base.with_docs_from(pd.DataFrame) def last(self, offset): per_partition = expressions.ComputedExpression( - 'last-per-partition', - lambda df: df.sort_index().last(offset=offset), [self._expr], + 'last-per-partition', lambda df: df.sort_index().last(offset=offset), + [self._expr], preserves_partition_by=partitionings.Arbitrary(), requires_partition_by=partitionings.Arbitrary()) with expressions.allow_non_parallel_operations(True): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'last', - lambda df: df.sort_index().last(offset=offset), [per_partition], + 'last', lambda df: df.sort_index().last(offset=offset), + [per_partition], preserves_partition_by=partitionings.Arbitrary(), requires_partition_by=partitionings.Singleton())) @@ -388,8 +382,7 @@ def groupby(self, by, level, axis, as_index, group_keys, **kwargs): if axis in (1, 'columns'): return _DeferredGroupByCols( expressions.ComputedExpression( - 'groupbycols', - lambda df: df.groupby( + 'groupbycols', lambda df: df.groupby( by, axis=axis, group_keys=group_keys, **kwargs), [self._expr], requires_partition_by=partitionings.Arbitrary(), preserves_partition_by=partitionings.Arbitrary()), @@ -561,11 +554,9 @@ def prepend_index(df, by): # type: ignore return DeferredGroupBy( expressions.ComputedExpression( - 'groupbyindex', - lambda df: df.groupby( - level=list(range(df.index.nlevels)), - group_keys=group_keys, - **kwargs), [to_group], + 'groupbyindex', lambda df: df.groupby( + level=list(range(df.index.nlevels)), group_keys=group_keys, ** + kwargs), [to_group], requires_partition_by=partitionings.Index(), preserves_partition_by=partitionings.Arbitrary()), kwargs, @@ -612,8 +603,8 @@ def reset_index(self, level=None, **kwargs): requires_partition_by = partitionings.Arbitrary() return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'reset_index', - lambda df: df.reset_index(level=level, **kwargs), [self._expr], + 'reset_index', lambda df: df.reset_index(level=level, **kwargs), + [self._expr], preserves_partition_by=partitionings.Singleton(), requires_partition_by=requires_partition_by)) @@ -705,12 +696,8 @@ def replace(self, to_replace, value, limit, method, **kwargs): "requires collecting all data on a single node.")) return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'replace', - lambda df: df.replace( - to_replace=to_replace, - value=value, - limit=limit, - method=method, + 'replace', lambda df: df.replace( + to_replace=to_replace, value=value, limit=limit, method=method, **kwargs), [self._expr], preserves_partition_by=partitionings.Arbitrary(), requires_partition_by=requires_partition_by)) @@ -732,10 +719,8 @@ def tz_localize(self, ambiguous, **kwargs): elif isinstance(ambiguous, frame_base.DeferredFrame): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'tz_localize', - lambda df, - ambiguous: df.tz_localize(ambiguous=ambiguous, **kwargs), - [self._expr, ambiguous._expr], + 'tz_localize', lambda df, ambiguous: df.tz_localize( + ambiguous=ambiguous, **kwargs), [self._expr, ambiguous._expr], requires_partition_by=partitionings.Index(), preserves_partition_by=partitionings.Singleton())) elif ambiguous == 'infer': @@ -767,8 +752,7 @@ def size(self): with expressions.allow_non_parallel_operations(True): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'sum_sizes', - lambda sizes: sizes.sum(), [sizes], + 'sum_sizes', lambda sizes: sizes.sum(), [sizes], requires_partition_by=partitionings.Singleton(), preserves_partition_by=partitionings.Singleton())) @@ -787,8 +771,7 @@ def length(self): with expressions.allow_non_parallel_operations(True): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'sum_lengths', - lambda lengths: lengths.sum(), [lengths], + 'sum_lengths', lambda lengths: lengths.sum(), [lengths], requires_partition_by=partitionings.Singleton(), preserves_partition_by=partitionings.Singleton())) @@ -812,8 +795,7 @@ def empty(self): with expressions.allow_non_parallel_operations(True): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'check_all_empty', - lambda empties: empties.all(), [empties], + 'check_all_empty', lambda empties: empties.all(), [empties], requires_partition_by=partitionings.Singleton(), preserves_partition_by=partitionings.Singleton())) @@ -834,8 +816,7 @@ def bool(self): # Will throw if overall dataset has != 1 element return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'combine_all_bools', - lambda bools: bools.bool(), [bools], + 'combine_all_bools', lambda bools: bools.bool(), [bools], proxy=bool(), requires_partition_by=partitionings.Singleton(), preserves_partition_by=partitionings.Singleton())) @@ -845,8 +826,7 @@ def equals(self, other): intermediate = expressions.ComputedExpression( 'equals_partitioned', # Wrap scalar results in a Series for easier concatenation later - lambda df, - other: pd.Series(df.equals(other)), + lambda df, other: pd.Series(df.equals(other)), [self._expr, other._expr], requires_partition_by=partitionings.Index(), preserves_partition_by=partitionings.Singleton()) @@ -854,8 +834,7 @@ def equals(self, other): with expressions.allow_non_parallel_operations(True): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'aggregate_equals', - lambda df: df.all(), [intermediate], + 'aggregate_equals', lambda df: df.all(), [intermediate], requires_partition_by=partitionings.Singleton(), preserves_partition_by=partitionings.Singleton())) @@ -1010,8 +989,7 @@ def unstack(self, **kwargs): "Please upgrade to pandas 1.2.0 or higher to use this operation.") return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'unstack', - lambda s: s.unstack(**kwargs), [self._expr], + 'unstack', lambda s: s.unstack(**kwargs), [self._expr], requires_partition_by=partitionings.Index())) else: # Unstacking MultiIndex objects @@ -1062,8 +1040,7 @@ def unstack(self, **kwargs): with expressions.allow_non_parallel_operations(True): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'unstack', - lambda s: pd.concat([proxy, s.unstack(**kwargs)]), + 'unstack', lambda s: pd.concat([proxy, s.unstack(**kwargs)]), [self._expr], proxy=proxy, requires_partition_by=partitionings.Singleton())) @@ -1080,8 +1057,7 @@ def xs(self, key, axis, level, **kwargs): # KeyError at construction time for missing columns. return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'xs', - lambda df: df.xs(key, axis=axis, **kwargs), [self._expr], + 'xs', lambda df: df.xs(key, axis=axis, **kwargs), [self._expr], requires_partition_by=partitionings.Arbitrary(), preserves_partition_by=partitionings.Arbitrary())) elif axis not in ('index', 0): @@ -1239,6 +1215,7 @@ def pipe(self, func, *args, **kwargs): @populate_not_implemented(pd.Series) @frame_base.DeferredFrame._register_for(pd.Series) class DeferredSeries(DeferredDataFrameOrSeries): + def __repr__(self): return ( f'DeferredSeries(name={self.name!r}, dtype={self.dtype}, ' @@ -1251,6 +1228,7 @@ def name(self): @name.setter def name(self, value): + def fn(s): s = s.copy() s.name = value @@ -1266,16 +1244,14 @@ def fn(s): @frame_base.with_docs_from(pd.Series) def hasnans(self): has_nans = expressions.ComputedExpression( - 'hasnans', - lambda s: pd.Series(s.hasnans), [self._expr], + 'hasnans', lambda s: pd.Series(s.hasnans), [self._expr], requires_partition_by=partitionings.Arbitrary(), preserves_partition_by=partitionings.Singleton()) with expressions.allow_non_parallel_operations(): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'combine_hasnans', - lambda s: s.any(), [has_nans], + 'combine_hasnans', lambda s: s.any(), [has_nans], requires_partition_by=partitionings.Singleton(), preserves_partition_by=partitionings.Singleton())) @@ -1312,8 +1288,7 @@ def __getitem__(self, key): expressions.ComputedExpression( # yapf: disable 'getitem', - lambda df, - indexer: df[indexer], + lambda df, indexer: df[indexer], [self._expr, key._expr], requires_partition_by=partitionings.Index(), preserves_partition_by=partitionings.Arbitrary())) @@ -1369,9 +1344,7 @@ def append(self, to_append, ignore_index, verify_integrity, **kwargs): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'append', - lambda s, - to_append: s.append( + 'append', lambda s, to_append: s.append( to_append, verify_integrity=verify_integrity, **kwargs), [self._expr, to_append._expr], requires_partition_by=requires, @@ -1399,9 +1372,7 @@ def align(self, other, join, axis, level, method, **kwargs): # multiple return values. aligned = frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'align', - lambda x, - y: pd.concat([x, y], axis=1, join='inner'), + 'align', lambda x, y: pd.concat([x, y], axis=1, join='inner'), [self._expr, other._expr], requires_partition_by=partitionings.Index(), preserves_partition_by=partitionings.Arbitrary())) @@ -1488,8 +1459,7 @@ def compute_idx(s): with expressions.allow_non_parallel_operations(True): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'idx_combine', - lambda s: func(s, **kwargs), [idx_func], + 'idx_combine', lambda s: func(s, **kwargs), [idx_func], requires_partition_by=partitionings.Singleton(), preserves_partition_by=partitionings.Singleton())) @@ -1514,8 +1484,7 @@ def explode(self, ignore_index): partitionings.Singleton() if ignore_index else partitionings.Index()) return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'explode', - lambda s: s.explode(ignore_index), [self._expr], + 'explode', lambda s: s.explode(ignore_index), [self._expr], preserves_partition_by=preserves, requires_partition_by=partitionings.Arbitrary())) @@ -1558,8 +1527,7 @@ def dot(self, other): if right_is_series: result = expressions.ComputedExpression( - 'extract', - lambda df: df[0], [sums], + 'extract', lambda df: df[0], [sums], requires_partition_by=partitionings.Singleton()) else: result = sums @@ -1592,8 +1560,7 @@ def quantile(self, q, **kwargs): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'quantile', - lambda df: df.quantile(q=q, **kwargs), [self._expr], + 'quantile', lambda df: df.quantile(q=q, **kwargs), [self._expr], requires_partition_by=requires, preserves_partition_by=partitionings.Singleton())) @@ -1679,9 +1646,8 @@ def corr(self, other, method, min_periods): # and custom partitioning. return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'corr', - lambda df, - other: df.corr(other, method=method, min_periods=min_periods), + 'corr', lambda df, other: df.corr( + other, method=method, min_periods=min_periods), [self._expr, other._expr], requires_partition_by=partitionings.Singleton(reason=reason))) @@ -1896,8 +1862,7 @@ def combine_co_moments(data): def dropna(self, **kwargs): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'dropna', - lambda df: df.dropna(**kwargs), [self._expr], + 'dropna', lambda df: df.dropna(**kwargs), [self._expr], preserves_partition_by=partitionings.Arbitrary(), requires_partition_by=partitionings.Arbitrary())) @@ -2006,8 +1971,8 @@ def aggregate(self, func, axis, *args, **kwargs): rows = [self.agg([f], *args, **kwargs) for f in func] return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'join_aggregate', - lambda *rows: pd.concat(rows), [row._expr for row in rows])) + 'join_aggregate', lambda *rows: pd.concat(rows), + [row._expr for row in rows])) else: # We're only handling a single column. It could be 'func' or ['func'], # which produce different results. 'func' produces a scalar, ['func'] @@ -2190,15 +2155,14 @@ def nlargest(self, keep, **kwargs): reason="order-sensitive") kwargs['keep'] = keep per_partition = expressions.ComputedExpression( - 'nlargest-per-partition', - lambda df: df.nlargest(**kwargs), [self._expr], + 'nlargest-per-partition', lambda df: df.nlargest(**kwargs), + [self._expr], preserves_partition_by=partitionings.Arbitrary(), requires_partition_by=partitionings.Arbitrary()) with expressions.allow_non_parallel_operations(True): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'nlargest', - lambda df: df.nlargest(**kwargs), [per_partition], + 'nlargest', lambda df: df.nlargest(**kwargs), [per_partition], preserves_partition_by=partitionings.Arbitrary(), requires_partition_by=partitionings.Singleton())) @@ -2220,21 +2184,21 @@ def nsmallest(self, keep, **kwargs): reason="order-sensitive") kwargs['keep'] = keep per_partition = expressions.ComputedExpression( - 'nsmallest-per-partition', - lambda df: df.nsmallest(**kwargs), [self._expr], + 'nsmallest-per-partition', lambda df: df.nsmallest(**kwargs), + [self._expr], preserves_partition_by=partitionings.Arbitrary(), requires_partition_by=partitionings.Arbitrary()) with expressions.allow_non_parallel_operations(True): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'nsmallest', - lambda df: df.nsmallest(**kwargs), [per_partition], + 'nsmallest', lambda df: df.nsmallest(**kwargs), [per_partition], preserves_partition_by=partitionings.Arbitrary(), requires_partition_by=partitionings.Singleton())) @property # type: ignore @frame_base.with_docs_from(pd.Series) def is_unique(self): + def set_index(s): s = s[:] s.index = s @@ -2247,16 +2211,14 @@ def set_index(s): preserves_partition_by=partitionings.Singleton()) is_unique_distributed = expressions.ComputedExpression( - 'is_unique_distributed', - lambda s: pd.Series(s.is_unique), [self_index], + 'is_unique_distributed', lambda s: pd.Series(s.is_unique), [self_index], requires_partition_by=partitionings.Index(), preserves_partition_by=partitionings.Singleton()) with expressions.allow_non_parallel_operations(): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'combine', - lambda s: s.all(), [is_unique_distributed], + 'combine', lambda s: s.all(), [is_unique_distributed], requires_partition_by=partitionings.Singleton(), preserves_partition_by=partitionings.Singleton())) @@ -2293,8 +2255,7 @@ def unique(self, as_series=False): reason="non-deferred-result") return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'unique', - lambda df: pd.Series(df.unique()), [self._expr], + 'unique', lambda df: pd.Series(df.unique()), [self._expr], preserves_partition_by=partitionings.Singleton(), requires_partition_by=partitionings.Singleton( reason="unique() cannot currently be parallelized."))) @@ -2302,9 +2263,8 @@ def unique(self, as_series=False): @frame_base.with_docs_from(pd.Series) def update(self, other): self._expr = expressions.ComputedExpression( - 'update', - lambda df, - other: df.update(other) or df, [self._expr, other._expr], + 'update', lambda df, other: df.update(other) or df, + [self._expr, other._expr], preserves_partition_by=partitionings.Arbitrary(), requires_partition_by=partitionings.Index()) @@ -2424,16 +2384,14 @@ def repeat(self, repeats, axis): if isinstance(repeats, int): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( - 'repeat', - lambda series: series.repeat(repeats), [self._expr], + 'repeat', lambda series: series.repeat(repeats), [self._expr], requires_partition_by=partitionings.Arbitrary(), preserves_partition_by=partitionings.Arbitrary())) elif isinstance(repeats, frame_base.DeferredBase): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( 'repeat', - lambda series, - repeats_series: series.repeat(repeats_series), + lambda series, repeats_series: series.repeat(repeats_series), [self._expr, repeats._expr], requires_partition_by=partitionings.Index(), preserves_partition_by=partitionings.Arbitrary())) @@ -2467,8 +2425,7 @@ def compare(self, other, align_axis, **kwargs): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( 'compare', - lambda s, - other: s.compare(other, align_axis, **kwargs), + lambda s, other: s.compare(other, align_axis, **kwargs), [self._expr, other._expr], requires_partition_by=partitionings.Index(), preserves_partition_by=preserves_partition)) @@ -2477,6 +2434,7 @@ def compare(self, other, align_axis, **kwargs): @populate_not_implemented(pd.DataFrame) @frame_base.DeferredFrame._register_for(pd.DataFrame) class DeferredDataFrame(DeferredDataFrameOrSeries): + def __repr__(self): return ( f'DeferredDataFrame(columns={list(self.columns)}, ' @@ -2489,6 +2447,7 @@ def columns(self): @columns.setter def columns(self, columns): + def set_columns(df): df = df.copy() df.columns = columns diff --git a/sdks/python/apache_beam/dataframe/frames_test.py b/sdks/python/apache_beam/dataframe/frames_test.py index f99b77e446a8..d739152f6fd4 100644 --- a/sdks/python/apache_beam/dataframe/frames_test.py +++ b/sdks/python/apache_beam/dataframe/frames_test.py @@ -61,6 +61,7 @@ def _get_deferred_args(*args): class _AbstractFrameTest(unittest.TestCase): """Test sub-class with utilities for verifying DataFrame operations.""" + def _run_error_test( self, func, *args, construction_time=True, distributed=True): """Verify that func(*args) raises the same exception in pandas and in Beam. @@ -119,6 +120,7 @@ def _run_inplace_test(self, func, arg, **kwargs): Checks that func performs the same inplace operation on arg, in pandas and in Beam.""" + def wrapper(df): df = df.copy() func(df) @@ -273,6 +275,7 @@ def _run_test( class DeferredFrameTest(_AbstractFrameTest): """Miscellaneous tessts for DataFrame operations.""" + def test_series_arithmetic(self): a = pd.Series([1, 2, 3]) b = pd.Series([100, 200, 300]) @@ -345,6 +348,7 @@ def test_dataframe_xs(self): lambda df: df.xs('state'), df.set_index(['provider', 'time'])) def test_set_column(self): + def new_column(df): df['NewCol'] = df['Speed'] @@ -355,6 +359,7 @@ def new_column(df): self._run_inplace_test(new_column, df) def test_set_column_from_index(self): + def new_column(df): df['NewCol'] = df.index @@ -390,8 +395,7 @@ def test_tz_localize_ambiguous_series(self): ambiguous = pd.Series([True, True, False], index=s.index) self._run_test( - lambda s, - ambiguous: s.tz_localize('CET', ambiguous=ambiguous), + lambda s, ambiguous: s.tz_localize('CET', ambiguous=ambiguous), s, ambiguous) @@ -444,8 +448,7 @@ def test_combine_dataframe(self): df2 = pd.DataFrame({'A': [1, 1], 'B': [3, 3]}) take_smaller = lambda s1, s2: s1 if s1.sum() < s2.sum() else s2 self._run_test( - lambda df, - df2: df.combine(df2, take_smaller), + lambda df, df2: df.combine(df2, take_smaller), df, df2, nonparallel=True) @@ -455,8 +458,7 @@ def test_combine_dataframe_fill(self): df2 = pd.DataFrame({'A': [1, 1], 'B': [3, 3]}) take_smaller = lambda s1, s2: s1 if s1.sum() < s2.sum() else s2 self._run_test( - lambda df1, - df2: df1.combine(df2, take_smaller, fill_value=-5), + lambda df1, df2: df1.combine(df2, take_smaller, fill_value=-5), df1, df2, nonparallel=True) @@ -465,8 +467,7 @@ def test_combine_Series(self): s1 = pd.Series({'falcon': 330.0, 'eagle': 160.0}) s2 = pd.Series({'falcon': 345.0, 'eagle': 200.0, 'duck': 30.0}) self._run_test( - lambda s1, - s2: s1.combine(s2, max), + lambda s1, s2: s1.combine(s2, max), s1, s2, nonparallel=True, @@ -576,16 +577,14 @@ def test_merge(self): 'rkey': ['foo', 'bar', 'baz', 'foo'], 'value': [5, 6, 7, 8] }) self._run_test( - lambda df1, - df2: df1.merge(df2, left_on='lkey', right_on='rkey').rename( + lambda df1, df2: df1.merge(df2, left_on='lkey', right_on='rkey').rename( index=lambda x: '*'), df1, df2, nonparallel=True, check_proxy=False) self._run_test( - lambda df1, - df2: df1.merge( + lambda df1, df2: df1.merge( df2, left_on='lkey', right_on='rkey', suffixes=('_left', '_right')). rename(index=lambda x: '*'), df1, @@ -600,8 +599,8 @@ def test_merge_left_join(self): df2 = pd.DataFrame({'a': ['foo', 'baz'], 'c': [3, 4]}) self._run_test( - lambda df1, - df2: df1.merge(df2, how='left', on='a').rename(index=lambda x: '*'), + lambda df1, df2: df1.merge(df2, how='left', on='a').rename( + index=lambda x: '*'), df1, df2, nonparallel=True, @@ -618,8 +617,7 @@ def test_merge_on_index(self): }).set_index('rkey') self._run_test( - lambda df1, - df2: df1.merge(df2, left_index=True, right_index=True), + lambda df1, df2: df1.merge(df2, left_index=True, right_index=True), df1, df2, check_proxy=False) @@ -632,16 +630,14 @@ def test_merge_same_key(self): 'key': ['foo', 'bar', 'baz', 'foo'], 'value': [5, 6, 7, 8] }) self._run_test( - lambda df1, - df2: df1.merge(df2, on='key').rename(index=lambda x: '*'), + lambda df1, df2: df1.merge(df2, on='key').rename(index=lambda x: '*'), df1, df2, nonparallel=True, check_proxy=False) self._run_test( - lambda df1, - df2: df1.merge(df2, on='key', suffixes=('_left', '_right')).rename( - index=lambda x: '*'), + lambda df1, df2: df1.merge(df2, on='key', suffixes=('_left', '_right')). + rename(index=lambda x: '*'), df1, df2, nonparallel=True, @@ -652,16 +648,15 @@ def test_merge_same_key_doctest(self): df2 = pd.DataFrame({'a': ['foo', 'baz'], 'c': [3, 4]}) self._run_test( - lambda df1, - df2: df1.merge(df2, how='left', on='a').rename(index=lambda x: '*'), + lambda df1, df2: df1.merge(df2, how='left', on='a').rename( + index=lambda x: '*'), df1, df2, nonparallel=True, check_proxy=False) # Test without specifying 'on' self._run_test( - lambda df1, - df2: df1.merge(df2, how='left').rename(index=lambda x: '*'), + lambda df1, df2: df1.merge(df2, how='left').rename(index=lambda x: '*'), df1, df2, nonparallel=True, @@ -672,8 +667,7 @@ def test_merge_same_key_suffix_collision(self): df2 = pd.DataFrame({'a': ['foo', 'baz'], 'c': [3, 4], 'a_rsuffix': [7, 8]}) self._run_test( - lambda df1, - df2: df1.merge( + lambda df1, df2: df1.merge( df2, how='left', on='a', suffixes=('_lsuffix', '_rsuffix')).rename( index=lambda x: '*'), df1, @@ -682,9 +676,9 @@ def test_merge_same_key_suffix_collision(self): check_proxy=False) # Test without specifying 'on' self._run_test( - lambda df1, - df2: df1.merge(df2, how='left', suffixes=('_lsuffix', '_rsuffix')). - rename(index=lambda x: '*'), + lambda df1, df2: df1.merge( + df2, how='left', suffixes=('_lsuffix', '_rsuffix')).rename( + index=lambda x: '*'), df1, df2, nonparallel=True, @@ -731,10 +725,8 @@ def test_value_counts_with_nans(self): for normalize in (True, False): for dropna in (True, False): self._run_test( - lambda df, - dropna=dropna, - normalize=normalize: df.num_wings.value_counts( - dropna=dropna, normalize=normalize), + lambda df, dropna=dropna, normalize=normalize: df.num_wings. + value_counts(dropna=dropna, normalize=normalize), df) def test_value_counts_does_not_support_sort(self): @@ -962,11 +954,8 @@ def test_dataframe_melt(self): df) self._run_test( lambda df: df.melt( - id_vars=['A'], - value_vars=['B'], - var_name='myVarname', - value_name='myValname', - ignore_index=False), + id_vars=['A'], value_vars=['B'], var_name='myVarname', value_name= + 'myValname', ignore_index=False), df) self._run_test( lambda df: df.melt( @@ -1041,14 +1030,12 @@ def test_append_verify_integrity(self): df2 = pd.DataFrame({'A': range(10), 'B': range(10)}, index=range(9, 19)) self._run_error_test( - lambda s1, - s2: s1.append(s2, verify_integrity=True), + lambda s1, s2: s1.append(s2, verify_integrity=True), df1['A'], df2['A'], construction_time=False) self._run_error_test( - lambda df1, - df2: df1.append(df2, verify_integrity=True), + lambda df1, df2: df1.append(df2, verify_integrity=True), df1, df2, construction_time=False) @@ -1144,12 +1131,12 @@ def test_drop_duplicates(self): ( lambda base: base.from_dict({ 'row_1': [3, 2, 1, 0], 'row_2': ['a', 'b', 'c', 'd'] - }, - orient='index'), ), + }, orient='index'), ), ( lambda base: base.from_records( - np.array([(3, 'a'), (2, 'b'), (1, 'c'), (0, 'd')], - dtype=[('col_1', 'i4'), ('col_2', 'U1')])), ), + np.array([(3, 'a'), (2, 'b'), (1, 'c'), + (0, 'd')], dtype=[('col_1', 'i4'), + ('col_2', 'U1')])), ), ]) def test_create_methods(self, func): expected = func(pd.DataFrame) @@ -1242,8 +1229,7 @@ def test_dt_tz_localize_ambiguous_series(self): ambiguous = pd.Series([True, True, False], index=s.index) self._run_test( - lambda s, - ambiguous: s.dt.tz_localize('CET', ambiguous=ambiguous), + lambda s, ambiguous: s.dt.tz_localize('CET', ambiguous=ambiguous), s, ambiguous) @@ -1297,25 +1283,22 @@ def test_compare_dataframe(self): self._run_test(lambda df1, df2: df1.compare(df2), df1, df2) self._run_test( - lambda df1, - df2: df1.compare(df2, align_axis=0), + lambda df1, df2: df1.compare(df2, align_axis=0), df1, df2, check_proxy=False) self._run_test(lambda df1, df2: df1.compare(df2, keep_shape=True), df1, df2) self._run_test( - lambda df1, - df2: df1.compare(df2, align_axis=0, keep_shape=True), + lambda df1, df2: df1.compare(df2, align_axis=0, keep_shape=True), df1, df2) self._run_test( - lambda df1, - df2: df1.compare(df2, keep_shape=True, keep_equal=True), + lambda df1, df2: df1.compare(df2, keep_shape=True, keep_equal=True), df1, df2) self._run_test( - lambda df1, - df2: df1.compare(df2, align_axis=0, keep_shape=True, keep_equal=True), + lambda df1, df2: df1.compare( + df2, align_axis=0, keep_shape=True, keep_equal=True), df1, df2) @@ -1398,6 +1381,7 @@ def test_idxmax(self): self._run_test(lambda s2: s2.idxmax(skipna=False), s2) def test_pipe(self): + def df_times(df, column, times): df[column] = df[column] * times return df @@ -1720,6 +1704,7 @@ def numeric_only_kwargs_for_pandas_2(agg_type: str) -> dict[str, bool]: class GroupByTest(_AbstractFrameTest): """Tests for DataFrame/Series GroupBy operations.""" + @staticmethod def median_sum_fn(x): with warnings.catch_warnings(): @@ -1956,8 +1941,7 @@ def test_groupby_aggregate_grouped_column(self): [2, 1], ['foo', 0], [1, 'str'], - [3, 0, 2, 1], - ]) + [3, 0, 2, 1], ]) def test_groupby_level_agg(self, level): df = GROUPBY_DF.set_index(['group', 'foo', 'bar', 'str'], drop=False) self._run_test(lambda df: df.groupby(level=level).bar.max(), df) @@ -2212,8 +2196,8 @@ def test_dataframe_agg_level(self): def test_series_agg_multifunc_level(self): # level= is ignored for multiple agg fns self._run_test( - lambda df: df.set_index(['group', 'foo']).bar.agg(['min', 'max'], - level=0), + lambda df: df.set_index(['group', 'foo']).bar.agg(['min', 'max'], level= + 0), GROUPBY_DF) def test_series_mean_skipna(self): @@ -2408,6 +2392,7 @@ class BeamSpecificTest(unittest.TestCase): """Tests for functionality that's specific to the Beam DataFrame API. These features don't exist in pandas so we must verify them independently.""" + def assert_frame_data_equivalent( self, actual, expected, check_column_subset=False, extra_col_value=0): """Verify that actual is the same as expected, ignoring the index and order @@ -2807,8 +2792,7 @@ def test_sample_with_weights_distribution(self): self.assertTrue(target_weight > other_weight * 10, "weights too close") result = self._evaluate( - lambda s, - weights: s.sample(n=num_samples, weights=weights).sum(), + lambda s, weights: s.sample(n=num_samples, weights=weights).sum(), # The first elements are 1, the rest are all 0. This means that when # we sum all the sampled elements (above), the result should be the # number of times the first elements (aka targets) were sampled. @@ -2941,6 +2925,7 @@ def test_astype_categorical_rejected(self): class AllowNonParallelTest(unittest.TestCase): + def _use_non_parallel_operation(self): _ = frame_base.DeferredFrame.wrap( expressions.PlaceholderExpression(pd.Series([1, 2, 3]))).replace( @@ -3036,6 +3021,7 @@ def test_datetime_tz(self): class DocstringTest(unittest.TestCase): + @parameterized.expand([ (frames.DeferredDataFrame, pd.DataFrame), (frames.DeferredSeries, pd.Series), @@ -3073,6 +3059,7 @@ def test_docs_defined(self, beam_type, pd_type): class ReprTest(unittest.TestCase): + def test_basic_dataframe(self): df = frame_base.DeferredFrame.wrap( expressions.ConstantExpression(GROUPBY_DF)) @@ -3188,6 +3175,7 @@ def test_series_with_named_multi_index(self): '[interactive] dependency is not installed.') @isolated_env class InteractiveDataFrameTest(unittest.TestCase): + def test_collect_merged_dataframes(self): p = beam.Pipeline(InteractiveRunner()) pcoll_1 = ( diff --git a/sdks/python/apache_beam/dataframe/io.py b/sdks/python/apache_beam/dataframe/io.py index 5fcb7326a026..d7a5e650b59e 100644 --- a/sdks/python/apache_beam/dataframe/io.py +++ b/sdks/python/apache_beam/dataframe/io.py @@ -171,8 +171,7 @@ def to_json(df, path, orient=None, *args, **kwargs): @frame_base.with_docs_from(pd) def read_html(path, *args, **kwargs): return _ReadFromPandas( - lambda *args, - **kwargs: pd.read_html(*args, **kwargs)[0], + lambda *args, **kwargs: pd.read_html(*args, **kwargs)[0], path, args, kwargs) @@ -193,8 +192,8 @@ def to_html(df, path, *args, **kwargs): def _binary_reader(format): func = getattr(pd, 'read_%s' % format) - result = lambda path, *args, **kwargs: _ReadFromPandas(func, path, args, - kwargs) + result = lambda path, *args, **kwargs: _ReadFromPandas( + func, path, args, kwargs) result.__name__ = f'read_{format}' return result @@ -202,10 +201,8 @@ def _binary_reader(format): def _binary_writer(format): result = ( - lambda df, - path, - *args, - **kwargs: _as_pc(df) | _WriteToPandas(f'to_{format}', path, args, kwargs)) + lambda df, path, *args, **kwargs: _as_pc(df) | _WriteToPandas( + f'to_{format}', path, args, kwargs)) result.__name__ = f'to_{format}' return result @@ -249,6 +246,7 @@ def _shift_range_index(offset, df): class _ReadFromPandas(beam.PTransform): + def __init__( self, reader, @@ -294,9 +292,10 @@ def expand(self, root): matches_pcoll.pipeline | 'DoOnce' >> beam.Create([None]) | beam.Map( - lambda _, - paths: {path: ix - for ix, path in enumerate(sorted(paths))}, + lambda _, paths: { + path: ix + for ix, path in enumerate(sorted(paths)) + }, paths=beam.pvalue.AsList( matches_pcoll | beam.Map(lambda match: match.path)))) @@ -318,6 +317,7 @@ def expand(self, root): class _Splitter: + def empty_buffer(self): """Returns an empty buffer of the right type (string or bytes). """ @@ -351,6 +351,7 @@ class _DelimSplitter(_Splitter): This delimiter is assumed ot never occur within a record. """ + def __init__(self, delim, read_chunk_size=_DEFAULT_BYTES_CHUNKSIZE): # Multi-char delimiters would require more care across chunk boundaries. assert len(delim) == 1 @@ -393,6 +394,7 @@ class _TextFileSplitter(_DelimSplitter): Currently does not handle quoted newlines, so is off by default, but such support could be added in the future. """ + def __init__(self, args, kwargs, read_chunk_size=_DEFAULT_BYTES_CHUNKSIZE): if args: # TODO(robertwb): Automatically populate kwargs as we do for df methods. @@ -470,6 +472,7 @@ class _TruncatingFileHandle(object): As with all SDF trackers, the endpoint may change dynamically during reading. """ + def __init__(self, underlying, tracker, splitter): self._underlying = underlying self._tracker = tracker @@ -582,6 +585,7 @@ def flush(self): class _ReadFromPandasDoFn(beam.DoFn, beam.RestrictionProvider): + def __init__(self, reader, args, kwargs, binary, incremental, splitter): # avoid pickling issues if reader.__module__.startswith('pandas.'): @@ -653,6 +657,7 @@ def process( class _WriteToPandas(beam.PTransform): + def __init__( self, writer, path, args, kwargs, incremental=False, binary=True): self.writer = writer @@ -677,6 +682,7 @@ def expand(self, pcoll): class _WriteToPandasFileSink(fileio.FileSink): + def __init__(self, writer, args, kwargs, incremental, binary): if 'compression' in kwargs: raise NotImplementedError('compression') @@ -765,6 +771,7 @@ def flush_buffer(self): class ReadViaPandas(beam.PTransform): + def __init__( self, format, @@ -787,6 +794,7 @@ def expand(self, p): class WriteViaPandas(beam.PTransform): + def __init__(self, format, *args, **kwargs): self._writer_func = globals()['to_%s' % format] self._args = args @@ -830,6 +838,7 @@ class _ReadGbq(beam.PTransform): unspecified or set to false, the default is currently utilized (EXPORT). If the flag is set to true, 'DIRECT_READ' will be utilized.""" + def __init__( self, table=None, diff --git a/sdks/python/apache_beam/dataframe/io_it_test.py b/sdks/python/apache_beam/dataframe/io_it_test.py index 9f750e2ff58c..3e4a9a037f93 100644 --- a/sdks/python/apache_beam/dataframe/io_it_test.py +++ b/sdks/python/apache_beam/dataframe/io_it_test.py @@ -40,6 +40,7 @@ @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class ReadUsingReadGbqTests(unittest.TestCase): + @pytest.mark.it_postcommit def test_ReadGbq(self): from apache_beam.dataframe import convert diff --git a/sdks/python/apache_beam/dataframe/io_test.py b/sdks/python/apache_beam/dataframe/io_test.py index 92bb10225c78..17784589c661 100644 --- a/sdks/python/apache_beam/dataframe/io_test.py +++ b/sdks/python/apache_beam/dataframe/io_test.py @@ -69,6 +69,7 @@ class MyRow(typing.NamedTuple): platform.system() == 'Windows', 'https://github.com/apache/beam/issues/20642') class IOTest(unittest.TestCase): + def setUp(self): self._temp_roots = [] @@ -200,6 +201,7 @@ def _run_read_write_test( big['text'] = big.number.map(lambda n: 'f' + 'o' * n) def frame_equal_to(expected_, check_index=True, check_names=True): + def check(actual): expected = expected_ try: @@ -312,6 +314,7 @@ def test_truncating_filehandle_iter(self): ('skiprows', dict(skiprows=[0, 1], header=[0, 1], comment='X')), ]) def test_csv_splitter(self, name, kwargs): + def assert_frame_equal(expected, actual): try: pandas.testing.assert_frame_equal(expected, actual) @@ -442,6 +445,7 @@ def test_double_write(self): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class ReadGbqTransformTests(unittest.TestCase): + @mock.patch.object(BigQueryWrapper, 'get_table') def test_bad_schema_public_api_direct_read(self, get_table): try: @@ -466,6 +470,7 @@ def test_bad_schema_public_api_direct_read(self, get_table): table="dataset.sample_table", use_bqstorage_api=True) def test_unsupported_callable(self): + def filterTable(table): if table is not None: return table diff --git a/sdks/python/apache_beam/dataframe/pandas_doctests_test.py b/sdks/python/apache_beam/dataframe/pandas_doctests_test.py index c7ea908a9336..70e8693ab175 100644 --- a/sdks/python/apache_beam/dataframe/pandas_doctests_test.py +++ b/sdks/python/apache_beam/dataframe/pandas_doctests_test.py @@ -27,6 +27,7 @@ @unittest.skipIf( sys.platform == 'win32', '[https://github.com/apache/beam/issues/20361]') class DoctestTest(unittest.TestCase): + def test_ndframe_tests(self): # IO methods are tested in io_test.py skip_writes = { @@ -70,8 +71,7 @@ def test_ndframe_tests(self): "df.loc['2016-01-05':'2016-01-10', :].tail()" ], 'pandas.core.generic.NDFrame.replace': [ - "s.replace([1, 2], method='bfill')", - # Relies on method='pad' + "s.replace([1, 2], method='bfill')", # Relies on method='pad' "s.replace('a')", # Relies on method='pad' # value=None is not valid for pandas < 1.4 @@ -142,8 +142,7 @@ def test_ndframe_tests(self): # some kind, 2 was passed # pandas doctests only verify the type of exception 'df.rename(2)' - ], - # For pandas >= 1.4, rename is changed to _rename + ], # For pandas >= 1.4, rename is changed to _rename 'pandas.core.generic.NDFrame._rename': [ # Seems to be an upstream bug. The actual error has a different # message: @@ -151,8 +150,7 @@ def test_ndframe_tests(self): # some kind, 2 was passed # pandas doctests only verify the type of exception 'df.rename(2)' - ], - # Tests rely on setting index + ], # Tests rely on setting index 'pandas.core.generic.NDFrame.rename_axis': ['*'], # Raises right exception, but testing framework has matching issues. 'pandas.core.generic.NDFrame.replace': [ @@ -160,7 +158,7 @@ def test_ndframe_tests(self): ], 'pandas.core.generic.NDFrame.squeeze': ['*'], - # NameError + # NameError 'pandas.core.generic.NDFrame.resample': ['df'], # Skipped so we don't need to install natsort @@ -211,8 +209,7 @@ def test_dataframe_tests(self): "df.nsmallest(3, 'population', keep='last')", ], 'pandas.core.frame.DataFrame.replace': [ - "s.replace([1, 2], method='bfill')", - # Relies on method='pad' + "s.replace([1, 2], method='bfill')", # Relies on method='pad' "s.replace('a')", # Relies on method='pad' # value=None is not valid for pandas < 1.4 @@ -256,8 +253,7 @@ def test_dataframe_tests(self): "df.melt(id_vars=[('A', 'D')], value_vars=[('B', 'E')])", "df.melt(id_vars=['A'], value_vars=['B'],\n" + " var_name='myVarname', value_name='myValname')" - ], - # Most keep= options are order-sensitive + ], # Most keep= options are order-sensitive 'pandas.core.frame.DataFrame.drop_duplicates': ['*'], 'pandas.core.frame.DataFrame.duplicated': [ 'df.duplicated()', @@ -294,20 +290,18 @@ def test_dataframe_tests(self): "df1.merge(df2, how='cross')" ], - # TODO(https://github.com/apache/beam/issues/20759) + # TODO(https://github.com/apache/beam/issues/20759) 'pandas.core.frame.DataFrame.set_index': [ "df.set_index([s, s**2])", ], - 'pandas.core.frame.DataFrame.set_axis': [ "df.set_axis(range(0,2), axis='index')", ], - # TODO(https://github.com/apache/beam/issues/21014) + # TODO(https://github.com/apache/beam/issues/21014) 'pandas.core.frame.DataFrame.value_counts': [ - 'df.value_counts(dropna=False)' + 'df.value_counts(dropna=False)' ], - 'pandas.core.frame.DataFrame.to_timestamp': ['*'] }, skip={ @@ -315,14 +309,12 @@ def test_dataframe_tests(self): '*': [ # mul doesn't work in Beam with axis='index'. "df.mul({'circle': 0, 'triangle': 2, 'rectangle': 3}, " - "axis='index')", - # eq doesn't work with axis='index'. + "axis='index')", # eq doesn't work with axis='index'. "df.eq([250, 250, 100], axis='index')", # New test in Pandas 2.1 that uses indexes. 'df != pd.Series([100, 250], index=["cost", "revenue"])', # New test in Pandas 2.1 that uses indexes. 'df.le(df_multindex, level=1)' - ], # DeferredDataFrame doesn't implement the DF interchange protocol. 'pandas.core.frame.DataFrame.__dataframe__': ['*'], @@ -335,20 +327,17 @@ def test_dataframe_tests(self): 'df', 'df2 = pd.DataFrame(data=df1, index=["a", "c"])', 'df2', - ], - # s2 created with reindex + ], # s2 created with reindex 'pandas.core.frame.DataFrame.dot': [ 'df.dot(s2)', ], - 'pandas.core.frame.DataFrame.resample': ['df'], 'pandas.core.frame.DataFrame.asfreq': ['*'], # Throws NotImplementedError when modifying df 'pandas.core.frame.DataFrame.axes': [ # Returns deferred index. 'df.axes', - ], - # Skipped because the relies on loc to set cells in df2 + ], # Skipped because the relies on loc to set cells in df2 'pandas.core.frame.DataFrame.compare': ['*'], 'pandas.core.frame.DataFrame.cov': [ # Relies on setting entries ahead of time. @@ -371,8 +360,7 @@ def test_dataframe_tests(self): # This should pass as set_axis(axis='columns') # and fail with set_axis(axis='index') "df.set_axis(['a', 'b', 'c'], axis='index')" - ], - # Beam's implementation takes a filepath as an argument. + ], # Beam's implementation takes a filepath as an argument. 'pandas.core.frame.DataFrame.to_html': ['*'], 'pandas.core.frame.DataFrame.to_markdown': ['*'], 'pandas.core.frame.DataFrame.to_parquet': ['*'], @@ -384,11 +372,10 @@ def test_dataframe_tests(self): 'df.insert(1, "newcol", [99, 99])', 'df.insert(0, "col1", [100, 100], allow_duplicates=True)' ], - 'pandas.core.frame.DataFrame.to_records': [ 'df.index = df.index.rename("I")', - 'index_dtypes = f" 100).mean()', ], - 'pandas.core.series.Series.asfreq': ['*'], - # error formatting + 'pandas.core.series.Series.asfreq': ['*'], # error formatting 'pandas.core.series.Series.append': [ 's1.append(s2, verify_integrity=True)', ], 'pandas.core.series.Series.cov': [ # Differs in LSB on jenkins. "s1.cov(s2)", - ], - # Test framework doesn't materialze DeferredIndex. + ], # Test framework doesn't materialze DeferredIndex. 'pandas.core.series.Series.keys': ['s.keys()'], # Skipped idxmax/idxmin due an issue with the test framework 'pandas.core.series.Series.idxmin': ['s.idxmin()'], @@ -620,14 +600,12 @@ def test_series_tests(self): # Fails when result is a singleton: # https://github.com/apache/beam/issues/28559 'pandas.core.series.Series.kurt': [ - 'df.kurt(axis=None).round(6)', - 's.kurt()' + 'df.kurt(axis=None).round(6)', 's.kurt()' ], # Fails when result is a singleton: # https://github.com/apache/beam/issues/28559 'pandas.core.series.Series.sem': [ - 'df.sem().round(6)', - 's.sem().round(6)' + 'df.sem().round(6)', 's.sem().round(6)' ], }) self.assertEqual(result.failed, 0) @@ -675,13 +653,13 @@ def test_string_tests(self): "pd.Series(['foo', 'fuz', np.nan]).str.replace('f', repr)" ], - # output has incorrect formatting in 1.2.x + # output has incorrect formatting in 1.2.x f'{module_name}.StringMethods.extractall': ['*'], # For split and rsplit, if expand=True, then the series # must be of CategoricalDtype, which pandas doesn't convert to f'{module_name}.StringMethods.rsplit': [ - 's.str.split(r"\\+|=", expand=True)', # for pandas<1.4 + 's.str.split(r"\\+|=", expand=True)', # for pandas<1.4 's.str.split(expand=True)', 's.str.rsplit("/", n=1, expand=True)', 's.str.split(r"and|plus", expand=True)', @@ -692,7 +670,7 @@ def test_string_tests(self): 's.str.split(r"\\.jpg", regex=False, expand=True)' ], f'{module_name}.StringMethods.split': [ - 's.str.split(r"\\+|=", expand=True)', # for pandas<1.4 + 's.str.split(r"\\+|=", expand=True)', # for pandas<1.4 's.str.split(expand=True)', 's.str.rsplit("/", n=1, expand=True)', 's.str.split(r"and|plus", expand=True)', @@ -741,16 +719,16 @@ def test_datetime_tests(self): ], 'pandas.core.indexes.accessors.TimedeltaProperties.to_pytimedelta': [ '*' - ], - # pylint: enable=line-too-long - # Test uses to_datetime. Beam calls to_datetime element-wise, and - # therefore the .tz attribute is not evaluated on entire Series. - # Hence, .tz becomes None, unless explicitly set. - # See: see test_tz_with_utc_zone_set_explicitly + ], # pylint: enable=line-too-long + # Test uses to_datetime. Beam calls to_datetime element-wise, and + # therefore the .tz attribute is not evaluated on entire Series. + # Hence, .tz becomes None, unless explicitly set. + # See: see test_tz_with_utc_zone_set_explicitly 'pandas.core.indexes.accessors.DatetimeProperties.tz': ['*'], }) datetimelike_result = doctests.testmod( - pd.core.arrays.datetimelike, use_beam=False, + pd.core.arrays.datetimelike, + use_beam=False, not_implemented_ok={ # Beam Dataframes don't implement a deferred to_timedelta operation. # Top-level issue: https://github.com/apache/beam/issues/20318 @@ -758,14 +736,12 @@ def test_datetime_tests(self): "ser = pd.Series(pd.to_timedelta([1, 2, 3], unit='d'))", "tdelta_idx = pd.to_timedelta([1, 2, 3], unit='D')", 'tdelta_idx = pd.to_timedelta(["0 days", "10 days", "20 days"])', # pylint: disable=line-too-long - "tdelta_idx", "tdelta_idx.inferred_freq", "tdelta_idx.mean()", ], }) - datetime_result = doctests.testmod( pd.core.arrays.datetimes, use_beam=False, @@ -782,7 +758,8 @@ def test_datetime_tests(self): '*': [ "ser = pd.Series(pd.to_timedelta([1, 2, 3], unit='d'))", "tdelta_idx = pd.to_timedelta([1, 2, 3], unit='D')", - 'tdelta_idx = pd.to_timedelta(["0 days", "10 days", "20 days"])'], # pylint: disable=line-too-long + 'tdelta_idx = pd.to_timedelta(["0 days", "10 days", "20 days"])' + ], # pylint: disable=line-too-long # Verifies index version of this method 'pandas.core.arrays.datetimes.DatetimeArray.to_period': [ 'df.index.to_period("M")' @@ -872,8 +849,7 @@ def test_groupby_tests(self): 'pandas.core.groupby.groupby.GroupBy.resample': [ 'df.iloc[2, 0] = 5', 'df', - ], - # df is reassigned + ], # df is reassigned 'pandas.core.groupby.groupby.GroupBy.rank': ['df'], # TODO: Raise wont implement for list passed as a grouping column # Currently raises unhashable type: list @@ -887,11 +863,10 @@ def test_groupby_tests(self): pd.core.groupby.generic, use_beam=False, wont_implement_ok={ - '*' : [ + '*': [ # resample is WontImpl. "ser.resample('MS').nunique()", - ], - # TODO: Is take actually deprecated? + ], # TODO: Is take actually deprecated? 'pandas.core.groupby.generic.DataFrameGroupBy.take': ['*'], 'pandas.core.groupby.generic.SeriesGroupBy.take': ['*'], 'pandas.core.groupby.generic.SeriesGroupBy.nsmallest': [ @@ -945,23 +920,24 @@ def test_groupby_tests(self): "df.loc[df.index[:5], 'a'] = np.nan", "df.loc[df.index[5:10], 'b'] = np.nan", "df.cov(min_periods=12)", - ], - # These examples rely on grouping by a list + ], # These examples rely on grouping by a list 'pandas.core.groupby.generic.SeriesGroupBy.aggregate': ['*'], 'pandas.core.groupby.generic.DataFrameGroupBy.aggregate': ['*'], # Skipped idxmax/idxmin due an issue with the test framework 'pandas.core.groupby.generic.SeriesGroupBy.idxmin': ['s.idxmin()'], 'pandas.core.groupby.generic.SeriesGroupBy.idxmax': ['s.idxmax()'], # Order-sensitive operations. TODO: Return a better error message. - 'pandas.core.groupby.generic.SeriesGroupBy.is_monotonic_increasing': ['*'], # pylint: disable=line-too-long - 'pandas.core.groupby.generic.SeriesGroupBy.is_monotonic_decreasing': ['*'], # pylint: disable=line-too-long + 'pandas.core.groupby.generic.SeriesGroupBy.is_monotonic_increasing': [ + '*' + ], # pylint: disable=line-too-long + 'pandas.core.groupby.generic.SeriesGroupBy.is_monotonic_decreasing': [ + '*' + ], # pylint: disable=line-too-long # Uses as_index, which is currently not_implemented 'pandas.core.groupby.generic.DataFrameGroupBy.value_counts': [ - "df.groupby('gender', as_index=False).value_counts()", - # pylint: disable=line-too-long + "df.groupby('gender', as_index=False).value_counts()", # pylint: disable=line-too-long "df.groupby('gender', as_index=False).value_counts(normalize=True)", - ], - # These examples rely on grouping by a list + ], # These examples rely on grouping by a list 'pandas.core.groupby.generic.SeriesGroupBy.fillna': ['*'], # These examples rely on grouping by a list 'pandas.core.groupby.generic.DataFrameGroupBy.fillna': ['*'], @@ -972,8 +948,7 @@ def test_groupby_tests(self): # Named aggregation not supported yet. 'pandas.core.groupby.generic.NamedAgg': [ 'df.groupby("key").agg(result_a=agg_a, result_1=agg_1)' - ], - # These examples rely on grouping by a list + ], # These examples rely on grouping by a list 'pandas.core.groupby.generic.DataFrameGroupBy.transform': ['*'], # These examples rely on grouping by a list 'pandas.core.groupby.generic.SeriesGroupBy.transform': ['*'], @@ -1024,7 +999,9 @@ def test_top_level(self): 'pivot': ['*'], 'to_datetime': ['s.head()'], 'to_pickle': ['*'], - 'unique': ['pd.unique(pd.Series([("a", "b"), ("b", "a"), ("a", "c"), ("b", "a")]).values)'], # pylint: disable=line-too-long + 'unique': [ + 'pd.unique(pd.Series([("a", "b"), ("b", "a"), ("a", "c"), ("b", "a")]).values)' + ], # pylint: disable=line-too-long 'melt': [ "pd.melt(df, id_vars=['A'], value_vars=['B'])", "pd.melt(df, id_vars=['A'], value_vars=['B', 'C'])", @@ -1039,34 +1016,27 @@ def test_top_level(self): 'concat': [ 'pd.concat([df5, df6], verify_integrity=True)', 'pd.concat([df7, new_row.to_frame().T], ignore_index=True)' - ], - # doctest DeprecationWarning - 'melt': ['df'], - # Order-sensitive re-indexing. + ], # doctest DeprecationWarning + 'melt': ['df'], # Order-sensitive re-indexing. 'merge': [ "df1.merge(df2, left_on='lkey', right_on='rkey')", "df1.merge(df2, left_on='lkey', right_on='rkey',\n" " suffixes=('_left', '_right'))", "df1.merge(df2, how='left', on='a')", - ], - # Not an actual test. + ], # Not an actual test. 'option_context': ['*'], 'factorize': ['codes', 'uniques'], # Bad top-level use of un-imported function. 'merge_ordered': [ 'merge_ordered(df1, df2, fill_method="ffill", left_by="group")' - ], - # Expected error. + ], # Expected error. 'pivot': [ "df.pivot(index='foo', columns='bar', values='baz')", "df.pivot(index='foo', columns='bar')['baz']", - "df.pivot(index='foo', columns='bar', values=['baz', 'zoo'])", - # pylint: disable=line-too-long - 'df.pivot(index="lev1", columns=["lev2", "lev3"],values="values")', - # pylint: disable=line-too-long + "df.pivot(index='foo', columns='bar', values=['baz', 'zoo'])", # pylint: disable=line-too-long + 'df.pivot(index="lev1", columns=["lev2", "lev3"],values="values")', # pylint: disable=line-too-long 'df.pivot(index=["lev1", "lev2"], columns=["lev3"],values="values")' - ], - # Never written. + ], # Never written. 'to_pickle': ['os.remove("./dummy.pkl")'], **skip_reads }) diff --git a/sdks/python/apache_beam/dataframe/pandas_top_level_functions.py b/sdks/python/apache_beam/dataframe/pandas_top_level_functions.py index a8139675ad39..061529f5aeaa 100644 --- a/sdks/python/apache_beam/dataframe/pandas_top_level_functions.py +++ b/sdks/python/apache_beam/dataframe/pandas_top_level_functions.py @@ -28,6 +28,7 @@ def _call_on_first_arg(name): + def wrapper(target, *args, **kwargs): if isinstance(target, frame_base.DeferredBase): return getattr(target, name)(*args, **kwargs) @@ -131,14 +132,9 @@ def concat( expressions.ComputedExpression( 'concat', lambda *objs: pd.concat( - objs, - axis=axis, - join=join, - ignore_index=ignore_index, - keys=keys, - levels=levels, - names=names, - verify_integrity=verify_integrity), # yapf break + objs, axis=axis, join=join, ignore_index=ignore_index, keys= + keys, levels=levels, names=names, verify_integrity= + verify_integrity), # yapf break exprs, requires_partition_by=required_partitioning, preserves_partition_by=preserves_partitioning)) diff --git a/sdks/python/apache_beam/dataframe/partitionings.py b/sdks/python/apache_beam/dataframe/partitionings.py index 1fe760fe8589..094099089283 100644 --- a/sdks/python/apache_beam/dataframe/partitionings.py +++ b/sdks/python/apache_beam/dataframe/partitionings.py @@ -28,6 +28,7 @@ class Partitioning(object): """A class representing a (consistent) partitioning of dataframe objects. """ + def __repr__(self): return self.__class__.__name__ @@ -71,6 +72,7 @@ class Index(Partitioning): The ordering is implemented via the is_subpartitioning_of method, where the examples on the right are subpartitionings of the examples on the left above. """ + def __init__(self, levels=None): self._levels = levels @@ -148,6 +150,7 @@ def apply_consistent_order(dfs): class Singleton(Partitioning): """A partitioning of all the data into a single partition. """ + def __init__(self, reason=None): self._reason = reason @@ -189,6 +192,7 @@ class JoinIndex(Partitioning): Expressions desiring to make use of this index should simply declare a requirement of JoinIndex(). """ + def __init__(self, ancestor=None): self._ancestor = ancestor @@ -229,6 +233,7 @@ def check(self, dfs): class Arbitrary(Partitioning): """A partitioning imposing no constraints on the actual partitioning. """ + def __eq__(self, other): return type(self) == type(other) diff --git a/sdks/python/apache_beam/dataframe/schemas.py b/sdks/python/apache_beam/dataframe/schemas.py index f849ab11e77c..c45cee5e04bb 100644 --- a/sdks/python/apache_beam/dataframe/schemas.py +++ b/sdks/python/apache_beam/dataframe/schemas.py @@ -61,6 +61,7 @@ class BatchRowsAsDataFrame(beam.PTransform): Batching parameters are inherited from :class:`~apache_beam.transforms.util.BatchElements`. """ + def __init__(self, *args, proxy=None, **kwargs): self._batch_elements_transform = BatchElements(*args, **kwargs) self._proxy = proxy @@ -204,6 +205,7 @@ class UnbatchPandas(beam.PTransform): levels are unnamed (name=None), or if any of the names are not unique among all column and index names. """ + def __init__(self, proxy, include_indexes=False): self._proxy = proxy self._include_indexes = include_indexes diff --git a/sdks/python/apache_beam/dataframe/schemas_test.py b/sdks/python/apache_beam/dataframe/schemas_test.py index 4c196e29e712..634443e4bf88 100644 --- a/sdks/python/apache_beam/dataframe/schemas_test.py +++ b/sdks/python/apache_beam/dataframe/schemas_test.py @@ -47,6 +47,7 @@ def matches_df(expected): + def check_df_pcoll_equal(actual): actual = pd.concat(actual) sorted_actual = actual.sort_values(by=list(actual.columns)).reset_index( @@ -157,6 +158,7 @@ def test_name_func(testcase_func, param_num, params): class SchemasTest(unittest.TestCase): + def test_simple_df(self): expected = pd.DataFrame({ 'name': list(str(i) for i in range(5)), @@ -265,6 +267,7 @@ def test_batch_with_df_transform(self): assert_that(res, equal_to([('Falcon', 375.), ('Parrot', 25.)])) def assert_typehints_equal(self, left, right): + def maybe_drop_rowtypeconstraint(typehint): if isinstance(typehint, row_type.RowTypeConstraint): return typehint.user_type diff --git a/sdks/python/apache_beam/dataframe/transforms.py b/sdks/python/apache_beam/dataframe/transforms.py index 59e5eec05d2f..354f18618f57 100644 --- a/sdks/python/apache_beam/dataframe/transforms.py +++ b/sdks/python/apache_beam/dataframe/transforms.py @@ -90,6 +90,7 @@ class DataframeTransform(transforms.PTransform): .. _schema-aware: https://beam.apache.org/documentation/programming-guide/#what-is-a-schema """ + def __init__( self, func, proxy=None, yield_elements="schemas", include_indexes=False): self._func = func @@ -138,6 +139,7 @@ def expand(self, input_pcolls): class _DataframeExpressionsTransform(transforms.PTransform): + def __init__(self, outputs): self._outputs = outputs @@ -163,9 +165,11 @@ def _apply_deferred_ops( Logically, `_apply_deferred_ops({x: a, y: b}, {f: F(x, y), g: G(x, y)})` returns `{f: F(a, b), g: G(a, b)}`. """ + class ComputeStage(beam.PTransform): """A helper transform that computes a single stage of operations. """ + def __init__(self, stage): self.stage = stage @@ -208,12 +212,13 @@ def expand(self, pcolls): | 'SumSizes' >> beam.CombineGlobally(sum) | 'NumPartitions' >> beam.Map( lambda size: max( - MIN_PARTITIONS, - min(MAX_PARTITIONS, size // TARGET_PARTITION_SIZE)))) + MIN_PARTITIONS, min( + MAX_PARTITIONS, size // TARGET_PARTITION_SIZE)))) partition_fn = self.stage.partitioning.partition_fn class Partition(beam.PTransform): + def expand(self, pcoll): return ( pcoll @@ -245,10 +250,11 @@ def expand(self, pcoll): # Actually evaluate the expressions. def evaluate(partition, stage=self.stage, **side_inputs): + def lookup(expr): # Use proxy if there's no data in this partition - return expr.proxy( - ).iloc[:0] if partition[expr._id] is None else partition[expr._id] + return expr.proxy().iloc[:0] if partition[ + expr._id] is None else partition[expr._id] session = expressions.Session( dict([(expr, lookup(expr)) for expr in tabular_inputs] + @@ -265,6 +271,7 @@ class Stage(object): Note that these Dataframe "stages" contain a CoGBK and hence are often split across multiple "executable" stages. """ + def __init__(self, inputs, partitioning): self.inputs = set(inputs) if (len(self.inputs) > 1 and @@ -301,6 +308,7 @@ def output_partitioning_in_stage(expr, stage): """Return the output partitioning of expr when computed in stage, or returns None if the expression cannot be computed in this stage. """ + def maybe_upgrade_to_join_index(partitioning): if partitioning.is_subpartitioning_of(partitionings.JoinIndex()): return partitionings.JoinIndex(expr) @@ -420,8 +428,10 @@ def expr_to_stage(expr): @_memoize def stage_to_result(stage): - return {expr._id: expr_to_pcoll(expr) - for expr in stage.inputs} | ComputeStage(stage) + return { + expr._id: expr_to_pcoll(expr) + for expr in stage.inputs + } | ComputeStage(stage) @_memoize def expr_to_pcoll(expr): @@ -484,6 +494,7 @@ def _total_memory_usage(frame): class _PreBatch(beam.DoFn): + def __init__( self, target_size=TARGET_PARTITION_SIZE, min_size=MIN_PARTITION_SIZE): self._target_size = target_size @@ -519,6 +530,7 @@ class _ReBatch(beam.DoFn): Also groups across partitions, up to a given data size, to recover some efficiency in the face of over-partitioning. """ + def __init__( self, target_size=TARGET_PARTITION_SIZE, min_size=MIN_PARTITION_SIZE): self._target_size = target_size diff --git a/sdks/python/apache_beam/dataframe/transforms_test.py b/sdks/python/apache_beam/dataframe/transforms_test.py index a143606cc913..25829a256faa 100644 --- a/sdks/python/apache_beam/dataframe/transforms_test.py +++ b/sdks/python/apache_beam/dataframe/transforms_test.py @@ -74,6 +74,7 @@ def df_equal_to(expected): class TransformTest(unittest.TestCase): + def run_scenario(self, input, func): expected = func(input) @@ -274,6 +275,7 @@ def test_input_output_polymorphism(self): proxy = one_series[:0] def equal_to_series(expected): + def check(actual): actual = pd.concat(actual) if not expected.equals(actual): @@ -302,8 +304,7 @@ def check(actual): assert_that( dict(x=one, y=two) | 'DictIn' >> transforms.DataframeTransform( - lambda x, - y: (x + y), + lambda x, y: (x + y), proxy=dict(x=proxy, y=proxy), yield_elements='pandas'), equal_to_series(three_series), @@ -348,14 +349,14 @@ def test_rename(self): with expressions.allow_non_parallel_operations(): self.run_scenario( - df, - lambda df: df.rename( + df, lambda df: df.rename( columns={'B': 'C'}, index={ 0: 2, 2: 0 }, errors='raise')) class FusionTest(unittest.TestCase): + @staticmethod def fused_stages(p): return p.result.monitoring_metrics().query( @@ -381,6 +382,7 @@ def test_loc_filter(self): self.assertEqual(self.fused_stages(p), 1) def test_column_manipulation(self): + def set_column(df, name, s): df[name] = s return df @@ -394,6 +396,7 @@ def set_column(df, name, s): class TransformPartsTest(unittest.TestCase): + def test_rebatch(self): with beam.Pipeline() as p: sA = pd.Series(range(1000)) diff --git a/sdks/python/apache_beam/examples/avro_nyc_trips.py b/sdks/python/apache_beam/examples/avro_nyc_trips.py index 23d25649dad5..0f7cfb301c8f 100644 --- a/sdks/python/apache_beam/examples/avro_nyc_trips.py +++ b/sdks/python/apache_beam/examples/avro_nyc_trips.py @@ -125,6 +125,7 @@ class CreateKeyWithServiceAndDay(beam.DoFn): company name and the day of the week of the ride. The value is the original record dictionary object. """ + def process(self, record: dict): options = { 'HV0002': 'Juno', 'HV0003': 'Uber', 'HV0004': 'Via', 'HV0005': 'Lyft' @@ -154,6 +155,7 @@ class CalculatePricePerAttribute(beam.CombineFn): hire vehicle service. And calculates the price per mile, minute, and trip for both the driver and passenger. """ + def create_accumulator(self): total_price = 0.0 total_driver_pay = 0.0 @@ -180,13 +182,9 @@ def add_input(self, accumulator, record): return ( total_price + sum( record[name] for name in ( - 'base_passenger_fare', - 'tolls', - 'bcf', - 'sales_tax', - 'congestion_surcharge', - 'airport_fee', - 'tips') if record[name] is not None), + 'base_passenger_fare', 'tolls', 'bcf', 'sales_tax', + 'congestion_surcharge', 'airport_fee', 'tips') + if record[name] is not None), total_driver_pay + record['driver_pay'] + record['tips'], total_trip_miles + record['trip_miles'], total_trip_time + record['trip_time'], diff --git a/sdks/python/apache_beam/examples/avro_nyc_trips_test.py b/sdks/python/apache_beam/examples/avro_nyc_trips_test.py index 0dd31962a2f4..67141f23c4ca 100644 --- a/sdks/python/apache_beam/examples/avro_nyc_trips_test.py +++ b/sdks/python/apache_beam/examples/avro_nyc_trips_test.py @@ -29,6 +29,7 @@ class AvroNycTripsTest(unittest.TestCase): + def test_create_key_with_service_and_day(self): RECORDS = [ { diff --git a/sdks/python/apache_beam/examples/complete/autocomplete.py b/sdks/python/apache_beam/examples/complete/autocomplete.py index 4e4c5143b96b..e4e45dd1a7cc 100644 --- a/sdks/python/apache_beam/examples/complete/autocomplete.py +++ b/sdks/python/apache_beam/examples/complete/autocomplete.py @@ -57,6 +57,7 @@ def format_result(prefix_candidates): class TopPerPrefix(beam.PTransform): + def __init__(self, count): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() diff --git a/sdks/python/apache_beam/examples/complete/autocomplete_it_test.py b/sdks/python/apache_beam/examples/complete/autocomplete_it_test.py index a19af5873186..101eccc84e38 100644 --- a/sdks/python/apache_beam/examples/complete/autocomplete_it_test.py +++ b/sdks/python/apache_beam/examples/complete/autocomplete_it_test.py @@ -32,6 +32,7 @@ def format_output_file(output_string): + def extract_prefix_topk_words_tuples(line): match = re.match(r'(.*): \[(.*)\]', line) prefix = match.group(1) diff --git a/sdks/python/apache_beam/examples/complete/distribopt.py b/sdks/python/apache_beam/examples/complete/distribopt.py index 89c312fcbf5e..3fd9a70691d7 100644 --- a/sdks/python/apache_beam/examples/complete/distribopt.py +++ b/sdks/python/apache_beam/examples/complete/distribopt.py @@ -68,6 +68,7 @@ class Simulator(object): """Greenhouse simulation for the optimization of greenhouse parameters.""" + def __init__(self, quantities): self.quantities = np.atleast_1d(quantities) @@ -99,6 +100,7 @@ class CreateGrid(beam.PTransform): } Output: tuple (mapping_identifier, {crop -> greenhouse}) """ + class PreGenerateMappings(beam.DoFn): """ParDo implementation forming based on two elements a small sub grid. @@ -107,6 +109,7 @@ class PreGenerateMappings(beam.DoFn): two tuples, and a list of remaining records. Both serve as an input to GenerateMappings. """ + def process(self, element): records = list(element[1]) # Split of 2 crops and pre-generate the subgrid. @@ -135,6 +138,7 @@ class GenerateMappings(beam.DoFn): Input: output of PreGenerateMappings Output: tuples of the form (mapping_identifier, {crop -> greenhouse}) """ + @staticmethod def _coordinates_to_greenhouse(coordinates, greenhouses, crops): # Map the grid coordinates back to greenhouse labels @@ -185,6 +189,7 @@ def expand(self, records): class OptimizeGrid(beam.PTransform): """A transform for optimizing all greenhouses of the mapping grid.""" + class CreateOptimizationTasks(beam.DoFn): """ Create tasks for optimization. @@ -192,6 +197,7 @@ class CreateOptimizationTasks(beam.DoFn): Input: (mapping_identifier, {crop -> greenhouse}) Output: ((mapping_identifier, greenhouse), [(crop, quantity),...]) """ + def process(self, element, quantities): mapping_identifier, mapping = element @@ -213,6 +219,7 @@ class OptimizeProductParameters(beam.DoFn): - solution: (mapping_identifier, (greenhouse, [production parameters])) - costs: (crop, greenhouse, mapping_identifier, cost) """ + @staticmethod def _optimize_production_parameters(sim): # setup initial starting point & bounds @@ -247,6 +254,7 @@ def expand(self, inputs): class CreateTransportData(beam.DoFn): """Transform records to pvalues ((crop, greenhouse), transport_cost)""" + def process(self, record): crop = record['crop'] for greenhouse, transport_cost in record['transport_costs']: diff --git a/sdks/python/apache_beam/examples/complete/estimate_pi.py b/sdks/python/apache_beam/examples/complete/estimate_pi.py index 530a270308d9..350f078b138d 100644 --- a/sdks/python/apache_beam/examples/complete/estimate_pi.py +++ b/sdks/python/apache_beam/examples/complete/estimate_pi.py @@ -81,12 +81,14 @@ def combine_results(results): class JsonCoder(object): """A JSON coder used to format the final result.""" + def encode(self, x): return json.dumps(x).encode('utf-8') class EstimatePiTransform(beam.PTransform): """Runs 10M trials, and combine the results to estimate pi.""" + def __init__(self, tries_per_work_item=100000): self.tries_per_work_item = tries_per_work_item diff --git a/sdks/python/apache_beam/examples/complete/estimate_pi_it_test.py b/sdks/python/apache_beam/examples/complete/estimate_pi_it_test.py index bf6f8fc76c11..6365e1491dba 100644 --- a/sdks/python/apache_beam/examples/complete/estimate_pi_it_test.py +++ b/sdks/python/apache_beam/examples/complete/estimate_pi_it_test.py @@ -31,6 +31,7 @@ class EstimatePiIT(unittest.TestCase): + @pytest.mark.no_xdist @pytest.mark.examples_postcommit def test_estimate_pi_output_file(self): diff --git a/sdks/python/apache_beam/examples/complete/estimate_pi_test.py b/sdks/python/apache_beam/examples/complete/estimate_pi_test.py index ff224f4a9cab..909c4ab66b2a 100644 --- a/sdks/python/apache_beam/examples/complete/estimate_pi_test.py +++ b/sdks/python/apache_beam/examples/complete/estimate_pi_test.py @@ -29,6 +29,7 @@ def in_between(lower, upper): + def _in_between(actual): _, _, estimate = actual[0] if estimate < lower or estimate > upper: @@ -39,6 +40,7 @@ def _in_between(actual): class EstimatePiTest(unittest.TestCase): + def test_basics(self): with TestPipeline() as p: result = p | 'Estimate' >> estimate_pi.EstimatePiTransform(5000) diff --git a/sdks/python/apache_beam/examples/complete/game/game_stats.py b/sdks/python/apache_beam/examples/complete/game/game_stats.py index 233d22b75427..950006776964 100644 --- a/sdks/python/apache_beam/examples/complete/game/game_stats.py +++ b/sdks/python/apache_beam/examples/complete/game/game_stats.py @@ -104,6 +104,7 @@ class ParseGameEventFn(beam.DoFn): The human-readable time string is not used here. """ + def __init__(self): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() @@ -130,6 +131,7 @@ class ExtractAndSumScore(beam.PTransform): The constructor argument `field` determines whether 'team' or 'user' info is extracted. """ + def __init__(self, field): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() @@ -150,6 +152,7 @@ class TeamScoresDict(beam.DoFn): formats everything together into a dictionary. The dictionary is in the format {'bigquery_column': value} """ + def process(self, team_score, window=beam.DoFn.WindowParam): team, score = team_score start = timestamp2str(int(window.start)) @@ -163,6 +166,7 @@ def process(self, team_score, window=beam.DoFn.WindowParam): class WriteToBigQuery(beam.PTransform): """Generate, format, and write BigQuery table row information.""" + def __init__(self, table_name, dataset, schema, project): """Initializes the transform. Args: @@ -231,6 +235,7 @@ def expand(self, user_scores): class UserSessionActivity(beam.DoFn): """Calculate and output an element's session duration, in seconds.""" + def process(self, elem, window=beam.DoFn.WindowParam): yield (window.end.micros - window.start.micros) // 1000000 diff --git a/sdks/python/apache_beam/examples/complete/game/hourly_team_score.py b/sdks/python/apache_beam/examples/complete/game/hourly_team_score.py index 48a105af527d..21d00fad8508 100644 --- a/sdks/python/apache_beam/examples/complete/game/hourly_team_score.py +++ b/sdks/python/apache_beam/examples/complete/game/hourly_team_score.py @@ -104,6 +104,7 @@ class ParseGameEventFn(beam.DoFn): The human-readable time string is not used here. """ + def __init__(self): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() @@ -130,6 +131,7 @@ class ExtractAndSumScore(beam.PTransform): The constructor argument `field` determines whether 'team' or 'user' info is extracted. """ + def __init__(self, field): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() @@ -150,6 +152,7 @@ class TeamScoresDict(beam.DoFn): formats everything together into a dictionary. The dictionary is in the format {'bigquery_column': value} """ + def process(self, team_score, window=beam.DoFn.WindowParam): team, score = team_score start = timestamp2str(int(window.start)) @@ -163,6 +166,7 @@ def process(self, team_score, window=beam.DoFn.WindowParam): class WriteToBigQuery(beam.PTransform): """Generate, format, and write BigQuery table row information.""" + def __init__(self, table_name, dataset, schema, project): """Initializes the transform. Args: @@ -195,6 +199,7 @@ def expand(self, pcoll): # [START main] class HourlyTeamScore(beam.PTransform): + def __init__(self, start_min, stop_min, window_duration): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() diff --git a/sdks/python/apache_beam/examples/complete/game/leader_board.py b/sdks/python/apache_beam/examples/complete/game/leader_board.py index 308e1e1cf5c0..71a28406bf7d 100644 --- a/sdks/python/apache_beam/examples/complete/game/leader_board.py +++ b/sdks/python/apache_beam/examples/complete/game/leader_board.py @@ -113,6 +113,7 @@ class ParseGameEventFn(beam.DoFn): The human-readable time string is not used here. """ + def __init__(self): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() @@ -139,6 +140,7 @@ class ExtractAndSumScore(beam.PTransform): The constructor argument `field` determines whether 'team' or 'user' info is extracted. """ + def __init__(self, field): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() @@ -159,6 +161,7 @@ class TeamScoresDict(beam.DoFn): formats everything together into a dictionary. The dictionary is in the format {'bigquery_column': value} """ + def process(self, team_score, window=beam.DoFn.WindowParam): team, score = team_score start = timestamp2str(int(window.start)) @@ -172,6 +175,7 @@ def process(self, team_score, window=beam.DoFn.WindowParam): class WriteToBigQuery(beam.PTransform): """Generate, format, and write BigQuery table row information.""" + def __init__(self, table_name, dataset, schema, project): """Initializes the transform. Args: @@ -209,6 +213,7 @@ class CalculateTeamScores(beam.PTransform): Extract team/score pairs from the event stream, using hour-long windows by default. """ + def __init__(self, team_window_duration, allowed_lateness): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() @@ -241,6 +246,7 @@ class CalculateUserScores(beam.PTransform): """Extract user/score pairs from the event stream using processing time, via global windowing. Get periodic updates on all users' running scores. """ + def __init__(self, allowed_lateness): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() diff --git a/sdks/python/apache_beam/examples/complete/game/user_score.py b/sdks/python/apache_beam/examples/complete/game/user_score.py index 03f0d00fc30f..00056cff586a 100644 --- a/sdks/python/apache_beam/examples/complete/game/user_score.py +++ b/sdks/python/apache_beam/examples/complete/game/user_score.py @@ -96,6 +96,7 @@ class ParseGameEventFn(beam.DoFn): The human-readable time string is not used here. """ + def __init__(self): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() @@ -123,6 +124,7 @@ class ExtractAndSumScore(beam.PTransform): The constructor argument `field` determines whether 'team' or 'user' info is extracted. """ + def __init__(self, field): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() @@ -140,6 +142,7 @@ def expand(self, pcoll): class UserScore(beam.PTransform): + def expand(self, pcoll): return ( pcoll diff --git a/sdks/python/apache_beam/examples/complete/juliaset/juliaset/juliaset.py b/sdks/python/apache_beam/examples/complete/juliaset/juliaset/juliaset.py index 4f98105c66f1..c08b968e0abd 100644 --- a/sdks/python/apache_beam/examples/complete/juliaset/juliaset/juliaset.py +++ b/sdks/python/apache_beam/examples/complete/juliaset/juliaset/juliaset.py @@ -46,6 +46,7 @@ def get_julia_set_point_color(element, c, n, max_iterations): def generate_julia_set_colors(pipeline, c, n, max_iterations): """Compute julia set coordinates for each point in our set.""" + def point_set(n): for x in range(n): for y in range(n): diff --git a/sdks/python/apache_beam/examples/complete/juliaset/juliaset/juliaset_test.py b/sdks/python/apache_beam/examples/complete/juliaset/juliaset/juliaset_test.py index 6416831f4269..0a69d89edb1c 100644 --- a/sdks/python/apache_beam/examples/complete/juliaset/juliaset/juliaset_test.py +++ b/sdks/python/apache_beam/examples/complete/juliaset/juliaset/juliaset_test.py @@ -33,6 +33,7 @@ @pytest.mark.examples_postcommit class JuliaSetTest(unittest.TestCase): + def setUp(self): self.test_files = {} self.test_files['output_coord_file_name'] = self.generate_temp_file() diff --git a/sdks/python/apache_beam/examples/complete/juliaset/setup.py b/sdks/python/apache_beam/examples/complete/juliaset/setup.py index c3a9fe043765..05ecbd8439e1 100644 --- a/sdks/python/apache_beam/examples/complete/juliaset/setup.py +++ b/sdks/python/apache_beam/examples/complete/juliaset/setup.py @@ -81,6 +81,7 @@ class build(_build): # pylint: disable=invalid-name class CustomCommands(setuptools.Command): """A setuptools Command class able to run arbitrary commands.""" + def initialize_options(self): pass diff --git a/sdks/python/apache_beam/examples/complete/tfidf.py b/sdks/python/apache_beam/examples/complete/tfidf.py index d7829f9d1c7d..651e9ec32983 100644 --- a/sdks/python/apache_beam/examples/complete/tfidf.py +++ b/sdks/python/apache_beam/examples/complete/tfidf.py @@ -54,6 +54,7 @@ class TfIdf(beam.PTransform): the value is a piece of the document's content. The output is mapping from terms to scores for each document URI. """ + def expand(self, uri_to_content): # Compute the total number of documents, and prepare a singleton diff --git a/sdks/python/apache_beam/examples/complete/tfidf_it_test.py b/sdks/python/apache_beam/examples/complete/tfidf_it_test.py index 3ecbd0c1ecae..6985f288b2ad 100644 --- a/sdks/python/apache_beam/examples/complete/tfidf_it_test.py +++ b/sdks/python/apache_beam/examples/complete/tfidf_it_test.py @@ -41,6 +41,7 @@ class TfIdfIT(unittest.TestCase): + @pytest.mark.examples_postcommit @pytest.mark.sickbay_flink def test_basics(self): diff --git a/sdks/python/apache_beam/examples/complete/tfidf_test.py b/sdks/python/apache_beam/examples/complete/tfidf_test.py index 085b9e2dd186..6b31dd05f351 100644 --- a/sdks/python/apache_beam/examples/complete/tfidf_test.py +++ b/sdks/python/apache_beam/examples/complete/tfidf_test.py @@ -36,6 +36,7 @@ class TfIdfTest(unittest.TestCase): + def test_tfidf_transform(self): with TestPipeline() as p: diff --git a/sdks/python/apache_beam/examples/complete/top_wikipedia_sessions.py b/sdks/python/apache_beam/examples/complete/top_wikipedia_sessions.py index 50b026edf240..e6c0bc7e01fb 100644 --- a/sdks/python/apache_beam/examples/complete/top_wikipedia_sessions.py +++ b/sdks/python/apache_beam/examples/complete/top_wikipedia_sessions.py @@ -76,6 +76,7 @@ class ComputeSessions(beam.PTransform): A session is defined as a string of edits where each is separated from the next by less than an hour. """ + def expand(self, pcoll): return ( pcoll @@ -86,6 +87,7 @@ def expand(self, pcoll): class TopPerMonth(beam.PTransform): """Computes the longest session ending in each month.""" + def expand(self, pcoll): return ( pcoll @@ -110,6 +112,7 @@ def format_output(element, window=beam.DoFn.WindowParam): class ComputeTopSessions(beam.PTransform): """Computes the top user sessions for each month.""" + def __init__(self, sampling_threshold): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() diff --git a/sdks/python/apache_beam/examples/cookbook/bigquery_side_input.py b/sdks/python/apache_beam/examples/cookbook/bigquery_side_input.py index 6f40b14226b5..f858ecc29653 100644 --- a/sdks/python/apache_beam/examples/cookbook/bigquery_side_input.py +++ b/sdks/python/apache_beam/examples/cookbook/bigquery_side_input.py @@ -43,6 +43,7 @@ def create_groups(group_ids, corpus, word, ignore_corpus, ignore_word): """Generate groups given the input PCollections.""" + def attach_corpus_fn(group, corpus, ignore): selected = None len_corpus = len(corpus) diff --git a/sdks/python/apache_beam/examples/cookbook/bigquery_side_input_test.py b/sdks/python/apache_beam/examples/cookbook/bigquery_side_input_test.py index ba9b61e342bd..3274c4a0e9b2 100644 --- a/sdks/python/apache_beam/examples/cookbook/bigquery_side_input_test.py +++ b/sdks/python/apache_beam/examples/cookbook/bigquery_side_input_test.py @@ -30,6 +30,7 @@ class BigQuerySideInputTest(unittest.TestCase): + def test_create_groups(self): with TestPipeline() as p: diff --git a/sdks/python/apache_beam/examples/cookbook/bigquery_tornadoes_test.py b/sdks/python/apache_beam/examples/cookbook/bigquery_tornadoes_test.py index e9f3f679ad80..a8763bf79cc1 100644 --- a/sdks/python/apache_beam/examples/cookbook/bigquery_tornadoes_test.py +++ b/sdks/python/apache_beam/examples/cookbook/bigquery_tornadoes_test.py @@ -30,6 +30,7 @@ class BigQueryTornadoesTest(unittest.TestCase): + def test_basics(self): with TestPipeline() as p: rows = ( diff --git a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py index 0a8c55d17d3a..a6a28b56ba06 100644 --- a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py +++ b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py @@ -66,6 +66,7 @@ class GenerateTestRows(beam.PTransform): Bigtable Table. """ + def __init__(self, number, project_id=None, instance_id=None, table_id=None): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() diff --git a/sdks/python/apache_beam/examples/cookbook/coders.py b/sdks/python/apache_beam/examples/cookbook/coders.py index 33b63fd2954f..655f7a88d7c3 100644 --- a/sdks/python/apache_beam/examples/cookbook/coders.py +++ b/sdks/python/apache_beam/examples/cookbook/coders.py @@ -44,6 +44,7 @@ class JsonCoder(Coder): """A JSON coder interpreting each line as a JSON string.""" + def encode(self, x): return json.dumps(x).encode('utf-8') diff --git a/sdks/python/apache_beam/examples/cookbook/coders_it_test.py b/sdks/python/apache_beam/examples/cookbook/coders_it_test.py index 941311ce5dc3..837c0c68423f 100644 --- a/sdks/python/apache_beam/examples/cookbook/coders_it_test.py +++ b/sdks/python/apache_beam/examples/cookbook/coders_it_test.py @@ -32,6 +32,7 @@ def format_result(result_string): + def format_tuple(result_elem_list): [country, counter] = result_elem_list return country, int(counter.strip()) diff --git a/sdks/python/apache_beam/examples/cookbook/combiners_test.py b/sdks/python/apache_beam/examples/cookbook/combiners_test.py index c24f60b5dcfe..a20e7a02c73a 100644 --- a/sdks/python/apache_beam/examples/cookbook/combiners_test.py +++ b/sdks/python/apache_beam/examples/cookbook/combiners_test.py @@ -57,6 +57,7 @@ def test_combine_per_key_with_callable(self): def test_combine_per_key_with_custom_callable(self): """CombinePerKey using a custom function reducing iterables.""" + def multiply(values): result = 1 for v in values: diff --git a/sdks/python/apache_beam/examples/cookbook/custom_ptransform.py b/sdks/python/apache_beam/examples/cookbook/custom_ptransform.py index a922216a5220..51325ccab333 100644 --- a/sdks/python/apache_beam/examples/cookbook/custom_ptransform.py +++ b/sdks/python/apache_beam/examples/cookbook/custom_ptransform.py @@ -36,6 +36,7 @@ class Count1(beam.PTransform): """Count as a subclass of PTransform, with an apply method.""" + def expand(self, pcoll): return ( pcoll diff --git a/sdks/python/apache_beam/examples/cookbook/custom_ptransform_it_test.py b/sdks/python/apache_beam/examples/cookbook/custom_ptransform_it_test.py index 9ad0c52bf23c..52e4fdc77df9 100644 --- a/sdks/python/apache_beam/examples/cookbook/custom_ptransform_it_test.py +++ b/sdks/python/apache_beam/examples/cookbook/custom_ptransform_it_test.py @@ -31,6 +31,7 @@ def format_result(result_string): + def format_tuple(result_elem_list): [country, counter] = result_elem_list return country, int(counter.strip()) diff --git a/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py b/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py index 9d71ac32aff2..cc2db388587f 100644 --- a/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py +++ b/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py @@ -81,6 +81,7 @@ @beam.typehints.with_output_types(Text) class WordExtractingDoFn(beam.DoFn): """Parse each line of input text into words.""" + def __init__(self): self.empty_line_counter = Metrics.counter('main', 'empty_lines') self.word_length_counter = Metrics.counter('main', 'word_lengths') @@ -111,6 +112,7 @@ def process(self, element: Entity) -> Optional[Iterable[Text]]: class EntityWrapper(object): """Create a Cloud Datastore entity from the given string.""" + def __init__(self, project, namespace, kind, ancestor): self._project = project self._namespace = namespace diff --git a/sdks/python/apache_beam/examples/cookbook/group_with_coder.py b/sdks/python/apache_beam/examples/cookbook/group_with_coder.py index 8a959138d3da..2ce443b85afa 100644 --- a/sdks/python/apache_beam/examples/cookbook/group_with_coder.py +++ b/sdks/python/apache_beam/examples/cookbook/group_with_coder.py @@ -42,12 +42,14 @@ class Player(object): """A custom class used as a key in combine/group transforms.""" + def __init__(self, name): self.name = name class PlayerCoder(coders.Coder): """A custom coder for the Player class.""" + def encode(self, o): """Encode to bytes with a trace that coder was used.""" # Our encoding prepends an 'x:' prefix. diff --git a/sdks/python/apache_beam/examples/cookbook/multiple_output_pardo.py b/sdks/python/apache_beam/examples/cookbook/multiple_output_pardo.py index 781ec9a8682a..51cdc80373d3 100644 --- a/sdks/python/apache_beam/examples/cookbook/multiple_output_pardo.py +++ b/sdks/python/apache_beam/examples/cookbook/multiple_output_pardo.py @@ -135,7 +135,9 @@ class CountWords(beam.PTransform): A PTransform that converts a PCollection containing words into a PCollection of "word: count" strings. """ + def expand(self, pcoll): + def count_ones(word_ones): (word, ones) = word_ones return (word, sum(ones)) diff --git a/sdks/python/apache_beam/examples/dataframe/taxiride_it_test.py b/sdks/python/apache_beam/examples/dataframe/taxiride_it_test.py index f81b7d8cfa60..7ab66e7fd0f4 100644 --- a/sdks/python/apache_beam/examples/dataframe/taxiride_it_test.py +++ b/sdks/python/apache_beam/examples/dataframe/taxiride_it_test.py @@ -34,6 +34,7 @@ class TaxirideIT(unittest.TestCase): + def setUp(self): self.test_pipeline = TestPipeline(is_integration_test=True) self.outdir = ( diff --git a/sdks/python/apache_beam/examples/inference/anomaly_detection/anomaly_detection_pipeline/main.py b/sdks/python/apache_beam/examples/inference/anomaly_detection/anomaly_detection_pipeline/main.py index ae59b94a1a67..75a88ef23c9f 100644 --- a/sdks/python/apache_beam/examples/inference/anomaly_detection/anomaly_detection_pipeline/main.py +++ b/sdks/python/apache_beam/examples/inference/anomaly_detection/anomaly_detection_pipeline/main.py @@ -77,6 +77,7 @@ class PytorchNoBatchModelHandler(PytorchModelHandlerKeyedTensor): Restricting max_batch_size to 1 means there is only 1 example per `batch` in the run_inference() call. """ + def batch_elements_kwargs(self): return {"max_batch_size": 1} diff --git a/sdks/python/apache_beam/examples/inference/anomaly_detection/anomaly_detection_pipeline/pipeline/transformations.py b/sdks/python/apache_beam/examples/inference/anomaly_detection/anomaly_detection_pipeline/pipeline/transformations.py index bb90c5a93cb4..dcdf2d3c78e7 100644 --- a/sdks/python/apache_beam/examples/inference/anomaly_detection/anomaly_detection_pipeline/pipeline/transformations.py +++ b/sdks/python/apache_beam/examples/inference/anomaly_detection/anomaly_detection_pipeline/pipeline/transformations.py @@ -60,6 +60,7 @@ def tokenize_sentence(input_dict): class ModelWrapper(DistilBertModel): """Wrapper to DistilBertModel to get embeddings when calling forward function.""" + def forward(self, **kwargs): output = super().forward(**kwargs) sentence_embedding = ( @@ -123,6 +124,7 @@ def run_inference(self, batch, model, inference_args=None): class NormalizeEmbedding(beam.DoFn): """A DoFn for normalization of text embedding.""" + def process(self, element, *args, **kwargs): """ For each element in the input PCollection, normalize the embedding vector, and @@ -139,6 +141,7 @@ def process(self, element, *args, **kwargs): class DecodePubSubMessage(beam.DoFn): """A DoFn for decoding PubSub message into a dictionary.""" + def process(self, element, *args, **kwargs): """ For each element in the input PCollection, retrieve the id and decode the bytes into string @@ -154,6 +157,7 @@ def process(self, element, *args, **kwargs): class DecodePrediction(beam.DoFn): """A DoFn for decoding the prediction from RunInference.""" + def process(self, element): """ The `process` function takes the output of RunInference and returns a dictionary @@ -170,6 +174,7 @@ def process(self, element): class TriggerEmailAlert(beam.DoFn): """A DoFn for sending email using yagmail.""" + def setup(self): """ Opens the cred.json file and initializes the yag SMTP client. diff --git a/sdks/python/apache_beam/examples/inference/anomaly_detection/write_data_to_pubsub_pipeline/pipeline/utils.py b/sdks/python/apache_beam/examples/inference/anomaly_detection/write_data_to_pubsub_pipeline/pipeline/utils.py index 66568084349e..30b8062e1283 100644 --- a/sdks/python/apache_beam/examples/inference/anomaly_detection/write_data_to_pubsub_pipeline/pipeline/utils.py +++ b/sdks/python/apache_beam/examples/inference/anomaly_detection/write_data_to_pubsub_pipeline/pipeline/utils.py @@ -55,6 +55,7 @@ def get_dataset(categories: list, split: str = "train"): class AssignUniqueID(beam.DoFn): """A DoFn for assigning Unique ID to each text.""" + def process(self, element, *args, **kwargs): uid = str(uuid.uuid4()) yield {"id": uid, "text": element} @@ -62,6 +63,7 @@ def process(self, element, *args, **kwargs): class ConvertToPubSubMessage(beam.DoFn): """A DoFn for converting into PubSub message format.""" + def process(self, element, *args, **kwargs): yield PubsubMessage( data=element["text"].encode("utf-8"), attributes={"id": element["id"]}) diff --git a/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py b/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py index 69c2eacc593d..ede4c8753aaf 100644 --- a/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py +++ b/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py @@ -75,6 +75,7 @@ class PostProcessor(beam.DoFn): The logits are the output of the Model. We can get the word with the highest probability of being a candidate replacement word by taking the argmax. """ + def __init__(self, tokenizer: AutoTokenizer): super().__init__() self.tokenizer = tokenizer diff --git a/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py b/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py index 7d4899cc38d9..4ac5813a650f 100644 --- a/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py +++ b/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py @@ -48,6 +48,7 @@ class PostProcessor(beam.DoFn): Hugging Face Pipeline for Question Answering returns a dictionary with score, start and end index of answer and the answer. """ + def process(self, result: tuple[str, PredictionResult]) -> Iterable[str]: text, prediction = result predicted_answer = prediction.inference['answer'] diff --git a/sdks/python/apache_beam/examples/inference/large_language_modeling/main.py b/sdks/python/apache_beam/examples/inference/large_language_modeling/main.py index 6fd06ae4d1d6..a3adad56ed65 100644 --- a/sdks/python/apache_beam/examples/inference/large_language_modeling/main.py +++ b/sdks/python/apache_beam/examples/inference/large_language_modeling/main.py @@ -37,6 +37,7 @@ class Preprocess(beam.DoFn): + def __init__(self, tokenizer: AutoTokenizer): self._tokenizer = tokenizer @@ -59,6 +60,7 @@ def process(self, element): class Postprocess(beam.DoFn): + def __init__(self, tokenizer: AutoTokenizer): self._tokenizer = tokenizer diff --git a/sdks/python/apache_beam/examples/inference/milk_quality_prediction_windowing.py b/sdks/python/apache_beam/examples/inference/milk_quality_prediction_windowing.py index dfec4640fdad..71761a49ccfc 100644 --- a/sdks/python/apache_beam/examples/inference/milk_quality_prediction_windowing.py +++ b/sdks/python/apache_beam/examples/inference/milk_quality_prediction_windowing.py @@ -138,6 +138,7 @@ class MilkQualityAggregation(NamedTuple): class AggregateMilkQualityResults(beam.CombineFn): """Simple aggregation to keep track of the number of samples with good, bad and medium quality milk.""" + def create_accumulator(self): return MilkQualityAggregation(0, 0, 0) diff --git a/sdks/python/apache_beam/examples/inference/multi_language_inference/multi_language_custom_transform/multi_language_custom_transform/composite_transform.py b/sdks/python/apache_beam/examples/inference/multi_language_inference/multi_language_custom_transform/multi_language_custom_transform/composite_transform.py index 3e1794bc5829..a386acb77084 100644 --- a/sdks/python/apache_beam/examples/inference/multi_language_inference/multi_language_custom_transform/multi_language_custom_transform/composite_transform.py +++ b/sdks/python/apache_beam/examples/inference/multi_language_inference/multi_language_custom_transform/multi_language_custom_transform/composite_transform.py @@ -35,6 +35,7 @@ class InferenceTransform(ptransform.PTransform): + class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor): """Wrapper to PytorchModelHandler to limit batch size to 1. The tokenized strings generated from BertTokenizer may have different @@ -43,10 +44,12 @@ class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor): Restricting max_batch_size to 1 means there is only 1 example per `batch` in the run_inference() call. """ + def batch_elements_kwargs(self): return {'max_batch_size': 1} class Preprocess(beam.DoFn): + def __init__(self, tokenizer): self._tokenizer = tokenizer logging.info('Starting Preprocess.') @@ -80,6 +83,7 @@ def process(self, text: str): return [(text, tokens)] class Postprocess(beam.DoFn): + def __init__(self, bert_tokenizer): self.bert_tokenizer = bert_tokenizer logging.info('Starting Postprocess') diff --git a/sdks/python/apache_beam/examples/inference/online_clustering/clustering_pipeline/main.py b/sdks/python/apache_beam/examples/inference/online_clustering/clustering_pipeline/main.py index 350eec7ed875..4d9187254c11 100644 --- a/sdks/python/apache_beam/examples/inference/online_clustering/clustering_pipeline/main.py +++ b/sdks/python/apache_beam/examples/inference/online_clustering/clustering_pipeline/main.py @@ -74,6 +74,7 @@ class PytorchNoBatchModelHandler(PytorchModelHandlerKeyedTensor): Restricting max_batch_size to 1 means there is only 1 example per `batch` in the run_inference() call. """ + def batch_elements_kwargs(self): return {"max_batch_size": 1} diff --git a/sdks/python/apache_beam/examples/inference/online_clustering/clustering_pipeline/pipeline/transformations.py b/sdks/python/apache_beam/examples/inference/online_clustering/clustering_pipeline/pipeline/transformations.py index 26010516f0c6..2e6b2a8e439b 100644 --- a/sdks/python/apache_beam/examples/inference/online_clustering/clustering_pipeline/pipeline/transformations.py +++ b/sdks/python/apache_beam/examples/inference/online_clustering/clustering_pipeline/pipeline/transformations.py @@ -54,6 +54,7 @@ def tokenize_sentence(input_dict): class ModelWrapper(DistilBertModel): """Wrapper to DistilBertModel to get embeddings when calling forward function.""" + def forward(self, **kwargs): output = super().forward(**kwargs) sentence_embedding = ( @@ -84,6 +85,7 @@ def mean_pooling(self, model_output, attention_mask): class NormalizeEmbedding(beam.DoFn): """A DoFn for normalization of text embedding.""" + def process(self, element, *args, **kwargs): """ For each element in the input PCollection, normalize the embedding vector, and @@ -99,6 +101,7 @@ def process(self, element, *args, **kwargs): class Decode(beam.DoFn): """A DoFn for decoding PubSub message into a dictionary.""" + def process(self, element, *args, **kwargs): """ For each element in the input PCollection, retrieve the id and decode the bytes into string @@ -181,6 +184,7 @@ def process( class GetUpdates(beam.DoFn): """A DoFn for printing the clusters and items belonging to each cluster.""" + def process(self, element, *args, **kwargs): """ Prints and returns clusters with items contained in it diff --git a/sdks/python/apache_beam/examples/inference/online_clustering/write_data_to_pubsub_pipeline/pipeline/utils.py b/sdks/python/apache_beam/examples/inference/online_clustering/write_data_to_pubsub_pipeline/pipeline/utils.py index c0007de63d29..d2de237d41c3 100644 --- a/sdks/python/apache_beam/examples/inference/online_clustering/write_data_to_pubsub_pipeline/pipeline/utils.py +++ b/sdks/python/apache_beam/examples/inference/online_clustering/write_data_to_pubsub_pipeline/pipeline/utils.py @@ -55,6 +55,7 @@ def get_dataset(categories: list, split: str = "train"): class AssignUniqueID(beam.DoFn): """A DoFn for assigning Unique ID to each text.""" + def process(self, element, *args, **kwargs): uid = str(uuid.uuid4()) yield {"id": uid, "text": element} @@ -62,6 +63,7 @@ def process(self, element, *args, **kwargs): class ConvertToPubSubMessage(beam.DoFn): """A DoFn for converting into PubSub message format.""" + def process(self, element, *args, **kwargs): yield PubsubMessage( data=element["text"].encode("utf-8"), attributes={"id": element["id"]}) diff --git a/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py b/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py index 0e62ab865431..4e9803c890c6 100644 --- a/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py +++ b/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py @@ -62,6 +62,7 @@ def filter_empty_lines(text: str) -> Iterator[str]: class PostProcessor(beam.DoFn): + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: filename, prediction_result = element prediction = np.argmax(prediction_result.inference, axis=0) @@ -112,6 +113,7 @@ class OnnxNoBatchModelHandler(OnnxModelHandlerNumpy): Restricting max_batch_size to 1 means there is only 1 example per `batch` in the run_inference() call. """ + def batch_elements_kwargs(self): return {'max_batch_size': 1} diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py index 787341263fde..878b5b9763df 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py @@ -115,6 +115,7 @@ class PostProcessor(beam.DoFn): Return filename, prediction and the model id used to perform the prediction """ + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: filename, prediction_result = element prediction = torch.argmax(prediction_result.inference, dim=0) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py b/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py index 5e5f77a679c3..76f2311ba4f3 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py @@ -160,6 +160,7 @@ def filter_empty_lines(text: str) -> Iterator[str]: class PostProcessor(beam.DoFn): + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: filename, prediction_result = element prediction_labels = prediction_result.inference['labels'] diff --git a/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py b/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py index a616998d2c73..a3ef1db7609c 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py @@ -78,6 +78,7 @@ class PostProcessor(beam.DoFn): of the words in BERT’s vocabulary. We can get the word with the highest probability of being a candidate replacement word by taking the argmax. """ + def __init__(self, bert_tokenizer: BertTokenizer): super().__init__() self.bert_tokenizer = bert_tokenizer @@ -165,6 +166,7 @@ class PytorchNoBatchModelHandler(PytorchModelHandlerKeyedTensor): Restricting max_batch_size to 1 means there is only 1 example per `batch` in the run_inference() call. """ + def batch_elements_kwargs(self): return {'max_batch_size': 1} @@ -181,18 +183,19 @@ def batch_elements_kwargs(self): bert_tokenizer = BertTokenizer.from_pretrained(known_args.bert_tokenizer) if not known_args.input: - text = (pipeline | 'CreateSentences' >> beam.Create([ - 'The capital of France is Paris .', - 'It is raining cats and dogs .', - 'He looked up and saw the sun and stars .', - 'Today is Monday and tomorrow is Tuesday .', - 'There are 5 coconuts on this palm tree .', - 'The richest person in the world is not here .', - 'Malls are amazing places to shop because you can find everything you need under one roof .', # pylint: disable=line-too-long - 'This audiobook is sure to liquefy your brain .', - 'The secret ingredient to his wonderful life was gratitude .', - 'The biggest animal in the world is the whale .', - ])) + text = ( + pipeline | 'CreateSentences' >> beam.Create([ + 'The capital of France is Paris .', + 'It is raining cats and dogs .', + 'He looked up and saw the sun and stars .', + 'Today is Monday and tomorrow is Tuesday .', + 'There are 5 coconuts on this palm tree .', + 'The richest person in the world is not here .', + 'Malls are amazing places to shop because you can find everything you need under one roof .', # pylint: disable=line-too-long + 'This audiobook is sure to liquefy your brain .', + 'The secret ingredient to his wonderful life was gratitude .', + 'The biggest animal in the world is the whale .', + ])) else: text = ( pipeline | 'ReadSentences' >> beam.io.ReadFromText(known_args.input)) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py index 18c4c3e653b4..0bc5ed266dfd 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py @@ -166,6 +166,7 @@ def filter_empty_lines(text: str) -> Iterator[str]: class KeyExamplesForEachModelType(beam.DoFn): """Duplicate data to run against each model type""" + def process( self, element: tuple[torch.Tensor, str]) -> Iterable[tuple[str, torch.Tensor]]: @@ -174,6 +175,7 @@ def process( class PostProcessor(beam.DoFn): + def process( self, element: tuple[str, PredictionResult]) -> tuple[torch.Tensor, str]: model, prediction_result = element @@ -183,6 +185,7 @@ def process( class FormatResults(beam.DoFn): + def process(self, element): _, filename_prediction = element predictions = filename_prediction['predictions'] diff --git a/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py b/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py index 755eff17c163..42cd7596bcca 100644 --- a/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py +++ b/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py @@ -38,22 +38,26 @@ # create some fake models which returns different inference results. class FakeModelDefault: + def predict(self, example: int) -> int: return example class FakeModelAdd(FakeModelDefault): + def predict(self, example: int) -> int: return example + 1 class FakeModelSub(FakeModelDefault): + def predict(self, example: int) -> int: return example - 1 class FakeModelHandlerReturnsPredictionResult( base.ModelHandler[int, base.PredictionResult, FakeModelDefault]): + def __init__(self, clock=None, model_id='model_default'): self.model_id = model_id self._fake_clock = clock @@ -96,6 +100,7 @@ def run(argv=None, save_main_session=True): options.view_as(SetupOptions).save_main_session = save_main_session class GetModel(beam.DoFn): + def process(self, element) -> Iterable[base.ModelMetadata]: if time.time() > mid_ts: yield base.ModelMetadata( diff --git a/sdks/python/apache_beam/examples/inference/runinference_metrics/pipeline/transformations.py b/sdks/python/apache_beam/examples/inference/runinference_metrics/pipeline/transformations.py index e7f6f9d44689..ccbf6753078f 100644 --- a/sdks/python/apache_beam/examples/inference/runinference_metrics/pipeline/transformations.py +++ b/sdks/python/apache_beam/examples/inference/runinference_metrics/pipeline/transformations.py @@ -27,6 +27,7 @@ class CustomPytorchModelHandlerKeyedTensor(PytorchModelHandlerKeyedTensor): """Wrapper around PytorchModelHandlerKeyedTensor to load a model on CPU.""" + def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" model = self._model_class(**self._model_params) @@ -43,6 +44,7 @@ class HuggingFaceStripBatchingWrapper(DistilBertForSequenceClassification): as a list of dicts instead of a dict of lists. Another workaround can be found here where they disable batching instead. https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py""" + def forward(self, **kwargs): output = super().forward(**kwargs) return [dict(zip(output, v)) for v in zip(*output.values())] @@ -50,6 +52,7 @@ def forward(self, **kwargs): class Tokenize(beam.DoFn): """A DoFn for tokenizing texts""" + def __init__(self, model_name: str): """Initialises a tokenizer based on the model_name""" self._model_name = model_name @@ -75,6 +78,7 @@ def process(self, text_input: str): class PostProcessor(beam.DoFn): """Postprocess the RunInference output""" + def process(self, element): """ Takes the input text and the prediction result, and returns a dictionary diff --git a/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py b/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py index 0a527e88dec2..8b85af3a05bc 100644 --- a/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py +++ b/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py @@ -88,6 +88,7 @@ def sort_by_features(dataframe, max_size): class LoadDataframe(beam.DoFn): + def process(self, file_name: str) -> Iterable[pandas.DataFrame]: """ Loads data files as a pandas dataframe.""" file = FileSystems.open(file_name, 'rb') diff --git a/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py b/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py index d7d08e294e9d..6f3890c3c0a6 100644 --- a/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py +++ b/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py @@ -51,6 +51,7 @@ class PostProcessor(beam.DoFn): """Process the PredictionResult to get the predicted label. Returns a comma separated string with true label and predicted label. """ + def process(self, element: tuple[int, PredictionResult]) -> Iterable[str]: label, prediction_result = element prediction = prediction_result.inference diff --git a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py index b44d775f4ad3..a24d8c6b3db5 100644 --- a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py +++ b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py @@ -37,6 +37,7 @@ class PostProcessor(beam.DoFn): """Process the PredictionResult to get the predicted label. Returns predicted label. """ + def setup(self): labels_path = tf.keras.utils.get_file( 'ImageNetLabels.txt', diff --git a/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py b/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py index bf85bb1aef16..f5f9b4fa8c9c 100644 --- a/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py +++ b/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py @@ -45,6 +45,7 @@ class PostProcessor(beam.DoFn): """Process the PredictionResult to get the predicted label. Returns a comma separated string with true label and predicted label. """ + def process(self, element: tuple[int, PredictionResult]) -> Iterable[str]: label, prediction_result = element prediction = numpy.argmax(prediction_result.inference, axis=0) diff --git a/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py b/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py index 677d36b9b767..39e1d625375f 100644 --- a/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py +++ b/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py @@ -167,6 +167,7 @@ class PostProcessor(beam.DoFn): an integer that we can transform into actual string class using COCO_OBJ_DET_CLASSES as reference. """ + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: key, prediction_result = element filename, im_width, im_height = key diff --git a/sdks/python/apache_beam/examples/inference/tensorrt_text_classification.py b/sdks/python/apache_beam/examples/inference/tensorrt_text_classification.py index a5cda68fd79e..e101e7d2de82 100644 --- a/sdks/python/apache_beam/examples/inference/tensorrt_text_classification.py +++ b/sdks/python/apache_beam/examples/inference/tensorrt_text_classification.py @@ -43,6 +43,7 @@ class Preprocess(beam.DoFn): The input sentences are tokenized because the model is expecting tokens. """ + def __init__(self, tokenizer: AutoTokenizer): self._tokenizer = tokenizer @@ -59,6 +60,7 @@ class Postprocess(beam.DoFn): We can get the class label by getting the index of maximum logit using argmax. """ + def __init__(self, tokenizer: AutoTokenizer): self._tokenizer = tokenizer diff --git a/sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py b/sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py index 9230f84955eb..31fd8c922f5b 100644 --- a/sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py +++ b/sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py @@ -54,6 +54,7 @@ class TFModelWrapperWithSignature(tf.keras.Model): saved_model_spec=saved_model_spec) model_handler = CreateModelHandler(inferece_spec_type) """ + def __init__( self, model, diff --git a/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py b/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py index 5df0b51e36d7..ed78534dd43e 100644 --- a/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py @@ -95,6 +95,7 @@ def convert_image_to_example_proto(tensor: tf.Tensor) -> tf.train.Example: class ProcessInferenceToString(beam.DoFn): + def process( self, element: tuple[str, prediction_log_pb2.PredictionLog]) -> Iterable[str]: diff --git a/sdks/python/apache_beam/examples/inference/tfx_bsl/tfx_bsl_inference_it_test.py b/sdks/python/apache_beam/examples/inference/tfx_bsl/tfx_bsl_inference_it_test.py index d72794df4f77..4b4376775853 100644 --- a/sdks/python/apache_beam/examples/inference/tfx_bsl/tfx_bsl_inference_it_test.py +++ b/sdks/python/apache_beam/examples/inference/tfx_bsl/tfx_bsl_inference_it_test.py @@ -60,6 +60,7 @@ def process_outputs(filepath): tfx_bsl is None, 'Missing dependencies. ' 'Test depends on tfx_bsl') class TFXRunInferenceTests(unittest.TestCase): + @pytest.mark.uses_tensorflow @pytest.mark.it_postcommit def test_tfx_run_inference_mobilenetv2(self): diff --git a/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py b/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py index 20312e7d3c88..55c28b244473 100644 --- a/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py @@ -117,6 +117,7 @@ def preprocess_image(data: bytes) -> list[float]: class PostProcessor(beam.DoFn): + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: img_name, prediction_result = element prediction_vals = prediction_result.inference diff --git a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py index 2708c0f3d1a1..f1f0b769affa 100644 --- a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py +++ b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py @@ -116,6 +116,7 @@ def parse_known_args(argv): class PostProcessor(beam.DoFn): + def process(self, element: PredictionResult) -> Iterable[str]: yield str(element.example) + ": " + str(element.inference) diff --git a/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py b/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py index 498511a5a2cf..660a745f7e63 100644 --- a/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py +++ b/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py @@ -46,6 +46,7 @@ class PostProcessor(beam.DoFn): """Process the PredictionResult to get the predicted label. Returns a comma separated string with true label and predicted label. """ + def process(self, element: tuple[int, PredictionResult]) -> Iterable[str]: label, prediction_result = element prediction = prediction_result.inference @@ -101,8 +102,8 @@ def load_sklearn_iris_test_data( dataset['data'], dataset['target'], test_size=.2, random_state=seed) if split: - return [(index, data_type(sample.reshape(1, -1))) for index, - sample in enumerate(x_test)] + return [(index, data_type(sample.reshape(1, -1))) + for index, sample in enumerate(x_test)] return [(0, data_type(x_test))] diff --git a/sdks/python/apache_beam/examples/ml_transform/ml_transform_it_test.py b/sdks/python/apache_beam/examples/ml_transform/ml_transform_it_test.py index 96fb3f775671..e28dc3313fa2 100644 --- a/sdks/python/apache_beam/examples/ml_transform/ml_transform_it_test.py +++ b/sdks/python/apache_beam/examples/ml_transform/ml_transform_it_test.py @@ -59,6 +59,7 @@ def _publish_metrics(pipeline, metric_value, metrics_table, metric_name): @pytest.mark.uses_tft class LargeMovieReviewDatasetProcessTest(unittest.TestCase): + def test_process_large_movie_review_dataset(self): input_data_dir = 'gs://apache-beam-ml/datasets/aclImdb' artifact_location = os.path.join(_OUTPUT_GCS_BUCKET_ROOT, uuid.uuid4().hex) diff --git a/sdks/python/apache_beam/examples/ml_transform/vocab_tfidf_processing.py b/sdks/python/apache_beam/examples/ml_transform/vocab_tfidf_processing.py index b8ae61ce51e5..f0d5d1c652b5 100644 --- a/sdks/python/apache_beam/examples/ml_transform/vocab_tfidf_processing.py +++ b/sdks/python/apache_beam/examples/ml_transform/vocab_tfidf_processing.py @@ -66,6 +66,7 @@ def Shuffle(pcoll): class ReadAndShuffleData(beam.PTransform): + def __init__(self, pos_file_pattern, neg_file_pattern): self.pos_file_pattern = pos_file_pattern self.neg_file_pattern = neg_file_pattern @@ -94,8 +95,7 @@ def expand(self, pcoll): shuffled_examples | beam.Map( lambda label_review: { - REVIEW_COLUMN: label_review[0], - LABEL_COLUMN: label_review[1], + REVIEW_COLUMN: label_review[0], LABEL_COLUMN: label_review[1], RAW_DATA_KEY: label_review[0] })) @@ -144,6 +144,7 @@ def preprocess_data( class MapTFIDFScoreToVocab(beam.DoFn): + def __init__(self, artifact_location): self.artifact_location = artifact_location diff --git a/sdks/python/apache_beam/examples/online_clustering.py b/sdks/python/apache_beam/examples/online_clustering.py index 9746995dce77..5c71bc52353a 100644 --- a/sdks/python/apache_beam/examples/online_clustering.py +++ b/sdks/python/apache_beam/examples/online_clustering.py @@ -37,6 +37,7 @@ class SaveModel(core.DoFn): """Saves trained clustering model to persistent storage""" + def __init__(self, checkpoints_path: str): self.checkpoints_path = checkpoints_path @@ -64,6 +65,7 @@ def process(self, model): class AssignClusterLabelsFn(core.DoFn): """Takes a trained model and input data and labels all data instances using the trained model.""" + def process(self, batch, model, model_id): cluster_labels = model.predict(batch) for e, i in zip(batch, cluster_labels): @@ -72,6 +74,7 @@ def process(self, batch, model, model_id): class SelectLatestModelState(core.CombineFn): """Selects that latest version of a model after training""" + def create_accumulator(self): # create and initialise accumulator return None, 0 @@ -178,6 +181,7 @@ def process( class ConvertToNumpyArray(core.DoFn): """Helper function to convert incoming data to numpy arrays that are accepted by sklearn""" + def process(self, element, *args, **kwargs): if isinstance(element, (tuple, list)): yield np.array(element) @@ -190,6 +194,7 @@ def process(self, element, *args, **kwargs): class ClusteringPreprocessing(ptransform.PTransform): + def __init__( self, n_clusters: int, batch_size: int, is_batched: bool = False): """ Preprocessing for Clustering Transformation @@ -235,6 +240,7 @@ def expand(self, pcoll): class OnlineClustering(ptransform.PTransform): + def __init__( self, clustering_algorithm, @@ -303,6 +309,7 @@ def expand(self, pcoll): class AssignClusterLabelsRunInference(ptransform.PTransform): + def __init__(self, checkpoints_path): super().__init__() self.clustering_model = SklearnModelHandlerNumpy( @@ -318,6 +325,7 @@ def expand(self, pcoll): class AssignClusterLabelsInMemoryModel(ptransform.PTransform): + def __init__( self, model, n_clusters, batch_size, is_batched=False, model_id=None): self.model = model diff --git a/sdks/python/apache_beam/examples/per_entity_training.py b/sdks/python/apache_beam/examples/per_entity_training.py index e2b267492b27..7eef6d47c205 100644 --- a/sdks/python/apache_beam/examples/per_entity_training.py +++ b/sdks/python/apache_beam/examples/per_entity_training.py @@ -44,6 +44,7 @@ class CreateKey(beam.DoFn): + def process(self, element, *args, **kwargs): # 3rd column of the dataset is Education idx = 3 @@ -62,6 +63,7 @@ def custom_filter(element): class PrepareDataforTraining(beam.DoFn): """Preprocess data in a format suitable for training.""" + def process(self, element, *args, **kwargs): key, values = element #Convert to dataframe @@ -82,6 +84,7 @@ class TrainModel(beam.DoFn): normalizes numerical columns and then fits a decision tree classifier. """ + def process(self, element, *args, **kwargs): X, y, cat_ix, num_ix, key = element steps = [('c', OneHotEncoder(handle_unknown='ignore'), cat_ix), @@ -95,6 +98,7 @@ def process(self, element, *args, **kwargs): class ModelSink(fileio.FileSink): + def open(self, fh): self._fh = fh diff --git a/sdks/python/apache_beam/examples/snippets/snippets.py b/sdks/python/apache_beam/examples/snippets/snippets.py index c849af4a00b3..1833ae643862 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets.py +++ b/sdks/python/apache_beam/examples/snippets/snippets.py @@ -80,6 +80,7 @@ class RenameFiles(PipelineVisitor): This is as close as we can get to have code snippets that are executed and are also ready to presented in webdocs. """ + def __init__(self, renames): self.renames = renames @@ -199,6 +200,7 @@ def pipeline_options_remote(): from apache_beam.options.pipeline_options import PipelineOptions class MyOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument('--input') @@ -256,6 +258,7 @@ def pipeline_options_local(): from apache_beam.options.pipeline_options import PipelineOptions class MyOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -332,6 +335,7 @@ def pipeline_logging(lines, output): import logging class ExtractWordsFn(beam.DoFn): + def process(self, element): words = re.findall(r'[A-Za-z\']+', element) for word in words: @@ -360,12 +364,14 @@ def pipeline_monitoring(): import apache_beam as beam class ExtractWordsFn(beam.DoFn): + def process(self, element): words = re.findall(r'[A-Za-z\']+', element) for word in words: yield word class FormatCountsFn(beam.DoFn): + def process(self, element): word, count = element yield '%s: %s' % (word, count) @@ -419,6 +425,7 @@ def examples_wordcount_templated(): # [START example_wordcount_templated] class WordcountTemplatedOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): # Use add_value_provider_argument for arguments to be templatable @@ -521,11 +528,13 @@ def examples_ptransforms_templated(renames): from apache_beam.options.value_provider import StaticValueProvider class TemplatedUserOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument('--templated_int', type=int) class MySumFn(beam.DoFn): + def __init__(self, templated_int): self.templated_int = templated_int @@ -556,6 +565,7 @@ def process(self, an_int): # Defining a new source. # [START model_custom_source_new_source] class CountingSource(iobase.BoundedSource): + def __init__(self, count): self.records_read = Metrics.counter(self.__class__, 'recordsRead') self._count = count @@ -609,6 +619,7 @@ class _CountingSource(CountingSource): # [START model_custom_source_new_ptransform] class ReadFromCountingSource(PTransform): + def __init__(self, count): super().__init__() self._count = count @@ -682,6 +693,7 @@ def model_custom_source(count): # # [START model_custom_sink_new_sink] class SimpleKVSink(iobase.Sink): + def __init__(self, simplekv, url, final_table_name): self._simplekv = simplekv self._url = url @@ -710,6 +722,7 @@ def finalize_write(self, access_token, table_names, pre_finalize_result): # Defining a writer for the new sink. # [START model_custom_sink_new_writer] class SimpleKVWriter(iobase.Writer): + def __init__(self, simplekv, access_token, table_name): self._simplekv = simplekv self._access_token = access_token @@ -730,6 +743,7 @@ def close(self): # [START model_custom_sink_new_ptransform] class WriteToKVSink(PTransform): + def __init__(self, simplekv, url, final_table_name): self._simplekv = simplekv super().__init__() @@ -810,6 +824,7 @@ def model_custom_sink( def model_textio(renames): """Using a Read and Write transform to read/write text files.""" + def filter_words(x): import re return re.findall(r'[A-Za-z\']+', x) @@ -1317,6 +1332,7 @@ def join_info(name, emails, phone_numbers): # [START model_library_transforms_keys] class Keys(beam.PTransform): + def expand(self, pcoll): return pcoll | 'Keys' >> beam.Map(lambda k_v: k_v[0]) @@ -1327,6 +1343,7 @@ def expand(self, pcoll): # [START model_library_transforms_count] class Count(beam.PTransform): + def expand(self, pcoll): return ( pcoll @@ -1364,11 +1381,13 @@ def accessing_valueprovider_info_after_run(): from apache_beam.options.value_provider import RuntimeValueProvider class MyOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument('--string_value', type=str) class LogValueProvidersFn(beam.DoFn): + def __init__(self, string_vp): self.string_vp = string_vp @@ -1569,6 +1588,7 @@ def sdf_basic_example(): # [START SDF_BasicExample] class FileToWordsRestrictionProvider(beam.transforms.core.RestrictionProvider ): + def initial_restriction(self, file_name): return OffsetRange(0, os.stat(file_name).st_size) @@ -1576,6 +1596,7 @@ def create_tracker(self, restriction): return beam.io.restriction_trackers.OffsetRestrictionTracker() class FileToWordsFn(beam.DoFn): + def process( self, file_name, @@ -1603,6 +1624,7 @@ def sdf_basic_example_with_splitting(): # [START SDF_BasicExampleWithSplitting] class FileToWordsRestrictionProvider(beam.transforms.core.RestrictionProvider ): + def split(self, file_name, restriction): # Compute and output 64 MiB size ranges to process in parallel split_size = 64 * (1 << 20) @@ -1624,6 +1646,7 @@ class MyRestrictionProvider(object): # [START SDF_UserInitiatedCheckpoint] class MySplittableDoFn(beam.DoFn): + def process( self, element, @@ -1657,6 +1680,7 @@ def sdf_get_size(): # The RestrictionProvider is responsible for calculating the size of given # restriction. class MyRestrictionProvider(beam.transforms.core.RestrictionProvider): + def restriction_size(self, file_name, restriction): weight = 2 if "expensiveRecords" in file_name else 1 return restriction.size() * weight @@ -1665,6 +1689,7 @@ def restriction_size(self, file_name, restriction): def sdf_bad_try_claim_loop(): + class FileToWordsRestrictionProvider(object): pass @@ -1672,6 +1697,7 @@ class FileToWordsRestrictionProvider(object): # [START SDF_BadTryClaimLoop] class BadTryClaimLoop(beam.DoFn): + def process( self, file_name, @@ -1702,12 +1728,14 @@ class MyRestrictionProvider(object): # (Optional) Define a custom watermark state type to save information between # bundle processing rounds. class MyCustomerWatermarkEstimatorState(object): + def __init__(self, element, restriction): # Store data necessary for future watermark computations pass # Define a WatermarkEstimator class MyCustomWatermarkEstimator(WatermarkEstimator): + def __init__(self, estimator_state): self.state = estimator_state @@ -1727,6 +1755,7 @@ def get_estimator_state(self): # Then, a WatermarkEstimatorProvider needs to be created for this # WatermarkEstimator class MyWatermarkEstimatorProvider(WatermarkEstimatorProvider): + def initial_estimator_state(self, element, restriction): return MyCustomerWatermarkEstimatorState(element, restriction) @@ -1735,6 +1764,7 @@ def create_watermark_estimator(self, estimator_state): # Finally, define the SDF using your estimator. class MySplittableDoFn(beam.DoFn): + def process( self, element, @@ -1750,6 +1780,7 @@ def process( def sdf_truncate(): # [START SDF_Truncate] class MyRestrictionProvider(beam.transforms.core.RestrictionProvider): + def truncate(self, file_name, restriction): if "optional" in file_name: # Skip optional files @@ -1764,6 +1795,7 @@ def bundle_finalize(): # [START BundleFinalize] class MySplittableDoFn(beam.DoFn): + def process(self, element, bundle_finalizer=beam.DoFn.BundleFinalizerParam): # ... produce output ... diff --git a/sdks/python/apache_beam/examples/snippets/snippets_examples_wordcount_debugging.py b/sdks/python/apache_beam/examples/snippets/snippets_examples_wordcount_debugging.py index 8efa3d827bd0..ac3aac5ae0b8 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets_examples_wordcount_debugging.py +++ b/sdks/python/apache_beam/examples/snippets/snippets_examples_wordcount_debugging.py @@ -63,6 +63,7 @@ def examples_wordcount_debugging(renames): class FilterTextFn(beam.DoFn): """A DoFn that filters for a specific key based on a regular expression.""" + def __init__(self, pattern): self.pattern = pattern # A custom metric can track values in your pipeline as it runs. Create diff --git a/sdks/python/apache_beam/examples/snippets/snippets_examples_wordcount_wordcount.py b/sdks/python/apache_beam/examples/snippets/snippets_examples_wordcount_wordcount.py index 30e6fe7dced2..021094e3e9a1 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets_examples_wordcount_wordcount.py +++ b/sdks/python/apache_beam/examples/snippets/snippets_examples_wordcount_wordcount.py @@ -88,6 +88,7 @@ def CountWords(pcoll): # [START examples_wordcount_wordcount_dofn] class FormatAsTextFn(beam.DoFn): + def process(self, element): word, count = element yield '%s: %s' % (word, count) diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py b/sdks/python/apache_beam/examples/snippets/snippets_test.py index 54a57673b5f4..3c4aba1815a0 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets_test.py +++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py @@ -92,6 +92,7 @@ class ParDoTest(unittest.TestCase): """Tests for model/par-do.""" + def test_pardo(self): # Note: "words" and "ComputeWordLengthFn" are referenced by name in # the text of the doc. @@ -100,6 +101,7 @@ def test_pardo(self): # [START model_pardo_pardo] class ComputeWordLengthFn(beam.DoFn): + def process(self, element): return [len(element)] @@ -116,6 +118,7 @@ def test_pardo_yield(self): # [START model_pardo_yield] class ComputeWordLengthFn(beam.DoFn): + def process(self, element): yield len(element) @@ -209,6 +212,7 @@ def test_pardo_side_input_dofn(self): # [START model_pardo_side_input_dofn] class FilterUsingLength(beam.DoFn): + def process(self, element, lower_bound, upper_bound=float('inf')): if lower_bound <= len(element) <= upper_bound: yield element @@ -220,6 +224,7 @@ def process(self, element, lower_bound, upper_bound=float('inf')): def test_pardo_with_tagged_outputs(self): # [START model_pardo_emitting_values_on_tagged_outputs] class ProcessWords(beam.DoFn): + def process(self, element, cutoff_length, marker): if len(element) <= cutoff_length: # Emit this short word to the main output. @@ -289,6 +294,7 @@ def even_odd(x): class TypeHintsTest(unittest.TestCase): + def test_bad_types(self): # [START type_hints_missing_define_numbers] p = TestPipeline() @@ -322,6 +328,7 @@ def test_bad_types(self): # [START type_hints_do_fn] @beam.typehints.with_input_types(int) class FilterEvensDoFn(beam.DoFn): + def process(self, element): if element % 2 == 0: yield element @@ -340,6 +347,7 @@ def process(self, element): @beam.typehints.with_input_types(T) @beam.typehints.with_output_types(Tuple[int, T]) class MyTransform(beam.PTransform): + def expand(self, pcoll): return pcoll | beam.Map(lambda x: (len(x), x)) @@ -362,6 +370,7 @@ def test_bad_types_annotations(self): # pylint: disable=expression-not-assigned # pylint: disable=unused-variable class FilterEvensDoFn(beam.DoFn): + def process(self, element): if element % 2 == 0: yield element @@ -383,6 +392,7 @@ def process(self, element): from typing import Iterable class TypedFilterEvensDoFn(beam.DoFn): + def process(self, element: int) -> Iterable[int]: if element % 2 == 0: yield element @@ -397,6 +407,7 @@ def process(self, element: int) -> Iterable[int]: from typing import List, Optional class FilterEvensDoubleDoFn(beam.DoFn): + def process(self, element: int) -> Optional[List[int]]: if element % 2 == 0: return [element, element] @@ -420,6 +431,7 @@ def my_fn(element: int) -> str: from apache_beam.pvalue import PCollection class IntToStr(beam.PTransform): + def expand(self, pcoll: PCollection[int]) -> PCollection[str]: return pcoll | beam.Map(lambda elem: str(elem)) @@ -463,11 +475,13 @@ def test_deterministic_key(self): from typing import Tuple class Player(object): + def __init__(self, team, name): self.team = team self.name = name class PlayerCoder(beam.coders.Coder): + def encode(self, player): return ('%s:%s' % (player.team, player.name)).encode('utf-8') @@ -502,11 +516,13 @@ class DummyReadTransform(beam.PTransform): To be used for testing. """ + def __init__(self, file_to_read=None, compression_type=None): self.file_to_read = file_to_read self.compression_type = compression_type class ReadDoFn(beam.DoFn): + def __init__(self, file_to_read, compression_type): self.file_to_read = file_to_read self.compression_type = compression_type @@ -541,10 +557,12 @@ class DummyWriteTransform(beam.PTransform): To be used for testing. """ + def __init__(self, file_to_write=None, file_name_suffix=''): self.file_to_write = file_to_write class WriteDoFn(beam.DoFn): + def __init__(self, file_to_write): self.file_to_write = file_to_write self.file_obj = None @@ -643,6 +661,7 @@ def test_model_custom_sink(self): tempdir_name = tempfile.mkdtemp() class SimpleKV(object): + def __init__(self, tmp_dir): self._dummy_token = 'dummy_token' self._tmp_dir = tmp_dir @@ -841,6 +860,7 @@ def test_examples_wordcount_debugging(self): @mock.patch('apache_beam.io.ReadFromPubSub') @mock.patch('apache_beam.io.WriteToPubSub') def test_examples_wordcount_streaming(self, *unused_mocks): + def FakeReadFromPubSub(topic=None, subscription=None, values=None): expected_topic = topic expected_subscription = subscription @@ -853,6 +873,7 @@ def _inner(topic=None, subscription=None): return _inner class AssertTransform(beam.PTransform): + def __init__(self, matcher): self.matcher = matcher @@ -998,8 +1019,8 @@ def test_model_co_group_by_key_tuple(self): ] # [END model_group_by_key_cogroupbykey_tuple_formatted_outputs] expected_results = [ - '%s; %s; %s' % (name, info['emails'], info['phones']) for name, - info in results + '%s; %s; %s' % (name, info['emails'], info['phones']) + for name, info in results ] self.assertEqual(expected_results, formatted_results) self.assertEqual(formatted_results, self.get_output(result_path)) @@ -1017,6 +1038,7 @@ def test_model_use_and_query_metrics(self): # [START metrics_usage_example] class FilterTextFn(beam.DoFn): """A DoFn that filters for a specific key based on a regex.""" + def __init__(self, pattern): self.pattern = pattern # A custom metric can track values in your pipeline as it runs. Create @@ -1192,6 +1214,7 @@ def test_model_other_composite_triggers(self): class CombineTest(unittest.TestCase): """Tests for model/combine.""" + def test_global_sum(self): pc = [1, 2, 3] # [START global_sum] @@ -1259,6 +1282,7 @@ def test_custom_average(self): # [START combine_custom_average_define] class AverageFn(beam.CombineFn): + def create_accumulator(self): return (0.0, 0) @@ -1380,6 +1404,7 @@ def extract_timestamp_from_log_entry(entry): # [START setting_timestamp] class AddTimestampDoFn(beam.DoFn): + def process(self, element): # Extract the numeric Unix seconds-since-epoch timestamp to be # associated with the current log entry. @@ -1403,10 +1428,12 @@ def process(self, element): class PTransformTest(unittest.TestCase): """Tests for PTransform.""" + def test_composite(self): # [START model_composite_transform] class ComputeWordLengths(beam.PTransform): + def expand(self, pcoll): # Transform logic goes here. return pcoll | beam.Map(lambda x: len(x)) @@ -1420,6 +1447,7 @@ def expand(self, pcoll): class SlowlyChangingSideInputsTest(unittest.TestCase): """Tests for PTransform.""" + def test_side_input_slow_update(self): temp_file = tempfile.NamedTemporaryFile(delete=True) src_file_pattern = temp_file.name @@ -1441,12 +1469,13 @@ def test_side_input_slow_update(self): for j in range(count): f.write('f' + idstr + 'a' + str(j) + '\n') - sample_main_input_elements = ([first_ts - 2, # no output due to no SI - first_ts + 1, # First window - first_ts + 8, # Second window - first_ts + 15, # Third window - first_ts + 22, # Fourth window - ]) + sample_main_input_elements = ([ + first_ts - 2, # no output due to no SI + first_ts + 1, # First window + first_ts + 8, # Second window + first_ts + 15, # Third window + first_ts + 22, # Fourth window + ]) pipeline, pipeline_result = snippets.side_input_slow_update( src_file_pattern, first_ts, last_ts, interval, diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/approximatequantiles_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/approximatequantiles_test.py index 2adfcd05b99a..e8ec7e794cbe 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/approximatequantiles_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/approximatequantiles_test.py @@ -32,10 +32,11 @@ @mock.patch('apache_beam.Pipeline', TestPipeline) @mock.patch( 'apache_beam.examples.snippets.transforms.aggregation.' - 'approximatequantiles.print', - lambda x: x) + 'approximatequantiles.print', lambda x: x) class ApproximateQuantilesTest(unittest.TestCase): + def test_approximatequantiles(self): + def check_result(quantiles): assert_that(quantiles, equal_to([[0, 250, 500, 750, 1000]])) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/approximateunique_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/approximateunique_test.py index c945cec534b8..96dd1a58746c 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/approximateunique_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/approximateunique_test.py @@ -33,10 +33,11 @@ @mock.patch('apache_beam.Pipeline', TestPipeline) @mock.patch( 'apache_beam.examples.snippets.transforms.aggregation.' - 'approximateunique.print', - lambda x: x) + 'approximateunique.print', lambda x: x) class ApproximateUniqueTest(unittest.TestCase): + def test_approximateunique(self): + def check_result(approx_count): actual_count = 1000 sample_size = 16 diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/batchelements_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/batchelements_test.py index b370f8cd16be..bd29fa291ef5 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/batchelements_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/batchelements_test.py @@ -57,6 +57,7 @@ def identity(x): identity) # pylint: enable=line-too-long class BatchElementsTest(unittest.TestCase): + def test_batchelements(self): batchelements.batchelements(check_batches) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/cogroupbykey_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/cogroupbykey_test.py index ad4ed99a6e2b..e4d31b36a0a3 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/cogroupbykey_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/cogroupbykey_test.py @@ -51,6 +51,7 @@ def normalize_element(elem): 'apache_beam.examples.snippets.transforms.aggregation.cogroupbykey.print', str) class CoGroupByKeyTest(unittest.TestCase): + def test_cogroupbykey(self): cogroupbykey.cogroupbykey(check_plants) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineglobally_combinefn.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineglobally_combinefn.py index 18e15e9e9bb9..75fae060fa70 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineglobally_combinefn.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineglobally_combinefn.py @@ -39,6 +39,7 @@ def combineglobally_combinefn(test=None): import apache_beam as beam class PercentagesFn(beam.CombineFn): + def create_accumulator(self): return {} diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineglobally_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineglobally_test.py index e7435c644e24..b8886ba4e1e1 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineglobally_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineglobally_test.py @@ -81,6 +81,7 @@ def check_percentages(actual): str) # pylint: enable=line-too-long class CombineGloballyTest(unittest.TestCase): + def test_combineglobally_function(self): combineglobally_function.combineglobally_function(check_common_items) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey_combinefn.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey_combinefn.py index 5b5606401b5f..4b9f2d03bb15 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey_combinefn.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey_combinefn.py @@ -39,6 +39,7 @@ def combineperkey_combinefn(test=None): import apache_beam as beam class AverageFn(beam.CombineFn): + def create_accumulator(self): sum = 0.0 count = 0 diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey_side_inputs_singleton.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey_side_inputs_singleton.py index b20571bde62d..eb6182667173 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey_side_inputs_singleton.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey_side_inputs_singleton.py @@ -37,8 +37,7 @@ def combineperkey_side_inputs_singleton(test=None): ('🍅', 3), ]) | 'Saturated sum' >> beam.CombinePerKey( - lambda values, - max_value: min(sum(values), max_value), + lambda values, max_value: min(sum(values), max_value), max_value=beam.pvalue.AsSingleton(max_value)) | beam.Map(print)) # [END combineperkey_side_inputs_singleton] diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey_test.py index 6492734246af..1162ade2c27c 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey_test.py @@ -101,6 +101,7 @@ def check_average(actual): 'apache_beam.examples.snippets.transforms.aggregation.combineperkey_combinefn.print', str) class CombinePerKeyTest(unittest.TestCase): + def test_combineperkey_simple(self): combineperkey_simple.combineperkey_simple(check_total) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combinevalues_combinefn.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combinevalues_combinefn.py index d32b33424484..0b3028507bac 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combinevalues_combinefn.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combinevalues_combinefn.py @@ -39,6 +39,7 @@ def combinevalues_combinefn(test=None): import apache_beam as beam class AverageFn(beam.CombineFn): + def create_accumulator(self): return {} diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combinevalues_side_inputs_singleton.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combinevalues_side_inputs_singleton.py index efec1635c18e..c07697b2c88b 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combinevalues_side_inputs_singleton.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combinevalues_side_inputs_singleton.py @@ -49,8 +49,7 @@ def combinevalues_side_inputs_singleton(test=None): ('🍅', [4, 5, 3]), ]) | 'Saturated sum' >> beam.CombineValues( - lambda values, - max_value: min(sum(values), max_value), + lambda values, max_value: min(sum(values), max_value), max_value=beam.pvalue.AsSingleton(max_value)) | beam.Map(print)) # [END combinevalues_side_inputs_singleton] diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combinevalues_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combinevalues_test.py index 8693e3f300fd..3485306cff69 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combinevalues_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/combinevalues_test.py @@ -95,6 +95,7 @@ def check_percentages_per_season(actual): 'apache_beam.examples.snippets.transforms.aggregation.combinevalues_combinefn.print', str) class CombineValuesTest(unittest.TestCase): + def test_combinevalues_function(self): combinevalues_function.combinevalues_function(check_saturated_total) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/count_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/count_test.py index 59f95a1eec1e..cf2b0a03e744 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/count_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/count_test.py @@ -70,6 +70,7 @@ def check_total_unique_elements(actual): 'apache_beam.examples.snippets.transforms.aggregation.count_per_element.print', str) class CountTest(unittest.TestCase): + def test_count_globally(self): count_globally.count_globally(check_total_elements) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/distinct_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/distinct_test.py index 6f1246ee363a..098a62499d89 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/distinct_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/distinct_test.py @@ -41,6 +41,7 @@ def check_unique_elements(actual): @mock.patch( 'apache_beam.examples.snippets.transforms.aggregation.distinct.print', str) class DistinctTest(unittest.TestCase): + def test_distinct(self): distinct.distinct(check_unique_elements) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/groupbykey_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/groupbykey_test.py index 1a9b8f158c1a..33f77af17310 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/groupbykey_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/groupbykey_test.py @@ -45,6 +45,7 @@ def check_produce_counts(actual): 'apache_beam.examples.snippets.transforms.aggregation.groupbykey.print', str) class GroupByKeyTest(unittest.TestCase): + def test_groupbykey(self): groupbykey.groupbykey(check_produce_counts) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/groupintobatches_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/groupintobatches_test.py index 09449f77f6e4..4619c1d68b4b 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/groupintobatches_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/groupintobatches_test.py @@ -45,6 +45,7 @@ def check_batches_with_keys(actual): str) # pylint: enable=line-too-long class GroupIntoBatchesTest(unittest.TestCase): + def test_groupintobatches(self): groupintobatches.groupintobatches(check_batches_with_keys) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/latest_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/latest_test.py index 0661fa862870..a8e4531d21aa 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/latest_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/latest_test.py @@ -55,6 +55,7 @@ def check_latest_elements_per_key(actual): 'apache_beam.examples.snippets.transforms.aggregation.latest_per_key.print', str) class LatestTest(unittest.TestCase): + def test_latest_globally(self): latest_globally.latest_globally(check_latest_element) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/max_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/max_test.py index 1851d555d282..7d0f04cd649c 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/max_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/max_test.py @@ -53,6 +53,7 @@ def check_elements_with_max_value_per_key(actual): 'apache_beam.examples.snippets.transforms.aggregation.max_per_key.print', str) class MaxTest(unittest.TestCase): + def test_max_globally(self): beam_max_globally.max_globally(check_max_element) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/mean_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/mean_test.py index 7d1420ef0ce8..b9336c66a150 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/mean_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/mean_test.py @@ -53,6 +53,7 @@ def check_elements_with_mean_value_per_key(actual): 'apache_beam.examples.snippets.transforms.aggregation.mean_per_key.print', str) class MeanTest(unittest.TestCase): + def test_mean_globally(self): mean_globally.mean_globally(check_mean_element) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/min_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/min_test.py index 65b970800deb..182682f0ce1a 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/min_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/min_test.py @@ -51,6 +51,7 @@ def check_elements_with_min_value_per_key(actual): 'apache_beam.examples.snippets.transforms.aggregation.min_per_key.print', str) class MinTest(unittest.TestCase): + def test_min_globally(self): beam_min_globally.min_globally(check_min_element) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sample_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sample_test.py index a4ecde5bdaba..74716499dd20 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sample_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sample_test.py @@ -57,6 +57,7 @@ def check_samples_per_key(actual): 'apache_beam.examples.snippets.transforms.aggregation.sample_fixed_size_per_key.print', str) class SampleTest(unittest.TestCase): + def test_sample_fixed_size_globally(self): sample_fixed_size_globally.sample_fixed_size_globally(check_sample) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sum_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sum_test.py index 122b8c892d71..d926644e38f3 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sum_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sum_test.py @@ -51,6 +51,7 @@ def check_totals_per_key(actual): 'apache_beam.examples.snippets.transforms.aggregation.sum_per_key.print', str) class SumTest(unittest.TestCase): + def test_sum_globally(self): beam_sum_globally.sum_globally(check_total) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/tolist_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/tolist_test.py index 8f5235000c49..d46c105d6bdf 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/tolist_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/tolist_test.py @@ -39,7 +39,9 @@ def identity(x): identity) # pylint: enable=line-too-long class BatchElementsTest(unittest.TestCase): + def test_tolist(self): + def check(result): assert_that( result diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/top_of.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/top_of.py index 51094a993de9..8596e9fcfec8 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/top_of.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/top_of.py @@ -48,12 +48,11 @@ def top_of(test=None): '🌽 Corn', ]) | 'Shortest names' >> beam.combiners.Top.Of( - 2, # number of elements - key=len, # optional, defaults to the element itself + 2, # number of elements + key=len, # optional, defaults to the element itself reverse=True, # optional, defaults to False (largest/descending) ) - | beam.Map(print) - ) + | beam.Map(print)) # [END top_of] if test: test(shortest_elements) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/top_per_key.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/top_per_key.py index 676f10ffc310..246de87f55ea 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/top_per_key.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/top_per_key.py @@ -51,12 +51,11 @@ def top_per_key(test=None): ('winter', '🍆 Eggplant'), ]) | 'Shortest names per key' >> beam.combiners.Top.PerKey( - 2, # number of elements - key=len, # optional, defaults to the value itself + 2, # number of elements + key=len, # optional, defaults to the value itself reverse=True, # optional, defaults to False (largest/descending) ) - | beam.Map(print) - ) + | beam.Map(print)) # [END top_per_key] if test: test(shortest_elements_per_key) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/top_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/top_test.py index e928088f2ed5..20b360cfe096 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/top_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/top_test.py @@ -101,6 +101,7 @@ def check_shortest_elements_per_key(actual): 'apache_beam.examples.snippets.transforms.aggregation.top_per_key.print', str) class TopTest(unittest.TestCase): + def test_top_largest(self): top_largest.top_largest(check_largest_elements) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py index 8a7cdfbe9263..506773000893 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py @@ -62,6 +62,7 @@ def validate_enrichment_with_vertex_ai_legacy(): @mock.patch('sys.stdout', new_callable=StringIO) class EnrichmentTest(unittest.TestCase): + def test_enrichment_with_bigtable(self, mock_stdout): enrichment_with_bigtable() output = mock_stdout.getvalue().splitlines() diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_side_inputs_dict.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_side_inputs_dict.py index 64a4b0aa97c5..a969765e4d12 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_side_inputs_dict.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_side_inputs_dict.py @@ -65,8 +65,7 @@ def filter_side_inputs_dict(test=None): }, ]) | 'Filter plants by duration' >> beam.Filter( - lambda plant, - keep_duration: keep_duration[plant['duration']], + lambda plant, keep_duration: keep_duration[plant['duration']], keep_duration=beam.pvalue.AsDict(keep_duration), ) | beam.Map(print)) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_side_inputs_iter.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_side_inputs_iter.py index 42043a38c35b..9db1c61443e4 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_side_inputs_iter.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_side_inputs_iter.py @@ -65,8 +65,7 @@ def filter_side_inputs_iter(test=None): }, ]) | 'Filter valid plants' >> beam.Filter( - lambda plant, - valid_durations: plant['duration'] in valid_durations, + lambda plant, valid_durations: plant['duration'] in valid_durations, valid_durations=beam.pvalue.AsIter(valid_durations), ) | beam.Map(print)) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_side_inputs_singleton.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_side_inputs_singleton.py index 5971082becd3..34662b4e257b 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_side_inputs_singleton.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_side_inputs_singleton.py @@ -61,8 +61,7 @@ def filter_side_inputs_singleton(test=None): }, ]) | 'Filter perennials' >> beam.Filter( - lambda plant, - duration: plant['duration'] == duration, + lambda plant, duration: plant['duration'] == duration, duration=beam.pvalue.AsSingleton(perennial), ) | beam.Map(print)) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_test.py index fee9e2f03ead..67aeddad3989 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/filter_test.py @@ -73,6 +73,7 @@ def check_valid_plants(actual): 'apache_beam.examples.snippets.transforms.elementwise.filter_side_inputs_dict.print', str) class FilterTest(unittest.TestCase): + def test_filter_function(self): filter_function.filter_function(check_perennials) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/flatmap_side_inputs_singleton.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/flatmap_side_inputs_singleton.py index 84aed00b1046..6abf3f485a18 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/flatmap_side_inputs_singleton.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/flatmap_side_inputs_singleton.py @@ -48,8 +48,7 @@ def flatmap_side_inputs_singleton(test=None): '🍅Tomato,🥔Potato', ]) | 'Split words' >> beam.FlatMap( - lambda text, - delimiter: text.split(delimiter), + lambda text, delimiter: text.split(delimiter), delimiter=beam.pvalue.AsSingleton(delimiter), ) | beam.Map(print)) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/flatmap_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/flatmap_test.py index 6dd02a208a74..7bd6e74dda12 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/flatmap_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/flatmap_test.py @@ -91,6 +91,7 @@ def check_valid_plants(actual): 'apache_beam.examples.snippets.transforms.elementwise.flatmap_side_inputs_dict.print', str) class FlatMapTest(unittest.TestCase): + def test_flatmap_simple(self): flatmap_simple.flatmap_simple(check_plants) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/keys_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/keys_test.py index fcfe370234f1..6dee2d5f98f2 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/keys_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/keys_test.py @@ -43,6 +43,7 @@ def check_icons(actual): @mock.patch( 'apache_beam.examples.snippets.transforms.elementwise.keys.print', str) class KeysTest(unittest.TestCase): + def test_keys(self): keys.keys(check_icons) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/kvswap_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/kvswap_test.py index fa494314bd3b..549245b4d7d0 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/kvswap_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/kvswap_test.py @@ -43,6 +43,7 @@ def check_plants(actual): @mock.patch( 'apache_beam.examples.snippets.transforms.elementwise.kvswap.print', str) class KvSwapTest(unittest.TestCase): + def test_kvswap(self): kvswap.kvswap(check_plants) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_context.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_context.py index 26b5558928cb..a827e72b3f3a 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_context.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_context.py @@ -71,9 +71,8 @@ def random_nonce(): ], reshuffle=False) | 'Strip header' >> beam.Map( - lambda text, - a=beam.DoFn.SetupContextParam(random_nonce), - b=beam.DoFn.BundleContextParam(random_nonce): f"{text} {a} {b}") + lambda text, a=beam.DoFn.SetupContextParam(random_nonce), b=beam. + DoFn.BundleContextParam(random_nonce): f"{text} {a} {b}") | beam.Map(print)) # [END map_context] if test: diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_side_inputs_iter.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_side_inputs_iter.py index c155764a0cd4..f0be086e918c 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_side_inputs_iter.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_side_inputs_iter.py @@ -51,8 +51,7 @@ def map_side_inputs_iter(test=None): '# 🥔Potato\n', ]) | 'Strip header' >> beam.Map( - lambda text, - chars: text.strip(''.join(chars)), + lambda text, chars: text.strip(''.join(chars)), chars=beam.pvalue.AsIter(chars), ) | beam.Map(print)) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_side_inputs_singleton.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_side_inputs_singleton.py index 323134e315ea..8dc6ebb354a9 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_side_inputs_singleton.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_side_inputs_singleton.py @@ -51,8 +51,7 @@ def map_side_inputs_singleton(test=None): '# 🥔Potato\n', ]) | 'Strip header' >> beam.Map( - lambda text, - chars: text.strip(chars), + lambda text, chars: text.strip(chars), chars=beam.pvalue.AsSingleton(chars), ) | beam.Map(print)) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_test.py index db766bb8e7c2..2b0e2c06aa37 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/map_test.py @@ -88,6 +88,7 @@ def check_plant_details(actual): 'apache_beam.examples.snippets.transforms.elementwise.map_context.print', str) class MapTest(unittest.TestCase): + def test_map_simple(self): map_simple.map_simple(check_plants) @@ -116,6 +117,7 @@ def test_map_context(self): import re def check_nonces(output): + def shares_same_nonces(elements): s = set(re.search(r'\d+ \d+', e).group(0) for e in elements) assert len(s) == 1, s diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/mltransform_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/mltransform_test.py index 261b480b1083..f046e1bd4242 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/mltransform_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/mltransform_test.py @@ -68,6 +68,7 @@ def check_mltransform_compute_and_apply_vocabulary_with_scalar(): @mock.patch('apache_beam.Pipeline', TestPipeline) @mock.patch('sys.stdout', new_callable=StringIO) class MLTransformStdOutTest(unittest.TestCase): + def test_mltransform_compute_and_apply_vocab(self, mock_stdout): mltransform_compute_and_apply_vocabulary() predicted = mock_stdout.getvalue().splitlines() diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_dofn.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_dofn.py index 0a6c27a21d8f..eebb847c5ee8 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_dofn.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_dofn.py @@ -38,6 +38,7 @@ def pardo_dofn(test=None): import apache_beam as beam class SplitWords(beam.DoFn): + def __init__(self, delimiter=','): self.delimiter = delimiter diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_dofn_methods.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_dofn_methods.py index 868519602569..cd0f1d8015aa 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_dofn_methods.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_dofn_methods.py @@ -38,6 +38,7 @@ def pardo_dofn_methods(test=None): import apache_beam as beam class DoFnMethods(beam.DoFn): + def __init__(self): print('__init__') self.window = beam.transforms.window.GlobalWindow() diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_dofn_params.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_dofn_params.py index ae777836d210..1f78aa13b06d 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_dofn_params.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_dofn_params.py @@ -39,6 +39,7 @@ def pardo_dofn_params(test=None): import apache_beam as beam class AnalyzeElement(beam.DoFn): + def process( self, elem, diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_test.py index 1b72640f5c5c..295834389365 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/pardo_test.py @@ -89,6 +89,7 @@ def check_dofn_methods(actual): 'apache_beam.examples.snippets.transforms.elementwise.pardo_dofn_params.print', str) class ParDoTest(unittest.TestCase): + def test_pardo_dofn(self): pardo_dofn.pardo_dofn(check_plants) @@ -99,6 +100,7 @@ def test_pardo_dofn_params(self): @mock.patch('apache_beam.Pipeline', TestPipeline) @mock.patch('sys.stdout', new_callable=StringIO) class ParDoStdoutTest(unittest.TestCase): + def test_pardo_dofn_methods(self, mock_stdout): expected = pardo_dofn_methods.pardo_dofn_methods(check_dofn_methods) actual = mock_stdout.getvalue().splitlines() diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/partition_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/partition_test.py index 3ffc49abe758..80bc614bc1b5 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/partition_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/partition_test.py @@ -91,6 +91,7 @@ def check_split_datasets(actual1, actual2): 'apache_beam.examples.snippets.transforms.elementwise.partition_multiple_arguments.print', lambda elem: elem) class PartitionTest(unittest.TestCase): + def test_partition_function(self): partition_function.partition_function(check_partitions) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/regex_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/regex_test.py index 2ae53089173b..310d7c83f5d3 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/regex_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/regex_test.py @@ -156,6 +156,7 @@ def check_split(actual): 'apache_beam.examples.snippets.transforms.elementwise.regex_split.print', str) class RegexTest(unittest.TestCase): + def test_matches(self): regex_matches.regex_matches(check_matches) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py index 8021a60929a1..2a041ae9c2c8 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py @@ -23,6 +23,7 @@ class LinearRegression(torch.nn.Module): + def __init__(self, input_dim=1, output_dim=1): super().__init__() self.linear = torch.nn.Linear(input_dim, output_dim) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference_test.py index 8dd46c659f39..90eed2f7b515 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference_test.py @@ -93,6 +93,7 @@ def check_sklearn_unkeyed_model_handler(actual): 'apache_beam.examples.snippets.transforms.elementwise.runinference_sklearn_keyed_model_handler.print', str) class RunInferenceTest(unittest.TestCase): + def test_sklearn_unkeyed_model_handler(self): runinference_sklearn_unkeyed_model_handler.sklearn_unkeyed_model_handler( check_sklearn_unkeyed_model_handler) @@ -105,6 +106,7 @@ def test_sklearn_keyed_model_handler(self): @mock.patch('apache_beam.Pipeline', TestPipeline) @mock.patch('sys.stdout', new_callable=StringIO) class RunInferenceStdoutTest(unittest.TestCase): + @pytest.mark.uses_pytorch def test_check_torch_keyed_model_handler(self, mock_stdout): runinference.torch_keyed_model_handler() diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/tostring_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/tostring_test.py index e63282a75737..c263b3e18d06 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/tostring_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/tostring_test.py @@ -75,6 +75,7 @@ def check_plants_csv(actual): 'apache_beam.examples.snippets.transforms.elementwise.tostring_iterables.print', str) class ToStringTest(unittest.TestCase): + def test_tostring_kvs(self): tostring_kvs.tostring_kvs(check_plants) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/values_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/values_test.py index acd84215a69d..77d6d4b39795 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/values_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/values_test.py @@ -43,6 +43,7 @@ def check_plants(actual): @mock.patch( 'apache_beam.examples.snippets.transforms.elementwise.values.print', str) class ValuesTest(unittest.TestCase): + def test_values(self): values.values(check_plants) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_event_time.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_event_time.py index c5e4013e0300..3f2cbb5ea327 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_event_time.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_event_time.py @@ -38,6 +38,7 @@ def withtimestamps_event_time(test=None): import apache_beam as beam class GetTimestamp(beam.DoFn): + def process(self, plant, timestamp=beam.DoFn.TimestampParam): yield '{} - {}'.format(timestamp.to_utc_datetime(), plant['name']) @@ -45,17 +46,26 @@ def process(self, plant, timestamp=beam.DoFn.TimestampParam): plant_timestamps = ( pipeline | 'Garden plants' >> beam.Create([ - {'name': 'Strawberry', 'season': 1585699200}, # April, 2020 - {'name': 'Carrot', 'season': 1590969600}, # June, 2020 - {'name': 'Artichoke', 'season': 1583020800}, # March, 2020 - {'name': 'Tomato', 'season': 1588291200}, # May, 2020 - {'name': 'Potato', 'season': 1598918400}, # September, 2020 + { + 'name': 'Strawberry', 'season': 1585699200 + }, # April, 2020 + { + 'name': 'Carrot', 'season': 1590969600 + }, # June, 2020 + { + 'name': 'Artichoke', 'season': 1583020800 + }, # March, 2020 + { + 'name': 'Tomato', 'season': 1588291200 + }, # May, 2020 + { + 'name': 'Potato', 'season': 1598918400 + }, # September, 2020 ]) | 'With timestamps' >> beam.Map( lambda plant: beam.window.TimestampedValue(plant, plant['season'])) | 'Get timestamp' >> beam.ParDo(GetTimestamp()) - | beam.Map(print) - ) + | beam.Map(print)) # [END withtimestamps_event_time] if test: test(plant_timestamps) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_logical_clock.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_logical_clock.py index 9a2ba36bacd6..5ce187d807d6 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_logical_clock.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_logical_clock.py @@ -38,6 +38,7 @@ def withtimestamps_logical_clock(test=None): import apache_beam as beam class GetTimestamp(beam.DoFn): + def process(self, plant, timestamp=beam.DoFn.TimestampParam): event_id = int(timestamp.micros / 1e6) # equivalent to seconds yield '{} - {}'.format(event_id, plant['name']) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_processing_time.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_processing_time.py index eacb02d99850..871c16b2bc09 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_processing_time.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_processing_time.py @@ -39,6 +39,7 @@ def withtimestamps_processing_time(test=None): import time class GetTimestamp(beam.DoFn): + def process(self, plant, timestamp=beam.DoFn.TimestampParam): yield '{} - {}'.format(timestamp.to_utc_datetime(), plant['name']) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_test.py index 9a86a6f54c5e..9184b917e538 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/withtimestamps_test.py @@ -82,6 +82,7 @@ def check_plant_processing_times(actual): 'apache_beam.examples.snippets.transforms.elementwise.withtimestamps_processing_time.print', str) class WithTimestampsTest(unittest.TestCase): + def test_event_time(self): withtimestamps_event_time.withtimestamps_event_time(check_plant_timestamps) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/other/create_test.py b/sdks/python/apache_beam/examples/snippets/transforms/other/create_test.py index 0457fdae7779..4ed85b4c9ee8 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/other/create_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/other/create_test.py @@ -43,6 +43,7 @@ def check_create(actual): @mock.patch('apache_beam.Pipeline', TestPipeline) @mock.patch('apache_beam.examples.snippets.transforms.other.create.print', str) class CreateTest(unittest.TestCase): + def test_create(self): create.create(check_create) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/other/flatten_test.py b/sdks/python/apache_beam/examples/snippets/transforms/other/flatten_test.py index bd029c6eeb28..5fe473d5a82e 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/other/flatten_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/other/flatten_test.py @@ -45,6 +45,7 @@ def check_flatten(actual): @mock.patch('apache_beam.Pipeline', TestPipeline) @mock.patch('apache_beam.examples.snippets.transforms.other.flatten.print', str) class FlattenTest(unittest.TestCase): + def test_flatten(self): flatten.flatten(check_flatten) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/other/window.py b/sdks/python/apache_beam/examples/snippets/transforms/other/window.py index 484917e77658..8f7a1047792c 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/other/window.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/other/window.py @@ -41,22 +41,38 @@ def window(test=None): import apache_beam as beam with beam.Pipeline() as pipeline: - produce = (pipeline - | 'Garden plants' >> beam.Create([ - {'name': 'Strawberry', 'season': 1585699200}, # April, 2020 - {'name': 'Strawberry', 'season': 1588291200}, # May, 2020 - {'name': 'Carrot', 'season': 1590969600}, # June, 2020 - {'name': 'Artichoke', 'season': 1583020800}, # March, 2020 - {'name': 'Artichoke', 'season': 1585699200}, # April, 2020 - {'name': 'Tomato', 'season': 1588291200}, # May, 2020 - {'name': 'Potato', 'season': 1598918400}, # September, 2020 - ]) - | 'With timestamps' >> beam.Map(lambda plant: beam.window.TimestampedValue(plant['name'], plant['season'])) - | 'Window into fixed 2-month windows' >> beam.WindowInto( - beam.window.FixedWindows(2 * 30 * 24 * 60 * 60)) - | 'Count per window' >> beam.combiners.Count.PerElement() - | 'Print results' >> beam.Map(print) - ) + produce = ( + pipeline + | 'Garden plants' >> beam.Create([ + { + 'name': 'Strawberry', 'season': 1585699200 + }, # April, 2020 + { + 'name': 'Strawberry', 'season': 1588291200 + }, # May, 2020 + { + 'name': 'Carrot', 'season': 1590969600 + }, # June, 2020 + { + 'name': 'Artichoke', 'season': 1583020800 + }, # March, 2020 + { + 'name': 'Artichoke', 'season': 1585699200 + }, # April, 2020 + { + 'name': 'Tomato', 'season': 1588291200 + }, # May, 2020 + { + 'name': 'Potato', 'season': 1598918400 + }, # September, 2020 + ]) + | 'With timestamps' >> beam.Map( + lambda plant: beam.window.TimestampedValue( + plant['name'], plant['season'])) + | 'Window into fixed 2-month windows' >> beam.WindowInto( + beam.window.FixedWindows(2 * 30 * 24 * 60 * 60)) + | 'Count per window' >> beam.combiners.Count.PerElement() + | 'Print results' >> beam.Map(print)) # [END window] if test: diff --git a/sdks/python/apache_beam/examples/snippets/transforms/other/window_test.py b/sdks/python/apache_beam/examples/snippets/transforms/other/window_test.py index 6fb53b04b287..a95a4ac80fd6 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/other/window_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/other/window_test.py @@ -43,6 +43,7 @@ def check_window(actual): @mock.patch('apache_beam.Pipeline', TestPipeline) @mock.patch('apache_beam.examples.snippets.transforms.other.window.print', str) class WindowTest(unittest.TestCase): + def test_window(self): window.window(check_window) diff --git a/sdks/python/apache_beam/examples/snippets/util.py b/sdks/python/apache_beam/examples/snippets/util.py index 7911b32ce4d3..181e90639ca0 100644 --- a/sdks/python/apache_beam/examples/snippets/util.py +++ b/sdks/python/apache_beam/examples/snippets/util.py @@ -37,6 +37,7 @@ def assert_matches_stdout( comparing them. Can be used to sort lists before comparing. label (str): [optional] Label to make transform names unique. """ + def stdout_to_python_object(elem_str): try: elem = ast.literal_eval(elem_str) diff --git a/sdks/python/apache_beam/examples/snippets/util_test.py b/sdks/python/apache_beam/examples/snippets/util_test.py index fae26a339d53..857ca100a70e 100644 --- a/sdks/python/apache_beam/examples/snippets/util_test.py +++ b/sdks/python/apache_beam/examples/snippets/util_test.py @@ -28,6 +28,7 @@ class UtilTest(unittest.TestCase): + def test_assert_matches_stdout_object(self): expected = [ "{'a': '🍓', 'b': True}", diff --git a/sdks/python/apache_beam/examples/sql_taxi.py b/sdks/python/apache_beam/examples/sql_taxi.py index e8a29806d72a..9cee37305f68 100644 --- a/sdks/python/apache_beam/examples/sql_taxi.py +++ b/sdks/python/apache_beam/examples/sql_taxi.py @@ -50,8 +50,8 @@ def run(output_topic, pipeline_args): # Use beam.Row to create a schema-aware PCollection | "Create beam Row" >> beam.Map( lambda x: beam.Row( - ride_status=str(x['ride_status']), - passenger_count=int(x['passenger_count']))) + ride_status=str(x['ride_status']), passenger_count=int( + x['passenger_count']))) # SqlTransform will computes result within an existing window | "15s fixed windows" >> beam.WindowInto(beam.window.FixedWindows(15)) # Aggregate drop offs and pick ups that occur within each 15s window @@ -68,13 +68,10 @@ def run(output_topic, pipeline_args): # the outputs of the query. # Collect those attributes, as well as window information, into a dict | "Assemble Dictionary" >> beam.Map( - lambda row, - window=beam.DoFn.WindowParam: { - "ride_status": row.ride_status, - "num_rides": row.num_rides, - "total_passengers": row.total_passengers, - "window_start": window.start.to_rfc3339(), - "window_end": window.end.to_rfc3339() + lambda row, window=beam.DoFn.WindowParam: { + "ride_status": row.ride_status, "num_rides": row.num_rides, + "total_passengers": row.total_passengers, "window_start": window + .start.to_rfc3339(), "window_end": window.end.to_rfc3339() }) | "Convert to JSON" >> beam.Map(json.dumps) | "UTF-8 encode" >> beam.Map(lambda s: s.encode("utf-8")) diff --git a/sdks/python/apache_beam/examples/streaming_wordcount_debugging.py b/sdks/python/apache_beam/examples/streaming_wordcount_debugging.py index af99a4ab537d..fcd074653b80 100644 --- a/sdks/python/apache_beam/examples/streaming_wordcount_debugging.py +++ b/sdks/python/apache_beam/examples/streaming_wordcount_debugging.py @@ -53,6 +53,7 @@ class PrintFn(beam.DoFn): """A DoFn that prints label, element, its window, and its timstamp. """ + def __init__(self, label): self.label = label @@ -75,6 +76,7 @@ class AddTimestampFn(beam.DoFn): For example, 120 and Sometext will result in: (120, Timestamp(120) and (Sometext, Timestamp(1234567890). """ + def process(self, element): logging.info('Adding timestamp to: %s', element) try: diff --git a/sdks/python/apache_beam/examples/streaming_wordcount_debugging_it_test.py b/sdks/python/apache_beam/examples/streaming_wordcount_debugging_it_test.py index f3460ec24f1a..cfac01c2e12d 100644 --- a/sdks/python/apache_beam/examples/streaming_wordcount_debugging_it_test.py +++ b/sdks/python/apache_beam/examples/streaming_wordcount_debugging_it_test.py @@ -57,6 +57,7 @@ class StreamingWordcountDebuggingIT(unittest.TestCase): + def setUp(self): self.test_pipeline = TestPipeline(is_integration_test=True) self.project = self.test_pipeline.get_option('project') diff --git a/sdks/python/apache_beam/examples/streaming_wordcount_debugging_test.py b/sdks/python/apache_beam/examples/streaming_wordcount_debugging_test.py index 2ab80993080c..98bcd7d9d702 100644 --- a/sdks/python/apache_beam/examples/streaming_wordcount_debugging_test.py +++ b/sdks/python/apache_beam/examples/streaming_wordcount_debugging_test.py @@ -39,10 +39,12 @@ class StreamingWordcountDebugging(unittest.TestCase): + @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') @mock.patch('apache_beam.io.ReadFromPubSub') @mock.patch('apache_beam.io.WriteToPubSub') def test_streaming_wordcount_debugging(self, *unused_mocks): + def FakeReadFromPubSub(topic=None, subscription=None, values=None): expected_topic = topic expected_subscription = subscription @@ -55,6 +57,7 @@ def _inner(topic=None, subscription=None): return _inner class AssertTransform(beam.PTransform): + def __init__(self, matcher): self.matcher = matcher diff --git a/sdks/python/apache_beam/examples/streaming_wordcount_it_test.py b/sdks/python/apache_beam/examples/streaming_wordcount_it_test.py index 3fc4bf0fbc2f..4c85c39574c9 100644 --- a/sdks/python/apache_beam/examples/streaming_wordcount_it_test.py +++ b/sdks/python/apache_beam/examples/streaming_wordcount_it_test.py @@ -43,6 +43,7 @@ class StreamingWordCountIT(unittest.TestCase): + def setUp(self): self.test_pipeline = TestPipeline(is_integration_test=True) self.project = self.test_pipeline.get_option('project') diff --git a/sdks/python/apache_beam/examples/windowed_wordcount.py b/sdks/python/apache_beam/examples/windowed_wordcount.py index 7889f61b9a62..c4a82cd3b17a 100644 --- a/sdks/python/apache_beam/examples/windowed_wordcount.py +++ b/sdks/python/apache_beam/examples/windowed_wordcount.py @@ -40,6 +40,7 @@ def find_words(element): class FormatDoFn(beam.DoFn): + def process(self, element, window=beam.DoFn.WindowParam): ts_format = '%Y-%m-%d %H:%M:%S.%f UTC' window_start = window.start.to_utc_datetime().strftime(ts_format) diff --git a/sdks/python/apache_beam/examples/wordcount.py b/sdks/python/apache_beam/examples/wordcount.py index a9138647581c..b4b225e980a5 100644 --- a/sdks/python/apache_beam/examples/wordcount.py +++ b/sdks/python/apache_beam/examples/wordcount.py @@ -50,6 +50,7 @@ class WordExtractingDoFn(beam.DoFn): """Parse each line of input text into words.""" + def process(self, element): """Returns an iterator over the words of this element. diff --git a/sdks/python/apache_beam/examples/wordcount_debugging.py b/sdks/python/apache_beam/examples/wordcount_debugging.py index 581bbd3adc1b..4c11168862e1 100644 --- a/sdks/python/apache_beam/examples/wordcount_debugging.py +++ b/sdks/python/apache_beam/examples/wordcount_debugging.py @@ -78,6 +78,7 @@ class FilterTextFn(beam.DoFn): """A DoFn that filters for a specific key based on a regular expression.""" + def __init__(self, pattern): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() @@ -115,7 +116,9 @@ class CountWords(beam.PTransform): A PTransform that converts a PCollection containing lines of text into a PCollection of (word, count) tuples. """ + def expand(self, pcoll): + def count_ones(word_ones): (word, ones) = word_ones return (word, sum(ones)) diff --git a/sdks/python/apache_beam/examples/wordcount_with_metrics.py b/sdks/python/apache_beam/examples/wordcount_with_metrics.py index f575a8d7fbba..b7103fb67636 100644 --- a/sdks/python/apache_beam/examples/wordcount_with_metrics.py +++ b/sdks/python/apache_beam/examples/wordcount_with_metrics.py @@ -52,6 +52,7 @@ class WordExtractingDoFn(beam.DoFn): """Parse each line of input text into words.""" + def __init__(self): # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. # super().__init__() diff --git a/sdks/python/apache_beam/examples/wordcount_xlang.py b/sdks/python/apache_beam/examples/wordcount_xlang.py index 80a0c8fe7a4f..823df2ecf345 100644 --- a/sdks/python/apache_beam/examples/wordcount_xlang.py +++ b/sdks/python/apache_beam/examples/wordcount_xlang.py @@ -39,6 +39,7 @@ class WordExtractingDoFn(beam.DoFn): """Parse each line of input text into words.""" + def process(self, element): """Returns an iterator over the words of this element. diff --git a/sdks/python/apache_beam/internal/dill_pickler.py b/sdks/python/apache_beam/internal/dill_pickler.py index e1d6b7e74e49..f96753389f23 100644 --- a/sdks/python/apache_beam/internal/dill_pickler.py +++ b/sdks/python/apache_beam/internal/dill_pickler.py @@ -174,6 +174,7 @@ def save_code(pickler, obj): class _NoOpContextManager(object): + def __enter__(self): pass @@ -251,6 +252,7 @@ def _nested_type_wrapper(fun): For nested class object only it will save the containing class object so the nested structure is recreated during unpickle. """ + def wrapper(pickler, obj): # When the nested class is defined in the __main__ module we do not have to # do anything special because the pickler itself will save the constituent @@ -309,7 +311,8 @@ def save_module(pickler, obj): else: dill_log.info('M2: %s' % obj) # pylint: disable=protected-access - pickler.save_reduce(dill.dill._import_module, (obj.__name__, ), obj=obj) + pickler.save_reduce( + dill.dill._import_module, (obj.__name__, ), obj=obj) # pylint: enable=protected-access dill_log.info('# M2') diff --git a/sdks/python/apache_beam/internal/gcp/auth.py b/sdks/python/apache_beam/internal/gcp/auth.py index 66c08b8344cb..489de12ac4b6 100644 --- a/sdks/python/apache_beam/internal/gcp/auth.py +++ b/sdks/python/apache_beam/internal/gcp/auth.py @@ -92,6 +92,7 @@ class _ApitoolsCredentialsAdapter: upgrading the auth library used by Beam without simultaneously upgrading all the GCP client libraries (a much larger change). """ + def __init__(self, google_auth_credentials): self._google_auth_credentials = google_auth_credentials diff --git a/sdks/python/apache_beam/internal/gcp/auth_test.py b/sdks/python/apache_beam/internal/gcp/auth_test.py index 654d8e815a50..ef69898ea291 100644 --- a/sdks/python/apache_beam/internal/gcp/auth_test.py +++ b/sdks/python/apache_beam/internal/gcp/auth_test.py @@ -31,6 +31,7 @@ class MockLoggingHandler(logging.Handler): """Mock logging handler to check for expected logs.""" + def __init__(self, *args, **kwargs): self.reset() logging.Handler.__init__(self, *args, **kwargs) @@ -50,6 +51,7 @@ def reset(self): @unittest.skipIf(gauth is None, 'Google Auth dependencies are not installed') class AuthTest(unittest.TestCase): + @mock.patch('google.auth.default') def test_auth_with_retrys(self, unused_mock_arg): pipeline_options = PipelineOptions() diff --git a/sdks/python/apache_beam/internal/gcp/json_value_test.py b/sdks/python/apache_beam/internal/gcp/json_value_test.py index 21337de1a103..f28fe97cf93f 100644 --- a/sdks/python/apache_beam/internal/gcp/json_value_test.py +++ b/sdks/python/apache_beam/internal/gcp/json_value_test.py @@ -37,6 +37,7 @@ @unittest.skipIf(JsonValue is None, 'GCP dependencies are not installed') class JsonValueTest(unittest.TestCase): + def test_string_to(self): self.assertEqual(JsonValue(string_value='abc'), to_json_value('abc')) diff --git a/sdks/python/apache_beam/internal/http_client_test.py b/sdks/python/apache_beam/internal/http_client_test.py index c8790782a8c1..4d40fec5951f 100644 --- a/sdks/python/apache_beam/internal/http_client_test.py +++ b/sdks/python/apache_beam/internal/http_client_test.py @@ -30,6 +30,7 @@ class HttpClientTest(unittest.TestCase): + def test_proxy_from_env_http_with_port(self): with mock.patch.dict(os.environ, http_proxy='http://localhost:9000'): proxy_info = proxy_info_from_environment_var('http_proxy') diff --git a/sdks/python/apache_beam/internal/metrics/cells.py b/sdks/python/apache_beam/internal/metrics/cells.py index 989dc7183045..3f44a4df0e2e 100644 --- a/sdks/python/apache_beam/internal/metrics/cells.py +++ b/sdks/python/apache_beam/internal/metrics/cells.py @@ -47,6 +47,7 @@ class HistogramCell(MetricCell): This class is thread safe since underlying histogram object is thread safe. """ + def __init__(self, bucket_type): self._bucket_type = bucket_type self.data = HistogramData.identity_element(bucket_type) @@ -73,6 +74,7 @@ def to_runner_api_monitoring_info(self, name, transform_id): class HistogramCellFactory(MetricCellFactory): + def __init__(self, bucket_type): self._bucket_type = bucket_type @@ -89,6 +91,7 @@ def __hash__(self): class HistogramResult(object): + def __init__(self, data: 'HistogramData') -> None: self.data = data @@ -126,6 +129,7 @@ class HistogramData(object): This object is not thread safe, so it's not supposed to be modified outside the HistogramCell. """ + def __init__(self, histogram): self.histogram = histogram diff --git a/sdks/python/apache_beam/internal/metrics/cells_test.py b/sdks/python/apache_beam/internal/metrics/cells_test.py index 066dec4a2635..4f1b6bf6854d 100644 --- a/sdks/python/apache_beam/internal/metrics/cells_test.py +++ b/sdks/python/apache_beam/internal/metrics/cells_test.py @@ -28,6 +28,7 @@ class TestHistogramCell(unittest.TestCase): + @classmethod def _modify_histogram(cls, d): for i in range(cls.NUM_ITERATIONS): diff --git a/sdks/python/apache_beam/internal/metrics/metric.py b/sdks/python/apache_beam/internal/metrics/metric.py index 8acf800ff8c6..39fe2143571f 100644 --- a/sdks/python/apache_beam/internal/metrics/metric.py +++ b/sdks/python/apache_beam/internal/metrics/metric.py @@ -60,6 +60,7 @@ class Metrics(object): + @staticmethod def counter( urn: str, @@ -106,6 +107,7 @@ def histogram( class DelegatingHistogram(Histogram): """Metrics Histogram that Delegates functionality to MetricsEnvironment.""" + def __init__( self, metric_name: MetricName, @@ -125,6 +127,7 @@ def update(self, value: object) -> None: class MetricLogger(object): """Simple object to locally aggregate and log metrics.""" + def __init__(self) -> None: self._metric: Dict[MetricName, 'MetricCell'] = {} self._lock = threading.Lock() @@ -152,8 +155,8 @@ def log_metrics(self, reset_after_logging: bool = False) -> None: if self._lock.acquire(False): try: current_millis = int(time.time() * 1000) - if ((current_millis - self._last_logging_millis) > - self.minimum_logging_frequency_msec): + if ((current_millis - self._last_logging_millis) + > self.minimum_logging_frequency_msec): logging_metric_info = [ '[Locally aggregated metrics since %s]' % datetime.datetime.fromtimestamp( @@ -180,6 +183,7 @@ class ServiceCallMetric(object): TODO(ajamato): Add Request latency metric. """ + def __init__( self, request_count_urn: str, diff --git a/sdks/python/apache_beam/internal/metrics/metric_test.py b/sdks/python/apache_beam/internal/metrics/metric_test.py index 22b64ee73aee..2e775036e1b0 100644 --- a/sdks/python/apache_beam/internal/metrics/metric_test.py +++ b/sdks/python/apache_beam/internal/metrics/metric_test.py @@ -34,6 +34,7 @@ class MetricLoggerTest(unittest.TestCase): + @patch('apache_beam.internal.metrics.metric._LOGGER') def test_log_metrics(self, mock_logger): logger = MetricLogger() @@ -44,6 +45,7 @@ def test_log_metrics(self, mock_logger): logger.log_metrics() class Contains(str): + def __eq__(self, other): return self in other @@ -52,6 +54,7 @@ def __eq__(self, other): class MetricsTest(unittest.TestCase): + def test_create_process_wide(self): sampler = statesampler.StateSampler('', counters.CounterFactory()) statesampler.set_current_tracker(sampler) diff --git a/sdks/python/apache_beam/internal/module_test.py b/sdks/python/apache_beam/internal/module_test.py index ff0ad0c564e6..52b387725e98 100644 --- a/sdks/python/apache_beam/internal/module_test.py +++ b/sdks/python/apache_beam/internal/module_test.py @@ -25,12 +25,16 @@ class TopClass(object): + class NestedClass(object): + def __init__(self, datum): self.datum = 'X:%s' % datum class MiddleClass(object): + class NestedClass(object): + def __init__(self, datum): self.datum = 'Y:%s' % datum @@ -45,13 +49,16 @@ def get_lambda_with_closure(message): class Xyz(object): """A class to be pickled.""" + def foo(self, s): return re.findall(r'\w+', s) def create_class(datum): """Creates an unnamable class to be pickled.""" + class Z(object): + def get(self): return 'Z:%s' % datum diff --git a/sdks/python/apache_beam/internal/util.py b/sdks/python/apache_beam/internal/util.py index 85a6e4c43b83..bac34bd8286f 100644 --- a/sdks/python/apache_beam/internal/util.py +++ b/sdks/python/apache_beam/internal/util.py @@ -51,6 +51,7 @@ class ArgumentPlaceholder(object): Fn object by the time it executes will have such values replaced with real computed values. """ + def __eq__(self, other): """Tests for equality of two placeholder objects. @@ -97,8 +98,7 @@ def swapper(value): # by sorting the entries first. This will be important when putting back # PValues. new_kwargs = dict((k, swapper(v)) if isinstance(v, pvalue_class) else (k, v) - for k, - v in sorted(kwargs.items())) + for k, v in sorted(kwargs.items())) return (new_args, new_kwargs, pvals) @@ -123,8 +123,8 @@ def insert_values_in_args(args, kwargs, values): for arg in args ] new_kwargs = dict( - (k, next(v_iter)) if isinstance(v, ArgumentPlaceholder) else (k, v) for k, - v in sorted(kwargs.items())) + (k, next(v_iter)) if isinstance(v, ArgumentPlaceholder) else (k, v) + for k, v in sorted(kwargs.items())) return (new_args, new_kwargs) diff --git a/sdks/python/apache_beam/internal/util_test.py b/sdks/python/apache_beam/internal/util_test.py index ded1190f8405..2449b5412c32 100644 --- a/sdks/python/apache_beam/internal/util_test.py +++ b/sdks/python/apache_beam/internal/util_test.py @@ -26,6 +26,7 @@ class UtilTest(unittest.TestCase): + def test_remove_objects_from_args(self): args, kwargs, objs = remove_objects_from_args( [1, 'a'], {'x': 1, 'y': 3.14}, (str, float)) diff --git a/sdks/python/apache_beam/io/avroio.py b/sdks/python/apache_beam/io/avroio.py index 8b7958a00b80..3b8bd11ffec1 100644 --- a/sdks/python/apache_beam/io/avroio.py +++ b/sdks/python/apache_beam/io/avroio.py @@ -85,6 +85,7 @@ class ReadFromAvro(PTransform): that comply with the schema contained in the Avro file that contains those records. """ + def __init__( self, file_pattern=None, @@ -287,6 +288,7 @@ def expand(self, pbegin): class _AvroUtils(object): + @staticmethod def advance_file_past_next_sync_marker(f, sync_marker): buf_size = 10000 @@ -322,6 +324,7 @@ class _FastAvroSource(filebasedsource.FileBasedSource): TODO: remove ``_AvroSource`` in favor of using ``_FastAvroSource`` everywhere once it has been more widely tested """ + def read_records(self, file_name, range_tracker): next_block_start = -1 @@ -367,6 +370,7 @@ class WriteToAvro(beam.transforms.PTransform): If the input has a schema, a corresponding avro schema will be automatically generated and used to write the output records.""" + def __init__( self, file_path_prefix, @@ -411,13 +415,8 @@ def __init__( """ self._schema = schema self._sink_provider = lambda avro_schema: _create_avro_sink( - file_path_prefix, - avro_schema, - codec, - file_name_suffix, - num_shards, - shard_name_template, - mime_type) + file_path_prefix, avro_schema, codec, file_name_suffix, num_shards, + shard_name_template, mime_type) def expand(self, pcoll): if self._schema: @@ -465,6 +464,7 @@ def _create_avro_sink( class _BaseAvroSink(filebasedsink.FileBasedSink): """A base for a sink for avro files. """ + def __init__( self, file_path_prefix, @@ -496,6 +496,7 @@ def display_data(self): class _FastAvroSink(_BaseAvroSink): """A sink for avro files using FastAvro. """ + def __init__( self, file_path_prefix, diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index 2d25010da486..ad6ccfeb26ad 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -174,8 +174,8 @@ def test_schema_read_write(self): @pytest.mark.xlang_sql_expansion_service @unittest.skipIf( - TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is - None, + TestPipeline().get_pipeline_options().view_as(StandardOptions).runner + is None, "Must be run with a runner that supports staging java artifacts.") def test_avro_schema_to_beam_schema_with_nullable_atomic_fields(self): records = [] @@ -396,6 +396,7 @@ def test_read_with_splitting_pattern(self): self._run_avro_test(pattern, 100, True, expected_result) def test_dynamic_work_rebalancing_exhaustive(self): + def compare_split_points(file_name): source = _FastAvroSource(file_name) splits = [ @@ -625,6 +626,7 @@ def test_writer_open_and_close(self): class TestFastAvro(AvroBase, unittest.TestCase): + def __init__(self, methodName='runTest'): super().__init__(methodName) self.SCHEMA = parse_schema(json.loads(self.SCHEMA_STRING)) diff --git a/sdks/python/apache_beam/io/aws/clients/s3/boto3_client.py b/sdks/python/apache_beam/io/aws/clients/s3/boto3_client.py index aee24ac2c052..4706f52f31a9 100644 --- a/sdks/python/apache_beam/io/aws/clients/s3/boto3_client.py +++ b/sdks/python/apache_beam/io/aws/clients/s3/boto3_client.py @@ -40,6 +40,7 @@ class Client(object): """ Wrapper for boto3 library """ + def __init__(self, options): assert boto3 is not None, 'Missing boto3 requirement' if isinstance(options, pipeline_options.PipelineOptions): diff --git a/sdks/python/apache_beam/io/aws/clients/s3/client_test.py b/sdks/python/apache_beam/io/aws/clients/s3/client_test.py index 67797b1736b9..b809c7e456b5 100644 --- a/sdks/python/apache_beam/io/aws/clients/s3/client_test.py +++ b/sdks/python/apache_beam/io/aws/clients/s3/client_test.py @@ -27,6 +27,7 @@ class ClientErrorTest(unittest.TestCase): + def setUp(self): # These tests can be run locally against a mock S3 client, or as integration diff --git a/sdks/python/apache_beam/io/aws/clients/s3/fake_client.py b/sdks/python/apache_beam/io/aws/clients/s3/fake_client.py index e2d34a11a46f..a36de57d308d 100644 --- a/sdks/python/apache_beam/io/aws/clients/s3/fake_client.py +++ b/sdks/python/apache_beam/io/aws/clients/s3/fake_client.py @@ -26,6 +26,7 @@ class FakeFile(object): + def __init__(self, bucket, key, contents, etag=None): self.bucket = bucket self.key = key @@ -53,6 +54,7 @@ def get_metadata(self): class FakeS3Client(object): + def __init__(self): self.files = {} self.list_continuation_tokens = {} diff --git a/sdks/python/apache_beam/io/aws/clients/s3/messages.py b/sdks/python/apache_beam/io/aws/clients/s3/messages.py index 8555364b727c..902ef36ccd73 100644 --- a/sdks/python/apache_beam/io/aws/clients/s3/messages.py +++ b/sdks/python/apache_beam/io/aws/clients/s3/messages.py @@ -22,6 +22,7 @@ class GetRequest(): """ S3 request object for `Get` command """ + def __init__(self, bucket, object): self.bucket = bucket self.object = object @@ -31,6 +32,7 @@ class UploadResponse(): """ S3 response object for `StartUpload` command """ + def __init__(self, upload_id): self.upload_id = upload_id @@ -39,6 +41,7 @@ class UploadRequest(): """ S3 request object for `StartUpload` command """ + def __init__(self, bucket, object, mime_type): self.bucket = bucket self.object = object @@ -49,6 +52,7 @@ class UploadPartRequest(): """ S3 request object for `UploadPart` command """ + def __init__(self, bucket, object, upload_id, part_number, bytes): self.bucket = bucket self.object = object @@ -62,6 +66,7 @@ class UploadPartResponse(): """ S3 response object for `UploadPart` command """ + def __init__(self, etag, part_number): self.etag = etag self.part_number = part_number @@ -71,6 +76,7 @@ class CompleteMultipartUploadRequest(): """ S3 request object for `UploadPart` command """ + def __init__(self, bucket, object, upload_id, parts): # parts is a list of objects of the form # {'ETag': response.etag, 'PartNumber': response.part_number} @@ -85,6 +91,7 @@ class ListRequest(): """ S3 request object for `List` command """ + def __init__(self, bucket, prefix, continuation_token=None): self.bucket = bucket self.prefix = prefix @@ -95,6 +102,7 @@ class ListResponse(): """ S3 response object for `List` command """ + def __init__(self, items, next_token=None): self.items = items self.next_token = next_token @@ -104,6 +112,7 @@ class Item(): """ An item in S3 """ + def __init__(self, etag, key, last_modified, size, mime_type=None): self.etag = etag self.key = key @@ -116,12 +125,14 @@ class DeleteRequest(): """ S3 request object for `Delete` command """ + def __init__(self, bucket, object): self.bucket = bucket self.object = object class DeleteBatchRequest(): + def __init__(self, bucket, objects): # `objects` is a list of strings corresponding to the keys to be deleted # in the bucket @@ -130,6 +141,7 @@ def __init__(self, bucket, objects): class DeleteBatchResponse(): + def __init__(self, deleted, failed, errors): # `deleted` is a list of strings corresponding to the keys that were deleted # `failed` is a list of strings corresponding to the keys that caused errors @@ -140,6 +152,7 @@ def __init__(self, deleted, failed, errors): class CopyRequest(): + def __init__(self, src_bucket, src_key, dest_bucket, dest_key): self.src_bucket = src_bucket self.src_key = src_key @@ -148,6 +161,7 @@ def __init__(self, src_bucket, src_key, dest_bucket, dest_key): class S3ClientError(Exception): + def __init__(self, message=None, code=None): self.message = message self.code = code diff --git a/sdks/python/apache_beam/io/aws/s3filesystem_test.py b/sdks/python/apache_beam/io/aws/s3filesystem_test.py index 036727cd7a70..07abe3486c75 100644 --- a/sdks/python/apache_beam/io/aws/s3filesystem_test.py +++ b/sdks/python/apache_beam/io/aws/s3filesystem_test.py @@ -41,6 +41,7 @@ @unittest.skipIf(s3filesystem is None, 'AWS dependencies are not installed') class S3FileSystemTest(unittest.TestCase): + def setUp(self): pipeline_options = PipelineOptions() self.fs = s3filesystem.S3FileSystem(pipeline_options=pipeline_options) diff --git a/sdks/python/apache_beam/io/aws/s3io.py b/sdks/python/apache_beam/io/aws/s3io.py index 887bb4c7baad..01be057ea9c3 100644 --- a/sdks/python/apache_beam/io/aws/s3io.py +++ b/sdks/python/apache_beam/io/aws/s3io.py @@ -56,6 +56,7 @@ def parse_s3_path(s3_path, object_optional=False): class S3IO(object): """S3 I/O client.""" + def __init__(self, client=None, options=None): if client is None and options is None: raise ValueError('Must provide one of client or options') @@ -566,6 +567,7 @@ def rename_files(self, src_dest_pairs): class S3Downloader(Downloader): + def __init__(self, client, path, buffer_size): self._client = client self._path = path @@ -602,6 +604,7 @@ def get_range(self, start, end): class S3Uploader(Uploader): + def __init__(self, client, path, mime_type='application/octet-stream'): self._client = client self._path = path diff --git a/sdks/python/apache_beam/io/aws/s3io_test.py b/sdks/python/apache_beam/io/aws/s3io_test.py index ffab95727078..339d7378b887 100644 --- a/sdks/python/apache_beam/io/aws/s3io_test.py +++ b/sdks/python/apache_beam/io/aws/s3io_test.py @@ -64,6 +64,7 @@ def test_bad_s3_path_object_optional(self): class TestS3IO(unittest.TestCase): + def _insert_random_file(self, client, path, size): bucket, name = s3io.parse_s3_path(path) contents = os.urandom(size) diff --git a/sdks/python/apache_beam/io/azure/blobstoragefilesystem_test.py b/sdks/python/apache_beam/io/azure/blobstoragefilesystem_test.py index c3418e137e87..1ff28793b96f 100644 --- a/sdks/python/apache_beam/io/azure/blobstoragefilesystem_test.py +++ b/sdks/python/apache_beam/io/azure/blobstoragefilesystem_test.py @@ -42,6 +42,7 @@ @unittest.skipIf( blobstoragefilesystem is None, 'Azure dependencies are not installed') class BlobStorageFileSystemTest(unittest.TestCase): + def setUp(self): pipeline_options = PipelineOptions() self.fs = blobstoragefilesystem.BlobStorageFileSystem( diff --git a/sdks/python/apache_beam/io/azure/blobstorageio.py b/sdks/python/apache_beam/io/azure/blobstorageio.py index cfa4fe7d2916..e681b668cc44 100644 --- a/sdks/python/apache_beam/io/azure/blobstorageio.py +++ b/sdks/python/apache_beam/io/azure/blobstorageio.py @@ -86,6 +86,7 @@ def get_azfs_url(storage_account, container, blob=''): class Blob(): """A Blob in Azure Blob Storage.""" + def __init__(self, etag, name, last_updated, size, mime_type): self.etag = etag self.name = name @@ -101,6 +102,7 @@ class BlobStorageIOError(IOError, retry.PermanentException): class BlobStorageError(Exception): """Blob Storage client error.""" + def __init__(self, message=None, code=None): self.message = message self.code = code @@ -108,6 +110,7 @@ def __init__(self, message=None, code=None): class BlobStorageIO(object): """Azure Blob Storage I/O client.""" + def __init__(self, client=None, pipeline_options=None): if client is None: azure_options = pipeline_options.view_as(AzureOptions) @@ -654,6 +657,7 @@ def list_files(self, path, with_metadata=False): class BlobStorageDownloader(Downloader): + def __init__(self, client, path, buffer_size): self._client = client self._path = path @@ -692,6 +696,7 @@ def get_range(self, start, end): class BlobStorageUploader(Uploader): + def __init__(self, client, path, mime_type='application/octet-stream'): self._client = client self._path = path diff --git a/sdks/python/apache_beam/io/components/util.py b/sdks/python/apache_beam/io/components/util.py index 0726ecb13825..c78d462ba0a0 100644 --- a/sdks/python/apache_beam/io/components/util.py +++ b/sdks/python/apache_beam/io/components/util.py @@ -34,6 +34,7 @@ class MovingSum(object): convenience we expose the count of entries as well so this doubles as a moving average tracker. """ + def __init__(self, window_ms, bucket_ms): if window_ms < bucket_ms or bucket_ms <= 0: raise ValueError("window_ms >= bucket_ms > 0 please") diff --git a/sdks/python/apache_beam/io/concat_source.py b/sdks/python/apache_beam/io/concat_source.py index 35ae6fe27317..bfed92d2e36f 100644 --- a/sdks/python/apache_beam/io/concat_source.py +++ b/sdks/python/apache_beam/io/concat_source.py @@ -35,6 +35,7 @@ class ConcatSource(iobase.BoundedSource): Primarily for internal use, use the ``apache_beam.Flatten`` transform to create the union of several reads. """ + def __init__(self, sources): self._source_bundles = [ source if isinstance(source, iobase.SourceBundle) else @@ -98,6 +99,7 @@ class ConcatRangeTracker(iobase.RangeTracker): """For internal use only; no backwards-compatibility guarantees. Range tracker for ConcatSource""" + def __init__(self, start, end, source_bundles): """Initializes ``ConcatRangeTracker`` diff --git a/sdks/python/apache_beam/io/concat_source_test.py b/sdks/python/apache_beam/io/concat_source_test.py index efa24b3975fc..38aaa6e6a31a 100644 --- a/sdks/python/apache_beam/io/concat_source_test.py +++ b/sdks/python/apache_beam/io/concat_source_test.py @@ -85,6 +85,7 @@ def __eq__(self, other): class ConcatSourceTest(unittest.TestCase): + def test_range_source(self): source_test_utils.assert_split_at_fraction_exhaustive(RangeSource(0, 10, 3)) diff --git a/sdks/python/apache_beam/io/debezium.py b/sdks/python/apache_beam/io/debezium.py index ada25760c27a..d94e72b42e40 100644 --- a/sdks/python/apache_beam/io/debezium.py +++ b/sdks/python/apache_beam/io/debezium.py @@ -115,6 +115,7 @@ class DriverClassName(Enum): class _JsonStringToDictionaries(DoFn): """ A DoFn that consumes a JSON string and yields a python dictionary """ + def process(self, json_string): obj = json.loads(json_string) yield obj diff --git a/sdks/python/apache_beam/io/external/generate_sequence_test.py b/sdks/python/apache_beam/io/external/generate_sequence_test.py index d44fd363b95a..21246a7d81ae 100644 --- a/sdks/python/apache_beam/io/external/generate_sequence_test.py +++ b/sdks/python/apache_beam/io/external/generate_sequence_test.py @@ -40,6 +40,7 @@ os.environ.get('EXPANSION_PORT'), "EXPANSION_PORT environment var is not provided.") class XlangGenerateSequenceTest(unittest.TestCase): + def test_generate_sequence(self): port = os.environ.get('EXPANSION_PORT') address = 'localhost:%s' % port diff --git a/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py b/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py index abe9530787e8..1cbbf0709497 100644 --- a/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py @@ -37,10 +37,11 @@ @unittest.skipIf( PostgresContainer is None, 'testcontainers package is not installed') @unittest.skipIf( - TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is - None, + TestPipeline().get_pipeline_options().view_as(StandardOptions).runner + is None, 'Do not run this test on precommit suites.') class CrossLanguageDebeziumIOTest(unittest.TestCase): + def setUp(self): self.username = 'debezium' self.password = 'dbz' diff --git a/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py b/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py index 38a405c2d331..8d8fa3c36ff0 100644 --- a/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py @@ -71,8 +71,8 @@ @unittest.skipIf( PostgresContainer is None, 'testcontainers package is not installed') @unittest.skipIf( - TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is - None, + TestPipeline().get_pipeline_options().view_as(StandardOptions).runner + is None, 'Do not run this test on precommit suites.') class CrossLanguageJdbcIOTest(unittest.TestCase): DbData = typing.NamedTuple( diff --git a/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py b/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py index a7bf686d0642..f9a7b5a9fea2 100644 --- a/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py @@ -66,6 +66,7 @@ def process( class CrossLanguageKafkaIO(object): + def __init__( self, bootstrap_servers, topic, null_key, expansion_service=None): self.bootstrap_servers = bootstrap_servers @@ -113,6 +114,7 @@ def run_xlang_kafkaio(self, pipeline): class CrossLanguageKafkaIOTest(unittest.TestCase): + @unittest.skipUnless( os.environ.get('LOCAL_KAFKA_JAR'), "LOCAL_KAFKA_JAR environment var is not provided.") diff --git a/sdks/python/apache_beam/io/external/xlang_kafkaio_perf_test.py b/sdks/python/apache_beam/io/external/xlang_kafkaio_perf_test.py index 08a6baee468d..2cb13f886817 100644 --- a/sdks/python/apache_beam/io/external/xlang_kafkaio_perf_test.py +++ b/sdks/python/apache_beam/io/external/xlang_kafkaio_perf_test.py @@ -37,6 +37,7 @@ class KafkaIOTestOptions(LoadTestOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -57,6 +58,7 @@ def _add_argparse_args(cls, parser): class KafkaIOPerfTest: """Performance test for cross-language Kafka IO pipeline.""" + def run(self): write_test = _KafkaIOBatchWritePerfTest() read_test = _KafkaIOSDFReadPerfTest() @@ -65,6 +67,7 @@ def run(self): class _KafkaIOBatchWritePerfTest(LoadTest): + def __init__(self): super().__init__(WRITE_NAMESPACE) self.test_options = self.pipeline.get_pipeline_options().view_as( @@ -93,6 +96,7 @@ def cleanup(self): class _KafkaIOSDFReadPerfTest(LoadTest): + def __init__(self): super().__init__(READ_NAMESPACE) self.test_options = self.pipeline.get_pipeline_options().view_as( diff --git a/sdks/python/apache_beam/io/external/xlang_kinesisio_it_test.py b/sdks/python/apache_beam/io/external/xlang_kinesisio_it_test.py index c9181fb2a721..007fad31bf04 100644 --- a/sdks/python/apache_beam/io/external/xlang_kinesisio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_kinesisio_it_test.py @@ -79,6 +79,7 @@ TestPipeline().get_pipeline_options().view_as(StandardOptions).runner, 'Do not run this test on precommit suites.') class CrossLanguageKinesisIOTest(unittest.TestCase): + @unittest.skipUnless( TestPipeline().get_option('aws_kinesis_stream'), 'Cannot test on real aws without pipeline options provided') @@ -242,6 +243,7 @@ def tearDown(self): class KinesisHelper: + def __init__(self, access_key, secret_key, region, service_endpoint): self.kinesis_client = boto3.client( service_name='kinesis', diff --git a/sdks/python/apache_beam/io/external/xlang_snowflakeio_it_test.py b/sdks/python/apache_beam/io/external/xlang_snowflakeio_it_test.py index f78175a8696d..e702edd391f5 100644 --- a/sdks/python/apache_beam/io/external/xlang_snowflakeio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_snowflakeio_it_test.py @@ -92,11 +92,13 @@ TestPipeline().get_option('server_name') is None, 'Snowflake IT test requires external configuration to be run.') class SnowflakeTest(unittest.TestCase): + def test_snowflake_write_read(self): self.run_write() self.run_read() def run_write(self): + def user_data_mapper(test_row): return [ str(test_row.number_column).encode('utf-8'), @@ -137,6 +139,7 @@ def user_data_mapper(test_row): )) def run_read(self): + def csv_mapper(bytes_array): return TestRow( int(bytes_array[0]), diff --git a/sdks/python/apache_beam/io/filebasedio_perf_test.py b/sdks/python/apache_beam/io/filebasedio_perf_test.py index 78a390d9bed4..e3c7965e7f7e 100644 --- a/sdks/python/apache_beam/io/filebasedio_perf_test.py +++ b/sdks/python/apache_beam/io/filebasedio_perf_test.py @@ -45,6 +45,7 @@ class FileBasedIOTestOptions(LoadTestOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -79,6 +80,7 @@ class SyntheticRecordToStrFn(beam.DoFn): Output length = 4(ceil[len(key)/3] + ceil[len(value)/3]) + 1 """ + def process(self, element): import base64 yield base64.b64encode(element[0]) + b' ' + base64.b64encode(element[1]) @@ -86,6 +88,7 @@ def process(self, element): class CreateFolderFn(beam.DoFn): """Create folder at pipeline runtime.""" + def __init__(self, folder): self.folder = folder @@ -97,6 +100,7 @@ def process(self, element): class TextIOPerfTest: + def run(self): write_test = _TextIOWritePerfTest(need_cleanup=False) read_test = _TextIOReadPerfTest(input_folder=write_test.output_folder) @@ -105,6 +109,7 @@ def run(self): class _TextIOWritePerfTest(LoadTest): + def __init__(self, need_cleanup=True): super().__init__(WRITE_NAMESPACE) self.need_cleanup = need_cleanup @@ -146,6 +151,7 @@ def cleanup(self): class _TextIOReadPerfTest(LoadTest): + def __init__(self, input_folder): super().__init__(READ_NAMESPACE) self.test_options = self.pipeline.get_pipeline_options().view_as( diff --git a/sdks/python/apache_beam/io/filebasedsink.py b/sdks/python/apache_beam/io/filebasedsink.py index eb433bd60583..faf05e89aa4f 100644 --- a/sdks/python/apache_beam/io/filebasedsink.py +++ b/sdks/python/apache_beam/io/filebasedsink.py @@ -427,6 +427,7 @@ def __eq__(self, other): class FileBasedSinkWriter(iobase.Writer): """The writer for FileBasedSink. """ + def __init__(self, sink, temp_shard_path): self.sink = sink self.temp_shard_path = temp_shard_path @@ -440,10 +441,10 @@ def write(self, value): def at_capacity(self): return ( self.sink.max_records_per_shard and - self.num_records_written >= self.sink.max_records_per_shard - ) or ( - self.sink.max_bytes_per_shard and - self.sink.byte_counter.bytes_written >= self.sink.max_bytes_per_shard) + self.num_records_written >= self.sink.max_records_per_shard) or ( + self.sink.max_bytes_per_shard and + self.sink.byte_counter.bytes_written + >= self.sink.max_bytes_per_shard) def close(self): self.sink.close(self.temp_handle) @@ -451,6 +452,7 @@ def close(self): class _ByteCountingWriter: + def __init__(self, writer): self.writer = writer self.bytes_written = 0 diff --git a/sdks/python/apache_beam/io/filebasedsink_test.py b/sdks/python/apache_beam/io/filebasedsink_test.py index 121bc479200f..ae0c9c94ccd2 100644 --- a/sdks/python/apache_beam/io/filebasedsink_test.py +++ b/sdks/python/apache_beam/io/filebasedsink_test.py @@ -50,6 +50,7 @@ class _TestCaseWithTempDirCleanUp(unittest.TestCase): Inherited test cases will call self._new_tempdir() to start a temporary dir which will be deleted at the end of the tests (when tearDown() is called). """ + def setUp(self): self._tempdirs = [] @@ -79,6 +80,7 @@ def _create_temp_file(self, name='', suffix='', dir=None, content=None): class MyFileBasedSink(filebasedsink.FileBasedSink): + def open(self, temp_path): # TODO: Fix main session pickling. # file_handle = super().open(temp_path) @@ -99,6 +101,7 @@ def close(self, file_handle): class TestFileBasedSink(_TestCaseWithTempDirCleanUp): + def _common_init(self, sink): # Manually invoke the generic Sink API. init_token = sink.initialize_write() @@ -202,6 +205,7 @@ def run_temp_dir_check( dir_root_path, prefix, separator): + def _get_temp_dir(file_path_prefix): sink = MyFileBasedSink( file_path_prefix, diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py index 49b1b1d125f1..d6bbf8e8d8f1 100644 --- a/sdks/python/apache_beam/io/filebasedsource.py +++ b/sdks/python/apache_beam/io/filebasedsource.py @@ -241,6 +241,7 @@ def _determine_splittability_from_compression_type(file_path, compression_type): class _SingleFileSource(iobase.BoundedSource): """Denotes a source for a specific file type.""" + def __init__( self, file_based_source, @@ -339,6 +340,7 @@ def default_output_coder(self): class _ExpandIntoRanges(DoFn): + def __init__( self, splittable, compression_type, desired_bundle_size, min_bundle_size): self._desired_bundle_size = desired_bundle_size @@ -372,6 +374,7 @@ def process(self, element: Union[str, FileMetadata], *args, class _ReadRange(DoFn): + def __init__( self, source_from_file: Union[str, iobase.BoundedSource], @@ -406,6 +409,7 @@ class ReadAllFiles(PTransform): PTransform authors who wishes to implement file-based Read transforms that read a PCollection of files. """ + def __init__( self, splittable: bool, diff --git a/sdks/python/apache_beam/io/filebasedsource_test.py b/sdks/python/apache_beam/io/filebasedsource_test.py index e68d2afbac9d..cb3b245112f2 100644 --- a/sdks/python/apache_beam/io/filebasedsource_test.py +++ b/sdks/python/apache_beam/io/filebasedsource_test.py @@ -47,6 +47,7 @@ class LineSource(FileBasedSource): + def read_records(self, file_name, range_tracker): f = self.open_file(file_name) try: @@ -178,7 +179,9 @@ def write_pattern(lines_per_file, no_data=False): class TestConcatSource(unittest.TestCase): + class DummySource(iobase.BoundedSource): + def __init__(self, values): self._values = values @@ -252,6 +255,7 @@ def test_estimate_size(self): class TestFileBasedSource(unittest.TestCase): + def setUp(self): # Reducing the size of thread pools. Without this test execution may fail in # environments with limited amount of resources. @@ -584,10 +588,12 @@ def test_read_auto_pattern_compressed_and_uncompressed(self): assert_that(pcoll, equal_to(lines)) def test_splits_get_coder_from_fbs(self): + class DummyCoder(object): val = 12345 class FileBasedSourceWithCoder(LineSource): + def default_output_coder(self): return DummyCoder() @@ -601,6 +607,7 @@ def default_output_coder(self): class TestSingleFileSource(unittest.TestCase): + def setUp(self): # Reducing the size of thread pools. Without this test execution may fail in # environments with limited amount of resources. diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py index 111206a18a28..4d33cb644e47 100644 --- a/sdks/python/apache_beam/io/fileio.py +++ b/sdks/python/apache_beam/io/fileio.py @@ -161,6 +161,7 @@ def allow_empty_match(pattern, setting): class _MatchAllFn(beam.DoFn): + def __init__(self, empty_match_treatment): self._empty_match_treatment = empty_match_treatment @@ -183,6 +184,7 @@ class MatchFiles(beam.PTransform): This ``PTransform`` returns a ``PCollection`` of matching files in the form of ``FileMetadata`` objects.""" + def __init__( self, file_pattern: str, @@ -200,6 +202,7 @@ class MatchAll(beam.PTransform): This ``PTransform`` returns a ``PCollection`` of matching files in the form of ``FileMetadata`` objects.""" + def __init__(self, empty_match_treatment=EmptyMatchTreatment.ALLOW): self._empty_match_treatment = empty_match_treatment @@ -212,6 +215,7 @@ def expand( class ReadableFile(object): """A utility class for accessing files.""" + def __init__(self, metadata, compression=None): self.metadata = metadata self._compression = compression @@ -231,6 +235,7 @@ def read_utf8(self): class _ReadMatchesFn(beam.DoFn): + def __init__(self, compression, skip_directories): self._compression = compression self._skip_directories = skip_directories @@ -271,6 +276,7 @@ class MatchContinuously(beam.PTransform): (https://cloud.google.com/storage/docs/pubsub-notifications) when using GCS if possible. """ + def __init__( self, file_pattern, @@ -345,6 +351,7 @@ class ReadMatches(beam.PTransform): """Converts each result of MatchFiles() or MatchAll() to a ReadableFile. This helps read in a file's contents or obtain a file descriptor.""" + def __init__(self, compression=None, skip_directories=True): self._compression = compression self._skip_directories = skip_directories @@ -373,6 +380,7 @@ class FileSink(object): - The ``create_metadata`` method, which creates all metadata passed to Filesystems.create. """ + def create_metadata( self, destination: str, full_file_name: str) -> FileMetadata: return FileMetadata( @@ -396,6 +404,7 @@ class TextSink(FileSink): This sink simply calls file_handler.write(record.encode('utf8') + '\n') on all records that come into it. """ + def open(self, fh): self._fh = fh @@ -466,6 +475,7 @@ def _format_shard( def destination_prefix_naming(suffix=None) -> FileNaming: + def _inner(window, pane, shard_index, total_shards, compression, destination): prefix = str(destination) return _format_shard( @@ -475,6 +485,7 @@ def _inner(window, pane, shard_index, total_shards, compression, destination): def default_file_naming(prefix, suffix=None) -> FileNaming: + def _inner(window, pane, shard_index, total_shards, compression, destination): return _format_shard( window, pane, shard_index, total_shards, compression, prefix, suffix) @@ -483,6 +494,7 @@ def _inner(window, pane, shard_index, total_shards, compression, destination): def single_file_naming(prefix, suffix=None) -> FileNaming: + def _inner(window, pane, shard_index, total_shards, compression, destination): assert shard_index in (0, None), shard_index assert total_shards in (1, None), total_shards @@ -669,6 +681,7 @@ def _create_writer( class _MoveTempFilesIntoFinalDestinationFn(beam.DoFn): + def __init__(self, path, file_naming_fn, temp_dir): self.path = path self.file_naming_fn = file_naming_fn @@ -751,6 +764,7 @@ def _check_orphaned_files(self, writer_key): class _WriteShardedRecordsFn(beam.DoFn): + def __init__( self, base_path, sink_fn: Callable[[Any], FileSink], shards: int): self.base_path = base_path @@ -795,6 +809,7 @@ def process( class _AppendShardedDestination(beam.DoFn): + def __init__(self, destination: Callable[[Any], str], shards: int): self.destination_fn = destination self.shards = shards @@ -879,12 +894,13 @@ def finish_bundle(self): sink.flush() writer.close() - file_result = FileResult(self._file_names[key], - shard_index=-1, - total_shards=0, - window=key[1], - pane=None, # TODO(pabloem): get the pane info - destination=key[0]) + file_result = FileResult( + self._file_names[key], + shard_index=-1, + total_shards=0, + window=key[1], + pane=None, # TODO(pabloem): get the pane info + destination=key[0]) yield beam.pvalue.TaggedOutput( self.WRITTEN_FILES, diff --git a/sdks/python/apache_beam/io/fileio_test.py b/sdks/python/apache_beam/io/fileio_test.py index ff4be9d3d7cc..4fc088746b4e 100644 --- a/sdks/python/apache_beam/io/fileio_test.py +++ b/sdks/python/apache_beam/io/fileio_test.py @@ -59,6 +59,7 @@ def _get_file_reader(readable_file): class MatchTest(_TestCaseWithTempDirCleanUp): + def test_basic_two_files(self): files = [] tempdir = '%s%s' % (self._new_tempdir(), os.sep) @@ -137,6 +138,7 @@ def test_match_files_one_directory_failure2(self): class ReadTest(_TestCaseWithTempDirCleanUp): + def test_basic_file_name_provided(self): content = 'TestingMyContent\nIn multiple lines\nhaha!' dir = '%s%s' % (self._new_tempdir(), os.sep) @@ -322,6 +324,7 @@ def test_transform_on_gcs(self): class MatchContinuouslyTest(_TestCaseWithTempDirCleanUp): + def test_with_deduplication(self): files = [] tempdir = '%s%s' % (self._new_tempdir(), os.sep) @@ -461,6 +464,7 @@ class WriteFilesTest(_TestCaseWithTempDirCleanUp): for elm in SIMPLE_COLLECTION} class CsvSink(fileio.TextSink): + def __init__(self, headers): self.headers = headers @@ -469,6 +473,7 @@ def write(self, record): self._fh.write('\n'.encode('utf8')) class JsonSink(fileio.TextSink): + def write(self, record): self._fh.write(json.dumps(record).encode('utf8')) self._fh.write('\n'.encode('utf8')) @@ -497,8 +502,8 @@ def test_write_to_single_file_batch(self): def test_write_to_dynamic_destination(self): sink_params = [ - fileio.TextSink, # pass a type signature - fileio.TextSink() # pass a FileSink object + fileio.TextSink, # pass a type signature + fileio.TextSink() # pass a FileSink object ] for sink in sink_params: @@ -522,8 +527,8 @@ def test_write_to_dynamic_destination(self): | fileio.ReadMatches() | beam.Map( lambda f: ( - os.path.basename(f.metadata.path).split('-')[0], - sorted(map(int, f.read_utf8().strip().split('\n')))))) + os.path.basename(f.metadata.path).split('-')[0], sorted( + map(int, f.read_utf8().strip().split('\n')))))) assert_that( result, @@ -664,7 +669,9 @@ def test_write_to_different_file_types(self): label='verifyApache') def record_dofn(self): + class RecordDoFn(beam.DoFn): + def process(self, element): WriteFilesTest.all_records.append(element) diff --git a/sdks/python/apache_beam/io/filesystem.py b/sdks/python/apache_beam/io/filesystem.py index bdc25dcf0fe5..e7c649b47430 100644 --- a/sdks/python/apache_beam/io/filesystem.py +++ b/sdks/python/apache_beam/io/filesystem.py @@ -450,6 +450,7 @@ class FileMetadata(object): last_updated_in_seconds: [Optional] last modified timestamp of the file, or valued 0.0 if not specified. """ + def __init__( self, path: str, @@ -486,12 +487,14 @@ class MatchResult(object): """Result from the ``FileSystem`` match operation which contains the list of matched ``FileMetadata``. """ + def __init__(self, pattern: str, metadata_list: List[FileMetadata]) -> None: self.metadata_list = metadata_list self.pattern = pattern class BeamIOError(IOError): + def __init__(self, msg, exception_details=None): """Class representing the errors thrown in the batch file operations. Args: diff --git a/sdks/python/apache_beam/io/filesystem_test.py b/sdks/python/apache_beam/io/filesystem_test.py index a4d456a366da..731d276e1647 100644 --- a/sdks/python/apache_beam/io/filesystem_test.py +++ b/sdks/python/apache_beam/io/filesystem_test.py @@ -42,6 +42,7 @@ class TestingFileSystem(FileSystem): + def __init__(self, pipeline_options, has_dirs=False): super().__init__(pipeline_options) self._has_dirs = has_dirs @@ -112,6 +113,7 @@ def delete(self, paths): class TestFileSystem(unittest.TestCase): + def setUp(self): self.fs = TestingFileSystem(pipeline_options=None) @@ -210,8 +212,8 @@ def test_match_glob(self, file_pattern, expected_object_names): # It's a filter function of type (str, int) -> bool # that returns true for expected objects filter_func = expected_object_names - expected_object_names = [(short_path, size) for short_path, - size in objects + expected_object_names = [(short_path, size) + for short_path, size in objects if filter_func(short_path, size)] for object_name, size in objects: @@ -219,8 +221,7 @@ def test_match_glob(self, file_pattern, expected_object_names): self.fs._insert_random_file(file_name, size) expected_file_names = [('gs://%s/%s' % (bucket_name, object_name), size) - for object_name, - size in expected_object_names] + for object_name, size in expected_object_names] actual_file_names = [ (file_metadata.path, file_metadata.size_in_bytes) for file_metadata in self._flatten_match(self.fs.match([file_pattern])) @@ -258,6 +259,7 @@ def test_translate_pattern(self, os_path, sep_re): class TestFileSystemWithDirs(TestFileSystem): + def setUp(self): self.fs = TestingFileSystem(pipeline_options=None, has_dirs=True) diff --git a/sdks/python/apache_beam/io/filesystemio.py b/sdks/python/apache_beam/io/filesystemio.py index 571d1f2d2699..340b5065d539 100644 --- a/sdks/python/apache_beam/io/filesystemio.py +++ b/sdks/python/apache_beam/io/filesystemio.py @@ -37,6 +37,7 @@ class Downloader(metaclass=abc.ABCMeta): Implementations should support random access reads. """ + @property @abc.abstractmethod def size(self): @@ -60,6 +61,7 @@ def get_range(self, start, end): class Uploader(metaclass=abc.ABCMeta): """Upload interface for a single file.""" + @abc.abstractmethod def put(self, data): """Write data to file sequentially. @@ -81,6 +83,7 @@ def finish(self): class DownloaderStream(io.RawIOBase): """Provides a stream interface for Downloader objects.""" + def __init__( self, downloader, read_buffer_size=io.DEFAULT_BUFFER_SIZE, mode='rb'): """Initializes the stream. @@ -174,6 +177,7 @@ def readall(self): class UploaderStream(io.RawIOBase): """Provides a stream interface for Uploader objects.""" + def __init__(self, uploader, mode='wb'): """Initializes the stream. @@ -228,6 +232,7 @@ class PipeStream(object): Remembers the last ``size`` bytes read and allows rewinding the stream by that amount exactly. See BEAM-6380 for more. """ + def __init__(self, recv_pipe): self.conn = recv_pipe self.closed = False diff --git a/sdks/python/apache_beam/io/filesystemio_test.py b/sdks/python/apache_beam/io/filesystemio_test.py index cbf9449e78d1..48206687c0b5 100644 --- a/sdks/python/apache_beam/io/filesystemio_test.py +++ b/sdks/python/apache_beam/io/filesystemio_test.py @@ -32,6 +32,7 @@ class FakeDownloader(filesystemio.Downloader): + def __init__(self, data): self._data = data self.last_read_size = -1 @@ -46,6 +47,7 @@ def get_range(self, start, end): class FakeUploader(filesystemio.Uploader): + def __init__(self): self.data = b'' self.last_write_size = -1 @@ -64,6 +66,7 @@ def finish(self): class TestDownloaderStream(unittest.TestCase): + def test_file_attributes(self): downloader = FakeDownloader(data=None) stream = filesystemio.DownloaderStream(downloader) @@ -102,6 +105,7 @@ def test_read_buffered(self): class TestUploaderStream(unittest.TestCase): + def test_file_attributes(self): uploader = FakeUploader() stream = filesystemio.UploaderStream(uploader) @@ -147,6 +151,7 @@ def test_write_buffered(self): class TestPipeStream(unittest.TestCase): + def _read_and_verify(self, stream, expected, buffer_size, success): data_list = [] bytes_read = 0 diff --git a/sdks/python/apache_beam/io/filesystems_test.py b/sdks/python/apache_beam/io/filesystems_test.py index 1ea7c34d9f4f..66bf316c7226 100644 --- a/sdks/python/apache_beam/io/filesystems_test.py +++ b/sdks/python/apache_beam/io/filesystems_test.py @@ -36,6 +36,7 @@ def _gen_fake_join(separator): """Returns a callable that joins paths with the given separator.""" + def _join(first_path, *paths): return separator.join((first_path.rstrip(separator), ) + paths) @@ -43,6 +44,7 @@ def _join(first_path, *paths): class FileSystemsTest(unittest.TestCase): + def setUp(self): self.tmpdir = tempfile.mkdtemp() diff --git a/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source_test.py b/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source_test.py index f6ba5abdd575..15405e459e45 100644 --- a/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source_test.py +++ b/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source_test.py @@ -26,6 +26,7 @@ class FlinkStreamingImpulseSourceTest(unittest.TestCase): + def test_serialization(self): p = beam.Pipeline() # pylint: disable=expression-not-assigned diff --git a/sdks/python/apache_beam/io/gcp/__init__.py b/sdks/python/apache_beam/io/gcp/__init__.py index f88a0117aa46..bf5dfd0acb57 100644 --- a/sdks/python/apache_beam/io/gcp/__init__.py +++ b/sdks/python/apache_beam/io/gcp/__init__.py @@ -25,7 +25,9 @@ from apitools.base.py import transfer class _WrapperNamespace(object): + class BytesGenerator(email_generator.BytesGenerator): + def _write_lines(self, lines): self.write(lines) diff --git a/sdks/python/apache_beam/io/gcp/big_query_query_to_table_it_test.py b/sdks/python/apache_beam/io/gcp/big_query_query_to_table_it_test.py index 052790c4a202..1b85997dd720 100644 --- a/sdks/python/apache_beam/io/gcp/big_query_query_to_table_it_test.py +++ b/sdks/python/apache_beam/io/gcp/big_query_query_to_table_it_test.py @@ -89,6 +89,7 @@ class BigQueryQueryToTableIT(unittest.TestCase): + def setUp(self): self.test_pipeline = TestPipeline(is_integration_test=True) self.runner_name = type(self.test_pipeline.runner).__name__ diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index 9f60b5af6726..ce8441e29821 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -520,6 +520,7 @@ class TableRowJsonCoder(coders.Coder): table schema in order to obtain the ordered list of field names. Reading from sources on the other hand does not need the table schema. """ + def __init__(self, table_schema=None): # The table schema is needed for encoding TableRows as JSON (writing to # sinks) because the ordered list of field names is used in the JSON @@ -642,6 +643,7 @@ class _BigQueryExportResult: class _CustomBigQuerySource(BoundedSource): + def __init__( self, method, @@ -829,8 +831,10 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None): weight=1.0, source=source, start_position=None, stop_position=None) def get_range_tracker(self, start_position, stop_position): + class CustomBigQuerySourceRangeTracker(RangeTracker): """A RangeTracker that always returns positions as None.""" + def start_position(self): return None @@ -1221,9 +1225,11 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None): weight=1.0, source=source, start_position=None, stop_position=None) def get_range_tracker(self, start_position, stop_position): + class NonePositionRangeTracker(RangeTracker): """A RangeTracker that always returns positions as None. Prevents the BigQuery Storage source from being read() before being split().""" + def start_position(self): return None @@ -1341,6 +1347,7 @@ def read_avro(self): class _ReadReadRowsResponsesWithFastAvro(): """An iterator that deserializes ReadRowsResponses using the fastavro library.""" + def __init__(self, read_rows_iterator, read_rows_response): self.read_rows_iterator = read_rows_iterator self.read_rows_response = read_rows_response @@ -1617,8 +1624,8 @@ def process( # Flush current batch first if adding this row will exceed our limits # limits: byte size; number of rows - if ((self._destination_buffer_byte_size[destination] + row_byte_size > - self._max_insert_payload_size) or + if ((self._destination_buffer_byte_size[destination] + row_byte_size + > self._max_insert_payload_size) or len(self._rows_buffer[destination]) >= self._max_batch_size): flushed_batch = self._flush_batch(destination) # After flushing our existing batch, we now buffer the current row @@ -1712,9 +1719,8 @@ def _flush_batch(self, destination): # - WARNING when we are continuing to retry, and have a deadline. # - ERROR when we will no longer retry, or MAY retry forever. log_level = ( - logging.WARN if should_retry or - self._retry_strategy != RetryStrategy.RETRY_ALWAYS else - logging.ERROR) + logging.WARN if should_retry or self._retry_strategy + != RetryStrategy.RETRY_ALWAYS else logging.ERROR) _LOGGER.log(log_level, message) @@ -1740,16 +1746,13 @@ def _flush_batch(self, destination): [ pvalue.TaggedOutput( BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS, - w.with_value((destination, row, err))) for row, - err, - w in failed_rows + w.with_value((destination, row, err))) + for row, err, w in failed_rows ], [ pvalue.TaggedOutput( BigQueryWriteFn.FAILED_ROWS, w.with_value((destination, row))) - for row, - unused_err, - w in failed_rows + for row, unused_err, w in failed_rows ]) @@ -1761,6 +1764,7 @@ def _flush_batch(self, destination): class _StreamToBigQuery(PTransform): + def __init__( self, table_reference, @@ -1801,6 +1805,7 @@ def __init__( self._max_insert_payload_size = max_insert_payload_size class InsertIdPrefixFn(DoFn): + def start_bundle(self): self.prefix = str(uuid.uuid4()) self._row_count = 0 @@ -1891,6 +1896,7 @@ class WriteToBigQuery(PTransform): tables. The elements would come in as Python dictionaries, or as `TableRow` instances. """ + class Method(object): DEFAULT = 'DEFAULT' STREAMING_INSERTS = 'STREAMING_INSERTS' @@ -2331,10 +2337,9 @@ def to_runner_api_parameter(self, context): # remove_objects_from_args and insert_values_in_args # are currently implemented. def serialize(side_inputs): - return {(SIDE_INPUT_PREFIX + '%s') % ix: - si.to_runner_api(context).SerializeToString() - for ix, - si in enumerate(side_inputs)} + return {(SIDE_INPUT_PREFIX + '%s') % ix: si.to_runner_api( + context).SerializeToString() + for ix, si in enumerate(side_inputs)} table_side_inputs = serialize(self.table_side_inputs) schema_side_inputs = serialize(self.schema_side_inputs) @@ -2382,8 +2387,8 @@ def deserialize(side_inputs): # to_runner_api_parameter above). indexed_side_inputs = [( get_sideinput_index(tag), - pvalue.AsSideInput.from_runner_api(si, context)) for tag, - si in deserialized_side_inputs.items()] + pvalue.AsSideInput.from_runner_api(si, context)) + for tag, si in deserialized_side_inputs.items()] return [si for _, si in sorted(indexed_side_inputs)] config['table_side_inputs'] = deserialize(config['table_side_inputs']) @@ -2395,6 +2400,7 @@ def deserialize(side_inputs): class WriteResult: """The result of a WriteToBigQuery transform. """ + def __init__( self, method: str = None, @@ -2661,8 +2667,8 @@ def expand(self, input): failed_rows = failed_rows | beam.Map(lambda row: row.as_dict()) failed_rows_with_errors = failed_rows_with_errors | beam.Map( lambda row: { - "error_message": row.error_message, - "failed_row": row.failed_row.as_dict() + "error_message": row.error_message, "failed_row": row.failed_row. + as_dict() }) return WriteResult( @@ -2671,6 +2677,7 @@ def expand(self, input): failed_rows_with_errors=failed_rows_with_errors) class ConvertToBeamRows(PTransform): + def __init__(self, schema, dynamic_destinations): self.schema = schema self.dynamic_destinations = dynamic_destinations @@ -2682,8 +2689,8 @@ def expand(self, input_dicts): | "Convert dict to Beam Row" >> beam.Map( lambda row: beam.Row( **{ - StorageWriteToBigQuery.DESTINATION: row[0], - StorageWriteToBigQuery.RECORD: bigquery_tools. + StorageWriteToBigQuery.DESTINATION: row[ + 0], StorageWriteToBigQuery.RECORD: bigquery_tools. beam_row_from_dict(row[1], self.schema) }))) else: @@ -2804,6 +2811,7 @@ class ReadFromBigQuery(PTransform): `BEAM_ROW`. For more information on schemas, see https://beam.apache.org/documentation/programming-guide/#what-is-a-schema) """ + class Method(object): EXPORT = 'EXPORT' # This is currently the default. DIRECT_READ = 'DIRECT_READ' @@ -2972,6 +2980,7 @@ class ReadFromBigQueryRequest: """ Class that defines data to read from BQ. """ + def __init__( self, query: str = None, diff --git a/sdks/python/apache_beam/io/gcp/bigquery_avro_tools_test.py b/sdks/python/apache_beam/io/gcp/bigquery_avro_tools_test.py index eca208d26612..36f652ad7cee 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_avro_tools_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_avro_tools_test.py @@ -27,6 +27,7 @@ @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestBigQueryToAvroSchema(unittest.TestCase): + def test_convert_bigquery_schema_to_avro_schema(self): subfields = [ bigquery.TableFieldSchema( diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py index 3145fb511068..2a6156eb4624 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py @@ -92,6 +92,7 @@ def _generate_job_name(job_name, job_type, step_name): def file_prefix_generator( with_validation=True, pipeline_gcs_location=None, temp_location=None): + def _generate_file_prefix(unused_elm): # If a gcs location is provided to the pipeline, then we shall use that. # Otherwise, we shall use the temp_location from pipeline options. @@ -291,6 +292,7 @@ class WriteGroupedRecordsToFile(beam.DoFn): Experimental; no backwards compatibility guarantees. """ + def __init__( self, schema, max_file_size=_DEFAULT_MAX_FILE_SIZE, file_format=None): self.schema = schema @@ -344,6 +346,7 @@ class UpdateDestinationSchema(beam.DoFn): Experimental; no backwards compatibility guarantees. """ + def __init__( self, project=None, @@ -790,6 +793,7 @@ class PartitionFiles(beam.DoFn): SINGLE_PARTITION_TAG = 'SINGLE_PARTITION' class Partition(object): + def __init__(self, max_size, max_files, files=None, size=0): self.max_size = max_size self.max_files = max_files @@ -848,6 +852,7 @@ def process(self, element): class DeleteTablesFn(beam.DoFn): + def __init__(self, test_client=None): self.test_client = test_client @@ -1145,8 +1150,7 @@ def _load_data( # https://github.com/apache/beam/issues/24535. finished_temp_tables_load_job_ids_list_pc = ( finished_temp_tables_load_job_ids_pc | beam.MapTuple( - lambda destination, - job_reference: ( + lambda destination, job_reference: ( bigquery_tools.parse_table_reference(destination).tableId, (destination, job_reference))) | beam.GroupByKey() @@ -1234,8 +1238,7 @@ def expand(self, pcoll): singleton_pc | "SchemaModJobNamePrefix" >> beam.Map( lambda _: _generate_job_name( - job_name, - bigquery_tools.BigQueryJobTypes.LOAD, + job_name, bigquery_tools.BigQueryJobTypes.LOAD, 'SCHEMA_MOD_STEP'))) copy_job_name_pcv = pvalue.AsSingleton( diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py index 10453d9c8baf..051aee5507ef 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py @@ -254,6 +254,7 @@ def check_many_files(output_pcs): class TestWriteGroupedRecordsToFile(_TestCaseWithTempDirCleanUp): + def _consume_input(self, fn, input, checks): if checks is None: return @@ -404,6 +405,7 @@ def test_partition_files_dofn_size_split(self): class TestBigQueryFileLoads(_TestCaseWithTempDirCleanUp): + def test_trigger_load_jobs_with_empty_files(self): destination = "project:dataset.table" empty_files = [] @@ -799,6 +801,7 @@ def test_triggering_frequency(self, is_streaming, with_auto_sharding): # Insert a fake clock to work with auto-sharding which needs a processing # time timer. class _FakeClock(object): + def __init__(self, now=time.time()): self._now = now @@ -827,12 +830,12 @@ def __call__(self): if is_streaming: _SIZE = len(_ELEMENTS) fisrt_batch = [ - TimestampedValue(value, start_time + i + 1) for i, - value in enumerate(_ELEMENTS[:_SIZE // 2]) + TimestampedValue(value, start_time + i + 1) + for i, value in enumerate(_ELEMENTS[:_SIZE // 2]) ] second_batch = [ - TimestampedValue(value, start_time + _SIZE // 2 + i + 1) for i, - value in enumerate(_ELEMENTS[_SIZE // 2:]) + TimestampedValue(value, start_time + _SIZE // 2 + i + 1) + for i, value in enumerate(_ELEMENTS[_SIZE // 2:]) ] # Advance processing time between batches of input elements to fire the # user triggers. Intentionally advance the processing time twice for the @@ -1031,12 +1034,10 @@ def test_multiple_destinations_transform(self): _ = ( input | "WriteWithMultipleDestsFreely" >> bigquery.WriteToBigQuery( - table=lambda x, - tables: + table=lambda x, tables: (tables['table1'] if 'language' in x else tables['table2']), table_side_inputs=(table_record_pcv, ), - schema=lambda dest, - schema_map: schema_map.get(dest, None), + schema=lambda dest, schema_map: schema_map.get(dest, None), schema_side_inputs=(schema_map_pcv, ), create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, write_disposition=beam.io.BigQueryDisposition.WRITE_EMPTY)) @@ -1045,8 +1046,7 @@ def test_multiple_destinations_transform(self): input | "WriteWithMultipleDests" >> bigquery.WriteToBigQuery( table=lambda x: (output_table_3 if 'language' in x else output_table_4), - schema=lambda dest, - schema_map: schema_map.get(dest, None), + schema=lambda dest, schema_map: schema_map.get(dest, None), schema_side_inputs=(schema_map_pcv, ), create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, write_disposition=beam.io.BigQueryDisposition.WRITE_EMPTY, diff --git a/sdks/python/apache_beam/io/gcp/bigquery_io_metadata.py b/sdks/python/apache_beam/io/gcp/bigquery_io_metadata.py index a730f2cfc9bb..c98a841c3a62 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_io_metadata.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_io_metadata.py @@ -80,6 +80,7 @@ class BigQueryIOMetadata(object): Do not construct directly, use the create_bigquery_io_metadata factory. Which will request metadata properly based on which runner is being used. """ + def __init__(self, beam_job_id=None, step_name=None): self.beam_job_id = beam_job_id self.step_name = step_name diff --git a/sdks/python/apache_beam/io/gcp/bigquery_io_metadata_test.py b/sdks/python/apache_beam/io/gcp/bigquery_io_metadata_test.py index c91202465388..a83b2e0a78be 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_io_metadata_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_io_metadata_test.py @@ -26,6 +26,7 @@ class BigqueryIoMetadataTest(unittest.TestCase): + def test_is_valid_cloud_label_value(self): # A dataflow job ID. # Lowercase letters, numbers, underscores and hyphens are allowed. diff --git a/sdks/python/apache_beam/io/gcp/bigquery_io_read_pipeline.py b/sdks/python/apache_beam/io/gcp/bigquery_io_read_pipeline.py index 8ca9736b025e..b294c5401254 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_io_read_pipeline.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_io_read_pipeline.py @@ -37,6 +37,7 @@ class RowToStringWithSlowDown(beam.DoFn): + def process(self, element, num_slow=0, *args, **kwargs): if num_slow == 0: diff --git a/sdks/python/apache_beam/io/gcp/bigquery_json_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_json_it_test.py index 6716aa1bc10f..e1e71aa875d7 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_json_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_json_it_test.py @@ -137,6 +137,7 @@ def maybe_unescape(value): return json.loads(value) class CompareJson(beam.DoFn, unittest.TestCase): + def process(self, row): country_code = row["country_code"] expected = json_data[country_code] diff --git a/sdks/python/apache_beam/io/gcp/bigquery_read_internal.py b/sdks/python/apache_beam/io/gcp/bigquery_read_internal.py index f038b48e04d5..0ab27650a9f8 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_read_internal.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_read_internal.py @@ -109,15 +109,19 @@ class _PassThroughThenCleanup(PTransform): Utilizes readiness of PCollection to trigger DoFn. """ + def __init__(self, side_input=None): self.side_input = side_input def expand(self, input): + class PassThrough(beam.DoFn): + def process(self, element): yield element class RemoveExtractedFiles(beam.DoFn): + def process(self, unused_element, unused_signal, gcs_locations): FileSystems.delete(list(gcs_locations)) @@ -144,6 +148,7 @@ class _PassThroughThenCleanupTempDatasets(PTransform): Utilizes readiness of PCollection to trigger DoFn. """ + def __init__(self, side_input=None): self.side_input = side_input @@ -151,10 +156,12 @@ def expand(self, input): pipeline_options = input.pipeline.options class PassThrough(beam.DoFn): + def process(self, element): yield element class CleanUpProjects(beam.DoFn): + def process(self, unused_element, unused_signal, pipeline_details): bq = bigquery_tools.BigQueryWrapper.from_pipeline_options( pipeline_options) @@ -191,6 +198,7 @@ class _BigQueryReadSplit(beam.transforms.DoFn): This transform will start a BigQuery export job, and output a number of file sources that are consumed downstream. """ + def __init__( self, options: PipelineOptions, @@ -404,6 +412,7 @@ def _get_project(self): class _JsonToDictCoder(coders.Coder): """A coder for a JSON string to a Python dict.""" + def __init__(self, table_schema): self.fields = self._convert_to_tuple(table_schema.fields) self._converters = { diff --git a/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py index 913d6e078d89..7fc6ab581428 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py @@ -62,6 +62,7 @@ def skip(runners): runners = [runners] def inner(fn): + @wraps(fn) def wrapped(self): if self.runner_name in runners: @@ -603,6 +604,7 @@ def test_iobase_source_with_query_and_filters(self): class ReadNewTypesTests(BigQueryReadIntegrationTests): + @classmethod def setUpClass(cls): super(ReadNewTypesTests, cls).setUpClass() @@ -824,6 +826,7 @@ def test_read_queries(self): class ReadInteractiveRunnerTests(BigQueryReadIntegrationTests): + @skip(['PortableRunner', 'FlinkRunner']) @pytest.mark.it_postcommit def test_read_in_interactive_runner(self): diff --git a/sdks/python/apache_beam/io/gcp/bigquery_read_perf_test.py b/sdks/python/apache_beam/io/gcp/bigquery_read_perf_test.py index 0b4cfe2ecbae..2f989c7c3684 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_read_perf_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_read_perf_test.py @@ -79,6 +79,7 @@ class BigQueryReadPerfTest(LoadTest): + def __init__(self): super().__init__() self.input_dataset = self.pipeline.get_option('input_dataset') diff --git a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py index beb373a7dea3..7d3acd4d16cb 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py @@ -101,6 +101,7 @@ def convert_to_usertype(table_schema, selected_fields=None): class BeamSchemaConversionDoFn(DoFn): + def __init__(self, pcoll_val_ctor): self._pcoll_val_ctor = pcoll_val_ctor diff --git a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools_test.py b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools_test.py index 7ae49dff205d..7079d5b4f159 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools_test.py @@ -35,6 +35,7 @@ @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestBigQueryToSchema(unittest.TestCase): + def test_check_schema_conversions(self): fields = [ bigquery.TableFieldSchema(name='stn', type='STRING', mode="NULLABLE"), @@ -173,6 +174,7 @@ def test_unsupported_value_provider(self): table=value_provider.ValueProvider(), output_type='BEAM_ROW') def test_unsupported_callable(self): + def filterTable(table): if table is not None: return table diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py index 435fe67d02fc..8e5070126d5e 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py @@ -148,6 +148,7 @@ def _load_or_default(filename): HttpError is None or gcp_bigquery is None, 'GCP dependencies are not installed') class TestTableRowJsonCoder(unittest.TestCase): + def test_row_as_table_row(self): schema_definition = [('s', 'STRING'), ('i', 'INTEGER'), ('f', 'FLOAT'), ('b', 'BOOLEAN'), ('n', 'NUMERIC'), ('r', 'RECORD'), @@ -172,8 +173,8 @@ def test_row_as_table_row(self): '"g": "LINESTRING(1 2, 3 4, 5 6, 7 8)"}') schema = bigquery.TableSchema( fields=[ - bigquery.TableFieldSchema(name=k, type=v) for k, - v in schema_definition + bigquery.TableFieldSchema(name=k, type=v) + for k, v in schema_definition ]) coder = TableRowJsonCoder(table_schema=schema) @@ -211,8 +212,8 @@ def json_compliance_exception(self, value): schema_definition = [('f', 'FLOAT')] schema = bigquery.TableSchema( fields=[ - bigquery.TableFieldSchema(name=k, type=v) for k, - v in schema_definition + bigquery.TableFieldSchema(name=k, type=v) + for k, v in schema_definition ]) coder = TableRowJsonCoder(table_schema=schema) test_row = bigquery.TableRow( @@ -231,8 +232,10 @@ def test_invalid_json_neg_inf(self): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestJsonToDictCoder(unittest.TestCase): + @staticmethod def _make_schema(fields): + def _fill_schema(fields): for field in fields: table_field = bigquery.TableFieldSchema() @@ -334,9 +337,12 @@ def test_repeatable_field_is_properly_converted(self): HttpError is None or HttpForbiddenError is None, 'GCP dependencies are not installed') class TestReadFromBigQuery(unittest.TestCase): + @classmethod def setUpClass(cls): + class UserDefinedOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument('--gcs_location') @@ -507,7 +513,9 @@ def test_create_temp_dataset_exception(self, exception_type, error_message): expected_retries=3), ]) def test_get_table_transient_exception(self, responses, expected_retries): + class DummyTable: + class DummySchema: fields = [] @@ -619,7 +627,9 @@ def store_callback(unused_request): expected_retries=2), ]) def test_get_table_non_transient_exception(self, responses, expected_retries): + class DummyTable: + class DummySchema: fields = [] @@ -765,6 +775,7 @@ def test_read_all_lineage(self): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestBigQuerySink(unittest.TestCase): + def test_table_spec_display_data(self): sink = beam.io.BigQuerySink('dataset.table') dd = DisplayData.create_from(sink) @@ -796,6 +807,7 @@ def test_project_table_display_data(self): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestWriteToBigQuery(unittest.TestCase): + def _cleanup_files(self): if os.path.exists('insert_calls1'): os.remove('insert_calls1') @@ -962,8 +974,7 @@ def test_to_from_runner_api(self): schema = value_provider.StaticValueProvider(str, '"a:str"') original = WriteToBigQuery( - table=lambda _, - side_input: side_input['table'], + table=lambda _, side_input: side_input['table'], table_side_inputs=(table_record_pcv, ), schema=schema) @@ -978,8 +989,7 @@ def test_to_from_runner_api(self): # Find the transform from the context. write_to_bq_id = [ - k for k, - v in pipeline_proto.components.transforms.items() + k for k, v in pipeline_proto.components.transforms.items() if v.unique_name == 'MyWriteToBigQuery' ][0] deserialized_node = context.transforms.get_by_id(write_to_bq_id) @@ -1010,6 +1020,7 @@ def test_to_from_runner_api(self): original_side_input_data.view_fn, deserialized_side_input_data.view_fn) def test_streaming_triggering_frequency_without_auto_sharding(self): + def noop(table, **kwargs): return [] @@ -1039,6 +1050,7 @@ def noop(table, **kwargs): test_client=client)) def test_streaming_triggering_frequency_with_auto_sharding(self): + def noop(table, **kwargs): return [] @@ -1219,25 +1231,27 @@ class BigQueryStreamingInsertsErrorHandling(unittest.TestCase): # failed rows param( insert_response=[ - exceptions.TooManyRequests if exceptions else None, - None], - error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS + exceptions.TooManyRequests if exceptions else None, None + ], + error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS failed_rows=[]), # reason not in _NON_TRANSIENT_ERRORS for row 1 on both attempts, sent to # failed rows after hitting max_retries param( insert_response=[ - exceptions.InternalServerError if exceptions else None, - exceptions.InternalServerError if exceptions else None], - error_reason='Internal Server Error', # not in _NON_TRANSIENT_ERRORS + exceptions.InternalServerError if exceptions else None, + exceptions.InternalServerError if exceptions else None + ], + error_reason='Internal Server Error', # not in _NON_TRANSIENT_ERRORS failed_rows=['value1', 'value3', 'value5']), # reason in _NON_TRANSIENT_ERRORS for row 1 on both attempts, sent to # failed_rows after hitting max_retries param( insert_response=[ - exceptions.Forbidden if exceptions else None, - exceptions.Forbidden if exceptions else None], - error_reason='Forbidden', # in _NON_TRANSIENT_ERRORS + exceptions.Forbidden if exceptions else None, + exceptions.Forbidden if exceptions else None + ], + error_reason='Forbidden', # in _NON_TRANSIENT_ERRORS failed_rows=['value1', 'value3', 'value5']), ]) def test_insert_rows_json_exception_retry_always( @@ -1363,63 +1377,63 @@ def test_insert_rows_json_exception_retry_never( @parameterized.expand([ param( exception_type=exceptions.DeadlineExceeded if exceptions else None, - error_reason='Deadline Exceeded', # not in _NON_TRANSIENT_ERRORS + error_reason='Deadline Exceeded', # not in _NON_TRANSIENT_ERRORS failed_values=[], expected_call_count=2), param( exception_type=exceptions.Conflict if exceptions else None, - error_reason='Conflict', # not in _NON_TRANSIENT_ERRORS + error_reason='Conflict', # not in _NON_TRANSIENT_ERRORS failed_values=[], expected_call_count=2), param( exception_type=exceptions.TooManyRequests if exceptions else None, - error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS + error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS failed_values=[], expected_call_count=2), param( exception_type=exceptions.InternalServerError if exceptions else None, - error_reason='Internal Server Error', # not in _NON_TRANSIENT_ERRORS + error_reason='Internal Server Error', # not in _NON_TRANSIENT_ERRORS failed_values=[], expected_call_count=2), param( exception_type=exceptions.BadGateway if exceptions else None, - error_reason='Bad Gateway', # not in _NON_TRANSIENT_ERRORS + error_reason='Bad Gateway', # not in _NON_TRANSIENT_ERRORS failed_values=[], expected_call_count=2), param( exception_type=exceptions.ServiceUnavailable if exceptions else None, - error_reason='Service Unavailable', # not in _NON_TRANSIENT_ERRORS + error_reason='Service Unavailable', # not in _NON_TRANSIENT_ERRORS failed_values=[], expected_call_count=2), param( exception_type=exceptions.GatewayTimeout if exceptions else None, - error_reason='Gateway Timeout', # not in _NON_TRANSIENT_ERRORS + error_reason='Gateway Timeout', # not in _NON_TRANSIENT_ERRORS failed_values=[], expected_call_count=2), param( exception_type=exceptions.BadRequest if exceptions else None, - error_reason='Bad Request', # in _NON_TRANSIENT_ERRORS + error_reason='Bad Request', # in _NON_TRANSIENT_ERRORS failed_values=['value1', 'value2'], expected_call_count=1), param( exception_type=exceptions.Unauthorized if exceptions else None, - error_reason='Unauthorized', # in _NON_TRANSIENT_ERRORS + error_reason='Unauthorized', # in _NON_TRANSIENT_ERRORS failed_values=['value1', 'value2'], expected_call_count=1), param( exception_type=exceptions.Forbidden if exceptions else None, - error_reason='Forbidden', # in _NON_TRANSIENT_ERRORS + error_reason='Forbidden', # in _NON_TRANSIENT_ERRORS failed_values=['value1', 'value2'], expected_call_count=1), param( exception_type=exceptions.NotFound if exceptions else None, - error_reason='Not Found', # in _NON_TRANSIENT_ERRORS + error_reason='Not Found', # in _NON_TRANSIENT_ERRORS failed_values=['value1', 'value2'], expected_call_count=1), param( exception_type=exceptions.MethodNotImplemented - if exceptions else None, - error_reason='Not Implemented', # in _NON_TRANSIENT_ERRORS + if exceptions else None, + error_reason='Not Implemented', # in _NON_TRANSIENT_ERRORS failed_values=['value1', 'value2'], expected_call_count=1), ]) @@ -1915,6 +1929,7 @@ def store_callback(table, **kwargs): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class BigQueryStreamingInsertTransformTests(unittest.TestCase): + def test_dofn_client_process_performs_batching(self): client = mock.Mock() client.tables.Get.return_value = bigquery.Table( @@ -2051,6 +2066,7 @@ def test_with_batched_input(self): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class PipelineBasedStreamingInsertTest(_TestCaseWithTempDirCleanUp): + @mock.patch('time.sleep') def test_failure_has_same_insert_ids(self, unused_mock_sleep): tempdir = '%s%s' % (self._new_tempdir(), os.sep) @@ -2460,12 +2476,10 @@ def test_multiple_destinations_transform(self): r = ( input | "WriteWithMultipleDests" >> beam.io.gcp.bigquery.WriteToBigQuery( - table=lambda x, - tables: + table=lambda x, tables: (tables['table1'] if 'language' in x else tables['table2']), table_side_inputs=(table_record_pcv, ), - schema=lambda dest, - table_map: table_map.get(dest, None), + schema=lambda dest, table_map: table_map.get(dest, None), schema_side_inputs=(schema_table_pcv, ), insert_retry_strategy=RetryStrategy.RETRY_ON_TRANSIENT_ERROR, method='STREAMING_INSERTS')) @@ -2665,8 +2679,7 @@ def test_avro_file_load(self): input | 'WriteToBigQuery' >> beam.io.gcp.bigquery.WriteToBigQuery( table='%s:%s' % (self.project, self.output_table), - schema=lambda _, - schema: schema, + schema=lambda _, schema: schema, schema_side_inputs=(beam.pvalue.AsSingleton(schema_pc), ), method='FILE_LOADS', temp_file_format=bigquery_tools.FileFormat.AVRO, diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index 48da929a07b2..4686ded89b59 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -1282,8 +1282,8 @@ def insert_rows( # can happen during retries on failures. # TODO(silviuc): Must add support to writing TableRow's instead of dicts. insert_ids = [ - str(self.unique_row_id) if not insert_ids else insert_ids[i] for i, - _ in enumerate(rows) + str(self.unique_row_id) if not insert_ids else insert_ids[i] + for i, _ in enumerate(rows) ] rows = [ fast_json_loads(fast_json_dumps(r, default=default_encoder)) @@ -1392,6 +1392,7 @@ class RowAsDictJsonCoder(coders.Coder): This is the default coder for sources and sinks if the coder argument is not specified. """ + def encode(self, table_row): # The normal error when dumping NAN/INF values is: # ValueError: Out of range float values are not JSON compliant @@ -1419,6 +1420,7 @@ class JsonRowWriter(io.IOBase): A writer which provides an IOBase-like interface for writing table rows (represented as dicts) as newline-delimited JSON strings. """ + def __init__(self, file_handle): """Initialize an JsonRowWriter. @@ -1459,6 +1461,7 @@ class AvroRowWriter(io.IOBase): A writer which provides an IOBase-like interface for writing table rows (represented as dicts) as Avro records. """ + def __init__(self, file_handle, schema): """Initialize an AvroRowWriter. @@ -1552,6 +1555,7 @@ class AppendDestinationsFn(DoFn): Experimental; no backwards compatibility guarantees. """ + def __init__(self, destination): self._display_destination = destination self.destination = AppendDestinationsFn._get_table_fn(destination) @@ -1656,6 +1660,7 @@ def get_table_schema_from_string(schema): def table_schema_to_dict(table_schema): """Create a dictionary representation of table schema for serialization """ + def get_table_field(field): """Create a dictionary representation of a table field """ diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py index 1307a7886924..c463186e6f74 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py @@ -72,6 +72,7 @@ @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestTableSchemaParser(unittest.TestCase): + def test_parse_table_schema_from_json(self): string_field = bigquery.TableFieldSchema( name='s', type='STRING', mode='NULLABLE', description='s description') @@ -109,6 +110,7 @@ def test_parse_table_schema_from_json(self): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestTableReferenceParser(unittest.TestCase): + def test_calling_with_table_reference(self): table_ref = bigquery.TableReference() table_ref.projectId = 'test_project' @@ -175,6 +177,7 @@ def test_calling_with_all_arguments(self): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestBigQueryWrapper(unittest.TestCase): + def test_delete_non_existing_dataset(self): client = mock.Mock() client.datasets.Delete.side_effect = HttpError( @@ -388,6 +391,7 @@ def test_get_or_create_table_invalid_tablename(self, table_id): False) def test_wait_for_job_returns_true_when_job_is_done(self): + def make_response(state): m = mock.Mock() m.status.errorResult = None @@ -581,6 +585,7 @@ def test_start_query_job_priority_configuration(self): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestRowAsDictJsonCoder(unittest.TestCase): + def test_row_as_dict(self): coder = RowAsDictJsonCoder() test_value = {'s': 'abc', 'i': 123, 'f': 123.456, 'b': True} @@ -622,6 +627,7 @@ def test_ensure_ascii(self): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestJsonRowWriter(unittest.TestCase): + def test_write_row(self): rows = [ { @@ -654,6 +660,7 @@ def test_write_row(self): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestAvroRowWriter(unittest.TestCase): + def test_write_row(self): schema = bigquery.TableSchema( fields=[ @@ -682,6 +689,7 @@ def test_write_row(self): class TestBQJobNames(unittest.TestCase): + def test_simple_names(self): self.assertEqual( "beam_bq_job_EXPORT_beamappjobtest_abcd", @@ -720,6 +728,7 @@ def test_matches_template(self): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestCheckSchemaEqual(unittest.TestCase): + def test_simple_schemas(self): schema1 = bigquery.TableSchema(fields=[]) self.assertTrue(check_schema_equal(schema1, schema1)) @@ -985,9 +994,8 @@ def test_typehints_from_repeated_schema(self): schema = {"fields": self.get_schema_fields_with_mode("repeated")} typehints = get_beam_typehints_from_tableschema(schema) - expected_repeated_typehints = [ - (name, Sequence[type]) for name, type in self.EXPECTED_TYPEHINTS - ] + expected_repeated_typehints = [(name, Sequence[type]) + for name, type in self.EXPECTED_TYPEHINTS] self.assertEqual(typehints, expected_repeated_typehints) @@ -995,9 +1003,8 @@ def test_typehints_from_nullable_schema(self): schema = {"fields": self.get_schema_fields_with_mode("nullable")} typehints = get_beam_typehints_from_tableschema(schema) - expected_nullable_typehints = [ - (name, Optional[type]) for name, type in self.EXPECTED_TYPEHINTS - ] + expected_nullable_typehints = [(name, Optional[type]) + for name, type in self.EXPECTED_TYPEHINTS] self.assertEqual(typehints, expected_nullable_typehints) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py index cd3edf19de5f..783938447875 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py @@ -591,8 +591,10 @@ def test_big_query_write_temp_table_append_schema_update(self, file_format): max_file_size=1, # bytes method=beam.io.WriteToBigQuery.Method.FILE_LOADS, additional_bq_parameters={ - 'schemaUpdateOptions': ['ALLOW_FIELD_ADDITION', - 'ALLOW_FIELD_RELAXATION']}, + 'schemaUpdateOptions': [ + 'ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION' + ] + }, temp_file_format=file_format)) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_write_perf_test.py b/sdks/python/apache_beam/io/gcp/bigquery_write_perf_test.py index 1aafb1b60a85..f19ba97048d5 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_write_perf_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_write_perf_test.py @@ -67,6 +67,7 @@ class BigQueryWritePerfTest(LoadTest): + def __init__(self): super().__init__() self.output_dataset = self.pipeline.get_option('output_dataset') diff --git a/sdks/python/apache_beam/io/gcp/bigtableio.py b/sdks/python/apache_beam/io/gcp/bigtableio.py index ffb1852eb0f4..1911f62e2fc0 100644 --- a/sdks/python/apache_beam/io/gcp/bigtableio.py +++ b/sdks/python/apache_beam/io/gcp/bigtableio.py @@ -80,6 +80,7 @@ class _BigTableWriteFn(beam.DoFn): table_id(str): GCP Table ID """ + def __init__(self, project_id, instance_id, table_id): """ Constructor of the Write connector of Bigtable Args: @@ -253,6 +254,7 @@ def expand(self, input): self._project_id, self._instance_id, self._table_id))) class _DirectRowMutationsToBeamRow(beam.DoFn): + def process(self, direct_row): args = {"key": direct_row.row_key, "mutations": []} # start accumulating mutations in a list @@ -346,6 +348,7 @@ def expand(self, input): # To make use of those methods and to give Python users a more familiar # object, we process each Beam Row and return a PartialRowData equivalent. class _BeamRowToPartialRowData(beam.DoFn): + def process(self, row): key = row.key families = row.column_families diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/datastore_write_it_pipeline.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/datastore_write_it_pipeline.py index f56443fe6ed8..491213a86da4 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/datastore_write_it_pipeline.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/datastore_write_it_pipeline.py @@ -66,6 +66,7 @@ class EntityWrapper(object): Namespace and project are taken from the parent key. """ + def __init__(self, kind, parent_key): self._kind = kind self._parent_key = parent_key diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/datastoreio.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/datastoreio.py index f120234e9740..bc609e524272 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/datastoreio.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/datastoreio.py @@ -173,6 +173,7 @@ def display_data(self): @typehints.with_output_types(types.Query) class _SplitQueryFn(DoFn): """A `DoFn` that splits a given query into multiple sub-queries.""" + def __init__(self, num_splits): super().__init__() self._num_splits = num_splits @@ -281,6 +282,7 @@ def get_estimated_num_splits(client, query): @typehints.with_output_types(types.Entity) class _QueryFn(DoFn): """A DoFn that fetches entities from Cloud Datastore, for a given query.""" + def process(self, query, *unused_args, **unused_kwargs): if query.namespace is None: query.namespace = '' @@ -359,6 +361,7 @@ class DatastoreMutateFn(DoFn): should be idempotent (`upsert` and `delete` mutations) to prevent duplicate data or errors. """ + def __init__(self, project): """ Args: @@ -523,6 +526,7 @@ class WriteToDatastore(_Mutate): property key is empty then it is filled with the project ID passed to this transform. """ + def __init__( self, project, @@ -540,6 +544,7 @@ def __init__( super().__init__(mutate_fn, throttle_rampup, hint_num_workers) class _DatastoreWriteFn(_Mutate.DatastoreMutateFn): + def element_to_client_batch_item(self, element): if not isinstance(element, types.Entity): raise ValueError( @@ -575,6 +580,7 @@ class DeleteFromDatastore(_Mutate): project ID passed to this transform. If ``project`` field in key is empty then it is filled with the project ID passed to this transform. """ + def __init__( self, project, @@ -593,6 +599,7 @@ def __init__( super().__init__(mutate_fn, throttle_rampup, hint_num_workers) class _DatastoreDeleteFn(_Mutate.DatastoreMutateFn): + def element_to_client_batch_item(self, element): if not isinstance(element, types.Key): raise ValueError( diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/datastoreio_test.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/datastoreio_test.py index aac99cb8c1f0..e15895c1d020 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/datastoreio_test.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/datastoreio_test.py @@ -51,6 +51,7 @@ # used for internal testing only class FakeMessage: + def __init__(self, entity, key): self.entity = entity self.key = key @@ -64,6 +65,7 @@ def ByteSize(self): # used for internal testing only class FakeMutation(object): + def __init__(self, entity=None, key=None): """Fake mutation request object. @@ -81,6 +83,7 @@ def __init__(self, entity=None, key=None): class FakeBatch(object): + def __init__(self, all_batch_items=None, commit_count=None): """Fake ``google.cloud.datastore.batch.Batch`` object. @@ -116,6 +119,7 @@ def commit(self): @unittest.skipIf(client is None, 'Datastore dependencies are not installed') class MutateTest(unittest.TestCase): + def test_write_mutations_no_errors(self): mock_batch = MagicMock() mock_throttler = MagicMock() diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/query_splitter.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/query_splitter.py index 842579dfb40f..a9c3be668bd8 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/query_splitter.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/query_splitter.py @@ -133,6 +133,7 @@ class IdOrName(object): Implements sort ordering: by ID, then by name, keys with IDs before those with names. """ + def __init__(self, id_or_name): self.id_or_name = id_or_name if isinstance(id_or_name, str): diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/rampup_throttling_fn_test.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/rampup_throttling_fn_test.py index 0bbe953e93d2..a0baab36542d 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/rampup_throttling_fn_test.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/rampup_throttling_fn_test.py @@ -33,6 +33,7 @@ class _RampupDelayException(Exception): class RampupThrottlerTransformTest(unittest.TestCase): + @patch('datetime.datetime') @patch('time.sleep') def test_rampup_throttling(self, mock_sleep, mock_datetime): diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py index f7ce69099ca0..98a7fca27f6f 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py @@ -37,6 +37,7 @@ class Query(object): + def __init__( self, kind=None, @@ -152,6 +153,7 @@ def __repr__(self): class Key(object): + def __init__( self, path_elements: List[Union[str, int]], @@ -228,6 +230,7 @@ def __repr__(self): class Entity(object): + def __init__(self, key: Key, exclude_from_indexes: Iterable[str] = ()): """ Represents a Datastore entity. diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/util.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/util.py index 06a22143f59d..213f289f26bf 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/util.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/util.py @@ -46,6 +46,7 @@ class MovingSum(object): convenience we expose the count of entries as well so this doubles as a moving average tracker. """ + def __init__(self, window_ms, bucket_ms): if window_ms < bucket_ms or bucket_ms <= 0: raise ValueError("window_ms >= bucket_ms > 0 please") @@ -110,6 +111,7 @@ def has_data(self, now): class DynamicBatchSizer(object): """Determines request sizes for future Datastore RPCs.""" + def __init__(self): self._commit_time_per_entity_ms = MovingSum( window_ms=120000, bucket_ms=10000) diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/util_test.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/util_test.py index f82b223b67ed..de97c1d699eb 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/util_test.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/util_test.py @@ -66,6 +66,7 @@ def test_data_expires_from_moving_window(self): class DynamicWriteBatcherTest(unittest.TestCase): + def setUp(self): self._batcher = util.DynamicBatchSizer() diff --git a/sdks/python/apache_beam/io/gcp/experimental/spannerio.py b/sdks/python/apache_beam/io/gcp/experimental/spannerio.py index 3b616a2452a8..d45cbd4e3967 100644 --- a/sdks/python/apache_beam/io/gcp/experimental/spannerio.py +++ b/sdks/python/apache_beam/io/gcp/experimental/spannerio.py @@ -311,6 +311,7 @@ class _BeamSpannerConfiguration(namedtuple("_BeamSpannerConfiguration", A namedtuple holds the immutable data of the connection string to the cloud spanner. """ + @property def snapshot_options(self): snapshot_options = {} @@ -324,6 +325,7 @@ def snapshot_options(self): @with_input_types(ReadOperation, _SPANNER_TRANSACTION) @with_output_types(typing.List[typing.Any]) class _NaiveSpannerReadDoFn(DoFn): + def __init__(self, spanner_configuration): """ A naive version of Spanner read which uses the transaction API of the @@ -458,6 +460,7 @@ class _CreateReadPartitions(DoFn): mappings of information used perform actual partitioned reads via :meth:`process_read_batch`. """ + def __init__(self, spanner_configuration): self._spanner_configuration = spanner_configuration @@ -504,6 +507,7 @@ class _CreateTransactionFn(DoFn): https://googleapis.dev/python/spanner/latest/database-api.html?highlight= batch_snapshot#google.cloud.spanner_v1.database.BatchSnapshot.to_dict """ + def __init__( self, project_id, @@ -592,6 +596,7 @@ class _ReadFromPartitionFn(DoFn): """ A DoFn to perform reads from the partition. """ + def __init__(self, spanner_configuration): self._spanner_configuration = spanner_configuration self.base_labels = { @@ -681,14 +686,25 @@ class ReadFromSpanner(PTransform): ReadFromSpanner uses BatchAPI to perform all read operations. """ - def __init__(self, project_id, instance_id, database_id, pool=None, - read_timestamp=None, exact_staleness=None, credentials=None, - sql=None, params=None, param_types=None, # with_query - table=None, query_name=None, columns=None, index="", - keyset=None, # with_table - read_operations=None, # for read all - transaction=None - ): + def __init__( + self, + project_id, + instance_id, + database_id, + pool=None, + read_timestamp=None, + exact_staleness=None, + credentials=None, + sql=None, + params=None, + param_types=None, # with_query + table=None, + query_name=None, + columns=None, + index="", + keyset=None, # with_table + read_operations=None, # for read all + transaction=None): """ A PTransform that uses Spanner Batch API to perform reads. @@ -816,6 +832,7 @@ def display_data(self): class WriteToSpanner(PTransform): + def __init__( self, project_id, @@ -907,6 +924,7 @@ class MutationGroup(deque): """ A Bundle of Spanner Mutations (_Mutator). """ + @property def info(self): cells = 0 @@ -986,11 +1004,8 @@ def __init__( self._replace = replace self._delete = delete - if sum([1 for x in [self._insert, - self._update, - self._insert_or_update, - self._replace, - self._delete] if x is not None]) != 1: + if sum([1 for x in [self._insert, self._update, self._insert_or_update, + self._replace, self._delete] if x is not None]) != 1: raise ValueError( "No or more than one write mutation operation " "provided: <%s: %s>" % (self.__class__.__name__, str(self.__dict__))) @@ -1118,6 +1133,7 @@ class _BatchFn(DoFn): """ Batches mutations together. """ + def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells): self._max_batch_size_bytes = max_batch_size_bytes self._max_number_rows = max_number_rows @@ -1194,6 +1210,7 @@ def process(self, element): class _WriteToSpannerDoFn(DoFn): + def __init__(self, spanner_configuration): self._spanner_configuration = spanner_configuration self._db_instance = None @@ -1274,6 +1291,7 @@ class _MakeMutationGroupsFn(DoFn): """ Make Mutation group object if the element is the instance of _Mutator. """ + def process(self, element): if isinstance(element, MutationGroup): yield element @@ -1286,6 +1304,7 @@ def process(self, element): class _WriteGroup(PTransform): + def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells): self._max_batch_size_bytes = max_batch_size_bytes self._max_number_rows = max_number_rows diff --git a/sdks/python/apache_beam/io/gcp/experimental/spannerio_read_perf_test.py b/sdks/python/apache_beam/io/gcp/experimental/spannerio_read_perf_test.py index 18f6c29593e7..e88af3ab6c60 100644 --- a/sdks/python/apache_beam/io/gcp/experimental/spannerio_read_perf_test.py +++ b/sdks/python/apache_beam/io/gcp/experimental/spannerio_read_perf_test.py @@ -79,6 +79,7 @@ class SpannerReadPerfTest(LoadTest): + def __init__(self): super().__init__() self.project = self.pipeline.get_option('project') @@ -113,6 +114,7 @@ def _create_input_data(self): Runs an additional pipeline which creates test data and waits for its completion. """ + def format_record(record): import base64 return base64.b64encode(record[1]) diff --git a/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py b/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py index 0e22041dbea4..856bb35fab54 100644 --- a/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py +++ b/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py @@ -80,6 +80,7 @@ def _generate_test_data(): @mock.patch('apache_beam.io.gcp.experimental.spannerio.Client') @mock.patch('apache_beam.io.gcp.experimental.spannerio.BatchSnapshot') class SpannerReadTest(unittest.TestCase): + def test_read_with_query_batch( self, mock_batch_snapshot_class, mock_client_class): @@ -434,6 +435,7 @@ def test_display_data(self, *args): @mock.patch('apache_beam.io.gcp.experimental.spannerio.Client') @mock.patch('google.cloud.spanner_v1.database.BatchCheckout') class SpannerWriteTest(unittest.TestCase): + def test_spanner_write(self, mock_batch_snapshot_class, mock_batch_checkout): ks = spanner.KeySet(keys=[[1233], [1234]]) diff --git a/sdks/python/apache_beam/io/gcp/experimental/spannerio_write_perf_test.py b/sdks/python/apache_beam/io/gcp/experimental/spannerio_write_perf_test.py index c61608ff6743..42754f71b3d6 100644 --- a/sdks/python/apache_beam/io/gcp/experimental/spannerio_write_perf_test.py +++ b/sdks/python/apache_beam/io/gcp/experimental/spannerio_write_perf_test.py @@ -107,6 +107,7 @@ def _init_setup(self): self._create_database() def test(self): + def format_record(record): import base64 return base64.b64encode(record[1]) diff --git a/sdks/python/apache_beam/io/gcp/gcsfilesystem_test.py b/sdks/python/apache_beam/io/gcp/gcsfilesystem_test.py index ade8529dcac8..3b1fd95d5671 100644 --- a/sdks/python/apache_beam/io/gcp/gcsfilesystem_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsfilesystem_test.py @@ -40,6 +40,7 @@ @unittest.skipIf(gcsfilesystem is None, 'GCP dependencies are not installed') class GCSFileSystemTest(unittest.TestCase): + def setUp(self): pipeline_options = PipelineOptions() self.fs = gcsfilesystem.GCSFileSystem(pipeline_options=pipeline_options) @@ -189,10 +190,8 @@ def test_copy_file_error(self, mock_gcsio): gcsio_mock.copy.side_effect = exception # Issue batch rename. - expected_results = { - (s, d): exception - for s, d in zip(sources, destinations) - } + expected_results = {(s, d): exception + for s, d in zip(sources, destinations)} # Issue batch copy. with self.assertRaisesRegex(BeamIOError, diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index 8056de51f43f..a0d0b1ce0168 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -137,6 +137,7 @@ def create_storage_client(pipeline_options, use_credentials=True): class GcsIO(object): """Google Cloud Storage I/O client.""" + def __init__( self, storage_client: Optional[storage.Client] = None, @@ -584,6 +585,7 @@ def is_soft_delete_enabled(self, gcs_path): class BeamBlobReader(BlobReader): + def __init__( self, blob, @@ -605,6 +607,7 @@ def read(self, size=-1): class BeamBlobWriter(BlobWriter): + def __init__( self, blob, diff --git a/sdks/python/apache_beam/io/gcp/gcsio_retry_test.py b/sdks/python/apache_beam/io/gcp/gcsio_retry_test.py index 750879ae0284..f78d59de6da1 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_retry_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_retry_test.py @@ -37,6 +37,7 @@ @unittest.skipIf((gcsio_retry is None or api_exceptions is None), 'GCP dependencies are not installed') class TestGCSIORetry(unittest.TestCase): + def test_retry_on_non_retriable(self): mock = Mock(side_effect=[ Exception('Something wrong!'), diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py b/sdks/python/apache_beam/io/gcp/gcsio_test.py index 19df15dcf7fa..bebc001549ff 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py @@ -155,6 +155,7 @@ def delete_blob(self, name, **kwargs): class FakeBlob(object): + def __init__( self, name, @@ -228,6 +229,7 @@ def test_bad_gcs_path_object_optional(self): class SampleOptions(object): + def __init__(self, project, region, kms_key=None): self.project = DEFAULT_GCP_PROJECT self.region = region @@ -253,6 +255,7 @@ def _make_credentials(project=None, universe_domain=_DEFAULT_UNIVERSE_DOMAIN): @unittest.skipIf(NotFound is None, 'GCP dependencies are not installed') class TestGCSIO(unittest.TestCase): + def _insert_random_file( self, client, diff --git a/sdks/python/apache_beam/io/gcp/healthcare/dicomio.py b/sdks/python/apache_beam/io/gcp/healthcare/dicomio.py index a73de19d5a35..3e4185b368db 100644 --- a/sdks/python/apache_beam/io/gcp/healthcare/dicomio.py +++ b/sdks/python/apache_beam/io/gcp/healthcare/dicomio.py @@ -175,6 +175,7 @@ class DicomSearch(PTransform): } """ + def __init__( self, buffer_size=8, max_workers=5, client=None, credential=None): """Initializes DicomSearch. @@ -201,6 +202,7 @@ def expand(self, pcoll): class _QidoReadFn(beam.DoFn): """A DoFn for executing every qido query request.""" + def __init__(self, buffer_size, max_workers, client, credential=None): self.buffer_size = buffer_size self.max_workers = max_workers @@ -317,6 +319,7 @@ class FormatToQido(PTransform): } """ + def __init__(self, credential=None): """Initializes FormatToQido. Args: @@ -331,6 +334,7 @@ def expand(self, pcoll): class _ConvertStringToQido(beam.DoFn): """A DoFn for converting pubsub string to qido search parameters.""" + def process(self, element): # Some constants for DICOM pubsub message NUM_PUBSUB_STR_ENTRIES = 15 @@ -418,6 +422,7 @@ class UploadToDicomStore(PTransform): } """ + def __init__( self, destination_dict, @@ -475,6 +480,7 @@ def expand(self, pcoll): class _StoreInstance(beam.DoFn): """A DoFn read or fetch dicom files then push it to a dicom store.""" + def __init__( self, destination_dict, diff --git a/sdks/python/apache_beam/io/gcp/healthcare/dicomio_integration_test.py b/sdks/python/apache_beam/io/gcp/healthcare/dicomio_integration_test.py index 58b97305f61a..f88b37e3a152 100644 --- a/sdks/python/apache_beam/io/gcp/healthcare/dicomio_integration_test.py +++ b/sdks/python/apache_beam/io/gcp/healthcare/dicomio_integration_test.py @@ -118,6 +118,7 @@ def get_gcs_file_http(file_name): @unittest.skipIf(DicomSearch is None, 'GCP dependencies are not installed') class DICOMIoIntegrationTest(unittest.TestCase): + def setUp(self): self.test_pipeline = TestPipeline(is_integration_test=True) self.project = self.test_pipeline.get_option('project') diff --git a/sdks/python/apache_beam/io/gcp/healthcare/dicomio_test.py b/sdks/python/apache_beam/io/gcp/healthcare/dicomio_test.py index 30e5d4c0f770..7341050134af 100644 --- a/sdks/python/apache_beam/io/gcp/healthcare/dicomio_test.py +++ b/sdks/python/apache_beam/io/gcp/healthcare/dicomio_test.py @@ -165,6 +165,7 @@ def test_failed_convert(self): @unittest.skipIf(DicomSearch is None, 'GCP dependencies are not installed') class TestDicomSearch(unittest.TestCase): + @patch("apache_beam.io.gcp.healthcare.dicomio.DicomApiHttpClient") def test_successful_search(self, MockClient): input_dict = {} @@ -334,6 +335,7 @@ def test_client_search_notfound(self, MockClient): @unittest.skipIf(DicomSearch is None, 'GCP dependencies are not installed') class TestDicomStoreInstance(_TestCaseWithTempDirCleanUp): + @patch("apache_beam.io.gcp.healthcare.dicomio.DicomApiHttpClient") def test_store_byte_file(self, MockClient): input_dict = {} diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index 9e006dbeda93..d039c72e40bb 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -87,6 +87,7 @@ class PubsubMessage(object): ordering_key: (str) If non-empty, identifies related messages for which publish order is respected by the PubSub subscription. """ + def __init__( self, data, @@ -286,6 +287,7 @@ def to_runner_api_parameter(self, context): class _AddMetricsPassThrough(DoFn): + def __init__(self, project, topic=None, sub=None): self.project = project self.topic = topic @@ -317,6 +319,7 @@ def ReadStringsFromPubSub(topic=None, subscription=None, id_label=None): class _ReadStringsFromPubSub(PTransform): """This class is deprecated. Use ``ReadFromPubSub`` instead.""" + def __init__(self, topic=None, subscription=None, id_label=None): super().__init__() self.topic = topic @@ -340,6 +343,7 @@ def WriteStringsToPubSub(topic): class _WriteStringsToPubSub(PTransform): """This class is deprecated. Use ``WriteToPubSub`` instead.""" + def __init__(self, topic): """Initializes ``_WriteStringsToPubSub``. @@ -356,6 +360,7 @@ def expand(self, pcoll): class _AddMetricsAndMap(DoFn): + def __init__(self, fn, project, topic=None): self.project = project self.topic = topic @@ -492,6 +497,7 @@ class _PubSubSource(iobase.SourceBase): with_attributes: If False, will fetch just message data. Otherwise, fetches ``PubsubMessage`` protobufs. """ + def __init__( self, topic: Optional[str] = None, @@ -547,6 +553,7 @@ class _PubSubSink(object): This ``NativeSource`` is overridden by a native Pubsub implementation. """ + def __init__( self, topic: str, @@ -615,6 +622,7 @@ class MultipleReadFromPubSub(PTransform): results = pipeline | MultipleReadFromPubSub( [topic_1, topic_2, subscription_1]) """ + def __init__( self, pubsub_source_descriptors: List[PubSubSourceDescriptor], diff --git a/sdks/python/apache_beam/io/gcp/pubsub_io_perf_test.py b/sdks/python/apache_beam/io/gcp/pubsub_io_perf_test.py index aece17a1eaf3..2ea05efb3642 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_io_perf_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_io_perf_test.py @@ -75,6 +75,7 @@ class PubsubIOPerfTest(LoadTest): + def _setup_env(self): if not self.pipeline.get_option('pubsub_namespace_prefix'): logging.error('--pubsub_namespace_prefix argument is required.') @@ -108,6 +109,7 @@ def _setup_pubsub(self): class PubsubWritePerfTest(PubsubIOPerfTest): + def __init__(self): super().__init__(WRITE_METRICS_NAMESPACE) self._setup_env() @@ -115,6 +117,7 @@ def __init__(self): self._setup_pipeline() def test(self): + def to_pubsub_message(element): import uuid from apache_beam.io import PubsubMessage @@ -151,6 +154,7 @@ def _setup_pubsub(self): class PubsubReadPerfTest(PubsubIOPerfTest): + def __init__(self): super().__init__(READ_METRICS_NAMESPACE) self._setup_env() diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index 73ba8d6abdb6..f53fc37a5942 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -68,6 +68,7 @@ class TestPubsubMessage(unittest.TestCase): + def test_payload_valid(self): _ = PubsubMessage('', None) _ = PubsubMessage('data', None) @@ -136,6 +137,7 @@ def test_repr(self): @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') class TestReadFromPubSubOverride(unittest.TestCase): + def test_expand_with_topic(self): options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True @@ -243,6 +245,7 @@ def test_expand_with_other_options(self): @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') class TestMultiReadFromPubSubOverride(unittest.TestCase): + def test_expand_with_multiple_sources(self): options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True @@ -329,9 +332,9 @@ def test_expand_with_multiple_sources_and_other_options(self): PubSubSourceDescriptor( source=source, id_label=id_label, - timestamp_attribute=timestamp_attribute) for source, - id_label, - timestamp_attribute in zip(sources, id_labels, timestamp_attributes) + timestamp_attribute=timestamp_attribute) + for source, id_label, timestamp_attribute in zip( + sources, id_labels, timestamp_attributes) ] pcoll = (p | MultipleReadFromPubSub(pubsub_sources) | beam.Map(lambda x: x)) @@ -364,6 +367,7 @@ def test_expand_with_wrong_source(self): @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') class TestWriteStringsToPubSubOverride(unittest.TestCase): + def test_expand_deprecated(self): options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True @@ -416,6 +420,7 @@ def test_expand(self): @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') class TestPubSubSource(unittest.TestCase): + def test_display_data_topic(self): source = _PubSubSource('projects/fakeprj/topics/a_topic', None, 'a_label') dd = DisplayData.create_from(source) @@ -453,6 +458,7 @@ def test_display_data_no_subscription(self): @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') class TestPubSubSink(unittest.TestCase): + def test_display_data(self): sink = WriteToPubSub( 'projects/fakeprj/topics/a_topic', @@ -498,6 +504,7 @@ def finish_bundle(self): @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') @mock.patch('google.cloud.pubsub.SubscriberClient') class TestReadFromPubSub(unittest.TestCase): + def test_read_messages_success(self, mock_pubsub): data = b'data' publish_time_secs = 1520861821 @@ -848,6 +855,7 @@ def test_read_from_pubsub_no_overwrite(self, unused_mock): @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') @mock.patch('google.cloud.pubsub.PublisherClient') class TestWriteToPubSub(unittest.TestCase): + def test_write_messages_success(self, mock_pubsub): data = 'data' payloads = [data] diff --git a/sdks/python/apache_beam/io/gcp/pubsublite/external.py b/sdks/python/apache_beam/io/gcp/pubsublite/external.py index a0e46c1b4d88..d2579ee5d21c 100644 --- a/sdks/python/apache_beam/io/gcp/pubsublite/external.py +++ b/sdks/python/apache_beam/io/gcp/pubsublite/external.py @@ -46,6 +46,7 @@ class _ReadExternal(ExternalTransform): Experimental; no backwards-compatibility guarantees. """ + def __init__( self, subscription_path, @@ -86,6 +87,7 @@ class _WriteExternal(ExternalTransform): Experimental; no backwards-compatibility guarantees. """ + def __init__( self, topic_path, diff --git a/sdks/python/apache_beam/io/gcp/pubsublite/proto_api.py b/sdks/python/apache_beam/io/gcp/pubsublite/proto_api.py index a8e3defc4f99..16de38774b2c 100644 --- a/sdks/python/apache_beam/io/gcp/pubsublite/proto_api.py +++ b/sdks/python/apache_beam/io/gcp/pubsublite/proto_api.py @@ -34,6 +34,7 @@ class ReadFromPubSubLite(PTransform): Experimental; no backwards-compatibility guarantees. """ + def __init__( self, subscription_path, @@ -71,6 +72,7 @@ class WriteToPubSubLite(PTransform): Experimental; no backwards-compatibility guarantees. """ + def __init__( self, topic_path, diff --git a/sdks/python/apache_beam/io/gcp/spanner.py b/sdks/python/apache_beam/io/gcp/spanner.py index 9089d746fe1c..692618b813cf 100644 --- a/sdks/python/apache_beam/io/gcp/spanner.py +++ b/sdks/python/apache_beam/io/gcp/spanner.py @@ -368,6 +368,7 @@ def _add_doc( row_type=None, operation_suffix=None, ): + def _doc(obj): obj.__doc__ = value.format( operation=operation, diff --git a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py index 4504ba43b2c1..f4a09f08ce78 100644 --- a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py +++ b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py @@ -60,6 +60,7 @@ class BigqueryMatcher(BaseMatcher): Fetch Bigquery data with given query, compute a hash string and compare with expected checksum. """ + def __init__(self, project, query, checksum, timeout_secs=0): """Initialize BigQueryMatcher object. Args: @@ -85,6 +86,7 @@ def __init__(self, project, query, checksum, timeout_secs=0): self.timeout_secs = timeout_secs def _matches(self, _): + @retry.with_exponential_backoff( num_retries=1000, initial_delay_secs=0.5, @@ -141,6 +143,7 @@ class BigqueryFullResultMatcher(BigqueryMatcher): Fetch Bigquery data with given query, compare to the expected data. """ + def __init__(self, project, query, data): """Initialize BigQueryMatcher object. Args: @@ -206,6 +209,7 @@ def _get_query_result(self): class BigQueryTableMatcher(BaseMatcher): """Matcher that verifies the properties of a Table in BigQuery.""" + def __init__(self, project, dataset, table, expected_properties): if bigquery is None: raise ImportError('Bigquery dependencies are not installed.') @@ -231,8 +235,8 @@ def _matches(self, _): _LOGGER.info('Table proto is %s', self.actual_table) return all( - self._match_property(v, self._get_or_none(self.actual_table, k)) for k, - v in self.expected_properties.items()) + self._match_property(v, self._get_or_none(self.actual_table, k)) + for k, v in self.expected_properties.items()) @staticmethod def _get_or_none(obj, attr): @@ -250,8 +254,8 @@ def _match_property(expected, actual): if isinstance(expected, dict): return all( BigQueryTableMatcher._match_property( - v, BigQueryTableMatcher._get_or_none(actual, k)) for k, - v in expected.items()) + v, BigQueryTableMatcher._get_or_none(actual, k)) + for k, v in expected.items()) else: return expected == actual diff --git a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py index f0cf14fedced..c62c871d395f 100644 --- a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py +++ b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py @@ -43,6 +43,7 @@ @unittest.skipIf(bigquery is None, 'Bigquery dependencies are not installed.') @mock.patch.object(bigquery, 'Client') class BigqueryMatcherTest(unittest.TestCase): + def setUp(self): self._mock_result = mock.Mock() patch_retry(self, bq_verifier) @@ -115,6 +116,7 @@ def test_bigquery_matcher_query_error_checksum(self, mock_bigquery): @unittest.skipIf(bigquery is None, 'Bigquery dependencies are not installed.') @mock.patch.object(bigquery_tools, 'BigQueryWrapper') class BigqueryTableMatcherTest(unittest.TestCase): + def setUp(self): self._mock_result = mock.Mock() patch_retry(self, bq_verifier) @@ -160,6 +162,7 @@ def test_bigquery_table_matcher_query_error_retry(self, mock_bigquery): @unittest.skipIf(bigquery is None, 'Bigquery dependencies are not installed.') @mock.patch.object(bigquery, 'Client') class BigqueryFullResultStreamingMatcherTest(unittest.TestCase): + def setUp(self): self.timeout = 0.01 diff --git a/sdks/python/apache_beam/io/gcp/tests/pubsub_matcher.py b/sdks/python/apache_beam/io/gcp/tests/pubsub_matcher.py index 85836eaf3374..0dbf5ad4d5b0 100644 --- a/sdks/python/apache_beam/io/gcp/tests/pubsub_matcher.py +++ b/sdks/python/apache_beam/io/gcp/tests/pubsub_matcher.py @@ -49,6 +49,7 @@ class PubSubMessageMatcher(BaseMatcher): This matcher can block the test and keep pulling messages from given subscription until all expected messages are shown or timeout. """ + def __init__( self, project, diff --git a/sdks/python/apache_beam/io/gcp/tests/pubsub_matcher_test.py b/sdks/python/apache_beam/io/gcp/tests/pubsub_matcher_test.py index 46656349205a..6dbb9ae7f870 100644 --- a/sdks/python/apache_beam/io/gcp/tests/pubsub_matcher_test.py +++ b/sdks/python/apache_beam/io/gcp/tests/pubsub_matcher_test.py @@ -40,6 +40,7 @@ @mock.patch('time.sleep', return_value=None) @mock.patch('google.cloud.pubsub.SubscriberClient') class PubSubMatcherTest(unittest.TestCase): + def setUp(self): self.mock_presult = mock.MagicMock() diff --git a/sdks/python/apache_beam/io/gcp/tests/utils_test.py b/sdks/python/apache_beam/io/gcp/tests/utils_test.py index 5ac41df1d3e5..b906ddae798b 100644 --- a/sdks/python/apache_beam/io/gcp/tests/utils_test.py +++ b/sdks/python/apache_beam/io/gcp/tests/utils_test.py @@ -42,6 +42,7 @@ @unittest.skipIf(bigquery is None, 'Bigquery dependencies are not installed.') @mock.patch.object(bigquery, 'Client') class UtilsTest(unittest.TestCase): + def setUp(self): test_utils.patch_retry(self, utils) @@ -76,6 +77,7 @@ def test_delete_table_fails_not_found(self, mock_client): @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') class PubSubUtilTest(unittest.TestCase): + def test_write_to_pubsub(self): mock_pubsub = mock.Mock() topic_path = "project/fakeproj/topics/faketopic" @@ -178,6 +180,7 @@ def test_read_from_pubsub_flaky(self): [test_utils.PullResponseMessage(data, ack_id=ack_id)]) class FlakyPullResponse(object): + def __init__(self, pull_response): self.pull_response = pull_response self._state = -1 @@ -210,17 +213,17 @@ def test_read_from_pubsub_many(self): } for i in range(number_of_elements)] ack_ids = ['ack_id_{}'.format(i) for i in range(number_of_elements)] messages = [ - PubsubMessage(data, attributes) for data, - attributes in zip(data_list, attributes_list) + PubsubMessage(data, attributes) + for data, attributes in zip(data_list, attributes_list) ] response_messages = [ test_utils.PullResponseMessage(data, attributes, ack_id=ack_id) - for data, - attributes, - ack_id in zip(data_list, attributes_list, ack_ids) + for data, attributes, ack_id in zip( + data_list, attributes_list, ack_ids) ] class SequentialPullResponse(object): + def __init__(self, response_messages, response_size): self.response_messages = response_messages self.response_size = response_size diff --git a/sdks/python/apache_beam/io/gcp/tests/xlang_spannerio_it_test.py b/sdks/python/apache_beam/io/gcp/tests/xlang_spannerio_it_test.py index 43a74f170531..f75869829b69 100644 --- a/sdks/python/apache_beam/io/gcp/tests/xlang_spannerio_it_test.py +++ b/sdks/python/apache_beam/io/gcp/tests/xlang_spannerio_it_test.py @@ -70,6 +70,7 @@ class SpannerPartTestRow(NamedTuple): @unittest.skipIf( DockerContainer is None, 'testcontainers package is not installed.') class CrossLanguageSpannerIOTest(unittest.TestCase): + @classmethod def setUpClass(cls): parser = argparse.ArgumentParser() @@ -131,6 +132,7 @@ def to_row_fn(i): [[f'or_update{i}', i, i % 2 == 0] for i in range(3)]) def test_spanner_insert(self): + def to_row_fn(num): return SpannerTestRow( f_string=f'insert{num}', f_int64=num, f_boolean=None) @@ -259,6 +261,7 @@ def retry(fn, retries, err_msg, *args, **kwargs): class SpannerHelper(object): + def __init__(self, project_id, instance_id, table, use_emulator): self.use_emulator = use_emulator self.table = table diff --git a/sdks/python/apache_beam/io/hadoopfilesystem.py b/sdks/python/apache_beam/io/hadoopfilesystem.py index cf488c228a28..d51c28a56e97 100644 --- a/sdks/python/apache_beam/io/hadoopfilesystem.py +++ b/sdks/python/apache_beam/io/hadoopfilesystem.py @@ -61,6 +61,7 @@ class HdfsDownloader(filesystemio.Downloader): + def __init__(self, hdfs_client, path): self._hdfs_client = hdfs_client self._path = path @@ -77,6 +78,7 @@ def get_range(self, start, end): class HdfsUploader(filesystemio.Uploader): + def __init__(self, hdfs_client, path): self._hdfs_client = hdfs_client if self._hdfs_client.status(path, strict=False) is not None: @@ -101,6 +103,7 @@ class HadoopFileSystem(FileSystem): URL arguments to methods expect strings starting with ``hdfs://``. """ + def __init__(self, pipeline_options): """Initializes a connection to HDFS. diff --git a/sdks/python/apache_beam/io/hadoopfilesystem_test.py b/sdks/python/apache_beam/io/hadoopfilesystem_test.py index 8c21effc8823..9a2f46300e8d 100644 --- a/sdks/python/apache_beam/io/hadoopfilesystem_test.py +++ b/sdks/python/apache_beam/io/hadoopfilesystem_test.py @@ -91,6 +91,7 @@ class FakeHdfsError(Exception): class FakeHdfs(object): """Fake implementation of ``hdfs.Client``.""" + def __init__(self): self.files = {} @@ -202,6 +203,7 @@ def checksum(self, path): @parameterized_class(('full_urls', ), [(False, ), (True, )]) class HadoopFileSystemTest(unittest.TestCase): + def setUp(self): self._fake_hdfs = FakeHdfs() hdfs.hdfs.InsecureClient = (lambda *args, **kwargs: self._fake_hdfs) @@ -610,6 +612,7 @@ def test_delete_error(self): class HadoopFileSystemRuntimeValueProviderTest(unittest.TestCase): """Tests pipeline_options, in the form of a RuntimeValueProvider.runtime_options object.""" + def setUp(self): self._fake_hdfs = FakeHdfs() hdfs.hdfs.InsecureClient = (lambda *args, **kwargs: self._fake_hdfs) diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index 53215275e050..2ff644609a8c 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -143,6 +143,7 @@ class BoundedSource(SourceBase): implementations may invoke methods of ``BoundedSource`` objects through multi-threaded and/or reentrant execution modes. """ + def estimate_size(self) -> Optional[int]: """Estimates the size of source in bytes. @@ -846,6 +847,7 @@ class Writer(object): See ``iobase.Sink`` for more detailed documentation about the process of writing to a sink. """ + def write(self, value): """Writes a value to the sink using the current writer. """ @@ -1122,6 +1124,7 @@ def from_runner_api_parameter( class WriteImpl(ptransform.PTransform): """Implements the writing of custom sinks.""" + def __init__(self, sink: Sink) -> None: super().__init__() self.sink = sink @@ -1174,6 +1177,7 @@ class _WriteBundleDoFn(core.DoFn): """A DoFn for writing elements to an iobase.Writer. Opens a writer at the first element and closes the writer at finish_bundle(). """ + def __init__(self, sink): self.sink = sink @@ -1200,6 +1204,7 @@ def finish_bundle(self): class _WriteKeyedBundleDoFn(core.DoFn): + def __init__(self, sink): self.sink = sink @@ -1242,6 +1247,7 @@ def _finalize_write( class _RoundRobinKeyFn(core.DoFn): + def start_bundle(self): self.counter = None @@ -1265,6 +1271,7 @@ class RestrictionTracker(object): * https://s.apache.org/splittable-do-fn * https://s.apache.org/splittable-do-fn-python-sdk """ + def current_restriction(self): """Returns the current restriction. @@ -1399,6 +1406,7 @@ class WatermarkEstimator(object): Internal state must not be updated asynchronously. """ + def get_estimator_state(self): """Get current state of the WatermarkEstimator instance, which can be used to recreate the WatermarkEstimator when processing the restriction. See @@ -1424,6 +1432,7 @@ def observe_timestamp(self, timestamp: timestamp.Timestamp) -> None: class RestrictionProgress(object): """Used to record the progress of a restriction.""" + def __init__(self, **kwargs): # Only accept keyword arguments. self._fraction = kwargs.pop('fraction', None) @@ -1478,6 +1487,7 @@ def with_completed(self, completed: int) -> 'RestrictionProgress': class _SDFBoundedSourceRestriction(object): """ A restriction wraps SourceBundle and RangeTracker. """ + def __init__(self, source_bundle, range_tracker=None): self._source_bundle = source_bundle self._range_tracker = range_tracker @@ -1541,6 +1551,7 @@ class _SDFBoundedSourceRestrictionTracker(RestrictionTracker): Delegated RangeTracker guarantees synchronization safety. """ + def __init__(self, restriction): if not isinstance(restriction, _SDFBoundedSourceRestriction): raise ValueError( @@ -1577,6 +1588,7 @@ def is_bounded(self): class _SDFBoundedSourceWrapperRestrictionCoder(coders.Coder): + def decode(self, value): return _SDFBoundedSourceRestriction(SourceBundle(*pickler.loads(value))) @@ -1595,6 +1607,7 @@ class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider): This restriction provider initializes restriction based on input element that is expected to be of BoundedSource type. """ + def __init__(self, desired_chunk_size=None, restriction_coder=None): self._desired_chunk_size = desired_chunk_size self._restriction_coder = ( @@ -1644,12 +1657,15 @@ class SDFBoundedSourceReader(PTransform): NOTE: This transform can only be used with beam_fn_api enabled. """ + def __init__(self, data_to_display=None): self._data_to_display = data_to_display or {} super().__init__() def _create_sdf_bounded_source_dofn(self): + class SDFBoundedSourceDoFn(core.DoFn): + def __init__(self, dd): self._dd = dd diff --git a/sdks/python/apache_beam/io/iobase_test.py b/sdks/python/apache_beam/io/iobase_test.py index eb9617cfae34..12775a0a4c5a 100644 --- a/sdks/python/apache_beam/io/iobase_test.py +++ b/sdks/python/apache_beam/io/iobase_test.py @@ -35,6 +35,7 @@ class SDFBoundedSourceRestrictionProviderTest(unittest.TestCase): + def setUp(self): self.initial_range_start = 0 self.initial_range_stop = 4 @@ -121,6 +122,7 @@ def test_restriction_size(self): class SDFBoundedSourceRestrictionTrackerTest(unittest.TestCase): + def setUp(self): self.initial_start_pos = 0 self.initial_stop_pos = 4 @@ -194,6 +196,7 @@ def test_try_split_with_any_exception(self): class UseSdfBoundedSourcesTests(unittest.TestCase): + def _run_sdf_wrapper_pipeline(self, source, expected_values): with beam.Pipeline() as p: experiments = (p._options.view_as(DebugOptions).experiments or []) @@ -210,6 +213,7 @@ def _run_sdf_wrapper_pipeline(self, source, expected_values): @mock.patch('apache_beam.io.iobase.SDFBoundedSourceReader.expand') def test_sdf_wrapper_overrides_read(self, sdf_wrapper_mock_expand): + def _fake_wrapper_expand(pbegin): return pbegin | beam.Map(lambda x: 'fake') diff --git a/sdks/python/apache_beam/io/jdbc.py b/sdks/python/apache_beam/io/jdbc.py index 11570680a2f3..52c0636eac32 100644 --- a/sdks/python/apache_beam/io/jdbc.py +++ b/sdks/python/apache_beam/io/jdbc.py @@ -381,6 +381,7 @@ class JdbcDateType(LogicalType[datetime.date, MillisInstant, str]): Support of Legacy JdbcIO DATE logical type. Deemed to change when Java JDBCIO has been migrated to Beam portable logical types. """ + def __init__(self, argument=""): pass @@ -425,6 +426,7 @@ class JdbcTimeType(LogicalType[datetime.time, MillisInstant, str]): Support of Legacy JdbcIO TIME logical type. . Deemed to change when Java JDBCIO has been migrated to Beam portable logical types. """ + def __init__(self, argument=""): pass diff --git a/sdks/python/apache_beam/io/localfilesystem.py b/sdks/python/apache_beam/io/localfilesystem.py index daf69b8d030c..ffe3f58a5ee7 100644 --- a/sdks/python/apache_beam/io/localfilesystem.py +++ b/sdks/python/apache_beam/io/localfilesystem.py @@ -36,6 +36,7 @@ class LocalFileSystem(FileSystem): """A Local ``FileSystem`` implementation for accessing files on disk. """ + @classmethod def scheme(cls): """URI scheme for the FileSystem @@ -333,6 +334,7 @@ def delete(self, paths): Raises: ``BeamIOError``: if any of the delete operations fail """ + def _delete_path(path): """Recursively delete the file or directory at the provided path. """ diff --git a/sdks/python/apache_beam/io/localfilesystem_test.py b/sdks/python/apache_beam/io/localfilesystem_test.py index 1370790970e9..e7d4244dd3af 100644 --- a/sdks/python/apache_beam/io/localfilesystem_test.py +++ b/sdks/python/apache_beam/io/localfilesystem_test.py @@ -38,6 +38,7 @@ def _gen_fake_join(separator): """Returns a callable that joins paths with the given separator.""" + def _join(first_path, *paths): return separator.join((first_path.rstrip(separator), ) + paths) @@ -46,6 +47,7 @@ def _join(first_path, *paths): def _gen_fake_split(separator): """Returns a callable that splits a with the given separator.""" + def _split(path): sep_index = path.rfind(separator) if sep_index >= 0: @@ -57,6 +59,7 @@ def _split(path): class LocalFileSystemTest(unittest.TestCase): + def setUp(self): self.tmpdir = tempfile.mkdtemp() pipeline_options = PipelineOptions() @@ -357,8 +360,8 @@ def check_tree(self, path, value, expected_leaf_count=None): elif isinstance(value, dict): # recurse to check subdirectory tree actual_leaf_count = sum([ - self.check_tree(os.path.join(path, basename), v) for basename, - v in value.items() + self.check_tree(os.path.join(path, basename), v) + for basename, v in value.items() ]) else: raise Exception('Unexpected value in tempdir tree: %s' % value) diff --git a/sdks/python/apache_beam/io/mongodbio.py b/sdks/python/apache_beam/io/mongodbio.py index 6ffc82f59676..0a5673dc7787 100644 --- a/sdks/python/apache_beam/io/mongodbio.py +++ b/sdks/python/apache_beam/io/mongodbio.py @@ -113,6 +113,7 @@ class ReadFromMongoDB(PTransform): """A ``PTransform`` to read MongoDB documents into a ``PCollection``.""" + def __init__( self, uri="mongodb://localhost:27017", @@ -169,6 +170,7 @@ def expand(self, pcoll): class _ObjectIdRangeTracker(OrderedPositionRangeTracker): """RangeTracker for tracking mongodb _id of bson ObjectId type.""" + def position_to_fraction( self, pos: ObjectId, @@ -242,6 +244,7 @@ class _BoundedMongoSource(iobase.BoundedSource): implementations may invoke methods of ``_BoundedMongoSource`` objects through multi-threaded and/or reentrant execution modes. """ + def __init__( self, uri=None, @@ -458,12 +461,12 @@ def _get_split_keys( with MongoClient(self.uri, **self.spec) as client: name_space = "%s.%s" % (self.db, self.coll) return client[self.db].command( - "splitVector", - name_space, - keyPattern={"_id": 1}, # Ascending index - min={"_id": start_pos}, - max={"_id": end_pos}, - maxChunkSize=desired_chunk_size_in_mb, + "splitVector", + name_space, + keyPattern={"_id": 1}, # Ascending index + min={"_id": start_pos}, + max={"_id": end_pos}, + maxChunkSize=desired_chunk_size_in_mb, )["splitKeys"] def _get_auto_buckets( @@ -584,6 +587,7 @@ def _count_id_range(self, start_position, stop_position): class _ObjectIdHelper: """A Utility class to manipulate bson object ids.""" + @classmethod def id_to_int(cls, _id: Union[int, ObjectId]) -> int: """ @@ -670,6 +674,7 @@ class WriteToMongoDB(PTransform): with different unique IDs. """ + def __init__( self, uri="mongodb://localhost:27017", @@ -718,6 +723,7 @@ def expand(self, pcoll): class _GenerateObjectIdFn(DoFn): + def process(self, element, *args, **kwargs): # if _id field already exist we keep it as it is, otherwise the ptransform # generates a new _id field to achieve idempotent write to mongodb. @@ -734,6 +740,7 @@ def process(self, element, *args, **kwargs): class _WriteMongoFn(DoFn): + def __init__( self, uri=None, db=None, coll=None, batch_size=100, extra_params=None): if extra_params is None: @@ -769,6 +776,7 @@ def display_data(self): class _MongoSink: + def __init__(self, uri=None, db=None, coll=None, extra_params=None): if extra_params is None: extra_params = {} diff --git a/sdks/python/apache_beam/io/mongodbio_it_test.py b/sdks/python/apache_beam/io/mongodbio_it_test.py index dfbc9e65305b..5471daf5097a 100644 --- a/sdks/python/apache_beam/io/mongodbio_it_test.py +++ b/sdks/python/apache_beam/io/mongodbio_it_test.py @@ -33,6 +33,7 @@ class GenerateDocs(beam.DoFn): + def process(self, num_docs, *args, **kwargs): for i in range(num_docs): yield {'number': i, 'number_mod_2': i % 2, 'number_mod_3': i % 3} diff --git a/sdks/python/apache_beam/io/mongodbio_test.py b/sdks/python/apache_beam/io/mongodbio_test.py index 150eac2d5437..a77eef4b71f7 100644 --- a/sdks/python/apache_beam/io/mongodbio_test.py +++ b/sdks/python/apache_beam/io/mongodbio_test.py @@ -49,6 +49,7 @@ class _MockMongoColl(object): """Fake mongodb collection cursor.""" + def __init__(self, docs): self.docs = docs @@ -104,9 +105,10 @@ def _filter(self, filter): @staticmethod def _projection(docs, projection=None): if projection: - return [{k: v - for k, v in doc.items() if k in projection or k == '_id'} - for doc in docs] + return [{ + k: v + for k, v in doc.items() if k in projection or k == '_id' + } for doc in docs] return docs def find(self, filter=None, projection=None, **kwargs): @@ -170,6 +172,7 @@ def aggregate(self, pipeline, **kwargs): class _MockMongoDb(object): """Fake Mongo Db.""" + def __init__(self, docs): self.docs = docs @@ -210,6 +213,7 @@ def get_split_keys(self, command, ns, min, max, maxChunkSize, **kwargs): class _MockMongoClient: + def __init__(self, docs): self.docs = docs @@ -302,6 +306,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): ), ]) class MongoSourceTest(unittest.TestCase): + @mock.patch('apache_beam.io.mongodbio.MongoClient') def setUp(self, mock_client): self._docs = [{'_id': self._ids[i], 'x': i} for i in range(len(self._ids))] @@ -570,6 +575,7 @@ def test_range_is_not_splittable(self): @parameterized_class(('bucket_auto', ), [(False, ), (True, )]) class ReadFromMongoDBTest(unittest.TestCase): + @mock.patch('apache_beam.io.mongodbio.MongoClient') def test_read_from_mongodb(self, mock_client): documents = [{ @@ -594,6 +600,7 @@ def test_read_from_mongodb(self, mock_client): class GenerateObjectIdFnTest(unittest.TestCase): + def test_process(self): with TestPipeline() as p: output = ( @@ -608,6 +615,7 @@ def test_process(self): class WriteMongoFnTest(unittest.TestCase): + @mock.patch('apache_beam.io.mongodbio._MongoSink') def test_process(self, mock_sink): docs = [{'x': 1}, {'x': 2}, {'x': 3}] @@ -626,6 +634,7 @@ def test_display_data(self): class MongoSinkTest(unittest.TestCase): + @mock.patch('apache_beam.io.mongodbio.MongoClient') def test_write(self, mock_client): docs = [{'x': 1}, {'x': 2}, {'x': 3}] @@ -636,6 +645,7 @@ def test_write(self, mock_client): class WriteToMongoDBTest(unittest.TestCase): + @mock.patch('apache_beam.io.mongodbio.MongoClient') def test_write_to_mongodb_with_existing_id(self, mock_client): _id = objectid.ObjectId() @@ -671,6 +681,7 @@ def test_write_to_mongodb_with_generated_id(self, mock_client): class ObjectIdHelperTest(TestCase): + def test_conversion(self): test_cases = [ (objectid.ObjectId('000000000000000000000000'), 0), @@ -713,6 +724,7 @@ def test_increment_id(self): class ObjectRangeTrackerTest(TestCase): + def test_fraction_position_conversion(self): start_int = 0 stop_int = 2**96 - 1 diff --git a/sdks/python/apache_beam/io/parquetio.py b/sdks/python/apache_beam/io/parquetio.py index 48c51428c17d..b1f637c2e467 100644 --- a/sdks/python/apache_beam/io/parquetio.py +++ b/sdks/python/apache_beam/io/parquetio.py @@ -76,6 +76,7 @@ class _ArrowTableToRowDictionaries(DoFn): """ A DoFn that consumes an Arrow table and yields a python dictionary for each row in the table.""" + def process(self, table, with_filename=False): if with_filename: file_name = table[0] @@ -94,6 +95,7 @@ def process(self, table, with_filename=False): class _RowDictionariesToArrowTable(DoFn): """ A DoFn that consumes python dictionarys and yields a pyarrow table.""" + def __init__( self, schema, @@ -154,6 +156,7 @@ def _flush_buffer(self): class _ArrowTableToBeamRows(DoFn): + def __init__(self, beam_type): self._beam_type = beam_type @@ -166,6 +169,7 @@ def infer_output_type(self, input_type): class _BeamRowsToArrowTable(DoFn): + @DoFn.yields_elements def process_batch(self, element: pa.Table) -> Iterator[pa.Table]: yield element @@ -175,6 +179,7 @@ class ReadFromParquetBatched(PTransform): """A :class:`~apache_beam.transforms.ptransform.PTransform` for reading Parquet files as a `PCollection` of `pyarrow.Table`. This `PTransform` is currently experimental. No backward-compatibility guarantees.""" + def __init__( self, file_pattern=None, min_bundle_size=0, validate=True, columns=None): """ Initializes :class:`~ReadFromParquetBatched` @@ -232,6 +237,7 @@ def display_data(self): class ReadFromParquet(PTransform): """A `PTransform` for reading Parquet files.""" + def __init__( self, file_pattern=None, @@ -370,6 +376,7 @@ def expand(self, pvalue): class ReadAllFromParquet(PTransform): + def __init__(self, with_filename=False, **kwargs): self._with_filename = with_filename self._read_batches = ReadAllFromParquetBatched( @@ -381,6 +388,7 @@ def expand(self, pvalue): class _ParquetUtils(object): + @staticmethod def find_first_row_group_index(pf, start_offset): for i in range(_ParquetUtils.get_number_of_row_groups(pf)): @@ -406,6 +414,7 @@ def get_number_of_row_groups(pf): class _ParquetSource(filebasedsource.FileBasedSource): """A source for reading Parquet files. """ + def __init__( self, file_pattern, min_bundle_size=0, validate=False, columns=None): super().__init__( @@ -464,6 +473,7 @@ def split_points_unclaimed(stop_position): class WriteToParquet(PTransform): """A ``PTransform`` for writing parquet files. """ + def __init__( self, file_path_prefix, @@ -599,6 +609,7 @@ class WriteToParquetBatched(PTransform): This ``PTransform`` is currently experimental. No backward-compatibility guarantees. """ + def __init__( self, file_path_prefix, @@ -724,6 +735,7 @@ def _create_parquet_sink( class _ParquetSink(filebasedsink.FileBasedSink): """A sink for parquet files from batches.""" + def __init__( self, file_path_prefix, diff --git a/sdks/python/apache_beam/io/parquetio_it_test.py b/sdks/python/apache_beam/io/parquetio_it_test.py index 052b54f3ebfb..a5caf769e768 100644 --- a/sdks/python/apache_beam/io/parquetio_it_test.py +++ b/sdks/python/apache_beam/io/parquetio_it_test.py @@ -46,6 +46,7 @@ @unittest.skipIf(pa is None, "PyArrow is not installed.") class TestParquetIT(unittest.TestCase): + def setUp(self): pass @@ -117,6 +118,7 @@ def _generate_data(self, p, output_prefix, init_size, data_size): class ProducerFn(DoFn): + def __init__(self, number): super().__init__() self._number = number diff --git a/sdks/python/apache_beam/io/parquetio_test.py b/sdks/python/apache_beam/io/parquetio_test.py index e33ee4ec1129..22723e28652a 100644 --- a/sdks/python/apache_beam/io/parquetio_test.py +++ b/sdks/python/apache_beam/io/parquetio_test.py @@ -65,6 +65,7 @@ @unittest.skipIf(pa is None, "PyArrow is not installed.") @pytest.mark.uses_pyarrow class TestParquet(unittest.TestCase): + def setUp(self): # Reducing the size of thread pools. Without this test execution may fail in # environments with limited amount of resources. diff --git a/sdks/python/apache_beam/io/range_trackers.py b/sdks/python/apache_beam/io/range_trackers.py index ba56fd3f3559..9f69ad4e4de1 100644 --- a/sdks/python/apache_beam/io/range_trackers.py +++ b/sdks/python/apache_beam/io/range_trackers.py @@ -207,9 +207,9 @@ def split_points(self): if self._split_points_unclaimed_callback else iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) split_points_remaining = ( - iobase.RangeTracker.SPLIT_POINTS_UNKNOWN - if split_points_unclaimed == iobase.RangeTracker.SPLIT_POINTS_UNKNOWN - else (split_points_unclaimed + 1)) + iobase.RangeTracker.SPLIT_POINTS_UNKNOWN if split_points_unclaimed + == iobase.RangeTracker.SPLIT_POINTS_UNKNOWN else + (split_points_unclaimed + 1)) return (split_points_consumed, split_points_remaining) @@ -306,6 +306,7 @@ class UnsplittableRangeTracker(iobase.RangeTracker): ignoring all calls to :meth:`.try_split()`. All other calls will be delegated to the given :class:`~apache_beam.io.iobase.RangeTracker`. """ + def __init__(self, range_tracker): """Initializes UnsplittableRangeTracker. @@ -350,6 +351,7 @@ class LexicographicKeyRangeTracker(OrderedPositionRangeTracker): """A range tracker that tracks progress through a lexicographically ordered keyspace of strings. """ + @classmethod def fraction_to_position( cls, diff --git a/sdks/python/apache_beam/io/range_trackers_test.py b/sdks/python/apache_beam/io/range_trackers_test.py index 0bf37997f2ce..5ec46a806fe2 100644 --- a/sdks/python/apache_beam/io/range_trackers_test.py +++ b/sdks/python/apache_beam/io/range_trackers_test.py @@ -29,6 +29,7 @@ class OffsetRangeTrackerTest(unittest.TestCase): + def test_try_return_record_simple_sparse(self): tracker = range_trackers.OffsetRangeTracker(100, 200) self.assertTrue(tracker.try_claim(110)) @@ -206,7 +207,9 @@ def dummy_callback(stop_position): class OrderedPositionRangeTrackerTest(unittest.TestCase): + class DoubleRangeTracker(range_trackers.OrderedPositionRangeTracker): + @staticmethod def fraction_to_position(fraction, start, end): return start + (end - start) * fraction @@ -283,6 +286,7 @@ def test_out_of_range(self): class UnsplittableRangeTrackerTest(unittest.TestCase): + def test_try_claim(self): tracker = range_trackers.UnsplittableRangeTracker( range_trackers.OffsetRangeTracker(100, 200)) diff --git a/sdks/python/apache_beam/io/requestresponse.py b/sdks/python/apache_beam/io/requestresponse.py index d7011e5a8ff3..fce2ab5086a6 100644 --- a/sdks/python/apache_beam/io/requestresponse.py +++ b/sdks/python/apache_beam/io/requestresponse.py @@ -92,6 +92,7 @@ def retry_on_exception(exception: Exception): class _MetricsCollector: """A metrics collector that tracks RequestResponseIO related usage.""" + def __init__(self, namespace: str): """ Args: @@ -120,6 +121,7 @@ class Caller(contextlib.AbstractContextManager, For setup and teardown of clients when applicable, implement the ``__enter__`` and ``__exit__`` methods respectively.""" + @abc.abstractmethod def __call__(self, request: RequestT, *args, **kwargs) -> ResponseT: """Calls a Web API with the ``RequestT`` and returns a @@ -161,6 +163,7 @@ class ShouldBackOff(abc.ABC): class Repeater(abc.ABC): """Provides mechanism to repeat requests for a configurable condition.""" + @abc.abstractmethod def repeat( self, @@ -218,6 +221,7 @@ class ExponentialBackOffRepeater(Repeater): It utilizes the decorator :func:`apache_beam.utils.retry.with_exponential_backoff`. """ + def __init__(self): pass @@ -247,6 +251,7 @@ def repeat( class NoOpsRepeater(Repeater): """Executes a request just once irrespective of any exception. """ + def repeat( self, caller: Caller[RequestT, ResponseT], @@ -273,6 +278,7 @@ class DefaultThrottler(PreCallThrottler): https://landing.google.com/sre/book/chapters/handling-overload.html. delay_secs (int): minimum number of seconds to throttle a request. """ + def __init__( self, window_ms: int = 1, @@ -289,6 +295,7 @@ class _FilterCacheReadFn(beam.DoFn): It emits to main output for successful cache read requests or to the tagged output - `cache_misses` - otherwise.""" + def process(self, element: Tuple[RequestT, ResponseT], *args, **kwargs): if not element[1]: yield pvalue.TaggedOutput('cache_misses', element[0]) @@ -313,6 +320,7 @@ class _Call(beam.PTransform[beam.PCollection[RequestT], repeater: (Optional) provides methods to repeat requests to API. throttler: (Optional) provides methods to pre-throttle a request. """ + def __init__( self, caller: Caller[RequestT, ResponseT], @@ -335,6 +343,7 @@ def expand( class _CallDoFn(beam.DoFn): + def setup(self): self._caller.__enter__() self._metrics_collector = _MetricsCollector(self._caller.__str__()) @@ -389,6 +398,7 @@ class Cache(abc.ABC): For adding cache support to RequestResponseIO, implement this class. """ + @abc.abstractmethod def get_read(self): """returns a PTransform that reads from the cache.""" @@ -442,6 +452,7 @@ class _RedisCaller(Caller): It provides the functionality for making requests to Redis server using :class:`apache_beam.io.requestresponse.RequestResponseIO`. """ + def __init__( self, host: str, @@ -552,6 +563,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): class _ReadFromRedis(beam.PTransform[beam.PCollection[RequestT], beam.PCollection[ResponseT]]): """A `PTransform` that performs Redis cache read.""" + def __init__( self, host: str, @@ -601,6 +613,7 @@ class _WriteToRedis(beam.PTransform[beam.PCollection[Tuple[RequestT, ResponseT]], beam.PCollection[ResponseT]]): """A `PTransfrom` that performs write to Redis cache.""" + def __init__( self, host: str, @@ -657,6 +670,7 @@ def ensure_coders_exist(request_coder): class RedisCache(Cache): """Configure cache using Redis for :class:`apache_beam.io.requestresponse.RequestResponseIO`.""" + def __init__( self, host: str, @@ -731,6 +745,7 @@ def request_coder(self, request_coder: coders.Coder): class FlattenBatch(beam.DoFn): """Flatten a batched PCollection.""" + def process(self, elements, *args, **kwargs): for element in elements: yield element @@ -744,6 +759,7 @@ class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT], by making a call to the API as defined in `Caller`'s `__call__` method and returns a :class:`~apache_beam.pvalue.PCollection` of responses. """ + def __init__( self, caller: Caller[RequestT, ResponseT], diff --git a/sdks/python/apache_beam/io/requestresponse_it_test.py b/sdks/python/apache_beam/io/requestresponse_it_test.py index 712ccc7881d6..3abf3f2465a1 100644 --- a/sdks/python/apache_beam/io/requestresponse_it_test.py +++ b/sdks/python/apache_beam/io/requestresponse_it_test.py @@ -55,6 +55,7 @@ class EchoITOptions(PipelineOptions): -infra/mock-apis#integration for details on how to acquire values required by ``EchoITOptions``. """ + @classmethod def _add_argparse_args(cls, parser) -> None: parser.add_argument( @@ -92,6 +93,7 @@ class EchoHTTPCaller(Caller[Request, EchoResponse]): """Implements ``Caller`` to call the ``EchoServiceGrpc``'s HTTP handler. The purpose of ``EchoHTTPCaller`` is to support integration tests. """ + def __init__(self, url: str): self.url = url + _HTTP_PATH @@ -129,6 +131,7 @@ def __call__(self, request: Request, *args, **kwargs) -> EchoResponse: class ValidateResponse(beam.DoFn): """Validates response received from Mock API server.""" + def process(self, element, *args, **kwargs): if (element.id != 'echo-should-never-exceed-quota' or element.payload != _PAYLOAD): @@ -172,6 +175,7 @@ def test_request_response_io(self): class ValidateCacheResponses(beam.DoFn): """Validates that the responses are fetched from the cache.""" + def process(self, element, *args, **kwargs): if not element[1] or 'cached-' not in element[1]: raise ValueError( @@ -181,12 +185,14 @@ def process(self, element, *args, **kwargs): class ValidateCallerResponses(beam.DoFn): """Validates that the responses are fetched from the caller.""" + def process(self, element, *args, **kwargs): if not element[1] or 'ACK-' not in element[1]: raise ValueError('responses not fetched from caller when they should.') class FakeCallerForCache(Caller[str, str]): + def __init__(self, use_cache: bool = False): self.use_cache = use_cache @@ -205,6 +211,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): @pytest.mark.uses_testcontainer class TestRedisCache(unittest.TestCase): + def setUp(self) -> None: self.retries = 3 self._start_container() diff --git a/sdks/python/apache_beam/io/requestresponse_test.py b/sdks/python/apache_beam/io/requestresponse_test.py index 3bc85a5e103a..cd9daa9bccbf 100644 --- a/sdks/python/apache_beam/io/requestresponse_test.py +++ b/sdks/python/apache_beam/io/requestresponse_test.py @@ -45,6 +45,7 @@ class AckCaller(Caller[str, str]): """AckCaller acknowledges the incoming request by returning a request with ACK.""" + def __enter__(self): pass @@ -58,6 +59,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): class CallerWithTimeout(AckCaller): """CallerWithTimeout sleeps for 2 seconds before responding. Used to test timeout in RequestResponseIO.""" + def __call__(self, request: str, *args, **kwargs): time.sleep(2) return f"ACK: {request}" @@ -66,12 +68,14 @@ def __call__(self, request: str, *args, **kwargs): class CallerWithRuntimeError(AckCaller): """CallerWithRuntimeError raises a `RuntimeError` for RequestResponseIO to raise a UserCodeExecutionException.""" + def __call__(self, request: str, *args, **kwargs): if not request: raise RuntimeError("Exception expected, not an error.") class CallerThatRetries(AckCaller): + def __init__(self): self.count = -1 @@ -86,6 +90,7 @@ def __call__(self, request: str, *args, **kwargs): class TestCaller(unittest.TestCase): + def test_valid_call(self): caller = AckCaller() with TestPipeline() as test_pipeline: diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py index 4b819e87a8d6..ecd44a2d22ee 100644 --- a/sdks/python/apache_beam/io/restriction_trackers.py +++ b/sdks/python/apache_beam/io/restriction_trackers.py @@ -26,6 +26,7 @@ class OffsetRange(object): + def __init__(self, start, stop): if start > stop: raise ValueError( @@ -77,6 +78,7 @@ class OffsetRestrictionTracker(RestrictionTracker): Offset range is represented as OffsetRange. """ + def __init__(self, offset_range: OffsetRange) -> None: assert isinstance(offset_range, OffsetRange), offset_range self._range = offset_range @@ -157,6 +159,7 @@ def is_bounded(self): class UnsplittableRestrictionTracker(RestrictionTracker): """An `iobase.RestrictionTracker` that wraps another but does not split.""" + def __init__(self, underling_tracker): self._underling_tracker = underling_tracker diff --git a/sdks/python/apache_beam/io/restriction_trackers_test.py b/sdks/python/apache_beam/io/restriction_trackers_test.py index 0d3eee18036e..e2d0f083c9b8 100644 --- a/sdks/python/apache_beam/io/restriction_trackers_test.py +++ b/sdks/python/apache_beam/io/restriction_trackers_test.py @@ -27,6 +27,7 @@ class OffsetRangeTest(unittest.TestCase): + def test_create(self): OffsetRange(0, 10) OffsetRange(10, 10) @@ -70,6 +71,7 @@ def test_split_at(self): class OffsetRestrictionTrackerTest(unittest.TestCase): + def test_try_claim(self): tracker = OffsetRestrictionTracker(OffsetRange(100, 200)) self.assertEqual(OffsetRange(100, 200), tracker.current_restriction()) diff --git a/sdks/python/apache_beam/io/source_test_utils.py b/sdks/python/apache_beam/io/source_test_utils.py index b40f70604c42..1b3aeb474462 100644 --- a/sdks/python/apache_beam/io/source_test_utils.py +++ b/sdks/python/apache_beam/io/source_test_utils.py @@ -156,8 +156,8 @@ def assert_sources_equal_reference_source(reference_source_info, sources_info): 'source_info must a three tuple where first' 'item of the tuple gives a ' 'iobase.BoundedSource. Received: %r' % source_info) - if (type(reference_source_info[0].default_output_coder()) != type( - source_info[0].default_output_coder())): + if (type(reference_source_info[0].default_output_coder()) + != type(source_info[0].default_output_coder())): raise ValueError( 'Reference source %r and the source %r must use the same coder. ' 'They are using %r and %r respectively instead.' % ( @@ -341,8 +341,8 @@ def _assert_split_at_fraction_behavior( num_items_to_read_before_split, split_result)) - elif ( - expected_outcome != ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS): + elif (expected_outcome + != ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS): raise ValueError('Unknown type of expected outcome: %r' % expected_outcome) current_items.extend([value for value in reader_iter]) diff --git a/sdks/python/apache_beam/io/source_test_utils_test.py b/sdks/python/apache_beam/io/source_test_utils_test.py index 081a6fcb60ca..1f8dcaa6a487 100644 --- a/sdks/python/apache_beam/io/source_test_utils_test.py +++ b/sdks/python/apache_beam/io/source_test_utils_test.py @@ -26,6 +26,7 @@ class SourceTestUtilsTest(unittest.TestCase): + def _create_file_with_data(self, lines): assert isinstance(lines, list) with tempfile.NamedTemporaryFile(delete=False) as f: diff --git a/sdks/python/apache_beam/io/sources_test.py b/sdks/python/apache_beam/io/sources_test.py index f75e4fdafff0..94cbf25de5c6 100644 --- a/sdks/python/apache_beam/io/sources_test.py +++ b/sdks/python/apache_beam/io/sources_test.py @@ -89,6 +89,7 @@ def _get_file_size(self): class SourcesTest(unittest.TestCase): + def _create_temp_file(self, contents): with tempfile.NamedTemporaryFile(delete=False) as f: f.write(contents) diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py index 0d7803bcabb1..047a8a111905 100644 --- a/sdks/python/apache_beam/io/textio.py +++ b/sdks/python/apache_beam/io/textio.py @@ -425,6 +425,7 @@ def output_type_hint(self): class _TextSourceWithFilename(_TextSource): + def read_records(self, file_name, range_tracker): records = super().read_records(file_name, range_tracker) for record in records: @@ -436,6 +437,7 @@ def output_type_hint(self): class _TextSink(filebasedsink.FileBasedSink): """A sink to a GCS or local text file or files.""" + def __init__( self, file_path_prefix, @@ -811,6 +813,7 @@ class ReadFromTextWithFilename(ReadFromText): class WriteToText(PTransform): """A :class:`~apache_beam.transforms.ptransform.PTransform` for writing to text files.""" + def __init__( self, file_path_prefix: str, @@ -899,6 +902,7 @@ def expand(self, pcoll): import pandas def append_pandas_args(src, exclude): + def append(dest): state = None skip = False diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py index d1bfdf6bfd35..bf5736d2ff99 100644 --- a/sdks/python/apache_beam/io/textio_test.py +++ b/sdks/python/apache_beam/io/textio_test.py @@ -52,6 +52,7 @@ class DummyCoder(coders.Coder): + def encode(self, x): raise ValueError @@ -1444,6 +1445,7 @@ def test_read_escaped_escapechar_after_splitting_many(self): class TextSinkTest(unittest.TestCase): + def setUp(self): super().setUp() self.lines = [b'Line %d' % d for d in range(100)] @@ -1712,6 +1714,7 @@ def test_write_max_bytes_per_shard(self): class CsvTest(unittest.TestCase): + def test_csv_read_write(self): records = [beam.Row(a='str', b=ix) for ix in range(3)] with tempfile.TemporaryDirectory() as dest: @@ -1770,6 +1773,7 @@ def test_non_utf8_csv_read_write(self): class JsonTest(unittest.TestCase): + def test_json_read_write(self): records = [beam.Row(a='str', b=ix) for ix in range(3)] with tempfile.TemporaryDirectory() as dest: diff --git a/sdks/python/apache_beam/io/tfrecordio.py b/sdks/python/apache_beam/io/tfrecordio.py index d3bb0f8acf3f..467ecbeb2c2d 100644 --- a/sdks/python/apache_beam/io/tfrecordio.py +++ b/sdks/python/apache_beam/io/tfrecordio.py @@ -75,6 +75,7 @@ class _TFRecordUtil(object): Note that masks and length are represented in LittleEndian order. """ + @classmethod def _masked_crc32c(cls, value, crc32c_fn=_default_crc32c_fn): """Compute a masked crc32c checksum for a value. @@ -166,6 +167,7 @@ class _TFRecordSource(FileBasedSource): For detailed TFRecords format description see: https://www.tensorflow.org/versions/r1.11/api_guides/python/python_io#TFRecords_Format_Details """ + def __init__(self, file_pattern, coder, compression_type, validate): """Initialize a TFRecordSource. See ReadFromTFRecord for details.""" super().__init__( @@ -202,6 +204,7 @@ def _create_tfrecordio_source( class ReadAllFromTFRecord(PTransform): """A ``PTransform`` for reading a ``PCollection`` of TFRecord files.""" + def __init__( self, coder=coders.BytesCoder(), @@ -239,6 +242,7 @@ def expand(self, pvalue): class ReadFromTFRecord(PTransform): """Transform for reading TFRecord sources.""" + def __init__( self, file_pattern, @@ -273,6 +277,7 @@ class _TFRecordSink(filebasedsink.FileBasedSink): For detailed TFRecord format description see: https://www.tensorflow.org/versions/r1.11/api_guides/python/python_io#TFRecords_Format_Details """ + def __init__( self, file_path_prefix, @@ -298,6 +303,7 @@ def write_encoded_record(self, file_handle, value): class WriteToTFRecord(PTransform): """Transform for writing to TFRecord sinks.""" + def __init__( self, file_path_prefix, diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py index a867c0212ad3..3955948e488d 100644 --- a/sdks/python/apache_beam/io/tfrecordio_test.py +++ b/sdks/python/apache_beam/io/tfrecordio_test.py @@ -88,6 +88,7 @@ def _write_file_gzip(path, base64_records): class TestTFRecordUtil(unittest.TestCase): + def setUp(self): self.record = binascii.a2b_base64(FOO_RECORD_BASE64) @@ -159,6 +160,7 @@ def test_compatibility_read_write(self): class TestTFRecordSink(unittest.TestCase): + def _write_lines(self, sink, path, lines): f = sink.open(path) for l in lines: @@ -200,6 +202,7 @@ def test_write_record_multiple(self): @unittest.skipIf(tf is None, 'tensorflow not installed.') class TestWriteToTFRecord(TestTFRecordSink): + def test_write_record_gzip(self): with TempDir() as temp_dir: file_path_prefix = temp_dir.create_temp_file('result') @@ -236,6 +239,7 @@ def test_write_record_auto(self): class TestReadFromTFRecord(unittest.TestCase): + def test_process_single(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result') @@ -328,6 +332,7 @@ def test_process_gzip_auto(self): class TestReadAllFromTFRecord(unittest.TestCase): + def _write_glob(self, temp_dir, suffix, include_empty=False): for _ in range(3): path = temp_dir.create_temp_file(suffix) @@ -468,6 +473,7 @@ def test_process_auto(self): class TestEnd2EndWriteAndRead(unittest.TestCase): + def create_inputs(self): input_array = [[random.random() - 0.5 for _ in range(15)] for _ in range(12)] diff --git a/sdks/python/apache_beam/io/utils.py b/sdks/python/apache_beam/io/utils.py index 0d1f52f35f2b..61abb74fdf3f 100644 --- a/sdks/python/apache_beam/io/utils.py +++ b/sdks/python/apache_beam/io/utils.py @@ -28,6 +28,7 @@ class CountingSource(iobase.BoundedSource): + def __init__(self, count): self.records_read = Metrics.counter(self.__class__, 'recordsRead') self._count = count diff --git a/sdks/python/apache_beam/io/utils_test.py b/sdks/python/apache_beam/io/utils_test.py index 76ed9969f995..efe14986f677 100644 --- a/sdks/python/apache_beam/io/utils_test.py +++ b/sdks/python/apache_beam/io/utils_test.py @@ -27,6 +27,7 @@ class CountingSourceTest(unittest.TestCase): + def setUp(self): self.source = CountingSource(10) diff --git a/sdks/python/apache_beam/io/watermark_estimators.py b/sdks/python/apache_beam/io/watermark_estimators.py index ea68608c821a..623a657539f5 100644 --- a/sdks/python/apache_beam/io/watermark_estimators.py +++ b/sdks/python/apache_beam/io/watermark_estimators.py @@ -29,6 +29,7 @@ class MonotonicWatermarkEstimator(WatermarkEstimator): """A WatermarkEstimator which assumes that timestamps of all ouput records are increasing monotonically. """ + def __init__(self, timestamp): """For a new pair, the initial value is None. When resuming processing, the initial timestamp will be the last reported @@ -54,7 +55,9 @@ def default_provider(): """Provide a default WatermarkEstimatorProvider for MonotonicWatermarkEstimator. """ + class DefaultMonotonicWatermarkEstimator(WatermarkEstimatorProvider): + def initial_estimator_state(self, element, restriction): return None @@ -67,6 +70,7 @@ def create_watermark_estimator(self, estimator_state): class WalltimeWatermarkEstimator(WatermarkEstimator): """A WatermarkEstimator which uses processing time as the estimated watermark. """ + def __init__(self, timestamp=None): self._timestamp = timestamp or Timestamp.now() @@ -85,7 +89,9 @@ def default_provider(): """Provide a default WatermarkEstimatorProvider for WalltimeWatermarkEstimator. """ + class DefaultWalltimeWatermarkEstimator(WatermarkEstimatorProvider): + def initial_estimator_state(self, element, restriction): return None @@ -100,6 +106,7 @@ class ManualWatermarkEstimator(WatermarkEstimator): The DoFn must invoke set_watermark to advance the watermark. """ + def __init__(self, watermark): self._watermark = watermark @@ -142,7 +149,9 @@ def default_provider(): """Provide a default WatermarkEstimatorProvider for WalltimeWatermarkEstimator. """ + class DefaultManualWatermarkEstimatorProvider(WatermarkEstimatorProvider): + def initial_estimator_state(self, element, restriction): return None diff --git a/sdks/python/apache_beam/io/watermark_estimators_test.py b/sdks/python/apache_beam/io/watermark_estimators_test.py index 6fd1e463ddaf..af2c3317248b 100644 --- a/sdks/python/apache_beam/io/watermark_estimators_test.py +++ b/sdks/python/apache_beam/io/watermark_estimators_test.py @@ -32,6 +32,7 @@ class MonotonicWatermarkEstimatorTest(unittest.TestCase): + def test_initialize_from_state(self): timestamp = Timestamp(10) watermark_estimator = MonotonicWatermarkEstimator(timestamp) @@ -60,6 +61,7 @@ def test_get_estimator_state(self): class WalltimeWatermarkEstimatorTest(unittest.TestCase): + @mock.patch('apache_beam.utils.timestamp.Timestamp.now') def test_initialization(self, mock_timestamp): now_time = Timestamp.now() - Duration(10) @@ -84,6 +86,7 @@ def test_advance_watermark_with_incorrect_sys_clock(self): class ManualWatermarkEstimatorTest(unittest.TestCase): + def test_initialization(self): watermark_estimator = ManualWatermarkEstimator(None) self.assertIsNone(watermark_estimator.get_estimator_state()) diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py index c2c2e8015ef2..52eaf53604a2 100644 --- a/sdks/python/apache_beam/metrics/cells.py +++ b/sdks/python/apache_beam/metrics/cells.py @@ -60,6 +60,7 @@ class MetricCell(object): and may be subject to parallel/concurrent updates. Cells should only be used directly within a runner. """ + def __init__(self): self._lock = threading.Lock() self._start_time = None @@ -89,6 +90,7 @@ def __reduce__(self): class MetricCellFactory(object): + def __call__(self): # type: () -> MetricCell raise NotImplementedError @@ -105,6 +107,7 @@ class CounterCell(MetricCell): This class is thread safe. """ + def __init__(self, *args): super().__init__(*args) self.value = 0 @@ -170,6 +173,7 @@ class DistributionCell(MetricCell): This class is thread safe. """ + def __init__(self, *args): super().__init__(*args) self.data = DistributionData.identity_element() @@ -225,6 +229,7 @@ class AbstractMetricCell(MetricCell): This class is thread safe. """ + def __init__(self, data_class): super().__init__() self.data_class = data_class @@ -268,6 +273,7 @@ class GaugeCell(AbstractMetricCell): This class is thread safe. """ + def __init__(self): super().__init__(GaugeData) @@ -297,6 +303,7 @@ class StringSetCell(AbstractMetricCell): This class is thread safe. """ + def __init__(self): super().__init__(StringSetData) @@ -326,6 +333,7 @@ class BoundedTrieCell(AbstractMetricCell): This class is thread safe. """ + def __init__(self): super().__init__(BoundedTrieData) @@ -346,6 +354,7 @@ def to_runner_api_monitoring_info_impl(self, name, transform_id): class DistributionResult(object): """The result of a Distribution metric.""" + def __init__(self, data): # type: (DistributionData) -> None self.data = data @@ -401,6 +410,7 @@ def mean(self): class GaugeResult(object): + def __init__(self, data): # type: (GaugeData) -> None self.data = data @@ -441,6 +451,7 @@ class GaugeData(object): This object is not thread safe, so it's not supposed to be modified by other than the GaugeCell that contains it. """ + def __init__(self, value, timestamp=None): # type: (Optional[int], Optional[int]) -> None self.value = value @@ -501,6 +512,7 @@ class DistributionData(object): This object is not thread safe, so it's not supposed to be modified by other than the DistributionCell that contains it. """ + def __init__(self, sum, count, min, max): # type: (int, int, int, int) -> None if count: @@ -665,6 +677,7 @@ def identity_element() -> "StringSetData": class _BoundedTrieNode(object): + def __init__(self): # invariant: size = len(self.flattened()) = min(1, sum(size of children)) self._size = 1 @@ -688,8 +701,7 @@ def from_proto(proto: metrics_pb2.BoundedTrieNode) -> '_BoundedTrieNode': else: node._children = { name: _BoundedTrieNode.from_proto(child) - for name, - child in proto.children.items() + for name, child in proto.children.items() } node._size = max(1, sum(child._size for child in node._children.values())) return node diff --git a/sdks/python/apache_beam/metrics/cells_test.py b/sdks/python/apache_beam/metrics/cells_test.py index 1cd15fced86c..33fc57db53da 100644 --- a/sdks/python/apache_beam/metrics/cells_test.py +++ b/sdks/python/apache_beam/metrics/cells_test.py @@ -36,6 +36,7 @@ class TestCounterCell(unittest.TestCase): + @classmethod def _modify_counter(cls, d): for i in range(cls.NUM_ITERATIONS): @@ -49,7 +50,8 @@ def test_parallel_access(self): threads = [] c = CounterCell() for _ in range(TestCounterCell.NUM_THREADS): - t = threading.Thread(target=TestCounterCell._modify_counter, args=(c, )) + t = threading.Thread( + target=TestCounterCell._modify_counter, args=(c, )) threads.append(t) t.start() @@ -84,6 +86,7 @@ def test_start_time_set(self): class TestDistributionCell(unittest.TestCase): + @classmethod def _modify_distribution(cls, d): for i in range(cls.NUM_ITERATIONS): @@ -142,6 +145,7 @@ def test_start_time_set(self): class TestGaugeCell(unittest.TestCase): + def test_basic_operations(self): g = GaugeCell() g.set(10) @@ -177,6 +181,7 @@ def test_start_time_set(self): class TestStringSetCell(unittest.TestCase): + def test_not_leak_mutable_set(self): c = StringSetCell() c.add('test') @@ -209,6 +214,7 @@ def test_add_size_tracked_correctly(self): class TestBoundedTrieNode(unittest.TestCase): + @classmethod def random_segments_fixed_depth(cls, n, depth, overlap, rand): if depth == 0: @@ -259,6 +265,7 @@ def assert_covers_flattened(self, flattened, expected, max_truncated=0): self.assertEqual(seen_truncated, truncated, truncated - seen_truncated) def run_covers_test(self, flattened, expected, max_truncated): + def parse(s): return tuple(s.strip('*')) + (s.endswith('*'), ) diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py index c28c8340a505..fdd0216b461c 100644 --- a/sdks/python/apache_beam/metrics/execution.py +++ b/sdks/python/apache_beam/metrics/execution.py @@ -70,6 +70,7 @@ class MetricKey(object): and any extra label metadata added by the runner specific metric collection service. """ + def __init__(self, step, metric, labels=None): """Initializes ``MetricKey``. @@ -110,6 +111,7 @@ class MetricResult(object): attempted: The logical updates of the metric. This attribute's type is that of metric type result (e.g. int, DistributionResult, GaugeResult). """ + def __init__(self, key, committed, attempted): """Initializes ``MetricResult``. Args: @@ -150,6 +152,7 @@ class _MetricsEnvironment(object): This class is not meant to be instantiated, instead being used to keep track of global state. """ + def current_container(self): """Returns the current MetricsContainer.""" sampler = statesampler.get_current_tracker() @@ -167,6 +170,7 @@ def process_wide_container(self): class _TypedMetricName(object): """Like MetricName, but also stores the cell type of the metric.""" + def __init__( self, cell_type, # type: Union[Type[MetricCell], MetricCellFactory] @@ -201,6 +205,7 @@ def __reduce__(self): class MetricUpdater(object): """A callable that updates the metric as quickly as possible.""" + def __init__( self, cell_type, # type: Union[Type[MetricCell], MetricCellFactory] @@ -238,6 +243,7 @@ class MetricsContainer(object): Or the metrics associated with the process/SDK harness. I.e. memory usage. """ + def __init__(self, step_name): self.step_name = step_name self.lock = threading.Lock() @@ -287,32 +293,27 @@ def get_cumulative(self): """ counters = { MetricKey(self.step_name, k.metric_name): v.get_cumulative() - for k, - v in self.metrics.items() if k.cell_type == CounterCell + for k, v in self.metrics.items() if k.cell_type == CounterCell } distributions = { MetricKey(self.step_name, k.metric_name): v.get_cumulative() - for k, - v in self.metrics.items() if k.cell_type == DistributionCell + for k, v in self.metrics.items() if k.cell_type == DistributionCell } gauges = { MetricKey(self.step_name, k.metric_name): v.get_cumulative() - for k, - v in self.metrics.items() if k.cell_type == GaugeCell + for k, v in self.metrics.items() if k.cell_type == GaugeCell } string_sets = { MetricKey(self.step_name, k.metric_name): v.get_cumulative() - for k, - v in self.metrics.items() if k.cell_type == StringSetCell + for k, v in self.metrics.items() if k.cell_type == StringSetCell } bounded_tries = { MetricKey(self.step_name, k.metric_name): v.get_cumulative() - for k, - v in self.metrics.items() if k.cell_type == BoundedTrieCell + for k, v in self.metrics.items() if k.cell_type == BoundedTrieCell } return MetricUpdates( @@ -320,8 +321,8 @@ def get_cumulative(self): def to_runner_api(self): return [ - cell.to_runner_api_user_metric(key.metric_name) for key, - cell in self.metrics.items() + cell.to_runner_api_user_metric(key.metric_name) + for key, cell in self.metrics.items() ] def to_runner_api_monitoring_infos(self, transform_id): @@ -332,8 +333,7 @@ def to_runner_api_monitoring_infos(self, transform_id): items = list(self.metrics.items()) all_metrics = [ cell.to_runner_api_monitoring_info(key.metric_name, transform_id) - for key, - cell in items + for key, cell in items ] return { monitoring_infos.to_key(mi): mi @@ -364,6 +364,7 @@ class MetricUpdates(object): For Distribution metrics, it is DistributionData, and for Counter metrics, it's an int. """ + def __init__( self, counters=None, # type: Optional[Dict[MetricKey, int]] diff --git a/sdks/python/apache_beam/metrics/execution_test.py b/sdks/python/apache_beam/metrics/execution_test.py index 38e27f1f3d0c..8649a5612375 100644 --- a/sdks/python/apache_beam/metrics/execution_test.py +++ b/sdks/python/apache_beam/metrics/execution_test.py @@ -26,6 +26,7 @@ class TestMetricKey(unittest.TestCase): + def test_equality_for_key_with_labels(self): test_labels = {'label1', 'value1'} test_object = MetricKey( @@ -73,6 +74,7 @@ def test_equality_for_key_with_no_labels(self): class TestMetricsContainer(unittest.TestCase): + def test_add_to_counter(self): mc = MetricsContainer('astep') counter = mc.get_counter(MetricName('namespace', 'name')) diff --git a/sdks/python/apache_beam/metrics/metric.py b/sdks/python/apache_beam/metrics/metric.py index 9cf42370f4b1..c4bdc40bacdc 100644 --- a/sdks/python/apache_beam/metrics/metric.py +++ b/sdks/python/apache_beam/metrics/metric.py @@ -61,6 +61,7 @@ class Metrics(object): """Lets users create/access metric objects during pipeline execution.""" + @staticmethod def get_namespace(namespace: Union[Type, str]) -> str: if isinstance(namespace, type): @@ -155,6 +156,7 @@ def bounded_trie( class DelegatingCounter(Counter): """Metrics Counter that Delegates functionality to MetricsEnvironment.""" + def __init__( self, metric_name: MetricName, process_wide: bool = False) -> None: super().__init__(metric_name) @@ -166,24 +168,28 @@ def __init__( class DelegatingDistribution(Distribution): """Metrics Distribution Delegates functionality to MetricsEnvironment.""" + def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) self.update = MetricUpdater(cells.DistributionCell, metric_name) # type: ignore[method-assign] class DelegatingGauge(Gauge): """Metrics Gauge that Delegates functionality to MetricsEnvironment.""" + def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) self.set = MetricUpdater(cells.GaugeCell, metric_name) # type: ignore[method-assign] class DelegatingStringSet(StringSet): """Metrics StringSet that Delegates functionality to MetricsEnvironment.""" + def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) self.add = MetricUpdater(cells.StringSetCell, metric_name) # type: ignore[method-assign] class DelegatingBoundedTrie(BoundedTrie): """Metrics StringSet that Delegates functionality to MetricsEnvironment.""" + def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) self.add = MetricUpdater(cells.BoundedTrieCell, metric_name) # type: ignore[method-assign] @@ -280,6 +286,7 @@ class MetricsFilter(object): Note: This class only supports user defined metrics. """ + def __init__(self) -> None: self._names: Set[str] = set() self._namespaces: Set[str] = set() diff --git a/sdks/python/apache_beam/metrics/metric_test.py b/sdks/python/apache_beam/metrics/metric_test.py index 2e2e51b267a7..9b4a7bd0818b 100644 --- a/sdks/python/apache_beam/metrics/metric_test.py +++ b/sdks/python/apache_beam/metrics/metric_test.py @@ -43,6 +43,7 @@ class NameTest(unittest.TestCase): + def test_basic_metric_name(self): name = MetricName('namespace1', 'name1') self.assertEqual(name.namespace, 'namespace1') @@ -57,6 +58,7 @@ def test_basic_metric_name(self): class MetricResultsTest(unittest.TestCase): + def test_metric_filter_namespace_matching(self): filter = MetricsFilter().with_namespace('ns1') name = MetricName('ns1', 'name1') @@ -103,7 +105,9 @@ def test_metric_filter_step_matching(self): class MetricsTest(unittest.TestCase): + def test_get_namespace_class(self): + class MyClass(object): pass @@ -153,6 +157,7 @@ def test_general_urn_metric_name_str(self): @pytest.mark.it_validatesrunner def test_user_counter_using_pardo(self): + class SomeDoFn(beam.DoFn): """A custom dummy DoFn using yield.""" static_counter_elements = metrics.Metrics.counter( @@ -250,6 +255,7 @@ def test_create_counter_distribution(self): class LineageTest(unittest.TestCase): + def test_fq_name(self): test_cases = { "apache-beam": "apache-beam", diff --git a/sdks/python/apache_beam/metrics/metricbase.py b/sdks/python/apache_beam/metrics/metricbase.py index 9b35bb24f895..38c4f8e6e225 100644 --- a/sdks/python/apache_beam/metrics/metricbase.py +++ b/sdks/python/apache_beam/metrics/metricbase.py @@ -56,6 +56,7 @@ class MetricName(object): allows grouping related metrics together and also prevents collisions between multiple metrics of the same name. """ + def __init__( self, namespace: Optional[str], @@ -113,6 +114,7 @@ def fast_name(self): class Metric(object): """Base interface of a metric object.""" + def __init__(self, metric_name: MetricName) -> None: self.metric_name = metric_name @@ -120,6 +122,7 @@ def __init__(self, metric_name: MetricName) -> None: class Counter(Metric): """Counter metric interface. Allows a count to be incremented/decremented during pipeline execution.""" + def inc(self, n=1): raise NotImplementedError @@ -132,6 +135,7 @@ class Distribution(Metric): Allows statistics about the distribution of a variable to be collected during pipeline execution.""" + def update(self, value): raise NotImplementedError @@ -141,6 +145,7 @@ class Gauge(Metric): Allows tracking of the latest value of a variable during pipeline execution.""" + def set(self, value): raise NotImplementedError @@ -149,6 +154,7 @@ class StringSet(Metric): """StringSet Metric interface. Reports set of unique string values during pipeline execution..""" + def add(self, value): raise NotImplementedError @@ -157,6 +163,7 @@ class BoundedTrie(Metric): """BoundedTrie Metric interface. Reports set of unique string values during pipeline execution..""" + def add(self, value): raise NotImplementedError @@ -166,5 +173,6 @@ class Histogram(Metric): Allows statistics about the percentile of a variable to be collected during pipeline execution.""" + def update(self, value): raise NotImplementedError diff --git a/sdks/python/apache_beam/metrics/monitoring_infos.py b/sdks/python/apache_beam/metrics/monitoring_infos.py index cb4e60e218f6..6dc4b7ef9c57 100644 --- a/sdks/python/apache_beam/metrics/monitoring_infos.py +++ b/sdks/python/apache_beam/metrics/monitoring_infos.py @@ -495,8 +495,7 @@ def merge(a, b): return metrics_pb2.MonitoringInfo( urn=a.urn, type=a.type, - labels=dict((label, value) for label, - value in a.labels.items() + labels=dict((label, value) for label, value in a.labels.items() if b.labels.get(label) == value), payload=combiner(a.payload, b.payload)) diff --git a/sdks/python/apache_beam/metrics/monitoring_infos_test.py b/sdks/python/apache_beam/metrics/monitoring_infos_test.py index 022943f417c2..ad10bc1c3a24 100644 --- a/sdks/python/apache_beam/metrics/monitoring_infos_test.py +++ b/sdks/python/apache_beam/metrics/monitoring_infos_test.py @@ -25,6 +25,7 @@ class MonitoringInfosTest(unittest.TestCase): + def test_parse_namespace_and_name_for_nonuser_metric(self): input = monitoring_infos.create_monitoring_info( "beam:dummy:metric", "typeurn", None) diff --git a/sdks/python/apache_beam/ml/gcp/cloud_dlp.py b/sdks/python/apache_beam/ml/gcp/cloud_dlp.py index cb33ef60ef2c..b6625e00e6bc 100644 --- a/sdks/python/apache_beam/ml/gcp/cloud_dlp.py +++ b/sdks/python/apache_beam/ml/gcp/cloud_dlp.py @@ -56,6 +56,7 @@ class MaskDetectedDetails(PTransform): }, inspection_config={'info_types': [{'name': 'EMAIL_ADDRESS'}]}) """ + def __init__( self, project=None, @@ -138,6 +139,7 @@ class InspectForDetails(PTransform): pipeline | InspectForDetails(project='example-gcp-project', inspection_config={'info_types': [{'name': 'EMAIL_ADDRESS'}]}) """ + def __init__( self, project=None, @@ -181,6 +183,7 @@ def expand(self, pcoll): class _DeidentifyFn(DoFn): + def __init__(self, config=None, timeout=None, project=None, client=None): self.config = config self.timeout = timeout @@ -204,6 +207,7 @@ def process(self, element, **kwargs): class _InspectFn(DoFn): + def __init__(self, config=None, timeout=None, project=None): self.config = config self.timeout = timeout diff --git a/sdks/python/apache_beam/ml/gcp/cloud_dlp_it_test.py b/sdks/python/apache_beam/ml/gcp/cloud_dlp_it_test.py index a699aaa36be5..b9b923710dde 100644 --- a/sdks/python/apache_beam/ml/gcp/cloud_dlp_it_test.py +++ b/sdks/python/apache_beam/ml/gcp/cloud_dlp_it_test.py @@ -61,6 +61,7 @@ def extract_inspection_results(response): @unittest.skipIf(dlp_v2 is None, 'GCP dependencies are not installed') class CloudDLPIT(unittest.TestCase): + def setUp(self): self.test_pipeline = TestPipeline(is_integration_test=True) self.runner_name = type(self.test_pipeline.runner).__name__ diff --git a/sdks/python/apache_beam/ml/gcp/cloud_dlp_test.py b/sdks/python/apache_beam/ml/gcp/cloud_dlp_test.py index d4153e5b3fe9..2793b0326f06 100644 --- a/sdks/python/apache_beam/ml/gcp/cloud_dlp_test.py +++ b/sdks/python/apache_beam/ml/gcp/cloud_dlp_test.py @@ -45,6 +45,7 @@ @unittest.skipIf(dlp_v2 is None, 'GCP dependencies are not installed') class TestDeidentifyText(unittest.TestCase): + def test_exception_raised_when_no_config_is_provided(self): with self.assertRaises(ValueError): with TestPipeline() as p: @@ -54,8 +55,11 @@ def test_exception_raised_when_no_config_is_provided(self): @unittest.skipIf(dlp_v2 is None, 'GCP dependencies are not installed') class TestDeidentifyFn(unittest.TestCase): + def test_deidentify_called(self): + class ClientMock(object): + def deidentify_content(self, *args, **kwargs): # Check that we can marshal a valid request. dlp.DeidentifyContentRequest(kwargs['request']) @@ -99,6 +103,7 @@ def common_project_path(self, *args): @unittest.skipIf(dlp_v2 is None, 'GCP dependencies are not installed') class TestInspectText(unittest.TestCase): + def test_exception_raised_then_no_config_provided(self): with self.assertRaises(ValueError): with TestPipeline() as p: @@ -108,8 +113,11 @@ def test_exception_raised_then_no_config_provided(self): @unittest.skipIf(dlp_v2 is None, 'GCP dependencies are not installed') class TestInspectFn(unittest.TestCase): + def test_inspect_called(self): + class ClientMock(object): + def inspect_content(self, *args, **kwargs): # Check that we can marshal a valid request. dlp.InspectContentRequest(kwargs['request']) diff --git a/sdks/python/apache_beam/ml/gcp/naturallanguageml.py b/sdks/python/apache_beam/ml/gcp/naturallanguageml.py index f46b8d61639b..e6ee917ce00c 100644 --- a/sdks/python/apache_beam/ml/gcp/naturallanguageml.py +++ b/sdks/python/apache_beam/ml/gcp/naturallanguageml.py @@ -52,6 +52,7 @@ class Document(object): from_gcs (bool): Whether the content should be interpret as a Google Cloud Storage URI. The default value is :data:`False`. """ + def __init__( self, content: str, @@ -108,6 +109,7 @@ def AnnotateText( @beam.typehints.with_input_types(Document) @beam.typehints.with_output_types(language_v1.AnnotateTextResponse) class _AnnotateTextFn(beam.DoFn): + def __init__( self, features: Union[Mapping[str, bool], diff --git a/sdks/python/apache_beam/ml/gcp/naturallanguageml_test.py b/sdks/python/apache_beam/ml/gcp/naturallanguageml_test.py index 891726cb2688..98f7a447c439 100644 --- a/sdks/python/apache_beam/ml/gcp/naturallanguageml_test.py +++ b/sdks/python/apache_beam/ml/gcp/naturallanguageml_test.py @@ -36,6 +36,7 @@ @unittest.skipIf(language is None, 'GCP dependencies are not installed') class NaturalLanguageMlTest(unittest.TestCase): + def assertCounterEqual(self, pipeline_result, counter_name, expected): metrics = pipeline_result.metrics().query( MetricsFilter().with_name(counter_name)) diff --git a/sdks/python/apache_beam/ml/gcp/naturallanguageml_test_it.py b/sdks/python/apache_beam/ml/gcp/naturallanguageml_test_it.py index 9adf56a90102..c8e09f6b1ad8 100644 --- a/sdks/python/apache_beam/ml/gcp/naturallanguageml_test_it.py +++ b/sdks/python/apache_beam/ml/gcp/naturallanguageml_test_it.py @@ -49,6 +49,7 @@ def extract(response): @pytest.mark.it_postcommit @unittest.skipIf(AnnotateText is None, 'GCP dependencies are not installed') class NaturalLanguageMlTestIT(unittest.TestCase): + def test_analyzing_syntax(self): with TestPipeline(is_integration_test=True) as p: output = ( diff --git a/sdks/python/apache_beam/ml/gcp/recommendations_ai.py b/sdks/python/apache_beam/ml/gcp/recommendations_ai.py index 696ea5e322ea..c4855de7cd64 100644 --- a/sdks/python/apache_beam/ml/gcp/recommendations_ai.py +++ b/sdks/python/apache_beam/ml/gcp/recommendations_ai.py @@ -87,6 +87,7 @@ class CreateCatalogItem(PTransform): project='example-gcp-project', catalog_name='my-catalog') """ + def __init__( self, project: str = None, @@ -131,6 +132,7 @@ def expand(self, pcoll): class _CreateCatalogItemFn(DoFn): + def __init__( self, project: str = None, @@ -181,6 +183,7 @@ class ImportCatalogItems(PTransform): project='example-gcp-project', catalog_name='my-catalog') """ + def __init__( self, max_batch_size: int = 5000, @@ -229,6 +232,7 @@ def expand(self, pcoll): class _ImportCatalogItemsFn(DoFn): + def __init__( self, project=None, @@ -282,6 +286,7 @@ class WriteUserEvent(PTransform): catalog_name='my-catalog', event_store='my_event_store') """ + def __init__( self, project: str = None, @@ -380,6 +385,7 @@ class ImportUserEvents(PTransform): catalog_name='my-catalog', event_store='my_event_store') """ + def __init__( self, max_batch_size: int = 5000, @@ -487,6 +493,7 @@ class PredictUserEvent(PTransform): event_store='my_event_store', placement_id='recently_viewed_default') """ + def __init__( self, project: str = None, diff --git a/sdks/python/apache_beam/ml/gcp/recommendations_ai_test.py b/sdks/python/apache_beam/ml/gcp/recommendations_ai_test.py index 2f688d97a309..27cb929de769 100644 --- a/sdks/python/apache_beam/ml/gcp/recommendations_ai_test.py +++ b/sdks/python/apache_beam/ml/gcp/recommendations_ai_test.py @@ -39,6 +39,7 @@ recommendationengine is None, "Recommendations AI dependencies not installed.") class RecommendationsAICatalogItemTest(unittest.TestCase): + def setUp(self): self._mock_client = mock.Mock() self._mock_client.create_catalog_item.return_value = ( @@ -106,6 +107,7 @@ def test_ImportCatalogItems(self): recommendationengine is None, "Recommendations AI dependencies not installed.") class RecommendationsAIUserEventTest(unittest.TestCase): + def setUp(self): self._mock_client = mock.Mock() self._mock_client.write_user_event.return_value = ( @@ -169,6 +171,7 @@ def test_ImportUserEvents(self): recommendationengine is None, "Recommendations AI dependencies not installed.") class RecommendationsAIPredictTest(unittest.TestCase): + def setUp(self): self._mock_client = mock.Mock() self._mock_client.predict.return_value = [ diff --git a/sdks/python/apache_beam/ml/gcp/videointelligenceml.py b/sdks/python/apache_beam/ml/gcp/videointelligenceml.py index ebd35d2426c0..13c653bac274 100644 --- a/sdks/python/apache_beam/ml/gcp/videointelligenceml.py +++ b/sdks/python/apache_beam/ml/gcp/videointelligenceml.py @@ -54,6 +54,7 @@ class AnnotateVideo(PTransform): bytes base64-encoded video data. Accepts an `AsDict` side input that maps each video to a video context. """ + def __init__( self, features, @@ -119,6 +120,7 @@ class _VideoAnnotateFn(DoFn): service and outputs an element with the return result of the API (``google.cloud.videointelligence_v1.AnnotateVideoResponse``). """ + def __init__(self, features, location_id, metadata, timeout): super().__init__() self._client = None @@ -171,6 +173,7 @@ class AnnotateVideoWithContext(AnnotateVideo): where the former is either an URI (e.g. a GCS URI) or bytes base64-encoded video data """ + def __init__(self, features, location_id=None, metadata=None, timeout=120): """ Args: @@ -209,6 +212,7 @@ class _VideoAnnotateFnWithContext(_VideoAnnotateFn): an element with the return result of the API (``google.cloud.videointelligence_v1.AnnotateVideoResponse``). """ + def __init__(self, features, location_id, metadata, timeout): super().__init__( features=features, diff --git a/sdks/python/apache_beam/ml/gcp/videointelligenceml_test.py b/sdks/python/apache_beam/ml/gcp/videointelligenceml_test.py index 79c841938cdb..46296f233408 100644 --- a/sdks/python/apache_beam/ml/gcp/videointelligenceml_test.py +++ b/sdks/python/apache_beam/ml/gcp/videointelligenceml_test.py @@ -42,6 +42,7 @@ VideoIntelligenceServiceClient is None, 'Video intelligence dependencies are not installed') class VideoIntelligenceTest(unittest.TestCase): + def setUp(self): self._mock_client = mock.Mock() self.m2 = mock.Mock() diff --git a/sdks/python/apache_beam/ml/gcp/visionml.py b/sdks/python/apache_beam/ml/gcp/visionml.py index dd29dd377388..907cfb1580f0 100644 --- a/sdks/python/apache_beam/ml/gcp/visionml.py +++ b/sdks/python/apache_beam/ml/gcp/visionml.py @@ -190,6 +190,7 @@ class AnnotateImageWithContext(AnnotateImage): where the former is either an URI (e.g. a GCS URI) or bytes base64-encoded image data. """ + def __init__( self, features, @@ -271,6 +272,7 @@ class _ImageAnnotateFn(DoFn): """A DoFn that sends each input element to the GCP Vision API. Returns ``google.cloud.vision.BatchAnnotateImagesResponse``. """ + def __init__(self, features, retry, timeout, client_options, metadata): super().__init__() self._client = None diff --git a/sdks/python/apache_beam/ml/gcp/visionml_test.py b/sdks/python/apache_beam/ml/gcp/visionml_test.py index 479b3d80e4de..d94f32f501d2 100644 --- a/sdks/python/apache_beam/ml/gcp/visionml_test.py +++ b/sdks/python/apache_beam/ml/gcp/visionml_test.py @@ -41,6 +41,7 @@ @unittest.skipIf( ImageAnnotatorClient is None, 'Vision dependencies are not installed') class VisionTest(unittest.TestCase): + def setUp(self): self._mock_client = mock.Mock() self._mock_client.batch_annotate_images.return_value = None diff --git a/sdks/python/apache_beam/ml/gcp/visionml_test_it.py b/sdks/python/apache_beam/ml/gcp/visionml_test_it.py index 00fd38704a02..7647012ef00b 100644 --- a/sdks/python/apache_beam/ml/gcp/visionml_test_it.py +++ b/sdks/python/apache_beam/ml/gcp/visionml_test_it.py @@ -43,6 +43,7 @@ def extract(response): @pytest.mark.it_postcommit @unittest.skipIf(vision is None, 'GCP dependencies are not installed') class VisionMlTestIT(unittest.TestCase): + def test_text_detection_with_language_hint(self): IMAGES_TO_ANNOTATE = [ 'gs://apache-beam-samples/advanced_analytics/vision/sign.jpg' diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 5cc38cd842c9..98331a833dc5 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -155,6 +155,7 @@ class KeyModelPathMapping(Generic[KeyT]): class ModelHandler(Generic[ExampleT, PredictionT, ModelT]): """Has the ability to load and apply an ML model.""" + def __init__(self): """Environment variables are set using a dict named 'env_vars' before loading the model. Child classes can accept this dict as a kwarg.""" @@ -345,6 +346,7 @@ class _ModelManager: single copy of each model into a multi_process_shared object and then return a lookup key for that object. """ + def __init__(self, mh_map: Dict[str, ModelHandler]): """ Args: @@ -453,6 +455,7 @@ class KeyModelMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]): `KeyModelMapping(['key1', 'key2'], myMh)`, all examples with keys `key1` or `key2` will be run against the model defined by the `myMh` ModelHandler. """ + def __init__( self, keys: List[KeyT], mh: ModelHandler[ExampleT, PredictionT, ModelT]): self.keys = keys @@ -463,6 +466,7 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[Tuple[KeyT, ExampleT], Tuple[KeyT, PredictionT], Union[ModelT, _ModelManager]]): + def __init__( self, unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT], @@ -849,6 +853,7 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], Union[PredictionT, Tuple[KeyT, PredictionT]], ModelT]): + def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): """A ModelHandler that takes examples that might have keys and returns predictions that might have keys. @@ -945,6 +950,7 @@ class _PrebatchedModelHandler(Generic[ExampleT, PredictionT, ModelT], ModelHandler[Sequence[ExampleT], PredictionT, ModelT]): + def __init__(self, base: ModelHandler[ExampleT, PredictionT, ModelT]): """A ModelHandler that skips batching in RunInference. @@ -1006,6 +1012,7 @@ class _PreProcessingModelHandler(Generic[ExampleT, PreProcessT], ModelHandler[PreProcessT, PredictionT, ModelT]): + def __init__( self, base: ModelHandler[ExampleT, PredictionT, ModelT], @@ -1071,6 +1078,7 @@ class _PostProcessingModelHandler(Generic[ExampleT, ModelT, PostProcessT], ModelHandler[ExampleT, PostProcessT, ModelT]): + def __init__( self, base: ModelHandler[ExampleT, PredictionT, ModelT], @@ -1134,6 +1142,7 @@ def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: class RunInference(beam.PTransform[beam.PCollection[Union[ExampleT, Iterable[ExampleT]]], beam.PCollection[PredictionT]]): + def __init__( self, model_handler: ModelHandler[ExampleT, PredictionT, Any], @@ -1430,6 +1439,7 @@ class _MetricsCollector: """ A metrics collector that tracks ML related performance and memory usage. """ + def __init__(self, namespace: str, prefix: str = ''): """ Args: @@ -1496,6 +1506,7 @@ class _ModelRoutingStrategy(): different models. Currently only supports round-robin, but can be extended to support other protocols if needed. """ + def __init__(self): self._cur_index = 0 @@ -1510,6 +1521,7 @@ class _ModelStatus(): Currently, this only includes whether or not the model is valid. Uses the model tag to map models to metadata. """ + def __init__(self, share_model_across_processes: bool): self._active_tags = set() self._invalid_tags = set() @@ -1615,6 +1627,7 @@ class _SharedModelWrapper(): This allows us to round robin calls to models sitting in different processes so that we can more efficiently use resources (e.g. GPUs). """ + def __init__(self, models: List[Any], model_tag: str): self.models = models if len(models) > 1: @@ -1636,6 +1649,7 @@ def all_models(self): class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]): + def __init__( self, model_handler: ModelHandler[ExampleT, PredictionT, Any], @@ -1672,6 +1686,7 @@ def _load_model( side_input_model_path: Optional[Union[str, List[KeyModelPathMapping]]] = None ) -> _SharedModelWrapper: + def load(): """Function for constructing shared LoadedModel.""" memory_before = _get_current_process_memory_in_bytes() diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 31f02c9c61c5..fc0d1240a0c9 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -48,11 +48,13 @@ class FakeModel: + def predict(self, example: int) -> int: return example + 1 class FakeStatefulModel: + def __init__(self, state: int): if state == 100: raise Exception('Oh no') @@ -66,6 +68,7 @@ def increment_state(self, amount: int): class FakeSlowModel: + def __init__(self, sleep_on_load_seconds=0, file_path_write_on_del=None): self._file_path_write_on_del = file_path_write_on_del @@ -82,6 +85,7 @@ def __del__(self): class FakeIncrementingModel: + def __init__(self): self._state = 0 @@ -91,6 +95,7 @@ def predict(self, example: int) -> int: class FakeSlowModelHandler(base.ModelHandler[int, int, FakeModel]): + def __init__( self, sleep_on_load: int, @@ -119,6 +124,7 @@ def batch_elements_kwargs(self): class FakeModelHandler(base.ModelHandler[int, int, FakeModel]): + def __init__( self, clock=None, @@ -189,6 +195,7 @@ def get_num_bytes(self, batch: Sequence[int]) -> int: class FakeModelHandlerReturnsPredictionResult( base.ModelHandler[int, base.PredictionResult, FakeModel]): + def __init__( self, clock=None, @@ -233,6 +240,7 @@ def share_model_across_processes(self): class FakeModelHandlerNoEnvVars(base.ModelHandler[int, int, FakeModel]): + def __init__( self, clock=None, min_batch_size=1, max_batch_size=9999, **kwargs): self._fake_clock = clock @@ -265,6 +273,7 @@ def batch_elements_kwargs(self): class FakeClock: + def __init__(self): # Start at 10 seconds. self.current_time_ns = 10_000_000_000 @@ -274,11 +283,13 @@ def time_ns(self) -> int: class ExtractInferences(beam.DoFn): + def process(self, prediction_result): yield prediction_result.inference class FakeModelHandlerNeedsBigBatch(FakeModelHandler): + def run_inference(self, batch, unused_model, inference_args=None): if len(batch) < 100: raise ValueError('Unexpectedly small batch') @@ -289,6 +300,7 @@ def batch_elements_kwargs(self): class FakeModelHandlerFailsOnInferenceArgs(FakeModelHandler): + def run_inference(self, batch, unused_model, inference_args=None): raise ValueError( 'run_inference should not be called because error should already be ' @@ -296,6 +308,7 @@ def run_inference(self, batch, unused_model, inference_args=None): class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler): + def run_inference(self, batch, unused_model, inference_args=None): if not inference_args: raise ValueError('inference_args should exist') @@ -306,6 +319,7 @@ def validate_inference_args(self, inference_args): class RunInferenceBaseTest(unittest.TestCase): + def test_run_inference_impl_simple_examples(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] @@ -498,6 +512,7 @@ def test_run_inference_impl_with_keyed_examples_many_mhs_max_models_hint( self.assertLess(load_latency_dist_aggregate.committed.count, 12) def test_keyed_many_model_handlers_validation(self): + def mult_two(example: str) -> int: return int(example) * 2 @@ -612,6 +627,7 @@ def test_run_inference_impl_with_maybe_keyed_examples_multi_process_shared( assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed') def test_run_inference_preprocessing(self): + def mult_two(example: str) -> int: return int(example) * 2 @@ -632,6 +648,7 @@ def test_run_inference_prebatched(self): assert_that(actual, equal_to(expected), label='assert:inferences') def test_run_inference_preprocessing_multiple_fns(self): + def add_one(example: str) -> int: return int(example) + 1 @@ -648,6 +665,7 @@ def mult_two(example: int) -> int: assert_that(actual, equal_to(expected), label='assert:inferences') def test_run_inference_postprocessing(self): + def mult_two(example: int) -> str: return str(example * 2) @@ -660,6 +678,7 @@ def mult_two(example: int) -> str: assert_that(actual, equal_to(expected), label='assert:inferences') def test_run_inference_postprocessing_multiple_fns(self): + def add_one(example: int) -> str: return str(int(example) + 1) @@ -676,6 +695,7 @@ def mult_two(example: int) -> int: assert_that(actual, equal_to(expected), label='assert:inferences') def test_run_inference_preprocessing_dlq(self): + def mult_two(example: str) -> int: if example == "5": raise Exception("TEST") @@ -700,6 +720,7 @@ def mult_two(example: str) -> int: bad_without_error, equal_to(expected_bad), label='assert:failures') def test_run_inference_postprocessing_dlq(self): + def mult_two(example: int) -> str: if example == 6: raise Exception("TEST") @@ -724,6 +745,7 @@ def mult_two(example: int) -> str: bad_without_error, equal_to(expected_bad), label='assert:failures') def test_run_inference_pre_and_post_processing_dlq(self): + def mult_two_pre(example: str) -> int: if example == "5": raise Exception("TEST") @@ -767,15 +789,15 @@ def mult_two_post(example: int) -> str: label='assert:failures_post') def test_run_inference_keyed_pre_and_post_processing(self): + def mult_two(element): return (element[0], element[1] * 2) with TestPipeline() as pipeline: examples = [1, 5, 3, 10] keyed_examples = [(i, example) for i, example in enumerate(examples)] - expected = [ - (i, ((example * 2) + 1) * 2) for i, example in enumerate(examples) - ] + expected = [(i, ((example * 2) + 1) * 2) + for i, example in enumerate(examples)] pcoll = pipeline | 'start' >> beam.Create(keyed_examples) actual = pcoll | base.RunInference( base.KeyedModelHandler(FakeModelHandler()).with_preprocess_fn( @@ -783,6 +805,7 @@ def mult_two(element): assert_that(actual, equal_to(expected), label='assert:inferences') def test_run_inference_maybe_keyed_pre_and_post_processing(self): + def mult_two(element): return element * 2 @@ -793,9 +816,8 @@ def mult_two_keyed(element): examples = [1, 5, 3, 10] keyed_examples = [(i, example) for i, example in enumerate(examples)] expected = [((2 * example) + 1) * 2 for example in examples] - keyed_expected = [ - (i, ((2 * example) + 1) * 2) for i, example in enumerate(examples) - ] + keyed_expected = [(i, ((2 * example) + 1) * 2) + for i, example in enumerate(examples)] model_handler = base.MaybeKeyedModelHandler(FakeModelHandler()) pcoll = pipeline | 'Unkeyed' >> beam.Create(examples) @@ -1071,6 +1093,7 @@ def test_model_handler_compatibility(self): # If this test fails, likely third party implementations of # ModelHandler will break. class ThirdPartyHandler(base.ModelHandler[int, int, FakeModel]): + def __init__(self, custom_parameter=None): pass @@ -1225,6 +1248,7 @@ def test_run_inference_side_input_in_batch(self): # applying GroupByKey to utilize windowing according to # https://beam.apache.org/documentation/programming-guide/#windowing-bounded-collections class _EmitElement(beam.DoFn): + def process(self, element): for e in element: yield e @@ -1323,6 +1347,7 @@ def test_run_inference_side_input_in_batch_per_key_models(self): ]) class _EmitElement(beam.DoFn): + def process(self, element): for e in element: yield e @@ -1424,6 +1449,7 @@ def test_run_inference_side_input_in_batch_per_key_models_split_cohort(self): ]) class _EmitElement(beam.DoFn): + def process(self, element): for e in element: yield e @@ -1499,6 +1525,7 @@ def test_run_inference_side_input_in_batch_multi_process_shared(self): # applying GroupByKey to utilize windowing according to # https://beam.apache.org/documentation/programming-guide/#windowing-bounded-collections class _EmitElement(beam.DoFn): + def process(self, element): for e in element: yield e diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py b/sdks/python/apache_beam/ml/inference/huggingface_inference.py index 2934a5362910..e4e61322be1e 100644 --- a/sdks/python/apache_beam/ml/inference/huggingface_inference.py +++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py @@ -48,15 +48,16 @@ "HuggingFacePipelineModelHandler", ] -TensorInferenceFn = Callable[[ - Sequence[Union[torch.Tensor, tf.Tensor]], - Union[AutoModel, TFAutoModel], - str, - Optional[Dict[str, Any]], - Optional[str], -], - Iterable[PredictionResult], - ] +TensorInferenceFn = Callable[ + [ + Sequence[Union[torch.Tensor, tf.Tensor]], + Union[AutoModel, TFAutoModel], + str, + Optional[Dict[str, Any]], + Optional[str], + ], + Iterable[PredictionResult], +] KeyedTensorInferenceFn = Callable[[ Sequence[Dict[str, Union[torch.Tensor, tf.Tensor]]], @@ -212,6 +213,7 @@ class HuggingFaceModelHandlerKeyedTensor(ModelHandler[Dict[str, PredictionResult, Union[AutoModel, TFAutoModel]]): + def __init__( self, model_uri: str, @@ -401,6 +403,7 @@ class HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor], PredictionResult, Union[AutoModel, TFAutoModel]]): + def __init__( self, model_uri: str, @@ -582,6 +585,7 @@ def _default_pipeline_inference_fn( class HuggingFacePipelineModelHandler(ModelHandler[str, PredictionResult, Pipeline]): + def __init__( self, task: Union[str, PipelineTask] = "", diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference_it_test.py b/sdks/python/apache_beam/ml/inference/huggingface_inference_it_test.py index dd675d1935a7..b056bc0257cb 100644 --- a/sdks/python/apache_beam/ml/inference/huggingface_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/huggingface_inference_it_test.py @@ -41,6 +41,7 @@ @pytest.mark.it_postcommit @pytest.mark.timeout(1800) class HuggingFaceInference(unittest.TestCase): + def test_hf_language_modeling(self): test_pipeline = TestPipeline(is_integration_test=True) # Path to text file containing some sentences diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference_test.py b/sdks/python/apache_beam/ml/inference/huggingface_inference_test.py index 74c7255afb9c..c3b22d77b01d 100644 --- a/sdks/python/apache_beam/ml/inference/huggingface_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/huggingface_inference_test.py @@ -56,12 +56,14 @@ def fake_inference_fn_tensor( class FakeTorchModel: + def predict(self, input: torch.Tensor): return input @pytest.mark.uses_transformers class HuggingFaceInferenceTest(unittest.TestCase): + def setUp(self) -> None: self.tmpdir = tempfile.mkdtemp() @@ -76,8 +78,7 @@ def test_predict_tensor(self): inference_fn=fake_inference_fn_tensor) batched_examples = [tf.constant([1]), tf.constant([10]), tf.constant([100])] expected_predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip( + PredictionResult(ex, pred) for ex, pred in zip( batched_examples, [tf.math.multiply(n, 10) for n in batched_examples]) ] @@ -95,8 +96,7 @@ def test_predict_tensor_with_inference_args(self): inference_args={"add": True}) batched_examples = [tf.constant([1]), tf.constant([10]), tf.constant([100])] expected_predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip( + PredictionResult(ex, pred) for ex, pred in zip( batched_examples, [ tf.math.add(tf.math.multiply(n, 10), 10) for n in batched_examples diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference_it_test.py b/sdks/python/apache_beam/ml/inference/onnx_inference_it_test.py index 3902a61dc260..2a03c666cc89 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference_it_test.py @@ -47,6 +47,7 @@ def process_outputs(filepath): 'Missing dependencies. ' 'Test depends on onnx and transformers') class OnnxInference(unittest.TestCase): + @pytest.mark.uses_onnx @pytest.mark.it_postcommit def test_onnx_run_inference_roberta_sentiment_classification(self): diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference_test.py b/sdks/python/apache_beam/ml/inference/onnx_inference_test.py index ab87c4cceef2..349274d78fcd 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference_test.py @@ -61,6 +61,7 @@ class PytorchLinearRegression(torch.nn.Module): + def __init__(self, input_dim, output_dim): super().__init__() self.linear = torch.nn.Linear(input_dim, output_dim) @@ -75,6 +76,7 @@ def generate(self, x): class TestDataAndModel(): + def get_one_feature_samples(self): return [ np.array([1], dtype="float32"), @@ -85,8 +87,7 @@ def get_one_feature_samples(self): def get_one_feature_predictions(self): return [ - PredictionResult(ex, pred) for ex, - pred in zip( + PredictionResult(ex, pred) for ex, pred in zip( self.get_one_feature_samples(), [example * 2.0 + 0.5 for example in self.get_one_feature_samples()]) ] @@ -101,12 +102,10 @@ def get_two_feature_examples(self): def get_two_feature_predictions(self): return [ - PredictionResult(ex, pred) for ex, - pred in zip( - self.get_two_feature_examples(), - [ - f1 * 2.0 + f2 * 3 + 0.5 for f1, - f2 in self.get_two_feature_examples() + PredictionResult(ex, pred) for ex, pred in zip( + self.get_two_feature_examples(), [ + f1 * 2.0 + f2 * 3 + 0.5 + for f1, f2 in self.get_two_feature_examples() ]) ] @@ -188,6 +187,7 @@ def __init__( #pylint: disable=dangerous-default-value class OnnxTestBase(unittest.TestCase): + def setUp(self): self.tmpdir = tempfile.mkdtemp() self.test_data_and_model = TestDataAndModel() @@ -198,6 +198,7 @@ def tearDown(self): @pytest.mark.uses_onnx class OnnxPytorchRunInferenceTest(OnnxTestBase): + def test_onnx_pytorch_run_inference(self): examples = self.test_data_and_model.get_one_feature_samples() expected_predictions = self.test_data_and_model.get_one_feature_predictions( @@ -206,17 +207,23 @@ def test_onnx_pytorch_run_inference(self): model = self.test_data_and_model.get_torch_one_feature_model() path = os.path.join(self.tmpdir, 'my_onnx_pytorch_path') dummy_input = torch.randn(4, 1, requires_grad=True) - torch.onnx.export(model, - dummy_input, # model input - path, # where to save the model - export_params=True, # store the trained parameter weights - opset_version=10, # the ONNX version - do_constant_folding=True, # whether to execute constant- - # folding for optimization - input_names = ['input'], # model's input names - output_names = ['output'], # model's output names - dynamic_axes={'input' : {0 : 'batch_size'}, - 'output' : {0 : 'batch_size'}}) + torch.onnx.export( + model, + dummy_input, # model input + path, # where to save the model + export_params=True, # store the trained parameter weights + opset_version=10, # the ONNX version + do_constant_folding=True, # whether to execute constant- + # folding for optimization + input_names=['input'], # model's input names + output_names=['output'], # model's output names + dynamic_axes={ + 'input': { + 0: 'batch_size' + }, 'output': { + 0: 'batch_size' + } + }) inference_runner = TestOnnxModelHandler(path) inference_session = ort.InferenceSession( @@ -252,6 +259,7 @@ def test_namespace(self): @pytest.mark.uses_onnx class OnnxTensorflowRunInferenceTest(OnnxTestBase): + def test_onnx_tensorflow_run_inference(self): examples = self.test_data_and_model.get_one_feature_samples() expected_predictions = self.test_data_and_model.get_one_feature_predictions( @@ -276,6 +284,7 @@ def test_onnx_tensorflow_run_inference(self): @pytest.mark.uses_onnx class OnnxSklearnRunInferenceTest(OnnxTestBase): + def save_model(self, model, input_dim, path): # assume float input initial_type = [('float_input', FloatTensorType([None, input_dim]))] @@ -303,19 +312,26 @@ def test_onnx_sklearn_run_inference(self): @pytest.mark.uses_onnx class OnnxPytorchRunInferencePipelineTest(OnnxTestBase): + def exportModelToOnnx(self, model, path): dummy_input = torch.randn(4, 2, requires_grad=True) - torch.onnx.export(model, - dummy_input, # model input - path, # where to save the model - export_params=True, # store the trained parameter weights - opset_version=10, # the ONNX version - do_constant_folding=True, # whether to execute constant - # folding for optimization - input_names = ['input'], # odel's input names - output_names = ['output'], # model's output names - dynamic_axes={'input' : {0 : 'batch_size'}, - 'output' : {0 : 'batch_size'}}) + torch.onnx.export( + model, + dummy_input, # model input + path, # where to save the model + export_params=True, # store the trained parameter weights + opset_version=10, # the ONNX version + do_constant_folding=True, # whether to execute constant + # folding for optimization + input_names=['input'], # odel's input names + output_names=['output'], # model's output names + dynamic_axes={ + 'input': { + 0: 'batch_size' + }, 'output': { + 0: 'batch_size' + } + }) def test_pipeline_local_model_simple(self): with TestPipeline() as pipeline: @@ -414,6 +430,7 @@ def test_invalid_input_type(self): @pytest.mark.uses_onnx class OnnxTensorflowRunInferencePipelineTest(OnnxTestBase): + def exportModelToOnnx(self, model, path): spec = (tf.TensorSpec((None, 2), tf.float32, name="input"), ) _, _ = tf2onnx.convert.from_keras(model, @@ -469,6 +486,7 @@ def test_invalid_input_type(self): @pytest.mark.uses_onnx class OnnxSklearnRunInferencePipelineTest(OnnxTestBase): + def save_model(self, model, input_dim, path): # assume float input initial_type = [('float_input', FloatTensorType([None, input_dim]))] diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 9a89cba7243a..7c8c5c376cb0 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -162,6 +162,7 @@ def make_tensor_model_fn(model_fn: str) -> TensorInferenceFn: model_fn: A string name of the method to be used. This is accessed through getattr(model, model_fn) """ + def attr_fn( batch: Sequence[torch.Tensor], model: torch.nn.Module, @@ -182,6 +183,7 @@ def attr_fn( class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, PredictionResult, torch.nn.Module]): + def __init__( self, state_dict_path: Optional[str] = None, @@ -391,6 +393,7 @@ def make_keyed_tensor_model_fn(model_fn: str) -> KeyedTensorInferenceFn: model_fn: A string name of the method to be used. This is accessed through getattr(model, model_fn) """ + def attr_fn( batch: Sequence[Dict[str, torch.Tensor]], model: torch.nn.Module, @@ -423,6 +426,7 @@ def attr_fn( class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], PredictionResult, torch.nn.Module]): + def __init__( self, state_dict_path: Optional[str] = None, diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py index 2cc49be54599..51c23ef1389a 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py @@ -66,6 +66,7 @@ def process_outputs(filepath): 'Missing dependencies. ' 'Test depends on torch, torchvision, pillow, and transformers') class PyTorchInference(unittest.TestCase): + @pytest.mark.uses_pytorch @pytest.mark.it_postcommit def test_torch_run_inference_imagenet_mobilenetv2(self): @@ -139,8 +140,8 @@ def test_torch_run_inference_coco_maskrcnn_resnet50_fpn_v1_and_v2(self): output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt']) model_state_dict_paths = [ - 'gs://apache-beam-ml/models/torchvision.models.detection.maskrcnn_resnet50_fpn.pth', # pylint: disable=line-too-long - 'gs://apache-beam-ml/models/torchvision.models.detection.maskrcnn_resnet50_fpn_v2.pth' # pylint: disable=line-too-long + 'gs://apache-beam-ml/models/torchvision.models.detection.maskrcnn_resnet50_fpn.pth', # pylint: disable=line-too-long + 'gs://apache-beam-ml/models/torchvision.models.detection.maskrcnn_resnet50_fpn_v2.pth' # pylint: disable=line-too-long ] images_dir = 'gs://apache-beam-ml/datasets/coco/raw-data/val2017' extra_opts = { diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index dd5793af2dd1..3ad5d2e3d20d 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -60,10 +60,8 @@ ] TWO_FEATURES_PREDICTIONS = [ - PredictionResult(ex, pred) for ex, - pred in zip( - TWO_FEATURES_EXAMPLES, - torch.Tensor( + PredictionResult(ex, pred) for ex, pred in zip( + TWO_FEATURES_EXAMPLES, torch.Tensor( [f1 * 2.0 + f2 * 3 + 0.5 for f1, f2 in TWO_FEATURES_EXAMPLES]).reshape(-1, 1)) ] @@ -95,20 +93,17 @@ ] KEYED_TORCH_PREDICTIONS = [ - PredictionResult(ex, pred) for ex, - pred in zip( - KEYED_TORCH_EXAMPLES, - torch.Tensor([(example['k1'] * 2.0 + 0.5) + (example['k2'] * 2.0 + 0.5) - for example in KEYED_TORCH_EXAMPLES]).reshape(-1, 1)) + PredictionResult(ex, pred) for ex, pred in zip( + KEYED_TORCH_EXAMPLES, torch.Tensor( + [(example['k1'] * 2.0 + 0.5) + (example['k2'] * 2.0 + 0.5) + for example in KEYED_TORCH_EXAMPLES]).reshape(-1, 1)) ] KEYED_TORCH_HELPER_PREDICTIONS = [ - PredictionResult(ex, pred) for ex, - pred in zip( - KEYED_TORCH_EXAMPLES, - torch.Tensor([(example['k1'] * 2.0 + 0.5) + - (example['k2'] * 2.0 + 0.5) + 0.5 - for example in KEYED_TORCH_EXAMPLES]).reshape(-1, 1)) + PredictionResult(ex, pred) for ex, pred in zip( + KEYED_TORCH_EXAMPLES, torch.Tensor( + [(example['k1'] * 2.0 + 0.5) + (example['k2'] * 2.0 + 0.5) + 0.5 + for example in KEYED_TORCH_EXAMPLES]).reshape(-1, 1)) ] KEYED_TORCH_DICT_OUT_PREDICTIONS = [ @@ -120,6 +115,7 @@ class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandlerTensor): + def __init__(self, device, *, inference_fn=default_tensor_inference_fn): self._device = device self._inference_fn = inference_fn @@ -129,6 +125,7 @@ def __init__(self, device, *, inference_fn=default_tensor_inference_fn): class TestPytorchModelHandlerKeyedTensorForInferenceOnly( PytorchModelHandlerKeyedTensor): + def __init__(self, device, *, inference_fn=default_keyed_tensor_inference_fn): self._device = device self._inference_fn = inference_fn @@ -139,8 +136,8 @@ def __init__(self, device, *, inference_fn=default_keyed_tensor_inference_fn): def _compare_prediction_result(x, y): if isinstance(x.example, dict): example_equals = all( - torch.equal(x, y) for x, - y in zip(x.example.values(), y.example.values())) + torch.equal(x, y) + for x, y in zip(x.example.values(), y.example.values())) else: example_equals = torch.equal(x.example, y.example) if not example_equals: @@ -148,8 +145,8 @@ def _compare_prediction_result(x, y): if isinstance(x.inference, dict): return all( - torch.equal(x, y) for x, - y in zip(x.inference.values(), y.inference.values())) + torch.equal(x, y) + for x, y in zip(x.inference.values(), y.inference.values())) return torch.equal(x.inference, y.inference) @@ -157,15 +154,15 @@ def _compare_prediction_result(x, y): def custom_tensor_inference_fn( batch, model, device, inference_args, model_id=None): predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip( - batch, - torch.Tensor([item * 2.0 + 1.5 for item in batch]).reshape(-1, 1)) + PredictionResult(ex, pred) for ex, pred in zip( + batch, torch.Tensor([item * 2.0 + 1.5 + for item in batch]).reshape(-1, 1)) ] return predictions class PytorchLinearRegression(torch.nn.Module): + def __init__(self, input_dim, output_dim): super().__init__() self.linear = torch.nn.Linear(input_dim, output_dim) @@ -180,6 +177,7 @@ def generate(self, x): class PytorchLinearRegressionDict(torch.nn.Module): + def __init__(self, input_dim, output_dim): super().__init__() self.linear = torch.nn.Linear(input_dim, output_dim) @@ -198,6 +196,7 @@ class PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs(torch.nn.Module): (typically model-related info) used to configure the model before its predict call is invoked """ + def __init__(self, input_dim, output_dim): super().__init__() self.linear = torch.nn.Linear(input_dim, output_dim) @@ -213,6 +212,7 @@ def forward(self, k1, k2, prediction_param_array, prediction_param_bool): @pytest.mark.uses_pytorch class PytorchRunInferenceTest(unittest.TestCase): + def test_run_inference_single_tensor_feature(self): examples = [ torch.from_numpy(np.array([1], dtype="float32")), @@ -221,11 +221,9 @@ def test_run_inference_single_tensor_feature(self): torch.from_numpy(np.array([10.0], dtype="float32")), ] expected_predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip( - examples, - torch.Tensor([example * 2.0 + 0.5 - for example in examples]).reshape(-1, 1)) + PredictionResult(ex, pred) for ex, pred in zip( + examples, torch.Tensor( + [example * 2.0 + 0.5 for example in examples]).reshape(-1, 1)) ] model = PytorchLinearRegression(input_dim=1, output_dim=1) @@ -274,11 +272,9 @@ def test_run_inference_custom(self): torch.from_numpy(np.array([10.0], dtype="float32")), ] expected_predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip( - examples, - torch.Tensor([example * 2.0 + 1.5 - for example in examples]).reshape(-1, 1)) + PredictionResult(ex, pred) for ex, pred in zip( + examples, torch.Tensor( + [example * 2.0 + 1.5 for example in examples]).reshape(-1, 1)) ] model = PytorchLinearRegression(input_dim=1, output_dim=1) @@ -308,7 +304,9 @@ def test_run_inference_keyed(self): 'k2' : torch.tensor([4, 5, 6]) } """ + class PytorchLinearRegressionMultipleArgs(torch.nn.Module): + def __init__(self, input_dim, output_dim): super().__init__() self.linear = torch.nn.Linear(input_dim, output_dim) @@ -330,7 +328,9 @@ def forward(self, k1, k2): self.assertTrue(_compare_prediction_result(actual, expected)) def test_run_inference_keyed_dict_output(self): + class PytorchLinearRegressionMultipleArgsDict(torch.nn.Module): + def __init__(self, input_dim, output_dim): super().__init__() self.linear = torch.nn.Linear(input_dim, output_dim) @@ -385,11 +385,9 @@ def test_run_inference_helper(self): torch.from_numpy(np.array([10.0], dtype="float32")), ] expected_predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip( - examples, - torch.Tensor([example * 2.0 + 1.0 - for example in examples]).reshape(-1, 1)) + PredictionResult(ex, pred) for ex, pred in zip( + examples, torch.Tensor( + [example * 2.0 + 1.0 for example in examples]).reshape(-1, 1)) ] gen_fn = make_tensor_model_fn('generate') @@ -421,7 +419,9 @@ def test_run_inference_keyed_helper(self): 'k2' : torch.tensor([4, 5, 6]) } """ + class PytorchLinearRegressionMultipleArgs(torch.nn.Module): + def __init__(self, input_dim, output_dim): super().__init__() self.linear = torch.nn.Linear(input_dim, output_dim) @@ -465,6 +465,7 @@ def test_namespace(self): @pytest.mark.uses_pytorch class PytorchRunInferencePipelineTest(unittest.TestCase): + def setUp(self): self.tmpdir = tempfile.mkdtemp() @@ -663,11 +664,9 @@ def test_pipeline_gcs_model(self): examples = torch.from_numpy( np.array([1, 5, 3, 10], dtype="float32").reshape(-1, 1)) expected_predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip( - examples, - torch.Tensor([example * 2.0 + 0.5 - for example in examples]).reshape(-1, 1)) + PredictionResult(ex, pred) for ex, pred in zip( + examples, torch.Tensor( + [example * 2.0 + 0.5 for example in examples]).reshape(-1, 1)) ] gs_pth = 'gs://apache-beam-ml/models/' \ @@ -691,11 +690,9 @@ def test_pipeline_gcs_model_control_batching(self): examples = torch.from_numpy( np.array([1, 5, 3, 10], dtype="float32").reshape(-1, 1)) expected_predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip( - examples, - torch.Tensor([example * 2.0 + 0.5 - for example in examples]).reshape(-1, 1)) + PredictionResult(ex, pred) for ex, pred in zip( + examples, torch.Tensor( + [example * 2.0 + 0.5 for example in examples]).reshape(-1, 1)) ] def batch_validator_tensor_inference_fn( @@ -983,6 +980,7 @@ def test_env_vars_set_correctly_keyed_tensor_handler(self): @pytest.mark.uses_pytorch class PytorchInferenceTestWithMocks(unittest.TestCase): + def setUp(self): self._load_model = pytorch_inference._load_model pytorch_inference._load_model = unittest.mock.MagicMock( diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index a29657968eaa..afef63e76923 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -82,6 +82,7 @@ def _default_numpy_inference_fn( class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray, PredictionResult, BaseEstimator]): + def __init__( self, model_uri: str, @@ -217,6 +218,7 @@ def _default_pandas_inference_fn( class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame, PredictionResult, BaseEstimator]): + def __init__( self, model_uri: str, diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_it_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_it_test.py index c5480234cda6..bdb0ebcc6e05 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_it_test.py @@ -54,6 +54,7 @@ def file_lines_sorted(filepath): @pytest.mark.uses_sklearn @pytest.mark.it_postcommit class SklearnInference(unittest.TestCase): + def test_sklearn_mnist_classification(self): test_pipeline = TestPipeline(is_integration_test=True) input_file = 'gs://apache-beam-ml/testing/inputs/it_mnist_data.csv' diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index c2ea9fa1e955..15be48605ed1 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -76,13 +76,14 @@ def _compare_dataframe_predictions(a_in, b_in): example_equal = pandas.DataFrame.equals(a.example, b.example) if isinstance(a.inference, dict): return all( - math.floor(a) == math.floor(b) for a, - b in zip(a.inference.values(), b.inference.values())) and example_equal + math.floor(a) == math.floor(b) for a, b in zip( + a.inference.values(), b.inference.values())) and example_equal inference_equal = math.floor(a.inference) == math.floor(b.inference) return inference_equal and example_equal and keys_equal class FakeModel: + def __init__(self): self.total_predict_calls = 0 @@ -92,6 +93,7 @@ def predict(self, input_vector: numpy.ndarray): class FakeNumpyModelDictOut: + def __init__(self): self.total_predict_calls = 0 @@ -102,6 +104,7 @@ def predict(self, input_vector: numpy.ndarray): class FakePandasModelDictOut: + def __init__(self): self.total_predict_calls = 0 @@ -179,6 +182,7 @@ def alternate_pandas_inference_fn( class SkLearnRunInferenceTest(unittest.TestCase): + def setUp(self): self.tmpdir = tempfile.mkdtemp() diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py index 78b59975e63c..e58386af9251 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py @@ -99,6 +99,7 @@ def default_tensor_inference_fn( class TFModelHandlerNumpy(ModelHandler[numpy.ndarray, PredictionResult, tf.Module]): + def __init__( self, model_uri: str, @@ -235,6 +236,7 @@ def model_copies(self) -> int: class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult, tf.Module]): + def __init__( self, model_uri: str, diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py index 4786b7a03980..89df80b6708a 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py @@ -72,6 +72,7 @@ def clear_tf_hub_temp_dir(model_path): @pytest.mark.uses_tf @pytest.mark.it_postcommit class TensorflowInference(unittest.TestCase): + def test_tf_mnist_classification(self): test_pipeline = TestPipeline(is_integration_test=True) input_file = 'gs://apache-beam-ml/testing/inputs/it_mnist_data.csv' diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py index dc35aa016013..be45ecdecba2 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py @@ -54,11 +54,13 @@ class FakeTFNumpyModel: + def predict(self, input: numpy.ndarray): return numpy.multiply(input, 10) class FakeTFTensorModel: + def predict(self, input: tf.Tensor, add=False): if add: return tf.math.add(tf.math.multiply(input, 10), 10) @@ -86,6 +88,7 @@ def fake_inference_fn( @pytest.mark.uses_tf class TFRunInferenceTest(unittest.TestCase): + def setUp(self): self.tmpdir = tempfile.mkdtemp() @@ -116,8 +119,7 @@ def test_predict_tensor(self): tf.convert_to_tensor(numpy.array([100])), ] expected_predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip( + PredictionResult(ex, pred) for ex, pred in zip( batched_examples, [tf.math.multiply(n, 10) for n in batched_examples]) ] @@ -160,8 +162,8 @@ def fake_batching_inference_fn( numpy.array([200.1, 300.2, 400.3], dtype='float32')), ] expected_predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip(examples, [tf.math.multiply(n, 2) for n in examples]) + PredictionResult(ex, pred) for ex, pred in zip( + examples, [tf.math.multiply(n, 2) for n in examples]) ] pcoll = pipeline | 'start' >> beam.Create(examples) @@ -207,8 +209,8 @@ def fake_batching_inference_fn( numpy.array([200.1, 300.2, 400.3], dtype='float32')), ] expected_predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip(examples, [tf.math.multiply(n, 2) for n in examples]) + PredictionResult(ex, pred) for ex, pred in zip( + examples, [tf.math.multiply(n, 2) for n in examples]) ] pcoll = pipeline | 'start' >> beam.Create(examples) @@ -250,8 +252,8 @@ def fake_batching_inference_fn( numpy.array([200.1, 300.2, 400.3], dtype='float32'), ] expected_predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip(examples, [numpy.multiply(n, 2) for n in examples]) + PredictionResult(ex, pred) for ex, pred in zip( + examples, [numpy.multiply(n, 2) for n in examples]) ] pcoll = pipeline | 'start' >> beam.Create(examples) @@ -294,8 +296,8 @@ def fake_inference_fn( numpy.array([200.1, 300.2, 400.3], dtype='float32'), ] expected_predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip(examples, [numpy.multiply(n, 2) for n in examples]) + PredictionResult(ex, pred) for ex, pred in zip( + examples, [numpy.multiply(n, 2) for n in examples]) ] pcoll = pipeline | 'start' >> beam.Create(examples) @@ -316,8 +318,7 @@ def test_predict_tensor_with_args(self): tf.convert_to_tensor(numpy.array([100])), ] expected_predictions = [ - PredictionResult(ex, pred) for ex, - pred in zip( + PredictionResult(ex, pred) for ex, pred in zip( batched_examples, [ tf.math.add(tf.math.multiply(n, 10), 10) for n in batched_examples @@ -339,8 +340,7 @@ def test_predict_keyed_numpy(self): ('k3', numpy.array([100], dtype=numpy.int64)), ] expected_predictions = [ - (ex[0], PredictionResult(ex[1], pred)) for ex, - pred in zip( + (ex[0], PredictionResult(ex[1], pred)) for ex, pred in zip( batched_examples, [numpy.multiply(n[1], 10) for n in batched_examples]) ] @@ -359,8 +359,7 @@ def test_predict_keyed_tensor(self): ('k3', tf.convert_to_tensor(numpy.array([100]))), ] expected_predictions = [ - (ex[0], PredictionResult(ex[1], pred)) for ex, - pred in zip( + (ex[0], PredictionResult(ex[1], pred)) for ex, pred in zip( batched_examples, [tf.math.multiply(n[1], 10) for n in batched_examples]) ] @@ -371,12 +370,14 @@ def test_predict_keyed_tensor(self): def test_load_model_exception(self): with self.assertRaises(ValueError): tensorflow_inference._load_model( - "https://tfhub.dev/google/imagenet/mobilenet_v1_075_192/quantops/classification/3", # pylint: disable=line-too-long - None, {}) + "https://tfhub.dev/google/imagenet/mobilenet_v1_075_192/quantops/classification/3", # pylint: disable=line-too-long + None, + {}) @pytest.mark.uses_tf class TFRunInferenceTestWithMocks(unittest.TestCase): + def setUp(self): self._load_model = tensorflow_inference._load_model tensorflow_inference._load_model = unittest.mock.MagicMock() diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py index 9563aa05232a..9e69d8df5586 100644 --- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py @@ -101,6 +101,7 @@ def _assign_or_fail(args): class TensorRTEngine: + def __init__(self, engine: trt.ICudaEngine): """Implementation of the TensorRTEngine class which handles allocations associated with TensorRT engine. @@ -223,6 +224,7 @@ def _default_tensorRT_inference_fn( class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray, PredictionResult, TensorRTEngine]): + def __init__( self, min_batch_size: int, diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference_test.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference_test.py index 86bb7f695d3c..a292462a7723 100644 --- a/sdks/python/apache_beam/ml/inference/tensorrt_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference_test.py @@ -54,16 +54,14 @@ ] SINGLE_FEATURE_PREDICTIONS = [ - PredictionResult(ex, pred) for ex, - pred in zip( + PredictionResult(ex, pred) for ex, pred in zip( SINGLE_FEATURE_EXAMPLES, [[np.array([example * 2.0 + 0.5], dtype=np.float32)] for example in SINGLE_FEATURE_EXAMPLES]) ] SINGLE_FEATURE_CUSTOM_PREDICTIONS = [ - PredictionResult(ex, pred) for ex, - pred in zip( + PredictionResult(ex, pred) for ex, pred in zip( SINGLE_FEATURE_EXAMPLES, [[np.array([(example * 2.0 + 0.5) * 2], dtype=np.float32)] for example in SINGLE_FEATURE_EXAMPLES]) @@ -77,20 +75,18 @@ ] TWO_FEATURES_PREDICTIONS = [ - PredictionResult(ex, pred) for ex, - pred in zip( - TWO_FEATURES_EXAMPLES, - [[ - np.array([example[0] * 2.0 + example[1] * 3 + 0.5], - dtype=np.float32) + PredictionResult(ex, pred) for ex, pred in zip( + TWO_FEATURES_EXAMPLES, [[ + np.array([example[0] * 2.0 + example[1] * 3 + + 0.5], dtype=np.float32) ] for example in TWO_FEATURES_EXAMPLES]) ] def _compare_prediction_result(a, b): return ((a.example == b.example).all() and all( - np.array_equal(actual, expected) for actual, - expected in zip(a.inference, b.inference))) + np.array_equal(actual, expected) + for actual, expected in zip(a.inference, b.inference))) def _assign_or_fail(args): @@ -140,13 +136,14 @@ def _custom_tensorRT_inference_fn(batch, engine, inference_args): return [ PredictionResult( - x, [prediction[idx] * 2 for prediction in cpu_allocations]) for idx, - x in enumerate(batch) + x, [prediction[idx] * 2 for prediction in cpu_allocations]) + for idx, x in enumerate(batch) ] @pytest.mark.uses_tensorrt class TensorRTRunInferenceTest(unittest.TestCase): + @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed') def test_inference_single_tensor_feature_onnx(self): """ @@ -357,6 +354,7 @@ def test_namespace(self): @pytest.mark.uses_tensorrt class TensorRTRunInferencePipelineTest(unittest.TestCase): + @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed') def test_pipeline_single_tensor_feature_built_engine(self): with TestPipeline() as pipeline: diff --git a/sdks/python/apache_beam/ml/inference/utils.py b/sdks/python/apache_beam/ml/inference/utils.py index 4936ab5fe1d4..55fa02cfce20 100644 --- a/sdks/python/apache_beam/ml/inference/utils.py +++ b/sdks/python/apache_beam/ml/inference/utils.py @@ -55,8 +55,8 @@ def _convert_to_result( dict(zip(predictions.keys(), v)) for v in zip(*predictions.values()) ] return [ - PredictionResult(x, y, model_id) for x, - y in zip(batch, predictions_per_tensor) + PredictionResult(x, y, model_id) + for x, y in zip(batch, predictions_per_tensor) ] return [PredictionResult(x, y, model_id) for x, y in zip(batch, predictions)] @@ -112,6 +112,7 @@ def process(self, element, time_state=beam.DoFn.StateParam(TIME_STATE)): class WatchFilePattern(beam.PTransform): + def __init__( self, file_pattern, diff --git a/sdks/python/apache_beam/ml/inference/utils_test.py b/sdks/python/apache_beam/ml/inference/utils_test.py index 66499a5a6f48..2297cac65576 100644 --- a/sdks/python/apache_beam/ml/inference/utils_test.py +++ b/sdks/python/apache_beam/ml/inference/utils_test.py @@ -27,6 +27,7 @@ class WatchFilePatternTest(unittest.TestCase): + def test_latest_file_by_timestamp_default_value(self): # match continuously returns the files in sorted timestamp order. main_input_pcoll = [ diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py index 4c4163accfb9..d0cb1651ff79 100644 --- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py +++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py @@ -62,6 +62,7 @@ def _retry_on_appropriate_gcp_error(exception): class VertexAIModelHandlerJSON(ModelHandler[Any, PredictionResult, aiplatform.Endpoint]): + def __init__( self, endpoint_id: str, diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py index 7c96dbe8b847..e2d0c4bd139a 100644 --- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py @@ -46,6 +46,7 @@ class VertexAIInference(unittest.TestCase): + @pytest.mark.vertex_ai_postcommit def test_vertex_ai_run_flower_image_classification(self): output_file = '/'.join([_OUTPUT_DIR, str(uuid.uuid4()), 'output.txt']) diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py index 34c7927272d6..69156fc887ff 100644 --- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py @@ -27,6 +27,7 @@ class RetryOnClientErrorTest(unittest.TestCase): + def test_retry_on_client_error_positive(self): e = TooManyRequests(message="fake service rate limiting") self.assertTrue(_retry_on_appropriate_gcp_error(e)) @@ -37,6 +38,7 @@ def test_retry_on_client_error_negative(self): class ModelHandlerArgConditions(unittest.TestCase): + def test_exception_on_private_without_network(self): self.assertRaises( ValueError, diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index 799083d16ceb..efe7fb2d69ff 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -107,6 +107,7 @@ def getAsyncVLLMClient(port) -> AsyncOpenAI: class _VLLMModelServer(): + def __init__(self, model_name: str, vllm_server_kwargs: Dict[str, str]): self._model_name = model_name self._vllm_server_kwargs = vllm_server_kwargs @@ -166,6 +167,7 @@ def check_connectivity(self, retries=3): class VLLMCompletionsModelHandler(ModelHandler[str, PredictionResult, _VLLMModelServer]): + def __init__( self, model_name: str, @@ -249,6 +251,7 @@ def share_model_across_processes(self) -> bool: class VLLMChatModelHandler(ModelHandler[Sequence[OpenAIChatMessage], PredictionResult, _VLLMModelServer]): + def __init__( self, model_name: str, diff --git a/sdks/python/apache_beam/ml/inference/xgboost_inference.py b/sdks/python/apache_beam/ml/inference/xgboost_inference.py index ff6f098b4150..1545946a1e62 100644 --- a/sdks/python/apache_beam/ml/inference/xgboost_inference.py +++ b/sdks/python/apache_beam/ml/inference/xgboost_inference.py @@ -70,6 +70,7 @@ def default_xgboost_inference_fn( class XGBoostModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC): + def __init__( self, model_class: Union[Callable[..., xgboost.Booster], @@ -173,6 +174,7 @@ class XGBoostModelHandlerNumpy(XGBoostModelHandler[numpy.ndarray, inference_fn: the inference function to use during RunInference. default=default_xgboost_inference_fn """ + def run_inference( self, batch: Sequence[numpy.ndarray], @@ -225,6 +227,7 @@ class XGBoostModelHandlerPandas(XGBoostModelHandler[pandas.DataFrame, inference_fn: the inference function to use during RunInference. default=default_xgboost_inference_fn """ + def run_inference( self, batch: Sequence[pandas.DataFrame], @@ -277,6 +280,7 @@ class XGBoostModelHandlerSciPy(XGBoostModelHandler[scipy.sparse.csr_matrix, inference_fn: the inference function to use during RunInference. default=default_xgboost_inference_fn """ + def run_inference( self, batch: Sequence[scipy.sparse.csr_matrix], @@ -329,6 +333,7 @@ class XGBoostModelHandlerDatatable(XGBoostModelHandler[datatable.Frame, inference_fn: the inference function to use during RunInference. default=default_xgboost_inference_fn """ + def run_inference( self, batch: Sequence[datatable.Frame], diff --git a/sdks/python/apache_beam/ml/inference/xgboost_inference_it_test.py b/sdks/python/apache_beam/ml/inference/xgboost_inference_it_test.py index 3db62bcc6a99..9cebe89832bb 100644 --- a/sdks/python/apache_beam/ml/inference/xgboost_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/xgboost_inference_it_test.py @@ -80,6 +80,7 @@ def process_outputs(filepath): @pytest.mark.uses_xgboost @pytest.mark.it_postcommit class XGBoostInference(unittest.TestCase): + def test_iris_classification_numpy_single_batch(self): test_pipeline = TestPipeline(is_integration_test=True) input_type = 'numpy' diff --git a/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py b/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py index e09f116dfb38..771ded0be53a 100644 --- a/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py @@ -94,6 +94,7 @@ def build_monkeypatched_xgboost_classifier() -> xgboost.XGBClassifier: @pytest.mark.uses_xgboost class XGBoostRunInferenceTest(unittest.TestCase): + def setUp(self): self.tmpdir = tempfile.mkdtemp() diff --git a/sdks/python/apache_beam/ml/rag/chunking/base.py b/sdks/python/apache_beam/ml/rag/chunking/base.py index 626a6ea8abbe..5cd7c850356c 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/base.py +++ b/sdks/python/apache_beam/ml/rag/chunking/base.py @@ -35,6 +35,7 @@ def _assign_chunk_id(chunk_id_fn: ChunkIdFn, chunk: Chunk): class ChunkingTransformProvider(MLTransformProvider): + def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None): """Base class for chunking transforms in RAG pipelines. diff --git a/sdks/python/apache_beam/ml/rag/chunking/base_test.py b/sdks/python/apache_beam/ml/rag/chunking/base_test.py index 54e25591c348..68579158bcf8 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/base_test.py +++ b/sdks/python/apache_beam/ml/rag/chunking/base_test.py @@ -34,6 +34,7 @@ class WordSplitter(beam.DoFn): + def process(self, element): words = element['text'].split() for i, word in enumerate(words): @@ -44,11 +45,13 @@ def process(self, element): class InvalidChunkingProvider(ChunkingTransformProvider): + def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None): super().__init__(chunk_id_fn=chunk_id_fn) class MockChunkingProvider(ChunkingTransformProvider): + def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None): super().__init__(chunk_id_fn=chunk_id_fn) @@ -78,6 +81,7 @@ def id_equals(expected, actual): @pytest.mark.uses_transformers class ChunkingTransformProviderTest(unittest.TestCase): + def setUp(self): self.test_doc = {'text': 'hello world test', 'source': 'test.txt'} @@ -115,6 +119,7 @@ def test_chunking_transform(self): def test_custom_chunk_id_fn(self): """Test the a custom chink id function.""" + def source_index_id_fn(chunk: Chunk): return f"{chunk.metadata['source']}_{chunk.index}" diff --git a/sdks/python/apache_beam/ml/rag/chunking/langchain.py b/sdks/python/apache_beam/ml/rag/chunking/langchain.py index 9e3b6b0c8ef9..999ce7f3e601 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/langchain.py +++ b/sdks/python/apache_beam/ml/rag/chunking/langchain.py @@ -33,6 +33,7 @@ class LangChainChunker(ChunkingTransformProvider): + def __init__( self, text_splitter: TextSplitter, @@ -104,6 +105,7 @@ def get_splitter_transform( class _LangChainTextSplitter(beam.DoFn): + def __init__( self, text_splitter: TextSplitter, diff --git a/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py index 83a4fc1a778f..515b6ec7a279 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py +++ b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py @@ -52,6 +52,7 @@ def chunk_equals(expected, actual): @unittest.skipIf(not LANGCHAIN_AVAILABLE, 'langchain is not installed.') class LangChainChunkingTest(unittest.TestCase): + def setUp(self): self.simple_text = { 'content': 'This is a simple test document. It has multiple sentences. ' @@ -104,8 +105,7 @@ def test_multiple_metadata_fields(self): assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks') assert_that( - chunks, - lambda x: all( + chunks, lambda x: all( c.metadata == { 'source': 'simple.txt', 'language': 'en' } for c in x)) diff --git a/sdks/python/apache_beam/ml/rag/embeddings/base_test.py b/sdks/python/apache_beam/ml/rag/embeddings/base_test.py index 3a27ae8e7ebb..fb2f92ec5d51 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/base_test.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/base_test.py @@ -23,6 +23,7 @@ class RAGBaseEmbeddingsTest(unittest.TestCase): + def setUp(self): self.test_chunks = [ Chunk( diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py index 4cb0aecd6e82..a797ec9b2013 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py @@ -33,6 +33,7 @@ class HuggingfaceTextEmbeddings(EmbeddingsManager): + def __init__( self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs): """Utilizes huggingface SentenceTransformer embeddings for RAG pipeline. diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py index aa63d13025a1..180551ffa4ac 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py @@ -56,6 +56,7 @@ def chunk_approximately_equals(expected, actual): @unittest.skipIf( not SENTENCE_TRANSFORMERS_AVAILABLE, "sentence-transformers not available") class HuggingfaceTextEmbeddingsTest(unittest.TestCase): + def setUp(self): self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_') self.test_chunks = [ diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index 57a5efd3ff0e..f9f519e6414b 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -121,6 +121,7 @@ class MLTransformProvider: used to process the data. """ + @abc.abstractmethod def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: """ @@ -138,6 +139,7 @@ def get_counter(self): class BaseOperation(Generic[OperationInputT, OperationOutputT], MLTransformProvider, abc.ABC): + def __init__(self, columns: list[str]) -> None: """ Base Opertation class data processing transformations. @@ -176,6 +178,7 @@ class ProcessHandler( """ Only for internal use. No backwards compatibility guarantees. """ + @abc.abstractmethod def append_transform(self, transform: BaseOperation): """ @@ -246,6 +249,7 @@ def _create_dict_adapter( # TODO:https://github.com/apache/beam/issues/29356 # Add support for inference_fn class EmbeddingsManager(MLTransformProvider): + def __init__( self, *, @@ -290,6 +294,7 @@ class MLTransform( tuple[beam.PCollection[MLTransformOutputT], beam.PCollection[beam.Row]]]], Generic[ExampleT, MLTransformOutputT]): + def __init__( self, *, @@ -470,11 +475,13 @@ def with_exception_handling( class MLTransformMetricsUsage(beam.PTransform): + def __init__(self, ml_transform: MLTransform): self._ml_transform = ml_transform self._ml_transform._counter.inc() def expand(self, pipeline): + def _increment_counters(): # increment for MLTransform. self._ml_transform._counter.inc() @@ -494,6 +501,7 @@ class _TransformAttributeManager: """ Base class used for saving and loading the attributes. """ + @staticmethod def save_attributes(artifact_location): """ @@ -517,6 +525,7 @@ class _JsonPickleTransformAttributeManager(_TransformAttributeManager): jsonpickle is used to serialize the PTransforms and save it to a json file and is compatible across python versions. """ + @staticmethod def _is_remote_path(path): is_gcs = path.find('gs://') != -1 @@ -596,6 +605,7 @@ class _MLTransformToPTransformMapper: PTransforms or attributes of PTransforms to the artifact location to seal the gap between the training and inference pipelines. """ + def __init__( self, transforms: list[MLTransformProvider], @@ -628,8 +638,8 @@ def create_ptransform_list(self): self._parent_artifact_location, uuid.uuid4().hex[:6]), artifact_mode=self.artifact_mode) append_transform = hasattr(current_ptransform, 'append_transform') - if (type(current_ptransform) != - previous_ptransform_type) or not append_transform: + if (type(current_ptransform) + != previous_ptransform_type) or not append_transform: ptransform_list.append(current_ptransform) previous_ptransform_type = type(current_ptransform) # If different PTransform is appended to the list and the PTransform @@ -672,6 +682,7 @@ class _EmbeddingHandler(ModelHandler): Args: embeddings_manager: An EmbeddingsManager instance. """ + def __init__(self, embeddings_manager: EmbeddingsManager): self.embedding_config = embeddings_manager self._underlying = self.embedding_config.get_model_handler() @@ -750,6 +761,7 @@ class _TextEmbeddingHandler(_EmbeddingHandler): Args: embeddings_manager: An EmbeddingsManager instance. """ + def _validate_column_data(self, batch): if not isinstance(batch[0], (str, bytes)): raise TypeError( @@ -785,6 +797,7 @@ class _ImageEmbeddingHandler(_EmbeddingHandler): Args: embeddings_manager: An EmbeddingsManager instance. """ + def _validate_column_data(self, batch): # Don't want to require framework-specific imports # here, so just catch columns of primatives for now. diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index 1ef01acca18a..e95f3e78d7e1 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -56,6 +56,7 @@ try: class _FakeOperation(TFTOperation): + def __init__(self, name, *args, **kwargs): super().__init__(*args, **kwargs) self.name = name @@ -72,6 +73,7 @@ def apply_transform(self, inputs, output_column_name, **kwargs): class BaseMLTransformTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -299,7 +301,9 @@ def test_mltransform_with_counter(self): result.metrics().query(mltransform_counter)['counters'][0].result, 1) def test_non_ptransfrom_provider_class_to_mltransform(self): + class Add: + def __call__(self, x): return x + 1 @@ -323,6 +327,7 @@ def test_read_mode_with_transforms(self): class FakeModel: + def __call__(self, example: list[str]) -> list[str]: for i in range(len(example)): if not isinstance(example[i], str): @@ -332,6 +337,7 @@ def __call__(self, example: list[str]) -> list[str]: class FakeModelHandler(ModelHandler): + def run_inference( self, batch: Sequence[str], @@ -344,6 +350,7 @@ def load_model(self): class FakeEmbeddingsManager(base.EmbeddingsManager): + def __init__(self, columns, **kwargs): super().__init__(columns=columns, **kwargs) @@ -359,6 +366,7 @@ def __repr__(self): class InvalidEmbeddingsManager(base.EmbeddingsManager): + def __init__(self, **kwargs): super().__init__(**kwargs) @@ -374,6 +382,7 @@ def __repr__(self): class TextEmbeddingHandlerTest(unittest.TestCase): + def setUp(self) -> None: self.embedding_conig = FakeEmbeddingsManager(columns=['x']) self.artifact_location = tempfile.mkdtemp() @@ -405,8 +414,10 @@ def test_handler_with_dict_inputs(self): 'x': "Apache Beam" }, ] - expected_data = [{key: value[::-1] - for key, value in d.items()} for d in data] + expected_data = [{ + key: value[::-1] + for key, value in d.items() + } for d in data] with beam.Pipeline() as p: result = ( p @@ -430,8 +441,10 @@ def test_handler_with_batch_sizes(self): 'x': "Apache Beam" }, ] * 100 - expected_data = [{key: value[::-1] - for key, value in d.items()} for d in data] + expected_data = [{ + key: value[::-1] + for key, value in d.items() + } for d in data] with beam.Pipeline() as p: result = ( p @@ -456,8 +469,7 @@ def test_handler_on_multiple_columns(self): embedding_config = FakeEmbeddingsManager(columns=['x', 'y']) expected_data = [{ key: (value[::-1] if key in embedding_config.columns else value) - for key, - value in d.items() + for key, value in d.items() } for d in data] with beam.Pipeline() as p: result = ( @@ -528,6 +540,7 @@ def test_handler_with_inconsistent_keys(self): class FakeImageModel: + def __call__(self, example: list[PIL_Image]) -> list[PIL_Image]: for i in range(len(example)): if not isinstance(example[i], PIL_Image): @@ -536,6 +549,7 @@ def __call__(self, example: list[PIL_Image]) -> list[PIL_Image]: class FakeImageModelHandler(ModelHandler): + def run_inference( self, batch: Sequence[PIL_Image], @@ -548,6 +562,7 @@ def load_model(self): class FakeImageEmbeddingsManager(base.EmbeddingsManager): + def __init__(self, columns, **kwargs): super().__init__(columns=columns, **kwargs) @@ -563,6 +578,7 @@ def __repr__(self): class TestImageEmbeddingHandler(unittest.TestCase): + def setUp(self) -> None: self.embedding_config = FakeImageEmbeddingsManager(columns=['x']) self.artifact_location = tempfile.mkdtemp() @@ -627,6 +643,7 @@ def test_handler_with_dict_inputs(self): class TestUtilFunctions(unittest.TestCase): + def test_dict_input_fn_normal(self): input_list = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] columns = ['a', 'b'] @@ -661,6 +678,7 @@ def test_dict_output_fn_on_list_inputs(self): class TestJsonPickleTransformAttributeManager(unittest.TestCase): + def setUp(self): self.attribute_manager = base._transform_attribute_manager self.artifact_location = tempfile.mkdtemp() @@ -768,6 +786,7 @@ def test_with_same_local_artifact_location(self): class MLTransformDLQTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py index e492cb164222..1f2ff42f16be 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py @@ -49,6 +49,7 @@ class _SentenceTransformerModelHandler(ModelHandler): """ Note: Intended for internal use and guarantees no backwards compatibility. """ + def __init__( self, model_name: str, @@ -108,6 +109,7 @@ def __repr__(self) -> str: class SentenceTransformerEmbeddings(EmbeddingsManager): + def __init__( self, model_name: str, @@ -165,6 +167,7 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: class _InferenceAPIHandler(ModelHandler): + def __init__(self, config: 'InferenceAPIEmbeddings'): super().__init__() self._config = config @@ -211,14 +214,15 @@ class InferenceAPIEmbeddings(EmbeddingsManager): ignored. If none, the default url for feature extraction will be used. """ + def __init__( self, hf_token: Optional[str], columns: list[str], - model_name: Optional[str] = None, # example: "sentence-transformers/all-MiniLM-l6-v2" # pylint: disable=line-too-long + model_name: Optional[str] = None, # example: "sentence-transformers/all-MiniLM-l6-v2" # pylint: disable=line-too-long api_url: Optional[str] = None, **kwargs, - ): + ): super().__init__(columns=columns, **kwargs) self._authorization_token = {"Authorization": f"Bearer {hf_token}"} self._model_name = model_name diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface_test.py index d09a573b6766..6cc5ab79b6b7 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface_test.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface_test.py @@ -94,6 +94,7 @@ SentenceTransformerEmbeddings is None, 'sentence-transformers is not installed.') class SentenceTransformerEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_') # this bucket has TTL and will be deleted periodically @@ -330,6 +331,7 @@ def test_sentence_transformer_images_with_str_data_types(self): @unittest.skipIf(_HF_TOKEN is None, 'HF_TOKEN environment variable not set.') class HuggingfaceInferenceAPITest(unittest.TestCase): + def setUp(self): self.artifact_location = tempfile.mkdtemp() self.inputs = [{test_query_column: test_query}] @@ -368,6 +370,7 @@ def test_embeddings_with_inference_api(self): @unittest.skipIf(_HF_TOKEN is None, 'HF_TOKEN environment variable not set.') class HuggingfaceInferenceAPIGCSLocationTest(HuggingfaceInferenceAPITest): + def setUp(self): self.artifact_location = self.gcs_artifact_location = os.path.join( 'gs://temp-storage-for-perf-tests/tft_handler', uuid.uuid4().hex) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py index c14904df7c2c..c9f3ff9c6436 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py @@ -40,6 +40,7 @@ class _TensorflowHubModelHandler(TFModelHandlerTensor): """ Note: Intended for internal use only. No backwards compatibility guarantees. """ + def __init__(self, preprocessing_url: Optional[str], *args, **kwargs): self.preprocessing_url = preprocessing_url super().__init__(*args, **kwargs) @@ -87,6 +88,7 @@ def run_inference(self, batch, model, inference_args, model_id=None): class TensorflowHubTextEmbeddings(EmbeddingsManager): + def __init__( self, columns: list[str], @@ -135,6 +137,7 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: class TensorflowHubImageEmbeddings(EmbeddingsManager): + def __init__(self, columns: list[str], hub_url: str, **kwargs): """ Embedding config for tensorflow hub models. This config can be used with diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py index 24bca5155fa7..6d69413c0eac 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py @@ -56,6 +56,7 @@ @unittest.skipIf( TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.') class TFHubEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -176,6 +177,7 @@ def test_with_int_data_types(self): @unittest.skipIf( TensorflowHubImageEmbeddings is None, 'Tensorflow is not installed.') class TFHubImageEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -224,6 +226,7 @@ def test_with_str_data_types(self): @unittest.skipIf( TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.') class TFHubEmbeddingsGCSArtifactLocationTest(TFHubEmbeddingsTest): + def setUp(self): self.artifact_location = os.path.join( 'gs://temp-storage-for-perf-tests/tfhub', uuid.uuid4().hex) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py index 6df505508ae9..67903e386bbf 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py @@ -84,6 +84,7 @@ class _VertexAITextEmbeddingHandler(ModelHandler): """ Note: Intended for internal use and guarantees no backwards compatibility. """ + def __init__( self, model_name: str, @@ -168,6 +169,7 @@ def __repr__(self): class VertexAITextEmbeddings(EmbeddingsManager): + def __init__( self, model_name: str, @@ -224,6 +226,7 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: class _VertexAIImageEmbeddingHandler(ModelHandler): + def __init__( self, model_name: str, @@ -296,6 +299,7 @@ def __repr__(self): class VertexAIImageEmbeddings(EmbeddingsManager): + def __init__( self, model_name: str, diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py index 93f4a28d2e9e..c548daf63f36 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py @@ -48,6 +48,7 @@ @unittest.skipIf( VertexAITextEmbeddings is None, 'Vertex AI Python SDK is not installed.') class VertexAIEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp(prefix='_vertex_ai_test') self.gcs_artifact_location = os.path.join( @@ -251,6 +252,7 @@ def test_mltransform_to_ptransform_with_vertex(self): @unittest.skipIf( VertexAIImageEmbeddings is None, 'Vertex AI Python SDK is not installed.') class VertexAIImageEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp(prefix='_vertex_ai_image_test') self.gcs_artifact_location = os.path.join( diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py b/sdks/python/apache_beam/ml/transforms/handlers.py index 1e752049f6e5..54903dce926f 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers.py +++ b/sdks/python/apache_beam/ml/transforms/handlers.py @@ -85,6 +85,7 @@ class _DataCoder: + def __init__( self, exclude_columns, @@ -115,6 +116,7 @@ def decode(self, element): class _ConvertScalarValuesToListValues(beam.DoFn): + def process( self, element, @@ -137,6 +139,7 @@ class _ConvertNamedTupleToDict( A PTransform that converts a collection of NamedTuples or Rows into a collection of dictionaries. """ + def expand( self, pcoll: beam.PCollection[Union[beam.Row, NamedTuple]] ) -> beam.PCollection[common_types.InstanceDictType]: @@ -151,6 +154,7 @@ def expand( class TFTProcessHandler(ProcessHandler[tft_process_handler_input_type, tft_process_handler_output_type]): + def __init__( self, *, diff --git a/sdks/python/apache_beam/ml/transforms/handlers_test.py b/sdks/python/apache_beam/ml/transforms/handlers_test.py index bb5f9b5f0f70..42a64943e922 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers_test.py +++ b/sdks/python/apache_beam/ml/transforms/handlers_test.py @@ -50,11 +50,13 @@ class _AddOperation(TFTOperation): + def apply_transform(self, inputs, output_column_name, **kwargs): return {output_column_name: inputs + 1} class _MultiplyOperation(TFTOperation): + def apply_transform(self, inputs, output_column_name, **kwargs): return {output_column_name: inputs * 10} @@ -72,6 +74,7 @@ class NumpyType(NamedTuple): class TFTProcessHandlerTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -621,6 +624,7 @@ def test_handler_with_same_input_elements(self): class TFTProcessHandlerTestWithGCSLocation(TFTProcessHandlerTest): + def setUp(self) -> None: self.artifact_location = self.gcs_artifact_location = os.path.join( 'gs://temp-storage-for-perf-tests/tft_handler', uuid.uuid4().hex) diff --git a/sdks/python/apache_beam/ml/transforms/tft.py b/sdks/python/apache_beam/ml/transforms/tft.py index bfe23757642b..558b41112caa 100644 --- a/sdks/python/apache_beam/ml/transforms/tft.py +++ b/sdks/python/apache_beam/ml/transforms/tft.py @@ -70,6 +70,7 @@ def register_input_dtype(type): + def wrapper(fn): _EXPECTED_TYPES[fn.__name__] = type return fn @@ -81,6 +82,7 @@ def wrapper(fn): # Add support for outputting artifacts to a text file in human readable form. class TFTOperation(BaseOperation[common_types.TensorType, common_types.TensorType]): + def __init__(self, columns: list[str]) -> None: """ Base Operation class for TFT data processing transformations. @@ -145,6 +147,7 @@ def _split_string_with_delimiter(self, data, delimiter): @register_input_dtype(str) class ComputeAndApplyVocabulary(TFTOperation): + def __init__( self, columns: list[str], @@ -213,6 +216,7 @@ def apply_transform( @register_input_dtype(float) class ScaleToZScore(TFTOperation): + def __init__( self, columns: list[str], @@ -254,6 +258,7 @@ def apply_transform( @register_input_dtype(float) class ScaleTo01(TFTOperation): + def __init__( self, columns: list[str], @@ -294,6 +299,7 @@ def apply_transform( @register_input_dtype(float) class ScaleToGaussian(TFTOperation): + def __init__( self, columns: list[str], @@ -331,6 +337,7 @@ def apply_transform( @register_input_dtype(float) class ApplyBuckets(TFTOperation): + def __init__( self, columns: list[str], @@ -366,6 +373,7 @@ def apply_transform( @register_input_dtype(float) class ApplyBucketsWithInterpolation(TFTOperation): + def __init__( self, columns: list[str], @@ -405,6 +413,7 @@ def apply_transform( @register_input_dtype(float) class Bucketize(TFTOperation): + def __init__( self, columns: list[str], @@ -454,6 +463,7 @@ def apply_transform( @register_input_dtype(float) class TFIDF(TFTOperation): + def __init__( self, columns: list[str], @@ -525,6 +535,7 @@ def apply_transform( @register_input_dtype(float) class ScaleByMinMax(TFTOperation): + def __init__( self, columns: list[str], @@ -561,6 +572,7 @@ def apply_transform( @register_input_dtype(str) class NGrams(TFTOperation): + def __init__( self, columns: list[str], @@ -606,6 +618,7 @@ def apply_transform( @register_input_dtype(str) class BagOfWords(TFTOperation): + def __init__( self, columns: list[str], @@ -681,6 +694,7 @@ def count_unique_words( @register_input_dtype(str) class HashStrings(TFTOperation): + def __init__( self, columns: list[str], @@ -725,6 +739,7 @@ def apply_transform( @register_input_dtype(str) class DeduplicateTensorPerRow(TFTOperation): + def __init__(self, columns: list[str], name: Optional[str] = None): """ Deduplicates each row (0th dimension) of the provided tensor. diff --git a/sdks/python/apache_beam/ml/transforms/tft_test.py b/sdks/python/apache_beam/ml/transforms/tft_test.py index 6afe9b5ab302..6a461c7a6e53 100644 --- a/sdks/python/apache_beam/ml/transforms/tft_test.py +++ b/sdks/python/apache_beam/ml/transforms/tft_test.py @@ -41,6 +41,7 @@ class ScaleZScoreTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -109,6 +110,7 @@ def test_z_score_list_data(self): class ScaleTo01Test(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -156,6 +158,7 @@ def test_ScaleTo01(self): class ScaleToGaussianTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -282,6 +285,7 @@ def test_gaussian_skewed(self): class BucketizeTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -331,6 +335,7 @@ def test_bucketize_list(self): class ApplyBucketsTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -365,6 +370,7 @@ def test_apply_buckets(self, test_inputs, bucket_boundaries): class ApplyBucketsWithInterpolationTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -395,6 +401,7 @@ def test_apply_buckets(self, test_inputs, bucket_boundaries, expected_values): class ComputeAndApplyVocabTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -404,13 +411,19 @@ def tearDown(self): def test_compute_and_apply_vocabulary_inputs(self): num_elements = 100 num_instances = num_elements + 1 - input_data = [{ - 'x': '%.10i' % i, # Front-padded to facilitate lexicographic sorting. - } for i in range(num_instances)] + input_data = [ + { + 'x': '%.10i' % + i, # Front-padded to facilitate lexicographic sorting. + } for i in range(num_instances) + ] - expected_data = [{ - 'x': (len(input_data) - 1) - i, # Due to reverse lexicographic sorting. - } for i in range(len(input_data))] + expected_data = [ + { + 'x': (len(input_data) - 1) - + i, # Due to reverse lexicographic sorting. + } for i in range(len(input_data)) + ] with beam.Pipeline() as p: actual_data = ( @@ -594,6 +607,7 @@ def test_multiple_columns_with_vocab_name(self): class TFIDIFTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -641,6 +655,7 @@ def equals_fn(a, b): class ScaleToMinMaxTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -687,6 +702,7 @@ def test_fail_max_value_less_than_min(self): class NGramsTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -804,6 +820,7 @@ def test_with_multiple_string_delimiters(self): class BagOfWordsTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -963,6 +980,7 @@ def validate_count_per_key(key_vocab_filename): class HashStringsTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() @@ -1040,6 +1058,7 @@ def test_multi_buckets_multi_string(self): class DeduplicateTensorPerRowTest(unittest.TestCase): + def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp() diff --git a/sdks/python/apache_beam/ml/transforms/utils.py b/sdks/python/apache_beam/ml/transforms/utils.py index 023657895686..540b38efb8ab 100644 --- a/sdks/python/apache_beam/ml/transforms/utils.py +++ b/sdks/python/apache_beam/ml/transforms/utils.py @@ -48,6 +48,7 @@ class ArtifactsFetcher: This is intended to be used for testing purposes only. """ + def __init__(self, artifact_location: str): tempdir = tempfile.mkdtemp() if artifact_location.startswith('gs://'): diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 8eba11d9ea34..99c35558dcca 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -76,6 +76,7 @@ def _static_value_provider_of(value_type): A partially constructed StaticValueProvider in the form of a function. """ + def _f(value): _f.__name__ = value_type.__name__ return StaticValueProvider(value_type, value) @@ -97,6 +98,7 @@ def _add_argparse_args(cls, parser): parser.add_argument('--non_vp_arg') """ + def add_value_provider_argument(self, *args, **kwargs): """ValueProvider arguments can be either of type keyword or positional. At runtime, even positional arguments will need to be supplied in the @@ -142,9 +144,10 @@ class _DictUnionAction(argparse.Action): argparse Action take union of json loads values. If a key is specified in more than one of the values, the last value takes precedence. """ + def __call__(self, parser, namespace, values, option_string=None): - if not hasattr(namespace, - self.dest) or getattr(namespace, self.dest) is None: + if not hasattr(namespace, self.dest) or getattr(namespace, + self.dest) is None: setattr(namespace, self.dest, {}) getattr(namespace, self.dest).update(values) @@ -186,6 +189,7 @@ def _add_argparse_args(cls, parser): By default the options classes will use command line arguments to initialize the options. """ + def __init__(self, flags: Optional[Sequence[str]] = None, **kwargs) -> None: """Initialize an options class. @@ -410,6 +414,7 @@ def get_all_options( return result def to_runner_api(self): + def to_struct_value(o): if isinstance(o, (bool, int, str)): return o @@ -430,6 +435,7 @@ def to_struct_value(o): @classmethod def from_runner_api(cls, proto_options): + def from_urn(key): assert key.startswith('beam:option:') assert key.endswith(':v1') @@ -586,6 +592,7 @@ def _add_argparse_args(cls, parser): class StreamingOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -598,6 +605,7 @@ def _add_argparse_args(cls, parser): class CrossLanguageOptions(PipelineOptions): + @staticmethod def _beam_services_from_enviroment(): return json.loads(os.environ.get('BEAM_SERVICE_OVERRIDES') or '{}') @@ -647,6 +655,7 @@ def enable_all_additional_type_checks(): class TypeOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): # TODO(laolu): Add a type inferencing option here once implemented. @@ -723,6 +732,7 @@ def validate(self, unused_validator): class DirectOptions(PipelineOptions): """DirectRunner-specific execution options.""" + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -1055,6 +1065,7 @@ def get_cloud_profiler_service_name(self): class AzureOptions(PipelineOptions): """Azure Blob Storage options.""" + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -1083,6 +1094,7 @@ def validate(self, validator): class HadoopFileSystemOptions(PipelineOptions): """``HadoopFileSystem`` connection options.""" + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -1112,6 +1124,7 @@ def validate(self, validator): # TODO(silviuc): Update description when autoscaling options are in. class WorkerOptions(PipelineOptions): """Worker pool configuration options.""" + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -1133,8 +1146,7 @@ def _add_argparse_args(cls, parser): type=str, choices=['NONE', 'THROUGHPUT_BASED'], default=None, # Meaning unset, distinct from 'NONE' meaning don't scale - help= - ('If and how to autoscale the workerpool.')) + help=('If and how to autoscale the workerpool.')) parser.add_argument( '--worker_machine_type', '--machine_type', @@ -1302,6 +1314,7 @@ def validate(self, validator): class DebugOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -1355,6 +1368,7 @@ def validate(self, validator): class ProfilingOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -1378,6 +1392,7 @@ def _add_argparse_args(cls, parser): class SetupOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): # Options for installing dependencies in the worker. @@ -1524,6 +1539,7 @@ class PortableOptions(PipelineOptions): the portable runners. Should generally be kept in sync with PortablePipelineOptions.java. """ + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -1627,6 +1643,7 @@ class JobServerOptions(PipelineOptions): """Options for starting a Beam job server. Roughly corresponds to JobServerDriver.ServerConfiguration in Java. """ + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -1719,6 +1736,7 @@ def _add_argparse_args(cls, parser): class SparkRunnerOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -1752,6 +1770,7 @@ def _add_argparse_args(cls, parser): class PrismRunnerOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -1770,6 +1789,7 @@ def _add_argparse_args(cls, parser): class TestOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): # Options for e2e test pipeline. @@ -1802,6 +1822,7 @@ def validate(self, validator): class TestDataflowOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): # This option is passed to Dataflow Runner's Pub/Sub client. The camelCase @@ -1844,6 +1865,7 @@ def __exit__(self, *exn_info): self.overrides.pop() def __call__(self, f, *args, **kwargs): + def wrapper(*args, **kwargs): with self: f(*args, **kwargs) @@ -1859,6 +1881,7 @@ def augment_options(cls, options): class S3Options(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): # These options are passed to the S3 IO Client diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py index 66acfe654791..b66020afa9d3 100644 --- a/sdks/python/apache_beam/options/pipeline_options_test.py +++ b/sdks/python/apache_beam/options/pipeline_options_test.py @@ -54,7 +54,9 @@ # Mock runners to use for validations. class MockRunners(object): + class DataflowRunner(object): + def get_default_gcp_region(self): # Return a default so we don't have to specify --region in every test # (unless specifically testing it). @@ -62,16 +64,19 @@ def get_default_gcp_region(self): class MockGoogleCloudOptionsNoBucket(GoogleCloudOptions): + def _create_default_gcs_bucket(self): return None class MockGoogleCloudOptionsWithBucket(GoogleCloudOptions): + def _create_default_gcs_bucket(self): return "gs://default/bucket" class PipelineOptionsTest(unittest.TestCase): + def setUp(self): # Reset runtime options to avoid side-effects caused by other tests. # Note that is_accessible assertions require runtime_options to @@ -185,6 +190,7 @@ def tearDown(self): # Used for testing newly added flags. class MockOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument('--mock_flag', action='store_true', help='mock flag') @@ -196,6 +202,7 @@ def _add_argparse_args(cls, parser): # Use with MockOptions in test cases where multiple option classes are needed. class FakeOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument('--fake_flag', action='store_true', help='fake flag') @@ -296,7 +303,9 @@ def test_from_dictionary(self, flags, expected, _): expected.get('mock_json_option', {})) def test_none_from_dictionary(self): + class NoneDefaultOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument('--test_arg_none', default=None, type=int) @@ -525,12 +534,15 @@ def test_template_location(self): self.assertEqual(options.get_all_options()['template_location'], None) def test_redefine_options(self): + class TestRedefinedOptions(PipelineOptions): # pylint: disable=unused-variable + @classmethod def _add_argparse_args(cls, parser): parser.add_argument('--redefined_flag', action='store_true') class TestRedefinedOptions(PipelineOptions): # pylint: disable=function-redefined + @classmethod def _add_argparse_args(cls, parser): parser.add_argument('--redefined_flag', action='store_true') @@ -544,7 +556,9 @@ def _add_argparse_args(cls, parser): # _non_vp_arg for non-value-provider arguments. # The number will grow per file as tests are added. def test_value_provider_options(self): + class UserOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( diff --git a/sdks/python/apache_beam/options/pipeline_options_validator.py b/sdks/python/apache_beam/options/pipeline_options_validator.py index 7c07e5a1e6c7..ebe9c8f223ce 100644 --- a/sdks/python/apache_beam/options/pipeline_options_validator.py +++ b/sdks/python/apache_beam/options/pipeline_options_validator.py @@ -416,7 +416,7 @@ def validate_endpoint_url(self, endpoint_url): return False if url_parts.scheme not in ['http', 'https']: return False - if set( - url_parts.netloc) <= set(string.ascii_letters + string.digits + '-.'): + if set(url_parts.netloc) <= set(string.ascii_letters + string.digits + + '-.'): return True return False diff --git a/sdks/python/apache_beam/options/pipeline_options_validator_test.py b/sdks/python/apache_beam/options/pipeline_options_validator_test.py index 56f305a01b74..ffaaebf8895c 100644 --- a/sdks/python/apache_beam/options/pipeline_options_validator_test.py +++ b/sdks/python/apache_beam/options/pipeline_options_validator_test.py @@ -38,7 +38,9 @@ # Mock runners to use for validations. class MockRunners(object): + class DataflowRunner(object): + def get_default_gcp_region(self): # Return a default so we don't have to specify --region in every test # (unless specifically testing it). @@ -53,11 +55,13 @@ class OtherRunner(object): # Matcher that always passes for testing on_success_matcher option class AlwaysPassMatcher(BaseMatcher): + def _matches(self, item): return True class SetupTest(unittest.TestCase): + def check_errors_for_arguments(self, errors, args): """Checks that there is exactly one error for each given argument.""" missing = [] @@ -108,6 +112,7 @@ def test_missing_required_options(self): ]) @unittest.skip('Not compatible with new GCS client. See GH issue #26335.') def test_gcs_path(self, temp_location, staging_location, expected_error_args): + def get_validator(_temp_location, _staging_location): options = ['--project=example:example', '--job_name=job'] @@ -130,6 +135,7 @@ def get_validator(_temp_location, _staging_location): ('FOO', ['project']), ('foo:BAR', ['project']), ('fo', ['project']), ('foo', []), ('foo:bar', [])]) def test_project(self, project, expected_error_args): + def get_validator(_project): options = [ '--job_name=job', @@ -153,6 +159,7 @@ def get_validator(_project): ('FOO', ['job_name']), ('foo:bar', ['job_name']), ('fo', []), ('foo', [])]) def test_job_name(self, job_name, expected_error_args): + def get_validator(_job_name): options = [ '--project=example:example', @@ -175,6 +182,7 @@ def get_validator(_job_name): @parameterized.expand([(None, []), ('1', []), ('0', ['num_workers']), ('-1', ['num_workers'])]) def test_num_workers(self, num_workers, expected_error_args): + def get_validator(_num_workers): options = [ '--project=example:example', @@ -552,6 +560,7 @@ def test_prebuild_sdk_container_base_allowed_if_matches_custom_image(self): (b'abc', ['on_success_matcher']), (pickler.dumps(object), ['on_success_matcher'])]) def test_test_matcher(self, on_success_matcher, errors): + def get_validator(matcher): options = [ '--project=example:example', diff --git a/sdks/python/apache_beam/options/value_provider.py b/sdks/python/apache_beam/options/value_provider.py index fa1649beed26..ad038429b21b 100644 --- a/sdks/python/apache_beam/options/value_provider.py +++ b/sdks/python/apache_beam/options/value_provider.py @@ -41,6 +41,7 @@ class ValueProvider(object): """Base class that all other ValueProviders must implement. """ + def is_accessible(self): """Whether the contents of this ValueProvider is available to routines that run at graph construction time. @@ -59,6 +60,7 @@ class StaticValueProvider(ValueProvider): """StaticValueProvider is an implementation of ValueProvider that allows for a static value to be provided. """ + def __init__(self, value_type, value): """ Args: @@ -142,6 +144,7 @@ class NestedValueProvider(ValueProvider): """NestedValueProvider is an implementation of ValueProvider that allows for wrapping another ValueProvider object. """ + def __init__(self, value, translator): """Creates a NestedValueProvider that wraps the provided ValueProvider. @@ -185,6 +188,7 @@ def check_accessible(value_provider_list): assert isinstance(value_provider_list, list) def _check_accessible(fnc): + @wraps(fnc) def _f(self, *args, **kwargs): for obj in [getattr(self, vp) for vp in value_provider_list]: diff --git a/sdks/python/apache_beam/options/value_provider_test.py b/sdks/python/apache_beam/options/value_provider_test.py index 42afa8c0def3..704d5414e3a8 100644 --- a/sdks/python/apache_beam/options/value_provider_test.py +++ b/sdks/python/apache_beam/options/value_provider_test.py @@ -37,6 +37,7 @@ # _non_vp_arg for non-value-provider arguments. # The number will grow per file as tests are added. class ValueProviderTests(unittest.TestCase): + def setUp(self): # Reset runtime options to avoid side-effects caused by other tests. # Note that is_accessible assertions require runtime_options to @@ -48,7 +49,9 @@ def tearDown(self): RuntimeValueProvider.set_runtime_options(None) def test_static_value_provider_keyword_argument(self): + class UserDefinedOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( @@ -62,7 +65,9 @@ def _add_argparse_args(cls, parser): self.assertEqual(options.vpt_vp_arg1.get(), 'abc') def test_runtime_value_provider_keyword_argument(self): + class UserDefinedOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( @@ -75,7 +80,9 @@ def _add_argparse_args(cls, parser): options.vpt_vp_arg2.get() def test_static_value_provider_positional_argument(self): + class UserDefinedOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( @@ -89,7 +96,9 @@ def _add_argparse_args(cls, parser): self.assertEqual(options.vpt_vp_arg3.get(), 'abc') def test_runtime_value_provider_positional_argument(self): + class UserDefinedOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( @@ -102,7 +111,9 @@ def _add_argparse_args(cls, parser): options.vpt_vp_arg4.get() def test_static_value_provider_type_cast(self): + class UserDefinedOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( @@ -116,6 +127,7 @@ def _add_argparse_args(cls, parser): def test_set_runtime_option(self): # define ValueProvider options, with and without default values class UserDefinedOptions1(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( @@ -168,7 +180,9 @@ def _add_argparse_args(cls, parser): self.assertEqual(options.vpt_vp_arg10.get(), 1.2) def test_choices(self): + class UserDefinedOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -186,7 +200,9 @@ def _add_argparse_args(cls, parser): self.assertEqual(options.vpt_vp_arg12, 2) def test_static_value_provider_choices(self): + class UserDefinedOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( @@ -240,7 +256,9 @@ def translator(x): self.assertEqual(mock_fn.call_count, 1) def test_nested_value_provider_wrap_runtime(self): + class UserDefinedOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 6209ca1ddae8..e444eba073b5 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -126,6 +126,7 @@ class Pipeline(HasDisplayData): should be used to designate new names (e.g. ``input | "label" >> my_transform``). """ + @classmethod def runner_implemented_transforms(cls): # type: () -> FrozenSet[str] @@ -296,6 +297,7 @@ def _replace(self, override): class TransformUpdater(PipelineVisitor): # pylint: disable=used-before-assignment """"A visitor that replaces the matching PTransforms.""" + def __init__(self, pipeline): # type: (Pipeline) -> None self.pipeline = pipeline @@ -428,6 +430,7 @@ class InputOutputUpdater(PipelineVisitor): # pylint: disable=used-before-assign We cannot update input and output values while visiting since that results in validation errors. """ + def __init__(self, pipeline): # type: (Pipeline) -> None self.pipeline = pipeline @@ -495,6 +498,7 @@ def visit_transform(self, transform_node): def _check_replacement(self, override): # type: (PTransformOverride) -> None class ReplacementValidator(PipelineVisitor): + def visit_transform(self, transform_node): # type: (AppliedPTransform) -> None if override.matches(transform_node): @@ -810,7 +814,6 @@ def _generate_unique_label( unique_suffix = uuid.uuid4().hex[:6] return '%s_%s' % (transform.label, unique_suffix) - def _infer_result_type( self, transform, # type: ptransform.PTransform @@ -929,6 +932,7 @@ def to_runner_api( TypeOptions).allow_non_deterministic_key_coders class ForceKvInputTypes(PipelineVisitor): + def enter_composite_transform(self, transform_node): # type: (AppliedPTransform) -> None self.visit_transform(transform_node) @@ -952,8 +956,8 @@ def visit_transform(self, transform_node): if (isinstance(output.element_type, typehints.TupleHint.TupleConstraint) and len(output.element_type.tuple_types) == 2 and - pcoll.element_type.tuple_types[0] == - output.element_type.tuple_types[0]): + pcoll.element_type.tuple_types[0] + == output.element_type.tuple_types[0]): output.requires_deterministic_key_coder = ( deterministic_key_coders and transform_node.full_label) for side_input in transform_node.transform.side_inputs: @@ -1005,8 +1009,10 @@ def from_runner_api( p = Pipeline( runner=runner, options=options, - display_data={str(ix): d - for ix, d in enumerate(proto.display_data)}) + display_data={ + str(ix): d + for ix, d in enumerate(proto.display_data) + }) from apache_beam.runners import pipeline_context context = pipeline_context.PipelineContext( proto.components, requirements=proto.requirements) @@ -1047,6 +1053,7 @@ class PipelineVisitor(object): Visitor pattern class used to traverse a DAG of transforms (used internally by Pipeline for bookkeeping purposes). """ + def visit_value(self, value, producer_node): # type: (pvalue.PValue, AppliedPTransform) -> None @@ -1082,6 +1089,7 @@ class ExternalTransformFinder(PipelineVisitor): """Looks for any external transforms in the pipeline and if found records it. """ + def __init__(self): self._contains_external_transforms = False @@ -1118,6 +1126,7 @@ class AppliedPTransform(object): A transform node representing an instance of applying a PTransform (used internally by Pipeline for bookeeping purposes). """ + def __init__( self, parent, # type: Optional[AppliedPTransform] @@ -1125,7 +1134,7 @@ def __init__( full_label, # type: str main_inputs, # type: Optional[Mapping[str, Union[pvalue.PBegin, pvalue.PCollection]]] environment_id=None, # type: Optional[str] - annotations=None, # type: Optional[Dict[str, bytes]] + annotations=None, # type: Optional[Dict[str, bytes]] ): # type: (...) -> None self.parent = parent @@ -1164,8 +1173,7 @@ def annotation_to_bytes(key, a: Any) -> bytes: annotations = { key: annotation_to_bytes(key, a) - for key, - a in transform.annotations().items() + for key, a in transform.annotations().items() } self.annotations = annotations @@ -1382,13 +1390,11 @@ def transform_to_runner_api( ], inputs={ tag: context.pcollections.get_id(pc) - for tag, - pc in sorted(self.named_inputs().items()) + for tag, pc in sorted(self.named_inputs().items()) }, outputs={ tag: context.pcollections.get_id(out) - for tag, - out in sorted(self.named_outputs().items()) + for tag, out in sorted(self.named_outputs().items()) }, environment_id=environment_id, annotations=self.annotations, @@ -1429,8 +1435,8 @@ def from_runner_api( # TODO(https://github.com/apache/beam/issues/20136): use key, value pairs # instead of depending on tags with index as a suffix. indexed_side_inputs = [ - (get_sideinput_index(tag), context.pcollections.get_by_id(id)) for tag, - id in proto.inputs.items() if tag in side_input_tags + (get_sideinput_index(tag), context.pcollections.get_by_id(id)) + for tag, id in proto.inputs.items() if tag in side_input_tags ] side_inputs = [si for _, si in sorted(indexed_side_inputs)] @@ -1453,8 +1459,7 @@ def from_runner_api( result.add_part(part) result.outputs = { None if tag == 'None' else tag: context.pcollections.get_by_id(id) - for tag, - id in proto.outputs.items() + for tag, id in proto.outputs.items() } # This annotation is expected by some runners. if proto.spec.urn == common_urns.primitives.PAR_DO.urn: @@ -1486,6 +1491,7 @@ class PTransformOverride(metaclass=abc.ABCMeta): TODO: Update this to support cases where input and/our output types are different. """ + @abc.abstractmethod def matches(self, applied_ptransform): # type: (AppliedPTransform) -> bool @@ -1564,6 +1570,7 @@ class ComponentIdMap(object): Component ID assignments are only guaranteed to be unique and consistent within the scope of a ComponentIdMap instance. """ + def __init__(self, namespace="ref"): self.namespace = namespace self._counters = defaultdict(lambda: 0) # type: Dict[type, int] diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index 1c11f953c58d..d8896f4ddea9 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -66,11 +66,13 @@ class FakeUnboundedSource(SourceBase): """Fake unbounded source. Does not work at runtime""" + def is_bounded(self): return False class DoubleParDo(beam.PTransform): + def expand(self, input): return input | 'Inner' >> beam.Map(lambda a: a * 2) @@ -79,6 +81,7 @@ def to_runner_api_parameter(self, context): class TripleParDo(beam.PTransform): + def expand(self, input): # Keeping labels the same intentionally to make sure that there is no label # conflict due to replacement. @@ -86,6 +89,7 @@ def expand(self, input): class ToStringParDo(beam.PTransform): + def expand(self, input): # We use copy.copy() here to make sure the typehint mechanism doesn't # automatically infer that the output type is str. @@ -93,32 +97,38 @@ def expand(self, input): class FlattenAndDouble(beam.PTransform): + def expand(self, pcolls): return pcolls | beam.Flatten() | 'Double' >> DoubleParDo() class FlattenAndTriple(beam.PTransform): + def expand(self, pcolls): return pcolls | beam.Flatten() | 'Triple' >> TripleParDo() class AddWithProductDoFn(beam.DoFn): + def process(self, input, a, b): yield input + a * b class AddThenMultiplyDoFn(beam.DoFn): + def process(self, input, a, b): yield (input + a) * b class AddThenMultiply(beam.PTransform): + def expand(self, pvalues): return pvalues[0] | beam.ParDo( AddThenMultiplyDoFn(), AsSingleton(pvalues[1]), AsSingleton(pvalues[2])) class PipelineTest(unittest.TestCase): + @staticmethod def custom_callable(pcoll): return pcoll | '+1' >> FlatMap(lambda x: [x + 1]) @@ -128,10 +138,12 @@ def custom_callable(pcoll): # work and is not related to other aspects of the tests. class CustomTransform(PTransform): + def expand(self, pcoll): return pcoll | '+1' >> FlatMap(lambda x: [x + 1]) class Visitor(PipelineVisitor): + def __init__(self, visited): self.visited = visited self.enter_composite = [] @@ -321,7 +333,9 @@ def test_reuse_cloned_custom_transform_instance(self): assert_that(result2, equal_to([5, 6, 7]), label='r2') def test_transform_no_super_init(self): + class AddSuffix(PTransform): + def __init__(self, suffix): # No call to super(...).__init__ self.suffix = suffix @@ -383,6 +397,7 @@ def test_aggregator_empty_input(self): self.assertEqual(actual, []) def test_pipeline_as_context(self): + def raise_exception(exn): raise exn @@ -392,7 +407,9 @@ def raise_exception(exn): p | Create([ValueError('msg')]) | Map(raise_exception) def test_ptransform_overrides(self): + class MyParDoOverride(PTransformOverride): + def matches(self, applied_ptransform): return isinstance(applied_ptransform.transform, DoubleParDo) @@ -411,7 +428,9 @@ def get_replacement_transform_for_applied_ptransform( p.run() def test_ptransform_override_type_hints(self): + class NoTypeHintOverride(PTransformOverride): + def matches(self, applied_ptransform): return isinstance(applied_ptransform.transform, DoubleParDo) @@ -420,6 +439,7 @@ def get_replacement_transform_for_applied_ptransform( return ToStringParDo() class WithTypeHintOverride(PTransformOverride): + def matches(self, applied_ptransform): return isinstance(applied_ptransform.transform, DoubleParDo) @@ -440,7 +460,9 @@ def get_replacement_transform_for_applied_ptransform( self.assertEqual(pcoll.producer.inputs[0].element_type, expected_type) def test_ptransform_override_multiple_inputs(self): + class MyParDoOverride(PTransformOverride): + def matches(self, applied_ptransform): return isinstance(applied_ptransform.transform, FlattenAndDouble) @@ -457,7 +479,9 @@ def get_replacement_transform(self, applied_ptransform): p.run() def test_ptransform_override_side_inputs(self): + class MyParDoOverride(PTransformOverride): + def matches(self, applied_ptransform): return ( isinstance(applied_ptransform.transform, ParDo) and @@ -478,7 +502,9 @@ def get_replacement_transform(self, transform): p.run() def test_ptransform_override_replacement_inputs(self): + class MyParDoOverride(PTransformOverride): + def matches(self, applied_ptransform): return ( isinstance(applied_ptransform.transform, ParDo) and @@ -508,11 +534,14 @@ def get_replacement_inputs(self, applied_ptransform): p.run() def test_ptransform_override_multiple_outputs(self): + class MultiOutputComposite(PTransform): + def __init__(self): self.output_tags = set() def expand(self, pcoll): + def mux_input(x): x = x * 2 if isinstance(x, int): @@ -532,6 +561,7 @@ def mux_input(x): } class MultiOutputOverride(PTransformOverride): + def matches(self, applied_ptransform): return applied_ptransform.full_label == 'MyMultiOutput' @@ -693,8 +723,11 @@ def test_incompatible_submission_and_runtime_envs_fail_pipeline(self): class DoFnTest(unittest.TestCase): + def test_element(self): + class TestDoFn(DoFn): + def process(self, element): yield element + 10 @@ -703,7 +736,9 @@ def process(self, element): assert_that(pcoll, equal_to([11, 12])) def test_side_input_no_tag(self): + class TestDoFn(DoFn): + def process(self, element, prefix, suffix): return ['%s-%s-%s' % (prefix, element, suffix)] @@ -717,7 +752,9 @@ def process(self, element, prefix, suffix): assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list])) def test_side_input_tagged(self): + class TestDoFn(DoFn): + def process(self, element, prefix, suffix=DoFn.SideInputParam): return ['%s-%s-%s' % (prefix, element, suffix)] @@ -752,7 +789,9 @@ def test_key_param(self): pipeline.run() def test_window_param(self): + class TestDoFn(DoFn): + def process(self, element, window=DoFn.WindowParam): yield (element, (float(window.start), float(window.end))) @@ -786,7 +825,9 @@ def test_windowed_value_param(self): equal_to([(1, [IntervalWindow(0, 5)]), (7, [IntervalWindow(5, 10)])])) # pylint: disable=too-many-function-args def test_timestamp_param(self): + class TestDoFn(DoFn): + def process(self, element, timestamp=DoFn.TimestampParam): yield timestamp @@ -820,6 +861,7 @@ def test_pane_info_param(self): label='CheckGrouped') def test_context_params(self): + def test_map( x, context_a=DoFn.BundleContextParam(_TestContext, args=('a')), @@ -834,7 +876,9 @@ def test_map( self.assertEqual(_TestContext.live_contexts, 0) def test_incomparable_default(self): + class IncomparableType(object): + def __eq__(self, other): raise RuntimeError() @@ -870,12 +914,14 @@ def __exit__(self, *args): class Bacon(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument('--slices', type=int) class Eggs(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument('--style', default='scrambled') @@ -886,6 +932,7 @@ class Breakfast(Bacon, Eggs): class PipelineOptionsTest(unittest.TestCase): + def test_flag_parsing(self): options = Breakfast(['--slices=3', '--style=sunny side up', '--ignored']) self.assertEqual(3, options.slices) @@ -956,8 +1003,11 @@ def test_dir(self): class RunnerApiTest(unittest.TestCase): + def test_parent_pointer(self): + class MyPTransform(beam.PTransform): + def expand(self, p): self.p = p return p | beam.Create([None]) @@ -984,6 +1034,7 @@ def test_annotations(self): some_proto = BytesCoder().to_runner_api(None) class EmptyTransform(beam.PTransform): + def expand(self, pcoll): return pcoll @@ -991,6 +1042,7 @@ def annotations(self): return {'foo': 'some_string'} class NonEmptyTransform(beam.PTransform): + def expand(self, pcoll): return pcoll | beam.Map(lambda x: x) @@ -1017,7 +1069,9 @@ def annotations(self): self.assertEqual(seen, 2) def test_transform_ids(self): + class MyPTransform(beam.PTransform): + def expand(self, p): self.p = p return p | beam.Create([None]) @@ -1030,7 +1084,9 @@ def expand(self, p): self.assertRegex(transform_id, r'[a-zA-Z0-9-_]+') def test_input_names(self): + class MyPTransform(beam.PTransform): + def expand(self, pcolls): return pcolls.values() | beam.Flatten() @@ -1048,7 +1104,9 @@ def expand(self, pcolls): self.fail('Unable to find transform.') def test_display_data(self): + class MyParentTransform(beam.PTransform): + def expand(self, p): self.p = p return p | beam.Create([None]) @@ -1063,6 +1121,7 @@ def display_data(self) -> dict: return parent_dd class MyPTransform(MyParentTransform): + def expand(self, p): self.p = p return p | beam.Create([None]) @@ -1166,6 +1225,7 @@ def test_runner_api_roundtrip_preserves_resource_hints(self): {common_urns.resource_hints.ACCELERATOR.urn: b'gpu'}) def test_hints_on_composite_transforms_are_propagated_to_subtransforms(self): + class FooHint(ResourceHint): urn = 'foo_urn' @@ -1236,6 +1296,7 @@ def CompositeTransform(pcoll): assert found def test_environments_with_same_resource_hints_are_reused(self): + class HintX(ResourceHint): urn = 'X_urn' @@ -1304,6 +1365,7 @@ class HintIsOdd(ResourceHint): self.assertEqual(len(env_ids), 5) def test_multiple_application_of_the_same_transform_set_different_hints(self): + class FooHint(ResourceHint): urn = 'foo_urn' @@ -1348,6 +1410,7 @@ def CompositeTransform(pcoll): assert count == 2 def test_environments_are_deduplicated(self): + def file_artifact(path, hash, staged_name): return beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.FILE.urn, @@ -1370,11 +1433,11 @@ def file_artifact(path, hash, staged_name): 'e1': beam_runner_api_pb2.Environment( dependencies=[file_artifact('a1', 'x', 'dest')]), 'e2': beam_runner_api_pb2.Environment( - dependencies=[file_artifact('a2', 'x', 'dest')]), - # Different hash. + dependencies=[file_artifact('a2', 'x', 'dest') + ]), # Different hash. 'e3': beam_runner_api_pb2.Environment( - dependencies=[file_artifact('a3', 'y', 'dest')]), - # Different destination. + dependencies=[file_artifact('a3', 'y', 'dest') + ]), # Different destination. 'e4': beam_runner_api_pb2.Environment( dependencies=[file_artifact('a4', 'y', 'dest2')]), # Multiple files with same hash and destinations. @@ -1387,14 +1450,12 @@ def file_artifact(path, hash, staged_name): dependencies=[ file_artifact('a2', 'x', 'dest'), file_artifact('b2', 'xb', 'destB') - ]), - # Overlapping, but not identical, files. + ]), # Overlapping, but not identical, files. 'e7': beam_runner_api_pb2.Environment( dependencies=[ file_artifact('a1', 'x', 'dest'), file_artifact('b2', 'y', 'destB') - ]), - # Same files as first, but differing other properties. + ]), # Same files as first, but differing other properties. 'e0': beam_runner_api_pb2.Environment( resource_hints={'hint': b'value'}, dependencies=[file_artifact('a1', 'x', 'dest')]), diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 5a400570cf18..4603e2fe32cc 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -80,6 +80,7 @@ class PValue(object): (2) Has a transform that can compute the value if executed. (3) Has a value which is meaningful if the transform was executed. """ + def __init__( self, pipeline: 'Pipeline', @@ -144,6 +145,7 @@ class PCollection(PValue, Generic[T]): Dataflow users should not construct PCollection objects directly in their pipelines. """ + def __eq__(self, other): if isinstance(other, PCollection): return self.tag == other.tag and self.producer == other.producer @@ -233,6 +235,7 @@ class PDone(PValue): class DoOutputsTuple(object): """An object grouping the multiple outputs of a ParDo or FlatMap transform.""" + def __init__( self, pipeline: 'Pipeline', @@ -330,6 +333,7 @@ class TaggedOutput(object): if it wants to emit on the main output and TaggedOutput objects if it wants to emit a value on a specific tagged output. """ + def __init__(self, tag: str, value: Any) -> None: if not isinstance(tag, str): raise TypeError( @@ -349,6 +353,7 @@ class AsSideInput(object): options, and should not be instantiated directly. (See instead AsSingleton, AsIter, etc.) """ + def __init__(self, pcoll: PCollection) -> None: from apache_beam.transforms import sideinputs self.pvalue = pcoll @@ -407,6 +412,7 @@ def requires_keyed_input(self): class _UnpickledSideInput(AsSideInput): + def __init__(self, side_input_data: 'SideInputData') -> None: self._data = side_input_data self._window_mapping_fn = side_input_data.window_mapping_fn @@ -426,8 +432,7 @@ def _from_runtime_iterable(it, options): def _view_options(self): return { - 'data': self._data, - # For non-fn-api runners. + 'data': self._data, # For non-fn-api runners. 'window_mapping_fn': self._data.window_mapping_fn, 'coder': self._windowed_coder(), } @@ -438,6 +443,7 @@ def _side_input_data(self): class SideInputData(object): """All of the data about a side input except for the bound PCollection.""" + def __init__( self, access_pattern: str, @@ -534,6 +540,7 @@ class AsIter(AsSideInput): (e.g., data.apply('label', MyPTransform(), AsIter(my_side_input) ) selects the former behavor. """ + def __repr__(self): return 'AsIter(%s)' % self.pvalue @@ -544,8 +551,7 @@ def _from_runtime_iterable(it, options): def _side_input_data(self) -> SideInputData: return SideInputData( common_urns.side_inputs.ITERABLE.urn, - self._window_mapping_fn, - lambda iterable: iterable) + self._window_mapping_fn, lambda iterable: iterable) @property def element_type(self): @@ -566,6 +572,7 @@ class AsList(AsSideInput): An AsList-wrapper around a PCollection whose one element is a list containing all elements in pcoll. """ + @staticmethod def _from_runtime_iterable(it, options): return list(it) @@ -590,6 +597,7 @@ class AsDict(AsSideInput): An AsDict-wrapper around a PCollection whose one element is a dict with entries for uniquely-keyed pairs in pcoll. """ + @staticmethod def _from_runtime_iterable(it, options): return dict(it) @@ -609,6 +617,7 @@ class AsMultiMap(AsSideInput): AsSingleton and AsIter are used, but returns an interface that allows key lookup. """ + @staticmethod def _from_runtime_iterable(it, options): # Legacy implementation. @@ -620,8 +629,7 @@ def _from_runtime_iterable(it, options): def _side_input_data(self) -> SideInputData: return SideInputData( common_urns.side_inputs.MULTIMAP.urn, - self._window_mapping_fn, - lambda x: x) + self._window_mapping_fn, lambda x: x) def requires_keyed_input(self): return True @@ -658,6 +666,7 @@ class Row(object): Note that in Beam 2.30.0 and later, Row objects are sensitive to field order. So `Row(x=3, y=4)` is not considered equal to `Row(y=4, x=3)`. """ + def __init__(self, **kwargs): self.__dict__.update(kwargs) @@ -681,8 +690,8 @@ def __eq__(self, other): return ( type(self) == type(other) and len(self.__dict__) == len(other.__dict__) and all( - s == o for s, - o in zip(self.__dict__.items(), other.__dict__.items()))) + s == o + for s, o in zip(self.__dict__.items(), other.__dict__.items()))) def __reduce__(self): return _make_Row, tuple(self.__dict__.items()) diff --git a/sdks/python/apache_beam/pvalue_test.py b/sdks/python/apache_beam/pvalue_test.py index 447d2327dc4f..f4a4541b05cc 100644 --- a/sdks/python/apache_beam/pvalue_test.py +++ b/sdks/python/apache_beam/pvalue_test.py @@ -29,6 +29,7 @@ class PValueTest(unittest.TestCase): + def test_pvalue_expected_arguments(self): pipeline = TestPipeline() value = PValue(pipeline) @@ -43,6 +44,7 @@ def test_assingleton_multi_element(self): class TaggedValueTest(unittest.TestCase): + def test_passed_tuple_as_tag(self): with self.assertRaisesRegex( TypeError, @@ -51,6 +53,7 @@ def test_passed_tuple_as_tag(self): class RowTest(unittest.TestCase): + def test_row_eq(self): row = Row(a=1, b=2) same = Row(a=1, b=2) diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index c43870d55ebb..f68c8e67eabb 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -89,6 +89,7 @@ class NameContext(object): """Holds the name information for a step.""" + def __init__(self, step_name, transform_id=None): # type: (str, Optional[str]) -> None @@ -126,6 +127,7 @@ class Receiver(object): This class can be efficiently used to pass values between the sdk and worker harnesses. """ + def receive(self, windowed_value): # type: (WindowedValue) -> None raise NotImplementedError @@ -142,6 +144,7 @@ class MethodWrapper(object): """For internal use only; no backwards-compatibility guarantees. Represents a method that can be invoked by `DoFnInvoker`.""" + def __init__(self, obj_to_invoke, method_name): """ Initiates a ``MethodWrapper``. @@ -274,6 +277,7 @@ class DoFnSignature(object): https://s.apache.org/splittable-do-fn) (3) validating a ``DoFn`` based on the feature set offered by it. """ + def __init__(self, do_fn): # type: (core.DoFn) -> None # We add a property here for all methods defined by Beam DoFn features. @@ -473,10 +477,11 @@ class DoFnInvoker(object): A DoFnInvoker describes a particular way for invoking methods of a DoFn represented by a given DoFnSignature.""" - def __init__(self, - output_handler, # type: _OutputHandler - signature # type: DoFnSignature - ): + def __init__( + self, + output_handler, # type: _OutputHandler + signature # type: DoFnSignature + ): # type: (...) -> None """ @@ -496,8 +501,9 @@ def create_invoker( signature, # type: DoFnSignature output_handler, # type: OutputHandler context=None, # type: Optional[DoFnContext] - side_inputs=None, # type: Optional[List[sideinputs.SideInputMap]] - input_args=None, input_kwargs=None, + side_inputs=None, # type: Optional[List[sideinputs.SideInputMap]] + input_args=None, + input_kwargs=None, process_invocation=True, user_state_context=None, # type: Optional[userstate.UserStateContext] bundle_finalizer_param=None # type: Optional[core._BundleFinalizerParam] @@ -548,13 +554,13 @@ def create_invoker( user_state_context, bundle_finalizer_param) - def invoke_process(self, - windowed_value, # type: WindowedValue - restriction=None, - watermark_estimator_state=None, - additional_args=None, - additional_kwargs=None - ): + def invoke_process( + self, + windowed_value, # type: WindowedValue + restriction=None, + watermark_estimator_state=None, + additional_args=None, + additional_kwargs=None): # type: (...) -> Iterable[SplitResultResidual] """Invokes the DoFn.process() function. @@ -575,11 +581,11 @@ def invoke_process(self, """ raise NotImplementedError - def invoke_process_batch(self, - windowed_batch, # type: WindowedBatch - additional_args=None, - additional_kwargs=None - ): + def invoke_process_batch( + self, + windowed_batch, # type: WindowedBatch + additional_args=None, + additional_kwargs=None): # type: (...) -> None """Invokes the DoFn.process() function. @@ -669,34 +675,35 @@ def invoke_create_tracker(self, restriction): class SimpleInvoker(DoFnInvoker): """An invoker that processes elements ignoring windowing information.""" - def __init__(self, - output_handler, # type: OutputHandler - signature # type: DoFnSignature - ): + def __init__( + self, + output_handler, # type: OutputHandler + signature # type: DoFnSignature + ): # type: (...) -> None super().__init__(output_handler, signature) self.process_method = signature.process_method.method_value self.process_batch_method = signature.process_batch_method.method_value - def invoke_process(self, - windowed_value, # type: WindowedValue - restriction=None, - watermark_estimator_state=None, - additional_args=None, - additional_kwargs=None - ): + def invoke_process( + self, + windowed_value, # type: WindowedValue + restriction=None, + watermark_estimator_state=None, + additional_args=None, + additional_kwargs=None): # type: (...) -> Iterable[SplitResultResidual] self.output_handler.handle_process_outputs( windowed_value, self.process_method(windowed_value.value)) return [] - def invoke_process_batch(self, - windowed_batch, # type: WindowedBatch - restriction=None, - watermark_estimator_state=None, - additional_args=None, - additional_kwargs=None - ): + def invoke_process_batch( + self, + windowed_batch, # type: WindowedBatch + restriction=None, + watermark_estimator_state=None, + additional_args=None, + additional_kwargs=None): # type: (...) -> None self.output_handler.handle_process_batch_outputs( windowed_batch, self.process_batch_method(windowed_batch.values)) @@ -716,6 +723,7 @@ def _get_arg_placeholders( # Not to be confused with ArgumentPlaceHolder, which may be passed in # input_args and is a placeholder for side-inputs. class ArgPlaceholder(object): + def __init__(self, placeholder): self.placeholder = placeholder @@ -782,16 +790,17 @@ def __init__(self, placeholder): class PerWindowInvoker(DoFnInvoker): """An invoker that processes elements considering windowing information.""" - def __init__(self, - output_handler, # type: OutputHandler - signature, # type: DoFnSignature - context, # type: DoFnContext - side_inputs, # type: Iterable[sideinputs.SideInputMap] - input_args, - input_kwargs, - user_state_context, # type: Optional[userstate.UserStateContext] - bundle_finalizer_param # type: Optional[core._BundleFinalizerParam] - ): + def __init__( + self, + output_handler, # type: OutputHandler + signature, # type: DoFnSignature + context, # type: DoFnContext + side_inputs, # type: Iterable[sideinputs.SideInputMap] + input_args, + input_kwargs, + user_state_context, # type: Optional[userstate.UserStateContext] + bundle_finalizer_param # type: Optional[core._BundleFinalizerParam] + ): super().__init__(output_handler, signature) self.side_inputs = side_inputs self.context = context @@ -823,8 +832,8 @@ def __init__(self, # and has_cached_window_batch_args will be set to true if the corresponding # self.args_for_process,have been updated and should be reused directly. self.recalculate_window_args = ( - self.has_windowed_inputs or 'disable_global_windowed_args_caching' in - RuntimeValueProvider.experiments) + self.has_windowed_inputs or 'disable_global_windowed_args_caching' + in RuntimeValueProvider.experiments) self.has_cached_window_args = False self.has_cached_window_batch_args = False @@ -846,13 +855,13 @@ def __init__(self, self.kwargs_for_process_batch) = _get_arg_placeholders( signature.process_batch_method, input_args, input_kwargs) - def invoke_process(self, - windowed_value, # type: WindowedValue - restriction=None, - watermark_estimator_state=None, - additional_args=None, - additional_kwargs=None - ): + def invoke_process( + self, + windowed_value, # type: WindowedValue + restriction=None, + watermark_estimator_state=None, + additional_args=None, + additional_kwargs=None): # type: (...) -> Iterable[SplitResultResidual] if not additional_args: additional_args = [] @@ -918,11 +927,11 @@ def invoke_process(self, windowed_value, additional_args, additional_kwargs) return residuals - def invoke_process_batch(self, - windowed_batch, # type: WindowedBatch - additional_args=None, - additional_kwargs=None - ): + def invoke_process_batch( + self, + windowed_batch, # type: WindowedBatch + additional_args=None, + additional_kwargs=None): # type: (...) -> None if not additional_args: @@ -947,9 +956,9 @@ def invoke_process_batch(self, def _should_process_window_for_sdf( self, - windowed_value, # type: WindowedValue + windowed_value, # type: WindowedValue additional_kwargs, - window_index=None, # type: Optional[int] + window_index=None, # type: Optional[int] ): restriction_tracker = self.invoke_create_tracker(self.restriction) watermark_estimator = self.invoke_create_watermark_estimator( @@ -982,11 +991,12 @@ def _should_process_window_for_sdf( additional_kwargs[watermark_param] = self.threadsafe_watermark_estimator return True - def _invoke_process_per_window(self, - windowed_value, # type: WindowedValue - additional_args, - additional_kwargs, - ): + def _invoke_process_per_window( + self, + windowed_value, # type: WindowedValue + additional_args, + additional_kwargs, + ): # type: (...) -> Optional[SplitResultResidual] if self.has_cached_window_args: args_for_process, kwargs_for_process = ( @@ -1155,16 +1165,17 @@ def _invoke_process_batch_per_window( self.threadsafe_watermark_estimator) @staticmethod - def _try_split(fraction, - window_index, # type: Optional[int] - stop_window_index, # type: Optional[int] - windowed_value, # type: WindowedValue + def _try_split( + fraction, + window_index, # type: Optional[int] + stop_window_index, # type: Optional[int] + windowed_value, # type: WindowedValue restriction, watermark_estimator_state, - restriction_provider, # type: RestrictionProvider - restriction_tracker, # type: RestrictionTracker - watermark_estimator, # type: WatermarkEstimator - ): + restriction_provider, # type: RestrictionProvider + restriction_tracker, # type: RestrictionTracker + watermark_estimator, # type: WatermarkEstimator + ): # type: (...) -> Optional[Tuple[Iterable[SplitResultPrimary], Iterable[SplitResultResidual], Optional[int]]] """Try to split returning a primaries, residuals and a new stop index. @@ -1199,6 +1210,7 @@ def _try_split(fraction, splitting was not possible. new_stop_index will only be set if the splittable DoFn is window observing otherwise it will be None. """ + def compute_whole_window_split(to_index, from_index): restriction_size = restriction_provider.restriction_size( windowed_value, restriction) @@ -1405,21 +1417,22 @@ class DoFnRunner: A helper class for executing ParDo operations. """ - def __init__(self, - fn, # type: core.DoFn - args, - kwargs, - side_inputs, # type: Iterable[sideinputs.SideInputMap] - windowing, - tagged_receivers, # type: Mapping[Optional[str], Receiver] - step_name=None, # type: Optional[str] - logging_context=None, - state=None, - scoped_metrics_container=None, - operation_name=None, - transform_id=None, - user_state_context=None, # type: Optional[userstate.UserStateContext] - ): + def __init__( + self, + fn, # type: core.DoFn + args, + kwargs, + side_inputs, # type: Iterable[sideinputs.SideInputMap] + windowing, + tagged_receivers, # type: Mapping[Optional[str], Receiver] + step_name=None, # type: Optional[str] + logging_context=None, + state=None, + scoped_metrics_container=None, + operation_name=None, + transform_id=None, + user_state_context=None, # type: Optional[userstate.UserStateContext] + ): """Initializes a DoFnRunner. Args: @@ -1613,6 +1626,7 @@ def _reraise_augmented(self, exn, windowed_value=None): class OutputHandler(object): + def handle_process_outputs( self, windowed_input_element, results, watermark_estimator=None): # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None @@ -1627,15 +1641,16 @@ def handle_process_batch_outputs( class _OutputHandler(OutputHandler): """Processes output produced by DoFn method invocations.""" - def __init__(self, - window_fn, - main_receivers, # type: Receiver - tagged_receivers, # type: Mapping[Optional[str], Receiver] - per_element_output_counter, - output_batch_converter, # type: Optional[BatchConverter] - process_yields_batches, # type: bool - process_batch_yields_elements, # type: bool - ): + def __init__( + self, + window_fn, + main_receivers, # type: Receiver + tagged_receivers, # type: Mapping[Optional[str], Receiver] + per_element_output_counter, + output_batch_converter, # type: Optional[BatchConverter] + process_yields_batches, # type: bool + process_batch_yields_elements, # type: bool + ): """Initializes ``_OutputHandler``. Args: @@ -1877,6 +1892,7 @@ class DoFnState(object): Keeps track of state that DoFns want, currently, user counters. """ + def __init__(self, counter_factory): self.step_name = '' self._counter_factory = counter_factory @@ -1890,6 +1906,7 @@ def counter_for(self, aggregator): # TODO(robertwb): Replace core.DoFnContext with this. class DoFnContext(object): """For internal use only; no backwards-compatibility guarantees.""" + def __init__(self, label, element=None, state=None): self.label = label self.state = state @@ -1935,6 +1952,7 @@ class GroupByKeyInputVisitor(PipelineVisitor): TODO(BEAM-115): Once Python SDK is compatible with the new Runner API, we could directly replace the coder instead of mutating the element type. """ + def __init__(self, deterministic_key_coders=True): self.deterministic_key_coders = deterministic_key_coders @@ -1962,6 +1980,7 @@ def visit_transform(self, transform_node): def validate_pipeline_graph(pipeline_proto): """Ensures this is a correctly constructed Beam pipeline. """ + def get_coder(pcoll_id): return pipeline_proto.components.coders[ pipeline_proto.components.pcollections[pcoll_id].coder_id] @@ -1986,11 +2005,11 @@ def validate_transform(transform_id): "Bad coder for output of %s: %s" % (transform_id, output_coder)) output_values_coder = pipeline_proto.components.coders[ output_coder.component_coder_ids[1]] - if (input_coder.component_coder_ids[0] != - output_coder.component_coder_ids[0] or + if (input_coder.component_coder_ids[0] + != output_coder.component_coder_ids[0] or output_values_coder.spec.urn != common_urns.coders.ITERABLE.urn or - output_values_coder.component_coder_ids[0] != - input_coder.component_coder_ids[1]): + output_values_coder.component_coder_ids[0] + != input_coder.component_coder_ids[1]): raise ValueError( "Incompatible input coder %s and output coder %s for transform %s" % (transform_id, input_coder, output_coder)) @@ -2009,6 +2028,7 @@ def validate_transform(transform_id): def merge_common_environments(pipeline_proto, inplace=False): + def dep_key(dep): if dep.type_urn == common_urns.artifact_types.FILE.urn: payload = beam_runner_api_pb2.ArtifactFilePayload.FromString( @@ -2052,7 +2072,8 @@ def env_key(env): environment_remappings = { e: es[0] - for es in canonical_environments.values() for e in es + for es in canonical_environments.values() + for e in es } if not inplace: diff --git a/sdks/python/apache_beam/runners/common_test.py b/sdks/python/apache_beam/runners/common_test.py index ca2cd2539a8c..bfb2a29697c4 100644 --- a/sdks/python/apache_beam/runners/common_test.py +++ b/sdks/python/apache_beam/runners/common_test.py @@ -45,8 +45,11 @@ class DoFnSignatureTest(unittest.TestCase): + def test_dofn_validate_process_error(self): + class MyDoFn(DoFn): + def process(self, element, w1=DoFn.WindowParam, w2=DoFn.WindowParam): pass @@ -54,7 +57,9 @@ def process(self, element, w1=DoFn.WindowParam, w2=DoFn.WindowParam): DoFnSignature(MyDoFn()) def test_dofn_get_defaults(self): + class MyDoFn(DoFn): + def process(self, element, w=DoFn.WindowParam): pass @@ -64,7 +69,9 @@ def process(self, element, w=DoFn.WindowParam): @unittest.skip('BEAM-5878') def test_dofn_get_defaults_kwonly(self): + class MyDoFn(DoFn): + def process(self, element, *, w=DoFn.WindowParam): pass @@ -73,7 +80,9 @@ def process(self, element, *, w=DoFn.WindowParam): self.assertEqual(signature.process_method.defaults, [DoFn.WindowParam]) def test_dofn_validate_start_bundle_error(self): + class MyDoFn(DoFn): + def process(self, element): pass @@ -84,7 +93,9 @@ def start_bundle(self, w1=DoFn.WindowParam): DoFnSignature(MyDoFn()) def test_dofn_validate_finish_bundle_error(self): + class MyDoFn(DoFn): + def process(self, element): pass @@ -95,12 +106,15 @@ def finish_bundle(self, w1=DoFn.WindowParam): DoFnSignature(MyDoFn()) def test_unbounded_element_process_fn(self): + class UnboundedDoFn(DoFn): + @DoFn.unbounded_per_element() def process(self, element): pass class BoundedDoFn(DoFn): + def process(self, element): pass @@ -118,14 +132,18 @@ def setUp(self): DoFnProcessTest.all_records = [] def record_dofn(self): + class RecordDoFn(DoFn): + def process(self, element): DoFnProcessTest.all_records.append(element) return RecordDoFn() def test_dofn_process_keyparam(self): + class DoFnProcessWithKeyparam(DoFn): + def process(self, element, mykey=DoFn.KeyParam): yield "{key}-verify".format(key=mykey) @@ -147,7 +165,9 @@ def process(self, element, mykey=DoFn.KeyParam): sorted(DoFnProcessTest.all_records)) def test_dofn_process_keyparam_error_no_key(self): + class DoFnProcessWithKeyparam(DoFn): + def process(self, element, mykey=DoFn.KeyParam): yield "{key}-verify".format(key=mykey) @@ -158,12 +178,15 @@ def process(self, element, mykey=DoFn.KeyParam): (p | test_stream | beam.ParDo(DoFnProcessWithKeyparam())) def test_pardo_with_unbounded_per_element_dofn(self): + class UnboundedDoFn(beam.DoFn): + @beam.DoFn.unbounded_per_element() def process(self, element): pass class BoundedDoFn(beam.DoFn): + def process(self, element): pass @@ -177,11 +200,13 @@ def process(self, element): class TestOffsetRestrictionProvider(RestrictionProvider): + def restriction_size(self, element, restriction): return restriction.size() class PerWindowInvokerSplitTest(unittest.TestCase): + def setUp(self): self.window1 = IntervalWindow(0, 10) self.window2 = IntervalWindow(10, 20) @@ -424,7 +449,10 @@ def test_window_observing_split_on_last_window(self): expected_primary_split, expected_primary_windows, )) - hc.assert_that(residuals, hc.contains_inanyorder(expected_residual_split, )) + hc.assert_that( + residuals, hc.contains_inanyorder( + expected_residual_split, + )) self.assertEqual(stop_index, 3) def test_window_observing_split_on_first_window_fallback(self): @@ -588,6 +616,7 @@ def test_window_observing_split_on_window_boundary_round_down_on_last_window( class UtilitiesTest(unittest.TestCase): + def test_equal_environments_merged(self): pipeline_proto = merge_common_environments( beam_runner_api_pb2.Pipeline( diff --git a/sdks/python/apache_beam/runners/dask/dask_runner.py b/sdks/python/apache_beam/runners/dask/dask_runner.py index 0f2317074cea..ccea1a092108 100644 --- a/sdks/python/apache_beam/runners/dask/dask_runner.py +++ b/sdks/python/apache_beam/runners/dask/dask_runner.py @@ -49,6 +49,7 @@ class DaskOptions(PipelineOptions): + @staticmethod def _parse_timeout(candidate): try: @@ -137,6 +138,7 @@ def metrics(self): class DaskRunner(BundleBasedDirectRunner): """Executes a pipeline on a Dask distributed client.""" + @staticmethod def to_dask_bag_visitor() -> PipelineVisitor: from dask import bag as db diff --git a/sdks/python/apache_beam/runners/dask/dask_runner_test.py b/sdks/python/apache_beam/runners/dask/dask_runner_test.py index 66dda4a984f4..268e6a57684e 100644 --- a/sdks/python/apache_beam/runners/dask/dask_runner_test.py +++ b/sdks/python/apache_beam/runners/dask/dask_runner_test.py @@ -37,6 +37,7 @@ class DaskOptionsTest(unittest.TestCase): + def test_parses_connection_timeout__defaults_to_none(self): default_options = PipelineOptions([]) default_dask_options = default_options.view_as(DaskOptions) @@ -69,6 +70,7 @@ def test_parser_destinations__agree_with_dask_client(self): class DaskRunnerRunPipelineTest(unittest.TestCase): """Test class used to introspect the dask runner via a debugger.""" + def setUp(self) -> None: self.pipeline = test_pipeline.TestPipeline(runner=DaskRunner()) @@ -83,6 +85,7 @@ def test_create_multiple(self): assert_that(pcoll, equal_to([1, 2, 3, 4])) def test_create_and_map(self): + def double(x): return x * 2 @@ -91,6 +94,7 @@ def double(x): assert_that(pcoll, equal_to([2])) def test_create_and_map_multiple(self): + def double(x): return x * 2 @@ -99,6 +103,7 @@ def double(x): assert_that(pcoll, equal_to([2, 4])) def test_create_and_map_many(self): + def double(x): return x * 2 @@ -107,6 +112,7 @@ def double(x): assert_that(pcoll, equal_to(list(range(2, 21, 2)))) def test_create_map_and_groupby(self): + def double(x): return x * 2, x @@ -115,6 +121,7 @@ def double(x): assert_that(pcoll, equal_to([(2, [1])])) def test_create_map_and_groupby_multiple(self): + def double(x): return x * 2, x @@ -127,6 +134,7 @@ def double(x): assert_that(pcoll, equal_to([(2, [1, 1]), (4, [2, 2]), (6, [3])])) def test_map_with_positional_side_input(self): + def mult_by(x, y): return x * y @@ -139,6 +147,7 @@ def mult_by(x, y): assert_that(pcoll, equal_to([3])) def test_map_with_keyword_side_input(self): + def mult_by(x, y): return x * y @@ -151,6 +160,7 @@ def mult_by(x, y): assert_that(pcoll, equal_to([3])) def test_pardo_side_inputs(self): + def cross_product(elem, sides): for side in sides: yield elem, side @@ -291,16 +301,13 @@ def test_multimap_multiside_input(self): assert_that( main | "first map" >> beam.Map( - lambda k, - d, - l: (k, sorted(d[k]), sorted([e[1] for e in l])), + lambda k, d, l: (k, sorted(d[k]), sorted([e[1] for e in l])), beam.pvalue.AsMultiMap(side), beam.pvalue.AsList(side), ) | "second map" >> beam.Map( - lambda k, - d, - l: (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])), + lambda k, d, l: + (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])), beam.pvalue.AsMultiMap(side), beam.pvalue.AsList(side), ), @@ -323,6 +330,7 @@ def test_multimap_side_input_type_coercion(self): ) def test_pardo_unfusable_side_inputs__one(self): + def cross_product(elem, sides): for side in sides: yield elem, side @@ -337,6 +345,7 @@ def cross_product(elem, sides): ) def test_pardo_unfusable_side_inputs__two(self): + def cross_product(elem, sides): for side in sides: yield elem, side @@ -357,6 +366,7 @@ def cross_product(elem, sides): ) def test_groupby_with_fixed_windows(self): + def double(x): return x * 2, x @@ -385,6 +395,7 @@ def test_groupby_string_keys(self): class ExpectingSideInputsFn(beam.DoFn): + def __init__(self, name): self._name = name diff --git a/sdks/python/apache_beam/runners/dask/overrides.py b/sdks/python/apache_beam/runners/dask/overrides.py index b952834f12d7..50653f841853 100644 --- a/sdks/python/apache_beam/runners/dask/overrides.py +++ b/sdks/python/apache_beam/runners/dask/overrides.py @@ -44,6 +44,7 @@ def get_windowing(self, inputs: t.Any) -> beam.Windowing: @typehints.with_input_types(K) @typehints.with_output_types(K) class _Reshuffle(beam.PTransform): + def expand(self, input_or_inputs): return beam.pvalue.PCollection.from_(input_or_inputs) @@ -59,6 +60,7 @@ def expand(self, input_or_inputs): @typehints.with_input_types(t.Tuple[K, V]) @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) class _GroupByKeyOnly(beam.PTransform): + def expand(self, input_or_inputs): return beam.pvalue.PCollection.from_(input_or_inputs) @@ -73,6 +75,7 @@ def infer_output_type(self, input_type): @typehints.with_input_types(t.Tuple[K, t.Iterable[V]]) @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) class _GroupAlsoByWindow(beam.ParDo): + def __init__(self, windowing): super().__init__(_GroupAlsoByWindowDoFn(windowing)) self.windowing = windowing @@ -84,6 +87,7 @@ def expand(self, input_or_inputs): @typehints.with_input_types(t.Tuple[K, V]) @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) class _GroupByKey(beam.PTransform): + def expand(self, input_or_inputs): return ( input_or_inputs @@ -93,6 +97,7 @@ def expand(self, input_or_inputs): class _Flatten(beam.PTransform): + def expand(self, input_or_inputs): if isinstance(input_or_inputs, beam.PCollection): # NOTE(cisaacstern): I needed this to avoid @@ -106,7 +111,9 @@ def expand(self, input_or_inputs): def dask_overrides() -> t.List[PTransformOverride]: + class CreateOverride(PTransformOverride): + def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.Create @@ -115,6 +122,7 @@ def get_replacement_transform_for_applied_ptransform( return _Create(t.cast(beam.Create, applied_ptransform.transform).values) class ReshuffleOverride(PTransformOverride): + def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.Reshuffle @@ -123,6 +131,7 @@ def get_replacement_transform_for_applied_ptransform( return _Reshuffle() class ReadOverride(PTransformOverride): + def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.io.Read @@ -131,6 +140,7 @@ def get_replacement_transform_for_applied_ptransform( return _Read(t.cast(beam.io.Read, applied_ptransform.transform).source) class GroupByKeyOverride(PTransformOverride): + def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.GroupByKey @@ -139,6 +149,7 @@ def get_replacement_transform_for_applied_ptransform( return _GroupByKey() class FlattenOverride(PTransformOverride): + def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.Flatten diff --git a/sdks/python/apache_beam/runners/dask/transform_evaluator.py b/sdks/python/apache_beam/runners/dask/transform_evaluator.py index e3bd5fd87763..f52931ea808b 100644 --- a/sdks/python/apache_beam/runners/dask/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/dask/transform_evaluator.py @@ -143,12 +143,14 @@ def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag: class NoOp(DaskBagOp): """An identity on a dask bag: returns the input as-is.""" + def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag: return input_bag class Create(DaskBagOp): """The beginning of a Beam pipeline; the input must be `None`.""" + def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag: assert input_bag is None, 'Create expects no input!' original_transform = t.cast(_Create, self.transform) @@ -185,6 +187,7 @@ class ParDo(DaskBagOp): This consumes a sequence of items and returns a sequence of items. """ + def apply(self, input_bag: db.Bag, side_inputs: OpSide = None) -> db.Bag: transform = t.cast(apache_beam.ParDo, self.transform) @@ -226,7 +229,9 @@ def apply(self, input_bag: db.Bag, side_inputs: OpSide = None) -> db.Bag: class GroupByKey(DaskBagOp): """Group a PCollection into a mapping of keys to elements.""" + def apply(self, input_bag: db.Bag, side_inputs: OpSide = None) -> db.Bag: + def key(item): return item[0] @@ -239,6 +244,7 @@ def value(item): class Flatten(DaskBagOp): """Produces a flattened bag from a collection of bags.""" + def apply( self, input_bag: t.List[db.Bag], side_inputs: OpSide = None) -> db.Bag: assert isinstance(input_bag, list), 'Must take a sequence of bags!' diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_metrics_pipeline.py b/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_metrics_pipeline.py index bfe56c7e38c2..8479c18ece3d 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_metrics_pipeline.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_metrics_pipeline.py @@ -130,6 +130,7 @@ def metric_matchers(): class UserMetricsDoFn(beam.DoFn): """Parse each line of input text into words.""" + def __init__(self): self.total_metric = Metrics.counter(self.__class__, 'total_values') self.dist_metric = Metrics.distribution( @@ -164,14 +165,8 @@ def apply_and_run(pipeline): | beam.GroupByKey() | 'm_out' >> beam.FlatMap( lambda x: [ - 1, - 2, - 3, - 4, - 5, - beam.pvalue.TaggedOutput('once', x), - beam.pvalue.TaggedOutput('twice', x), - beam.pvalue.TaggedOutput('twice', x) + 1, 2, 3, 4, 5, beam.pvalue.TaggedOutput('once', x), beam.pvalue. + TaggedOutput('twice', x), beam.pvalue.TaggedOutput('twice', x) ])) result = pipeline.run() result.wait_until_finish() diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_metrics_pipeline_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_metrics_pipeline_test.py index 909c15896a26..f5feb90ed88c 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_metrics_pipeline_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_metrics_pipeline_test.py @@ -32,6 +32,7 @@ class ExerciseMetricsPipelineTest(unittest.TestCase): + def run_pipeline(self, **opts): test_pipeline = TestPipeline(is_integration_test=True) argv = test_pipeline.get_full_options_as_args(**opts) diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_streaming_metrics_pipeline.py b/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_streaming_metrics_pipeline.py index 01c0c5beb909..eae4070cfd42 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_streaming_metrics_pipeline.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_streaming_metrics_pipeline.py @@ -36,6 +36,7 @@ class StreamingUserMetricsDoFn(beam.DoFn): """Generates user metrics and outputs same element.""" + def __init__(self): self.double_message_counter = Metrics.counter( self.__class__, 'double_msg_counter_name') diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_streaming_metrics_pipeline_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_streaming_metrics_pipeline_test.py index 83bb35034642..0941a6932c02 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_streaming_metrics_pipeline_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_streaming_metrics_pipeline_test.py @@ -50,6 +50,7 @@ class ExerciseStreamingMetricsPipelineTest(unittest.TestCase): + def setUp(self): """Creates all required topics and subs.""" self.test_pipeline = TestPipeline(is_integration_test=True) diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_job_service.py b/sdks/python/apache_beam/runners/dataflow/dataflow_job_service.py index 710c71273e34..f9e7440ca633 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_job_service.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_job_service.py @@ -29,6 +29,7 @@ class DataflowBeamJob(local_job_service.BeamJob): """A representation of a single Beam job to be run on the Dataflow runner. """ + def _invoke_runner(self): """Actually calls Dataflow and waits for completion. """ @@ -45,8 +46,7 @@ def _invoke_runner(self): dataflow_runner.DataflowRunner.poll_for_job_completion( runner, self.result, - None, - lambda dataflow_state: self.set_state( + None, lambda dataflow_state: self.set_state( portable_runner.PipelineResult.pipeline_state_to_runner_api_state( self.result.api_jobstate_to_pipeline_state(dataflow_state)))) return self.result diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_job_service_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_job_service_test.py index e2f880085cb8..64ca62761c9c 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_job_service_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_job_service_test.py @@ -33,6 +33,7 @@ @unittest.skipIf(apiclient is None, 'GCP dependencies are not installed') class DirectPipelineResultTest(unittest.TestCase): + def test_dry_run(self): # Not an integration test that actually runs on Dataflow, # but does exercise (most of) the translation and setup code, @@ -58,6 +59,7 @@ def test_dry_run(self): @unittest.skipIf(apiclient is None, 'GCP dependencies are not installed') class DirectPipelineTemplateTest(unittest.TestCase): + def test_template(self): job_servicer = local_job_service.LocalJobServicer( None, beam_job_type=dataflow_job_service.DataflowBeamJob) diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py b/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py index 78c3b64595b0..01d8a471a18c 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py @@ -64,6 +64,7 @@ def _get_match(proto, filter_fn): class DataflowMetrics(MetricResults): """Implementation of MetricResults class for the Dataflow runner.""" + def __init__(self, dataflow_client=None, job_result=None, job_graph=None): """Initialize the Dataflow metrics object. @@ -101,8 +102,8 @@ def _translate_step_name(self, internal_name): 'Could not translate the internal step name %r since job graph is ' 'not available.' % internal_name) user_step_name = None - if (self._job_graph and internal_name in - self._job_graph.proto_pipeline.components.transforms.keys()): + if (self._job_graph and internal_name + in self._job_graph.proto_pipeline.components.transforms.keys()): # Dataflow Runner v2 with portable job submission uses proto transform map # IDs for step names. Also PTransform.unique_name maps to user step names. # Hence we lookup user step names based on the proto. diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_metrics_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_metrics_test.py index 86e71f9c1ed2..6cac5244734a 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_metrics_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_metrics_test.py @@ -53,6 +53,7 @@ class DictToObject(object): """Translate from a dict(list()) structure to an object structure""" + def __init__(self, data): for name, value in data.items(): setattr(self, name, self._wrap(value)) @@ -308,8 +309,7 @@ class TestDataflowMetrics(unittest.TestCase): "additionalProperties": [ { "key": "original_name", - "value": - "ToIsmRecordForMultimap-out0-ElementCount" + "value": "ToIsmRecordForMultimap-out0-ElementCount" }, # yapf: disable { "key": "output_user_name", @@ -332,13 +332,13 @@ class TestDataflowMetrics(unittest.TestCase): "additionalProperties": [ { "key": "original_name", - "value": - "ToIsmRecordForMultimap-out0-ElementCount" + "value": "ToIsmRecordForMultimap-out0-ElementCount" }, # yapf: disable { "key": "output_user_name", "value": "ToIsmRecordForMultimap-out0" - }, { + }, + { "key": "tentative", "value": "true" } ] diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index 162ace3ca451..dd4c8e257500 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -165,8 +165,8 @@ def rank_error(msg): # Check that job is in a post-preparation state before starting the # final countdown. - if (str(response.currentState) not in ('JOB_STATE_PENDING', - 'JOB_STATE_QUEUED')): + if (str(response.currentState) + not in ('JOB_STATE_PENDING', 'JOB_STATE_QUEUED')): # The job has failed; ensure we see any final error messages. sleep_secs = 1.0 # poll faster during the final countdown final_countdown_timer_secs -= sleep_secs @@ -244,6 +244,7 @@ class SideInputVisitor(PipelineVisitor): TODO(BEAM-115): Once Python SDK is compatible with the new Runner API, we could directly replace the coder instead of mutating the element type. """ + def visit_transform(self, transform_node): if isinstance(transform_node.transform, ParDo): new_side_inputs = [] @@ -284,6 +285,7 @@ class FlattenInputVisitor(PipelineVisitor): """A visitor that replaces the element type for input ``PCollections``s of a ``Flatten`` transform with that of the output ``PCollection``. """ + def visit_transform(self, transform_node): # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position @@ -306,6 +308,7 @@ class CombineFnVisitor(PipelineVisitor): """Checks if `CombineFn` has non-default setup or teardown methods. If yes, raises `ValueError`. """ + def visit_transform(self, applied_transform): transform = applied_transform.transform if isinstance(transform, core.ParDo) and isinstance( @@ -562,6 +565,7 @@ def get_default_gcp_region(self): class _DataflowSideInput(beam.pvalue.AsSideInput): """Wraps a side input as a dataflow-compatible side input.""" + def _view_options(self): return { 'data': self._data, @@ -661,6 +665,7 @@ def _is_runner_v2_disabled(options): class _DataflowIterableSideInput(_DataflowSideInput): """Wraps an iterable side input as dataflow-compatible side input.""" + def __init__(self, side_input): # pylint: disable=protected-access self.pvalue = side_input.pvalue @@ -675,6 +680,7 @@ def __init__(self, side_input): class _DataflowMultimapSideInput(_DataflowSideInput): """Wraps a multimap side input as dataflow-compatible side input.""" + def __init__(self, side_input): # pylint: disable=protected-access self.pvalue = side_input.pvalue @@ -689,6 +695,7 @@ def __init__(self, side_input): class DataflowPipelineResult(PipelineResult): """Represents the state of a pipeline run on the Dataflow service.""" + def __init__(self, job, runner): """Initialize a new DataflowPipelineResult instance. @@ -848,6 +855,7 @@ def __repr__(self): class DataflowRuntimeException(Exception): """Indicates an error has occurred in running this pipeline.""" + def __init__(self, msg, result): super().__init__(msg) self.result = result diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py index b5568305ce65..05d44c57dace 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py @@ -63,6 +63,7 @@ # TODO: Should not subclass ParDo. Switch to PTransform as soon as # composite transforms support display data. class SpecialParDo(beam.ParDo): + def __init__(self, fn, now): super().__init__(fn) self.fn = fn @@ -76,6 +77,7 @@ def display_data(self): class SpecialDoFn(beam.DoFn): + def display_data(self): return {'dofn_value': 42} @@ -85,6 +87,7 @@ def process(self): @unittest.skipIf(apiclient is None, 'GCP dependencies are not installed') class DataflowRunnerTest(unittest.TestCase, ExtraAssertionsMixin): + def setUp(self): self.default_properties = [ '--job_name=test-job', @@ -102,6 +105,7 @@ def test_wait_until_finish(self, patched_time_sleep): values_enum = dataflow_api.Job.CurrentStateValueValuesEnum class MockDataflowRunner(object): + def __init__(self, states): self.dataflow_client = mock.MagicMock() self.job = mock.MagicMock() @@ -169,6 +173,7 @@ def test_cancel(self, patched_time_sleep): values_enum = dataflow_api.Job.CurrentStateValueValuesEnum class MockDataflowRunner(object): + def __init__(self, state, cancel_result): self.dataflow_client = mock.MagicMock() self.job = mock.MagicMock() @@ -352,9 +357,7 @@ def test_side_input_visitor(self): pc = p | beam.Create([]) transform = beam.Map( - lambda x, - y, - z: (x, y, z), + lambda x, y, z: (x, y, z), beam.pvalue.AsSingleton(pc), beam.pvalue.AsMultiMap(pc)) applied_transform = AppliedPTransform(None, transform, "label", {'pc': pc}) @@ -454,7 +457,9 @@ def test_get_default_gcp_region_ignores_error( 'https://github.com/apache/beam/issues/18716: enable once ' 'CombineFnVisitor is fixed') def test_unsupported_combinefn_detection(self): + class CombinerWithNonDefaultSetupTeardown(combiners.CountCombineFn): + def setup(self, *args, **kwargs): pass @@ -483,7 +488,9 @@ def teardown(self, *args, **kwargs): self.fail('ValueError raised unexpectedly') def test_pack_combiners(self): + class PackableCombines(beam.PTransform): + def annotations(self): return {python_urns.APPLY_COMBINER_PACKING: b''} diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index 4e65156f3bc7..daf84febbf6d 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -88,6 +88,7 @@ class Environment(object): """Wrapper for a dataflow Environment protobuf.""" + def __init__( self, packages, @@ -229,8 +230,8 @@ def __init__( container_image = dataflow.SdkHarnessContainerImage() container_image.containerImage = container_image_url container_image.useSingleCorePerContainer = ( - common_urns.protocols.MULTI_CORE_BUNDLE_PROCESSING.urn not in - environment.capabilities) + common_urns.protocols.MULTI_CORE_BUNDLE_PROCESSING.urn + not in environment.capabilities) container_image.environmentId = id for capability in environment.capabilities: container_image.capabilities.append(capability) @@ -316,7 +317,9 @@ def _get_python_sdk_name(self): class Job(object): """Wrapper for a dataflow Job protobuf.""" + def __str__(self): + def encode_shortstrings(input_buffer, errors='strict'): """Encoder (from Unicode) that suppresses long base64 strings.""" original_len = len(input_buffer) @@ -485,6 +488,7 @@ class DataflowApplicationClient(object): _HASH_CHUNK_SIZE = 1024 * 8 _GCS_CACHE_PREFIX = "artifact_cache" """A Dataflow API client used by application code to create and query jobs.""" + def __init__(self, options, root_staging_location=None): """Initializes a Dataflow API client object.""" self.standard_options = options.view_as(StandardOptions) @@ -1054,10 +1058,9 @@ def job_id_for_name(self, job_name): pageToken=token) response = self._client.projects_locations_jobs.List(request) for job in response.jobs: - if (job.name == job_name and job.currentState in [ - dataflow.Job.CurrentStateValueValuesEnum.JOB_STATE_RUNNING, - dataflow.Job.CurrentStateValueValuesEnum.JOB_STATE_DRAINING - ]): + if (job.name == job_name and job.currentState + in [dataflow.Job.CurrentStateValueValuesEnum.JOB_STATE_RUNNING, + dataflow.Job.CurrentStateValueValuesEnum.JOB_STATE_DRAINING]): return job.id token = response.nextPageToken if token is None: @@ -1066,6 +1069,7 @@ def job_id_for_name(self, job_name): class MetricUpdateTranslators(object): """Translators between accumulators and dataflow metric updates.""" + @staticmethod def translate_boolean(accumulator, metric_update_proto): metric_update_proto.boolean = accumulator.value @@ -1099,6 +1103,7 @@ def translate_scalar_counter_float(accumulator, metric_update_proto): class _LegacyDataflowStager(Stager): + def __init__(self, dataflow_application_client): super().__init__() self._dataflow_application_client = dataflow_application_client @@ -1215,9 +1220,8 @@ def get_response_encoding(): def _verify_interpreter_version_is_supported(pipeline_options): - if ('%s.%s' % - (sys.version_info[0], - sys.version_info[1]) in _PYTHON_VERSIONS_SUPPORTED_BY_DATAFLOW): + if ('%s.%s' % (sys.version_info[0], sys.version_info[1]) + in _PYTHON_VERSIONS_SUPPORTED_BY_DATAFLOW): return if 'dev' in beam_version.__version__: diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index d055065cb9d9..69ad6009ac5e 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -60,6 +60,7 @@ @unittest.skipIf(apiclient is None, 'GCP dependencies are not installed') class UtilTest(unittest.TestCase): + @unittest.skip("Enable once BEAM-1080 is fixed.") def test_create_application_client(self): pipeline_options = PipelineOptions() @@ -1671,6 +1672,7 @@ def exists_return_value(*args): self.assertEqual(pipeline, pipeline_expected) def test_stage_file_with_retry(self): + def effect(self, *args, **kwargs): nonlocal count count += 1 @@ -1679,6 +1681,7 @@ def effect(self, *args, **kwargs): raise Exception("This exception is raised for testing purpose.") class Unseekable(io.IOBase): + def seekable(self): return False diff --git a/sdks/python/apache_beam/runners/dataflow/internal/clients/cloudbuild/cloudbuild_v1_client.py b/sdks/python/apache_beam/runners/dataflow/internal/clients/cloudbuild/cloudbuild_v1_client.py index 9d699aba5892..52941bfe0b4b 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/clients/cloudbuild/cloudbuild_v1_client.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/clients/cloudbuild/cloudbuild_v1_client.py @@ -130,8 +130,7 @@ def RegionalWebhook(self, request, global_params=None): request_field='httpBody', request_type_name='CloudbuildLocationsRegionalWebhookRequest', response_type_name='Empty', - supports_download=False, - ) + supports_download=False, ) class OperationsService(base_api.BaseApiService): """Service class for the operations resource.""" @@ -165,8 +164,7 @@ def Cancel(self, request, global_params=None): request_field='cancelOperationRequest', request_type_name='CloudbuildOperationsCancelRequest', response_type_name='Empty', - supports_download=False, - ) + supports_download=False, ) def Get(self, request, global_params=None): r"""Gets the latest state of a long-running operation. Clients can use this method to poll the operation result at intervals as recommended by the API service. @@ -191,8 +189,7 @@ def Get(self, request, global_params=None): request_field='', request_type_name='CloudbuildOperationsGetRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) class ProjectsBuildsService(base_api.BaseApiService): """Service class for the projects_builds resource.""" @@ -226,8 +223,7 @@ def Approve(self, request, global_params=None): request_field='approveBuildRequest', request_type_name='CloudbuildProjectsBuildsApproveRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Cancel(self, request, global_params=None): r"""Cancels a build in progress. @@ -251,8 +247,7 @@ def Cancel(self, request, global_params=None): request_field='', request_type_name='CancelBuildRequest', response_type_name='Build', - supports_download=False, - ) + supports_download=False, ) def Create(self, request, global_params=None): r"""Starts a build with the specified configuration. This method returns a long-running `Operation`, which includes the build ID. Pass the build ID to `GetBuild` to determine the build status (such as `SUCCESS` or `FAILURE`). @@ -276,8 +271,7 @@ def Create(self, request, global_params=None): request_field='build', request_type_name='CloudbuildProjectsBuildsCreateRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Get(self, request, global_params=None): r"""Returns information about a previously requested build. The `Build` that is returned includes its status (such as `SUCCESS`, `FAILURE`, or `WORKING`), and timing information. @@ -301,8 +295,7 @@ def Get(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsBuildsGetRequest', response_type_name='Build', - supports_download=False, - ) + supports_download=False, ) def List(self, request, global_params=None): r"""Lists previously requested builds. Previously requested builds may still be in-progress, or may have finished successfully or unsuccessfully. @@ -326,8 +319,7 @@ def List(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsBuildsListRequest', response_type_name='ListBuildsResponse', - supports_download=False, - ) + supports_download=False, ) def Retry(self, request, global_params=None): r"""Creates a new build based on the specified build. This method creates a new build using the original build request, which may or may not result in an identical build. For triggered builds: * Triggered builds resolve to a precise revision; therefore a retry of a triggered build will result in a build that uses the same revision. For non-triggered builds that specify `RepoSource`: * If the original build built from the tip of a branch, the retried build will build from the tip of that branch, which may not be the same revision as the original build. * If the original build specified a commit sha or revision ID, the retried build will use the identical source. For builds that specify `StorageSource`: * If the original build pulled source from Google Cloud Storage without specifying the generation of the object, the new build will use the current object, which may be different from the original build source. * If the original build pulled source from Cloud Storage and specified the generation of the object, the new build will attempt to use the same object, which may or may not be available depending on the bucket's lifecycle management settings. @@ -351,8 +343,7 @@ def Retry(self, request, global_params=None): request_field='', request_type_name='RetryBuildRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) class ProjectsGithubEnterpriseConfigsService(base_api.BaseApiService): """Service class for the projects_githubEnterpriseConfigs resource.""" @@ -388,8 +379,7 @@ def Create(self, request, global_params=None): request_type_name= 'CloudbuildProjectsGithubEnterpriseConfigsCreateRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Delete(self, request, global_params=None): r"""Delete an association between a GCP project and a GitHub Enterprise server. @@ -416,8 +406,7 @@ def Delete(self, request, global_params=None): request_type_name= 'CloudbuildProjectsGithubEnterpriseConfigsDeleteRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Get(self, request, global_params=None): r"""Retrieve a GitHubEnterpriseConfig. @@ -443,8 +432,7 @@ def Get(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsGithubEnterpriseConfigsGetRequest', response_type_name='GitHubEnterpriseConfig', - supports_download=False, - ) + supports_download=False, ) def List(self, request, global_params=None): r"""List all GitHubEnterpriseConfigs for a given project. @@ -470,8 +458,7 @@ def List(self, request, global_params=None): request_type_name= 'CloudbuildProjectsGithubEnterpriseConfigsListRequest', response_type_name='ListGithubEnterpriseConfigsResponse', - supports_download=False, - ) + supports_download=False, ) def Patch(self, request, global_params=None): r"""Update an association between a GCP project and a GitHub Enterprise server. @@ -498,8 +485,7 @@ def Patch(self, request, global_params=None): request_type_name= 'CloudbuildProjectsGithubEnterpriseConfigsPatchRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) class ProjectsLocationsBitbucketServerConfigsConnectedRepositoriesService( base_api.BaseApiService): @@ -540,8 +526,7 @@ def BatchCreate(self, request, global_params=None): request_type_name= 'CloudbuildProjectsLocationsBitbucketServerConfigsConnectedRepositoriesBatchCreateRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) class ProjectsLocationsBitbucketServerConfigsReposService( base_api.BaseApiService): @@ -581,8 +566,7 @@ def List(self, request, global_params=None): request_type_name= 'CloudbuildProjectsLocationsBitbucketServerConfigsReposListRequest', response_type_name='ListBitbucketServerRepositoriesResponse', - supports_download=False, - ) + supports_download=False, ) class ProjectsLocationsBitbucketServerConfigsService(base_api.BaseApiService): """Service class for the projects_locations_bitbucketServerConfigs resource.""" @@ -621,8 +605,7 @@ def AddBitbucketServerConnectedRepository( request_type_name= 'CloudbuildProjectsLocationsBitbucketServerConfigsAddBitbucketServerConnectedRepositoryRequest', response_type_name='AddBitbucketServerConnectedRepositoryResponse', - supports_download=False, - ) + supports_download=False, ) def Create(self, request, global_params=None): r"""Creates a new `BitbucketServerConfig`. This API is experimental. @@ -649,8 +632,7 @@ def Create(self, request, global_params=None): request_type_name= 'CloudbuildProjectsLocationsBitbucketServerConfigsCreateRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Delete(self, request, global_params=None): r"""Delete a `BitbucketServerConfig`. This API is experimental. @@ -677,8 +659,7 @@ def Delete(self, request, global_params=None): request_type_name= 'CloudbuildProjectsLocationsBitbucketServerConfigsDeleteRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Get(self, request, global_params=None): r"""Retrieve a `BitbucketServerConfig`. This API is experimental. @@ -705,8 +686,7 @@ def Get(self, request, global_params=None): request_type_name= 'CloudbuildProjectsLocationsBitbucketServerConfigsGetRequest', response_type_name='BitbucketServerConfig', - supports_download=False, - ) + supports_download=False, ) def List(self, request, global_params=None): r"""List all `BitbucketServerConfigs` for a given project. This API is experimental. @@ -733,8 +713,7 @@ def List(self, request, global_params=None): request_type_name= 'CloudbuildProjectsLocationsBitbucketServerConfigsListRequest', response_type_name='ListBitbucketServerConfigsResponse', - supports_download=False, - ) + supports_download=False, ) def Patch(self, request, global_params=None): r"""Updates an existing `BitbucketServerConfig`. This API is experimental. @@ -761,8 +740,7 @@ def Patch(self, request, global_params=None): request_type_name= 'CloudbuildProjectsLocationsBitbucketServerConfigsPatchRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def RemoveBitbucketServerConnectedRepository( self, request, global_params=None): @@ -791,8 +769,7 @@ def RemoveBitbucketServerConnectedRepository( request_type_name= 'CloudbuildProjectsLocationsBitbucketServerConfigsRemoveBitbucketServerConnectedRepositoryRequest', response_type_name='Empty', - supports_download=False, - ) + supports_download=False, ) class ProjectsLocationsBuildsService(base_api.BaseApiService): """Service class for the projects_locations_builds resource.""" @@ -827,8 +804,7 @@ def Approve(self, request, global_params=None): request_field='approveBuildRequest', request_type_name='CloudbuildProjectsLocationsBuildsApproveRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Cancel(self, request, global_params=None): r"""Cancels a build in progress. @@ -854,8 +830,7 @@ def Cancel(self, request, global_params=None): request_field='', request_type_name='CancelBuildRequest', response_type_name='Build', - supports_download=False, - ) + supports_download=False, ) def Create(self, request, global_params=None): r"""Starts a build with the specified configuration. This method returns a long-running `Operation`, which includes the build ID. Pass the build ID to `GetBuild` to determine the build status (such as `SUCCESS` or `FAILURE`). @@ -880,8 +855,7 @@ def Create(self, request, global_params=None): request_field='build', request_type_name='CloudbuildProjectsLocationsBuildsCreateRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Get(self, request, global_params=None): r"""Returns information about a previously requested build. The `Build` that is returned includes its status (such as `SUCCESS`, `FAILURE`, or `WORKING`), and timing information. @@ -907,8 +881,7 @@ def Get(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsLocationsBuildsGetRequest', response_type_name='Build', - supports_download=False, - ) + supports_download=False, ) def List(self, request, global_params=None): r"""Lists previously requested builds. Previously requested builds may still be in-progress, or may have finished successfully or unsuccessfully. @@ -933,8 +906,7 @@ def List(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsLocationsBuildsListRequest', response_type_name='ListBuildsResponse', - supports_download=False, - ) + supports_download=False, ) def Retry(self, request, global_params=None): r"""Creates a new build based on the specified build. This method creates a new build using the original build request, which may or may not result in an identical build. For triggered builds: * Triggered builds resolve to a precise revision; therefore a retry of a triggered build will result in a build that uses the same revision. For non-triggered builds that specify `RepoSource`: * If the original build built from the tip of a branch, the retried build will build from the tip of that branch, which may not be the same revision as the original build. * If the original build specified a commit sha or revision ID, the retried build will use the identical source. For builds that specify `StorageSource`: * If the original build pulled source from Google Cloud Storage without specifying the generation of the object, the new build will use the current object, which may be different from the original build source. * If the original build pulled source from Cloud Storage and specified the generation of the object, the new build will attempt to use the same object, which may or may not be available depending on the bucket's lifecycle management settings. @@ -960,8 +932,7 @@ def Retry(self, request, global_params=None): request_field='', request_type_name='RetryBuildRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) class ProjectsLocationsGithubEnterpriseConfigsService(base_api.BaseApiService ): @@ -1000,8 +971,7 @@ def Create(self, request, global_params=None): request_type_name= 'CloudbuildProjectsLocationsGithubEnterpriseConfigsCreateRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Delete(self, request, global_params=None): r"""Delete an association between a GCP project and a GitHub Enterprise server. @@ -1029,8 +999,7 @@ def Delete(self, request, global_params=None): request_type_name= 'CloudbuildProjectsLocationsGithubEnterpriseConfigsDeleteRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Get(self, request, global_params=None): r"""Retrieve a GitHubEnterpriseConfig. @@ -1057,8 +1026,7 @@ def Get(self, request, global_params=None): request_type_name= 'CloudbuildProjectsLocationsGithubEnterpriseConfigsGetRequest', response_type_name='GitHubEnterpriseConfig', - supports_download=False, - ) + supports_download=False, ) def List(self, request, global_params=None): r"""List all GitHubEnterpriseConfigs for a given project. @@ -1085,8 +1053,7 @@ def List(self, request, global_params=None): request_type_name= 'CloudbuildProjectsLocationsGithubEnterpriseConfigsListRequest', response_type_name='ListGithubEnterpriseConfigsResponse', - supports_download=False, - ) + supports_download=False, ) def Patch(self, request, global_params=None): r"""Update an association between a GCP project and a GitHub Enterprise server. @@ -1113,8 +1080,7 @@ def Patch(self, request, global_params=None): request_type_name= 'CloudbuildProjectsLocationsGithubEnterpriseConfigsPatchRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) class ProjectsLocationsOperationsService(base_api.BaseApiService): """Service class for the projects_locations_operations resource.""" @@ -1150,8 +1116,7 @@ def Cancel(self, request, global_params=None): request_field='cancelOperationRequest', request_type_name='CloudbuildProjectsLocationsOperationsCancelRequest', response_type_name='Empty', - supports_download=False, - ) + supports_download=False, ) def Get(self, request, global_params=None): r"""Gets the latest state of a long-running operation. Clients can use this method to poll the operation result at intervals as recommended by the API service. @@ -1177,8 +1142,7 @@ def Get(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsLocationsOperationsGetRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) class ProjectsLocationsTriggersService(base_api.BaseApiService): """Service class for the projects_locations_triggers resource.""" @@ -1213,8 +1177,7 @@ def Create(self, request, global_params=None): request_field='buildTrigger', request_type_name='CloudbuildProjectsLocationsTriggersCreateRequest', response_type_name='BuildTrigger', - supports_download=False, - ) + supports_download=False, ) def Delete(self, request, global_params=None): r"""Deletes a `BuildTrigger` by its project ID and trigger ID. This API is experimental. @@ -1240,8 +1203,7 @@ def Delete(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsLocationsTriggersDeleteRequest', response_type_name='Empty', - supports_download=False, - ) + supports_download=False, ) def Get(self, request, global_params=None): r"""Returns information about a `BuildTrigger`. This API is experimental. @@ -1267,8 +1229,7 @@ def Get(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsLocationsTriggersGetRequest', response_type_name='BuildTrigger', - supports_download=False, - ) + supports_download=False, ) def List(self, request, global_params=None): r"""Lists existing `BuildTrigger`s. This API is experimental. @@ -1293,8 +1254,7 @@ def List(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsLocationsTriggersListRequest', response_type_name='ListBuildTriggersResponse', - supports_download=False, - ) + supports_download=False, ) def Patch(self, request, global_params=None): r"""Updates a `BuildTrigger` by its project ID and trigger ID. This API is experimental. @@ -1320,8 +1280,7 @@ def Patch(self, request, global_params=None): request_field='buildTrigger', request_type_name='CloudbuildProjectsLocationsTriggersPatchRequest', response_type_name='BuildTrigger', - supports_download=False, - ) + supports_download=False, ) def Run(self, request, global_params=None): r"""Runs a `BuildTrigger` at a particular source revision. @@ -1347,8 +1306,7 @@ def Run(self, request, global_params=None): request_field='runBuildTriggerRequest', request_type_name='CloudbuildProjectsLocationsTriggersRunRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Webhook(self, request, global_params=None): r"""ReceiveTriggerWebhook [Experimental] is called when the API receives a webhook request targeted at a specific trigger. @@ -1374,8 +1332,7 @@ def Webhook(self, request, global_params=None): request_field='httpBody', request_type_name='CloudbuildProjectsLocationsTriggersWebhookRequest', response_type_name='ReceiveTriggerWebhookResponse', - supports_download=False, - ) + supports_download=False, ) class ProjectsLocationsWorkerPoolsService(base_api.BaseApiService): """Service class for the projects_locations_workerPools resource.""" @@ -1411,8 +1368,7 @@ def Create(self, request, global_params=None): request_field='workerPool', request_type_name='CloudbuildProjectsLocationsWorkerPoolsCreateRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Delete(self, request, global_params=None): r"""Deletes a `WorkerPool`. @@ -1438,8 +1394,7 @@ def Delete(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsLocationsWorkerPoolsDeleteRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Get(self, request, global_params=None): r"""Returns details of a `WorkerPool`. @@ -1465,8 +1420,7 @@ def Get(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsLocationsWorkerPoolsGetRequest', response_type_name='WorkerPool', - supports_download=False, - ) + supports_download=False, ) def List(self, request, global_params=None): r"""Lists `WorkerPool`s. @@ -1492,8 +1446,7 @@ def List(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsLocationsWorkerPoolsListRequest', response_type_name='ListWorkerPoolsResponse', - supports_download=False, - ) + supports_download=False, ) def Patch(self, request, global_params=None): r"""Updates a `WorkerPool`. @@ -1519,8 +1472,7 @@ def Patch(self, request, global_params=None): request_field='workerPool', request_type_name='CloudbuildProjectsLocationsWorkerPoolsPatchRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) class ProjectsLocationsService(base_api.BaseApiService): """Service class for the projects_locations resource.""" @@ -1562,8 +1514,7 @@ def Create(self, request, global_params=None): request_field='buildTrigger', request_type_name='CloudbuildProjectsTriggersCreateRequest', response_type_name='BuildTrigger', - supports_download=False, - ) + supports_download=False, ) def Delete(self, request, global_params=None): r"""Deletes a `BuildTrigger` by its project ID and trigger ID. This API is experimental. @@ -1587,8 +1538,7 @@ def Delete(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsTriggersDeleteRequest', response_type_name='Empty', - supports_download=False, - ) + supports_download=False, ) def Get(self, request, global_params=None): r"""Returns information about a `BuildTrigger`. This API is experimental. @@ -1612,8 +1562,7 @@ def Get(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsTriggersGetRequest', response_type_name='BuildTrigger', - supports_download=False, - ) + supports_download=False, ) def List(self, request, global_params=None): r"""Lists existing `BuildTrigger`s. This API is experimental. @@ -1637,8 +1586,7 @@ def List(self, request, global_params=None): request_field='', request_type_name='CloudbuildProjectsTriggersListRequest', response_type_name='ListBuildTriggersResponse', - supports_download=False, - ) + supports_download=False, ) def Patch(self, request, global_params=None): r"""Updates a `BuildTrigger` by its project ID and trigger ID. This API is experimental. @@ -1662,8 +1610,7 @@ def Patch(self, request, global_params=None): request_field='buildTrigger', request_type_name='CloudbuildProjectsTriggersPatchRequest', response_type_name='BuildTrigger', - supports_download=False, - ) + supports_download=False, ) def Run(self, request, global_params=None): r"""Runs a `BuildTrigger` at a particular source revision. @@ -1687,8 +1634,7 @@ def Run(self, request, global_params=None): request_field='repoSource', request_type_name='CloudbuildProjectsTriggersRunRequest', response_type_name='Operation', - supports_download=False, - ) + supports_download=False, ) def Webhook(self, request, global_params=None): r"""ReceiveTriggerWebhook [Experimental] is called when the API receives a webhook request targeted at a specific trigger. @@ -1712,8 +1658,7 @@ def Webhook(self, request, global_params=None): request_field='httpBody', request_type_name='CloudbuildProjectsTriggersWebhookRequest', response_type_name='ReceiveTriggerWebhookResponse', - supports_download=False, - ) + supports_download=False, ) class ProjectsService(base_api.BaseApiService): """Service class for the projects resource.""" @@ -1755,5 +1700,4 @@ def Webhook(self, request, global_params=None): request_field='httpBody', request_type_name='CloudbuildWebhookRequest', response_type_name='Empty', - supports_download=False, - ) + supports_download=False, ) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/clients/cloudbuild/cloudbuild_v1_messages.py b/sdks/python/apache_beam/runners/dataflow/internal/clients/cloudbuild/cloudbuild_v1_messages.py index 99edce0c45e6..6d51b0c79417 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/clients/cloudbuild/cloudbuild_v1_messages.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/clients/cloudbuild/cloudbuild_v1_messages.py @@ -91,6 +91,7 @@ class ApprovalResult(_messages.Message): rendered by the UI differently. An example use case is a link to an external job that approved this Build. """ + class DecisionValueValuesEnum(_messages.Enum): r"""Required. The decision of this manual approval. @@ -461,6 +462,7 @@ class Build(_messages.Message): warnings: Output only. Non-fatal problems encountered during the execution of the build. """ + class StatusValueValuesEnum(_messages.Enum): r"""Output only. Status of the build. @@ -499,6 +501,7 @@ class SubstitutionsValue(_messages.Message): Fields: additionalProperties: Additional properties of type SubstitutionsValue """ + class AdditionalProperty(_messages.Message): r"""An additional property for a SubstitutionsValue object. @@ -527,6 +530,7 @@ class TimingValue(_messages.Message): Fields: additionalProperties: Additional properties of type TimingValue """ + class AdditionalProperty(_messages.Message): r"""An additional property for a TimingValue object. @@ -584,6 +588,7 @@ class BuildApproval(_messages.Message): result: Output only. Result of manual approval for this Build. state: Output only. The state of this build's approval. """ + class StateValueValuesEnum(_messages.Enum): r"""Output only. The state of this build's approval. @@ -677,6 +682,7 @@ class BuildOptions(_messages.Message): configuration. workerPool: This field deprecated; please use `pool.name` instead. """ + class LogStreamingOptionValueValuesEnum(_messages.Enum): r"""Option to define build log streaming behavior to Google Cloud Storage. @@ -852,6 +858,7 @@ class BuildStep(_messages.Message): start when all previous build steps in the `Build.Steps` list have completed successfully. """ + class StatusValueValuesEnum(_messages.Enum): r"""Output only. Status of the build step. At this time, build step status is only updated on build completion; step status is not updated in real- @@ -981,6 +988,7 @@ class BuildTrigger(_messages.Message): webhookConfig: WebhookConfig describes the configuration of a trigger that creates a build whenever a webhook is sent to a trigger's webhook URL. """ + class EventTypeValueValuesEnum(_messages.Enum): r"""EventType allows the user to explicitly set the type of event to which this BuildTrigger should respond. This field will be validated against the @@ -1011,6 +1019,7 @@ class SubstitutionsValue(_messages.Message): Fields: additionalProperties: Additional properties of type SubstitutionsValue """ + class AdditionalProperty(_messages.Message): r"""An additional property for a SubstitutionsValue object. @@ -2119,6 +2128,7 @@ class FailureInfo(_messages.Message): detail: Explains the failure issue in more detail using hard-coded text. type: The name of the failure. """ + class TypeValueValuesEnum(_messages.Enum): r"""The name of the failure. @@ -2181,6 +2191,7 @@ class GitFileSource(_messages.Message): the trigger invocation originated is assumed to be the repo from which to read the specified path. """ + class RepoTypeValueValuesEnum(_messages.Enum): r"""See RepoType above. @@ -2321,6 +2332,7 @@ class GitRepoSource(_messages.Message): repoType: See RepoType below. uri: The URI of the repo (required). """ + class RepoTypeValueValuesEnum(_messages.Enum): r"""See RepoType below. @@ -2391,6 +2403,7 @@ class Hash(_messages.Message): type: The type of hash that was performed. value: The hash value. """ + class TypeValueValuesEnum(_messages.Enum): r"""The type of hash that was performed. @@ -2436,6 +2449,7 @@ class HttpBody(_messages.Message): extensions: Application specific response metadata. Must be set in the first response for streaming APIs. """ + @encoding.MapUnrecognizedFields('additionalProperties') class ExtensionsValueListEntry(_messages.Message): r"""A ExtensionsValueListEntry object. @@ -2448,6 +2462,7 @@ class ExtensionsValueListEntry(_messages.Message): additionalProperties: Properties of the object. Contains field @type with type URL. """ + class AdditionalProperty(_messages.Message): r"""An additional property for a ExtensionsValueListEntry object. @@ -2488,6 +2503,7 @@ class InlineSecret(_messages.Message): kmsKeyName: Resource name of Cloud KMS crypto key to decrypt the encrypted value. In format: projects/*/locations/*/keyRings/*/cryptoKeys/* """ + @encoding.MapUnrecognizedFields('additionalProperties') class EnvMapValue(_messages.Message): r"""Map of environment variable name to its encrypted value. Secret @@ -2502,6 +2518,7 @@ class EnvMapValue(_messages.Message): Fields: additionalProperties: Additional properties of type EnvMapValue """ + class AdditionalProperty(_messages.Message): r"""An additional property for a EnvMapValue object. @@ -2617,6 +2634,7 @@ class NetworkConfig(_messages.Message): configuration options](https://cloud.google.com/build/docs/private- pools/set-up-private-pool-environment) """ + class EgressOptionValueValuesEnum(_messages.Enum): r"""Option to configure network egress for the workers. @@ -2653,6 +2671,7 @@ class Notification(_messages.Message): smtpDelivery: Configuration for SMTP (email) delivery. structDelivery: Escape hatch for users to supply custom delivery configs. """ + @encoding.MapUnrecognizedFields('additionalProperties') class StructDeliveryValue(_messages.Message): r"""Escape hatch for users to supply custom delivery configs. @@ -2664,6 +2683,7 @@ class StructDeliveryValue(_messages.Message): Fields: additionalProperties: Properties of the object. """ + class AdditionalProperty(_messages.Message): r"""An additional property for a StructDeliveryValue object. @@ -2800,6 +2820,7 @@ class Operation(_messages.Message): the original method name. For example, if the original method name is `TakeSnapshot()`, the inferred response type is `TakeSnapshotResponse`. """ + @encoding.MapUnrecognizedFields('additionalProperties') class MetadataValue(_messages.Message): r"""Service-specific metadata associated with the operation. It typically @@ -2814,6 +2835,7 @@ class MetadataValue(_messages.Message): additionalProperties: Properties of the object. Contains field @type with type URL. """ + class AdditionalProperty(_messages.Message): r"""An additional property for a MetadataValue object. @@ -2845,6 +2867,7 @@ class ResponseValue(_messages.Message): additionalProperties: Properties of the object. Contains field @type with type URL. """ + class AdditionalProperty(_messages.Message): r"""An additional property for a ResponseValue object. @@ -2953,6 +2976,7 @@ class PubsubConfig(_messages.Message): topic: The name of the topic from which this subscription is receiving messages. Format is `projects/{project}/topics/{topic}`. """ + class StateValueValuesEnum(_messages.Enum): r"""Potential issues with the underlying Pub/Sub subscription configuration. Only populated on get requests. @@ -2994,6 +3018,7 @@ class PullRequestFilter(_messages.Message): invertRegex: If true, branches that do NOT match the git_ref will trigger a build. """ + class CommentControlValueValuesEnum(_messages.Enum): r"""Configure builds to run whether a repository owner or collaborator need to comment `/gcbrun`. @@ -3078,6 +3103,7 @@ class RepoSource(_messages.Message): expressions accepted is the syntax accepted by RE2 and described at https://github.com/google/re2/wiki/Syntax """ + @encoding.MapUnrecognizedFields('additionalProperties') class SubstitutionsValue(_messages.Message): r"""Substitutions to use in a triggered build. Should only be used with @@ -3090,6 +3116,7 @@ class SubstitutionsValue(_messages.Message): Fields: additionalProperties: Additional properties of type SubstitutionsValue """ + class AdditionalProperty(_messages.Message): r"""An additional property for a SubstitutionsValue object. @@ -3244,6 +3271,7 @@ class Secret(_messages.Message): in size. There can be at most 100 secret values across all of a build's secrets. """ + @encoding.MapUnrecognizedFields('additionalProperties') class SecretEnvValue(_messages.Message): r"""Map of environment variable name to its encrypted value. Secret @@ -3258,6 +3286,7 @@ class SecretEnvValue(_messages.Message): Fields: additionalProperties: Additional properties of type SecretEnvValue """ + class AdditionalProperty(_messages.Message): r"""An additional property for a SecretEnvValue object. @@ -3370,6 +3399,7 @@ class SourceProvenance(_messages.Message): `source.storage_source_manifest`, if exists, with any revisions resolved. This feature is in Preview. """ + @encoding.MapUnrecognizedFields('additionalProperties') class FileHashesValue(_messages.Message): r"""Output only. Hash(es) of the build source, which can be used to verify @@ -3386,6 +3416,7 @@ class FileHashesValue(_messages.Message): Fields: additionalProperties: Additional properties of type FileHashesValue """ + class AdditionalProperty(_messages.Message): r"""An additional property for a FileHashesValue object. @@ -3433,6 +3464,7 @@ class StandardQueryParameters(_messages.Message): uploadType: Legacy upload protocol for media (e.g. "media", "multipart"). upload_protocol: Upload protocol for media (e.g. "raw", "multipart"). """ + class AltValueValuesEnum(_messages.Enum): r"""Data format for response. @@ -3488,6 +3520,7 @@ class Status(_messages.Message): user-facing error message should be localized and sent in the google.rpc.Status.details field, or localized by the client. """ + @encoding.MapUnrecognizedFields('additionalProperties') class DetailsValueListEntry(_messages.Message): r"""A DetailsValueListEntry object. @@ -3500,6 +3533,7 @@ class DetailsValueListEntry(_messages.Message): additionalProperties: Properties of the object. Contains field @type with type URL. """ + class AdditionalProperty(_messages.Message): r"""An additional property for a DetailsValueListEntry object. @@ -3660,6 +3694,7 @@ class Warning(_messages.Message): priority: The priority for this warning. text: Explanation of the warning generated. """ + class PriorityValueValuesEnum(_messages.Enum): r"""The priority for this warning. @@ -3692,6 +3727,7 @@ class WebhookConfig(_messages.Message): state: Potential issues with the underlying Pub/Sub subscription configuration. Only populated on get requests. """ + class StateValueValuesEnum(_messages.Enum): r"""Potential issues with the underlying Pub/Sub subscription configuration. Only populated on get requests. @@ -3771,6 +3807,7 @@ class WorkerPool(_messages.Message): updateTime: Output only. Time at which the request to update the `WorkerPool` was received. """ + class StateValueValuesEnum(_messages.Enum): r"""Output only. `WorkerPool` state. @@ -3802,6 +3839,7 @@ class AnnotationsValue(_messages.Message): Fields: additionalProperties: Additional properties of type AnnotationsValue """ + class AdditionalProperty(_messages.Message): r"""An additional property for a AnnotationsValue object. diff --git a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/message_matchers.py b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/message_matchers.py index 5b8753dfab65..9909d81a7719 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/message_matchers.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/message_matchers.py @@ -23,6 +23,7 @@ class MetricStructuredNameMatcher(BaseMatcher): """Matches a MetricStructuredName.""" + def __init__(self, name=IGNORED, origin=IGNORED, context=IGNORED): """Creates a MetricsStructuredNameMatcher. @@ -69,6 +70,7 @@ def describe_to(self, description): class MetricUpdateMatcher(BaseMatcher): """Matches a metrics update protocol buffer.""" + def __init__( self, cumulative=IGNORED, name=IGNORED, scalar=IGNORED, kind=IGNORED): """Creates a MetricUpdateMatcher. diff --git a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/message_matchers_test.py b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/message_matchers_test.py index 68dd06681ca0..156fcc7a0cb0 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/message_matchers_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/message_matchers_test.py @@ -35,6 +35,7 @@ @unittest.skipIf(base_api is None, 'GCP dependencies are not installed') class TestMatchers(unittest.TestCase): + def test_structured_name_matcher_basic(self): metric_name = dataflow.MetricStructuredName() metric_name.name = 'metric1' diff --git a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py index 8004762f5eec..0674bd576a2f 100644 --- a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py +++ b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py @@ -28,6 +28,7 @@ class NativeReadPTransformOverride(PTransformOverride): The DataflowRunner expects that the Read PTransform using native sources act as a primitive. So this override replaces the Read with a primitive. """ + def matches(self, applied_ptransform): # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position diff --git a/sdks/python/apache_beam/runners/dataflow/template_runner_test.py b/sdks/python/apache_beam/runners/dataflow/template_runner_test.py index 792c5cfd1655..60e026c57f8a 100644 --- a/sdks/python/apache_beam/runners/dataflow/template_runner_test.py +++ b/sdks/python/apache_beam/runners/dataflow/template_runner_test.py @@ -40,6 +40,7 @@ @unittest.skipIf(apiclient is None, 'GCP dependencies are not installed') class TemplatingDataflowRunnerTest(unittest.TestCase): """TemplatingDataflow tests.""" + def test_full_completion(self): # Create dummy file and close it. Note that we need to do this because # Windows does not allow NamedTemporaryFiles to be reopened elsewhere diff --git a/sdks/python/apache_beam/runners/dataflow/test_dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/test_dataflow_runner.py index 1550034afc73..2ab351d4eb74 100644 --- a/sdks/python/apache_beam/runners/dataflow/test_dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/test_dataflow_runner.py @@ -39,6 +39,7 @@ class TestDataflowRunner(DataflowRunner): + def run_pipeline(self, pipeline, options): """Execute test pipeline and verify test matcher""" test_options = options.view_as(TestOptions) diff --git a/sdks/python/apache_beam/runners/direct/bundle_factory.py b/sdks/python/apache_beam/runners/direct/bundle_factory.py index 95d8c06111a2..388b27cf7484 100644 --- a/sdks/python/apache_beam/runners/direct/bundle_factory.py +++ b/sdks/python/apache_beam/runners/direct/bundle_factory.py @@ -40,6 +40,7 @@ class BundleFactory(object): in case consecutive ones share the same timestamp and windows. DirectRunnerOptions.direct_runner_use_stacked_bundle controls this option. """ + def __init__(self, stacked: bool) -> None: self._stacked = stacked @@ -78,6 +79,7 @@ class _Bundle(common.Receiver): b = Bundle(stacked=False) """ + class _StackedWindowedValues(object): """A stack of WindowedValues with the same timestamp and windows. @@ -92,6 +94,7 @@ class _StackedWindowedValues(object): windowed_values = [wv for wv in s.windowed_values()] # now windowed_values equals to [windowed_value, another_windowed_value] """ + def __init__(self, initial_windowed_value): self._initial_windowed_value = initial_windowed_value self._appended_values = [] diff --git a/sdks/python/apache_beam/runners/direct/clock.py b/sdks/python/apache_beam/runners/direct/clock.py index 99e3bed3abea..34efd09d2a72 100644 --- a/sdks/python/apache_beam/runners/direct/clock.py +++ b/sdks/python/apache_beam/runners/direct/clock.py @@ -27,6 +27,7 @@ class Clock(object): + def time(self): """Returns the number of seconds since epoch.""" raise NotImplementedError() @@ -37,12 +38,14 @@ def advance_time(self, advance_by): class RealClock(object): + def time(self): return time.time() class TestClock(object): """Clock used for Testing""" + def __init__(self, current_time=None): self._current_time = current_time if current_time else Timestamp() diff --git a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py index 91085274f32a..184a3e691b9c 100644 --- a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py +++ b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py @@ -36,6 +36,7 @@ class ConsumerTrackingPipelineVisitor(PipelineVisitor): is used to schedule consuming PTransforms to consume input after the upstream transform has produced and committed output. """ + def __init__(self): self.value_to_consumers: Dict[pvalue.PValue, Set[AppliedPTransform]] = {} self.root_transforms: Set[AppliedPTransform] = set() diff --git a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor_test.py b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor_test.py index 7eba868afba0..e6df050d426d 100644 --- a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor_test.py +++ b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor_test.py @@ -40,6 +40,7 @@ class ConsumerTrackingPipelineVisitorTest(unittest.TestCase): + def setUp(self): self.pipeline = Pipeline(DirectRunner()) self.visitor = ConsumerTrackingPipelineVisitor() @@ -66,7 +67,9 @@ def test_root_transforms(self): self.assertEqual(len(self.visitor.step_names), 3) def test_side_inputs(self): + class SplitNumbersFn(DoFn): + def process(self, element): if element < 0: yield pvalue.TaggedOutput('tag_negative', element) @@ -74,6 +77,7 @@ def process(self, element): yield element class ProcessNumbersFn(DoFn): + def process(self, element, negatives): yield element @@ -148,14 +152,12 @@ def test_visitor_not_sorted(self): # Convert to string to assert they are equal. out_of_order_labels = { str(k): [str(t) for t in value_to_consumer] - for k, - value_to_consumer in v_out_of_order.value_to_consumers.items() + for k, value_to_consumer in v_out_of_order.value_to_consumers.items() } original_labels = { str(k): [str(t) for t in value_to_consumer] - for k, - value_to_consumer in v_original.value_to_consumers.items() + for k, value_to_consumer in v_original.value_to_consumers.items() } self.assertDictEqual(out_of_order_labels, original_labels) diff --git a/sdks/python/apache_beam/runners/direct/direct_metrics.py b/sdks/python/apache_beam/runners/direct/direct_metrics.py index 5beb19d4610a..e4242cc27b2c 100644 --- a/sdks/python/apache_beam/runners/direct/direct_metrics.py +++ b/sdks/python/apache_beam/runners/direct/direct_metrics.py @@ -40,6 +40,7 @@ class MetricAggregator(object): """For internal use only; no backwards-compatibility guarantees. Base interface for aggregating metric data during pipeline execution.""" + def identity_element(self): # type: () -> Any @@ -66,6 +67,7 @@ class CounterAggregator(MetricAggregator): Values aggregated should be ``int`` objects. """ + @staticmethod def identity_element(): # type: () -> int @@ -81,6 +83,7 @@ def result(self, x): class GenericAggregator(MetricAggregator): + def __init__(self, data_class): self._data_class = data_class @@ -95,6 +98,7 @@ def result(self, x): class DirectMetrics(MetricResults): + def __init__(self): self._counters = defaultdict(lambda: DirectMetric(CounterAggregator())) self._distributions = defaultdict( @@ -139,36 +143,36 @@ def query(self, filter=None): MetricResult( MetricKey(k.step, k.metric), v.extract_committed(), - v.extract_latest_attempted()) for k, - v in self._counters.items() if self.matches(filter, k) + v.extract_latest_attempted()) for k, v in self._counters.items() + if self.matches(filter, k) ] distributions = [ MetricResult( MetricKey(k.step, k.metric), v.extract_committed(), - v.extract_latest_attempted()) for k, - v in self._distributions.items() if self.matches(filter, k) + v.extract_latest_attempted()) + for k, v in self._distributions.items() if self.matches(filter, k) ] gauges = [ MetricResult( MetricKey(k.step, k.metric), v.extract_committed(), - v.extract_latest_attempted()) for k, - v in self._gauges.items() if self.matches(filter, k) + v.extract_latest_attempted()) for k, v in self._gauges.items() + if self.matches(filter, k) ] string_sets = [ MetricResult( MetricKey(k.step, k.metric), v.extract_committed(), - v.extract_latest_attempted()) for k, - v in self._string_sets.items() if self.matches(filter, k) + v.extract_latest_attempted()) for k, v in self._string_sets.items() + if self.matches(filter, k) ] bounded_tries = [ MetricResult( MetricKey(k.step, k.metric), v.extract_committed(), - v.extract_latest_attempted()) for k, - v in self._bounded_tries.items() if self.matches(filter, k) + v.extract_latest_attempted()) + for k, v in self._bounded_tries.items() if self.matches(filter, k) ] return { @@ -186,6 +190,7 @@ class DirectMetric(object): It keeps track of the metric's physical and logical updates. It's thread safe. """ + def __init__(self, aggregator): self.aggregator = aggregator self._attempted_lock = threading.Lock() diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 8b8937653688..c6b482e2a020 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -66,6 +66,7 @@ class SwitchingDirectRunner(PipelineRunner): which supports streaming execution and certain primitives not yet implemented in the FnApiRunner. """ + def is_fnapi_compatible(self): return BundleBasedDirectRunner.is_fnapi_compatible() @@ -78,6 +79,7 @@ def run_pipeline(self, pipeline, options): class _FnApiRunnerSupportVisitor(PipelineVisitor): """Visitor determining if a Pipeline can be run on the FnApiRunner.""" + def accept(self, pipeline): self.supported_by_fnapi_runner = True pipeline.visit(self) @@ -112,6 +114,7 @@ def visit_transform(self, applied_ptransform): class _PrismRunnerSupportVisitor(PipelineVisitor): """Visitor determining if a Pipeline can be run on the PrismRunner.""" + def accept(self, pipeline): self.supported_by_prism_runner = True pipeline.visit(self) @@ -193,6 +196,7 @@ def visit_transform(self, applied_ptransform): @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) class _GroupByKeyOnly(PTransform): """A group by key transform, ignoring windows.""" + def infer_output_type(self, input_type): key_type, value_type = trivial_inference.key_value_types(input_type) return typehints.KV[key_type, typehints.Iterable[value_type]] @@ -206,6 +210,7 @@ def expand(self, pcoll): @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) class _GroupAlsoByWindow(ParDo): """The GroupAlsoByWindow transform.""" + def __init__(self, windowing): super().__init__(_GroupAlsoByWindowDoFn(windowing)) self.windowing = windowing @@ -280,6 +285,7 @@ def from_runner_api_parameter(unused_ptransform, payload, context): @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) class _GroupByKey(PTransform): """The DirectRunner GroupByKey implementation.""" + def expand(self, pcoll): # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position @@ -342,6 +348,7 @@ def _get_transform_overrides(pipeline_options): from apache_beam.runners.direct.sdf_direct_runner import SplittableParDoOverride class CombinePerKeyOverride(PTransformOverride): + def matches(self, applied_ptransform): if isinstance(applied_ptransform.transform, CombinePerKey): return applied_ptransform.inputs[0].windowing.is_default() @@ -359,6 +366,7 @@ def get_replacement_transform_for_applied_ptransform( return transform class StreamingGroupByKeyOverride(PTransformOverride): + def matches(self, applied_ptransform): # Note: we match the exact class, since we replace it with a subclass. return applied_ptransform.transform.__class__ == _GroupByKeyOnly @@ -370,6 +378,7 @@ def get_replacement_transform_for_applied_ptransform( return transform class StreamingGroupAlsoByWindowOverride(PTransformOverride): + def matches(self, applied_ptransform): # Note: we match the exact class, since we replace it with a subclass. transform = applied_ptransform.transform @@ -386,6 +395,7 @@ def get_replacement_transform_for_applied_ptransform( return transform class TestStreamOverride(PTransformOverride): + def matches(self, applied_ptransform): from apache_beam.testing.test_stream import TestStream self.applied_ptransform = applied_ptransform @@ -401,6 +411,7 @@ class GroupByKeyPTransformOverride(PTransformOverride): This replaces the Beam implementation as a primitive. """ + def matches(self, applied_ptransform): # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position @@ -442,6 +453,7 @@ def get_replacement_transform_for_applied_ptransform( class _DirectReadFromPubSub(PTransform): + def __init__(self, source): self._source = source @@ -517,6 +529,7 @@ def _get_pubsub_transform_overrides(pipeline_options): from apache_beam.pipeline import PTransformOverride class ReadFromPubSubOverride(PTransformOverride): + def matches(self, applied_ptransform): return isinstance( applied_ptransform.transform, beam_pubsub.ReadFromPubSub) @@ -530,6 +543,7 @@ def get_replacement_transform_for_applied_ptransform( return _DirectReadFromPubSub(applied_ptransform.transform._source) class WriteToPubSubOverride(PTransformOverride): + def matches(self, applied_ptransform): return isinstance(applied_ptransform.transform, beam_pubsub.WriteToPubSub) @@ -546,6 +560,7 @@ def get_replacement_transform_for_applied_ptransform( class BundleBasedDirectRunner(PipelineRunner): """Executes a single pipeline on the local machine.""" + @staticmethod def is_fnapi_compatible(): return False @@ -568,6 +583,7 @@ def run_pipeline(self, pipeline, options): class VerifyNoCrossLanguageTransforms(PipelineVisitor): """Visitor determining whether a Pipeline uses a TestStream.""" + def visit_transform(self, applied_ptransform): if isinstance(applied_ptransform.transform, ExternalTransform): raise RuntimeError( @@ -581,6 +597,7 @@ def visit_transform(self, applied_ptransform): # If the TestStream I/O is used, use a mock test clock. class TestStreamUsageVisitor(PipelineVisitor): """Visitor determining whether a Pipeline uses a TestStream.""" + def __init__(self): self.uses_test_stream = False @@ -631,6 +648,7 @@ def visit_transform(self, applied_ptransform): class DirectPipelineResult(PipelineResult): """A DirectPipelineResult provides access to info about a pipeline.""" + def __init__(self, executor, evaluation_context): super().__init__(PipelineState.RUNNING) self._executor = executor diff --git a/sdks/python/apache_beam/runners/direct/direct_runner_test.py b/sdks/python/apache_beam/runners/direct/direct_runner_test.py index 1af5f1bc7bea..a46b83767a45 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner_test.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner_test.py @@ -44,6 +44,7 @@ class DirectPipelineResultTest(unittest.TestCase): + def test_waiting_on_result_stops_executor_threads(self): pre_test_threads = set(t.ident for t in threading.enumerate()) @@ -60,7 +61,9 @@ def test_waiting_on_result_stops_executor_threads(self): self.assertEqual(len(new_threads), 0) def test_direct_runner_metrics(self): + class MyDoFn(beam.DoFn): + def start_bundle(self): count = Metrics.counter(self.__class__, 'bundles') count.inc() @@ -141,6 +144,7 @@ def test_create_runner(self): class BundleBasedRunnerTest(unittest.TestCase): + def test_type_hints(self): with test_pipeline.TestPipeline(runner='BundleBasedDirectRunner') as p: _ = ( @@ -154,6 +158,7 @@ def test_impulse(self): class DirectRunnerRetryTests(unittest.TestCase): + def test_retry_fork_graph(self): # TODO(https://github.com/apache/beam/issues/18640): The FnApiRunner # currently does not currently support retries. @@ -183,7 +188,9 @@ def f_c(x): assert count_b == count_c == 4 def test_no_partial_writeouts(self): + class TestTransformEvaluator(_TransformEvaluator): + def __init__(self): self._execution_context = _ExecutionContext(None, {}) diff --git a/sdks/python/apache_beam/runners/direct/direct_userstate.py b/sdks/python/apache_beam/runners/direct/direct_userstate.py index 196a9a048d7a..6a7eb7fe990b 100644 --- a/sdks/python/apache_beam/runners/direct/direct_userstate.py +++ b/sdks/python/apache_beam/runners/direct/direct_userstate.py @@ -28,6 +28,7 @@ class DirectRuntimeState(userstate.RuntimeState): + def __init__(self, state_spec, state_tag, current_value_accessor): self._state_spec = state_spec self._state_tag = state_tag @@ -61,6 +62,7 @@ def _decode(self, value): class ReadModifyWriteRuntimeState(DirectRuntimeState, userstate.ReadModifyWriteRuntimeState): + def __init__(self, state_spec, state_tag, current_value_accessor): super().__init__(state_spec, state_tag, current_value_accessor) self._value = UNREAD_VALUE @@ -94,6 +96,7 @@ def is_modified(self): class BagRuntimeState(DirectRuntimeState, userstate.BagRuntimeState): + def __init__(self, state_spec, state_tag, current_value_accessor): super().__init__(state_spec, state_tag, current_value_accessor) self._cached_value = UNREAD_VALUE @@ -119,6 +122,7 @@ def clear(self): class SetRuntimeState(DirectRuntimeState, userstate.SetRuntimeState): + def __init__(self, state_spec, state_tag, current_value_accessor): super().__init__(state_spec, state_tag, current_value_accessor) self._current_accumulator = UNREAD_VALUE @@ -151,6 +155,7 @@ def is_modified(self): class CombiningValueRuntimeState(DirectRuntimeState, userstate.CombiningValueRuntimeState): """Combining value state interface object passed to user code.""" + def __init__(self, state_spec, state_tag, current_value_accessor): super().__init__(state_spec, state_tag, current_value_accessor) self._current_accumulator = UNREAD_VALUE @@ -195,6 +200,7 @@ class DirectUserStateContext(userstate.UserStateContext): The DirectUserStateContext buffers up updates that are to be committed by the TransformEvaluator after running a DoFn. """ + def __init__(self, step_context, dofn, key_coder): self.step_context = step_context self.dofn = dofn diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py b/sdks/python/apache_beam/runners/direct/evaluation_context.py index c34735499abc..1160814bac83 100644 --- a/sdks/python/apache_beam/runners/direct/evaluation_context.py +++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py @@ -53,6 +53,7 @@ class _ExecutionContext(object): It holds the watermarks for that transform, as well as keyed states. """ + def __init__(self, watermarks: '_TransformWatermarks', keyed_states): self.watermarks = watermarks self.keyed_states = keyed_states @@ -69,6 +70,7 @@ def reset(self): class _SideInputView(object): + def __init__(self, view): self._view = view self.blocked_tasks = collections.deque() @@ -88,6 +90,7 @@ class _SideInputsContainer(object): It provides methods for blocking until a side-input is available and writing to a side input. """ + def __init__(self, side_inputs: Iterable['pvalue.AsSideInput']) -> None: self._lock = threading.Lock() self._views: Dict[pvalue.AsSideInput, _SideInputView] = {} @@ -225,6 +228,7 @@ class EvaluationContext(object): appropriately. This includes updating the per-(step,key) state, updating global watermarks, and executing any callbacks that can be executed. """ + def __init__( self, pipeline_options, @@ -346,8 +350,8 @@ def _update_side_inputs_container( registered as a PCollectionView, we add the result to the PCollectionView. """ if (result.uncommitted_output_bundles and - result.uncommitted_output_bundles[0].pcollection in - self._pcollection_to_views): + result.uncommitted_output_bundles[0].pcollection + in self._pcollection_to_views): for view in self._pcollection_to_views[ result.uncommitted_output_bundles[0].pcollection]: for committed_bundle in committed_bundles: @@ -436,12 +440,14 @@ def shutdown(self): class DirectUnmergedState(InMemoryUnmergedState): """UnmergedState implementation for the DirectRunner.""" + def __init__(self): super().__init__(defensive_copy=False) class DirectStepContext(object): """Context for the currently-executing step.""" + def __init__(self, existing_keyed_state): self.existing_keyed_state = existing_keyed_state # In order to avoid partial writes of a bundle, every time diff --git a/sdks/python/apache_beam/runners/direct/executor.py b/sdks/python/apache_beam/runners/direct/executor.py index e8be9d64f993..74024c45bb1f 100644 --- a/sdks/python/apache_beam/runners/direct/executor.py +++ b/sdks/python/apache_beam/runners/direct/executor.py @@ -49,7 +49,9 @@ class _ExecutorService(object): """Thread pool for executing tasks in parallel.""" + class CallableTask(object): + def call(self, state_sampler): pass @@ -145,6 +147,7 @@ def shutdown(self): class _TransformEvaluationState(object): + def __init__(self, executor_service, scheduled: Set['TransformExecutor']): self.executor_service = executor_service self.scheduled = scheduled @@ -178,6 +181,7 @@ class _SerialEvaluationState(_TransformEvaluationState): A principal use of this is for evaluators that keeps a global state such as _GroupByKeyOnly. """ + def __init__(self, executor_service, scheduled): super().__init__(executor_service, scheduled) self.serial_queue = collections.deque() @@ -210,6 +214,7 @@ class _TransformExecutorServices(object): Controls the concurrency as appropriate for the applied transform the executor exists for. """ + def __init__(self, executor_service: _ExecutorService) -> None: self._executor_service = executor_service self._scheduled: Set[TransformExecutor] = set() @@ -240,6 +245,7 @@ class _CompletionCallback(object): that are triggered due to the arrival of elements from an upstream transform, or for a source transform. """ + def __init__( self, evaluation_context: 'EvaluationContext', @@ -408,6 +414,7 @@ def attempt_call( class Executor(object): """For internal use only; no backwards-compatibility guarantees.""" + def __init__(self, *args, **kwargs): self._executor = _ExecutorServiceParallelExecutor(*args, **kwargs) @@ -513,6 +520,7 @@ def schedule_consumption( class _TypedUpdateQueue(object): """Type checking update queue with blocking and non-blocking operations.""" + def __init__(self, item_type): self._item_type = item_type self._queue = queue.Queue() @@ -544,6 +552,7 @@ def offer(self, item): class _ExecutorUpdate(object): """An internal status update on the state of the executor.""" + def __init__( self, transform_executor, @@ -565,12 +574,14 @@ class _VisibleExecutorUpdate(object): Used for awaiting the completion to decide whether to return normally or raise an exception. """ + def __init__(self, exception=None): self.finished = exception is not None self.exception = exception class _MonitorTask(_ExecutorService.CallableTask): """MonitorTask continuously runs to ensure that pipeline makes progress.""" + def __init__(self, executor: '_ExecutorServiceParallelExecutor') -> None: self._executor = executor diff --git a/sdks/python/apache_beam/runners/direct/helper_transforms.py b/sdks/python/apache_beam/runners/direct/helper_transforms.py index 0e88c021e2f9..b8e85188ab24 100644 --- a/sdks/python/apache_beam/runners/direct/helper_transforms.py +++ b/sdks/python/apache_beam/runners/direct/helper_transforms.py @@ -31,6 +31,7 @@ class LiftedCombinePerKey(beam.PTransform): """An implementation of CombinePerKey that does mapper-side pre-combining. """ + def __init__(self, combine_fn, args, kwargs): args_to_check = itertools.chain(args, kwargs.values()) if isinstance(combine_fn, _CurriedFn): @@ -55,6 +56,7 @@ class PartialGroupByKeyCombiningValues(beam.DoFn): As bundles are in-memory-sized, we don't bother flushing until the very end. """ + def __init__(self, combine_fn): self._combine_fn = combine_fn @@ -94,6 +96,7 @@ def default_type_hints(self): class FinishCombine(beam.DoFn): """Merges partially combined results. """ + def __init__(self, combine_fn): self._combine_fn = combine_fn diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py index e0a58db0ef3e..199a30fd9bd9 100644 --- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py @@ -54,6 +54,7 @@ class SplittableParDoOverride(PTransformOverride): Replaces the ParDo transform with a SplittableParDo transform that performs SDF specific logic. """ + def matches(self, applied_ptransform): assert isinstance(applied_ptransform, AppliedPTransform) transform = applied_ptransform.transform @@ -75,6 +76,7 @@ def get_replacement_transform_for_applied_ptransform( class SplittableParDo(PTransform): """A transform that processes a PCollection using a Splittable DoFn.""" + def __init__(self, ptransform): assert isinstance(ptransform, ParDo) self._ptransform = ptransform @@ -104,6 +106,7 @@ def expand(self, pcoll): class ElementAndRestriction(object): """A holder for an element, restriction, and watermark estimator state.""" + def __init__(self, element, restriction, watermark_estimator_state): self.element = element self.restriction = restriction @@ -112,6 +115,7 @@ def __init__(self, element, restriction, watermark_estimator_state): class PairWithRestrictionFn(beam.DoFn): """A transform that pairs each element with a restriction.""" + def __init__(self, do_fn): self._signature = DoFnSignature(do_fn) @@ -132,6 +136,7 @@ def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs): class SplitRestrictionFn(beam.DoFn): """A transform that perform initial splitting of Splittable DoFn inputs.""" + def __init__(self, do_fn): self._do_fn = do_fn @@ -157,12 +162,14 @@ class ExplodeWindowsFn(beam.DoFn): This is done to make sure that Splittable DoFn proceses an element for each of the windows that element belongs to. """ + def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs): yield element class RandomUniqueKeyFn(beam.DoFn): """A transform that assigns a unique key to each element.""" + def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs): # We ignore UUID collisions here since they are extremely rare. yield (uuid.uuid4().bytes, element) @@ -174,6 +181,7 @@ class ProcessKeyedElements(PTransform): Input to this transform should be a PCollection of keyed ElementAndRestriction objects. """ + def __init__( self, sdf, @@ -197,6 +205,7 @@ def expand(self, pcoll): class ProcessKeyedElementsViaKeyedWorkItemsOverride(PTransformOverride): """A transform override for ProcessElements transform.""" + def matches(self, applied_ptransform): return isinstance(applied_ptransform.transform, ProcessKeyedElements) @@ -207,6 +216,7 @@ def get_replacement_transform_for_applied_ptransform( class ProcessKeyedElementsViaKeyedWorkItems(PTransform): """A transform that processes Splittable DoFn input via KeyedWorkItems.""" + def __init__(self, process_keyed_elements_transform): self._process_keyed_elements_transform = process_keyed_elements_transform @@ -227,6 +237,7 @@ class ProcessElements(PTransform): Will be evaluated by `runners.direct.transform_evaluator._ProcessElementsEvaluator`. """ + def __init__(self, process_keyed_elements_transform): self._process_keyed_elements_transform = process_keyed_elements_transform self.sdf = self._process_keyed_elements_transform.sdf @@ -257,6 +268,7 @@ class ProcessFn(beam.DoFn): (4) after the final invocation of a given element clear any previous state set for re-invoking the element and release the output watermark. """ + def __init__(self, sdf, args_for_invoker, kwargs_for_invoker): self.sdf = sdf self._element_tag = _ReadModifyWriteStateTag('element') @@ -416,7 +428,9 @@ class SDFProcessElementInvoker(object): produced this class ends the execution and performs steps to finalize the current invocation. """ + class Result(object): + def __init__( self, residual_restriction=None, @@ -467,6 +481,7 @@ def invoke_process_element( assert isinstance(sdf_invoker, DoFnInvoker) class CheckpointState(object): + def __init__(self): self.checkpointed = None self.residual_restriction = None @@ -534,6 +549,7 @@ def initiate_checkpoint(): class _OutputHandler(OutputHandler): + def __init__(self): self.output_iter = None @@ -549,6 +565,7 @@ def reset(self): class _NoneShallPassOutputHandler(OutputHandler): + def handle_process_outputs( self, windowed_input_element: WindowedValue, diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py index 246d180cddee..5cc1cc768e9b 100644 --- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py +++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py @@ -41,6 +41,7 @@ class ReadFilesProvider(RestrictionProvider): + def initial_restriction(self, element): size = os.path.getsize(element) return OffsetRange(0, size) @@ -53,6 +54,7 @@ def restriction_size(self, element, restriction): class ReadFiles(DoFn): + def __init__(self, resume_count=None): self._resume_count = resume_count @@ -92,6 +94,7 @@ def process( class ExpandStringsProvider(RestrictionProvider): + def initial_restriction(self, element): return OffsetRange(0, len(element[0])) @@ -109,6 +112,7 @@ def restriction_size(self, element, restriction): class ExpandStrings(DoFn): + def __init__(self, record_window=False): self._record_window = record_window @@ -145,6 +149,7 @@ def process( class SDFDirectRunnerTest(unittest.TestCase): + def setUp(self): super().setUp() # Importing following for DirectRunner SDF implemenation for testing. diff --git a/sdks/python/apache_beam/runners/direct/test_direct_runner.py b/sdks/python/apache_beam/runners/direct/test_direct_runner.py index 084820eb17c6..c7753a446c28 100644 --- a/sdks/python/apache_beam/runners/direct/test_direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/test_direct_runner.py @@ -29,6 +29,7 @@ class TestDirectRunner(DirectRunner): + def run_pipeline(self, pipeline, options): """Execute test pipeline and verify test matcher""" test_options = options.view_as(TestOptions) diff --git a/sdks/python/apache_beam/runners/direct/test_stream_impl.py b/sdks/python/apache_beam/runners/direct/test_stream_impl.py index c720418b05ed..2afa4ca545af 100644 --- a/sdks/python/apache_beam/runners/direct/test_stream_impl.py +++ b/sdks/python/apache_beam/runners/direct/test_stream_impl.py @@ -66,6 +66,7 @@ class _WatermarkController(PTransform): - If the instance receives an ElementEvent, it emits all specified elements to the Global Window with the event time set to the element's timestamp. """ + def __init__(self, output_tag): self.output_tag = output_tag @@ -79,6 +80,7 @@ def expand(self, pcoll): class _ExpandableTestStream(PTransform): + def __init__(self, test_stream): self.test_stream = test_stream diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index b0278ba5356c..2f614ba23512 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -182,6 +182,7 @@ def should_execute_serially(self, applied_ptransform): class RootBundleProvider(object): """Provides bundles for the initial execution of a root transform.""" + def __init__(self, evaluation_context, applied_ptransform): self._evaluation_context = evaluation_context self._applied_ptransform = applied_ptransform @@ -192,6 +193,7 @@ def get_root_bundles(self): class DefaultRootBundleProvider(RootBundleProvider): """Provides an empty bundle by default for root transforms.""" + def get_root_bundles(self): input_node = pvalue.PBegin(self._applied_ptransform.transform.pipeline) empty_bundle = ( @@ -206,6 +208,7 @@ class _TestStreamRootBundleProvider(RootBundleProvider): bundle emitted from the TestStream afterwards is its state: index into the stream, and the watermark. """ + def get_root_bundles(self): test_stream = self._applied_ptransform.transform @@ -230,6 +233,7 @@ def get_root_bundles(self): class _TransformEvaluator(object): """An evaluator of a specific application of a transform.""" + def __init__( self, evaluation_context: 'EvaluationContext', @@ -354,8 +358,7 @@ def _read_values_to_bundles(reader): return self._split_list_into_bundles( output_pcollection, read_result, - _BoundedReadEvaluator.MAX_ELEMENT_PER_BUNDLE, - lambda _: 1) + _BoundedReadEvaluator.MAX_ELEMENT_PER_BUNDLE, lambda _: 1) if isinstance(self._source, io.iobase.BoundedSource): # Getting a RangeTracker for the default range of the source and reading @@ -451,6 +454,7 @@ class _PairWithTimingEvaluator(_TransformEvaluator): KV(element, `TimingInfo`). Where the `TimingInfo` contains both the processing time timestamp and watermark. """ + def __init__( self, evaluation_context, @@ -727,6 +731,7 @@ def finish_bundle(self) -> TransformResult: class _FlattenEvaluator(_TransformEvaluator): """TransformEvaluator for Flatten transform.""" + def __init__( self, evaluation_context, @@ -755,6 +760,7 @@ def finish_bundle(self): class _ImpulseEvaluator(_TransformEvaluator): """TransformEvaluator for Impulse transform.""" + def finish_bundle(self): assert len(self._outputs) == 1 output_pcollection = list(self._outputs)[0] @@ -765,6 +771,7 @@ def finish_bundle(self): class _TaggedReceivers(dict): """Received ParDo output and redirect to the associated output bundle.""" + def __init__(self, evaluation_context): self._evaluation_context = evaluation_context self._null_receiver = None @@ -772,11 +779,13 @@ def __init__(self, evaluation_context): class NullReceiver(common.Receiver): """Ignores undeclared outputs, default execution mode.""" + def receive(self, element: WindowedValue) -> None: pass class _InMemoryReceiver(common.Receiver): """Buffers undeclared outputs to the given dictionary.""" + def __init__(self, target, tag): self._target = target self._tag = tag @@ -792,6 +801,7 @@ def __missing__(self, key): class _ParDoEvaluator(_TransformEvaluator): """TransformEvaluator for ParDo transform.""" + def __init__( self, evaluation_context: 'EvaluationContext', @@ -1048,6 +1058,7 @@ class _StreamingGroupAlsoByWindowEvaluator(_TransformEvaluator): GroupAlsoByWindow operation is evaluated as a normal DoFn, as defined in transforms/core.py. """ + def __init__( self, evaluation_context, diff --git a/sdks/python/apache_beam/runners/direct/util.py b/sdks/python/apache_beam/runners/direct/util.py index 11081c1289b2..3b5efc58baf6 100644 --- a/sdks/python/apache_beam/runners/direct/util.py +++ b/sdks/python/apache_beam/runners/direct/util.py @@ -25,6 +25,7 @@ class TransformResult(object): """Result of evaluating an AppliedPTransform with a TransformEvaluator.""" + def __init__( self, transform_evaluator, @@ -57,6 +58,7 @@ def __init__( class TimerFiring(object): """A single instance of a fired timer.""" + def __init__( self, encoded_key, @@ -83,6 +85,7 @@ def __repr__(self): class KeyedWorkItem(object): """A keyed item that can either be a timer firing or a list of elements.""" + def __init__(self, encoded_key, timer_firings=None, elements=None): self.encoded_key = encoded_key self.timer_firings = timer_firings or [] diff --git a/sdks/python/apache_beam/runners/direct/watermark_manager.py b/sdks/python/apache_beam/runners/direct/watermark_manager.py index 666ade6cf82d..9b8ae3431b5c 100644 --- a/sdks/python/apache_beam/runners/direct/watermark_manager.py +++ b/sdks/python/apache_beam/runners/direct/watermark_manager.py @@ -198,6 +198,7 @@ def extract_all_timers( class _TransformWatermarks(object): """Tracks input and output watermarks for an AppliedPTransform.""" + def __init__(self, clock, keyed_states, transform): self._clock = clock self._keyed_states = keyed_states diff --git a/sdks/python/apache_beam/runners/interactive/augmented_pipeline.py b/sdks/python/apache_beam/runners/interactive/augmented_pipeline.py index c1adc0c4a4f7..f46aaddb0af2 100644 --- a/sdks/python/apache_beam/runners/interactive/augmented_pipeline.py +++ b/sdks/python/apache_beam/runners/interactive/augmented_pipeline.py @@ -40,6 +40,7 @@ class AugmentedPipeline: PCollections defined by the user, reads computed PCollections as source and prunes unnecessary pipeline parts for fast computation. """ + def __init__( self, user_pipeline: beam.Pipeline, diff --git a/sdks/python/apache_beam/runners/interactive/augmented_pipeline_test.py b/sdks/python/apache_beam/runners/interactive/augmented_pipeline_test.py index 1bafb9fb16f5..9bbc799fa1eb 100644 --- a/sdks/python/apache_beam/runners/interactive/augmented_pipeline_test.py +++ b/sdks/python/apache_beam/runners/interactive/augmented_pipeline_test.py @@ -28,6 +28,7 @@ class CacheableTest(unittest.TestCase): + def setUp(self): ie.new_env() @@ -65,6 +66,7 @@ def test_ignore_pcoll_from_other_pipeline(self): class AugmentTest(unittest.TestCase): + def setUp(self): ie.new_env() diff --git a/sdks/python/apache_beam/runners/interactive/background_caching_job.py b/sdks/python/apache_beam/runners/interactive/background_caching_job.py index 71f7f77ded4e..d72cae0b298d 100644 --- a/sdks/python/apache_beam/runners/interactive/background_caching_job.py +++ b/sdks/python/apache_beam/runners/interactive/background_caching_job.py @@ -65,6 +65,7 @@ class BackgroundCachingJob(object): In both situations, the background source recording job should be treated as done successfully. """ + def __init__(self, pipeline_result, limiters): self._pipeline_result = pipeline_result self._result_lock = threading.RLock() diff --git a/sdks/python/apache_beam/runners/interactive/background_caching_job_test.py b/sdks/python/apache_beam/runners/interactive/background_caching_job_test.py index aef2f768237e..4a16e697987a 100644 --- a/sdks/python/apache_beam/runners/interactive/background_caching_job_test.py +++ b/sdks/python/apache_beam/runners/interactive/background_caching_job_test.py @@ -77,6 +77,7 @@ def _setup_test_streaming_cache(pipeline): not ie.current_env().is_interactive_ready, '[interactive] dependency is not installed.') class BackgroundCachingJobTest(unittest.TestCase): + def tearDown(self): ie.new_env() @@ -84,24 +85,23 @@ def tearDown(self): # that meet the boundedness checks. @patch( 'apache_beam.runners.interactive.background_caching_job' - '.has_source_to_cache', - lambda x: True) + '.has_source_to_cache', lambda x: True) # Disable the clean up so that we can keep the test streaming cache. @patch( 'apache_beam.runners.interactive.interactive_environment' - '.InteractiveEnvironment.cleanup', - lambda x, - y: None) + '.InteractiveEnvironment.cleanup', lambda x, y: None) def test_background_caching_job_starts_when_none_such_job_exists(self): # Create a fake PipelineResult and PipelineRunner. This is because we want # to test whether the BackgroundCachingJob can be started without having to # rely on a real pipeline run. class FakePipelineResult(beam.runners.runner.PipelineResult): + def wait_until_finish(self): return class FakePipelineRunner(beam.runners.PipelineRunner): + def run_pipeline(self, pipeline, options): return FakePipelineResult(beam.runners.runner.PipelineState.RUNNING) @@ -126,8 +126,7 @@ def run_pipeline(self, pipeline, options): @patch( 'apache_beam.runners.interactive.background_caching_job' - '.has_source_to_cache', - lambda x: False) + '.has_source_to_cache', lambda x: False) def test_background_caching_job_not_start_for_batch_pipeline(self): p = beam.Pipeline() @@ -138,14 +137,11 @@ def test_background_caching_job_not_start_for_batch_pipeline(self): @patch( 'apache_beam.runners.interactive.background_caching_job' - '.has_source_to_cache', - lambda x: True) + '.has_source_to_cache', lambda x: True) # Disable the clean up so that we can keep the test streaming cache. @patch( 'apache_beam.runners.interactive.interactive_environment' - '.InteractiveEnvironment.cleanup', - lambda x, - y: None) + '.InteractiveEnvironment.cleanup', lambda x, y: None) def test_background_caching_job_not_start_when_such_job_exists(self): p = _build_a_test_stream_pipeline() _setup_test_streaming_cache(p) @@ -163,14 +159,11 @@ def test_background_caching_job_not_start_when_such_job_exists(self): @patch( 'apache_beam.runners.interactive.background_caching_job' - '.has_source_to_cache', - lambda x: True) + '.has_source_to_cache', lambda x: True) # Disable the clean up so that we can keep the test streaming cache. @patch( 'apache_beam.runners.interactive.interactive_environment' - '.InteractiveEnvironment.cleanup', - lambda x, - y: None) + '.InteractiveEnvironment.cleanup', lambda x, y: None) def test_background_caching_job_not_start_when_such_job_is_done(self): p = _build_a_test_stream_pipeline() _setup_test_streaming_cache(p) @@ -293,6 +286,7 @@ def test_source_to_cache_not_changed_when_source_is_removed(self, cell): self.assertNotEqual(signature_with_only_foo, signature_with_foo_bar) class BarPruneVisitor(PipelineVisitor): + def enter_composite_transform(self, transform_node): pruned_parts = list(transform_node.parts) for part in transform_node.parts: diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager.py b/sdks/python/apache_beam/runners/interactive/cache_manager.py index ac592475c057..17877e08dee6 100644 --- a/sdks/python/apache_beam/runners/interactive/cache_manager.py +++ b/sdks/python/apache_beam/runners/interactive/cache_manager.py @@ -40,6 +40,7 @@ class CacheManager(object): 'full' or 'sample') and a cache_label which is a hash of the PCollection derivation. """ + def exists(self, *labels): # type (*str) -> bool @@ -158,13 +159,10 @@ class FileBasedCacheManager(CacheManager): _available_formats = { 'text': ( lambda path: textio.ReadFromText( - path, - coder=Base64Coder(), - compression_type=filesystems.CompressionTypes.BZIP2), - lambda path: textio.WriteToText( - path, - coder=Base64Coder(), - compression_type=filesystems.CompressionTypes.BZIP2)), + path, coder=Base64Coder(), compression_type=filesystems. + CompressionTypes.BZIP2), lambda path: textio.WriteToText( + path, coder=Base64Coder(), compression_type=filesystems. + CompressionTypes.BZIP2)), 'tfrecord': (tfrecordio.ReadFromTFRecord, tfrecordio.WriteToTFRecord) } @@ -311,6 +309,7 @@ def _match(self, *labels): class _CacheVersion(object): """This class keeps track of the timestamp and the corresponding version.""" + def __init__(self): self.current_version = -1 self.current_timestamp = 0 @@ -332,6 +331,7 @@ def get_version(self, timestamp): class ReadCache(beam.PTransform): """A PTransform that reads the PCollections from the cache.""" + def __init__(self, cache_manager, label): self._cache_manager = cache_manager self._label = label @@ -343,6 +343,7 @@ def expand(self, pbegin): class WriteCache(beam.PTransform): """A PTransform that writes the PCollections to the cache.""" + def __init__( self, cache_manager, diff --git a/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py b/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py index 5b1b9effe5c5..62ff686d24ad 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py +++ b/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py @@ -42,6 +42,7 @@ class ExpressionCache(object): This object can be created and destroyed whenever. This class holds no state and the only side-effect is modifying the given expression. """ + def __init__(self, pcollection_cache=None, computed_cache=None): from apache_beam.runners.interactive import interactive_environment as ie diff --git a/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py index c6e46f3cc3ff..7b7252202fb3 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py +++ b/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py @@ -23,6 +23,7 @@ class ExpressionCacheTest(unittest.TestCase): + def setUp(self): self._pcollection_cache = {} self._computed_cache = set() diff --git a/sdks/python/apache_beam/runners/interactive/caching/read_cache.py b/sdks/python/apache_beam/runners/interactive/caching/read_cache.py index cf0859d5804a..16573118b8e2 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/read_cache.py +++ b/sdks/python/apache_beam/runners/interactive/caching/read_cache.py @@ -35,6 +35,7 @@ class ReadCache: """Class that facilitates reading cache of computed PCollections. """ + def __init__( self, pipeline: beam_runner_api_pb2.Pipeline, @@ -84,8 +85,8 @@ def read_cache(self) -> Tuple[str, str]: self._pipeline.components.coders[coder_id].CopyFrom( template.components.coders[coder_id]) for windowing_strategy_id in template.components.windowing_strategies: - if (windowing_strategy_id in - self._pipeline.components.windowing_strategies): + if (windowing_strategy_id + in self._pipeline.components.windowing_strategies): continue self._pipeline.components.windowing_strategies[ windowing_strategy_id].CopyFrom( @@ -129,6 +130,7 @@ def _build_runner_api_template( class _ReadCacheTransform(PTransform): """A composite transform encapsulates reading cache of PCollections. """ + def __init__(self, cache_manager: cache.CacheManager, key: str): self._cache_manager = cache_manager self._key = key diff --git a/sdks/python/apache_beam/runners/interactive/caching/read_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/read_cache_test.py index d32c265e553f..a60cce24506b 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/read_cache_test.py +++ b/sdks/python/apache_beam/runners/interactive/caching/read_cache_test.py @@ -31,6 +31,7 @@ class ReadCacheTest(unittest.TestCase): + def setUp(self): ie.new_env() diff --git a/sdks/python/apache_beam/runners/interactive/caching/reify.py b/sdks/python/apache_beam/runners/interactive/caching/reify.py index c82033dc1b9b..a1fa76ca6060 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/reify.py +++ b/sdks/python/apache_beam/runners/interactive/caching/reify.py @@ -39,6 +39,7 @@ class Reify(beam.DoFn): Internally used to capture window info with each element into cache for replayability. """ + def process(self, e, wv=beam.DoFn.WindowedValueParam): yield test_stream.WindowedValueHolder(wv) @@ -48,6 +49,7 @@ class Unreify(beam.DoFn): Cached values are elements with window info. This unpacks the elements. """ + def process(self, e): # Row coder was used when encoding windowed values. if isinstance(e, beam.Row) and hasattr(e, 'windowed_value'): diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py index 064246a97087..5008fdd27761 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py +++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py @@ -62,6 +62,7 @@ class StreamingCacheSink(beam.PTransform): source/sink types aside from file based. Also, generalize to cases where there might be multiple workers writing to the same sink. """ + def __init__( self, cache_dir, @@ -93,12 +94,14 @@ def size_in_bytes(self): return 0 def expand(self, pcoll): + class StreamingWriteToText(beam.DoFn): """DoFn that performs the writing. Note that the other file writing methods cannot be used in streaming contexts. """ + def __init__(self, full_path, coder=SafeFastPrimitivesCoder()): self._full_path = full_path self._coder = coder @@ -143,6 +146,7 @@ class StreamingCacheSource: This class is used to read from file and send its to the TestStream via the StreamingCacheManager.Reader. """ + def __init__(self, cache_dir, labels, is_cache_complete=None, coder=None): if not coder: coder = SafeFastPrimitivesCoder() @@ -241,6 +245,7 @@ def read(self, tail): class StreamingCache(CacheManager): """Abstraction that holds the logic for reading and writing to cache. """ + def __init__( self, cache_dir, @@ -427,6 +432,7 @@ class Reader(object): This class is also responsible for holding the state of the clock, injecting clock advancement events, and watermark advancement events. """ + def __init__(self, headers, readers): # This timestamp is used as the monotonic clock to order events in the # replay. diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py index 6f3c17a0ff37..cdee4e7ec055 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py +++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py @@ -42,6 +42,7 @@ class StreamingCacheTest(unittest.TestCase): + def setUp(self): pass diff --git a/sdks/python/apache_beam/runners/interactive/caching/write_cache.py b/sdks/python/apache_beam/runners/interactive/caching/write_cache.py index d398e70338b6..39722a7287fe 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/write_cache.py +++ b/sdks/python/apache_beam/runners/interactive/caching/write_cache.py @@ -35,6 +35,7 @@ class WriteCache: """Class that facilitates writing cache for PCollections being computed. """ + def __init__( self, pipeline: beam_runner_api_pb2.Pipeline, @@ -75,9 +76,8 @@ def write_cache(self) -> None: # Copy cache writing subgraph from the template to the pipeline proto. for pcoll_id in template.components.pcollections: - if (pcoll_id in self._pipeline.components.pcollections or - pcoll_id in write_input_placeholder.ignorable_components.pcollections - ): + if (pcoll_id in self._pipeline.components.pcollections or pcoll_id + in write_input_placeholder.ignorable_components.pcollections): continue self._pipeline.components.pcollections[pcoll_id].CopyFrom( template.components.pcollections[pcoll_id]) @@ -88,10 +88,10 @@ def write_cache(self) -> None: self._pipeline.components.coders[coder_id].CopyFrom( template.components.coders[coder_id]) for windowing_strategy_id in template.components.windowing_strategies: - if (windowing_strategy_id in - self._pipeline.components.windowing_strategies or - windowing_strategy_id in - write_input_placeholder.ignorable_components.windowing_strategies): + if (windowing_strategy_id + in self._pipeline.components.windowing_strategies or + windowing_strategy_id + in write_input_placeholder.ignorable_components.windowing_strategies): continue self._pipeline.components.windowing_strategies[ windowing_strategy_id].CopyFrom( @@ -106,8 +106,8 @@ def write_cache(self) -> None: template.components.transforms[transform_id]) for top_level_transform in template.components.transforms[ template_root_transform_id].subtransforms: - if (top_level_transform in - write_input_placeholder.ignorable_components.transforms): + if (top_level_transform + in write_input_placeholder.ignorable_components.transforms): continue self._pipeline.components.transforms[ root_transform_id].subtransforms.append(top_level_transform) @@ -135,6 +135,7 @@ def _build_runner_api_template( class _WriteCacheTransform(PTransform): """A composite transform encapsulates writing cache for PCollections. """ + def __init__(self, cache_manager: cache.CacheManager, key: str): self._cache_manager = cache_manager self._key = key @@ -147,6 +148,7 @@ def expand(self, pcoll: beam.pvalue.PCollection) -> beam.pvalue.PValue: class _PCollectionPlaceHolder: """A placeholder as an input to the cache writing transform. """ + def __init__(self, pcoll: beam.pvalue.PCollection, context: PipelineContext): tmp_pipeline = beam.Pipeline() tmp_pipeline.component_id_map = context.component_id_map diff --git a/sdks/python/apache_beam/runners/interactive/caching/write_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/write_cache_test.py index 588efdcf5e3a..2b8bea81267c 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/write_cache_test.py +++ b/sdks/python/apache_beam/runners/interactive/caching/write_cache_test.py @@ -31,6 +31,7 @@ class WriteCacheTest(unittest.TestCase): + def setUp(self): ie.new_env() diff --git a/sdks/python/apache_beam/runners/interactive/dataproc/dataproc_cluster_manager.py b/sdks/python/apache_beam/runners/interactive/dataproc/dataproc_cluster_manager.py index 4d260d4a6a56..8c494300f32f 100644 --- a/sdks/python/apache_beam/runners/interactive/dataproc/dataproc_cluster_manager.py +++ b/sdks/python/apache_beam/runners/interactive/dataproc/dataproc_cluster_manager.py @@ -78,6 +78,7 @@ class DataprocClusterManager: """Self-contained cluster manager that controls the lifecyle of a Dataproc cluster connected by one or more pipelines under Interactive Beam. """ + def __init__(self, cluster_metadata: ClusterMetadata) -> None: """Initializes the DataprocClusterManager with properties required to interface with the Dataproc ClusterControllerClient. diff --git a/sdks/python/apache_beam/runners/interactive/dataproc/dataproc_cluster_manager_test.py b/sdks/python/apache_beam/runners/interactive/dataproc/dataproc_cluster_manager_test.py index 0d69eba16508..d6b1e4d2de90 100644 --- a/sdks/python/apache_beam/runners/interactive/dataproc/dataproc_cluster_manager_test.py +++ b/sdks/python/apache_beam/runners/interactive/dataproc/dataproc_cluster_manager_test.py @@ -35,22 +35,26 @@ class MockProperty: + def __init__(self, property, value): object.__setattr__(self, property, value) class MockException(Exception): + def __init__(self, code=-1): self.code = code class MockCluster: + def __init__(self, config_bucket=None): self.config = MockProperty('config_bucket', config_bucket) self.status = MockProperty('state', MockProperty('name', None)) class MockFileSystem: + def _list(self, dir=None): return [ MockProperty( @@ -63,6 +67,7 @@ def open(self, dir=None): class MockFileIO: + def __init__(self, contents): self.contents = contents @@ -73,6 +78,7 @@ def readlines(self): @unittest.skipIf(not _dataproc_imported, 'dataproc package was not imported.') class DataprocClusterManagerTest(unittest.TestCase): """Unit test for DataprocClusterManager""" + def setUp(self): self.patcher = patch( 'apache_beam.runners.interactive.interactive_environment.current_env') diff --git a/sdks/python/apache_beam/runners/interactive/display/display_manager.py b/sdks/python/apache_beam/runners/interactive/display/display_manager.py index e1f248304228..781d2b138d36 100644 --- a/sdks/python/apache_beam/runners/interactive/display/display_manager.py +++ b/sdks/python/apache_beam/runners/interactive/display/display_manager.py @@ -52,6 +52,7 @@ def _formatter(string, pp, cycle): # pylint: disable=unused-argument class DisplayManager(object): """Manages displaying pipeline graph and execution status on the frontend.""" + def __init__( self, pipeline_proto, diff --git a/sdks/python/apache_beam/runners/interactive/display/interactive_pipeline_graph.py b/sdks/python/apache_beam/runners/interactive/display/interactive_pipeline_graph.py index 5a0943e12e6d..e77cffba8777 100644 --- a/sdks/python/apache_beam/runners/interactive/display/interactive_pipeline_graph.py +++ b/sdks/python/apache_beam/runners/interactive/display/interactive_pipeline_graph.py @@ -49,6 +49,7 @@ def format_sample(contents, count=1000): class InteractivePipelineGraph(pipeline_graph.PipelineGraph): """Creates the DOT representation of an interactive pipeline. Thread-safe.""" + def __init__( self, pipeline, diff --git a/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py b/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py index d767a15a345d..b342a361051e 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py +++ b/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py @@ -208,6 +208,7 @@ def visualize( tl = Timeloop() def dynamic_plotting(stream, pv, tl, include_window_info, display_facets): + @tl.job(interval=timedelta(seconds=dynamic_plotting_interval)) def continuous_update_display(): # pylint: disable=unused-variable # Always creates a new PCollVisualization instance when the @@ -285,6 +286,7 @@ class PCollectionVisualization(object): access current interactive environment for materialized PCollection data at the moment of self instantiation through cache. """ + def __init__( self, stream, diff --git a/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization_test.py b/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization_test.py index 7fc76feb7494..24a6f85168cd 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization_test.py +++ b/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization_test.py @@ -48,6 +48,7 @@ not ie.current_env().is_interactive_ready, '[interactive] dependency is not installed.') class PCollectionVisualizationTest(unittest.TestCase): + def setUp(self): ie.new_env() # Allow unit test to run outside of ipython kernel since we don't test the diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py index 1f1e315fea09..817f2801a56b 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py @@ -46,6 +46,7 @@ class PipelineGraph(object): """Creates a DOT representing the pipeline. Thread-safe. Runner agnostic.""" + def __init__( self, pipeline: Union[beam_runner_api_pb2.Pipeline, beam.Pipeline], @@ -267,6 +268,7 @@ def _update_graph(self, vertex_dict=None, edge_dict=None): Or (Dict[(str, str), Dict[str, str]]) which maps vertex pairs to edge attributes """ + def set_attrs(ref, attrs): for attr_name, attr_val in attrs.items(): ref.set(attr_name, attr_val) diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py index ad46f5d65ea3..d4b6a36540a2 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py @@ -38,6 +38,7 @@ class PipelineGraphRenderer(BeamPlugin, metaclass=abc.ABCMeta): """Abstract class for renderers, who decide how pipeline graphs are rendered. """ + @classmethod @abc.abstractmethod def option(cls) -> str: @@ -61,6 +62,7 @@ def render_pipeline_graph(self, pipeline_graph: 'PipelineGraph') -> str: class MuteRenderer(PipelineGraphRenderer): """Use this renderer to mute the pipeline display. """ + @classmethod def option(cls) -> str: return 'mute' @@ -72,6 +74,7 @@ def render_pipeline_graph(self, pipeline_graph: 'PipelineGraph') -> str: class TextRenderer(PipelineGraphRenderer): """This renderer simply returns the dot representation in text format. """ + @classmethod def option(cls) -> str: return 'text' @@ -87,6 +90,7 @@ class PydotRenderer(PipelineGraphRenderer): 1. The software Graphviz: https://www.graphviz.org/ 2. The python module pydot: https://pypi.org/project/pydot/ """ + @classmethod def option(cls) -> str: return 'graph' diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_test.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_test.py index 419cd50ac6e9..5e9c919632b5 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_test.py +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_test.py @@ -40,6 +40,7 @@ not ie.current_env().is_interactive_ready, '[interactive] dependency is not installed.') class PipelineGraphTest(unittest.TestCase): + def setUp(self): ie.new_env() diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py index e3dc8b8968ad..604eb573aae9 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py @@ -68,6 +68,7 @@ class Options(interactive_options.InteractiveOptions): """Options that guide how Interactive Beam works.""" + @property def enable_recording_replay(self): """Whether replayable source data recorded should be replayed for multiple @@ -273,6 +274,7 @@ class Recordings(): from all defined unbounded sources for that PCollection's pipeline. The following methods allow for introspection into that background recording job. """ + def describe( self, pipeline: Optional[beam.Pipeline] = None) -> Dict[str, Any]: # noqa: F821 diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py index 7af747af5407..d6cd70c29f77 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py @@ -67,6 +67,7 @@ def _get_watched_pcollections_with_variable_names(): @isolated_env class InteractiveBeamTest(unittest.TestCase): + def setUp(self): self._var_in_class_instance = 'a var in class instance, not directly used' @@ -177,7 +178,9 @@ def test_show_handles_deferred_dataframes(self, mocked_visualize): 'apache_beam.runners.interactive.interactive_beam.' 'visualize_computed_pcoll')) def test_show_noop_when_pcoll_container_is_invalid(self, mocked_visualize): + class SomeRandomClass: + def __init__(self, pcoll): self._pcoll = pcoll @@ -256,6 +259,7 @@ def test_recordings_record(self): # written to cache. This is used to make ensure that the pipeline is # functioning properly and that there are no data races with the test. class SizeLimiter(Limiter): + def __init__(self, pipeline): self.pipeline = pipeline self.should_trigger = False @@ -299,6 +303,7 @@ def is_triggered(self): '[interactive] dependency is not installed.') @isolated_env class InteractiveBeamClustersTest(unittest.TestCase): + def setUp(self): self.current_env.options.cache_root = 'gs://fake' self.clusters = self.current_env.clusters diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment.py b/sdks/python/apache_beam/runners/interactive/interactive_environment.py index 0e3d0060b1a4..de3031a2344e 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_environment.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_environment.py @@ -135,6 +135,7 @@ class InteractiveEnvironment(object): also visualize and introspect those PCollections in user code since they have handles to the variables. """ + def __init__(self): # Registers a cleanup routine when system exits. atexit.register(self.cleanup) @@ -462,8 +463,7 @@ def describe_all_recordings(self): """Returns a description of the recording for all watched pipelnes.""" return { self.pipeline_id_to_pipeline(pid): rm.describe() - for pid, - rm in self._recording_managers.items() + for pid, rm in self._recording_managers.items() } def set_pipeline_result(self, pipeline, result): diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py index 4d5f3f36ce67..f937337b4b66 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py @@ -36,6 +36,7 @@ @isolated_env class InteractiveEnvironmentTest(unittest.TestCase): + def setUp(self): self._p = beam.Pipeline() self._var_in_class_instance = 'a var in class instance' @@ -93,6 +94,7 @@ def test_watch_class_instance(self): '_var_in_class_instance', self._var_in_class_instance) def test_fail_to_set_pipeline_result_key_not_pipeline(self): + class NotPipeline(object): pass @@ -104,6 +106,7 @@ class NotPipeline(object): 'or its subclass' in ctx.exception) def test_fail_to_set_pipeline_result_value_not_pipeline_result(self): + class NotResult(object): pass @@ -115,6 +118,7 @@ class NotResult(object): 'subclass' in ctx.exception) def test_set_pipeline_result_successfully(self): + class PipelineSubClass(beam.Pipeline): pass diff --git a/sdks/python/apache_beam/runners/interactive/interactive_runner.py b/sdks/python/apache_beam/runners/interactive/interactive_runner.py index 17619fbb6ddc..6db5c5964cb3 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_runner.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_runner.py @@ -54,6 +54,7 @@ class InteractiveRunner(runners.PipelineRunner): Allows interactively building and running Beam Python pipelines. """ + def __init__( self, underlying_runner=None, @@ -189,6 +190,7 @@ def exception_handler(e): # TODO: make the StreamingCacheManager and TestStreamServiceController # constructed when the InteractiveEnvironment is imported. class TestStreamVisitor(PipelineVisitor): + def visit_transform(self, transform_node): from apache_beam.testing.test_stream import TestStream if (isinstance(transform_node.transform, TestStream) and @@ -299,6 +301,7 @@ def _configure_flink_options( class PipelineResult(beam.runners.runner.PipelineResult): """Provides access to information about a pipeline.""" + def __init__(self, underlying_result, pipeline_instrument): """Constructor of PipelineResult. diff --git a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py index ed27d9e55e06..312e015c2554 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py @@ -53,6 +53,7 @@ def print_with_message(msg): + def printer(elem): print(msg, elem) return elem @@ -68,6 +69,7 @@ class Record(NamedTuple): @isolated_env class InteractiveRunnerTest(unittest.TestCase): + @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]") def test_basic(self): p = beam.Pipeline( @@ -88,7 +90,9 @@ def test_basic(self): @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]") def test_wordcount(self): + class WordExtractingDoFn(beam.DoFn): + def process(self, element): text_line = element.strip() words = text_line.split() @@ -160,7 +164,9 @@ def process(self, element): self.assertEqual(actual_reified, expected_reified) def test_streaming_wordcount(self): + class WordExtractingDoFn(beam.DoFn): + def process(self, element): text_line = element.strip() words = text_line.split() @@ -247,7 +253,9 @@ def process(self, element): pd.testing.assert_frame_equal(expected_counts_df, sorted_counts_df) def test_session(self): + class MockPipelineRunner(object): + def __init__(self): self._in_session = False @@ -457,6 +465,7 @@ def test_dataframe_caching(self): # Only look at the top-level transforms for the isomorphism. The test # doesn't care about the transform implementations, just the overall shape. class TopLevelTracer(beam.pipeline.PipelineVisitor): + def _find_root_producer(self, node: beam.pipeline.AppliedPTransform): if node is None or not node.full_label: return None @@ -538,6 +547,7 @@ def test_defaults_to_efficient_cache(self): '[interactive] dependency is not installed.') @isolated_env class ConfigForFlinkTest(unittest.TestCase): + def setUp(self): self.current_env.options.cache_root = 'gs://fake' diff --git a/sdks/python/apache_beam/runners/interactive/messaging/interactive_environment_inspector.py b/sdks/python/apache_beam/runners/interactive/messaging/interactive_environment_inspector.py index dd7ee947daad..2c3f052a9a8a 100644 --- a/sdks/python/apache_beam/runners/interactive/messaging/interactive_environment_inspector.py +++ b/sdks/python/apache_beam/runners/interactive/messaging/interactive_environment_inspector.py @@ -36,6 +36,7 @@ class InteractiveEnvironmentInspector(object): list_inspectables first then communicates back to the kernel and get_val for usage on the kernel side. """ + def __init__(self, ignore_synthetic=True): self._inspectables = {} self._anonymous = {} diff --git a/sdks/python/apache_beam/runners/interactive/messaging/interactive_environment_inspector_test.py b/sdks/python/apache_beam/runners/interactive/messaging/interactive_environment_inspector_test.py index 6140fff5dd14..e830d2c23a39 100644 --- a/sdks/python/apache_beam/runners/interactive/messaging/interactive_environment_inspector_test.py +++ b/sdks/python/apache_beam/runners/interactive/messaging/interactive_environment_inspector_test.py @@ -39,6 +39,7 @@ sys.version_info < (3, 7), 'The tests require at least Python 3.7 to work.') @isolated_env class InteractiveEnvironmentInspectorTest(unittest.TestCase): + def test_inspect(self): with self.cell: # Cell 1 pipeline = beam.Pipeline(ir.InteractiveRunner()) diff --git a/sdks/python/apache_beam/runners/interactive/non_interactive_runner_test.py b/sdks/python/apache_beam/runners/interactive/non_interactive_runner_test.py index f7fd052fecc4..53c36b97f8bf 100644 --- a/sdks/python/apache_beam/runners/interactive/non_interactive_runner_test.py +++ b/sdks/python/apache_beam/runners/interactive/non_interactive_runner_test.py @@ -39,6 +39,7 @@ def print_with_message(msg): + def printer(elem): print(msg, elem) return elem @@ -73,6 +74,7 @@ def clear_side_effect(): @isolated_env class NonInteractiveRunnerTest(unittest.TestCase): + @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]") def test_basic(self): clear_side_effect() @@ -146,7 +148,9 @@ def test_multiple_collect(self): @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]") def test_wordcount(self): + class WordExtractingDoFn(beam.DoFn): + def process(self, element): text_line = element.strip() words = text_line.split() @@ -259,6 +263,7 @@ def test_dataframes_same_cell_twice(self): @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]") def test_new_runner_and_options(self): + class MyRunner(beam.runners.PipelineRunner): run_count = 0 diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_control.py b/sdks/python/apache_beam/runners/interactive/options/capture_control.py index 826b596bbc6d..6d22c6bc1b82 100644 --- a/sdks/python/apache_beam/runners/interactive/options/capture_control.py +++ b/sdks/python/apache_beam/runners/interactive/options/capture_control.py @@ -37,6 +37,7 @@ class CaptureControl(object): """Options and their utilities that controls how Interactive Beam captures deterministic replayable data from sources.""" + def __init__(self): self._enable_capture_replay = True self._capturable_sources = { diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_control_test.py b/sdks/python/apache_beam/runners/interactive/options/capture_control_test.py index 8b734f7f15c2..aa9581736af4 100644 --- a/sdks/python/apache_beam/runners/interactive/options/capture_control_test.py +++ b/sdks/python/apache_beam/runners/interactive/options/capture_control_test.py @@ -49,7 +49,9 @@ def _build_an_empty_streaming_pipeline(): def _fake_a_running_test_stream_service(pipeline): + class FakeReader: + def read_multiple(self): yield 1 @@ -63,6 +65,7 @@ def read_multiple(self): not ie.current_env().is_interactive_ready, '[interactive] dependency is not installed.') class CaptureControlTest(unittest.TestCase): + def setUp(self): ie.new_env() @@ -155,6 +158,7 @@ def test_timer_terminates_capture_size_checker(self): p = _build_an_empty_streaming_pipeline() class FakeLimiter(capture_limiters.Limiter): + def __init__(self): self.trigger = False diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py index 497772f94c36..7eb196f102a6 100644 --- a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py +++ b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py @@ -34,6 +34,7 @@ class Limiter: """Limits an aspect of the caching layer.""" + def is_triggered(self) -> bool: """Returns True if the limiter has triggered, and caching should stop.""" raise NotImplementedError @@ -43,6 +44,7 @@ class ElementLimiter(Limiter): """A `Limiter` that limits reading from cache based on some property of an element. """ + def update(self, e: Any) -> None: # noqa: F821 @@ -55,6 +57,7 @@ def update(self, e: Any) -> None: class SizeLimiter(Limiter): """Limits the cache size to a specified byte limit.""" + def __init__(self, size_limit: int): self._size_limit = size_limit @@ -70,6 +73,7 @@ def is_triggered(self): class DurationLimiter(Limiter): """Limits the duration of the capture.""" + def __init__( self, duration_limit: datetime.timedelta # noqa: F821 @@ -89,6 +93,7 @@ def is_triggered(self): class CountLimiter(ElementLimiter): """Limits by counting the number of elements seen.""" + def __init__(self, max_count): self._max_count = max_count self._count = 0 @@ -125,6 +130,7 @@ class ProcessingTimeLimiter(ElementLimiter): clock forward. This triggers when the duration from the internal clock and the start exceeds the given duration. """ + def __init__(self, max_duration_secs): """Initialize the ProcessingTimeLimiter.""" self._max_duration_us = max_duration_secs * 1e6 diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py b/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py index 06b8e60c3130..b2b4cf38e690 100644 --- a/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py +++ b/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py @@ -26,6 +26,7 @@ class CaptureLimitersTest(unittest.TestCase): + def test_count_limiter(self): limiter = CountLimiter(5) diff --git a/sdks/python/apache_beam/runners/interactive/options/interactive_options.py b/sdks/python/apache_beam/runners/interactive/options/interactive_options.py index d737640b0b2d..c334b2d28717 100644 --- a/sdks/python/apache_beam/runners/interactive/options/interactive_options.py +++ b/sdks/python/apache_beam/runners/interactive/options/interactive_options.py @@ -30,6 +30,7 @@ class InteractiveOptions(object): """An intermediate facade to query and configure options that guide how Interactive Beam works.""" + def __init__(self): self._capture_control = capture_control.CaptureControl() self._display_timestamp_format = '%Y-%m-%d %H:%M:%S.%f%z' diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_fragment.py b/sdks/python/apache_beam/runners/interactive/pipeline_fragment.py index 20dee2b71163..e621e67d1bc3 100644 --- a/sdks/python/apache_beam/runners/interactive/pipeline_fragment.py +++ b/sdks/python/apache_beam/runners/interactive/pipeline_fragment.py @@ -34,6 +34,7 @@ class PipelineFragment(object): A pipeline fragment is built from the original pipeline definition to include only PTransforms that are necessary to produce the given PCollections. """ + def __init__(self, pcolls, options=None, runner=None): """Constructor of PipelineFragment. @@ -163,6 +164,7 @@ def _calculate_user_transform_labels(self): label_to_user_transform = {} class UserTransformVisitor(PipelineVisitor): + def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) @@ -180,6 +182,7 @@ def _build_correlation_between_pipelines( runner_transforms_to_user_transforms = {} class CorrelationVisitor(PipelineVisitor): + def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) @@ -242,7 +245,9 @@ def _mark_necessary_transforms_and_pcolls(self, runner_pcolls_to_user_pcolls): def _prune_runner_pipeline_to_fragment( self, runner_pipeline, necessary_transforms): + class PruneVisitor(PipelineVisitor): + def enter_composite_transform(self, transform_node): if should_skip_pruning(transform_node): return diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py b/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py index 3e7207fbb117..610b9d22c35c 100644 --- a/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py +++ b/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py @@ -35,6 +35,7 @@ not ie.current_env().is_interactive_ready, '[interactive] dependency is not installed.') class PipelineFragmentTest(unittest.TestCase): + def setUp(self): ie.new_env() # Assume a notebook frontend is connected to the mocked ipython kernel. diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py b/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py index 8e5d50ed3f3f..8b701e561442 100644 --- a/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py +++ b/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py @@ -53,6 +53,7 @@ class PipelineInstrument(object): runner's responsibility to coordinate supported underlying runners to run the pipeline instrumented and recover the original pipeline states if needed. """ + def __init__(self, pipeline, options=None): self._pipeline = pipeline @@ -178,8 +179,7 @@ def _required_components( visited_copy = visited.copy() consuming_transforms = { t_id: t - for t_id, - t in transforms.items() + for t_id, t in transforms.items() if set(outputs).intersection(set(t.inputs.values())) } consuming_transforms = set(consuming_transforms.keys()) @@ -201,8 +201,7 @@ def _required_components( ] producing_transforms = { t_id: t - for t_id, - t in transforms.items() + for t_id, t in transforms.items() if set(inputs).intersection(set(t.outputs.values())) } (t, pc) = self._required_components( @@ -296,8 +295,8 @@ def background_caching_pipeline_proto(self): # Get the IDs of the unbounded sources. required_transform_labels = [src.full_label for src in sources] unbounded_source_ids = [ - k for k, - v in transforms.items() if v.unique_name in required_transform_labels + k for k, v in transforms.items() + if v.unique_name in required_transform_labels ] # The required transforms are the transforms that we want to cut out of @@ -412,6 +411,7 @@ def instrument(self): class InstrumentVisitor(PipelineVisitor): """Visitor utilizes cache to instrument the pipeline.""" + def __init__(self, pin): self._pin = pin @@ -460,6 +460,7 @@ def visit_transform(self, transform_node): is_capture=True) class TestStreamVisitor(PipelineVisitor): + def __init__(self): self.test_stream = None @@ -496,7 +497,9 @@ def preprocess(self): of cacheable PCollections between these 2 instances by replacing 'pcoll' fields in the cacheable dictionary with ones from the running instance. """ + class PreprocessVisitor(PipelineVisitor): + def __init__(self, pin): self._pin = pin @@ -596,8 +599,8 @@ def _read_cache(self, pipeline, pcoll, is_unbounded_source_output): is_cached = self._cache_manager.exists('full', key) is_computed = ( pcoll in self._runner_pcoll_to_user_pcoll and - self._runner_pcoll_to_user_pcoll[pcoll] in - ie.current_env().computed_pcollections) + self._runner_pcoll_to_user_pcoll[pcoll] + in ie.current_env().computed_pcollections) if ((is_cached and is_computed) or is_unbounded_source_output): if key not in self._cached_pcoll_read: # Mutates the pipeline with cache read transform attached @@ -626,6 +629,7 @@ def _replace_with_cached_inputs(self, pipeline): if self.has_unbounded_sources: class CacheableUnboundedPCollectionVisitor(PipelineVisitor): + def __init__(self, pin): self._pin = pin self.unbounded_pcolls = set() @@ -668,6 +672,7 @@ class ReadCacheWireVisitor(PipelineVisitor): """Visitor wires cache read as inputs to replace corresponding original input PCollections in pipeline. """ + def __init__(self, pin): """Initializes with a PipelineInstrument.""" self._pin = pin @@ -769,6 +774,7 @@ def pcoll_to_pcoll_id(pipeline, original_context): Returns: (dict from str to str) a dict mapping str(pcoll) to pcoll_id. """ + class PCollVisitor(PipelineVisitor): """"A visitor that records input and output values to be replaced. @@ -778,6 +784,7 @@ class PCollVisitor(PipelineVisitor): We cannot update input and output values while visiting since that results in validation errors. """ + def __init__(self): self.pcoll_to_pcoll_id = {} diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py index 893603ddbb52..e0046b5cc8d1 100644 --- a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py +++ b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py @@ -42,6 +42,7 @@ class PipelineInstrumentTest(unittest.TestCase): + def setUp(self): ie.new_env() @@ -257,6 +258,7 @@ def test_instrument_example_pipeline_to_read_cache(self): class TestReadCacheWireVisitor(PipelineVisitor): """Replace init_pcoll with cached_init_pcoll for all occuring inputs.""" + def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) @@ -327,6 +329,7 @@ def test_instrument_example_unbounded_pipeline_to_read_cache(self): # Test that the TestStream is outputting to the correct PCollection. class TestStreamVisitor(PipelineVisitor): + def __init__(self): self.output_tags = set() @@ -420,6 +423,7 @@ def test_able_to_cache_intermediate_unbounded_source_pcollection(self): # Test that the TestStream is outputting to the correct PCollection. class TestStreamVisitor(PipelineVisitor): + def __init__(self): self.output_tags = set() @@ -505,6 +509,7 @@ def test_instrument_mixed_streaming_batch(self): # Test that the TestStream is outputting to the correct PCollection. class TestStreamVisitor(PipelineVisitor): + def __init__(self): self.output_tags = set() @@ -570,6 +575,7 @@ def test_instrument_example_unbounded_pipeline_direct_from_source(self): # Test that the TestStream is outputting to the correct PCollection. class TestStreamVisitor(PipelineVisitor): + def __init__(self): self.output_tags = set() @@ -641,6 +647,7 @@ def test_instrument_example_unbounded_pipeline_to_read_cache_not_cached(self): # Test that the TestStream is outputting to the correct PCollection. class TestStreamVisitor(PipelineVisitor): + def __init__(self): self.output_tags = set() @@ -716,6 +723,7 @@ def test_instrument_example_unbounded_pipeline_to_multiple_read_cache(self): # Test that the TestStream is outputting to the correct PCollection. class TestStreamVisitor(PipelineVisitor): + def __init__(self): self.output_tags = set() diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py index 6811d3e0d345..12947e357625 100644 --- a/sdks/python/apache_beam/runners/interactive/recording_manager.py +++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py @@ -45,6 +45,7 @@ class ElementStream: """A stream of elements from a given PCollection.""" + def __init__( self, pcoll: beam.pvalue.PCollection, @@ -148,21 +149,23 @@ def read(self, tail: bool = True) -> Any: class Recording: """A group of PCollections from a given pipeline run.""" + def __init__( self, user_pipeline: beam.Pipeline, - pcolls: List[beam.pvalue.PCollection], # noqa: F821 + pcolls: List[beam.pvalue.PCollection], # noqa: F821 result: 'beam.runner.PipelineResult', max_n: int, max_duration_secs: float, - ): + ): self._user_pipeline = user_pipeline self._result = result self._result_lock = threading.Lock() self._pcolls = pcolls - pcoll_var = lambda pcoll: {v: k - for k, v in utils.pcoll_by_name().items()}.get( - pcoll, None) + pcoll_var = lambda pcoll: { + v: k + for k, v in utils.pcoll_by_name().items() + }.get(pcoll, None) self._streams = { pcoll: ElementStream( @@ -254,6 +257,7 @@ def describe(self) -> Dict[str, int]: class RecordingManager: """Manages recordings of PCollections for a given pipeline.""" + def __init__( self, user_pipeline: beam.Pipeline, diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager_test.py b/sdks/python/apache_beam/runners/interactive/recording_manager_test.py index 698a464ae739..bb8a4bc8043f 100644 --- a/sdks/python/apache_beam/runners/interactive/recording_manager_test.py +++ b/sdks/python/apache_beam/runners/interactive/recording_manager_test.py @@ -45,6 +45,7 @@ class MockPipelineResult(beam.runners.runner.PipelineResult): """Mock class for controlling a PipelineResult.""" + def __init__(self): self._state = PipelineState.RUNNING @@ -63,6 +64,7 @@ def cancel(self): class ElementStreamTest(unittest.TestCase): + def setUp(self): self.cache = InMemoryCache() self.p = beam.Pipeline() @@ -142,6 +144,7 @@ def test_read_n(self): def test_read_duration(self): """Test that the stream only reads a 'duration' of elements.""" + def as_windowed_value(element): return WindowedValueHolder(WindowedValue(element, 0, [])) @@ -184,6 +187,7 @@ def as_windowed_value(element): class RecordingTest(unittest.TestCase): + def test_computed(self): """Tests that a PCollection is marked as computed only in a complete state. @@ -284,6 +288,7 @@ def test_describe(self): class RecordingManagerTest(unittest.TestCase): + def test_basic_execution(self): """A basic pipeline to be used as a smoke test.""" @@ -361,6 +366,7 @@ def test_cancel_stops_recording(self): ie.current_env().track_user_pipelines() class SemaphoreLimiter(Limiter): + def __init__(self): self.triggered = False @@ -503,6 +509,7 @@ def test_record_pipeline(self): # written to cache. This is used to make ensure that the pipeline is # functioning properly and that there are no data races with the test. class SizeLimiter(Limiter): + def __init__(self, p): self.pipeline = p self._rm = None diff --git a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py index bf4c4c0380e5..132a903b0bf5 100644 --- a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py +++ b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py @@ -91,6 +91,7 @@ class BeamSqlParser: """A parser to parse beam_sql inputs.""" + def __init__(self): self._parser = argparse.ArgumentParser(usage=_EXAMPLE_USAGE) self._parser.add_argument( @@ -156,6 +157,7 @@ def on_error(error_msg, *args): @magics_class class BeamSqlMagics(Magics): + def __init__(self, shell): super().__init__(shell) # Eagerly initializes the environment. @@ -328,6 +330,7 @@ def pcolls_from_streaming_cache( When the user_pipeline has unbounded sources, we force all cache reads to go through the TestStream even if they are bounded sources. """ + def exception_handler(e): _LOGGER.error(str(e)) return True diff --git a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py index 3d843a0f6ae8..7f8081eca4af 100644 --- a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py +++ b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py @@ -46,6 +46,7 @@ not ie.current_env().is_interactive_ready, reason='[interactive] dependency is not installed.') class BeamSqlMagicsTest(unittest.TestCase): + def test_generate_output_name_when_not_provided(self): output_name = None self.assertTrue( @@ -76,10 +77,7 @@ def test_build_query_components_when_single_pcoll_queried(self): with patch('apache_beam.runners.interactive.sql.beam_sql_magics.' 'unreify_from_cache', - lambda pipeline, - cache_key, - cache_manager, - element_type: target): + lambda pipeline, cache_key, cache_manager, element_type: target): processed_query, sql_source, chain = _build_query_components( query, found, 'output') expected_query = 'SELECT * FROM PCOLLECTION where a=1' @@ -97,12 +95,10 @@ def test_build_query_components_when_multiple_pcolls_queried(self): query = 'SELECT * FROM pcoll_1 JOIN pcoll_2 USING (a)' found = {'pcoll_1': pcoll_1, 'pcoll_2': pcoll_2} - with patch('apache_beam.runners.interactive.sql.beam_sql_magics.' - 'unreify_from_cache', - lambda pipeline, - cache_key, - cache_manager, - element_type: pcoll_1): + with patch( + 'apache_beam.runners.interactive.sql.beam_sql_magics.' + 'unreify_from_cache', + lambda pipeline, cache_key, cache_manager, element_type: pcoll_1): processed_query, sql_source, chain = _build_query_components( query, found, 'output') @@ -124,10 +120,7 @@ def test_build_query_components_when_unbounded_pcolls_queried(self): found = {'pcoll': pcoll} with patch('apache_beam.runners.interactive.sql.beam_sql_magics.' - 'pcolls_from_streaming_cache', - lambda a, - b, - c: found): + 'pcolls_from_streaming_cache', lambda a, b, c: found): _, sql_source, chain = _build_query_components(query, found, 'output') self.assertIs(sql_source, pcoll) self.assertIn('pcoll', chain.current.source) @@ -141,9 +134,7 @@ def test_cache_output(self): ie.current_env().set_cache_manager(cache_manager, p_cache_output) ib.watch(locals()) with patch('apache_beam.runners.interactive.display.pcoll_visualization.' - 'visualize_computed_pcoll', - lambda a, - b: None): + 'visualize_computed_pcoll', lambda a, b: None): cache_output('pcoll_co', pcoll_co) self.assertIn(pcoll_co, ie.current_env().computed_pcollections) self.assertTrue( diff --git a/sdks/python/apache_beam/runners/interactive/sql/sql_chain.py b/sdks/python/apache_beam/runners/interactive/sql/sql_chain.py index a6f48661b87b..646599bd0737 100644 --- a/sdks/python/apache_beam/runners/interactive/sql/sql_chain.py +++ b/sdks/python/apache_beam/runners/interactive/sql/sql_chain.py @@ -122,6 +122,7 @@ class SchemaLoadedSqlTransform(beam.PTransform): makes sure only the schemas needed are pickled locally and restored later on workers. """ + def __init__(self, output_name, query, schemas, execution_count): self.output_name = output_name self.query = query @@ -137,6 +138,7 @@ def __init__(self, output_name, query, schemas, execution_count): class _SqlTransformDoFn(beam.DoFn): """The DoFn yields all its input without any transform but a setup to configure the main session.""" + def __init__(self, schemas, annotations): self.pickled_schemas = [pickler.dumps(s) for s in schemas] self.pickled_annotations = [pickler.dumps(a) for a in annotations] @@ -165,8 +167,7 @@ def expand(self, source): self.output_name, tag, self.execution_count) >> beam.ParDo( self._SqlTransformDoFn(self.schemas, self.schema_annotations)) if pcoll.element_type in self.schemas else pcoll - for tag, - pcoll in source.items() + for tag, pcoll in source.items() } elif isinstance(source, beam.pvalue.PCollection): schema_loaded = source | 'load_schemas_{}_{}'.format( diff --git a/sdks/python/apache_beam/runners/interactive/sql/sql_chain_test.py b/sdks/python/apache_beam/runners/interactive/sql/sql_chain_test.py index 42d0804665e2..6e690d9151d8 100644 --- a/sdks/python/apache_beam/runners/interactive/sql/sql_chain_test.py +++ b/sdks/python/apache_beam/runners/interactive/sql/sql_chain_test.py @@ -32,6 +32,7 @@ class SqlChainTest(unittest.TestCase): + def test_init(self): chain = SqlChain() self.assertEqual({}, chain.nodes) diff --git a/sdks/python/apache_beam/runners/interactive/sql/utils.py b/sdks/python/apache_beam/runners/interactive/sql/utils.py index a6e810d5555b..79251dbb35f5 100644 --- a/sdks/python/apache_beam/runners/interactive/sql/utils.py +++ b/sdks/python/apache_beam/runners/interactive/sql/utils.py @@ -115,8 +115,8 @@ def pformat_namedtuple(schema: NamedTuple) -> str: return '{}({})'.format( schema.__name__, ', '.join([ - '{}: {}'.format(k, repr(v)) for k, - v in schema.__annotations__.items() + '{}: {}'.format(k, repr(v)) + for k, v in schema.__annotations__.items() ])) @@ -160,6 +160,7 @@ class OptionsForm: """A form visualized to take inputs from users in IPython Notebooks and generate PipelineOptions to run pipelines. """ + def __init__(self): # The current Python SDK incorrectly parses unparsable pipeline options # Here we ignore all flags for the interactive beam_sql magic @@ -232,6 +233,7 @@ class DataflowOptionsForm(OptionsForm): Only contains minimum fields needed. """ + @staticmethod def _build_default_project() -> str: """Builds a default project id.""" diff --git a/sdks/python/apache_beam/runners/interactive/sql/utils_test.py b/sdks/python/apache_beam/runners/interactive/sql/utils_test.py index d5747b03a919..a12d6a1a600f 100644 --- a/sdks/python/apache_beam/runners/interactive/sql/utils_test.py +++ b/sdks/python/apache_beam/runners/interactive/sql/utils_test.py @@ -49,6 +49,7 @@ class OptionalUnionType(NamedTuple): class UtilsTest(unittest.TestCase): + def test_register_coder_for_schema(self): self.assertNotIsInstance( beam.coders.registry.get_coder(ANamedTuple), beam.coders.RowCoder) @@ -99,6 +100,7 @@ def test_pformat_dict(self): not ie.current_env().is_interactive_ready, reason='[interactive] dependency is not installed.') class OptionsFormTest(unittest.TestCase): + def test_dataflow_options_form(self): p = beam.Pipeline() pcoll = p | beam.Create([1, 2, 3]) diff --git a/sdks/python/apache_beam/runners/interactive/testing/integration/notebook_executor.py b/sdks/python/apache_beam/runners/interactive/testing/integration/notebook_executor.py index 808ede64d60d..7f9871738b5b 100644 --- a/sdks/python/apache_beam/runners/interactive/testing/integration/notebook_executor.py +++ b/sdks/python/apache_beam/runners/interactive/testing/integration/notebook_executor.py @@ -40,6 +40,7 @@ class NotebookExecutor(object): """Executor that reads notebooks, executes it and gathers outputs into static HTML pages that can be served.""" + def __init__(self, path: str) -> None: assert _interactive_integration_ready, ( @@ -148,6 +149,7 @@ def _extract_html(output, sink): class IFrameParser(HTMLParser): """A parser to extract iframe content from given HTML.""" + def __init__(self): self._srcdocs = [] super().__init__() diff --git a/sdks/python/apache_beam/runners/interactive/testing/integration/screen_diff.py b/sdks/python/apache_beam/runners/interactive/testing/integration/screen_diff.py index 743d5614f9a2..dd256b5c0858 100644 --- a/sdks/python/apache_beam/runners/interactive/testing/integration/screen_diff.py +++ b/sdks/python/apache_beam/runners/interactive/testing/integration/screen_diff.py @@ -52,6 +52,7 @@ class ScreenDiffIntegrationTestEnvironment(object): """A test environment to conduct screen diff integration tests for notebooks. """ + def __init__( self, test_notebook_path: str, diff --git a/sdks/python/apache_beam/runners/interactive/testing/integration/tests/screen_diff_test.py b/sdks/python/apache_beam/runners/interactive/testing/integration/tests/screen_diff_test.py index a3f8ace0b53f..eaf64f2cffbd 100644 --- a/sdks/python/apache_beam/runners/interactive/testing/integration/tests/screen_diff_test.py +++ b/sdks/python/apache_beam/runners/interactive/testing/integration/tests/screen_diff_test.py @@ -27,6 +27,7 @@ @pytest.mark.timeout(300) class DataFramesTest(BaseTestCase): + def __init__(self, *args, **kwargs): kwargs['golden_size'] = (1024, 10000) super().__init__(*args, **kwargs) @@ -49,6 +50,7 @@ def test_dataframes(self): @pytest.mark.timeout(300) class InitSquareCubeTest(BaseTestCase): + def __init__(self, *args, **kwargs): kwargs['golden_size'] = (1024, 10000) super().__init__(*args, **kwargs) diff --git a/sdks/python/apache_beam/runners/interactive/testing/mock_env.py b/sdks/python/apache_beam/runners/interactive/testing/mock_env.py index 9b8f349d785a..8c4487dc2381 100644 --- a/sdks/python/apache_beam/runners/interactive/testing/mock_env.py +++ b/sdks/python/apache_beam/runners/interactive/testing/mock_env.py @@ -33,7 +33,9 @@ def isolated_env(cls: Type[unittest.TestCase]): """A class decorator for unittest.TestCase to set up an isolated test environment for Interactive Beam.""" + class IsolatedInteractiveEnvironmentTest(cls): + def setUp(self): self.env_patchers = [] interactive_path = 'apache_beam.runners.interactive' diff --git a/sdks/python/apache_beam/runners/interactive/testing/mock_ipython.py b/sdks/python/apache_beam/runners/interactive/testing/mock_ipython.py index e8eb4c4108c3..249463a0dc28 100644 --- a/sdks/python/apache_beam/runners/interactive/testing/mock_ipython.py +++ b/sdks/python/apache_beam/runners/interactive/testing/mock_ipython.py @@ -42,7 +42,9 @@ def some_test(self, cell): # ... # arbitrary python code """ + class MockedGetIpython(object): + def __init__(self): self._execution_count = 0 # Mock as if the kernel is connected to a notebook frontend. diff --git a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py index 6a995de771d8..df4856cd8b53 100644 --- a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py +++ b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py @@ -33,6 +33,7 @@ class InMemoryCache(CacheManager): This is only used for checking the pipeline shape. This can't be used for running the pipeline isn't shared between the SDK and the Runner. """ + def __init__(self): self._cached = {} self._pcoders = {} @@ -85,11 +86,13 @@ def _key(self, *labels): class NoopSink(beam.PTransform): + def expand(self, pcoll): return pcoll | beam.Map(lambda x: x) class FileRecordsBuilder(object): + def __init__(self, tag=None): self._header = beam_interactive_api_pb2.TestStreamFileHeader(tag=tag) self._records = [] diff --git a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py index 53ee54ac8a35..f99b7566eee6 100644 --- a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py +++ b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py @@ -38,6 +38,7 @@ class UserPipelineTracker: pipeline can only have one parent user pipeline. A user pipeline can have many derived pipelines. """ + def __init__(self): self._user_pipelines: dict[beam.Pipeline, list[beam.Pipeline]] = {} self._derived_pipelines: dict[beam.Pipeline] = {} diff --git a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py index f7025b8b75bf..9e3674dd1836 100644 --- a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py +++ b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py @@ -22,6 +22,7 @@ class UserPipelineTrackerTest(unittest.TestCase): + def test_getting_unknown_pid_returns_none(self): ut = UserPipelineTracker() diff --git a/sdks/python/apache_beam/runners/interactive/utils.py b/sdks/python/apache_beam/runners/interactive/utils.py index 828f23a467c2..a04e85276693 100644 --- a/sdks/python/apache_beam/runners/interactive/utils.py +++ b/sdks/python/apache_beam/runners/interactive/utils.py @@ -270,6 +270,7 @@ def __exit__(self, exc_type, exc_value, traceback): def progress_indicated(func: Callable[..., Any]) -> Callable[..., Any]: """A decorator using a unique progress indicator as a context manager to execute the given function within.""" + @functools.wraps(func) def run_within_progress_indicator(*args, **kwargs): with ProgressIndicator(f'Processing... {func.__name__}', 'Done.'): @@ -286,6 +287,7 @@ def as_json(func: Callable[..., Any]) -> Callable[..., str]: If the object is not parsable, the str() of original object is returned instead. """ + def return_as_json(*args, **kwargs): try: return_value = func(*args, **kwargs) @@ -365,6 +367,7 @@ def watch_sources(pipeline): pcoll_to_name = {v: k for k, v in pcoll_by_name().items()} class CacheableUnboundedPCollectionVisitor(PipelineVisitor): + def __init__(self): self.unbounded_pcolls = set() @@ -401,6 +404,7 @@ class CheckUnboundednessVisitor(PipelineVisitor): Visitor visits all nodes and checks if it is an instance of recordable sources. """ + def __init__(self): self.unbounded_sources = [] diff --git a/sdks/python/apache_beam/runners/interactive/utils_test.py b/sdks/python/apache_beam/runners/interactive/utils_test.py index f3d7f96b0dbb..df28098751a6 100644 --- a/sdks/python/apache_beam/runners/interactive/utils_test.py +++ b/sdks/python/apache_beam/runners/interactive/utils_test.py @@ -54,6 +54,7 @@ class MockStorageClient(): + def __init__(self): pass @@ -76,6 +77,7 @@ def windowed_value(e): class ParseToDataframeTest(unittest.TestCase): + def test_parse_windowedvalue(self): """Tests that WindowedValues are supported but not present. """ @@ -148,6 +150,7 @@ def test_parse_series(self): class ToElementListTest(unittest.TestCase): + def test_test_stream_payload_events(self): """Tests that the to_element_list can limit the count in a single bundle.""" @@ -185,6 +188,7 @@ def test_element_limit_count(self): not ie.current_env().is_interactive_ready, '[interactive] dependency is not installed.') class IPythonLogHandlerTest(unittest.TestCase): + def setUp(self): utils.register_ipython_log_handler() self._interactive_root_logger = logging.getLogger( @@ -242,6 +246,7 @@ def test_child_module_logger_can_override_logging_level(self, mock_emit): not ie.current_env().is_interactive_ready, reason='[interactive] dependency is not installed.') class ProgressIndicatorTest(unittest.TestCase): + def setUp(self): ie.new_env() @@ -289,6 +294,7 @@ def setUp(self): ie.new_env() def test_as_json_decorator(self): + @utils.as_json def dummy(): return MessagingUtilTest.SAMPLE_DATA @@ -299,6 +305,7 @@ def dummy(): class GeneralUtilTest(unittest.TestCase): + def test_pcoll_by_name(self): p = beam.Pipeline() pcoll = p | beam.Create([1]) @@ -351,6 +358,7 @@ def test_create_var_in_main(self): @patch('google.cloud.storage.Client', return_value=MockStorageClient()) @unittest.skipIf(not _http_error_imported, 'http errors are not imported.') class GCSUtilsTest(unittest.TestCase): + @patch('google.cloud.storage.Client.get_bucket') def test_assert_bucket_exists_not_found(self, mock_response, mock_client): with self.assertRaises(ValueError): @@ -368,6 +376,7 @@ def test_assert_bucket_exists_found(self, mock_response, mock_client): class PipelineUtilTest(unittest.TestCase): + def test_detect_pipeline_underlying_runner(self): p = beam.Pipeline(InteractiveRunner(underlying_runner=FlinkRunner())) pipeline_runner = utils.detect_pipeline_runner(p) diff --git a/sdks/python/apache_beam/runners/job/manager.py b/sdks/python/apache_beam/runners/job/manager.py index 9c9265d8fd4e..888675a79747 100644 --- a/sdks/python/apache_beam/runners/job/manager.py +++ b/sdks/python/apache_beam/runners/job/manager.py @@ -34,6 +34,7 @@ class DockerRPCManager(object): """A native co-process to start a contianer that speaks the JobApi """ + def __init__(self, run_command=None): # TODO(BEAM-2431): Change this to a docker container from a command. self.process = subprocess.Popen([ diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index 13ab665c1eb1..8fb0fdb6edbd 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -55,6 +55,7 @@ class PortableObject(Protocol): + def to_runner_api(self, __context: 'PipelineContext') -> Any: pass @@ -69,6 +70,7 @@ class _PipelineContextMap(Generic[PortableObjectT]): Under the hood it encodes and decodes these objects into runner API representations. """ + def __init__( self, context: 'PipelineContext', @@ -166,6 +168,7 @@ class PipelineContext(object): Used for accessing and constructing the referenced objects of a Pipeline. """ + def __init__( self, proto: Optional[Union[beam_runner_api_pb2.Components, diff --git a/sdks/python/apache_beam/runners/pipeline_context_test.py b/sdks/python/apache_beam/runners/pipeline_context_test.py index 49ff6f744bf1..0b62337ad7d0 100644 --- a/sdks/python/apache_beam/runners/pipeline_context_test.py +++ b/sdks/python/apache_beam/runners/pipeline_context_test.py @@ -27,6 +27,7 @@ class PipelineContextTest(unittest.TestCase): + def test_deduplication(self): context = pipeline_context.PipelineContext() bytes_coder_ref = context.coders.get_id(coders.BytesCoder()) diff --git a/sdks/python/apache_beam/runners/portability/abstract_job_service.py b/sdks/python/apache_beam/runners/portability/abstract_job_service.py index 87162d5feda5..cb3e06be3c47 100644 --- a/sdks/python/apache_beam/runners/portability/abstract_job_service.py +++ b/sdks/python/apache_beam/runners/portability/abstract_job_service.py @@ -69,15 +69,16 @@ class AbstractJobServiceServicer(beam_job_api_pb2_grpc.JobServiceServicer): Experimental: No backward compatibility guaranteed. Servicer for the Beam Job API. """ + def __init__(self): self._jobs: Dict[str, AbstractBeamJob] = {} - def create_beam_job(self, - preparation_id, # stype: str - job_name: str, - pipeline: beam_runner_api_pb2.Pipeline, - options: struct_pb2.Struct - ) -> 'AbstractBeamJob': + def create_beam_job( + self, + preparation_id, # stype: str + job_name: str, + pipeline: beam_runner_api_pb2.Pipeline, + options: struct_pb2.Struct) -> 'AbstractBeamJob': """Returns an instance of AbstractBeamJob specific to this servicer.""" raise NotImplementedError(type(self)) @@ -185,6 +186,7 @@ def DescribePipelineOptions( class AbstractBeamJob(object): """Abstract baseclass for managing a single Beam job.""" + def __init__( self, job_id: str, @@ -265,6 +267,7 @@ def to_runner_api(self) -> beam_job_api_pb2.JobInfo: class JarArtifactManager(object): + def __init__(self, jar_path, root): self._root = root self._zipfile_handle = zipfile.ZipFile(jar_path, 'a') diff --git a/sdks/python/apache_beam/runners/portability/artifact_service.py b/sdks/python/apache_beam/runners/portability/artifact_service.py index b9395caeafaf..481318bae086 100644 --- a/sdks/python/apache_beam/runners/portability/artifact_service.py +++ b/sdks/python/apache_beam/runners/portability/artifact_service.py @@ -95,6 +95,7 @@ def GetArtifact(self, request, context=None): class ArtifactStagingService( beam_artifact_api_pb2_grpc.ArtifactStagingServiceServicer): + def __init__( self, file_writer: Callable[[str, Optional[str]], Tuple[BinaryIO, str]], @@ -142,6 +143,7 @@ def ReverseArtifactRetrievalService(self, responses, context=None): requests = _QueueIter() class ForwardingRetrievalService(object): + def ResolveArtifactss(self, request): requests.put( beam_artifact_api_pb2.ArtifactRequestWrapper( @@ -162,8 +164,7 @@ def resolve(): for key, dependencies in dependency_sets.items(): dependency_sets[key] = list( resolve_as_files( - ForwardingRetrievalService(), - lambda name: self._file_writer( + ForwardingRetrievalService(), lambda name: self._file_writer( os.path.join(staging_token, name)), dependencies)) requests.done() @@ -258,6 +259,7 @@ def offer_artifacts( class BeamFilesystemHandler(object): + def __init__(self, root): self._root = root diff --git a/sdks/python/apache_beam/runners/portability/artifact_service_test.py b/sdks/python/apache_beam/runners/portability/artifact_service_test.py index 17f1e962b9a0..e9fce2311794 100644 --- a/sdks/python/apache_beam/runners/portability/artifact_service_test.py +++ b/sdks/python/apache_beam/runners/portability/artifact_service_test.py @@ -32,6 +32,7 @@ class InMemoryFileManager(object): + def __init__(self, contents=()): self._contents = dict(contents) @@ -55,6 +56,7 @@ def writable(): class ArtifactServiceTest(unittest.TestCase): + def file_artifact(self, path): return beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.FILE.urn, @@ -129,6 +131,7 @@ def test_push_artifacts(self): dep_big = self.embedded_artifact(data=b'big ' * 100, name='big.txt') class TestArtifacts(object): + def ResolveArtifacts(self, request): replacements = [] for artifact in request.artifacts: diff --git a/sdks/python/apache_beam/runners/portability/expansion_service.py b/sdks/python/apache_beam/runners/portability/expansion_service.py index 4890dd9215e7..f5cd1f48b1ff 100644 --- a/sdks/python/apache_beam/runners/portability/expansion_service.py +++ b/sdks/python/apache_beam/runners/portability/expansion_service.py @@ -38,6 +38,7 @@ class ExpansionServiceServicer( beam_expansion_api_pb2_grpc.ExpansionServiceServicer): + def __init__(self, options=None, loopback_address=None): self._options = options or beam_pipeline.PipelineOptions( flags=[], @@ -83,17 +84,15 @@ def with_pipeline(component, pcoll_id=None): requirements=request.requirements) producers = { pcoll_id: (context.transforms.get_by_id(t_id), pcoll_tag) - for t_id, - t_proto in request.components.transforms.items() for pcoll_tag, - pcoll_id in t_proto.outputs.items() + for t_id, t_proto in request.components.transforms.items() + for pcoll_tag, pcoll_id in t_proto.outputs.items() } transform = with_pipeline( ptransform.PTransform.from_runner_api(request.transform, context)) if len(request.output_coder_requests) == 1: output_coder = { k: context.element_type_from_coder_id(v) - for k, - v in request.output_coder_requests.items() + for k, v in request.output_coder_requests.items() } transform = transform.with_output_types(list(output_coder.values())[0]) elif len(request.output_coder_requests) > 1: @@ -101,10 +100,9 @@ def with_pipeline(component, pcoll_id=None): 'type annotation for multiple outputs is not allowed yet: %s' % request.output_coder_requests) inputs = transform._pvaluish_from_dict({ - tag: - with_pipeline(context.pcollections.get_by_id(pcoll_id), pcoll_id) - for tag, - pcoll_id in request.transform.inputs.items() + tag: with_pipeline( + context.pcollections.get_by_id(pcoll_id), pcoll_id) + for tag, pcoll_id in request.transform.inputs.items() }) if not inputs: inputs = pipeline diff --git a/sdks/python/apache_beam/runners/portability/expansion_service_test.py b/sdks/python/apache_beam/runners/portability/expansion_service_test.py index 7aa2e5f16e5b..ff02b554b269 100644 --- a/sdks/python/apache_beam/runners/portability/expansion_service_test.py +++ b/sdks/python/apache_beam/runners/portability/expansion_service_test.py @@ -62,6 +62,7 @@ @ptransform.PTransform.register_urn('beam:transforms:xlang:count', None) class CountPerElementTransform(ptransform.PTransform): + def expand(self, pcoll): return pcoll | combine.Count.PerElement() @@ -77,6 +78,7 @@ def from_runner_api_parameter( @ptransform.PTransform.register_urn( 'beam:transforms:xlang:filter_less_than_eq', bytes) class FilterLessThanTransform(ptransform.PTransform): + def __init__(self, payload): self._payload = payload @@ -97,6 +99,7 @@ def from_runner_api_parameter(unused_ptransform, payload, unused_context): @ptransform.PTransform.register_urn(TEST_PREFIX_URN, None) @beam.typehints.with_output_types(str) class PrefixTransform(ptransform.PTransform): + def __init__(self, payload): self._payload = payload @@ -115,6 +118,7 @@ def from_runner_api_parameter(unused_ptransform, payload, unused_context): @ptransform.PTransform.register_urn(TEST_MULTI_URN, None) class MutltiTransform(ptransform.PTransform): + def expand(self, pcolls): return { 'main': (pcolls['main1'], pcolls['main2']) @@ -136,6 +140,7 @@ def from_runner_api_parameter( @ptransform.PTransform.register_urn(TEST_GBK_URN, None) class GBKTransform(ptransform.PTransform): + def expand(self, pcoll): return pcoll | 'TestLabel' >> beam.GroupByKey() @@ -150,7 +155,9 @@ def from_runner_api_parameter( @ptransform.PTransform.register_urn(TEST_CGBK_URN, None) class CoGBKTransform(ptransform.PTransform): + class ConcatFn(beam.DoFn): + def process(self, element): (k, v) = element return [(k, v['col1'] + v['col2'])] @@ -172,6 +179,7 @@ def from_runner_api_parameter( @ptransform.PTransform.register_urn(TEST_COMGL_URN, None) class CombineGloballyTransform(ptransform.PTransform): + def expand(self, pcoll): return pcoll \ | beam.CombineGlobally(sum).with_output_types(int) @@ -187,6 +195,7 @@ def from_runner_api_parameter( @ptransform.PTransform.register_urn(TEST_COMPK_URN, None) class CombinePerKeyTransform(ptransform.PTransform): + def expand(self, pcoll): output = pcoll \ | beam.CombinePerKey(sum) @@ -206,6 +215,7 @@ def from_runner_api_parameter( @ptransform.PTransform.register_urn(TEST_FLATTEN_URN, None) class FlattenTransform(ptransform.PTransform): + def expand(self, pcoll): return pcoll.values() | beam.Flatten().with_output_types(int) @@ -220,6 +230,7 @@ def from_runner_api_parameter( @ptransform.PTransform.register_urn(TEST_PARTITION_URN, None) class PartitionTransform(ptransform.PTransform): + def expand(self, pcoll): col1, col2 = pcoll | beam.Partition( lambda elem, n: 0 if elem % 2 == 0 else 1, 2) @@ -237,6 +248,7 @@ def from_runner_api_parameter( class ExtractHtmlTitleDoFn(beam.DoFn): + def process(self, element): from bs4 import BeautifulSoup soup = BeautifulSoup(element, 'html.parser') @@ -245,6 +257,7 @@ def process(self, element): @ptransform.PTransform.register_urn(TEST_PYTHON_BS4_URN, None) class ExtractHtmlTitleTransform(ptransform.PTransform): + def expand(self, pcoll): return pcoll | beam.ParDo(ExtractHtmlTitleDoFn()).with_output_types(str) @@ -259,6 +272,7 @@ def from_runner_api_parameter( @ptransform.PTransform.register_urn('payload', bytes) class PayloadTransform(ptransform.PTransform): + def __init__(self, payload): self._payload = payload @@ -275,7 +289,9 @@ def from_runner_api_parameter(unused_ptransform, payload, unused_context): @ptransform.PTransform.register_urn('map_to_union_types', None) class MapToUnionTypesTransform(ptransform.PTransform): + class CustomDoFn(beam.DoFn): + def process(self, element): if element == 1: return ['1'] @@ -298,6 +314,7 @@ def from_runner_api_parameter( @ptransform.PTransform.register_urn('fib', bytes) class FibTransform(ptransform.PTransform): + def __init__(self, level): self._level = level @@ -327,7 +344,9 @@ def from_runner_api_parameter(unused_ptransform, level, unused_context): @ptransform.PTransform.register_urn(TEST_NO_OUTPUT_URN, None) class NoOutputTransform(ptransform.PTransform): + def expand(self, pcoll): + def log_val(val): logging.debug('Got value: %r', val) diff --git a/sdks/python/apache_beam/runners/portability/flink_runner.py b/sdks/python/apache_beam/runners/portability/flink_runner.py index c9bf15b46e22..e73cbe7c481b 100644 --- a/sdks/python/apache_beam/runners/portability/flink_runner.py +++ b/sdks/python/apache_beam/runners/portability/flink_runner.py @@ -87,6 +87,7 @@ def add_http_scheme(flink_master): class FlinkJarJobServer(job_server.JavaJarJobServer): + def __init__(self, options): super().__init__(options) options = options.view_as(pipeline_options.FlinkRunnerOptions) diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py index 30f1a4c06025..e5de425f6537 100644 --- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py @@ -257,8 +257,8 @@ def test_expand_kafka_read(self): allow_duplicates=False, expansion_service=self.get_expansion_service())) self.assertTrue( - 'No resolvable bootstrap urls given in bootstrap.servers' in str( - ctx.exception), + 'No resolvable bootstrap urls given in bootstrap.servers' + in str(ctx.exception), 'Expected to fail due to invalid bootstrap.servers, but ' 'failed due to:\n%s' % str(ctx.exception)) @@ -354,6 +354,7 @@ def test_pack_combiners(self): class FlinkRunnerTestStreaming(FlinkRunnerTest): + def create_options(self): options = super().create_options() options.view_as(StandardOptions).streaming = True diff --git a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py index 3b302e334a5f..daf641590e91 100644 --- a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py +++ b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py @@ -42,6 +42,7 @@ class FlinkUberJarJobServer(abstract_job_service.AbstractJobServiceServicer): The jar contains the Beam pipeline definition, dependencies, and the pipeline artifacts. """ + def __init__(self, master_url, options): super().__init__() self._master_url = master_url @@ -107,6 +108,7 @@ def GetJobMetrics(self, request, context=None): class FlinkBeamJob(abstract_job_service.UberJarBeamJob): """Runs a single Beam job on Flink by staging all contents into a Jar and uploading it via the Flink Rest API.""" + def __init__( self, master_url, @@ -213,6 +215,7 @@ def get_state(self): return state, timestamp def get_state_stream(self): + def _state_iter(): sleep_secs = 1.0 while True: diff --git a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py index 12ba3940d396..6f4cbd0cff97 100644 --- a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py +++ b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py @@ -45,6 +45,7 @@ def temp_name(*args, **kwargs): class FlinkUberJarJobServerTest(unittest.TestCase): + @requests_mock.mock() def test_flink_version(self, http_mock): http_mock.get('http://flink/v1/config', json={'flink-version': '3.1.4.1'}) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py index e69e37495f64..97466306743c 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py @@ -94,6 +94,7 @@ class Buffer(Protocol): + def __iter__(self) -> Iterator[bytes]: pass @@ -105,6 +106,7 @@ def extend(self, other: 'Buffer') -> None: class PartitionableBuffer(Buffer, Protocol): + def copy(self) -> 'PartitionableBuffer': pass @@ -124,6 +126,7 @@ def reset(self) -> None: class ListBuffer: """Used to support parititioning of a list.""" + def __init__(self, coder_impl: Optional[CoderImpl]) -> None: self._coder_impl = coder_impl or CoderImpl() self._inputs: List[bytes] = [] @@ -190,6 +193,7 @@ def reset(self) -> None: class GroupingBuffer(object): """Used to accumulate groupded (shuffled) results.""" + def __init__( self, pre_grouped_coder: coders.Coder, @@ -251,7 +255,8 @@ def partition(self, n: int) -> List[List[bytes]]: index=0, nonspeculative_index=0)).with_value windowed_key_values = lambda key, values: [ - globally_window((key, values))] + globally_window((key, values)) + ] else: # TODO(pabloem, BEAM-7514): Trigger driver needs access to the clock # note that this only comes through if windowing is default - but what @@ -292,6 +297,7 @@ def reset(self) -> None: class WindowGroupingBuffer(object): """Used to partition windowed side inputs.""" + def __init__( self, access_pattern: beam_runner_api_pb2.FunctionSpec, @@ -382,7 +388,9 @@ class _ProcessingQueueManager(object): the time is the real time point at which the inputs should be scheduled, and inputs are dictionaries mapping PCollection name to data buffers. """ + class KeyedQueue(Generic[QUEUE_KEY_TYPE]): + def __init__(self) -> None: self._q: typing.Deque[Tuple[QUEUE_KEY_TYPE, DataInput]] = collections.deque() @@ -551,6 +559,7 @@ def make_process_bundle_descriptor( """Creates a ProcessBundleDescriptor for invoking the WindowFn's merge operation. """ + def make_channel_payload(coder_id: str) -> bytes: data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) if data_api_service_descriptor: @@ -661,6 +670,7 @@ class FnApiRunnerExecutionContext(object): PCollection IDs to list that functions as buffer for the ``beam.PCollection``. """ + def __init__( self, stages: List[translations.Stage], @@ -990,6 +1000,7 @@ def commit_side_inputs_to_state( class BundleContextManager(object): + def __init__( self, execution_context: FnApiRunnerExecutionContext, @@ -1169,8 +1180,8 @@ def input_for(self, transform_id: str, input_id: str) -> str: input_pcoll in proto.outputs.values()): return read_id # The GrpcRead is followed by the SDF/Truncate -> SDF/Process. - if (proto.spec.urn == - common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn and + if (proto.spec.urn + == common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn and input_pcoll in proto.outputs.values()): read_input = list( self.process_bundle_descriptor.transforms[read_id].inputs.values() diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py index 95bcb7567918..7d166cfffd99 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py @@ -1077,6 +1077,7 @@ def get_cache_token_generator( If False, generator returns a new cache token each time :return A generator which returns a cache token on next(generator) """ + def generate_token( identifier: int) -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken: return beam_fn_api_pb2.ProcessBundleRequest.CacheToken( @@ -1085,6 +1086,7 @@ def generate_token( token="cache_token_{}".format(identifier).encode("utf-8")) class StaticGenerator(object): + def __init__(self) -> None: self._token = generate_token(1) @@ -1096,6 +1098,7 @@ def __next__(self) -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken: return self._token class DynamicGenerator(object): + def __init__(self) -> None: self._counter = 0 self._lock = threading.Lock() @@ -1116,6 +1119,7 @@ def __next__(self) -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken: class ExtendedProvisionInfo(object): + def __init__( self, provision_info: Optional[beam_provision_api_pb2.ProvisionInfo] = None, @@ -1293,15 +1297,15 @@ def _generate_splits_for_testing( self._worker_handler.control_conn.push(split_request).get()) for t in (0.05, 0.1, 0.2): if ('Unknown process bundle' in split_response.error or - split_response.process_bundle_split == - beam_fn_api_pb2.ProcessBundleSplitResponse()): + split_response.process_bundle_split + == beam_fn_api_pb2.ProcessBundleSplitResponse()): time.sleep(t) split_response = self._worker_handler.control_conn.push( split_request).get() logging.info('Got split response %s', split_response) if ('Unknown process bundle' in split_response.error or - split_response.process_bundle_split == - beam_fn_api_pb2.ProcessBundleSplitResponse()): + split_response.process_bundle_split + == beam_fn_api_pb2.ProcessBundleSplitResponse()): # It may have finished too fast. split_result = None elif split_response.error: @@ -1410,6 +1414,7 @@ def process_bundle( class ParallelBundleManager(BundleManager): + def __init__( self, bundle_context_manager: execution.BundleContextManager, @@ -1482,6 +1487,7 @@ class ProgressRequester(threading.Thread): A callback can be passed to call with progress updates. """ + def __init__( self, worker_handler: WorkerHandler, @@ -1526,6 +1532,7 @@ def stop(self): class FnApiMetrics(metric.MetricResults): + def __init__(self, step_monitoring_infos, user_metrics_only=True): """Used for querying metrics from the PipelineResult object. @@ -1551,24 +1558,24 @@ def __init__(self, step_monitoring_infos, user_metrics_only=True): def query(self, filter=None): counters = [ - MetricResult(k, v, v) for k, - v in self._counters.items() if self.matches(filter, k) + MetricResult(k, v, v) for k, v in self._counters.items() + if self.matches(filter, k) ] distributions = [ - MetricResult(k, v, v) for k, - v in self._distributions.items() if self.matches(filter, k) + MetricResult(k, v, v) for k, v in self._distributions.items() + if self.matches(filter, k) ] gauges = [ - MetricResult(k, v, v) for k, - v in self._gauges.items() if self.matches(filter, k) + MetricResult(k, v, v) for k, v in self._gauges.items() + if self.matches(filter, k) ] string_sets = [ - MetricResult(k, v, v) for k, - v in self._string_sets.items() if self.matches(filter, k) + MetricResult(k, v, v) for k, v in self._string_sets.items() + if self.matches(filter, k) ] bounded_tries = [ - MetricResult(k, v, v) for k, - v in self._bounded_tries.items() if self.matches(filter, k) + MetricResult(k, v, v) for k, v in self._bounded_tries.items() + if self.matches(filter, k) ] return { @@ -1586,6 +1593,7 @@ def monitoring_infos(self) -> List[metrics_pb2.MonitoringInfo]: class RunnerResult(runner.PipelineResult): + def __init__(self, state, monitoring_infos_by_stage): super().__init__(state) self._monitoring_infos_by_stage = monitoring_infos_by_stage diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py index 3f036ab27f6e..f35a5b74269e 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py @@ -95,6 +95,7 @@ def _matcher_or_equal_to(value_or_matcher): def has_urn_and_labels(mi, urn, labels): """Returns true if it the monitoring_info contains the labels and urn.""" + def contains_labels(mi, labels): # Check all the labels and their values exist in the monitoring_info return all(item in mi.labels.items() for item in labels.items()) @@ -103,6 +104,7 @@ def contains_labels(mi, labels): class FnApiRunnerTest(unittest.TestCase): + def create_pipeline(self, is_drain=False): return beam.Pipeline(runner=fn_api_runner.FnApiRunner(is_drain=is_drain)) @@ -138,7 +140,9 @@ def test_batch_pardo(self): assert_that(res, equal_to([6, 12, 18])) def test_batch_pardo_override_type_inference(self): + class ArrayMultiplyDoFnOverride(beam.DoFn): + def process_batch(self, batch, *unused_args, **unused_kwargs) -> Iterator[np.ndarray]: assert isinstance(batch, np.ndarray) @@ -207,7 +211,9 @@ def test_batch_rebatch_pardos(self): assert_that(res, equal_to([9, 15, 21])) def test_batch_pardo_fusion_break(self): + class NormalizeDoFn(beam.DoFn): + @no_type_check def process_batch( self, @@ -238,7 +244,9 @@ def infer_output_type(self, input_type): assert_that(res, equal_to([-2, 0, 2])) def test_batch_pardo_dofn_params(self): + class ConsumeParamsDoFn(beam.DoFn): + @no_type_check def process_batch( self, @@ -269,7 +277,9 @@ def infer_output_type(self, input_type): assert_that(res, equal_to([0, 1, 0, 3, 0, 5, 0, 7, 0, 9])) def test_batch_pardo_window_param(self): + class PerWindowDoFn(beam.DoFn): + @no_type_check def process_batch( self, @@ -296,7 +306,9 @@ def infer_output_type(self, input_type): assert_that(res, equal_to([0, 0, 0, 0, 0, 25, 30, 35, 40, 45])) def test_batch_pardo_overlapping_windows(self): + class PerWindowDoFn(beam.DoFn): + @no_type_check def process_batch(self, batch: np.ndarray, @@ -318,15 +330,32 @@ def infer_output_type(self, input_type): | beam.WindowInto(window.SlidingWindows(size=5, period=3)) | beam.ParDo(PerWindowDoFn())) - assert_that(res, equal_to([ 0*-3, 1*-3, # [-3, 2) - 0*0, 1*0, 2*0, 3* 0, 4* 0, # [ 0, 5) - 3*3, 4*3, 5*3, 6* 3, 7* 3, # [ 3, 8) - 6*6, 7*6, 8*6, 9* 6, # [ 6, 11) - 9*9 # [ 9, 14) - ])) + assert_that( + res, + equal_to([ + 0 * -3, + 1 * -3, # [-3, 2) + 0 * 0, + 1 * 0, + 2 * 0, + 3 * 0, + 4 * 0, # [ 0, 5) + 3 * 3, + 4 * 3, + 5 * 3, + 6 * 3, + 7 * 3, # [ 3, 8) + 6 * 6, + 7 * 6, + 8 * 6, + 9 * 6, # [ 6, 11) + 9 * 9 # [ 9, 14) + ])) def test_batch_to_element_pardo(self): + class ArraySumDoFn(beam.DoFn): + @beam.DoFn.yields_elements def process_batch(self, batch: np.ndarray, *unused_args, **unused_kwargs) -> Iterator[np.int64]: @@ -348,7 +377,9 @@ def infer_output_type(self, input_type): assert_that(res, equal_to([99 * 50 * 2])) def test_element_to_batch_pardo(self): + class ArrayProduceDoFn(beam.DoFn): + @beam.DoFn.yields_batches def process(self, element: np.int64, *unused_args, **unused_kwargs) -> Iterator[np.ndarray]: @@ -388,6 +419,7 @@ def test_pardo_large_input(self): assert_that(res, equal_to([(i * 2) + 3 for i in range(5000)])) def test_pardo_side_outputs(self): + def tee(elem, *tags): for tag in tags: if tag in elem: @@ -402,6 +434,7 @@ def tee(elem, *tags): assert_that(xy.y, equal_to(['y', 'xy']), label='y') def test_pardo_side_and_main_outputs(self): + def even_odd(elem): yield elem yield beam.pvalue.TaggedOutput('odd' if elem % 2 else 'even', elem) @@ -421,6 +454,7 @@ def even_odd(elem): assert_that(unnamed.odd, equal_to([1, 3]), label='unnamed.odd') def test_pardo_side_inputs(self): + def cross_product(elem, sides): for side in sides: yield elem, side @@ -565,15 +599,12 @@ def test_multimap_multiside_input(self): side = p | 'side' >> beam.Create([('a', 1), ('b', 2), ('a', 3)]) assert_that( main | 'first map' >> beam.Map( - lambda k, - d, - l: (k, sorted(d[k]), sorted([e[1] for e in l])), + lambda k, d, l: (k, sorted(d[k]), sorted([e[1] for e in l])), beam.pvalue.AsMultiMap(side), beam.pvalue.AsList(side)) | 'second map' >> beam.Map( - lambda k, - d, - l: (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])), + lambda k, d, l: + (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])), beam.pvalue.AsMultiMap(side), beam.pvalue.AsList(side)), equal_to([('a', [1, 3], [1, 2, 3]), ('b', [2], [1, 2, 3])])) @@ -593,6 +624,7 @@ def test_multimap_side_input_type_coercion(self): equal_to([('a', [1, 3]), ('b', [2])])) def test_pardo_unfusable_side_inputs(self): + def cross_product(elem, sides): for side in sides: yield elem, side @@ -604,6 +636,7 @@ def cross_product(elem, sides): equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')])) def test_pardo_unfusable_side_inputs_with_separation(self): + def cross_product(elem, sides): for side in sides: yield elem, side @@ -625,6 +658,7 @@ def test_pardo_state_only(self): # TODO(ccy): State isn't detected with Map/FlatMap. class AddIndex(beam.DoFn): + def process( self, kv, @@ -649,6 +683,7 @@ def test_teststream_pardo_timers(self): timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK) class TimerDoFn(beam.DoFn): + def process(self, element, timer=beam.DoFn.TimerParam(timer_spec)): unused_key, ts = element timer.set(ts) @@ -679,6 +714,7 @@ def test_pardo_timers(self): state_spec = userstate.CombiningValueStateSpec('num_called', sum) class TimerDoFn(beam.DoFn): + def process(self, element, timer=beam.DoFn.TimerParam(timer_spec)): unused_key, ts = element timer.set(ts) @@ -711,6 +747,7 @@ def test_pardo_timers_clear(self): 'clear_timer', userstate.TimeDomain.WATERMARK) class TimerDoFn(beam.DoFn): + def process( self, element, @@ -762,6 +799,7 @@ def _run_pardo_state_timers(self, windowed, key_type=None): buffer_size = 3 class BufferDoFn(beam.DoFn): + def process( self, kv, @@ -818,6 +856,7 @@ def is_buffered_correctly(actual): assert_that(actual, is_buffered_correctly) def test_pardo_dynamic_timer(self): + class DynamicTimerDoFn(beam.DoFn): dynamic_timer_spec = userstate.TimerSpec( 'dynamic_timer', userstate.TimeDomain.WATERMARK) @@ -842,7 +881,9 @@ def dynamic_timer_callback( assert_that(actual, equal_to([('key1', 10), ('key2', 20), ('key3', 30)])) def test_sdf(self): + class ExpandingStringsDoFn(beam.DoFn): + def process( self, element, @@ -860,7 +901,9 @@ def process( assert_that(actual, equal_to(list(''.join(data)))) def test_sdf_with_dofn_as_restriction_provider(self): + class ExpandingStringsDoFn(beam.DoFn, ExpandStringsProvider): + def process( self, element, restriction_tracker=beam.DoFn.RestrictionParam()): assert isinstance(restriction_tracker, RestrictionTrackerView) @@ -875,7 +918,9 @@ def process( assert_that(actual, equal_to(list(''.join(data)))) def test_sdf_with_check_done_failed(self): + class ExpandingStringsDoFn(beam.DoFn): + def process( self, element, @@ -894,7 +939,9 @@ def process( _ = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn())) def test_sdf_with_watermark_tracking(self): + class ExpandingStringsDoFn(beam.DoFn): + def process( self, element, @@ -920,7 +967,9 @@ def process( assert_that(actual, equal_to(list(''.join(data)))) def test_sdf_with_dofn_as_watermark_estimator(self): + class ExpandingStringsDoFn(beam.DoFn, beam.WatermarkEstimatorProvider): + def initial_estimator_state(self, element, restriction): return None @@ -955,6 +1004,7 @@ def run_sdf_initiated_checkpointing(self, is_drain=False): counter = beam.metrics.Metrics.counter('ns', 'my_counter') class ExpandStringsDoFn(beam.DoFn): + def process( self, element, @@ -990,7 +1040,9 @@ def test_draining_sdf_with_sdf_initiated_checkpointing(self): self.run_sdf_initiated_checkpointing(is_drain=True) def test_sdf_default_truncate_when_bounded(self): + class SimpleSDF(beam.DoFn): + def process( self, element, @@ -1007,7 +1059,9 @@ def process( assert_that(actual, equal_to(range(10))) def test_sdf_default_truncate_when_unbounded(self): + class SimpleSDF(beam.DoFn): + def process( self, element, @@ -1024,7 +1078,9 @@ def process( assert_that(actual, equal_to([])) def test_sdf_with_truncate(self): + class SimpleSDF(beam.DoFn): + def process( self, element, @@ -1145,8 +1201,7 @@ def test_large_elements(self): side_input_res = ( big | beam.Map( - lambda x, - side: (x[0], side.count(x[0])), + lambda x, side: (x[0], side.count(x[0])), beam.pvalue.AsList(big | beam.Map(lambda x: x[0])))) assert_that( side_input_res, @@ -1177,6 +1232,7 @@ def raise_error(x): self.assertNotIn('StageB', message) def test_error_traceback_includes_user_code(self): + def first(x): return second(x) @@ -1200,7 +1256,9 @@ def third(x): self.assertIn('third', message) def test_no_subtransform_composite(self): + class First(beam.PTransform): + def expand(self, pcolls): return pcolls[0] @@ -1265,6 +1323,7 @@ def raise_expetion(): raise Exception('raise exception when calling callback') class FinalizebleDoFnWithException(beam.DoFn): + def process( self, element, bundle_finalizer=beam.DoFn.BundleFinalizerParam): bundle_finalizer.register(raise_expetion) @@ -1281,6 +1340,7 @@ def test_register_finalizations(self): event_recorder = EventRecorder(tempfile.gettempdir()) class FinalizableSplittableDoFn(beam.DoFn): + def process( self, element, @@ -1349,6 +1409,7 @@ def test_create_value_provider_pipeline_option(self): # provider pipeline options # pylint: disable=unused-variable class FooOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( @@ -1371,6 +1432,7 @@ def max_with_counter(values): return max(values) class PackableCombines(beam.PTransform): + def annotations(self): return {python_urns.APPLY_COMBINER_PACKING: b''} @@ -1424,6 +1486,7 @@ def test_group_by_key_with_empty_pcoll_elements(self): # upon repeating bundle processing due to unncessarily incrementing # the sampling counter. class FnApiRunnerMetricsTest(unittest.TestCase): + def assert_has_counter( self, mon_infos, urn, labels, value=None, ge_value=None): # TODO(ajamato): Consider adding a matcher framework @@ -1498,7 +1561,9 @@ def create_pipeline(self): return beam.Pipeline(runner=fn_api_runner.FnApiRunner()) def test_element_count_metrics(self): + class GenerateTwoOutputs(beam.DoFn): + def process(self, element): yield str(element) + '1' yield beam.pvalue.TaggedOutput('SecondOutput', str(element) + '2') @@ -1506,6 +1571,7 @@ def process(self, element): yield beam.pvalue.TaggedOutput('ThirdOutput', str(element) + '3') class PassThrough(beam.DoFn): + def process(self, element): yield element @@ -1737,14 +1803,8 @@ def test_progress_metrics(self): | beam.GroupByKey() | 'm_out' >> beam.FlatMap( lambda x: [ - 1, - 2, - 3, - 4, - 5, - beam.pvalue.TaggedOutput('once', x), - beam.pvalue.TaggedOutput('twice', x), - beam.pvalue.TaggedOutput('twice', x) + 1, 2, 3, 4, 5, beam.pvalue.TaggedOutput('once', x), beam.pvalue. + TaggedOutput('twice', x), beam.pvalue.TaggedOutput('twice', x) ])) res = p.run() @@ -1812,6 +1872,7 @@ def has_mi_for_ptransform(mon_infos, ptransform): class FnApiRunnerTestWithGrpc(FnApiRunnerTest): + def create_pipeline(self, is_drain=False): return beam.Pipeline( runner=fn_api_runner.FnApiRunner( @@ -1821,6 +1882,7 @@ def create_pipeline(self, is_drain=False): class FnApiRunnerTestWithDisabledCaching(FnApiRunnerTest): + def create_pipeline(self, is_drain=False): return beam.Pipeline( runner=fn_api_runner.FnApiRunner( @@ -1833,6 +1895,7 @@ def create_pipeline(self, is_drain=False): class FnApiRunnerTestWithMultiWorkers(FnApiRunnerTest): + def create_pipeline(self, is_drain=False): pipeline_options = PipelineOptions(direct_num_workers=2) p = beam.Pipeline( @@ -1862,6 +1925,7 @@ def test_register_finalizations(self): class FnApiRunnerTestWithGrpcAndMultiWorkers(FnApiRunnerTest): + def create_pipeline(self, is_drain=False): pipeline_options = PipelineOptions( direct_num_workers=2, direct_running_mode='multi_threading') @@ -1892,6 +1956,7 @@ def test_register_finalizations(self): class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest): + def create_pipeline(self, is_drain=False): return beam.Pipeline( runner=fn_api_runner.FnApiRunner(bundle_repeat=3, is_drain=is_drain)) @@ -1901,6 +1966,7 @@ def test_register_finalizations(self): class FnApiRunnerTestWithBundleRepeatAndMultiWorkers(FnApiRunnerTest): + def create_pipeline(self, is_drain=False): pipeline_options = PipelineOptions(direct_num_workers=2) p = beam.Pipeline( @@ -1930,6 +1996,7 @@ def test_sdf_with_dofn_as_watermark_estimator(self): class FnApiRunnerSplitTest(unittest.TestCase): + def create_pipeline(self, is_drain=False): # Must be GRPC so we can send data and split requests concurrent # to the bundle process request. @@ -2073,6 +2140,7 @@ def split_manager(num_elements): raise def test_nosplit_sdf(self): + def split_manager(num_elements): yield @@ -2109,6 +2177,7 @@ def run_sdf_split_pipeline( # Define an SDF that for each input x produces [(x, k) for k in range(x)]. class EnumerateProvider(beam.transforms.core.RestrictionProvider): + def initial_restriction(self, element): return restriction_trackers.OffsetRange(0, element) @@ -2126,6 +2195,7 @@ def is_bounded(self): return True class EnumerateSdf(beam.DoFn): + def process( self, element, @@ -2157,6 +2227,7 @@ def test_time_based_split_manager(self): elements = [str(x) for x in range(100)] class BundleCountingDoFn(beam.DoFn): + def process(self, element): time.sleep(0.005) yield element @@ -2198,6 +2269,7 @@ def verify_channel_split(self, split_result, last_primary, first_residual): class ElementCounter(object): """Used to wait until a certain number of elements are seen.""" + def __init__(self): self._cv = threading.Condition() self.reset() @@ -2221,6 +2293,7 @@ def set_breakpoint(self, value): self._breakpoints[value].append(event) class Breakpoint(object): + @staticmethod def wait(timeout=10): with self._cv: @@ -2257,6 +2330,7 @@ class EventRecorder(object): The reason why records are written into a tmp file is, the in-memory dataset cannot keep callback records when passing into one DoFn. """ + def __init__(self, tmp_dir): self.tmp_dir = os.path.join(tmp_dir, uuid.uuid4().hex) os.mkdir(self.tmp_dir) @@ -2283,6 +2357,7 @@ def cleanup(self): class ExpandStringsProvider(beam.transforms.core.RestrictionProvider): """A RestrictionProvider that used for sdf related tests.""" + def initial_restriction(self, element): return restriction_trackers.OffsetRange(0, len(element)) @@ -2299,11 +2374,13 @@ def restriction_size(self, element, restriction): class UnboundedOffsetRestrictionTracker( restriction_trackers.OffsetRestrictionTracker): + def is_bounded(self): return False class OffsetRangeProvider(beam.transforms.core.RestrictionProvider): + def __init__(self, use_bounded_offset_range, checkpoint_only=False): self.use_bounded_offset_range = use_bounded_offset_range self.checkpoint_only = checkpoint_only @@ -2316,6 +2393,7 @@ def create_tracker(self, restriction): class CheckpointOnlyOffsetRestrictionTracker( restriction_trackers.OffsetRestrictionTracker): + def try_split(self, unused_fraction_of_remainder): return super().try_split(0.0) @@ -2332,6 +2410,7 @@ def restriction_size(self, element, restriction): class OffsetRangeProviderWithTruncate(OffsetRangeProvider): + def __init__(self): super().__init__(True) @@ -2341,6 +2420,7 @@ def truncate(self, element, restriction): class FnApiBasedLullLoggingTest(unittest.TestCase): + def create_pipeline(self): return beam.Pipeline( runner=fn_api_runner.FnApiRunner( @@ -2369,6 +2449,7 @@ def __reduce__(self): @pytest.mark.it_validatesrunner class FnApiBasedStateBackedCoderTest(unittest.TestCase): + def create_pipeline(self): return beam.Pipeline( runner=fn_api_runner.FnApiRunner(use_state_iterables=True)) @@ -2394,6 +2475,7 @@ def test_gbk_many_values(self): # TODO(robertwb): Why does pickling break when this is inlined? class CustomMergingWindowFn(window.WindowFn): + def assign(self, assign_context): return [ window.IntervalWindow( @@ -2413,6 +2495,7 @@ def get_window_coder(self): class ColoredFixedWindow(window.BoundedWindow): + def __init__(self, end, color): super().__init__(end) self.color = color @@ -2441,6 +2524,7 @@ def is_deterministic(self): class EvenOddWindows(window.NonMergingWindowFn): + def assign(self, context): timestamp = context.timestamp return [ @@ -2454,6 +2538,7 @@ def get_window_coder(self): class ExpectingSideInputsFn(beam.DoFn): + def __init__(self, name): self._name = name @@ -2468,6 +2553,7 @@ def process(self, element, *side_inputs): class ArrayMultiplyDoFn(beam.DoFn): + def process_batch(self, batch: np.ndarray, *unused_args, **unused_kwargs) -> Iterator[np.ndarray]: assert isinstance(batch, np.ndarray) @@ -2483,6 +2569,7 @@ def infer_output_type(self, input_type): class ListPlusOneDoFn(beam.DoFn): + def process_batch(self, batch: List[np.int64], *unused_args, **unused_kwargs) -> Iterator[List[np.int64]]: assert isinstance(batch, list) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py index c1c7f649f77a..26490477264a 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py @@ -133,6 +133,7 @@ class DataInput(NamedTuple): class Stage(object): """A set of Transforms that can be sent to the worker for processing.""" + def __init__( self, name, # type: str @@ -225,6 +226,7 @@ def is_runner_urn(self, context): for transform in self.transforms) def is_all_sdk_urns(self, context): + def is_sdk_transform(transform): # Execute multi-input flattens in the runner. if transform.spec.urn == common_urns.primitives.FLATTEN.urn and len( @@ -332,8 +334,7 @@ def executable_stage_transform( beam_runner_api_pb2.ExecutableStagePayload.TimerId( transform_id=transform_id, local_name=tag)) main_inputs.update( - pcoll_id for tag, - pcoll_id in transform.inputs.items() + pcoll_id for tag, pcoll_id in transform.inputs.items() if tag not in payload.side_inputs) else: main_inputs.update(transform.inputs.values()) @@ -341,8 +342,8 @@ def executable_stage_transform( main_input_id = only_element(main_inputs - all_outputs) named_inputs = dict({ - '%s:%s' % (side.transform_id, side.local_name): - stage_components.transforms[side.transform_id].inputs[side.local_name] + '%s:%s' % (side.transform_id, side.local_name): stage_components. + transforms[side.transform_id].inputs[side.local_name] for side in side_inputs }, main_input=main_input_id) @@ -367,8 +368,7 @@ def executable_stage_transform( inputs=named_inputs, outputs={ 'output_%d' % ix: pcoll - for ix, - pcoll in enumerate(external_outputs) + for ix, pcoll in enumerate(external_outputs) }, ) @@ -519,8 +519,8 @@ def maybe_length_prefixed_and_safe_coder(self, coder_id): # have the runner treat it as opaque bytes. return coder_id, self.bytes_coder_id elif (coder.spec.urn == common_urns.coders.WINDOWED_VALUE.urn and - self.components.coders[coder.component_coder_ids[1]].spec.urn not in - self._known_coder_urns): + self.components.coders[coder.component_coder_ids[1]].spec.urn + not in self._known_coder_urns): # A WindowedValue coder with an unknown window type. # This needs to be encoded in such a way that we still have access to its # timestmap. @@ -659,8 +659,7 @@ def pipeline_from_stages( roots = {} # type: Dict[str, Any] parents = { child: parent - for parent, - proto in pipeline_proto.components.transforms.items() + for parent, proto in pipeline_proto.components.transforms.items() for child in proto.subtransforms } @@ -798,8 +797,7 @@ def standard_optimize_phases(): pack_combiners, lift_combiners, expand_sdf, - fix_flatten_coders, - # sink_flattens, + fix_flatten_coders, # sink_flattens, greedily_fuse, read_to_impulse, extract_impulse_stages, @@ -1035,6 +1033,7 @@ def pack_per_key_combiners(stages, context, can_pack=lambda s: True): tuples from this PCollection and sends them to the original output PCollections. """ + class _UnpackFn(core.DoFn): """A DoFn that unpacks a packed to multiple tagged outputs. @@ -1043,14 +1042,15 @@ class _UnpackFn(core.DoFn): input = (K, (V1, V2, ...)) output = TaggedOutput(T1, (K, V1)), TaggedOutput(T2, (K, V1)), ... """ + def __init__(self, tags): self._tags = tags def process(self, element): key, values = element return [ - core.pvalue.TaggedOutput(tag, (key, value)) for tag, - value in zip(self._tags, values) + core.pvalue.TaggedOutput(tag, (key, value)) + for tag, value in zip(self._tags, values) ] def _get_fallback_coder_id(): @@ -1088,8 +1088,8 @@ def _get_limit(stage_name): # and group eligible CombinePerKey stages by parent and environment. def get_stage_key(stage): if (len(stage.transforms) == 1 and can_pack(stage.name) and - stage.environment is not None and python_urns.PACKED_COMBINE_FN in - context.components.environments[stage.environment].capabilities): + stage.environment is not None and python_urns.PACKED_COMBINE_FN + in context.components.environments[stage.environment].capabilities): transform = only_transform(stage.transforms) if (transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn and len(transform.inputs) == 1 and len(transform.outputs) == 1): @@ -1104,10 +1104,12 @@ def get_stage_key(stage): for stage in ineligible_stages: yield stage - grouped_packable_stages = [(stage_key, subgrouped_stages) for stage_key, - grouped_stages in grouped_eligible_stages.items() - for subgrouped_stages in _group_stages_with_limit( - grouped_stages, _get_limit)] + grouped_packable_stages = [ + (stage_key, subgrouped_stages) + for stage_key, grouped_stages in grouped_eligible_stages.items() + for subgrouped_stages in _group_stages_with_limit( + grouped_stages, _get_limit) + ] for stage_key, packable_stages in grouped_packable_stages: input_pcoll_id, _ = stage_key @@ -1275,6 +1277,7 @@ def lift_combiners(stages, context): ... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ... """ + def is_compatible_with_combiner_lifting(trigger): '''Returns whether this trigger is compatible with combiner lifting. @@ -2041,7 +2044,8 @@ def sort_stages(stages, pipeline_context): producers = { pcoll: stage - for stage in all_stages for t in stage.transforms + for stage in all_stages + for t in stage.transforms for pcoll in t.outputs.values() } diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations_test.py index 3ff2421e6265..1ecdbf03a530 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations_test.py @@ -37,8 +37,11 @@ class TranslationsTest(unittest.TestCase): + def test_eliminate_common_key_with_void(self): + class MultipleKeyWithNone(beam.PTransform): + def expand(self, pcoll): _ = pcoll | 'key-with-none-a' >> beam.ParDo(core._KeyWithNone()) _ = pcoll | 'key-with-none-b' >> beam.ParDo(core._KeyWithNone()) @@ -58,7 +61,9 @@ def expand(self, pcoll): self.assertIn('multiple-key-with-none', key_with_none_stages[0].parent) def test_pack_combiners(self): + class MultipleCombines(beam.PTransform): + def annotations(self): return {python_urns.APPLY_COMBINER_PACKING: b''} @@ -89,7 +94,9 @@ def expand(self, pcoll): self.assertNotIn('-perkey', combine_per_key_stages[0].parent) def test_pack_combiners_with_missing_environment_capability(self): + class MultipleCombines(beam.PTransform): + def annotations(self): return {python_urns.APPLY_COMBINER_PACKING: b''} @@ -120,7 +127,9 @@ def expand(self, pcoll): 'Packed', combine_per_key_stage.transforms[0].unique_name) def test_pack_global_combiners(self): + class MultipleCombines(beam.PTransform): + def annotations(self): return {python_urns.APPLY_COMBINER_PACKING: b''} @@ -170,7 +179,9 @@ def test_optimize_empty_pipeline(self): optimized_pipeline_proto, runner, pipeline_options.PipelineOptions()) def test_optimize_single_combine_globally(self): + class SingleCombine(beam.PTransform): + def annotations(self): return {python_urns.APPLY_COMBINER_PACKING: b''} @@ -193,7 +204,9 @@ def expand(self, pcoll): optimized_pipeline_proto, runner, pipeline_options.PipelineOptions()) def test_optimize_multiple_combine_globally(self): + class MultipleCombines(beam.PTransform): + def annotations(self): return {python_urns.APPLY_COMBINER_PACKING: b''} @@ -223,6 +236,7 @@ def test_pipeline_from_sorted_stages_is_toplogically_ordered(self): side = pipeline | 'side' >> Create([3, 4]) class CreateAndMultiplyBySide(beam.PTransform): + def expand(self, pcoll): return ( pcoll | 'main' >> Create([1, 2]) | 'compute' >> beam.FlatMap( @@ -251,7 +265,9 @@ def assert_is_topologically_sorted(transform_id, visited_pcolls): @pytest.mark.it_validatesrunner def test_run_packable_combine_per_key(self): + class MultipleCombines(beam.PTransform): + def annotations(self): return {python_urns.APPLY_COMBINER_PACKING: b''} @@ -281,7 +297,9 @@ def expand(self, pcoll): @pytest.mark.it_validatesrunner def test_run_packable_combine_globally(self): + class MultipleCombines(beam.PTransform): + def annotations(self): return {python_urns.APPLY_COMBINER_PACKING: b''} @@ -309,7 +327,9 @@ def expand(self, pcoll): @pytest.mark.it_validatesrunner def test_run_packable_combine_limit(self): + class MultipleLargeCombines(beam.PTransform): + def annotations(self): # Limit to at most 2 combiners per packed combiner. return {python_urns.APPLY_COMBINER_PACKING: b'2'} @@ -329,6 +349,7 @@ def expand(self, pcoll): label='assert-min-3-globally') class MultipleSmallCombines(beam.PTransform): + def annotations(self): # Limit to at most 4 combiners per packed combiner. return {python_urns.APPLY_COMBINER_PACKING: b'4'} @@ -389,7 +410,9 @@ def test_combineperkey_annotation_propagation(self): Test that the CPK component transforms inherit annotations from the source CPK """ + class MyCombinePerKey(beam.CombinePerKey): + def annotations(self): return {"my_annotation": b""} @@ -411,11 +434,13 @@ def annotations(self): 'MyCombinePerKey(min)/Merge', 'MyCombinePerKey(min)/ExtractOutputs']: assert ( - "my_annotation" in - optimized.components.transforms[transform_id].annotations) + "my_annotation" + in optimized.components.transforms[transform_id].annotations) def test_conditionally_packed_combiners(self): + class RecursiveCombine(beam.PTransform): + def __init__(self, labels): self._labels = labels diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/trigger_manager.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/trigger_manager.py index 021f5950d71d..19d541465c3f 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/trigger_manager.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/trigger_manager.py @@ -59,6 +59,7 @@ class _ReifyWindows(DoFn): """Receives KV pairs, and wraps the values into WindowedValues.""" + def process( self, element, window=DoFn.WindowParam, timestamp=DoFn.TimestampParam): try: @@ -72,6 +73,7 @@ def process( class _GroupBundlesByKey(DoFn): + def start_bundle(self): self.keys = defaultdict(list) @@ -94,6 +96,7 @@ def read_watermark(watermark_state): class TriggerMergeContext(WindowFn.MergeContext): + def __init__( self, all_windows, context: 'FnRunnerStatefulTriggerContext', windowing): super().__init__(all_windows) @@ -319,6 +322,7 @@ def watermark_trigger( class FnRunnerStatefulTriggerContext(TriggerContext): + def __init__( self, processing_time_timer: RuntimeTimer, @@ -409,6 +413,7 @@ def clear_state(self, tag): class PerWindowTriggerContext(TriggerContext): + def __init__(self, window, parent: FnRunnerStatefulTriggerContext): self.window = window self.parent = parent diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/trigger_manager_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/trigger_manager_test.py index 8a071520ad15..f9ce8f117722 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/trigger_manager_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/trigger_manager_test.py @@ -39,7 +39,9 @@ class TriggerManagerTest(unittest.TestCase): + def test_with_trigger_window_that_finish(self): + def tsv(key, value, ts): return TimestampedValue((key, value), timestamp=ts) @@ -79,6 +81,7 @@ def tsv(key, value, ts): ])) def test_fixed_windows_simple_watermark(self): + def tsv(key, value, ts): return TimestampedValue((key, value), timestamp=ts) @@ -116,8 +119,8 @@ def tsv(key, value, ts): equal_to([ ('k1', IntervalWindow(0, 1), [1, 2, 3]), # On the watermark ('k2', IntervalWindow(0, 1), [1, 2, 3]), # On the watermark - ('k1', IntervalWindow(1, 2), [4, 5]), # On the watermark - ('k2', IntervalWindow(1, 2), [4, 5]), # On the watermark + ('k1', IntervalWindow(1, 2), [4, 5]), # On the watermark + ('k2', IntervalWindow(1, 2), [4, 5]), # On the watermark ('k1', IntervalWindow(0, 1), [6]), # After the watermark ])) @@ -202,6 +205,7 @@ def test_fixed_after_count_accumulating(self): ])) def test_sessions_and_complex_trigger_accumulating(self): + def tsv(key, value, ts): return TimestampedValue((key, value), timestamp=ts) @@ -238,11 +242,11 @@ def tsv(key, value, ts): assert_that( result, equal_to([ - ('k1', IntervalWindow(1, 25), {1, 2, 3}), # early - ('k1', IntervalWindow(1, 25), {1, 2, 3}), # on time + ('k1', IntervalWindow(1, 25), {1, 2, 3}), # early + ('k1', IntervalWindow(1, 25), {1, 2, 3}), # on time ('k1', IntervalWindow(30, 40), {4}), # on time - ('k1', IntervalWindow(1, 25), {1, 2, 3, -3, -2}), # late - ('k1', IntervalWindow(1, 40), {1, 2, 3, 4, -3, -2, -1}), # late + ('k1', IntervalWindow(1, 25), {1, 2, 3, -3, -2}), # late + ('k1', IntervalWindow(1, 40), {1, 2, 3, 4, -3, -2, -1}), # late ])) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/watermark_manager.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/watermark_manager.py index 106eca108297..cbaf0b417c97 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/watermark_manager.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/watermark_manager.py @@ -36,7 +36,9 @@ class WatermarkManager(object): """Manages the watermarks of a pipeline's stages. It works by constructing an internal graph representation of the pipeline, and keeping track of dependencies.""" + class PCollectionNode(object): + def __init__(self, name): self.name = name self._watermark = timestamp.MIN_TIMESTAMP @@ -70,6 +72,7 @@ def watermark(self): return self._watermark class StageNode(object): + def __init__(self, name): # We keep separate inputs and side inputs because side inputs # should hold back a stage's input watermark, to hold back execution diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py index d798e96d3aa3..ea69b0ac8ace 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py @@ -207,10 +207,11 @@ def get_conn_by_worker_id(self, worker_id): with self._lock: return self._connections_by_worker_id[worker_id] - def Control(self, - iterator, # type: Iterable[beam_fn_api_pb2.InstructionResponse] - context # type: ServicerContext - ): + def Control( + self, + iterator, # type: Iterable[beam_fn_api_pb2.InstructionResponse] + context # type: ServicerContext + ): # type: (...) -> Iterator[beam_fn_api_pb2.InstructionRequest] with self._lock: if self._state == self.DONE_STATE: @@ -262,12 +263,13 @@ class WorkerHandler(object): control_conn = None # type: ControlConnection data_conn = None # type: data_plane._GrpcDataChannel - def __init__(self, - control_handler, # type: Any - data_plane_handler, # type: Any - state, # type: sdk_worker.StateHandler - provision_info # type: ExtendedProvisionInfo - ): + def __init__( + self, + control_handler, # type: Any + data_plane_handler, # type: Any + state, # type: sdk_worker.StateHandler + provision_info # type: ExtendedProvisionInfo + ): # type: (...) -> None """Initialize a WorkerHandler. @@ -334,12 +336,13 @@ def wrapper(constructor): return wrapper @classmethod - def create(cls, - environment, # type: beam_runner_api_pb2.Environment - state, # type: sdk_worker.StateHandler - provision_info, # type: ExtendedProvisionInfo - grpc_server # type: GrpcServer - ): + def create( + cls, + environment, # type: beam_runner_api_pb2.Environment + state, # type: sdk_worker.StateHandler + provision_info, # type: ExtendedProvisionInfo + grpc_server # type: GrpcServer + ): # type: (...) -> WorkerHandler constructor, payload_type = cls._registered_environments[environment.urn] return constructor( @@ -356,12 +359,13 @@ def create(cls, class EmbeddedWorkerHandler(WorkerHandler): """An in-memory worker_handler for fn API control, state and data planes.""" - def __init__(self, - unused_payload, # type: None - state, # type: sdk_worker.StateHandler - provision_info, # type: ExtendedProvisionInfo - worker_manager, # type: WorkerHandlerManager - ): + def __init__( + self, + unused_payload, # type: None + state, # type: sdk_worker.StateHandler + provision_info, # type: ExtendedProvisionInfo + worker_manager, # type: WorkerHandlerManager + ): # type: (...) -> None super().__init__( self, data_plane.InMemoryDataChannel(), state, provision_info) @@ -414,6 +418,7 @@ def logging_api_service_descriptor(self): class BasicLoggingService(beam_fn_api_pb2_grpc.BeamFnLoggingServicer): + def Logging(self, log_messages, context=None): # type: (Iterable[beam_fn_api_pb2.LogEntry.List], Any) -> Iterator[beam_fn_api_pb2.LogControl] yield beam_fn_api_pb2.LogControl() @@ -424,6 +429,7 @@ def Logging(self, log_messages, context=None): class BasicProvisionService(beam_provision_api_pb2_grpc.ProvisionServiceServicer ): + def __init__(self, base_info, worker_manager): # type: (beam_provision_api_pb2.ProvisionInfo, WorkerHandlerManager) -> None self._base_info = base_info @@ -448,11 +454,12 @@ class GrpcServer(object): _DEFAULT_SHUTDOWN_TIMEOUT_SECS = 5 - def __init__(self, - state, # type: StateServicer - provision_info, # type: Optional[ExtendedProvisionInfo] - worker_manager, # type: WorkerHandlerManager - ): + def __init__( + self, + state, # type: StateServicer + provision_info, # type: Optional[ExtendedProvisionInfo] + worker_manager, # type: WorkerHandlerManager + ): # type: (...) -> None # Options to have no limits (-1) on the size of the messages @@ -540,11 +547,12 @@ def close(self): class GrpcWorkerHandler(WorkerHandler): """An grpc based worker_handler for fn API control, state and data planes.""" - def __init__(self, - state, # type: StateServicer - provision_info, # type: ExtendedProvisionInfo - grpc_server # type: GrpcServer - ): + def __init__( + self, + state, # type: StateServicer + provision_info, # type: ExtendedProvisionInfo + grpc_server # type: GrpcServer + ): # type: (...) -> None self._grpc_server = grpc_server super().__init__( @@ -604,12 +612,14 @@ def host_from_worker(self): @WorkerHandler.register_environment( common_urns.environments.EXTERNAL.urn, beam_runner_api_pb2.ExternalPayload) class ExternalWorkerHandler(GrpcWorkerHandler): - def __init__(self, - external_payload, # type: beam_runner_api_pb2.ExternalPayload - state, # type: StateServicer - provision_info, # type: ExtendedProvisionInfo - grpc_server # type: GrpcServer - ): + + def __init__( + self, + external_payload, # type: beam_runner_api_pb2.ExternalPayload + state, # type: StateServicer + provision_info, # type: ExtendedProvisionInfo + grpc_server # type: GrpcServer + ): # type: (...) -> None super().__init__(state, provision_info, grpc_server) self._external_payload = external_payload @@ -649,12 +659,14 @@ def host_from_worker(self): @WorkerHandler.register_environment(python_urns.EMBEDDED_PYTHON_GRPC, bytes) class EmbeddedGrpcWorkerHandler(GrpcWorkerHandler): - def __init__(self, - payload, # type: bytes - state, # type: StateServicer - provision_info, # type: ExtendedProvisionInfo - grpc_server # type: GrpcServer - ): + + def __init__( + self, + payload, # type: bytes + state, # type: StateServicer + provision_info, # type: ExtendedProvisionInfo + grpc_server # type: GrpcServer + ): # type: (...) -> None super().__init__(state, provision_info, grpc_server) @@ -690,12 +702,14 @@ def stop_worker(self): @WorkerHandler.register_environment(python_urns.SUBPROCESS_SDK, bytes) class SubprocessSdkWorkerHandler(GrpcWorkerHandler): - def __init__(self, - worker_command_line, # type: bytes - state, # type: StateServicer - provision_info, # type: ExtendedProvisionInfo - grpc_server # type: GrpcServer - ): + + def __init__( + self, + worker_command_line, # type: bytes + state, # type: StateServicer + provision_info, # type: ExtendedProvisionInfo + grpc_server # type: GrpcServer + ): # type: (...) -> None super().__init__(state, provision_info, grpc_server) self._worker_command_line = worker_command_line @@ -720,12 +734,14 @@ def stop_worker(self): @WorkerHandler.register_environment( common_urns.environments.DOCKER.urn, beam_runner_api_pb2.DockerPayload) class DockerSdkWorkerHandler(GrpcWorkerHandler): - def __init__(self, - payload, # type: beam_runner_api_pb2.DockerPayload - state, # type: StateServicer - provision_info, # type: ExtendedProvisionInfo - grpc_server # type: GrpcServer - ): + + def __init__( + self, + payload, # type: beam_runner_api_pb2.DockerPayload + state, # type: StateServicer + provision_info, # type: ExtendedProvisionInfo + grpc_server # type: GrpcServer + ): # type: (...) -> None super().__init__(state, provision_info, grpc_server) self._container_image = payload.container_image @@ -852,10 +868,12 @@ class WorkerHandlerManager(object): Caches ``WorkerHandler``s based on environment id. """ - def __init__(self, - environments, # type: Mapping[str, beam_runner_api_pb2.Environment] - job_provision_info # type: ExtendedProvisionInfo - ): + + def __init__( + self, + environments, # type: Mapping[str, beam_runner_api_pb2.Environment] + job_provision_info # type: ExtendedProvisionInfo + ): # type: (...) -> None self._environments = environments self._job_provision_info = job_provision_info @@ -966,6 +984,7 @@ class StateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer, ]) class CopyOnWriteState(object): + def __init__(self, underlying): # type: (DefaultDict[bytes, Buffer]) -> None self._underlying = underlying @@ -989,7 +1008,9 @@ def commit(self): return self._underlying class CopyOnWriteList(object): - def __init__(self, + + def __init__( + self, underlying, # type: DefaultDict[bytes, Buffer] overlay, # type: Dict[bytes, Buffer] key # type: bytes @@ -1063,10 +1084,11 @@ def _get_one_interval_key(self, state_key, start): state_key_copy.ordered_list_user_state.range.end = start + 1 return self._to_key(state_key_copy) - def get_raw(self, + def get_raw( + self, state_key, # type: beam_fn_api_pb2.StateKey continuation_token=None # type: Optional[bytes] - ): + ): # type: (...) -> Tuple[bytes, Optional[bytes]] if state_key.WhichOneof('type') not in self._SUPPORTED_STATE_TYPES: @@ -1176,14 +1198,16 @@ def _to_key(state_key): class GrpcStateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer): + def __init__(self, state): # type: (StateServicer) -> None self._state = state - def State(self, + def State( + self, request_stream, # type: Iterable[beam_fn_api_pb2.StateRequest] context=None # type: Any - ): + ): # type: (...) -> Iterator[beam_fn_api_pb2.StateResponse] # Note that this eagerly mutates state, assuming any failures are fatal. # Thus it is safe to ignore instruction_id. @@ -1213,6 +1237,7 @@ def State(self, class SingletonStateHandlerFactory(sdk_worker.StateHandlerFactory): """A singleton cache for a StateServicer.""" + def __init__(self, state_handler): # type: (sdk_worker.CachingStateHandler) -> None self._state_handler = state_handler @@ -1231,10 +1256,12 @@ def close(self): class ControlFuture(object): - def __init__(self, - instruction_id, # type: str - response=None # type: Optional[beam_fn_api_pb2.InstructionResponse] - ): + + def __init__( + self, + instruction_id, # type: str + response=None # type: Optional[beam_fn_api_pb2.InstructionResponse] + ): # type: (...) -> None self.instruction_id = instruction_id self._response = response diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers_test.py index 832e7ecee801..210bb22bd40c 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers_test.py @@ -27,6 +27,7 @@ class WorkerHandlerManagerTest(unittest.TestCase): + def test_close_all(self): inprocess_env = environments.EmbeddedPythonEnvironment( capabilities=environments.python_sdk_capabilities(), diff --git a/sdks/python/apache_beam/runners/portability/job_server.py b/sdks/python/apache_beam/runners/portability/job_server.py index eee75f66a277..26b7d197c484 100644 --- a/sdks/python/apache_beam/runners/portability/job_server.py +++ b/sdks/python/apache_beam/runners/portability/job_server.py @@ -33,6 +33,7 @@ class JobServer(object): + def start(self): """Starts this JobServer, returning a grpc service to which to submit jobs. """ @@ -44,6 +45,7 @@ def stop(self): class ExternalJobServer(JobServer): + def __init__(self, endpoint, timeout=None): self._endpoint = endpoint self._timeout = timeout @@ -58,6 +60,7 @@ def stop(self): class EmbeddedJobServer(JobServer): + def start(self) -> 'local_job_service.LocalJobServicer': return local_job_service.LocalJobServicer() @@ -68,6 +71,7 @@ def stop(self): class StopOnExitJobServer(JobServer): """Wraps a JobServer such that its stop will automatically be called on exit. """ + def __init__(self, job_server): self._lock = threading.Lock() self._job_server = job_server @@ -91,6 +95,7 @@ def stop(self): class SubprocessJobServer(JobServer): """An abstract base class for JobServers run as an external process.""" + def __init__(self): self._local_temp_root = None self._server = None @@ -118,6 +123,7 @@ def local_temp_dir(self, **kwargs): class JavaJarJobServer(SubprocessJobServer): + def __init__(self, options): super().__init__() options = options.view_as(pipeline_options.JobServerOptions) diff --git a/sdks/python/apache_beam/runners/portability/job_server_test.py b/sdks/python/apache_beam/runners/portability/job_server_test.py index 13b3629b24bf..337a37ee00a0 100644 --- a/sdks/python/apache_beam/runners/portability/job_server_test.py +++ b/sdks/python/apache_beam/runners/portability/job_server_test.py @@ -24,6 +24,7 @@ class JavaJarJobServerStub(JavaJarJobServer): + def java_arguments( self, job_port, artifact_port, expansion_port, artifacts_dir): return [ @@ -47,6 +48,7 @@ def local_jar(url, jar_cache_dir=None): class JavaJarJobServerTest(unittest.TestCase): + def test_subprocess_cmd_and_endpoint(self): pipeline_options = PipelineOptions([ '--job_port=8099', diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py b/sdks/python/apache_beam/runners/portability/local_job_service.py index a2b4e5e7f939..9c438e61c1d4 100644 --- a/sdks/python/apache_beam/runners/portability/local_job_service.py +++ b/sdks/python/apache_beam/runners/portability/local_job_service.py @@ -78,6 +78,7 @@ class LocalJobServicer(abstract_job_service.AbstractJobServiceServicer): inline calls rather than GRPC (for speed) or launch completely separate subprocesses for the runner and worker(s). """ + def __init__(self, staging_dir=None, beam_job_type=None): super().__init__() self._cleanup_staging_dir = staging_dir is None @@ -88,12 +89,12 @@ def __init__(self, staging_dir=None, beam_job_type=None): endpoints_pb2.ApiServiceDescriptor] = None self._beam_job_type = beam_job_type or BeamJob - def create_beam_job(self, - preparation_id, # stype: str - job_name: str, - pipeline: beam_runner_api_pb2.Pipeline, - options: struct_pb2.Struct - ) -> 'BeamJob': + def create_beam_job( + self, + preparation_id, # stype: str + job_name: str, + pipeline: beam_runner_api_pb2.Pipeline, + options: struct_pb2.Struct) -> 'BeamJob': self._artifact_service.register_job( staging_token=preparation_id, dependency_sets=_extract_dependency_sets( @@ -176,6 +177,7 @@ def GetJobMetrics(self, request, context=None): class SubprocessSdkWorker(object): """Manages a SDK worker implemented as a subprocess communicating over grpc. """ + def __init__( self, worker_command_line: bytes, @@ -235,6 +237,7 @@ class BeamJob(abstract_job_service.AbstractBeamJob): The current state of the pipeline is available as self.state. """ + def __init__( self, job_id: str, @@ -254,6 +257,7 @@ def __init__( self.result = None def pipeline_options(self): + def from_urn(key): assert key.startswith('beam:option:') assert key.endswith(':v1') @@ -363,6 +367,7 @@ def get_message_stream(self): class BeamFnLoggingServicer(beam_fn_api_pb2_grpc.BeamFnLoggingServicer): + def Logging(self, log_bundles, context=None): for log_bundle in log_bundles: for log_entry in log_bundle.log_entries: @@ -374,6 +379,7 @@ def Logging(self, log_bundles, context=None): class JobLogQueues(object): + def __init__(self): self._queues: List[queue.Queue] = [] self._cache = [] @@ -463,6 +469,7 @@ def _extract_dependency_sets( The values can then be resolved and the mapping passed back to _update_dependency_sets to update the dependencies in the original protos. """ + def dependencies_iter(): for env_id, env in envs.items(): for ix, sub_env in enumerate(environments.expand_anyof_environments(env)): diff --git a/sdks/python/apache_beam/runners/portability/local_job_service_test.py b/sdks/python/apache_beam/runners/portability/local_job_service_test.py index 7d9d70d98df8..8b3256ce53e4 100644 --- a/sdks/python/apache_beam/runners/portability/local_job_service_test.py +++ b/sdks/python/apache_beam/runners/portability/local_job_service_test.py @@ -30,6 +30,7 @@ class TestJobServicePlan(JobServiceHandle): + def __init__(self, job_service): self.job_service = job_service self.options = None @@ -41,6 +42,7 @@ def get_pipeline_options(self): class LocalJobServerTest(unittest.TestCase): + def test_end_to_end(self): job_service = local_job_service.LocalJobServicer() diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py index fe9dcfa62b29..574c03658297 100644 --- a/sdks/python/apache_beam/runners/portability/portable_runner.py +++ b/sdks/python/apache_beam/runners/portability/portable_runner.py @@ -87,6 +87,7 @@ class JobServiceHandle(object): - stage - run """ + def __init__(self, job_service, options, retain_unknown_options=False): self.job_service = job_service self.options = options @@ -169,6 +170,7 @@ def add_runner_options(parser): @staticmethod def encode_pipeline_options( all_options: Dict[str, Any]) -> 'struct_pb2.Struct': + def convert_pipeline_option_value(v): # convert int values: BEAM-5509 if type(v) == int: @@ -181,8 +183,7 @@ def convert_pipeline_option_value(v): # TODO: Define URNs for options. p_options = { 'beam:option:' + k + ':v1': convert_pipeline_option_value(v) - for k, - v in all_options.items() if v is not None + for k, v in all_options.items() if v is not None } return job_utils.dict_to_struct(p_options) @@ -256,6 +257,7 @@ class PortableRunner(runner.PipelineRunner): This runner schedules the job on a job service. The responsibility of running and managing the job lies with the job service used. """ + def __init__(self): self._dockerized_job_server: Optional[job_server.JobServer] = None @@ -423,6 +425,7 @@ def start_and_replace_loopback_environments(pipeline, options): class PortableMetrics(metric.MetricResults): + def __init__(self, job_metrics_response): metrics = job_metrics_response.metrics self.attempted = portable_metrics.from_monitoring_infos(metrics.attempted) @@ -452,6 +455,7 @@ def query(self, filter=None): class PipelineResult(runner.PipelineResult): + def __init__( self, job_service, @@ -529,6 +533,7 @@ def wait_until_finish(self, duration=None): the execution. If None or zero, will wait until the pipeline finishes. :return: The result of the pipeline, i.e. PipelineResult. """ + def read_messages() -> None: previous_state = -1 for message in self._message_stream: diff --git a/sdks/python/apache_beam/runners/portability/portable_runner_test.py b/sdks/python/apache_beam/runners/portability/portable_runner_test.py index 85d1607e9fa1..7d86f0f1d8c2 100644 --- a/sdks/python/apache_beam/runners/portability/portable_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/portable_runner_test.py @@ -145,6 +145,7 @@ def _maybe_kill_subprocess(cls): time.sleep(0.1) def create_options(self): + def get_pipeline_name(): for _, _, _, method_name, _, _ in inspect.stack(): if method_name.find('test') != -1: @@ -187,11 +188,13 @@ def test_pardo_state_with_custom_key_coder(self): # Use a DoFn which has to use FastPrimitivesCoder because the type cannot # be inferred class Input(beam.DoFn): + def process(self, impulse): for i in inputs: yield i class AddIndex(beam.DoFn): + def process(self, kv, index=beam.DoFn.StateParam(index_state_spec)): k, v = kv index.add(1) @@ -225,6 +228,7 @@ def test_draining_sdf_with_sdf_initiated_checkpointing(self): @unittest.skip("https://github.com/apache/beam/issues/19422") class PortableRunnerOptimized(PortableRunnerTest): + def create_options(self): options = super().create_options() options.view_as(DebugOptions).add_experiment('pre_optimize=all') @@ -237,6 +241,7 @@ def create_options(self): # TODO(https://github.com/apache/beam/issues/19422): Delete this test after # PortableRunner supports beam:runner:executable_stage:v1. class PortableRunnerOptimizedWithoutFusion(PortableRunnerTest): + def create_options(self): options = super().create_options() options.view_as(DebugOptions).add_experiment( @@ -248,6 +253,7 @@ def create_options(self): class PortableRunnerTestWithExternalEnv(PortableRunnerTest): + @classmethod def setUpClass(cls): cls._worker_address, cls._worker_server = ( @@ -310,6 +316,7 @@ def create_options(self): class PortableRunnerInternalTest(unittest.TestCase): + def setUp(self) -> None: self.tmp_dir = tempfile.TemporaryDirectory() self.actual_mkdtemp = tempfile.mkdtemp @@ -426,6 +433,7 @@ def hasDockerImage(): not hasDockerImage(), "docker not installed or " "no docker image") class PortableRunnerTestWithLocalDocker(PortableRunnerTest): + def create_options(self): options = super().create_options() options.view_as(PortableOptions).job_endpoint = 'embed' diff --git a/sdks/python/apache_beam/runners/portability/prism_runner.py b/sdks/python/apache_beam/runners/portability/prism_runner.py index 77dc8a214e8e..d010f92fa8e0 100644 --- a/sdks/python/apache_beam/runners/portability/prism_runner.py +++ b/sdks/python/apache_beam/runners/portability/prism_runner.py @@ -56,6 +56,7 @@ class PrismRunner(portable_runner.PortableRunner): """A runner for launching jobs on Prism, automatically downloading and starting a Prism instance if needed. """ + def default_environment( self, options: pipeline_options.PipelineOptions) -> environments.Environment: @@ -217,8 +218,8 @@ def path_to_binary(self) -> str: # We failed to build for some reason. output = process.stdout.decode("utf-8") - if ("not in a module" not in output) and ( - "no required module provides" not in output): + if ("not in a module" not in output) and ("no required module provides" + not in output): # This branch handles two classes of failures: # 1. Go isn't installed, so it needs to be installed by the Beam SDK # developer. diff --git a/sdks/python/apache_beam/runners/portability/sdk_container_builder.py b/sdks/python/apache_beam/runners/portability/sdk_container_builder.py index 489973304f5f..2fa009e3e1d7 100644 --- a/sdks/python/apache_beam/runners/portability/sdk_container_builder.py +++ b/sdks/python/apache_beam/runners/portability/sdk_container_builder.py @@ -66,6 +66,7 @@ class SdkContainerImageBuilder(plugin.BeamPlugin): + def __init__(self, options): self._options = options self._docker_registry_push_url = self._options.view_as( @@ -159,6 +160,7 @@ def _get_subclass_by_key(cls, key: str) -> Type['SdkContainerImageBuilder']: class _SdkContainerImageLocalBuilder(SdkContainerImageBuilder): """SdkContainerLocalBuilder builds the sdk container image with local docker.""" + @classmethod def _builder_key(cls): return 'local_docker' @@ -200,6 +202,7 @@ def _invoke_docker_build_and_push(self, container_image_name): class _SdkContainerImageCloudBuilder(SdkContainerImageBuilder): """SdkContainerLocalBuilder builds the sdk container image with google cloud build.""" + def __init__(self, options): super().__init__(options) self._google_cloud_options = options.view_as(GoogleCloudOptions) diff --git a/sdks/python/apache_beam/runners/portability/sdk_container_builder_test.py b/sdks/python/apache_beam/runners/portability/sdk_container_builder_test.py index 955fe328f171..71aab9d877c6 100644 --- a/sdks/python/apache_beam/runners/portability/sdk_container_builder_test.py +++ b/sdks/python/apache_beam/runners/portability/sdk_container_builder_test.py @@ -29,6 +29,7 @@ class SdkContainerBuilderTest(unittest.TestCase): + def tearDown(self): # Ensures SdkContainerImageBuilder subclasses are cleared gc.collect() @@ -55,11 +56,13 @@ def test_missing_builder_key_throws_value_error(self): def test_multiple_matchings_keys_throws_value_error(self): # pylint: disable=unused-variable class _PluginSdkBuilder(sdk_container_builder.SdkContainerImageBuilder): + @classmethod def _builder_key(cls): return 'test-id' class _PluginSdkBuilder2(sdk_container_builder.SdkContainerImageBuilder): + @classmethod def _builder_key(cls): return 'test-id' @@ -71,6 +74,7 @@ def _builder_key(cls): 'test-id') def test_can_find_new_subclass(self): + class _PluginSdkBuilder(sdk_container_builder.SdkContainerImageBuilder): pass diff --git a/sdks/python/apache_beam/runners/portability/spark_java_job_server_test.py b/sdks/python/apache_beam/runners/portability/spark_java_job_server_test.py index 50490d9c5c15..e9fb19f5cf47 100644 --- a/sdks/python/apache_beam/runners/portability/spark_java_job_server_test.py +++ b/sdks/python/apache_beam/runners/portability/spark_java_job_server_test.py @@ -24,6 +24,7 @@ class SparkTestPipelineOptions(pipeline_options.PipelineOptions): + def view_as(self, cls): # Ensure only SparkRunnerOptions and JobServerOptions are used when calling # default_job_server. If other options classes are needed, the cache key @@ -35,6 +36,7 @@ def view_as(self, cls): class SparkJavaJobServerTest(unittest.TestCase): + def test_job_server_cache(self): # Multiple SparkRunner instances may be created, so we need to make sure we # cache job servers across runner instances. diff --git a/sdks/python/apache_beam/runners/portability/spark_runner.py b/sdks/python/apache_beam/runners/portability/spark_runner.py index 480fbdecdce3..4ee9c2b0e238 100644 --- a/sdks/python/apache_beam/runners/portability/spark_runner.py +++ b/sdks/python/apache_beam/runners/portability/spark_runner.py @@ -78,6 +78,7 @@ def create_job_service_handle(self, job_service, options): class SparkJarJobServer(job_server.JavaJarJobServer): + def __init__(self, options): super().__init__(options) options = options.view_as(pipeline_options.SparkRunnerOptions) diff --git a/sdks/python/apache_beam/runners/portability/spark_uber_jar_job_server.py b/sdks/python/apache_beam/runners/portability/spark_uber_jar_job_server.py index f754b4c330ad..2b84e56a0b8f 100644 --- a/sdks/python/apache_beam/runners/portability/spark_uber_jar_job_server.py +++ b/sdks/python/apache_beam/runners/portability/spark_uber_jar_job_server.py @@ -44,6 +44,7 @@ class SparkUberJarJobServer(abstract_job_service.AbstractJobServiceServicer): The jar contains the Beam pipeline definition, dependencies, and the pipeline artifacts. """ + def __init__(self, rest_url, options): super().__init__() self._rest_url = rest_url @@ -97,6 +98,7 @@ class SparkBeamJob(abstract_job_service.UberJarBeamJob): Note that the Spark Rest API is not enabled by default. It must be enabled by setting the configuration property spark.master.rest.enabled to true.""" + def __init__( self, rest_url, diff --git a/sdks/python/apache_beam/runners/portability/spark_uber_jar_job_server_test.py b/sdks/python/apache_beam/runners/portability/spark_uber_jar_job_server_test.py index a99bec840bee..3014e9b41cd0 100644 --- a/sdks/python/apache_beam/runners/portability/spark_uber_jar_job_server_test.py +++ b/sdks/python/apache_beam/runners/portability/spark_uber_jar_job_server_test.py @@ -59,6 +59,7 @@ def spark_job(): class SparkUberJarJobServerTest(unittest.TestCase): + @requests_mock.mock() def test_get_server_spark_version(self, http_mock): http_mock.get( diff --git a/sdks/python/apache_beam/runners/portability/stager.py b/sdks/python/apache_beam/runners/portability/stager.py index c7142bfddcaf..7f113c95ccd0 100644 --- a/sdks/python/apache_beam/runners/portability/stager.py +++ b/sdks/python/apache_beam/runners/portability/stager.py @@ -214,8 +214,8 @@ def create_job_resources( if not skip_prestaged_dependencies: requirements_cache_path = ( os.path.join(tempfile.gettempdir(), 'dataflow-requirements-cache') if - (setup_options.requirements_cache is None) else - setup_options.requirements_cache) + (setup_options.requirements_cache + is None) else setup_options.requirements_cache) if (setup_options.requirements_cache != SKIP_REQUIREMENTS_CACHE and not os.path.exists(requirements_cache_path)): os.makedirs(requirements_cache_path) diff --git a/sdks/python/apache_beam/runners/portability/stager_test.py b/sdks/python/apache_beam/runners/portability/stager_test.py index 5535989a5786..da147317a596 100644 --- a/sdks/python/apache_beam/runners/portability/stager_test.py +++ b/sdks/python/apache_beam/runners/portability/stager_test.py @@ -44,6 +44,7 @@ class StagerTest(unittest.TestCase): + def setUp(self): self._temp_dir = None self.stager = TestStager() @@ -842,6 +843,7 @@ def test_populate_requirements_cache_with_local_files(self): class TestStager(stager.Stager): + def stage_artifact(self, local_path_to_artifact, artifact_name, sha256): _LOGGER.info( 'File copy from %s to %s.', local_path_to_artifact, artifact_name) diff --git a/sdks/python/apache_beam/runners/render.py b/sdks/python/apache_beam/runners/render.py index 45e66e1ba06a..02fe6a99c58f 100644 --- a/sdks/python/apache_beam/runners/render.py +++ b/sdks/python/apache_beam/runners/render.py @@ -90,6 +90,7 @@ class RenderOptions(pipeline_options.PipelineOptions): """Rendering options.""" + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -133,6 +134,7 @@ def _add_argparse_args(cls, parser): class PipelineRenderer: + def __init__(self, pipeline, options): self.pipeline = pipeline self.options = options @@ -151,9 +153,8 @@ def __init__(self, pipeline, options): if options.render_leaf_composite_nodes: is_leaf = lambda transform_id: any( re.match( - pattern, - self.pipeline.components.transforms[transform_id].unique_name) - for patterns in options.render_leaf_composite_nodes + pattern, self.pipeline.components.transforms[transform_id]. + unique_name) for patterns in options.render_leaf_composite_nodes for pattern in patterns.split(',')) self.leaf_composites = set() @@ -441,6 +442,7 @@ def run_portable_pipeline(self, pipeline_proto, options): # TODO: If this gets more complex, we could consider taking on a # framework like Flask as a dependency. class RequestHandler(http.server.BaseHTTPRequestHandler): + def do_GET(self): parts = urllib.parse.urlparse(self.path) args = urllib.parse.parse_qs(parts.query) @@ -471,6 +473,7 @@ def do_GET(self): class RenderPipelineResult(runner.PipelineResult): + def __init__(self, server): super().__init__(runner.PipelineState.RUNNING) self.server = server @@ -548,7 +551,9 @@ def render_one(options): def run_server(options): + class RenderBeamJob(local_job_service.BeamJob): + def _invoke_runner(self): return RenderRunner().run_portable_pipeline( self._pipeline_proto, diff --git a/sdks/python/apache_beam/runners/render_test.py b/sdks/python/apache_beam/runners/render_test.py index 67e7afc1c7b9..8997b331593d 100644 --- a/sdks/python/apache_beam/runners/render_test.py +++ b/sdks/python/apache_beam/runners/render_test.py @@ -32,6 +32,7 @@ class RenderRunnerTest(unittest.TestCase): + def test_basic_graph(self): p = beam.Pipeline() _ = ( @@ -80,6 +81,7 @@ def test_composite_collapse(self): class DotRequiringRenderingTest(unittest.TestCase): + @classmethod def setUpClass(cls): try: diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py index 78022724226a..6f4915ebf1c7 100644 --- a/sdks/python/apache_beam/runners/runner.py +++ b/sdks/python/apache_beam/runners/runner.py @@ -111,6 +111,7 @@ class PipelineRunner(object): provide a new implementation for clear_pvalue(), which is used to wipe out materialized values in order to reduce footprint. """ + def run( self, transform: 'PTransform', @@ -261,6 +262,7 @@ def is_terminal(cls, state): class PipelineResult(object): """A :class:`PipelineResult` provides access to info about a pipeline.""" + def __init__(self, state): self._state = state diff --git a/sdks/python/apache_beam/runners/sdf_utils.py b/sdks/python/apache_beam/runners/sdf_utils.py index 01573656b6ac..03070c4e9967 100644 --- a/sdks/python/apache_beam/runners/sdf_utils.py +++ b/sdks/python/apache_beam/runners/sdf_utils.py @@ -55,6 +55,7 @@ class ThreadsafeRestrictionTracker(object): This wrapper guarantees synchronization of modifying restrictions across multi-thread. """ + def __init__(self, restriction_tracker: 'RestrictionTracker') -> None: from apache_beam.io.iobase import RestrictionTracker if not isinstance(restriction_tracker, RestrictionTracker): @@ -155,6 +156,7 @@ class RestrictionTrackerView(object): time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a restriction_tracker. """ + def __init__( self, threadsafe_restriction_tracker: ThreadsafeRestrictionTracker) -> None: @@ -182,6 +184,7 @@ class ThreadsafeWatermarkEstimator(object): """A threadsafe wrapper which wraps a WatermarkEstimator with locking mechanism to guarantee multi-thread safety. """ + def __init__(self, watermark_estimator: 'WatermarkEstimator') -> None: from apache_beam.io.iobase import WatermarkEstimator if not isinstance(watermark_estimator, WatermarkEstimator): @@ -220,6 +223,7 @@ class NoOpWatermarkEstimatorProvider(WatermarkEstimatorProvider): """A WatermarkEstimatorProvider which creates NoOpWatermarkEstimator for the framework. """ + def initial_estimator_state(self, element, restriction): return None @@ -230,6 +234,7 @@ class _NoOpWatermarkEstimator(WatermarkEstimator): """A No-op WatermarkEstimator which is provided for the framework if there is no custom one. """ + def observe_timestamp(self, timestamp): pass diff --git a/sdks/python/apache_beam/runners/sdf_utils_test.py b/sdks/python/apache_beam/runners/sdf_utils_test.py index a4510d747d13..a26034b25e65 100644 --- a/sdks/python/apache_beam/runners/sdf_utils_test.py +++ b/sdks/python/apache_beam/runners/sdf_utils_test.py @@ -33,6 +33,7 @@ class ThreadsafeRestrictionTrackerTest(unittest.TestCase): + def test_initialization(self): with self.assertRaises(ValueError): ThreadsafeRestrictionTracker(RangeSource(0, 1)) @@ -79,6 +80,7 @@ def test_self_checkpoint_with_absolute_time(self): class RestrictionTrackerViewTest(unittest.TestCase): + def test_initialization(self): with self.assertRaises(ValueError): RestrictionTrackerView(OffsetRestrictionTracker(OffsetRange(0, 10))) @@ -111,6 +113,7 @@ def test_non_expose_apis(self): class ThreadsafeWatermarkEstimatorTest(unittest.TestCase): + def test_initialization(self): with self.assertRaises(ValueError): ThreadsafeWatermarkEstimator(None) diff --git a/sdks/python/apache_beam/runners/trivial_runner.py b/sdks/python/apache_beam/runners/trivial_runner.py index af8f4f92c4e3..bb8c1755053b 100644 --- a/sdks/python/apache_beam/runners/trivial_runner.py +++ b/sdks/python/apache_beam/runners/trivial_runner.py @@ -52,6 +52,7 @@ class TrivialRunner(runner.PipelineRunner): several features in order to keep it as simple as possible. Where possible pointers are provided which this should serve as a useful starting point. """ + def run_portable_pipeline(self, pipeline, options): # First ensure we are able to run this pipeline. # Specifically, that it does not depend on requirements that were @@ -278,10 +279,10 @@ def group_by_key_and_window(self, input_pcoll, output_pcoll, execution_state): windowing = components.windowing_strategies[ components.pcollections[input_pcoll].windowing_strategy_id] - if (windowing.merge_status == - beam_runner_api_pb2.MergeStatus.Enum.NON_MERGING and - windowing.output_time == - beam_runner_api_pb2.OutputTime.Enum.END_OF_WINDOW): + if (windowing.merge_status + == beam_runner_api_pb2.MergeStatus.Enum.NON_MERGING and + windowing.output_time + == beam_runner_api_pb2.OutputTime.Enum.END_OF_WINDOW): # This is the "easy" case, show how to do it by hand. # Note that we're grouping by encoded key, and also by the window. grouped = collections.defaultdict(list) @@ -322,6 +323,7 @@ def supported_requirements(self) -> Iterable[str]: class ExecutionState: """A helper class holding various values and context during execution.""" + def __init__(self, optimized_pipeline): self.optimized_pipeline = optimized_pipeline self._pcollections_to_encoded_chunks = {} diff --git a/sdks/python/apache_beam/runners/trivial_runner_test.py b/sdks/python/apache_beam/runners/trivial_runner_test.py index 6acfc2fee495..bea1c87882a9 100644 --- a/sdks/python/apache_beam/runners/trivial_runner_test.py +++ b/sdks/python/apache_beam/runners/trivial_runner_test.py @@ -25,6 +25,7 @@ class TrivialRunnerTest(unittest.TestCase): + def test_trivial(self): # The most trivial pipeline, to ensure at least something is working. # (Notably avoids the non-trivial complexity within assert_that.) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 89c137fe4366..56be78ebf051 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -131,6 +131,7 @@ class RunnerIOOperation(operations.Operation): """Common baseclass for runner harness IO operations.""" + def __init__( self, name_context: common.NameContext, @@ -156,6 +157,7 @@ def __init__( class DataOutputOperation(RunnerIOOperation): """A sink-like operation that gathers outputs to be sent back to the runner. """ + def set_output_stream( self, output_stream: data_plane.ClosableOutputStream) -> None: self.output_stream = output_stream @@ -172,6 +174,7 @@ def finish(self) -> None: class DataInputOperation(RunnerIOOperation): """A source-like operation that gathers input from the runner.""" + def __init__( self, operation_name: common.NameContext, @@ -295,6 +298,7 @@ def _compute_split( total_buffer_size, allowed_split_points=(), try_split=lambda fraction: None): + def is_valid_split_point(index): return not allowed_split_points or index in allowed_split_points @@ -356,6 +360,7 @@ def reset(self) -> None: class _StateBackedIterable(object): + def __init__( self, state_handler: sdk_worker.CachingStateHandler, @@ -510,6 +515,7 @@ def reset(self) -> None: class ReadModifyWriteRuntimeState(userstate.ReadModifyWriteRuntimeState): + def __init__(self, underlying_bag_state): self._underlying_bag_state = underlying_bag_state @@ -531,6 +537,7 @@ def commit(self) -> None: class CombiningValueRuntimeState(userstate.CombiningValueRuntimeState): + def __init__( self, underlying_bag_state: userstate.AccumulatingRuntimeState, @@ -580,6 +587,7 @@ class _ConcatIterable(object): Unlike itertools.chain, this allows reiteration. """ + def __init__(self, first: Iterable[Any], second: Iterable[Any]) -> None: self.first = first self.second = second @@ -595,6 +603,7 @@ def __iter__(self) -> Iterator[Any]: class SynchronousBagRuntimeState(userstate.BagRuntimeState): + def __init__( self, state_handler: sdk_worker.CachingStateHandler, @@ -633,6 +642,7 @@ def commit(self) -> None: class SynchronousSetRuntimeState(userstate.SetRuntimeState): + def __init__( self, state_handler: sdk_worker.CachingStateHandler, @@ -693,6 +703,7 @@ def commit(self) -> None: class RangeSet: """For Internal Use only. A simple range set for ranges of [x,y).""" + def __init__(self) -> None: # The start points and end points are stored separately in order. self._sorted_starts = SortedList() @@ -725,8 +736,8 @@ def add(self, start: int, end: int) -> None: def __contains__(self, key: int) -> bool: idx = self._sorted_starts.bisect_left(key) - return (idx < len(self._sorted_starts) and self._sorted_starts[idx] == key - ) or (idx > 0 and self._sorted_ends[idx - 1] > key) + return (idx < len(self._sorted_starts) and self._sorted_starts[idx] + == key) or (idx > 0 and self._sorted_ends[idx - 1] > key) def __len__(self) -> int: assert len(self._sorted_starts) == len(self._sorted_ends) @@ -866,6 +877,7 @@ def commit(self) -> None: class OutputTimer(userstate.BaseTimer): + def __init__( self, key, @@ -914,6 +926,7 @@ def clear(self, dynamic_timer_tag='') -> None: class TimerInfo(object): """A data class to store information related to a timer.""" + def __init__(self, timer_coder_impl, output_stream=None): self.timer_coder_impl = timer_coder_impl self.output_stream = output_stream @@ -921,6 +934,7 @@ def __init__(self, timer_coder_impl, output_stream=None): class FnApiUserStateContext(userstate.UserStateContext): """Interface for state and timers from SDK to Fn API servicer of state..""" + def __init__( self, state_handler: sdk_worker.CachingStateHandler, @@ -1079,6 +1093,7 @@ def _verify_descriptor_created_in_a_compatible_env( class BundleProcessor(object): """ A class for processing bundles of elements. """ + def __init__( self, runner_capabilities: FrozenSet[str], @@ -1166,8 +1181,8 @@ def is_side_input(transform_proto, tag): def get_operation(transform_id: str) -> operations.Operation: transform_consumers = { tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]] - for tag, - pcoll_id in descriptor.transforms[transform_id].outputs.items() + for tag, pcoll_id in + descriptor.transforms[transform_id].outputs.items() } # Initialize transform-specific state in the Data Sampler. @@ -1287,8 +1302,8 @@ def process_bundle( timer_info.output_stream.close() return ([ - self.delayed_bundle_application(op, residual) for op, - residual in execution_context.delayed_applications + self.delayed_bundle_application(op, residual) + for op, residual in execution_context.delayed_applications ], self.requires_finalization()) @@ -1427,6 +1442,7 @@ class ExecutionContext: class BeamTransformFactory(object): """Factory for turning transform_protos into executable operations.""" + def __init__( self, runner_capabilities: FrozenSet[str], @@ -1445,10 +1461,9 @@ def __init__( self.state_handler = state_handler self.context = pipeline_context.PipelineContext( descriptor, - iterable_state_read=lambda token, - element_coder_impl: _StateBackedIterable( - state_handler, - beam_fn_api_pb2.StateKey( + iterable_state_read=lambda token, element_coder_impl: + _StateBackedIterable( + state_handler, beam_fn_api_pb2.StateKey( runner=beam_fn_api_pb2.StateKey.Runner(key=token)), element_coder_impl)) self.data_sampler = data_sampler @@ -1479,6 +1494,7 @@ def register_urn( Dict[str, List[operations.Operation]] ], operations.Operation]]: + def wrapper(func): cls._known_urns[urn] = func, parameter_type return func @@ -1539,8 +1555,7 @@ def get_output_coders( ) -> Dict[str, coders.Coder]: return { tag: self.get_windowed_coder(pcoll_id) - for tag, - pcoll_id in transform_proto.outputs.items() + for tag, pcoll_id in transform_proto.outputs.items() } def get_only_output_coder( @@ -1552,8 +1567,7 @@ def get_input_coders( ) -> Dict[str, coders.WindowedValueCoder]: return { tag: self.get_windowed_coder(pcoll_id) - for tag, - pcoll_id in transform_proto.inputs.items() + for tag, pcoll_id in transform_proto.inputs.items() } def get_only_input_coder( @@ -1702,7 +1716,9 @@ def create_dofn_javasdk( common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn, beam_runner_api_pb2.ParDoPayload) def create_pair_with_restriction(*args): + class PairWithRestriction(beam.DoFn): + def __init__(self, fn, restriction_provider, watermark_estimator_provider): self.restriction_provider = restriction_provider self.watermark_estimator_provider = watermark_estimator_provider @@ -1725,7 +1741,9 @@ def process(self, element, *args, **kwargs): common_urns.sdf_components.SPLIT_AND_SIZE_RESTRICTIONS.urn, beam_runner_api_pb2.ParDoPayload) def create_split_and_size_restrictions(*args): + class SplitAndSizeRestrictions(beam.DoFn): + def __init__(self, fn, restriction_provider, watermark_estimator_provider): self.restriction_provider = restriction_provider self.watermark_estimator_provider = watermark_estimator_provider @@ -1748,7 +1766,9 @@ def process(self, element_restriction, *args, **kwargs): common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn, beam_runner_api_pb2.ParDoPayload) def create_truncate_sized_restriction(*args): + class TruncateAndSizeRestriction(beam.DoFn): + def __init__(self, fn, restriction_provider, watermark_estimator_provider): self.restriction_provider = restriction_provider @@ -1851,8 +1871,7 @@ def _create_pardo_operation( input_tags_to_coders = factory.get_input_coders(transform_proto) tagged_side_inputs = [ (tag, beam.pvalue.SideInputData.from_runner_api(si, factory.context)) - for tag, - si in pardo_proto.side_inputs.items() + for tag, si in pardo_proto.side_inputs.items() ] tagged_side_inputs.sort( key=lambda tag_si: sideinputs.get_sideinput_index(tag_si[0])) @@ -1962,7 +1981,9 @@ def create_assign_windows( transform_proto: beam_runner_api_pb2.PTransform, parameter: beam_runner_api_pb2.WindowingStrategy, consumers: Dict[str, List[operations.Operation]]): + class WindowIntoDoFn(beam.DoFn): + def __init__(self, windowing): self.windowing = windowing @@ -2133,6 +2154,7 @@ def create_map_windows( window_mapping_fn = pickler.loads(mapping_fn_spec.payload) class MapWindows(beam.DoFn): + def process(self, element): key, window = element return [(key, window_mapping_fn(window))] @@ -2153,6 +2175,7 @@ def create_merge_windows( window_fn = pickler.loads(mapping_fn_spec.payload) class MergeWindows(beam.DoFn): + def process(self, element): nonce, windows = element @@ -2163,6 +2186,7 @@ def process(self, element): set) # noqa: F821 class RecordingMergeContext(window.WindowFn.MergeContext): + def merge( self, to_be_merged: Iterable[window.BoundedWindow], @@ -2190,7 +2214,9 @@ def create_to_string_fn( transform_proto: beam_runner_api_pb2.PTransform, mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, consumers: Dict[str, List[operations.Operation]]): + class ToString(beam.DoFn): + def process(self, element): key, value = element return [(key, str(value))] diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py index 0eb4dd9485fd..209acb6dbb38 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py @@ -47,9 +47,12 @@ class FnApiUserStateContextTest(unittest.TestCase): + def testOutputTimerTimestamp(self): + class Coder(object): """Dummy coder to capture the timer result befor encoding.""" + def encode_to_stream(self, timer, *args, **kwargs): self.timer = timer @@ -82,6 +85,7 @@ def encode_to_stream(self, timer, *args, **kwargs): class SplitTest(unittest.TestCase): + def split( self, index, @@ -197,7 +201,9 @@ def element_split(frac, index): class TestOperation(operations.Operation): """Test operation that forwards its payload to consumers.""" + class Spec: + def __init__(self, transform_proto): self.output_coders = [ FastPrimitivesCoder() for _ in transform_proto.outputs @@ -250,7 +256,9 @@ def create_test_op(factory, transform_id, transform_proto, payload, consumers): def create_exception_dofn( factory, transform_id, transform_proto, payload, consumers): """Returns a test DoFn that raises the given exception.""" + class RaiseException(beam.DoFn): + def __init__(self, msg): self.msg = msg.decode() @@ -266,6 +274,7 @@ def process(self, _): class DataSamplingTest(unittest.TestCase): + def test_disabled_by_default(self): """Test that not providing the sampler does not enable Data Sampling. @@ -410,6 +419,7 @@ def test_can_sample_exceptions(self): class EnvironmentCompatibilityTest(unittest.TestCase): + def test_rc_environments_are_compatible_with_released_images(self): # TODO(https://github.com/apache/beam/issues/28084): remove when # resolved. @@ -430,7 +440,9 @@ def test_user_modified_sdks_need_to_be_installed_in_runtime_env(self): class OrderedListStateTest(unittest.TestCase): + class NoStateCache(StateCache): + def __init__(self): super().__init__(max_weight=0) @@ -621,7 +633,9 @@ def test_multiple_iterators(self): self.assertEqual([], list(self.state.read())) def fuzz_test_helper(self, seed=0, lower=0, upper=20): + class NaiveState: + def __init__(self): self._data = [[] for i in range((upper - lower + 1))] self._logs = [] diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index 2f9de24594b2..5ec03c7a0223 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -87,6 +87,7 @@ class ClosableOutputStream(OutputStream): """A Outputstream for use with CoderImpls that has a close() method.""" + def __init__( self, close_callback=None # type: Optional[Callable[[bytes], None]] @@ -135,7 +136,7 @@ def __init__( close_callback=None, # type: Optional[Callable[[bytes], None]] flush_callback=None, # type: Optional[Callable[[bytes], None]] size_flush_threshold=_DEFAULT_SIZE_FLUSH_THRESHOLD, # type: int - large_buffer_warn_threshold_bytes = 512 << 20 # type: int + large_buffer_warn_threshold_bytes=512 << 20 # type: int ): super().__init__(close_callback) self._flush_callback = flush_callback @@ -290,6 +291,7 @@ class DataChannel(metaclass=abc.ABCMeta): data_channel.close() """ + @abc.abstractmethod def input_elements( self, @@ -367,6 +369,7 @@ class InMemoryDataChannel(DataChannel): This channel is two-sided. What is written to one side is read by the other. The inverse() method returns the other side of a instance. """ + def __init__(self, inverse=None, data_buffer_time_limit_ms=0): # type: (Optional[InMemoryDataChannel], int) -> None self._inputs = [] # type: List[DataOrTimers] @@ -739,6 +742,7 @@ def __init__( class BeamFnDataServicer(beam_fn_api_pb2_grpc.BeamFnDataServicer): """Implementation of BeamFnDataServicer for any number of clients""" + def __init__( self, data_buffer_time_limit_ms=0 # type: int @@ -768,6 +772,7 @@ def Data( class DataChannelFactory(metaclass=abc.ABCMeta): """An abstract factory for creating ``DataChannel``.""" + @abc.abstractmethod def create_data_channel(self, remote_grpc_port): # type: (beam_fn_api_pb2.RemoteGrpcPort) -> GrpcClientDataChannel @@ -860,6 +865,7 @@ def close(self): class InMemoryDataChannelFactory(DataChannelFactory): """A singleton factory for ``InMemoryDataChannel``.""" + def __init__(self, in_memory_data_channel): # type: (GrpcClientDataChannel) -> None self._in_memory_data_channel = in_memory_data_channel diff --git a/sdks/python/apache_beam/runners/worker/data_plane_test.py b/sdks/python/apache_beam/runners/worker/data_plane_test.py index 5124bb69e6c8..b6e4e0fecf1e 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane_test.py +++ b/sdks/python/apache_beam/runners/worker/data_plane_test.py @@ -34,6 +34,7 @@ class DataChannelTest(unittest.TestCase): + def test_grpc_data_channel(self): self._grpc_data_channel_test() diff --git a/sdks/python/apache_beam/runners/worker/data_sampler.py b/sdks/python/apache_beam/runners/worker/data_sampler.py index a0f02a51c8ad..fe28db802417 100644 --- a/sdks/python/apache_beam/runners/worker/data_sampler.py +++ b/sdks/python/apache_beam/runners/worker/data_sampler.py @@ -50,6 +50,7 @@ class SampleTimer: """Periodic timer for sampling elements.""" + def __init__(self, timeout_secs: float, sampler: OutputSampler) -> None: self._target_timeout_secs = timeout_secs self._timeout_secs = min(timeout_secs, 0.5) if timeout_secs > 0 else 0.0 @@ -111,6 +112,7 @@ class OutputSampler: This is configurable to only keep `max_samples` (see constructor) sampled elements in memory. Samples are taken every `sample_every_sec`. """ + def __init__( self, coder: Coder, @@ -155,9 +157,8 @@ def flush(self, clear: bool = True) -> List[beam_fn_api_pb2.SampledElement]: exceptions = [s for s in self._exceptions] samples = [s for s in self._samples if id(s) not in seen] else: - exceptions = [ - (self.remove_windowed_value(a), b) for a, b in self._exceptions - ] + exceptions = [(self.remove_windowed_value(a), b) + for a, b in self._exceptions] samples = [ self.remove_windowed_value(s) for s in self._samples if id(s) not in seen @@ -186,8 +187,7 @@ def flush(self, clear: bool = True) -> List[beam_fn_api_pb2.SampledElement]: exception=beam_fn_api_pb2.SampledElement.Exception( instruction_id=exn.instruction_id, transform_id=exn.transform_id, - error=exn.msg)) for s, - exn in exceptions) + error=exn.msg)) for s, exn in exceptions) except Exception as e: # pylint: disable=broad-except _LOGGER.warning('Could not encode sampled exception values: %s' % e) @@ -221,6 +221,7 @@ class DataSampler: Samples generated during execution can then be sampled with the `samples` method. This filters samples from the given pcollection ids. """ + def __init__( self, max_samples: int = 10, diff --git a/sdks/python/apache_beam/runners/worker/data_sampler_test.py b/sdks/python/apache_beam/runners/worker/data_sampler_test.py index 8c47315b7a9e..29b5498bebb2 100644 --- a/sdks/python/apache_beam/runners/worker/data_sampler_test.py +++ b/sdks/python/apache_beam/runners/worker/data_sampler_test.py @@ -40,6 +40,7 @@ class DataSamplerTest(unittest.TestCase): + def make_test_descriptor( self, outputs: Optional[List[str]] = None, @@ -443,6 +444,7 @@ def test_only_sample_exceptions(self): class OutputSamplerTest(unittest.TestCase): + def tearDown(self): self.sampler.stop() diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py index 979c7cdb53be..69815acc7194 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler.py +++ b/sdks/python/apache_beam/runners/worker/log_handler.py @@ -111,8 +111,8 @@ def map_log_level( return LOG_LEVEL_TO_LOGENTRY_MAP[level] except KeyError: return max( - beam_level for python_level, - beam_level in LOG_LEVEL_TO_LOGENTRY_MAP.items() + beam_level + for python_level, beam_level in LOG_LEVEL_TO_LOGENTRY_MAP.items() if python_level <= level) def emit(self, record: logging.LogRecord) -> None: diff --git a/sdks/python/apache_beam/runners/worker/log_handler_test.py b/sdks/python/apache_beam/runners/worker/log_handler_test.py index 2cf7dff9d57f..fb07459085a8 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler_test.py +++ b/sdks/python/apache_beam/runners/worker/log_handler_test.py @@ -48,7 +48,9 @@ def create_exception_dofn( factory, transform_id, transform_proto, payload, consumers): """Returns a test DoFn that raises the given exception.""" + class RaiseException(beam.DoFn): + def __init__(self, msg): self.msg = msg.decode() @@ -65,7 +67,9 @@ def process(self, _): class TestOperation(operations.Operation): """Test operation that forwards its payload to consumers.""" + class Spec: + def __init__(self, transform_proto): self.output_coders = [ FastPrimitivesCoder() for _ in transform_proto.outputs @@ -115,6 +119,7 @@ def create_test_op(factory, transform_id, transform_proto, payload, consumers): class BeamFnLoggingServicer(beam_fn_api_pb2_grpc.BeamFnLoggingServicer): + def __init__(self): self.log_records_received = [] @@ -127,6 +132,7 @@ def Logging(self, request_iterator, context): class FnApiLogRecordHandlerTest(unittest.TestCase): + def setUp(self): self.test_logging_service = BeamFnLoggingServicer() self.server = grpc.server(thread_pool_executor.shared_unbounded_instance()) @@ -316,8 +322,7 @@ def test_extracts_transform_id_during_exceptions(self): def _create_test(name, num_logs): setattr( FnApiLogRecordHandlerTest, - 'test_%s' % name, - lambda self: self._verify_fn_log_handler(num_logs)) + 'test_%s' % name, lambda self: self._verify_fn_log_handler(num_logs)) for test_name, num_logs_entries in data.items(): diff --git a/sdks/python/apache_beam/runners/worker/logger.py b/sdks/python/apache_beam/runners/worker/logger.py index 06e2508fb7d2..ed03060e0c35 100644 --- a/sdks/python/apache_beam/runners/worker/logger.py +++ b/sdks/python/apache_beam/runners/worker/logger.py @@ -37,6 +37,7 @@ # context information that changes while work items get executed: # work_item_id, step_name, stage_name. class _PerThreadWorkerData(threading.local): + def __init__(self) -> None: super().__init__() # in the list, as going up and down all the way to zero incurs several @@ -66,6 +67,7 @@ def PerThreadLoggingContext(**kwargs: Any) -> Iterator[None]: class JsonLogFormatter(logging.Formatter): """A JSON formatter class as expected by the logging standard module.""" + def __init__(self, job_id: str, worker_id: str) -> None: super().__init__() self.job_id = job_id diff --git a/sdks/python/apache_beam/runners/worker/logger_test.py b/sdks/python/apache_beam/runners/worker/logger_test.py index 158fd3b5856b..0f77b2571724 100644 --- a/sdks/python/apache_beam/runners/worker/logger_test.py +++ b/sdks/python/apache_beam/runners/worker/logger_test.py @@ -31,6 +31,7 @@ class PerThreadLoggingContextTest(unittest.TestCase): + def thread_check_attribute(self, name): self.assertFalse(name in logger.per_thread_worker_data.get_data()) with logger.PerThreadLoggingContext(**{name: 'thread-value'}): @@ -95,7 +96,9 @@ class JsonLogFormatterTest(unittest.TestCase): } def create_log_record(self, **kwargs): + class Record(object): + def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) diff --git a/sdks/python/apache_beam/runners/worker/opcounters.py b/sdks/python/apache_beam/runners/worker/opcounters.py index 5496bccd014e..4af6bb83bae3 100644 --- a/sdks/python/apache_beam/runners/worker/opcounters.py +++ b/sdks/python/apache_beam/runners/worker/opcounters.py @@ -47,6 +47,7 @@ class TransformIOCounter(object): Some examples of IO can be side inputs, shuffle, or streaming state. """ + def __init__(self, counter_factory, state_sampler): """Create a new IO read counter. @@ -93,6 +94,7 @@ def __exit__(self, exception_type, exception_value, traceback): class NoOpTransformIOCounter(TransformIOCounter): """All operations for IO tracking are no-ops.""" + def __init__(self): super().__init__(None, None) @@ -123,12 +125,12 @@ class SideInputReadCounter(TransformIOCounter): not be the only step that spends time reading from this side input. """ - def __init__(self, - counter_factory, - state_sampler, # type: StateSampler - declaring_step, - input_index - ): + def __init__( + self, + counter_factory, + state_sampler, # type: StateSampler + declaring_step, + input_index): """Create a side input read counter. Args: @@ -166,6 +168,7 @@ def _update_counters_for_requesting_step(self, step_name): class SumAccumulator(object): """Accumulator for collecting byte counts.""" + def __init__(self): self._value = 0 @@ -178,6 +181,7 @@ def value(self): class OperationCounters(object): """The set of basic counters to attach to an Operation.""" + def __init__( self, counter_factory, @@ -186,7 +190,7 @@ def __init__( index, suffix='out', producer_type_hints=None, - producer_batch_converter=None, # type: Optional[BatchConverter] + producer_batch_converter=None, # type: Optional[BatchConverter] ): self._counter_factory = counter_factory self.element_counter = counter_factory.get_counter( @@ -223,6 +227,7 @@ def update_from_batch(self, windowed_batch): self.mean_byte_counter.update_n(mean_element_size, batch_length) def _observable_callback(self, inner_coder_impl, accumulator): + def _observable_callback_inner(value, is_encoded=False): # TODO(ccy): If this stream is large, sample it as well. # To do this, we'll need to compute the average size of elements diff --git a/sdks/python/apache_beam/runners/worker/opcounters_test.py b/sdks/python/apache_beam/runners/worker/opcounters_test.py index 83b345f47c9c..249819f29fae 100644 --- a/sdks/python/apache_beam/runners/worker/opcounters_test.py +++ b/sdks/python/apache_beam/runners/worker/opcounters_test.py @@ -36,16 +36,19 @@ class OldClassThatDoesNotImplementLen(object): + def __init__(self): pass class ObjectThatDoesNotImplementLen(object): + def __init__(self): pass class TransformIoCounterTest(unittest.TestCase): + def test_basic_counters(self): counter_factory = CounterFactory() sampler = statesampler.StateSampler('stage1', counter_factory) @@ -91,6 +94,7 @@ def test_basic_counters(self): class OperationCountersTest(unittest.TestCase): + def verify_counters(self, opcounts, expected_elements, expected_size=None): self.assertEqual(expected_elements, opcounts.element_counter.value()) if expected_size is not None: diff --git a/sdks/python/apache_beam/runners/worker/operation_specs.py b/sdks/python/apache_beam/runners/worker/operation_specs.py index 1b86cdaae561..0823a1e37484 100644 --- a/sdks/python/apache_beam/runners/worker/operation_specs.py +++ b/sdks/python/apache_beam/runners/worker/operation_specs.py @@ -56,25 +56,14 @@ def worker_printable_fields(workerproto): '%s=%s' % (name, value) # _asdict is the only way and cannot subclass this generated class # pylint: disable=protected-access - for name, - value in workerproto._asdict().items() + for name, value in workerproto._asdict().items() # want to output value 0 but not None nor [] if (value or value == 0) and name not in ( - 'coder', - 'coders', - 'output_coders', - 'elements', - 'combine_fn', - 'serialized_fn', - 'window_fn', - 'append_trailing_newlines', - 'strip_trailing_newlines', - 'compression_type', - 'context', - 'start_shuffle_position', - 'end_shuffle_position', - 'shuffle_reader_config', - 'shuffle_writer_config') + 'coder', 'coders', 'output_coders', 'elements', 'combine_fn', + 'serialized_fn', 'window_fn', 'append_trailing_newlines', + 'strip_trailing_newlines', 'compression_type', 'context', + 'start_shuffle_position', 'end_shuffle_position', + 'shuffle_reader_config', 'shuffle_writer_config') ] diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index 58c807c28dbd..b5983f2320a4 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -114,16 +114,18 @@ class ConsumerSet(Receiver): the other edge. ConsumerSet are attached to the outputting Operation. """ + @staticmethod - def create(counter_factory, - step_name, # type: str - output_index, - consumers, # type: List[Operation] - coder, - producer_type_hints, - producer_batch_converter, # type: Optional[BatchConverter] - output_sampler=None, # type: Optional[OutputSampler] - ): + def create( + counter_factory, + step_name, # type: str + output_index, + consumers, # type: List[Operation] + coder, + producer_type_hints, + producer_batch_converter, # type: Optional[BatchConverter] + output_sampler=None, # type: Optional[OutputSampler] + ): # type: (...) -> ConsumerSet if len(consumers) == 1: consumer = consumers[0] @@ -152,16 +154,16 @@ def create(counter_factory, producer_batch_converter, output_sampler) - def __init__(self, - counter_factory, - step_name, # type: str - output_index, - consumers, - coder, - producer_type_hints, - producer_batch_converter, - output_sampler - ): + def __init__( + self, + counter_factory, + step_name, # type: str + output_index, + consumers, + coder, + producer_type_hints, + producer_batch_converter, + output_sampler): self.opcounter = opcounters.OperationCounters( counter_factory, step_name, @@ -238,15 +240,16 @@ def __repr__(self): class SingletonElementConsumerSet(ConsumerSet): """ConsumerSet representing a single consumer that can only process elements (not batches).""" - def __init__(self, - counter_factory, - step_name, - output_index, - consumer, # type: Operation - coder, - producer_type_hints, - output_sampler - ): + + def __init__( + self, + counter_factory, + step_name, + output_index, + consumer, # type: Operation + coder, + producer_type_hints, + output_sampler): super().__init__( counter_factory, step_name, @@ -284,15 +287,16 @@ class GeneralPurposeConsumerSet(ConsumerSet): """ MAX_BATCH_SIZE = 4096 - def __init__(self, - counter_factory, - step_name, # type: str - output_index, - coder, - producer_type_hints, - consumers, # type: List[Operation] - producer_batch_converter, - output_sampler): + def __init__( + self, + counter_factory, + step_name, # type: str + output_index, + coder, + producer_type_hints, + consumers, # type: List[Operation] + producer_batch_converter, + output_sampler): super().__init__( counter_factory, step_name, @@ -415,12 +419,13 @@ class Operation(object): one or more receiver operations that will take that as input. """ - def __init__(self, - name_context, # type: common.NameContext - spec, - counter_factory, - state_sampler # type: StateSampler - ): + def __init__( + self, + name_context, # type: common.NameContext + spec, + counter_factory, + state_sampler # type: StateSampler + ): """Initializes a worker operation instance. Args: @@ -489,8 +494,8 @@ def get_output_sampler(output_num): coder, self._get_runtime_performance_hints(), self.get_output_batch_converter(), - get_output_sampler(i)) for i, - coder in enumerate(self.spec.output_coders) + get_output_sampler(i)) + for i, coder in enumerate(self.spec.output_coders) ] self.setup_done = True @@ -700,6 +705,7 @@ def _get_runtime_performance_hints(self): class ReadOperation(Operation): + def start(self): with self.scoped_start_state: super(ReadOperation, self).start() @@ -714,6 +720,7 @@ def start(self): class ImpulseReadOperation(Operation): + def __init__( self, name_context, # type: common.NameContext @@ -751,6 +758,7 @@ def process(self, unused_impulse): class InMemoryWriteOperation(Operation): """A write operation that will write to an in-memory sink.""" + def process(self, o): # type: (WindowedValue) -> None with self.scoped_process_state: @@ -761,6 +769,7 @@ def process(self, o): class _TaggedReceivers(dict): + def __init__(self, counter_factory, step_name): self._counter_factory = counter_factory self._step_name = step_name @@ -794,14 +803,15 @@ def total_output_bytes(self): class DoOperation(Operation): """A Do operation that will execute a custom DoFn for each input element.""" - def __init__(self, - name, # type: common.NameContext - spec, # operation_specs.WorkerDoFn # need to fix this type - counter_factory, - sampler, - side_input_maps=None, - user_state_context=None, - ): + def __init__( + self, + name, # type: common.NameContext + spec, # operation_specs.WorkerDoFn # need to fix this type + counter_factory, + sampler, + side_input_maps=None, + user_state_context=None, + ): super(DoOperation, self).__init__(name, spec, counter_factory, sampler) self.side_input_maps = side_input_maps self.user_state_context = user_state_context @@ -1039,6 +1049,7 @@ def _get_runtime_performance_hints(self): class SdfTruncateSizedRestrictions(DoOperation): + def __init__(self, *args, **kwargs): super(SdfTruncateSizedRestrictions, self).__init__(*args, **kwargs) @@ -1053,6 +1064,7 @@ def try_split( class SdfProcessSizedElements(DoOperation): + def __init__(self, *args, **kwargs): super(SdfProcessSizedElements, self).__init__(*args, **kwargs) self.lock = threading.RLock() @@ -1137,6 +1149,7 @@ def monitoring_infos(self, transform_id, tag_to_pcollection_id): class CombineOperation(Operation): """A Combine operation executing a CombineFn for each input element.""" + def __init__(self, name_context, spec, counter_factory, state_sampler): super(CombineOperation, self).__init__(name_context, spec, counter_factory, state_sampler) @@ -1189,6 +1202,7 @@ class PGBKOperation(Operation): (key, [value]) tuples, performing a best effort group-by-key for values in this bundle, memory permitting. """ + def __init__(self, name_context, spec, counter_factory, state_sampler): super(PGBKOperation, self).__init__(name_context, spec, counter_factory, state_sampler) @@ -1229,6 +1243,7 @@ def flush(self, target): class PGBKCVOperation(Operation): + def __init__( self, name_context, spec, counter_factory, state_sampler, windowing=None): super(PGBKCVOperation, @@ -1343,6 +1358,7 @@ class FlattenOperation(Operation): Receives one or more producer operations, outputs just one list with all the items. """ + def process(self, o): # type: (WindowedValue) -> None with self.scoped_process_state: @@ -1441,6 +1457,7 @@ class SimpleMapTaskExecutor(object): Stores progress of the read operation that is the first operation of a map task. """ + def __init__( self, map_task, diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 3cb1a26b77f1..d5893394d3c4 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -108,6 +108,7 @@ class ShortIdCache(object): """ Cache for MonitoringInfo "short ids" """ + def __init__(self): # type: () -> None self._lock = threading.Lock() @@ -174,8 +175,8 @@ def __init__( data_sampler=None, # type: Optional[data_sampler.DataSampler] # Unrecoverable SDK harness initialization error (if any) # that should be reported to the runner when proocessing the first bundle. - deferred_exception=None, # type: Optional[Exception] - runner_capabilities=frozenset(), # type: FrozenSet[str] + deferred_exception=None, # type: Optional[Exception] + runner_capabilities=frozenset(), # type: FrozenSet[str] ): # type: (...) -> None self._alive = True @@ -361,8 +362,7 @@ def _request_harness_monitoring_infos(self, request): ).to_runner_api_monitoring_infos(None).values() self._execute( lambda: beam_fn_api_pb2.InstructionResponse( - instruction_id=request.instruction_id, - harness_monitoring_infos=( + instruction_id=request.instruction_id, harness_monitoring_infos=( beam_fn_api_pb2.HarnessMonitoringInfosResponse( monitoring_data={ SHORT_ID_CACHE.get_short_id(info): info.payload @@ -374,8 +374,8 @@ def _request_monitoring_infos(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._execute( lambda: beam_fn_api_pb2.InstructionResponse( - instruction_id=request.instruction_id, - monitoring_infos=beam_fn_api_pb2.MonitoringInfosMetadataResponse( + instruction_id=request.instruction_id, monitoring_infos= + beam_fn_api_pb2.MonitoringInfosMetadataResponse( monitoring_info=SHORT_ID_CACHE.get_infos( request.monitoring_infos.monitoring_info_id))), request) @@ -641,6 +641,7 @@ def _shutdown_cached_bundle_processors(cached_bundle_processors): class SdkWorker(object): + def __init__( self, bundle_processor_cache, # type: BundleProcessorCache @@ -805,6 +806,7 @@ def maybe_profile(self, instruction_id): class StateHandler(metaclass=abc.ABCMeta): """An abstract object representing a ``StateHandler``.""" + @abc.abstractmethod def get_raw( self, @@ -875,6 +877,7 @@ def done(self): class StateHandlerFactory(metaclass=abc.ABCMeta): """An abstract factory for creating ``DataChannel``.""" + @abc.abstractmethod def create_state_handler(self, api_service_descriptor): # type: (endpoints_pb2.ApiServiceDescriptor) -> CachingStateHandler @@ -895,6 +898,7 @@ class GrpcStateHandlerFactory(StateHandlerFactory): Caches the created channels by ``state descriptor url``. """ + def __init__(self, state_cache, credentials=None, worker_id=None): # type: (StateCache, Optional[grpc.ChannelCredentials], Optional[str]) -> None self._state_handler_cache = {} # type: Dict[str, CachingStateHandler] @@ -946,6 +950,7 @@ def close(self): class CachingStateHandler(metaclass=abc.ABCMeta): + @abc.abstractmethod @contextlib.contextmanager def process_instruction_id(self, bundle_id, cache_tokens): @@ -984,6 +989,7 @@ def done(self): class ThrowingStateHandler(CachingStateHandler): """A caching state handler that errors on any requests.""" + @contextlib.contextmanager def process_instruction_id(self, bundle_id, cache_tokens): # type: (str, Iterable[beam_fn_api_pb2.ProcessBundleRequest.CacheToken]) -> Iterator[None] @@ -1161,6 +1167,7 @@ class GlobalCachingStateHandler(CachingStateHandler): If activated but no cache token is supplied, caching is done at the bundle level. """ + def __init__( self, global_state_cache, # type: StateCache @@ -1297,10 +1304,11 @@ def _lazy_iterator( if not continuation_token: break - def _get_raw(self, + def _get_raw( + self, state_key, # type: beam_fn_api_pb2.StateKey continuation_token # type: Optional[bytes] - ): + ): # type: (...) -> Tuple[coder_impl.create_InputStream, Optional[bytes]] """Call underlying get_raw with performance statistics and detection.""" @@ -1369,6 +1377,7 @@ def _partially_cached_iterable( self._lazy_iterator, state_key, coder, continuation_token)) class ContinuationIterable(Generic[T], CacheAware): + def __init__(self, head, continue_iterator_fn): # type: (Iterable[T], Callable[[], Iterable[T]]) -> None self.head = head @@ -1397,6 +1406,7 @@ def _convert_to_cache_key(state_key): class _Future(Generic[T]): """A simple future object to implement blocking requests. """ + def __init__(self): # type: () -> None self._event = threading.Event() @@ -1429,6 +1439,7 @@ def done(cls): class _DeferredCall(_Future[T]): + def __init__(self, func, *args): # type: (Callable[..., Any], *Any) -> None self._func = func diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py index 3389f0c7afb1..b3c81fd93467 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py @@ -232,8 +232,7 @@ def _load_pipeline_options(options_json): return { re.match(portable_option_regex, k).group('key') if re.match( portable_option_regex, k) else k: v - for k, - v in options.items() + for k, v in options.items() } diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py index 498a07b70e9e..14c11ba6f301 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py @@ -39,6 +39,7 @@ class SdkWorkerMainTest(unittest.TestCase): # Used for testing newly added flags. class MockOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument('--eam:option:m_option:v', help='mock option') @@ -108,6 +109,7 @@ def test_create_sdk_harness_log_handler_received_log(self): logstream = io.StringIO() class InMemoryHandler(logging.StreamHandler): + def __init__(self, *unused): super().__init__(stream=logstream) @@ -160,7 +162,7 @@ def test__get_log_level_from_options_dict(self): def test__set_log_level_overrides(self): test_cases = [ - ([], {}), # not provided, as a smoke test + ([], {}), # not provided, as a smoke test ( # single overrides ['{"fake_module_1a.b":"DEBUG","fake_module_1c.d":"INFO"}'], @@ -168,8 +170,7 @@ def test__set_log_level_overrides(self): "fake_module_1a.b": logging.DEBUG, "fake_module_1a.b.f": logging.DEBUG, "fake_module_1c.d": logging.INFO - } - ), + }), ( # multiple overrides, the last takes precedence [ @@ -183,8 +184,7 @@ def test__set_log_level_overrides(self): "fake_module_2c.d": logging.ERROR, "fake_module_2c.d.e": 15, "fake_module_2c.d.f": logging.ERROR - } - ) + }) ] for case, expected in test_cases: overrides = self._overrides_case_to_option_dict(case) diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py index 17bf043d020c..bc35335ffd20 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py @@ -48,6 +48,7 @@ class BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer): + def __init__(self, requests, raise_errors=True): self.requests = requests self.instruction_ids = set(r.instruction_id for r in requests) @@ -75,6 +76,7 @@ def Control(self, response_iterator, context): class SdkWorkerTest(unittest.TestCase): + def _get_process_bundles(self, prefix, size): return [ beam_fn_api_pb2.ProcessBundleDescriptor( @@ -281,6 +283,7 @@ def test_data_sampling_response(self): coder = FastPrimitivesCoder() class FakeDataSampler: + def samples(self, pcollection_ids): return beam_fn_api_pb2.SampleDataResponse( element_samples={ @@ -332,6 +335,7 @@ def stop(self): class CachingStateHandlerTest(unittest.TestCase): + def test_caching(self): coder = VarIntCoder() @@ -340,6 +344,7 @@ def test_caching(self): class FakeUnderlyingState(object): """Simply returns an incremented counter as the state "value." """ + def set_counter(self, n): self._counter = n @@ -442,6 +447,7 @@ def get_as_list(key): class UnderlyingStateHandler(object): """Simply returns an incremented counter as the state "value." """ + def __init__(self): self._encoded_values = [] self._continuations = False @@ -569,82 +575,66 @@ def clear(): class ShortIdCacheTest(unittest.TestCase): + def testShortIdAssignment(self): TestCase = namedtuple('TestCase', ['expected_short_id', 'info']) test_cases = [ TestCase(*args) for args in [ ( - "1", - metrics_pb2.MonitoringInfo( + "1", metrics_pb2.MonitoringInfo( urn="beam:metric:user:distribution_int64:v1", type="beam:metrics:distribution_int64:v1")), ( - "2", - metrics_pb2.MonitoringInfo( + "2", metrics_pb2.MonitoringInfo( urn="beam:metric:element_count:v1", type="beam:metrics:sum_int64:v1")), ( - "3", - metrics_pb2.MonitoringInfo( + "3", metrics_pb2.MonitoringInfo( urn="beam:metric:ptransform_progress:completed:v1", type="beam:metrics:progress:v1")), ( - "4", - metrics_pb2.MonitoringInfo( + "4", metrics_pb2.MonitoringInfo( urn="beam:metric:user:distribution_double:v1", type="beam:metrics:distribution_double:v1")), ( - "5", - metrics_pb2.MonitoringInfo( + "5", metrics_pb2.MonitoringInfo( urn="TestingSentinelUrn", type="TestingSentinelType")), ( - "6", - metrics_pb2.MonitoringInfo( + "6", metrics_pb2.MonitoringInfo( urn= "beam:metric:pardo_execution_time:finish_bundle_msecs:v1", type="beam:metrics:sum_int64:v1")), # This case and the next one validates that different labels # with the same urn are in fact assigned different short ids. ( - "7", - metrics_pb2.MonitoringInfo( + "7", metrics_pb2.MonitoringInfo( urn="beam:metric:user:sum_int64:v1", - type="beam:metrics:sum_int64:v1", - labels={ - "PTRANSFORM": "myT", - "NAMESPACE": "harness", + type="beam:metrics:sum_int64:v1", labels={ + "PTRANSFORM": "myT", "NAMESPACE": "harness", "NAME": "metricNumber7" })), ( - "8", - metrics_pb2.MonitoringInfo( + "8", metrics_pb2.MonitoringInfo( urn="beam:metric:user:sum_int64:v1", - type="beam:metrics:sum_int64:v1", - labels={ - "PTRANSFORM": "myT", - "NAMESPACE": "harness", + type="beam:metrics:sum_int64:v1", labels={ + "PTRANSFORM": "myT", "NAMESPACE": "harness", "NAME": "metricNumber8" })), ( - "9", - metrics_pb2.MonitoringInfo( + "9", metrics_pb2.MonitoringInfo( urn="beam:metric:user:top_n_double:v1", - type="beam:metrics:top_n_double:v1", - labels={ - "PTRANSFORM": "myT", - "NAMESPACE": "harness", + type="beam:metrics:top_n_double:v1", labels={ + "PTRANSFORM": "myT", "NAMESPACE": "harness", "NAME": "metricNumber7" })), ( - "a", - metrics_pb2.MonitoringInfo( + "a", metrics_pb2.MonitoringInfo( urn="beam:metric:element_count:v1", type="beam:metrics:sum_int64:v1", labels={"PCOLLECTION": "myPCol"})), # validate payload is ignored for shortId assignment ( - "3", - metrics_pb2.MonitoringInfo( + "3", metrics_pb2.MonitoringInfo( urn="beam:metric:ptransform_progress:completed:v1", type="beam:metrics:progress:v1", payload=b"this is ignored!")) @@ -680,8 +670,8 @@ def testShortIdAssignment(self): def monitoringInfoMetadata(info): return { descriptor.name: value - for descriptor, - value in info.ListFields() if not descriptor.name == "payload" + for descriptor, value in info.ListFields() + if not descriptor.name == "payload" } diff --git a/sdks/python/apache_beam/runners/worker/sideinputs.py b/sdks/python/apache_beam/runners/worker/sideinputs.py index ea09aafba317..9644f6c14cf1 100644 --- a/sdks/python/apache_beam/runners/worker/sideinputs.py +++ b/sdks/python/apache_beam/runners/worker/sideinputs.py @@ -53,6 +53,7 @@ class PrefetchingSourceSetIterable(object): """Value iterator that reads concurrently from a set of sources.""" + def __init__( self, sources, @@ -87,6 +88,7 @@ def add_byte_counter(self, reader): reader: A reader that should inherit from ObservableMixin to have bytes tracked. """ + def update_bytes_read(record_size, is_record_size=False, **kwargs): # Let the reader report block size. if is_record_size: @@ -194,6 +196,7 @@ def get_iterator_fn_for_sources( read_counter=None, element_counter=None): """Returns callable that returns iterator over elements for given sources.""" + def _inner(): return iter( PrefetchingSourceSetIterable( @@ -207,6 +210,7 @@ def _inner(): class EmulatedIterable(abc.Iterable): """Emulates an iterable for a side input.""" + def __init__(self, iterator_fn): self.iterator_fn = iterator_fn diff --git a/sdks/python/apache_beam/runners/worker/sideinputs_test.py b/sdks/python/apache_beam/runners/worker/sideinputs_test.py index 2e89b866986f..630b84d39768 100644 --- a/sdks/python/apache_beam/runners/worker/sideinputs_test.py +++ b/sdks/python/apache_beam/runners/worker/sideinputs_test.py @@ -34,6 +34,7 @@ def strip_windows(iterator): class FakeSource(object): + def __init__(self, items, notify_observers=False): self.items = items self._should_notify_observers = notify_observers @@ -43,6 +44,7 @@ def reader(self): class FakeSourceReader(observable.ObservableMixin): + def __init__(self, items, notify_observers=False): super().__init__() self.items = items @@ -69,6 +71,7 @@ def returns_windowed_values(self): class PrefetchingSourceIteratorTest(unittest.TestCase): + def test_single_source_iterator_fn(self): sources = [ FakeSource([0, 1, 2, 3, 4, 5]), @@ -111,6 +114,7 @@ def test_multiple_sources_single_reader_iterator_fn(self): assert list(strip_windows(iterator_fn())) == list(range(11)) def test_source_iterator_single_source_exception(self): + class MyException(Exception): pass @@ -129,6 +133,7 @@ def exception_generator(): self.assertEqual(sorted(seen), [0]) def test_source_iterator_fn_exception(self): + class MyException(Exception): pass @@ -158,7 +163,9 @@ def perpetual_generator(value): class EmulatedCollectionsTest(unittest.TestCase): + def test_emulated_iterable(self): + def _iterable_fn(): for i in range(10): yield i diff --git a/sdks/python/apache_beam/runners/worker/statecache.py b/sdks/python/apache_beam/runners/worker/statecache.py index d4e61cc9297f..0d07f91ac3c8 100644 --- a/sdks/python/apache_beam/runners/worker/statecache.py +++ b/sdks/python/apache_beam/runners/worker/statecache.py @@ -58,6 +58,7 @@ class WeightedValue(object): :arg weight The associated weight of the value. If unspecified, the objects size will be used. """ + def __init__(self, value: Any, weight: int) -> None: self._value = value if weight <= 0: @@ -74,6 +75,7 @@ def value(self) -> Any: class CacheAware(object): """Allows cache users to override what objects are measured.""" + def __init__(self) -> None: pass @@ -168,6 +170,7 @@ def get_deep_size(*objs: Any) -> int: class _LoadingValue(WeightedValue): """Allows concurrent users of the cache to wait for a value to be loaded.""" + def __init__(self) -> None: super().__init__(None, 1) self._wait_event = threading.Event() @@ -210,6 +213,7 @@ class StateCache(object): :arg max_weight The maximum weight of entries to store in the cache in bytes. """ + def __init__(self, max_weight: int) -> None: _LOGGER.info('Creating state cache with size %s', max_weight) self._max_weight = max_weight diff --git a/sdks/python/apache_beam/runners/worker/statecache_test.py b/sdks/python/apache_beam/runners/worker/statecache_test.py index a5d1ff2e01e3..ac951218eb0a 100644 --- a/sdks/python/apache_beam/runners/worker/statecache_test.py +++ b/sdks/python/apache_beam/runners/worker/statecache_test.py @@ -38,10 +38,12 @@ class StateCacheTest(unittest.TestCase): + def test_weakref(self): test_value = WeightedValue('test', 10 << 20) class WeightedValueRef(): + def __init__(self): self.ref = weakref.ref(test_value) @@ -69,6 +71,7 @@ def test_weakref_proxy(self): test_value = WeightedValue('test', 10 << 20) class WeightedValueRef(): + def __init__(self): self.ref = weakref.ref(test_value) @@ -93,7 +96,9 @@ def __init__(self): cache.put('deleted', o_ref) def test_size_of_fails(self): + class BadSizeOf(object): + def __sizeof__(self): raise RuntimeError("TestRuntimeError") @@ -219,6 +224,7 @@ def test_lru(self): 'avg load time 0 ns, loads 0, evictions 5')) def test_get(self): + def check_key(key): self.assertEqual(key, "key") time.sleep(0.5) @@ -269,11 +275,13 @@ def load_key(output): output["time"] = time.time_ns() t1_output = {} - t1 = threading.Thread(target=load_key, args=(t1_output, )) + t1 = threading.Thread( + target=load_key, args=(t1_output, )) t1.start() t2_output = {} - t2 = threading.Thread(target=load_key, args=(t2_output, )) + t2 = threading.Thread( + target=load_key, args=(t2_output, )) t2.start() # Wait for both threads to start @@ -311,7 +319,8 @@ def load_key(output): output["value"] = cache.get("key", wait_for_event) t1_output = {} - t1 = threading.Thread(target=load_key, args=(t1_output, )) + t1 = threading.Thread( + target=load_key, args=(t1_output, )) t1.start() # Wait for the load to start, update the key, and then let the load finish @@ -343,7 +352,9 @@ def test_is_cached_enabled(self): 'avg load time 0 ns, loads 0, evictions 0')) def test_get_referents_for_cache(self): + class GetReferentsForCache(CacheAware): + def __init__(self): self.measure_me = bytearray(1 << 20) self.ignore_me = bytearray(2 << 20) @@ -366,19 +377,21 @@ def test_get_deep_size_builtin_objects(self): built-in objects. """ primitive_test_objects = [ - 1, # int - 2.0, # float - 1+1j, # complex - True, # bool - 'hello,world', # str - b'\00\01\02', # bytes + 1, # int + 2.0, # float + 1 + 1j, # complex + True, # bool + 'hello,world', # str + b'\00\01\02', # bytes ] collection_test_objects = [ - [3, 4, 5], # list - (6, 7), # tuple - {'a', 'b', 'c'}, # set - {'k': 8, 'l': 9}, # dict + [3, 4, 5], # list + (6, 7), # tuple + {'a', 'b', 'c'}, # set + { + 'k': 8, 'l': 9 + }, # dict ] for obj in primitive_test_objects: diff --git a/sdks/python/apache_beam/runners/worker/statesampler.py b/sdks/python/apache_beam/runners/worker/statesampler.py index b9c75f4de93d..5321c1495b15 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler.py +++ b/sdks/python/apache_beam/runners/worker/statesampler.py @@ -89,6 +89,7 @@ def for_test(): class StateSampler(statesampler_impl.StateSampler): + def __init__( self, prefix: str, diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py b/sdks/python/apache_beam/runners/worker/statesampler_slow.py index be801284450a..5452da6dbaee 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py @@ -24,6 +24,7 @@ class StateSampler(object): + def __init__(self, sampling_period_ms): self._state_stack = [ ScopedState(self, counters.CounterName('unknown'), None) @@ -73,6 +74,7 @@ def reset(self) -> None: class ScopedState(object): + def __init__( self, sampler: StateSampler, diff --git a/sdks/python/apache_beam/runners/worker/worker_id_interceptor_test.py b/sdks/python/apache_beam/runners/worker/worker_id_interceptor_test.py index 0db9c1b4ddc0..ba0168eb31c0 100644 --- a/sdks/python/apache_beam/runners/worker/worker_id_interceptor_test.py +++ b/sdks/python/apache_beam/runners/worker/worker_id_interceptor_test.py @@ -34,6 +34,7 @@ class _ClientCallDetails(collections.namedtuple( class WorkerIdInterceptorTest(unittest.TestCase): + def test_worker_id_insertion(self): worker_id_key = 'worker_id' headers_holder = {} diff --git a/sdks/python/apache_beam/runners/worker/worker_pool_main.py b/sdks/python/apache_beam/runners/worker/worker_pool_main.py index 307261c2d3c3..62aed60f2d58 100644 --- a/sdks/python/apache_beam/runners/worker/worker_pool_main.py +++ b/sdks/python/apache_beam/runners/worker/worker_pool_main.py @@ -57,6 +57,7 @@ def kill_process_gracefully(proc, timeout=10): it to finish. A SIGKILL will be sent if the process has not finished after ``timeout`` seconds. """ + def _kill(): proc.terminate() try: @@ -73,6 +74,7 @@ def _kill(): class BeamFnExternalWorkerPoolServicer( beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolServicer): + def __init__( self, use_process=False, diff --git a/sdks/python/apache_beam/runners/worker/worker_status.py b/sdks/python/apache_beam/runners/worker/worker_status.py index ecd4dc4e02c0..a2412c792f3a 100644 --- a/sdks/python/apache_beam/runners/worker/worker_status.py +++ b/sdks/python/apache_beam/runners/worker/worker_status.py @@ -145,6 +145,7 @@ def _active_processing_bundles_state(bundle_process_cache): class FnApiWorkerStatusHandler(object): """FnApiWorkerStatusHandler handles worker status request from Runner. """ + def __init__( self, status_address, @@ -276,8 +277,8 @@ def _get_stack_trace(self, sampler_info): return '-NOT AVAILABLE-' def _passed_lull_timeout_since_last_log(self) -> bool: - if (time.time() - self._last_lull_logged_secs > - self.log_lull_timeout_ns / 1e9): + if (time.time() - self._last_lull_logged_secs + > self.log_lull_timeout_ns / 1e9): self._last_lull_logged_secs = time.time() return True else: diff --git a/sdks/python/apache_beam/runners/worker/worker_status_test.py b/sdks/python/apache_beam/runners/worker/worker_status_test.py index 1004d21e7fd3..5ddc4ba6a7d5 100644 --- a/sdks/python/apache_beam/runners/worker/worker_status_test.py +++ b/sdks/python/apache_beam/runners/worker/worker_status_test.py @@ -33,6 +33,7 @@ class BeamFnStatusServicer(beam_fn_api_pb2_grpc.BeamFnWorkerStatusServicer): + def __init__(self, num_request): self.finished = threading.Condition() self.num_request = num_request @@ -50,6 +51,7 @@ def WorkerStatus(self, response_iterator, context): class FnApiWorkerStatusHandlerTest(unittest.TestCase): + def setUp(self): self.num_request = 3 self.test_status_service = BeamFnStatusServicer(self.num_request) @@ -87,6 +89,7 @@ def test_generate_error(self, mock_method): self.fn_status_handler.close() def test_log_lull_in_bundle_processor(self): + def get_state_sampler_info_for_lull(lull_duration_s): return "bundle-id", statesampler.StateSamplerInfo( CounterName('progress-msecs', 'stage_name', 'step_name'), @@ -128,6 +131,7 @@ def get_state_sampler_info_for_lull(lull_duration_s): class HeapDumpTest(unittest.TestCase): + @mock.patch('apache_beam.runners.worker.worker_status.hpy', None) def test_skip_heap_dump(self): result = '%s' % heap_dump() diff --git a/sdks/python/apache_beam/testing/analyzers/load_test_perf_analysis.py b/sdks/python/apache_beam/testing/analyzers/load_test_perf_analysis.py index ee9d04e6260f..01a845b2c6b0 100644 --- a/sdks/python/apache_beam/testing/analyzers/load_test_perf_analysis.py +++ b/sdks/python/apache_beam/testing/analyzers/load_test_perf_analysis.py @@ -35,6 +35,7 @@ class LoadTestMetricsFetcher(perf_analysis_utils.MetricsFetcher): are fetched and returned as a dataclass containing lists of timestamps and metric_values. """ + def fetch_metric_data( self, *, test_config: TestConfigContainer) -> MetricContainer: if test_config.test_name: diff --git a/sdks/python/apache_beam/testing/analyzers/perf_analysis.py b/sdks/python/apache_beam/testing/analyzers/perf_analysis.py index a12b06c8c3eb..93b9fb342bc0 100644 --- a/sdks/python/apache_beam/testing/analyzers/perf_analysis.py +++ b/sdks/python/apache_beam/testing/analyzers/perf_analysis.py @@ -68,7 +68,9 @@ def get_test_config_container( ) -def get_change_point_config(params: Dict[str, Any], ) -> ChangePointConfig: +def get_change_point_config( + params: Dict[str, Any], +) -> ChangePointConfig: """ Args: params: Dict containing parameters to run change point analysis. diff --git a/sdks/python/apache_beam/testing/analyzers/perf_analysis_test.py b/sdks/python/apache_beam/testing/analyzers/perf_analysis_test.py index 5dbeba74b7e9..5875208634ed 100644 --- a/sdks/python/apache_beam/testing/analyzers/perf_analysis_test.py +++ b/sdks/python/apache_beam/testing/analyzers/perf_analysis_test.py @@ -73,6 +73,7 @@ def get_existing_issue_data(**kwargs): class TestChangePointAnalysis(unittest.TestCase): + def setUp(self) -> None: self.single_change_point_series = [0] * 10 + [1] * 10 self.multiple_change_point_series = self.single_change_point_series + [ @@ -248,6 +249,7 @@ def test_change_point_has_anomaly_marker_in_gh_description(self): self.assertTrue(match) def test_change_point_on_noisy_data(self): + def read_csv(path): with FileSystems.open(path) as fp: return pd.read_csv(fp) diff --git a/sdks/python/apache_beam/testing/analyzers/perf_analysis_utils.py b/sdks/python/apache_beam/testing/analyzers/perf_analysis_utils.py index a9015d715e90..85c10063754b 100644 --- a/sdks/python/apache_beam/testing/analyzers/perf_analysis_utils.py +++ b/sdks/python/apache_beam/testing/analyzers/perf_analysis_utils.py @@ -330,6 +330,7 @@ def is_edge_change_point( class MetricsFetcher(metaclass=abc.ABCMeta): + @abc.abstractmethod def fetch_metric_data( self, *, test_config: TestConfigContainer) -> MetricContainer: @@ -341,6 +342,7 @@ def fetch_metric_data( class BigQueryMetricsFetcher(MetricsFetcher): + def fetch_metric_data( self, *, test_config: TestConfigContainer) -> MetricContainer: """ diff --git a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/preprocess.py b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/preprocess.py index 2016a2c97658..aa1da415d505 100644 --- a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/preprocess.py +++ b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/preprocess.py @@ -81,6 +81,7 @@ def transform_data( pipeline_args: additional DataflowRunner or DirectRunner args passed to the beam pipeline. """ + def preprocessing_fn(inputs): """tf.transform's callback function for preprocessing inputs. diff --git a/sdks/python/apache_beam/testing/benchmarks/cloudml/cloudml_benchmark_test.py b/sdks/python/apache_beam/testing/benchmarks/cloudml/cloudml_benchmark_test.py index ea2d93512fd4..e8b4031325d7 100644 --- a/sdks/python/apache_beam/testing/benchmarks/cloudml/cloudml_benchmark_test.py +++ b/sdks/python/apache_beam/testing/benchmarks/cloudml/cloudml_benchmark_test.py @@ -57,6 +57,7 @@ def _publish_metrics(pipeline, metric_value, metrics_table, metric_name): @pytest.mark.uses_tft class CloudMLTFTBenchmarkTest(unittest.TestCase): + def test_cloudml_benchmark_criteo_small(self): test_pipeline = TestPipeline(is_integration_test=True) extra_opts = {} diff --git a/sdks/python/apache_beam/testing/benchmarks/cloudml/criteo_tft/criteo.py b/sdks/python/apache_beam/testing/benchmarks/cloudml/criteo_tft/criteo.py index d2a0b652ca69..f08920526b75 100644 --- a/sdks/python/apache_beam/testing/benchmarks/cloudml/criteo_tft/criteo.py +++ b/sdks/python/apache_beam/testing/benchmarks/cloudml/criteo_tft/criteo.py @@ -120,6 +120,7 @@ def make_preprocessing_fn(frequency_threshold): Returns: A preprocessing function. """ + def preprocessing_fn(inputs): """User defined preprocessing function for criteo columns. diff --git a/sdks/python/apache_beam/testing/benchmarks/cloudml/pipelines/workflow.py b/sdks/python/apache_beam/testing/benchmarks/cloudml/pipelines/workflow.py index e60e3a47c0d1..82dcdbf287d0 100644 --- a/sdks/python/apache_beam/testing/benchmarks/cloudml/pipelines/workflow.py +++ b/sdks/python/apache_beam/testing/benchmarks/cloudml/pipelines/workflow.py @@ -34,16 +34,17 @@ class _RecordBatchToPyDict(beam.PTransform): """Converts PCollections of pa.RecordBatch to python dicts.""" + def __init__(self, input_feature_spec): self._input_feature_spec = input_feature_spec def expand(self, pcoll): + def format_values(instance): return { k: v.squeeze(0).tolist() if v is not None else self._input_feature_spec[k].default_value - for k, - v in instance.items() + for k, v in instance.items() } return ( @@ -70,6 +71,7 @@ def _synthetic_preprocessing_fn(inputs): class _PredictionHistogramFn(beam.DoFn): + def __init__(self): # Beam Metrics API for Distributions only works with integers but # predictions are floating point numbers. We thus store a "quantized" diff --git a/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py b/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py index 514c9d672850..6979f215d2f2 100644 --- a/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py +++ b/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py @@ -27,6 +27,7 @@ class PytorchVisionBenchmarkTest(LoadTest): + def __init__(self): # TODO (https://github.com/apache/beam/issues/23008) # make get_namespace() method in RunInference static diff --git a/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_language_modeling_benchmarks.py b/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_language_modeling_benchmarks.py index 1d6ecb2bd438..41375b8b1125 100644 --- a/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_language_modeling_benchmarks.py +++ b/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_language_modeling_benchmarks.py @@ -23,6 +23,7 @@ class PytorchLanguageModelingBenchmarkTest(LoadTest): + def __init__(self): # TODO (https://github.com/apache/beam/issues/23008): # make get_namespace() method in RunInference static diff --git a/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py b/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py index 223b973e5fbe..90a827bc3daf 100644 --- a/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py +++ b/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py @@ -23,6 +23,7 @@ class TensorflowMNISTClassificationCostBenchmark(DataflowCostBenchmark): + def __init__(self): super().__init__() diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/models/auction_bid.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/models/auction_bid.py index 7424a3a48355..2b3dfa407b6b 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/models/auction_bid.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/models/auction_bid.py @@ -23,6 +23,7 @@ class AuctionBidCoder(FastCoder): + def to_type_hint(self): return AuctionBid diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/models/nexmark_model.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/models/nexmark_model.py index 4613d7f90c26..499d156b590d 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/models/nexmark_model.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/models/nexmark_model.py @@ -33,6 +33,7 @@ class PersonCoder(FastCoder): + def to_type_hint(self): return Person @@ -63,6 +64,7 @@ def __repr__(self): class AuctionCoder(FastCoder): + def to_type_hint(self): return Auction @@ -105,6 +107,7 @@ def __repr__(self): class BidCoder(FastCoder): + def to_type_hint(self): return Bid diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/monitor.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/monitor.py index 064fbb11da5d..6ee79fc32175 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/monitor.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/monitor.py @@ -32,6 +32,7 @@ class Monitor(object): name_prefix: a prefix for this Monitor's metrics' names, intended to be unique in per-monitor basis in pipeline """ + def __init__(self, namespace: str, name_prefix: str) -> None: self.namespace = namespace self.name_prefix = name_prefix @@ -39,6 +40,7 @@ def __init__(self, namespace: str, name_prefix: str) -> None: class MonitorDoFn(beam.DoFn): + def __init__(self, namespace, prefix): self.element_count = Metrics.counter( namespace, prefix + MonitorSuffix.ELEMENT_COUNTER) diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py index c29825f95f3e..24d24676205f 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py @@ -21,6 +21,7 @@ class NexmarkPerf(object): + def __init__( self, runtime_sec=None, diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py index ef53156d8be0..9a6f42151205 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py @@ -51,11 +51,13 @@ class Command(object): + def __init__(self, cmd, args): self.cmd = cmd self.args = args def run(self, timeout): + def thread_target(): logging.debug( 'Starting thread for %d seconds: %s', timeout, self.cmd.__name__) @@ -102,6 +104,7 @@ class ParseEventFn(beam.DoFn): 1528098831536,20180630,maria,vehicle' 'b12345,maria,20000,1528098831536' """ + def process(self, elem): model_dict = { 'p': nexmark_model.Person, @@ -142,6 +145,7 @@ class ParseJsonEventFn(beam.DoFn): {"auction":1000,"bidder":1001,"price":32530001,"dateTime":1528098831066,\ "extra":"fdiysaV^]NLVsbolvyqwgticfdrwdyiyofWPYTOuwogvszlxjrcNOORM"} """ + def process(self, elem): json_dict = json.loads(elem) if type(json_dict[FieldNames.DATE_TIME]) is dict: @@ -183,6 +187,7 @@ def process(self, elem): class CountAndLog(beam.PTransform): + def expand(self, pcoll): return ( pcoll diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/nexmark_query_util.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/nexmark_query_util.py index 85bece4083c6..b3e14b8f6d4c 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/nexmark_query_util.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/nexmark_query_util.py @@ -57,35 +57,42 @@ def auction_or_bid(event): class JustBids(beam.PTransform): + def expand(self, pcoll): return pcoll | "IsBid" >> beam.Filter(is_bid) class JustAuctions(beam.PTransform): + def expand(self, pcoll): return pcoll | "IsAuction" >> beam.Filter(is_auction) class JustPerson(beam.PTransform): + def expand(self, pcoll): return pcoll | "IsPerson" >> beam.Filter(is_person) class AuctionByIdFn(beam.DoFn): + def process(self, element): yield element.id, element class BidByAuctionIdFn(beam.DoFn): + def process(self, element): yield element.auction, element class PersonByIdFn(beam.DoFn): + def process(self, element): yield element.id, element class AuctionBySellerFn(beam.DoFn): + def process(self, element): yield element.seller, element diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query0.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query0.py index 904e1d208dc0..091975c65dcc 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query0.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query0.py @@ -31,6 +31,7 @@ class RoundTripFn(beam.DoFn): + def process(self, element): coder = element.CODER byte_value = coder.encode(element) diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query1.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query1.py index 2173a93c2abe..5588bc3cb1eb 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query1.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query1.py @@ -38,8 +38,5 @@ def load(events, metadata=None, pipeline_options=None): | nexmark_query_util.JustBids() | 'ConvertToEuro' >> beam.Map( lambda bid: nexmark_model.Bid( - bid.auction, - bid.bidder, - bid.price * USD_TO_EURO, - bid.date_time, + bid.auction, bid.bidder, bid.price * USD_TO_EURO, bid.date_time, bid.extra))) diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query10.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query10.py index 49c428ef78c6..ba6ebb529aa3 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query10.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query10.py @@ -38,6 +38,7 @@ class OutputFile(object): + def __init__(self, max_timestamp, shard, index, timing, filename): self.max_timestamp = max_timestamp self.shard = shard @@ -112,6 +113,7 @@ def load(events, metadata=None, pipeline_options=None): class ShardEventsDoFn(beam.DoFn): + def process(self, element): shard_number = abs(hash(element) % num_log_shards) shard = 'shard-%05d-of-%05d' % (shard_number, num_log_shards) @@ -119,6 +121,7 @@ def process(self, element): class WriteEventDoFn(beam.DoFn): + def process( self, element, @@ -138,6 +141,7 @@ def process( class WriteIndexDoFn(beam.DoFn): + def process(self, element, pipeline_options, window=beam.DoFn.WindowParam): options = pipeline_options.view_as(GoogleCloudOptions) filename = index_path_for(window) diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query3.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query3.py index eb16d2dc36a0..f390c8c37001 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query3.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query3.py @@ -74,10 +74,8 @@ def load(events, metadata=None, pipeline_options=None): JoinFn(metadata.get('max_auction_waiting_time'))) | 'query3_output' >> beam.Map( lambda t: { - ResultNames.NAME: t[1].name, - ResultNames.CITY: t[1].city, - ResultNames.STATE: t[1].state, - ResultNames.AUCTION_ID: t[0].id + ResultNames.NAME: t[1].name, ResultNames.CITY: t[1].city, + ResultNames.STATE: t[1].state, ResultNames.AUCTION_ID: t[0].id })) diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query4.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query4.py index ad6c63a88c37..bd038d1654e9 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query4.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query4.py @@ -72,6 +72,7 @@ def load(events, metadata=None, pipeline_options=None): class ProjectToCategoryPriceFn(beam.DoFn): + def process(self, element, pane_info=beam.DoFn.PaneInfoParam): yield { ResultNames.CATEGORY: element[0], diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query5.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query5.py index a55d31de6091..1e71059400dc 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query5.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query5.py @@ -64,6 +64,7 @@ class MostBidCombineFn(beam.CombineFn): """ combiner function to find auctions with most bid counts """ + def create_accumulator(self): return [], 0 diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query6.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query6.py index 0f8a0eb59325..c3537f44f978 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query6.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query6.py @@ -62,6 +62,7 @@ class MovingMeanSellingPriceFn(beam.CombineFn): Combiner to keep track of up to max_num_bids of the most recent wining bids and calculate their average selling price. """ + def __init__(self, max_num_bids): self.max_num_bids = max_num_bids diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query7.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query7.py index 930eb08f0366..88ae5ec8e7ae 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query7.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query7.py @@ -52,6 +52,7 @@ def load(events, metadata=None, pipeline_options=None): class SelectMaxBidFn(beam.DoFn): + def process(self, element, max_bid_price): if element.price == max_bid_price: yield element diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query8.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query8.py index 59a0459742c1..b8ed9daa63b5 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query8.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/query8.py @@ -58,6 +58,7 @@ def load(events, metadata=None, pipeline_options=None): class JoinPersonAuctionFn(beam.DoFn): + def process(self, element): _, group = element persons = group[nexmark_query_util.PERSON_TAG] diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/winning_bids.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/winning_bids.py index 52ffd483a840..80f7792326a2 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/winning_bids.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/queries/winning_bids.py @@ -44,6 +44,7 @@ class AuctionOrBidWindow(IntervalWindow): """Windows for open auctions and bids.""" + def __init__(self, start, end, auction_id, is_auction_window): super().__init__(start, end) self.auction = auction_id @@ -71,6 +72,7 @@ def __str__(self): class AuctionOrBidWindowCoder(FastCoder): + def _create_impl(self): return AuctionOrBidWindowCoderImpl() @@ -98,6 +100,7 @@ def decode_from_stream(self, stream, nested): class AuctionOrBidWindowFn(WindowFn): + def __init__(self, expected_duration_micro): self.expected_duration = expected_duration_micro @@ -145,6 +148,7 @@ def get_transformed_output_time(self, window, input_timestamp): class JoinAuctionBidFn(beam.DoFn): + @staticmethod def higher_bid(bid, other): if bid.price > other.price: @@ -171,6 +175,7 @@ def process(self, element): class WinningBids(beam.PTransform): + def __init__(self): #TODO: change this to be calculated by event generation expected_duration = 16667000 diff --git a/sdks/python/apache_beam/testing/benchmarks/wordcount/wordcount.py b/sdks/python/apache_beam/testing/benchmarks/wordcount/wordcount.py index 513ede47e80a..b4322930b0d5 100644 --- a/sdks/python/apache_beam/testing/benchmarks/wordcount/wordcount.py +++ b/sdks/python/apache_beam/testing/benchmarks/wordcount/wordcount.py @@ -23,6 +23,7 @@ class WordcountCostBenchmark(DataflowCostBenchmark): + def __init__(self): super().__init__() diff --git a/sdks/python/apache_beam/testing/datatype_inference.py b/sdks/python/apache_beam/testing/datatype_inference.py index b68f5ec4a125..d7ecf77c5d5f 100644 --- a/sdks/python/apache_beam/testing/datatype_inference.py +++ b/sdks/python/apache_beam/testing/datatype_inference.py @@ -64,9 +64,8 @@ def infer_typehints_schema(data): for row in data: for key, value in row.items(): column_data.setdefault(key, []).append(value) - column_types = OrderedDict([ - (key, infer_element_type(values)) for key, values in column_data.items() - ]) + column_types = OrderedDict([(key, infer_element_type(values)) + for key, values in column_data.items()]) return column_types @@ -101,8 +100,7 @@ def typehint_to_avro_type(value): column_types = infer_typehints_schema(data) avro_fields = [{ "name": str(key), "type": typehint_to_avro_type(value) - } for key, - value in column_types.items()] + } for key, value in column_types.items()] schema_dict = { "namespace": "example.avro", "name": "User", @@ -127,7 +125,6 @@ def infer_pyarrow_schema(data): for row in data: for key, value in row.items(): column_data.setdefault(key, []).append(value) - column_types = OrderedDict([ - (key, pa.array(value).type) for key, value in column_data.items() - ]) + column_types = OrderedDict([(key, pa.array(value).type) + for key, value in column_data.items()]) return pa.schema(list(column_types.items())) diff --git a/sdks/python/apache_beam/testing/datatype_inference_test.py b/sdks/python/apache_beam/testing/datatype_inference_test.py index 001752f8ab27..22dcdf5457c8 100644 --- a/sdks/python/apache_beam/testing/datatype_inference_test.py +++ b/sdks/python/apache_beam/testing/datatype_inference_test.py @@ -109,6 +109,7 @@ def nullify_data_and_schemas(test_data): """Add a row with all columns set to None and adjust the schemas accordingly. """ + def nullify_avro_schema(schema): """Add a 'null' type to every field.""" schema = schema.copy() @@ -149,8 +150,8 @@ def get_collumns_in_order(test_data): OrderedDict([(c, None) for c in columns]) ] test_case["type_schema"] = OrderedDict([ - (k, typehints.Union[v, type(None)]) for k, - v in test_case["type_schema"].items() + (k, typehints.Union[v, type(None)]) + for k, v in test_case["type_schema"].items() ]) test_case["avro_schema"] = nullify_avro_schema(test_case["avro_schema"]) nullified_test_data.append(test_case) @@ -161,6 +162,7 @@ def get_collumns_in_order(test_data): class DatatypeInferenceTest(unittest.TestCase): + @parameterized.expand([(d["name"], d["data"], d["type_schema"]) for d in TEST_DATA]) def test_infer_typehints_schema(self, _, data, schema): diff --git a/sdks/python/apache_beam/testing/extra_assertions.py b/sdks/python/apache_beam/testing/extra_assertions.py index 7820956c7c5e..5c9acf630706 100644 --- a/sdks/python/apache_beam/testing/extra_assertions.py +++ b/sdks/python/apache_beam/testing/extra_assertions.py @@ -20,6 +20,7 @@ class ExtraAssertionsMixin(object): + def assertUnhashableCountEqual(self, data1, data2): """Assert that two containers have the same items, with special treatment for numpy arrays. diff --git a/sdks/python/apache_beam/testing/extra_assertions_test.py b/sdks/python/apache_beam/testing/extra_assertions_test.py index 174fb54e2fa8..0f978bd5b2d2 100644 --- a/sdks/python/apache_beam/testing/extra_assertions_test.py +++ b/sdks/python/apache_beam/testing/extra_assertions_test.py @@ -26,6 +26,7 @@ class ExtraAssertionsMixinTest(ExtraAssertionsMixin, unittest.TestCase): + def test_assert_array_count_equal_strings(self): data1 = ["±♠Ωℑ", "hello", "world"] data2 = ["hello", "±♠Ωℑ", "world"] diff --git a/sdks/python/apache_beam/testing/load_tests/co_group_by_key_test.py b/sdks/python/apache_beam/testing/load_tests/co_group_by_key_test.py index 617e00d40f26..a8c11b355377 100644 --- a/sdks/python/apache_beam/testing/load_tests/co_group_by_key_test.py +++ b/sdks/python/apache_beam/testing/load_tests/co_group_by_key_test.py @@ -96,6 +96,7 @@ def __init__(self): self.iterations = self.get_option_or_default('iterations', 1) class _UngroupAndReiterate(beam.DoFn): + def __init__(self, input_tag, co_input_tag): self.input_tag = input_tag self.co_input_tag = co_input_tag diff --git a/sdks/python/apache_beam/testing/load_tests/combine_test.py b/sdks/python/apache_beam/testing/load_tests/combine_test.py index 9452730b88f2..4024f9e932ac 100644 --- a/sdks/python/apache_beam/testing/load_tests/combine_test.py +++ b/sdks/python/apache_beam/testing/load_tests/combine_test.py @@ -81,6 +81,7 @@ class CombineTest(LoadTest): + def __init__(self): super().__init__() self.fanout = self.get_option_or_default('fanout', 1) @@ -93,6 +94,7 @@ def __init__(self): sys.exit(1) class _GetElement(beam.DoFn): + def process(self, element): yield element diff --git a/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py index 96a1cd31e298..96d276bbe499 100644 --- a/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py +++ b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py @@ -44,6 +44,7 @@ class DataflowCostBenchmark(LoadTest): calculate the cost of the job later, as different accelerators have different billing rates per hour of use. """ + def __init__( self, metrics_namespace: Optional[str] = None, diff --git a/sdks/python/apache_beam/testing/load_tests/group_by_key_test.py b/sdks/python/apache_beam/testing/load_tests/group_by_key_test.py index 38724fc17391..3fee43574666 100644 --- a/sdks/python/apache_beam/testing/load_tests/group_by_key_test.py +++ b/sdks/python/apache_beam/testing/load_tests/group_by_key_test.py @@ -80,6 +80,7 @@ class GroupByKeyTest(LoadTest): + def __init__(self): super().__init__() self.fanout = self.get_option_or_default('fanout', 1) diff --git a/sdks/python/apache_beam/testing/load_tests/load_test.py b/sdks/python/apache_beam/testing/load_tests/load_test.py index 20dea3932b49..8ce5e52a4f98 100644 --- a/sdks/python/apache_beam/testing/load_tests/load_test.py +++ b/sdks/python/apache_beam/testing/load_tests/load_test.py @@ -32,6 +32,7 @@ class LoadTestOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( @@ -91,6 +92,7 @@ class LoadTest(object): If using InfluxDB with Basic HTTP authentication enabled, provide the following environment options: `INFLUXDB_USER` and `INFLUXDB_USER_PASSWORD`. """ + def __init__(self, metrics_namespace=None): # Be sure to set blocking to false for timeout_ms to work properly self.pipeline = TestPipeline(is_integration_test=True, blocking=False) diff --git a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py index caadbaca1e1e..c40fa80483a6 100644 --- a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py +++ b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py @@ -193,6 +193,7 @@ class MetricsReader(object): A :class:`MetricsReader` retrieves metrics from pipeline result, prepares it for publishers and setup publishers. """ + def __init__( self, project_name=None, @@ -287,8 +288,8 @@ def _prepare_extra_metrics( if not extra_metrics: extra_metrics = {} return [ - Metric(ts, metric_id, v, label=k).as_dict() for k, - v in extra_metrics.items() + Metric(ts, metric_id, v, label=k).as_dict() + for k, v in extra_metrics.items() ] def publish_values(self, labeled_values): @@ -299,8 +300,7 @@ def publish_values(self, labeled_values): """ metric_dicts = [ Metric(time.time(), uuid.uuid4().hex, value, label=label).as_dict() - for label, - value in labeled_values + for label, value in labeled_values ] for publisher in self.publishers: @@ -334,6 +334,7 @@ def _get_distributions(self, distributions, metric_id): class Metric(object): """Metric base class in ready-to-save format.""" + def __init__( self, submit_timestamp, metric_id, value, metric=None, label=None): """Initializes :class:`Metric` @@ -369,6 +370,7 @@ class CounterMetric(Metric): submit_timestamp (float): date-time of saving metric to database metric_id (uuid): unique id to identify test run """ + def __init__(self, counter_metric, submit_timestamp, metric_id): value = counter_metric.result super().__init__(submit_timestamp, metric_id, value, counter_metric) @@ -382,6 +384,7 @@ class DistributionMetric(Metric): submit_timestamp (float): date-time of saving metric to database metric_id (uuid): unique id to identify test run """ + def __init__(self, dist_metric, submit_timestamp, metric_id, metric_type): custom_label = dist_metric.key.metric.namespace + \ '_' + parse_step(dist_metric.key.step) + \ @@ -405,6 +408,7 @@ class RuntimeMetric(Metric): with runtime name metric_id(uuid): unique id to identify test run """ + def __init__(self, runtime_list, metric_id): value = self._prepare_runtime_metrics(runtime_list) submit_timestamp = time.time() @@ -431,6 +435,7 @@ def _prepare_runtime_metrics(self, distributions): class MetricsPublisher: """Base class for metrics publishers.""" + def publish(self, results): raise NotImplementedError @@ -438,6 +443,7 @@ def publish(self, results): class ConsoleMetricsPublisher(MetricsPublisher): """A :class:`ConsoleMetricsPublisher` publishes collected metrics to console output.""" + def publish(self, results): if len(results) > 0: log = "Load test results for test: %s and timestamp: %s:" \ @@ -454,6 +460,7 @@ def publish(self, results): class BigQueryMetricsPublisher(MetricsPublisher): """A :class:`BigQueryMetricsPublisher` publishes collected metrics to BigQuery output.""" + def __init__(self, project_name, table, dataset, bq_schema=None): if not bq_schema: bq_schema = SCHEMA @@ -472,6 +479,7 @@ def publish(self, results): class BigQueryClient(object): """A :class:`BigQueryClient` publishes collected metrics to BigQuery output.""" + def __init__(self, project_name, table, dataset, bq_schema=None): self.schema = bq_schema self._namespace = table @@ -522,6 +530,7 @@ def save(self, results): class InfluxDBMetricsPublisherOptions(object): + def __init__( self, measurement: str, @@ -544,6 +553,7 @@ def http_auth_enabled(self) -> bool: class InfluxDBMetricsPublisher(MetricsPublisher): """Publishes collected metrics to InfluxDB database.""" + def __init__(self, options: InfluxDBMetricsPublisherOptions): self.options = options @@ -571,6 +581,7 @@ def publish( def _build_payload( self, results: List[Mapping[str, Union[float, str, int]]]) -> str: + def build_kv(mapping, key): return '{}={}'.format(key, mapping[key]) @@ -590,6 +601,7 @@ def build_kv(mapping, key): class MeasureTime(beam.DoFn): """A distribution metric prepared to be added to pipeline as ParDo to measure runtime.""" + def __init__(self, namespace): """Initializes :class:`MeasureTime`. @@ -663,6 +675,7 @@ def process(self, element, timestamp=beam.DoFn.TimestampParam): class AssignTimestamps(beam.DoFn): """DoFn to assigned timestamps to elements.""" + def __init__(self): # Avoid having to use save_main_session self.time_fn = time.time diff --git a/sdks/python/apache_beam/testing/load_tests/microbenchmarks_test.py b/sdks/python/apache_beam/testing/load_tests/microbenchmarks_test.py index 34d4080c072f..03887fe8c45f 100644 --- a/sdks/python/apache_beam/testing/load_tests/microbenchmarks_test.py +++ b/sdks/python/apache_beam/testing/load_tests/microbenchmarks_test.py @@ -59,6 +59,7 @@ class MicroBenchmarksLoadTest(LoadTest): + def __init__(self): super().__init__() diff --git a/sdks/python/apache_beam/testing/load_tests/pardo_test.py b/sdks/python/apache_beam/testing/load_tests/pardo_test.py index 989ed2168a66..144a94430a3e 100644 --- a/sdks/python/apache_beam/testing/load_tests/pardo_test.py +++ b/sdks/python/apache_beam/testing/load_tests/pardo_test.py @@ -87,6 +87,7 @@ class ParDoTest(LoadTest): + def __init__(self): super().__init__() self.iterations = self.get_option_or_default('iterations') @@ -100,7 +101,9 @@ def __init__(self): 'state_cache_size=1000') def test(self): + class BaseCounterOperation(beam.DoFn): + def __init__(self, number_of_counters, number_of_operations): self.number_of_operations = number_of_operations self.counters = [] @@ -124,6 +127,7 @@ def process(self, element, state=state_param): yield element class CounterOperation(BaseCounterOperation): + def process(self, element): for _ in range(self.number_of_operations): for counter in self.counters: diff --git a/sdks/python/apache_beam/testing/load_tests/sideinput_test.py b/sdks/python/apache_beam/testing/load_tests/sideinput_test.py index 3b5dfdf38cd9..e064b6104b70 100644 --- a/sdks/python/apache_beam/testing/load_tests/sideinput_test.py +++ b/sdks/python/apache_beam/testing/load_tests/sideinput_test.py @@ -106,8 +106,10 @@ def materialize_as(self): 'these: {}'.format(list(self.SIDE_INPUT_TYPES.keys()))) def test(self): + class SequenceSideInputTestDoFn(beam.DoFn): """Iterate over first n side_input elements.""" + def __init__(self, first_n: int): self._first_n = first_n @@ -126,6 +128,7 @@ def process( class MappingSideInputTestDoFn(beam.DoFn): """Iterates over first n keys in the dictionary and checks the value.""" + def __init__(self, first_n: int): self._first_n = first_n @@ -142,6 +145,7 @@ def process( class AssignTimestamps(beam.DoFn): """Produces timestamped values. Timestamps are equal to the value of the element.""" + def __init__(self): # Avoid having to use save_main_session self.window = window @@ -150,6 +154,7 @@ def process(self, element: int) -> Iterable[window.TimestampedValue]: yield self.window.TimestampedValue(element, element) class GetSyntheticSDFOptions(beam.DoFn): + def __init__( self, elements_per_record: int, key_size: int, value_size: int): self.elements_per_record = elements_per_record diff --git a/sdks/python/apache_beam/testing/metric_result_matchers.py b/sdks/python/apache_beam/testing/metric_result_matchers.py index 3c0386535213..3828c65b2de8 100644 --- a/sdks/python/apache_beam/testing/metric_result_matchers.py +++ b/sdks/python/apache_beam/testing/metric_result_matchers.py @@ -62,6 +62,7 @@ def _matcher_or_equal_to(value_or_matcher): class MetricResultMatcher(BaseMatcher): """A PyHamcrest matcher that validates counter MetricResults.""" + def __init__( self, namespace=None, @@ -141,6 +142,7 @@ def describe_mismatch(self, metric_result, mismatch_description): class DistributionMatcher(BaseMatcher): """A PyHamcrest matcher that validates counter distributions.""" + def __init__( self, sum_value=None, count_value=None, min_value=None, max_value=None): self.sum_value = _matcher_or_equal_to(sum_value) diff --git a/sdks/python/apache_beam/testing/metric_result_matchers_test.py b/sdks/python/apache_beam/testing/metric_result_matchers_test.py index 3657356a9fe0..feff268b4841 100644 --- a/sdks/python/apache_beam/testing/metric_result_matchers_test.py +++ b/sdks/python/apache_beam/testing/metric_result_matchers_test.py @@ -103,6 +103,7 @@ def _create_metric_result(data_dict): class MetricResultMatchersTest(unittest.TestCase): + def test_matches_all_for_counter(self): metric_result = _create_metric_result(EVERYTHING_COUNTER) matcher = MetricResultMatcher( diff --git a/sdks/python/apache_beam/testing/pipeline_verifiers.py b/sdks/python/apache_beam/testing/pipeline_verifiers.py index 225e6d0dbae1..6989444bf522 100644 --- a/sdks/python/apache_beam/testing/pipeline_verifiers.py +++ b/sdks/python/apache_beam/testing/pipeline_verifiers.py @@ -56,6 +56,7 @@ class PipelineStateMatcher(BaseMatcher): Matcher compares the actual pipeline terminate state with expected. By default, `PipelineState.DONE` is used as expected state. """ + def __init__(self, expected_state=PipelineState.DONE): self.expected_state = expected_state @@ -85,6 +86,7 @@ class FileChecksumMatcher(BaseMatcher): Use apache_beam.io.filebasedsink to fetch file(s) from given path. File checksum is a hash string computed from content of file(s). """ + def __init__(self, file_path, expected_checksum, sleep_secs=None): """Initialize a FileChecksumMatcher object diff --git a/sdks/python/apache_beam/testing/pipeline_verifiers_test.py b/sdks/python/apache_beam/testing/pipeline_verifiers_test.py index 085339003699..a47a4ef0250a 100644 --- a/sdks/python/apache_beam/testing/pipeline_verifiers_test.py +++ b/sdks/python/apache_beam/testing/pipeline_verifiers_test.py @@ -45,6 +45,7 @@ class PipelineVerifiersTest(unittest.TestCase): + def setUp(self): self._mock_result = Mock() patch_retry(self, verifiers) diff --git a/sdks/python/apache_beam/testing/synthetic_pipeline.py b/sdks/python/apache_beam/testing/synthetic_pipeline.py index b18de244e3f8..e47bec5de9a5 100644 --- a/sdks/python/apache_beam/testing/synthetic_pipeline.py +++ b/sdks/python/apache_beam/testing/synthetic_pipeline.py @@ -145,6 +145,7 @@ def initial_splitting_zipf( class SyntheticStep(beam.DoFn): """A DoFn of which behavior can be controlled through prespecified parameters. """ + def __init__( self, per_element_delay_sec=0, @@ -189,6 +190,7 @@ def process(self, element): class NonLiquidShardingOffsetRangeTracker(OffsetRestrictionTracker): """An OffsetRangeTracker that doesn't allow splitting. """ + def try_split(self, split_offset): pass # Don't split. @@ -207,6 +209,7 @@ class SyntheticSDFStepRestrictionProvider(RestrictionProvider): If initial_splitting_uneven_chunks, produces uneven chunks. """ + def __init__( self, num_records, @@ -264,10 +267,12 @@ def get_synthetic_sdf_step( size_estimate_override=None, ): """A function which returns a SyntheticSDFStep with given parameters. """ + class SyntheticSDFStep(beam.DoFn): """A SplittableDoFn of which behavior can be controlled through prespecified parameters. """ + def __init__( self, per_element_delay_sec_arg, @@ -335,6 +340,7 @@ def process( class SyntheticSource(iobase.BoundedSource): """A custom source of a specified size. """ + def __init__(self, input_spec): """Initiates a synthetic source. @@ -344,6 +350,7 @@ def __init__(self, input_spec): Raises: ValueError: if input parameters are invalid. """ + def maybe_parse_byte_size(s): return parse_byte_size(s) if isinstance(s, str) else int(s) @@ -506,6 +513,7 @@ class SyntheticSDFSourceRestrictionProvider(RestrictionProvider): } """ + def initial_restriction(self, element): return OffsetRange(0, element['num_records']) @@ -591,6 +599,7 @@ class SyntheticSDFAsSource(beam.DoFn): During runtime, the DoFnRunner.process_with_sized_restriction() will feed a 'RestrictionTracker' based on a restriction to SDF.process(). """ + def process( self, element, @@ -606,6 +615,7 @@ def process( class ShuffleBarrier(beam.PTransform): + def expand(self, pc): return ( pc @@ -615,13 +625,13 @@ def expand(self, pc): class SideInputBarrier(beam.PTransform): + def expand(self, pc): return ( pc | beam.Map(rotate_key) | beam.Map( - lambda elem, - ignored: elem, + lambda elem, ignored: elem, beam.pvalue.AsIter(pc | beam.FlatMap(lambda elem: None)))) @@ -641,6 +651,7 @@ def merge_using_gbk(name, pc1, pc2): def merge_using_side_input(name, pc1, pc2): """Merges two given PCollections using side inputs.""" + def join_fn(val, _): # Ignoring side input return val @@ -658,7 +669,9 @@ def expand_using_gbk(name, pc): def expand_using_second_output(name, pc): """Expands a given PCollection into two copies using side outputs.""" + class ExpandFn(beam.DoFn): + def process(self, element): yield beam.pvalue.TaggedOutput('second_out', element) yield element @@ -906,6 +919,7 @@ def run(argv=None, save_main_session=True): class StatefulLoadGenerator(beam.PTransform): """A PTransform for generating random data using Timers API.""" + def __init__(self, input_options, num_keys=100): self.num_records = input_options['num_records'] self.key_size = input_options['key_size'] @@ -914,6 +928,7 @@ def __init__(self, input_options, num_keys=100): @typehints.with_output_types(Tuple[bytes, bytes]) class GenerateKeys(beam.DoFn): + def __init__(self, num_keys, key_size): self.num_keys = num_keys self.key_size = key_size diff --git a/sdks/python/apache_beam/testing/synthetic_pipeline_test.py b/sdks/python/apache_beam/testing/synthetic_pipeline_test.py index 7e973a3ca7d7..5af622c7c0b8 100644 --- a/sdks/python/apache_beam/testing/synthetic_pipeline_test.py +++ b/sdks/python/apache_beam/testing/synthetic_pipeline_test.py @@ -168,6 +168,7 @@ def test_synthetic_step_split_provider_no_liquid_sharding(self): self.assertEqual(tracker.try_split(3), None) def test_synthetic_source(self): + def assert_size(element, expected_size): assert len(element) == expected_size diff --git a/sdks/python/apache_beam/testing/test_pipeline_test.py b/sdks/python/apache_beam/testing/test_pipeline_test.py index 06946c7a7efb..ecc70db41f91 100644 --- a/sdks/python/apache_beam/testing/test_pipeline_test.py +++ b/sdks/python/apache_beam/testing/test_pipeline_test.py @@ -33,6 +33,7 @@ # A simple matcher that is ued for testing extra options appending. class SimpleMatcher(BaseMatcher): + def _matches(self, item): return True @@ -49,6 +50,7 @@ class TestPipelineTest(unittest.TestCase): # Used for testing pipeline option creation. class TestParsingOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument('--job', action='store', help='mock job') diff --git a/sdks/python/apache_beam/testing/test_stream.py b/sdks/python/apache_beam/testing/test_stream.py index 12a87bf0a68a..b4a48f4efda5 100644 --- a/sdks/python/apache_beam/testing/test_stream.py +++ b/sdks/python/apache_beam/testing/test_stream.py @@ -58,6 +58,7 @@ @total_ordering class Event(metaclass=ABCMeta): # type: ignore[misc] """Test stream event to be emitted during execution of a TestStream.""" + @abstractmethod def __eq__(self, other): raise NotImplementedError @@ -101,6 +102,7 @@ def from_runner_api(proto, element_coder): class ElementEvent(Event): """Element-producing test stream event.""" + def __init__(self, timestamped_values, tag=None): self.timestamped_values = timestamped_values self.tag = tag @@ -142,6 +144,7 @@ def __repr__(self): class WatermarkEvent(Event): """Watermark-advancing test stream event.""" + def __init__(self, new_watermark, tag=None): self.new_watermark = Timestamp.of(new_watermark) self.tag = tag @@ -177,6 +180,7 @@ def __repr__(self): class ProcessingTimeEvent(Event): """Processing time-advancing test stream event.""" + def __init__(self, advance_by): self.advance_by = Duration.of(advance_by) @@ -213,6 +217,7 @@ class WindowedValueHolderMeta(type): The override is needed because WindowedValueHolder elements encoded then decoded become Row elements. """ + def __instancecheck__(cls, other): """Checks if a beam.Row typed instance is a WindowedValueHolder. """ @@ -268,6 +273,7 @@ class TestStream(PTransform): output or only one output tag has been used. Otherwise a dictionary of output names to PCollections will be returned. """ + def __init__( self, coder=coders.FastPrimitivesCoder(), @@ -423,6 +429,7 @@ def from_runner_api_parameter(ptransform, payload, context): class TimingInfo(object): + def __init__(self, processing_time, watermark): self._processing_time = Timestamp.of(processing_time) self._watermark = Timestamp.of(watermark) @@ -468,6 +475,7 @@ class ReverseTestStream(PTransform): and elements come in order and are outputted in the same order that they came in. """ + def __init__( self, sample_resolution_sec, output_tag, coder=None, output_format=None): self._sample_resolution_sec = sample_resolution_sec diff --git a/sdks/python/apache_beam/testing/test_stream_it_test.py b/sdks/python/apache_beam/testing/test_stream_it_test.py index 0e293eda3713..f5bd4dda9f5c 100644 --- a/sdks/python/apache_beam/testing/test_stream_it_test.py +++ b/sdks/python/apache_beam/testing/test_stream_it_test.py @@ -44,6 +44,7 @@ def supported(runners): runners = [runners] def inner(fn): + @wraps(fn) def wrapped(self): if self.runner_name not in runners: @@ -59,6 +60,7 @@ def wrapped(self): class TestStreamIntegrationTests(unittest.TestCase): + @classmethod def setUpClass(cls): cls.test_pipeline = TestPipeline(is_integration_test=True) @@ -80,6 +82,7 @@ def test_basic_execution(self): ]).advance_watermark_to_infinity()) class RecordFn(beam.DoFn): + def process( self, element=beam.DoFn.ElementParam, @@ -123,6 +126,7 @@ def test_multiple_outputs(self): numbers_elements, tag='numbers')) class RecordFn(beam.DoFn): + def process( self, element=beam.DoFn.ElementParam, diff --git a/sdks/python/apache_beam/testing/test_stream_service.py b/sdks/python/apache_beam/testing/test_stream_service.py index 1f63cbf7274f..177c403a52e3 100644 --- a/sdks/python/apache_beam/testing/test_stream_service.py +++ b/sdks/python/apache_beam/testing/test_stream_service.py @@ -30,6 +30,7 @@ class TestStreamServiceController( This server is used as a way for TestStreams to receive events from file. """ + def __init__(self, reader, endpoint=None, exception_handler=None): self._server = grpc.server(ThreadPoolExecutor(max_workers=10)) self._server_started = False diff --git a/sdks/python/apache_beam/testing/test_stream_service_test.py b/sdks/python/apache_beam/testing/test_stream_service_test.py index a04fa2303d08..a031b542be33 100644 --- a/sdks/python/apache_beam/testing/test_stream_service_test.py +++ b/sdks/python/apache_beam/testing/test_stream_service_test.py @@ -36,6 +36,7 @@ class EventsReader: + def __init__(self, expected_key): self._expected_key = expected_key @@ -57,6 +58,7 @@ def read_multiple(self, keys): class TestStreamServiceTest(unittest.TestCase): + def setUp(self): self.controller = TestStreamServiceController( EventsReader(expected_key=[('full', EXPECTED_KEY)])) diff --git a/sdks/python/apache_beam/testing/test_stream_test.py b/sdks/python/apache_beam/testing/test_stream_test.py index 13b9e3c17696..b314aed86f39 100644 --- a/sdks/python/apache_beam/testing/test_stream_test.py +++ b/sdks/python/apache_beam/testing/test_stream_test.py @@ -52,6 +52,7 @@ class TestStreamTest(unittest.TestCase): + def test_basic_test_stream(self): test_stream = (TestStream() .advance_watermark_to(0) @@ -111,6 +112,7 @@ def test_basic_execution(self): .advance_watermark_to_infinity()) # yapf: disable class RecordFn(beam.DoFn): + def process( self, element=beam.DoFn.ElementParam, @@ -154,6 +156,7 @@ def test_multiple_outputs(self): .add_elements(numbers_elements, tag='numbers')) # yapf: disable class RecordFn(beam.DoFn): + def process( self, element=beam.DoFn.ElementParam, @@ -279,6 +282,7 @@ def test_dicts_not_interpreted_as_windowed_values(self): .advance_watermark_to_infinity()) # yapf: disable class RecordFn(beam.DoFn): + def process( self, element=beam.DoFn.ElementParam, @@ -312,6 +316,7 @@ def test_windowed_values_interpreted_correctly(self): .advance_watermark_to_infinity()) # yapf: disable class RecordFn(beam.DoFn): + def process( self, element=beam.DoFn.ElementParam, @@ -493,6 +498,7 @@ def test_basic_execution_batch_sideinputs(self): | beam.Map(lambda t: window.TimestampedValue(t, t))) class RecordFn(beam.DoFn): + def process( self, elm=beam.DoFn.ElementParam, @@ -527,6 +533,7 @@ def test_basic_execution_sideinputs(self): side_stream = test_stream['side'] class RecordFn(beam.DoFn): + def process( self, elm=beam.DoFn.ElementParam, @@ -559,6 +566,7 @@ def test_basic_execution_batch_sideinputs_fixed_windows(self): | beam.WindowInto(window.FixedWindows(2))) class RecordFn(beam.DoFn): + def process( self, elm=beam.DoFn.ElementParam, @@ -609,6 +617,7 @@ def test_basic_execution_sideinputs_fixed_windows(self): | 'side windowInto' >> beam.WindowInto(window.FixedWindows(3))) class RecordFn(beam.DoFn): + def process( self, elm=beam.DoFn.ElementParam, @@ -710,6 +719,7 @@ def test_basic_execution_with_service(self): ] class InMemoryEventReader: + def read_multiple(self, unused_keys): for e in test_stream_proto_events: yield e @@ -720,6 +730,7 @@ def read_multiple(self, unused_keys): test_stream = TestStream(coder=coder, endpoint=service.endpoint) class RecordFn(beam.DoFn): + def process( self, element=beam.DoFn.ElementParam, @@ -748,6 +759,7 @@ def process( class ReverseTestStreamTest(unittest.TestCase): + def test_basic_execution(self): test_stream = (TestStream() .advance_watermark_to(0) diff --git a/sdks/python/apache_beam/testing/test_utils.py b/sdks/python/apache_beam/testing/test_utils.py index 03d77df950d5..656b81c3a861 100644 --- a/sdks/python/apache_beam/testing/test_utils.py +++ b/sdks/python/apache_beam/testing/test_utils.py @@ -39,6 +39,7 @@ class TempDir(object): """Context Manager to create and clean-up a temporary directory.""" + def __init__(self): self._tempdir = tempfile.mkdtemp() @@ -157,6 +158,7 @@ class PullResponseMessage(object): Utility class for ``create_pull_response``. """ + def __init__( self, data, @@ -230,6 +232,7 @@ def read_files_from_pattern(file_pattern): class LCGenerator: """A pure Python implementation of linear congruential generator.""" + def __init__(self, a=0x5DEECE66D, c=0xB, bits=48): self._a = a self._c = c diff --git a/sdks/python/apache_beam/testing/test_utils_test.py b/sdks/python/apache_beam/testing/test_utils_test.py index cef19f67b4ea..9277be838bbd 100644 --- a/sdks/python/apache_beam/testing/test_utils_test.py +++ b/sdks/python/apache_beam/testing/test_utils_test.py @@ -32,6 +32,7 @@ class TestUtilsTest(unittest.TestCase): + def setUp(self): utils.patch_retry(self, utils) self.tmpdir = tempfile.mkdtemp() diff --git a/sdks/python/apache_beam/testing/util.py b/sdks/python/apache_beam/testing/util.py index 8532d1c1f97d..d5031f7175aa 100644 --- a/sdks/python/apache_beam/testing/util.py +++ b/sdks/python/apache_beam/testing/util.py @@ -67,7 +67,9 @@ def contains_in_any_order(iterable): Arguments: iterable: An iterable of hashable objects. """ + class InAnyOrder(object): + def __init__(self, iterable): self._counter = collections.Counter(iterable) @@ -84,6 +86,7 @@ def __repr__(self): class _EqualToPerWindowMatcher(object): + def __init__(self, expected_window_to_elements): self._expected_window_to_elements = expected_window_to_elements @@ -151,6 +154,7 @@ def equal_to_per_window(expected_window_to_elements): # other. However, only permutations of the top level are checked. Therefore # [1,2] and [2,1] are considered equal and [[1,2]] and [[2,1]] are not. def equal_to(expected, equals_fn=None): + def _equal(actual, equals_fn=equals_fn): expected_list = list(expected) @@ -201,6 +205,7 @@ def matches_all(expected): expected: A list of elements or hamcrest matchers to be used to match the elements of a single PCollection. """ + def _matches(actual): from hamcrest.core import assert_that as hamcrest_assert from hamcrest.library.collection import contains_inanyorder @@ -212,6 +217,7 @@ def _matches(actual): def is_empty(): + def _empty(actual): actual = list(actual) if actual: @@ -226,6 +232,7 @@ def is_not_empty(): some data in it. :return: """ + def _not_empty(actual): actual = list(actual) if not actual: @@ -289,6 +296,7 @@ def assert_that( use_global_window = True class ReifyTimestampWindow(DoFn): + def process( self, element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam): # This returns TestWindowedValue instead of @@ -297,10 +305,12 @@ def process( return [TestWindowedValue(element, timestamp, [window])] class AddWindow(DoFn): + def process(self, element, window=DoFn.WindowParam): yield element, window class AssertThat(PTransform): + def expand(self, pcoll): if reify_windows: pcoll = pcoll | ParDo(ReifyTimestampWindow()) diff --git a/sdks/python/apache_beam/testing/util_test.py b/sdks/python/apache_beam/testing/util_test.py index ba3c743c03f3..0bb9b118a2be 100644 --- a/sdks/python/apache_beam/testing/util_test.py +++ b/sdks/python/apache_beam/testing/util_test.py @@ -41,6 +41,7 @@ class UtilTest(unittest.TestCase): + def test_assert_that_passes(self): with TestPipeline() as p: assert_that(p | Create([1, 2, 3]), equal_to([1, 2, 3])) diff --git a/sdks/python/apache_beam/tools/coders_microbenchmark.py b/sdks/python/apache_beam/tools/coders_microbenchmark.py index a5a0c25c4ef9..86abb3d72efb 100644 --- a/sdks/python/apache_beam/tools/coders_microbenchmark.py +++ b/sdks/python/apache_beam/tools/coders_microbenchmark.py @@ -58,7 +58,9 @@ def coder_benchmark_factory(coder, generate_fn): coder: coder to use to encode an element. generate_fn: a callable that generates an element. """ + class CoderBenchmark(object): + def __init__(self, num_elements_per_benchmark): self._coder = coders.IterableCoder(coder) self._list = [generate_fn() for _ in range(num_elements_per_benchmark)] @@ -80,7 +82,9 @@ def batch_row_coder_benchmark_factory(generate_fn, use_batch): coder: coder to use to encode an element. generate_fn: a callable that generates an element. """ + class CoderBenchmark(object): + def __init__(self, num_elements_per_benchmark): self._use_batch = use_batch row_instance = generate_fn() diff --git a/sdks/python/apache_beam/tools/fn_api_runner_microbenchmark.py b/sdks/python/apache_beam/tools/fn_api_runner_microbenchmark.py index a73b2282e3e8..c80587638f2e 100644 --- a/sdks/python/apache_beam/tools/fn_api_runner_microbenchmark.py +++ b/sdks/python/apache_beam/tools/fn_api_runner_microbenchmark.py @@ -115,6 +115,7 @@ def _build_serial_stages( def run_single_pipeline(size): + def _pipeline_runner(): with beam.Pipeline(runner=FnApiRunner()) as p: for i in range(NUM_PARALLEL_STAGES): diff --git a/sdks/python/apache_beam/tools/microbenchmarks_test.py b/sdks/python/apache_beam/tools/microbenchmarks_test.py index 9c7d0074ebd4..d5cef2bbaca3 100644 --- a/sdks/python/apache_beam/tools/microbenchmarks_test.py +++ b/sdks/python/apache_beam/tools/microbenchmarks_test.py @@ -28,6 +28,7 @@ class MicrobenchmarksTest(unittest.TestCase): + def test_coders_microbenchmark(self): # Right now, we don't evaluate performance impact, only check that # microbenchmark code can successfully run. diff --git a/sdks/python/apache_beam/tools/runtime_type_check_microbenchmark.py b/sdks/python/apache_beam/tools/runtime_type_check_microbenchmark.py index 6ba01b9a5172..79b158e34dea 100644 --- a/sdks/python/apache_beam/tools/runtime_type_check_microbenchmark.py +++ b/sdks/python/apache_beam/tools/runtime_type_check_microbenchmark.py @@ -41,12 +41,14 @@ @beam.typehints.with_input_types(Tuple[int, ...]) class SimpleInput(beam.DoFn): + def process(self, element, *args, **kwargs): yield element @beam.typehints.with_output_types(Tuple[int, ...]) class SimpleOutput(beam.DoFn): + def process(self, element, *args, **kwargs): yield element @@ -54,6 +56,7 @@ def process(self, element, *args, **kwargs): @beam.typehints.with_input_types( Tuple[int, str, Tuple[float, ...], Iterable[int], Union[str, int]]) class NestedInput(beam.DoFn): + def process(self, element, *args, **kwargs): yield element @@ -61,6 +64,7 @@ def process(self, element, *args, **kwargs): @beam.typehints.with_output_types( Tuple[int, str, Tuple[float, ...], Iterable[int], Union[str, int]]) class NestedOutput(beam.DoFn): + def process(self, element, *args, **kwargs): yield element diff --git a/sdks/python/apache_beam/tools/teststream_microbenchmark.py b/sdks/python/apache_beam/tools/teststream_microbenchmark.py index 7c5bb6135b5c..9c02e09fd269 100644 --- a/sdks/python/apache_beam/tools/teststream_microbenchmark.py +++ b/sdks/python/apache_beam/tools/teststream_microbenchmark.py @@ -58,6 +58,7 @@ class RekeyElements(beam.DoFn): + def process(self, element): _, values = element return [(random.randint(0, 1000), v) for v in values] @@ -77,6 +78,7 @@ def _build_serial_stages(input_pc, num_serial_stages, stage_count): def run_single_pipeline(size): + def _pipeline_runner(): with beam.Pipeline(runner=DirectRunner()) as p: ts = TestStream().advance_watermark_to(0) diff --git a/sdks/python/apache_beam/tools/utils.py b/sdks/python/apache_beam/tools/utils.py index e3df3f2c1c6f..9f91de3ab2bb 100644 --- a/sdks/python/apache_beam/tools/utils.py +++ b/sdks/python/apache_beam/tools/utils.py @@ -123,6 +123,7 @@ def run_benchmarks(benchmark_suite, verbose=True): A dictionary of the form string -> list of floats. Keys of the dictionary are benchmark names, values are execution times in seconds for each run. """ + def run(benchmark: BenchmarkFactoryFn, size: int): # Contain each run of a benchmark inside a function so that any temporary # objects can be garbage-collected after the run. diff --git a/sdks/python/apache_beam/transforms/batch_dofn_test.py b/sdks/python/apache_beam/transforms/batch_dofn_test.py index d2aceb371492..a84596483137 100644 --- a/sdks/python/apache_beam/transforms/batch_dofn_test.py +++ b/sdks/python/apache_beam/transforms/batch_dofn_test.py @@ -31,22 +31,26 @@ class ElementDoFn(beam.DoFn): + def process(self, element: int, *args, **kwargs) -> Iterator[float]: yield element / 2 class BatchDoFn(beam.DoFn): + def process_batch(self, batch: List[int], *args, **kwargs) -> Iterator[List[float]]: yield [element / 2 for element in batch] class NoReturnAnnotation(beam.DoFn): + def process_batch(self, batch: List[int], *args, **kwargs): yield [element * 2 for element in batch] class OverrideTypeInference(beam.DoFn): + def process_batch(self, batch, *args, **kwargs): yield [element * 2 for element in batch] @@ -58,6 +62,7 @@ def get_output_batch_type(self, input_element_type): class EitherDoFn(beam.DoFn): + def process(self, element: int, *args, **kwargs) -> Iterator[float]: yield element / 2 @@ -67,6 +72,7 @@ def process_batch(self, batch: List[int], *args, class ElementToBatchDoFn(beam.DoFn): + @beam.DoFn.yields_batches def process(self, element: int, *args, **kwargs) -> Iterator[List[int]]: yield [element] * element @@ -76,6 +82,7 @@ def infer_output_type(self, input_element_type): class BatchToElementDoFn(beam.DoFn): + @beam.DoFn.yields_elements def process_batch(self, batch: List[int], *args, **kwargs) -> Iterator[Tuple[int, int]]: @@ -146,6 +153,7 @@ def get_test_class_name(cls, num, params_dict): ], class_name_func=get_test_class_name) class BatchDoFnParameterizedTest(unittest.TestCase): + def test_process_defined(self): self.assertEqual(self.dofn._process_defined, self.expected_process_defined) @@ -169,6 +177,7 @@ def test_can_yield_batches(self): class NoInputAnnotation(beam.DoFn): + def process_batch(self, batch, *args, **kwargs): yield [element * 2 for element in batch] @@ -177,6 +186,7 @@ class MismatchedBatchProducingDoFn(beam.DoFn): """A DoFn that produces batches from both process and process_batch, with mismatched return types (one yields floats, the other ints). Should yield a construction time error when applied.""" + @beam.DoFn.yields_batches def process(self, element: int, *args, **kwargs) -> Iterator[List[int]]: yield [element] @@ -190,6 +200,7 @@ class MismatchedElementProducingDoFn(beam.DoFn): """A DoFn that produces elements from both process and process_batch, with mismatched return types (one yields floats, the other ints). Should yield a construction time error when applied.""" + def process(self, element: int, *args, **kwargs) -> Iterator[float]: yield element / 2 @@ -199,12 +210,14 @@ def process_batch(self, batch: List[int], *args, **kwargs) -> Iterator[int]: class NoElementOutputAnnotation(beam.DoFn): + def process_batch(self, batch: List[int], *args, **kwargs) -> Iterator[List[int]]: yield [element * 2 for element in batch] class BatchDoFnTest(unittest.TestCase): + def test_map_pardo(self): # verify batch dofn accessors work well with beam.Map generated DoFn # checking this in parameterized test causes a circular reference issue @@ -223,7 +236,9 @@ def test_no_input_annotation_raises(self): _ = pc | beam.ParDo(NoInputAnnotation()) def test_unsupported_dofn_param_raises(self): + class BadParam(beam.DoFn): + @no_type_check def process_batch(self, batch: List[int], key=beam.DoFn.KeyParam): yield batch * key @@ -265,8 +280,7 @@ def test_cant_infer_batchconverter_input_raises(self): pc = p | beam.Create(['a', 'b', 'c']) with self.assertRaisesRegex( - TypeError, - # Error should mention "input", and the name of the DoFn + TypeError, # Error should mention "input", and the name of the DoFn r'input.*BatchDoFn.*'): _ = pc | beam.ParDo(BatchDoFn()) diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py index 647e08db7aaa..49a772440acc 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py @@ -37,6 +37,7 @@ @pytest.mark.it_validatesrunner class CombineFnLifecycleTest(unittest.TestCase): + def setUp(self): self.pipeline = TestPipeline(is_integration_test=True) @@ -60,6 +61,7 @@ def test_combining_value_state(self): {'runner': fn_api_runner.FnApiRunner, 'pickler': 'cloudpickle'}, ]) # yapf: disable class LocalCombineFnLifecycleTest(unittest.TestCase): + def tearDown(self): CallSequenceEnforcingCombineFn.instances.clear() diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index 8b05e8da1df5..210a2e7d27e7 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -74,6 +74,7 @@ class CombinerWithoutDefaults(ptransform.PTransform): """Super class to inherit without_defaults to built-in Combiners.""" + def __init__(self, has_defaults=True): super().__init__() self.has_defaults = has_defaults @@ -89,8 +90,10 @@ def without_defaults(self): class Mean(object): """Combiners for computing arithmetic means of elements.""" + class Globally(CombinerWithoutDefaults): """combiners.Mean.Globally computes the arithmetic mean of the elements.""" + def expand(self, pcoll): if self.has_defaults: return pcoll | core.CombineGlobally(MeanCombineFn()) @@ -99,6 +102,7 @@ def expand(self, pcoll): class PerKey(ptransform.PTransform): """combiners.Mean.PerKey finds the means of the values for each key.""" + def expand(self, pcoll): return pcoll | core.CombinePerKey(MeanCombineFn()) @@ -109,6 +113,7 @@ def expand(self, pcoll): @with_output_types(float) class MeanCombineFn(core.CombineFn): """CombineFn for computing an arithmetic mean.""" + def create_accumulator(self): return (0, 0) @@ -136,10 +141,12 @@ def for_input_type(self, input_type): class Count(object): """Combiners for counting elements.""" + @with_input_types(T) @with_output_types(int) class Globally(CombinerWithoutDefaults): """combiners.Count.Globally counts the total number of elements.""" + def expand(self, pcoll): if self.has_defaults: return pcoll | core.CombineGlobally(CountCombineFn()) @@ -150,6 +157,7 @@ def expand(self, pcoll): @with_output_types(Tuple[K, int]) class PerKey(ptransform.PTransform): """combiners.Count.PerKey counts how many elements each unique key has.""" + def expand(self, pcoll): return pcoll | core.CombinePerKey(CountCombineFn()) @@ -157,6 +165,7 @@ def expand(self, pcoll): @with_output_types(Tuple[T, int]) class PerElement(ptransform.PTransform): """combiners.Count.PerElement counts how many times each element occurs.""" + def expand(self, pcoll): paired_with_void_type = typehints.Tuple[pcoll.element_type, Any] output_type = typehints.KV[pcoll.element_type, int] @@ -172,6 +181,7 @@ def expand(self, pcoll): @with_output_types(int) class CountCombineFn(core.CombineFn): """CombineFn for computing PCollection size.""" + def create_accumulator(self): return 0 @@ -201,6 +211,7 @@ class Of(CombinerWithoutDefaults): to which it is applied, where "greatest" is determined by a function supplied as the `key` or `reverse` arguments. """ + def __init__(self, n, key=None, reverse=False): """Creates a global Top operation. @@ -256,6 +267,7 @@ class PerKey(ptransform.PTransform): "greatest" is determined by a function supplied as the `key` or `reverse` arguments. """ + def __init__(self, n, key=None, reverse=False): """Creates a per-key Top operation. @@ -325,6 +337,7 @@ def SmallestPerKey(pcoll, n, *, key=None): @with_input_types(T) @with_output_types(Tuple[None, List[T]]) class _TopPerBundle(core.DoFn): + def __init__(self, n, key, reverse): self._n = n self._compare = operator.gt if reverse else None @@ -359,6 +372,7 @@ def finish_bundle(self): @with_input_types(Tuple[None, Iterable[List[T]]]) @with_output_types(List[T]) class _MergeTopPerBundle(core.DoFn): + def __init__(self, n, key, reverse): self._n = n self._compare = operator.gt if reverse else None @@ -433,6 +447,7 @@ class TopCombineFn(core.CombineFn): reverse: (optional) whether to order things smallest to largest, rather than largest to smallest """ + def __init__(self, n, key=None, reverse=False): self._n = n self._compare = operator.gt if reverse else operator.lt @@ -546,11 +561,13 @@ def extract_output(self, accumulator, *args, **kwargs): class Largest(TopCombineFn): + def default_label(self): return 'Largest(%s)' % self._n class Smallest(TopCombineFn): + def __init__(self, n): super().__init__(n, reverse=True) @@ -567,6 +584,7 @@ class Sample(object): @with_output_types(List[T]) class FixedSizeGlobally(CombinerWithoutDefaults): """Sample n elements from the input PCollection without replacement.""" + def __init__(self, n): super().__init__() self._n = n @@ -588,6 +606,7 @@ def default_label(self): @with_output_types(Tuple[K, List[V]]) class FixedSizePerKey(ptransform.PTransform): """Sample n elements associated with each key without replacement.""" + def __init__(self, n): self._n = n @@ -605,6 +624,7 @@ def default_label(self): @with_output_types(List[T]) class SampleCombineFn(core.CombineFn): """CombineFn for all Sample transforms.""" + def __init__(self, n): super().__init__() # Most of this combiner's work is done by a TopCombineFn. We could just @@ -640,6 +660,7 @@ def teardown(self): class _TupleCombineFnBase(core.CombineFn): + def __init__(self, *combiners, merge_accumulators_batch_size=None): self._combiners = [core.CombineFn.maybe_from_callable(c) for c in combiners] self._named_combiners = combiners @@ -680,21 +701,21 @@ def merge_accumulators(self, accumulators, *args, **kwargs): if len(accumulators_batch) == 1: break result = [ - c.merge_accumulators(a, *args, **kwargs) for c, - a in zip(self._combiners, zip(*accumulators_batch)) + c.merge_accumulators(a, *args, **kwargs) + for c, a in zip(self._combiners, zip(*accumulators_batch)) ] return result def compact(self, accumulator, *args, **kwargs): return [ - c.compact(a, *args, **kwargs) for c, - a in zip(self._combiners, accumulator) + c.compact(a, *args, **kwargs) + for c, a in zip(self._combiners, accumulator) ] def extract_output(self, accumulator, *args, **kwargs): return tuple( - c.extract_output(a, *args, **kwargs) for c, - a in zip(self._combiners, accumulator)) + c.extract_output(a, *args, **kwargs) + for c, a in zip(self._combiners, accumulator)) def teardown(self, *args, **kwargs): for c in reversed(self._combiners): @@ -708,11 +729,11 @@ class TupleCombineFn(_TupleCombineFnBase): combining the k-th element of each tuple with the k-th CombineFn, outputting a new N-tuple of combined values. """ + def add_input(self, accumulator, element, *args, **kwargs): return [ - c.add_input(a, e, *args, **kwargs) for c, - a, - e in zip(self._combiners, accumulator, element) + c.add_input(a, e, *args, **kwargs) + for c, a, e in zip(self._combiners, accumulator, element) ] def with_common_input(self): @@ -726,10 +747,11 @@ class SingleInputTupleCombineFn(_TupleCombineFnBase): applying each CombineFn to each input, producing an N-tuple of the outputs corresponding to each of the N CombineFn's outputs. """ + def add_input(self, accumulator, element, *args, **kwargs): return [ - c.add_input(a, element, *args, **kwargs) for c, - a in zip(self._combiners, accumulator) + c.add_input(a, element, *args, **kwargs) + for c, a in zip(self._combiners, accumulator) ] @@ -737,6 +759,7 @@ def add_input(self, accumulator, element, *args, **kwargs): @with_output_types(List[T]) class ToList(CombinerWithoutDefaults): """A global CombineFn that condenses a PCollection into a single list.""" + def expand(self, pcoll): if self.has_defaults: return pcoll | self.label >> core.CombineGlobally(ToListCombineFn()) @@ -749,6 +772,7 @@ def expand(self, pcoll): @with_output_types(List[T]) class ToListCombineFn(core.CombineFn): """CombineFn for to_list.""" + def create_accumulator(self): return [] @@ -767,6 +791,7 @@ def extract_output(self, accumulator): @with_output_types(T) class ConcatListCombineFn(core.CombineFn): """CombineFn for concatenating lists together.""" + def create_accumulator(self): return [] @@ -789,6 +814,7 @@ class ToDict(CombinerWithoutDefaults): If multiple values are associated with the same key, only one of the values will be present in the resulting dict. """ + def expand(self, pcoll): if self.has_defaults: return pcoll | self.label >> core.CombineGlobally(ToDictCombineFn()) @@ -801,6 +827,7 @@ def expand(self, pcoll): @with_output_types(Dict[K, V]) class ToDictCombineFn(core.CombineFn): """CombineFn for to_dict.""" + def create_accumulator(self): return {} @@ -823,6 +850,7 @@ def extract_output(self, accumulator): @with_output_types(Set[T]) class ToSet(CombinerWithoutDefaults): """A global CombineFn that condenses a PCollection into a set.""" + def expand(self, pcoll): if self.has_defaults: return pcoll | self.label >> core.CombineGlobally(ToSetCombineFn()) @@ -835,6 +863,7 @@ def expand(self, pcoll): @with_output_types(Set[T]) class ToSetCombineFn(core.CombineFn): """CombineFn for ToSet.""" + def create_accumulator(self): return set() @@ -851,6 +880,7 @@ def extract_output(self, accumulator): class _CurriedFn(core.CombineFn): """Wrapped CombineFn with extra arguments.""" + def __init__(self, fn, args, kwargs): self.fn = fn self.args = args @@ -890,6 +920,7 @@ def curry_combine_fn(fn, args, kwargs): class PhasedCombineFnExecutor(object): """Executor for phases of combine operations.""" + def __init__(self, phase, fn, args, kwargs): self.combine_fn = curry_combine_fn(fn, args, kwargs) @@ -927,11 +958,13 @@ def convert_to_accumulator(self, element): class Latest(object): """Combiners for computing the latest element""" + @with_input_types(T) @with_output_types(T) class Globally(CombinerWithoutDefaults): """Compute the element with the latest timestamp from a PCollection.""" + @staticmethod def add_timestamp(element, timestamp=core.DoFn.TimestampParam): return [(element, timestamp)] @@ -955,6 +988,7 @@ def expand(self, pcoll): class PerKey(ptransform.PTransform): """Compute elements with the latest timestamp for each key from a keyed PCollection""" + @staticmethod def add_timestamp(element, timestamp=core.DoFn.TimestampParam): key, value = element @@ -973,6 +1007,7 @@ def expand(self, pcoll): class LatestCombineFn(core.CombineFn): """CombineFn to get the element with the latest timestamp from a PCollection.""" + def create_accumulator(self): return (None, window.MIN_TIMESTAMP) diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py index a8979239f831..b48a3db32203 100644 --- a/sdks/python/apache_beam/transforms/combiners_test.py +++ b/sdks/python/apache_beam/transforms/combiners_test.py @@ -62,6 +62,7 @@ class SortedConcatWithCounters(beam.CombineFn): """CombineFn for incrementing three different counters: counter, distribution, gauge, at the same time concatenating words.""" + def __init__(self): beam.CombineFn.__init__(self) self.word_counter = Metrics.counter(self.__class__, 'word_counter') @@ -91,6 +92,7 @@ def extract_output(self, acc): class CombineTest(unittest.TestCase): + def test_builtin_combines(self): with TestPipeline() as pipeline: @@ -208,6 +210,7 @@ def test_top_key(self): [('a', [4, 3, 2])]) def test_sharded_top_combine_fn(self): + def test_combine_fn(combine_fn, shards, expected): accumulators = [ combine_fn.add_inputs(combine_fn.create_accumulator(), shard) @@ -222,6 +225,7 @@ def test_combine_fn(combine_fn, shards, expected): [1000, 999, 999, 998, 998]) def test_combine_per_key_top_display_data(self): + def individual_test_per_key_dd(combineFn): transform = beam.CombinePerKey(combineFn) dd = DisplayData.create_from(transform) @@ -238,6 +242,7 @@ def individual_test_per_key_dd(combineFn): individual_test_per_key_dd(combine.Largest(5)) def test_combine_sample_display_data(self): + def individual_test_per_key_dd(sampleFn, n): trs = [sampleFn(n)] for transform in trs: @@ -289,7 +294,9 @@ def test_top_shorthands(self): assert_that(result_kbot, equal_to([('a', [0, 1, 1, 1])]), label='kbot') def test_top_no_compact(self): + class TopCombineFnNoCompact(combine.TopCombineFn): + def compact(self, accumulator): return accumulator @@ -312,6 +319,7 @@ def compact(self, accumulator): assert_that(result_kbot, equal_to([('a', [0, 1, 1, 1])]), label='KBot') def test_global_sample(self): + def is_good_sample(actual): assert len(actual) == 1 assert sorted(actual[0]) in [[1, 1, 2], [1, 2, 2]], actual @@ -343,6 +351,7 @@ def test_per_key_sample(self): result = pcoll | 'sample' >> combine.Sample.FixedSizePerKey(3) def matcher(): + def match(actual): for _, samples in actual: equal_to([3])([len(samples)]) @@ -402,6 +411,7 @@ def __init__(self): CountedAccumulator.count += 1 class CountedAccumulatorCombineFn(beam.CombineFn): + def create_accumulator(self): return CountedAccumulator() @@ -433,6 +443,7 @@ def test_to_list_and_to_dict1(self): | 'to list wo defaults' >> combine.ToList().without_defaults()) def matcher(expected): + def match(actual): equal_to(expected[0])(actual[0]) @@ -457,6 +468,7 @@ def test_to_list_and_to_dict2(self): | 'to dict wo defaults' >> combine.ToDict().without_defaults()) def matcher(): + def match(actual): equal_to([1])([len(actual)]) equal_to(pairs)(actual[0].items()) @@ -481,6 +493,7 @@ def test_to_set(self): | 'to set wo defaults' >> combine.ToSet().without_defaults()) def matcher(expected): + def match(actual): equal_to(expected[0])(actual[0]) @@ -500,7 +513,9 @@ def test_combine_globally_without_default(self): assert_that(result, equal_to([])) def test_combine_globally_with_default_side_input(self): + class SideInputCombine(PTransform): + def expand(self, pcoll): side = pcoll | CombineGlobally(sum).as_singleton_view() main = pcoll.pipeline | Create([None]) @@ -792,6 +807,7 @@ def test_custormized_counters_in_combine_fn_empty(self): class LatestTest(unittest.TestCase): + def test_globally(self): l = [ window.TimestampedValue(3, 100), @@ -845,6 +861,7 @@ def test_per_key_empty(self): class LatestCombineFnTest(unittest.TestCase): + def setUp(self): self.fn = combine.LatestCombineFn() @@ -891,7 +908,9 @@ def test_with_input_types_decorator_violation(self): @pytest.mark.it_validatesrunner class CombineValuesTest(unittest.TestCase): + def test_gbk_immediately_followed_by_combine(self): + def merge(vals): return "".join(vals) @@ -912,6 +931,7 @@ def merge(vals): # @pytest.mark.it_validatesrunner class TimestampCombinerTest(unittest.TestCase): + def test_combiner_earliest(self): """Test TimestampCombiner with EARLIEST.""" options = PipelineOptions(streaming=True) @@ -980,6 +1000,7 @@ def test_combiner_latest(self): class CombineGloballyTest(unittest.TestCase): + def test_combine_globally_for_unbounded_source_with_default(self): # this error is logged since the below combination is ill-defined. with self.assertLogs() as captured_logs: diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index b420d1d66d09..1cbeca21e932 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -141,6 +141,7 @@ class DoFnProcessContext(DoFnContext): for this element. Not used by the pipeline code. """ + def __init__(self, label, element=None, state=None): """Initialize a processing context object with an element and state. @@ -180,6 +181,7 @@ class ProcessContinuation(object): If produced, indicates that there is more work to be done for the current input element. """ + def __init__(self, resume_delay=0): """Initializes a ProcessContinuation object. @@ -252,6 +254,7 @@ class RestrictionProvider(object): be invoked with a single parameter of type ``Timestamp`` or as an integer that gives the watermark in number of seconds. """ + def create_tracker(self, restriction): # type: (...) -> iobase.RestrictionTracker @@ -393,8 +396,8 @@ def get_function_args_defaults(f): parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD ] args = [ - name for name, - p in signature.parameters.items() if p.kind in _SUPPORTED_ARG_TYPES + name for name, p in signature.parameters.items() + if p.kind in _SUPPORTED_ARG_TYPES ] defaults = [ p.default for p in signature.parameters.values() @@ -417,6 +420,7 @@ class WatermarkEstimatorProvider(object): or, if no WatermarkEstimatorProvider is provided, the DoFn itself must be a WatermarkEstimatorProvider. """ + def initial_estimator_state(self, element, restriction): """Returns the initial state of the WatermarkEstimator with given element and restriction. @@ -436,6 +440,7 @@ def estimator_state_coder(self): class _DoFnParam(object): """DoFn parameter.""" + def __init__(self, param_id): self.param_id = param_id @@ -453,6 +458,7 @@ def __repr__(self): class _RestrictionDoFnParam(_DoFnParam): """Restriction Provider DoFn parameter.""" + def __init__(self, restriction_provider=None): # type: (typing.Optional[RestrictionProvider]) -> None if (restriction_provider is not None and @@ -466,6 +472,7 @@ def __init__(self, restriction_provider=None): class _StateDoFnParam(_DoFnParam): """State DoFn parameter.""" + def __init__(self, state_spec): # type: (StateSpec) -> None if not isinstance(state_spec, StateSpec): @@ -476,6 +483,7 @@ def __init__(self, state_spec): class _TimerDoFnParam(_DoFnParam): """Timer DoFn parameter.""" + def __init__(self, timer_spec): # type: (TimerSpec) -> None if not isinstance(timer_spec, TimerSpec): @@ -486,6 +494,7 @@ def __init__(self, timer_spec): class _BundleFinalizerParam(_DoFnParam): """Bundle Finalization DoFn parameter.""" + def __init__(self): self._callbacks = [] self.param_id = "FinalizeBundle" @@ -513,6 +522,7 @@ def reset(self): class _WatermarkEstimatorParam(_DoFnParam): """WatermarkEstimator DoFn parameter.""" + def __init__( self, watermark_estimator_provider: typing. @@ -527,6 +537,7 @@ def __init__( class _ContextParam(_DoFnParam): + def __init__( self, context_manager_constructor, args=(), kwargs=None, *, name=None): class_name = self.__class__.__name__.strip('_') @@ -649,6 +660,7 @@ def from_callable(fn): def unbounded_per_element(): """A decorator on process fn specifying that the fn performs an unbounded amount of work per input element.""" + def wrapper(process_fn): process_fn.unbounded_per_element = True return process_fn @@ -801,12 +813,12 @@ def default_type_hints(self): self.process_batch) or typehints.decorators.IOTypeHints.empty() # Then we deconflict with the typehint from process, if it exists - if (process_batch_type_hints.output_types != - typehints.decorators.IOTypeHints.empty().output_types): - if (process_type_hints.output_types != - typehints.decorators.IOTypeHints.empty().output_types and - process_batch_type_hints.output_types != - process_type_hints.output_types): + if (process_batch_type_hints.output_types + != typehints.decorators.IOTypeHints.empty().output_types): + if (process_type_hints.output_types + != typehints.decorators.IOTypeHints.empty().output_types and + process_batch_type_hints.output_types + != process_type_hints.output_types): raise TypeError( f"DoFn {self!r} yields element from both process and " "process_batch, but they have mismatched output typehints:\n" @@ -986,6 +998,7 @@ class CallableWrapperDoFn(DoFn): The purpose of this class is to conveniently wrap simple functions and use them in transforms. """ + def __init__(self, fn, fullargspec=None): """Initializes a CallableWrapperDoFn object wrapping a callable. @@ -1080,6 +1093,7 @@ class CombineFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn): **apply** will be called with an empty list at expansion time to get the default value. """ + def default_label(self): return self.__class__.__name__ @@ -1246,6 +1260,7 @@ def get_accumulator_coder(self): class _ReiterableChain(object): """Like itertools.chain, but allowing re-iteration.""" + def __init__(self, iterables): self.iterables = iterables @@ -1378,6 +1393,7 @@ class NoSideInputsCallableWrapperCombineFn(CallableWrapperCombineFn): This is identical to its parent, but avoids accepting and passing *args and **kwargs for efficiency as they are known to be empty. """ + def create_accumulator(self): return [] @@ -1412,6 +1428,7 @@ class PartitionFn(WithTypeHints): A PartitionFn specifies how individual values in a PCollection will be placed into separate partitions, indexed by an integer. """ + def default_label(self): return self.__class__.__name__ @@ -1439,6 +1456,7 @@ class CallableWrapperPartitionFn(PartitionFn): Instances of this class wrap simple functions for use in Partition operations. """ + def __init__(self, fn): """Initializes a PartitionFn object wrapping a callable. @@ -1561,6 +1579,7 @@ class ParDo(PTransformWithSideInputs): replaced by values from the :class:`~apache_beam.pvalue.PCollection` in the exact positions where they appear in the argument lists. """ + def __init__(self, fn, *args, **kwargs): super().__init__(fn, *args, **kwargs) # TODO(robertwb): Change all uses of the dofn attribute to use fn instead. @@ -1888,10 +1907,10 @@ def to_runner_api_parameter(self, context, **extra_kwargs): # type: ignore[over # transformation is currently irreversible given how # remove_objects_from_args and insert_values_in_args # are currently implemented. - side_inputs={(SIDE_INPUT_PREFIX + '%s') % ix: - si.to_runner_api(context) - for ix, - si in enumerate(self.side_inputs)})) + side_inputs={ + (SIDE_INPUT_PREFIX + '%s') % ix: si.to_runner_api(context) + for ix, si in enumerate(self.side_inputs) + })) @staticmethod @PTransform.register_urn( @@ -1909,8 +1928,8 @@ def from_runner_api_parameter(unused_ptransform, pardo_payload, context): # to_runner_api_parameter above). indexed_side_inputs = [( get_sideinput_index(tag), - pvalue.AsSideInput.from_runner_api(si, context)) for tag, - si in pardo_payload.side_inputs.items()] + pvalue.AsSideInput.from_runner_api(si, context)) + for tag, si in pardo_payload.side_inputs.items()] result.side_inputs = [si for _, si in sorted(indexed_side_inputs)] return result @@ -1932,6 +1951,7 @@ def _add_type_constraint_from_consumer(self, full_label, input_type_hints): class _MultiParDo(PTransform): + def __init__(self, do_transform, tags, main_tag, allow_unknown_tags=None): super().__init__(do_transform.label) self._do_transform = do_transform @@ -1953,8 +1973,10 @@ class DoFnInfo(object): """This class represents the state in the ParDoPayload's function spec, which is the actual DoFn together with some data required for invoking it. """ + @staticmethod def register_stateless_dofn(urn): + def wrapper(cls): StatelessDoFnInfo.REGISTERED_DOFNS[urn] = cls cls._stateless_dofn_urn = urn @@ -1989,6 +2011,7 @@ def serialized_dofn_data(self): class PickledDoFnInfo(DoFnInfo): + def __init__(self, serialized_data): self._serialized_data = serialized_data @@ -2257,6 +2280,7 @@ def FlatMapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name class _ExceptionHandlingWrapper(ptransform.PTransform): """Implementation of ParDo.with_exception_handling.""" + def __init__( self, fn, @@ -2312,6 +2336,7 @@ def expand(self, pcoll): if self._threshold < 1.0: class MaybeWindow(ptransform.PTransform): + @staticmethod def expand(pcoll): if self._threshold_windowing: @@ -2344,6 +2369,7 @@ def check_threshold(bad, total, threshold, window=DoFn.WindowParam): class _ExceptionHandlingWrapperDoFn(DoFn): + def __init__( self, fn, dead_letter_tag, exc_class, partial, on_failure_callback): self._fn = fn @@ -2387,6 +2413,7 @@ def process(self, *args, **kwargs): class _PValueWithErrors(object): """This wraps a PCollection such that transforms can be chained in a linear manner while still accumulating any errors.""" + def __init__(self, pcoll, exception_handling_args, upstream_errors=()): self._pcoll = pcoll self._exception_handling_args = exception_handling_args @@ -2453,6 +2480,7 @@ class _MaybePValueWithErrors(object): exception_handling_args is non-trivial. It is useful for handling error-catching and non-error-catching code in a uniform manner. """ + def __init__(self, pvalue, exception_handling_args=None): if isinstance(pvalue, _PValueWithErrors): assert exception_handling_args is None @@ -2490,6 +2518,7 @@ def as_result(self, error_post_processing=None): class _SubprocessDoFn(DoFn): """Process method run in a subprocess, turning hard crashes into exceptions. """ + def __init__(self, fn, timeout=None): self._fn = fn self._serialized_fn = pickler.dumps(fn) @@ -2590,6 +2619,7 @@ def _remote_teardown(cls): class _TimeoutDoFn(DoFn): """Process method run in a separate thread allowing timeouts. """ + def __init__(self, fn, timeout=None): self._fn = fn self._timeout = timeout @@ -2796,6 +2826,7 @@ def as_singleton_view(self): return self._clone(as_view=True) def expand(self, pcoll): + def add_input_types(transform): type_hints = self.get_type_hints() if type_hints.input_types: @@ -2894,6 +2925,7 @@ def from_runner_api_parameter(unused_ptransform, combine_payload, context): @DoFnInfo.register_stateless_dofn(python_urns.KEY_WITH_NONE_DOFN) class _KeyWithNone(DoFn): + def process(self, v): yield None, v @@ -2917,6 +2949,7 @@ class CombinePerKey(PTransformWithSideInputs): Returns: A PObject holding the result of the combine operation. """ + def with_hot_key_fanout(self, fanout): """A per-key combine operation like self but with two levels of aggregation. @@ -3014,6 +3047,7 @@ def runner_api_requires_keyed_input(self): # TODO(robertwb): Rename to CombineGroupedValues? class CombineValues(PTransformWithSideInputs): + def make_fn(self, fn, has_side_inputs): return CombineFn.maybe_from_callable(fn, has_side_inputs) @@ -3146,6 +3180,7 @@ def expand(self, pcoll): 'SlidingWindows. See: https://github.com/apache/beam/issues/20528') class SplitHotCold(DoFn): + def start_bundle(self): # Spreading a hot key across all possible sub-keys for all bundles # would defeat the goal of not overwhelming downstream reducers @@ -3164,6 +3199,7 @@ def process(self, element): yield pvalue.TaggedOutput('hot', ((self._nonce % fanout, key), value)) class PreCombineFn(CombineFn): + def __init__(self): # Deepcopy of the combine_fn to avoid sharing state between lifted # stages when using cloudpickle. @@ -3185,6 +3221,7 @@ def extract_output(accumulator): return (True, accumulator) class PostCombineFn(CombineFn): + def __init__(self): # Deepcopy of the combine_fn to avoid sharing state between lifted # stages when using cloudpickle. @@ -3238,7 +3275,9 @@ class GroupByKey(PTransform): The implementation here is used only when run on the local direct runner. """ + class ReifyWindows(DoFn): + def process( self, element, window=DoFn.WindowParam, timestamp=DoFn.TimestampParam): try: @@ -3454,6 +3493,7 @@ def _unpickle_dynamic_named_tuple(type_name, field_names, values): class _GroupAndAggregate(PTransform): + def __init__(self, grouping, aggregations): self._grouping = grouping self._aggregations = aggregations @@ -3487,8 +3527,7 @@ def expand(self, pcoll): TupleCombineFn( *[combine_fn for _, combine_fn, __ in self._aggregations])) | MapTuple( - lambda key, - value: _dynamic_named_tuple('Result', result_fields) + lambda key, value: _dynamic_named_tuple('Result', result_fields) (*(key + value)))) @@ -3529,10 +3568,13 @@ def expand(self, pcoll): return ( _MaybePValueWithErrors(pcoll, self._exception_handling_args) | Map( lambda x: pvalue.Row( - **{name: expr(x) - for name, expr in self._fields}))).as_result() + **{ + name: expr(x) + for name, expr in self._fields + }))).as_result() def infer_output_type(self, input_type): + def extract_return_type(expr): expr_hints = get_type_hints(expr) if (expr_hints and expr_hints.has_simple_output_type() and @@ -3562,8 +3604,10 @@ class Partition(PTransformWithSideInputs): The result of this PTransform is a simple list of the output PCollections representing each of n partitions, in order. """ + class ApplyPartitionFnFn(DoFn): """A DoFn that applies a PartitionFn.""" + def process(self, element, partitionfn, n, *args, **kwargs): partition = partitionfn.partition_for(element, n, *args, **kwargs) if not 0 <= partition < n: @@ -3586,14 +3630,16 @@ def expand(self, pcoll): class Windowing(object): - def __init__(self, - windowfn, # type: WindowFn - triggerfn=None, # type: typing.Optional[TriggerFn] - accumulation_mode=None, # type: typing.Optional[beam_runner_api_pb2.AccumulationMode.Enum.ValueType] - timestamp_combiner=None, # type: typing.Optional[beam_runner_api_pb2.OutputTime.Enum.ValueType] - allowed_lateness=0, # type: typing.Union[int, float] - environment_id=None, # type: typing.Optional[str] - ): + + def __init__( + self, + windowfn, # type: WindowFn + triggerfn=None, # type: typing.Optional[TriggerFn] + accumulation_mode=None, # type: typing.Optional[beam_runner_api_pb2.AccumulationMode.Enum.ValueType] + timestamp_combiner=None, # type: typing.Optional[beam_runner_api_pb2.OutputTime.Enum.ValueType] + allowed_lateness=0, # type: typing.Union[int, float] + environment_id=None, # type: typing.Optional[str] + ): """Class representing the window strategy. Args: @@ -3715,8 +3761,10 @@ class WindowInto(ParDo): element with the same input value and timestamp, with its new set of windows determined by the windowing function. """ + class WindowIntoFn(DoFn): """A DoFn that applies a WindowInto operation.""" + def __init__(self, windowing): # type: (Windowing) -> None self.windowing = windowing @@ -3827,6 +3875,7 @@ class Flatten(PTransform): if there's a chance there may be none), this argument is the only way to provide pipeline information and should be considered mandatory. """ + def __init__(self, **kwargs): super().__init__() self.pipeline = kwargs.pop( @@ -3884,6 +3933,7 @@ class FlattenWith(PTransform): Root PTransforms can be passed as well as PCollections, in which case their outputs will be flattened. """ + def __init__(self, *others): self._others = others @@ -3903,6 +3953,7 @@ def expand(self, pcoll): class Create(PTransform): """A transform that creates a PCollection from an iterable.""" + def __init__(self, values, reshuffle=True): """Initializes a Create transform. @@ -3957,6 +4008,7 @@ def expand(self, pbegin): # transforms (e.g. Write). class MaybeReshuffle(PTransform): + def expand(self, pcoll): if len(serialized_values) > 1 and reshuffle: from apache_beam.transforms.util import Reshuffle @@ -3995,6 +4047,7 @@ def _create_source(serialized_values, coder): @typehints.with_output_types(bytes) class Impulse(PTransform): """Impulse primitive.""" + def expand(self, pbegin): if not isinstance(pbegin, pvalue.PBegin): raise TypeError( diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index 54afb365d2d8..32e1b9645ac2 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -34,12 +34,15 @@ class TestDoFn1(beam.DoFn): + def process(self, element): yield element class TestDoFn2(beam.DoFn): + def process(self, element): + def inner_func(x): yield x @@ -48,6 +51,7 @@ def inner_func(x): class TestDoFn3(beam.DoFn): """mixing return and yield is not allowed""" + def process(self, element): if not element: return -1 @@ -56,6 +60,7 @@ def process(self, element): class TestDoFn4(beam.DoFn): """test the variable name containing return""" + def process(self, element): my_return = element yield my_return @@ -63,6 +68,7 @@ def process(self, element): class TestDoFn5(beam.DoFn): """test the variable name containing yield""" + def process(self, element): my_yield = element return my_yield @@ -70,6 +76,7 @@ def process(self, element): class TestDoFn6(beam.DoFn): """test the variable name containing return""" + def process(self, element): return_test = element yield return_test @@ -77,6 +84,7 @@ def process(self, element): class TestDoFn7(beam.DoFn): """test the variable name containing yield""" + def process(self, element): yield_test = element return yield_test @@ -84,6 +92,7 @@ def process(self, element): class TestDoFn8(beam.DoFn): """test the code containing yield and yield from""" + def process(self, element): if not element: yield from [1, 2, 3] @@ -92,6 +101,7 @@ def process(self, element): class TestDoFn9(beam.DoFn): + def process(self, element): if len(element) > 3: raise ValueError('Not allowed to have long elements') @@ -100,23 +110,27 @@ def process(self, element): class TestDoFn10(beam.DoFn): """test process returning None explicitly""" + def process(self, element): return None class TestDoFn11(beam.DoFn): """test process returning None (no return and no yield)""" + def process(self, element): pass class TestDoFn12(beam.DoFn): """test process returning None (return statement without a value)""" + def process(self, element): return class CreateTest(unittest.TestCase): + @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): self._caplog = caplog @@ -159,11 +173,14 @@ def test_dofn_with_implicit_return_none_return_without_value(self): class PartitionTest(unittest.TestCase): + def test_partition_boundedness(self): + def partition_fn(val, num_partitions): return val % num_partitions class UnboundedDoFn(beam.DoFn): + @beam.DoFn.unbounded_per_element() def process(self, element): yield element @@ -187,6 +204,7 @@ def process(self, element): class FlattenTest(unittest.TestCase): + def test_flatten_identical_windows(self): with beam.testing.test_pipeline.TestPipeline() as p: source1 = p | "c1" >> beam.Create( @@ -218,6 +236,7 @@ def test_flatten_mismatched_windows(self): class ExceptionHandlingTest(unittest.TestCase): + def test_routes_failures(self): with beam.Pipeline() as pipeline: good, bad = ( @@ -280,6 +299,7 @@ def failure_callback(e, el): class FlatMapTest(unittest.TestCase): + def test_default(self): with beam.Pipeline() as pipeline: diff --git a/sdks/python/apache_beam/transforms/create_source.py b/sdks/python/apache_beam/transforms/create_source.py index 2fbc925afdda..be2b65062ed0 100644 --- a/sdks/python/apache_beam/transforms/create_source.py +++ b/sdks/python/apache_beam/transforms/create_source.py @@ -23,6 +23,7 @@ class _CreateSource(iobase.BoundedSource): """Internal source that is used by Create()""" + def __init__(self, serialized_values, coder): self._coder = coder self._serialized_values = [] diff --git a/sdks/python/apache_beam/transforms/create_test.py b/sdks/python/apache_beam/transforms/create_test.py index 37f32d478008..1e5e9aa68cf1 100644 --- a/sdks/python/apache_beam/transforms/create_test.py +++ b/sdks/python/apache_beam/transforms/create_test.py @@ -32,6 +32,7 @@ class CreateTest(unittest.TestCase): + def setUp(self): self.coder = FastPrimitivesCoder() @@ -139,6 +140,7 @@ def test_create_uses_coder_for_pickling(self): class _Unpicklable(object): + def __init__(self, value): self.value = value @@ -153,6 +155,7 @@ def __setstate__(self, state): class _UnpicklableCoder(coders.Coder): + def encode(self, value): return str(value.value).encode() diff --git a/sdks/python/apache_beam/transforms/cy_combiners.py b/sdks/python/apache_beam/transforms/cy_combiners.py index b5cc7493a29a..deb89022fe42 100644 --- a/sdks/python/apache_beam/transforms/cy_combiners.py +++ b/sdks/python/apache_beam/transforms/cy_combiners.py @@ -66,6 +66,7 @@ def __hash__(self): class CountAccumulator(object): + def __init__(self): self.value = 0 @@ -84,6 +85,7 @@ def extract_output(self): class SumInt64Accumulator(object): + def __init__(self): self.value = 0 @@ -114,6 +116,7 @@ def extract_output(self): class MinInt64Accumulator(object): + def __init__(self): self.value = INT64_MAX @@ -137,6 +140,7 @@ def extract_output(self): class MaxInt64Accumulator(object): + def __init__(self): self.value = INT64_MIN @@ -160,6 +164,7 @@ def extract_output(self): class MeanInt64Accumulator(object): + def __init__(self): self.sum = 0 self.count = 0 @@ -192,6 +197,7 @@ def extract_output(self): class DistributionInt64Accumulator(object): + def __init__(self): self.sum = 0 self.count = 0 @@ -262,6 +268,7 @@ class DistributionInt64Fn(AccumulatorCombineFn): class SumDoubleAccumulator(object): + def __init__(self): self.value = 0 @@ -278,6 +285,7 @@ def extract_output(self): class MinDoubleAccumulator(object): + def __init__(self): self.value = _POS_INF @@ -296,6 +304,7 @@ def extract_output(self): class MaxDoubleAccumulator(object): + def __init__(self): self.value = _NEG_INF @@ -314,6 +323,7 @@ def extract_output(self): class MeanDoubleAccumulator(object): + def __init__(self): self.sum = 0 self.count = 0 @@ -349,6 +359,7 @@ class MeanFloatFn(AccumulatorCombineFn): class AllAccumulator(object): + def __init__(self): self.value = True @@ -364,6 +375,7 @@ def extract_output(self): class AnyAccumulator(object): + def __init__(self): self.value = False diff --git a/sdks/python/apache_beam/transforms/dataflow_distribution_counter_test.py b/sdks/python/apache_beam/transforms/dataflow_distribution_counter_test.py index eda658e245c6..f73e312b5c37 100644 --- a/sdks/python/apache_beam/transforms/dataflow_distribution_counter_test.py +++ b/sdks/python/apache_beam/transforms/dataflow_distribution_counter_test.py @@ -26,6 +26,7 @@ class DataflowDistributionAccumulatorTest(unittest.TestCase): + def test_calculate_bucket_index_with_input_0(self): counter = DataflowDistributionCounter() index = counter.calculate_bucket_index(0) diff --git a/sdks/python/apache_beam/transforms/deduplicate.py b/sdks/python/apache_beam/transforms/deduplicate.py index 916b071fdf02..ab0aa998b257 100644 --- a/sdks/python/apache_beam/transforms/deduplicate.py +++ b/sdks/python/apache_beam/transforms/deduplicate.py @@ -55,6 +55,7 @@ class DeduplicatePerKey(ptransform.PTransform): Does not preserve any order the input PCollection might have had. """ + def __init__(self, processing_time_duration=None, event_time_duration=None): if processing_time_duration is None and event_time_duration is None: raise ValueError( @@ -72,6 +73,7 @@ def _create_deduplicate_fn(self): event_time_duration = self.event_time_duration class DeduplicationFn(core.DoFn): + def process( self, kv, @@ -113,6 +115,7 @@ class Deduplicate(ptransform.PTransform): value as input and uses value as key to deduplicate among certain amount of time duration. """ + def __init__(self, processing_time_duration=None, event_time_duration=None): if processing_time_duration is None and event_time_duration is None: raise ValueError( diff --git a/sdks/python/apache_beam/transforms/deduplicate_test.py b/sdks/python/apache_beam/transforms/deduplicate_test.py index 392dac20fb82..e5d38716cbe0 100644 --- a/sdks/python/apache_beam/transforms/deduplicate_test.py +++ b/sdks/python/apache_beam/transforms/deduplicate_test.py @@ -43,6 +43,7 @@ @pytest.mark.no_sickbay_streaming @pytest.mark.it_validatesrunner class DeduplicateTest(unittest.TestCase): + def __init__(self, *args, **kwargs): self.runner = None self.options = None diff --git a/sdks/python/apache_beam/transforms/display.py b/sdks/python/apache_beam/transforms/display.py index 14cd485d1f8e..2453716f3028 100644 --- a/sdks/python/apache_beam/transforms/display.py +++ b/sdks/python/apache_beam/transforms/display.py @@ -63,6 +63,7 @@ class HasDisplayData(object): It implements only the display_data method and a _get_display_data_namespace method. """ + def display_data(self): # type: () -> dict @@ -96,6 +97,7 @@ def _get_display_data_namespace(self): class DisplayData(object): """ Static display data associated with a pipeline component. """ + def __init__( self, namespace, # type: str @@ -136,6 +138,7 @@ def to_proto(self): # type: (...) -> List[beam_runner_api_pb2.DisplayData] """Returns a List of Beam proto representation of Display data.""" + def create_payload(dd) -> Optional[beam_runner_api_pb2.LabelledPayload]: try: display_data_dict = dd.get_dict() @@ -221,8 +224,7 @@ def create_from_options(cls, pipeline_options): items = { k: (v if DisplayDataItem._get_value_type(v) is not None else str(v)) - for k, - v in pipeline_options.display_data().items() + for k, v in pipeline_options.display_data().items() } return cls(pipeline_options._get_display_data_namespace(), items) diff --git a/sdks/python/apache_beam/transforms/display_test.py b/sdks/python/apache_beam/transforms/display_test.py index c91ad41e8d1c..f47c0ffc8960 100644 --- a/sdks/python/apache_beam/transforms/display_test.py +++ b/sdks/python/apache_beam/transforms/display_test.py @@ -92,11 +92,13 @@ def describe_to(self, description): class DisplayDataTest(unittest.TestCase): + def test_display_data_item_matcher(self): with self.assertRaises(ValueError): DisplayDataItemMatcher() def test_inheritance_ptransform(self): + class MyTransform(beam.PTransform): pass @@ -106,6 +108,7 @@ class MyTransform(beam.PTransform): self.assertEqual(display_pt.display_data(), {}) def test_inheritance_dofn(self): + class MyDoFn(beam.DoFn): pass @@ -114,7 +117,9 @@ class MyDoFn(beam.DoFn): self.assertEqual(display_dofn.display_data(), {}) def test_unsupported_type_display_data(self): + class MyDisplayComponent(HasDisplayData): + def display_data(self): return {'item_key': 'item_value'} @@ -122,7 +127,9 @@ def display_data(self): DisplayData.create_from_options(MyDisplayComponent()) def test_value_provider_display_data(self): + class TestOptions(PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( @@ -161,7 +168,9 @@ def test_create_list_display_data(self): 'extra_packages', str(['package1', 'package2'])))) def test_unicode_type_display_data(self): + class MyDoFn(beam.DoFn): + def display_data(self): return { 'unicode_string': 'my string', @@ -177,7 +186,9 @@ def test_base_cases(self): """ Tests basic display data cases (key:value, key:dict) It does not test subcomponent inclusion """ + class MyDoFn(beam.DoFn): + def __init__(self, my_display_data=None): self.my_display_data = my_display_data @@ -219,7 +230,9 @@ def display_data(self): hc.assert_that(dd.items, hc.has_items(*expected_items)) def test_drop_if_none(self): + class MyDoFn(beam.DoFn): + def display_data(self): return { 'some_val': DisplayDataItem('something').drop_if_none(), @@ -236,7 +249,9 @@ def display_data(self): hc.assert_that(dd.items, hc.has_items(*expected_items)) def test_subcomponent(self): + class SpecialDoFn(beam.DoFn): + def display_data(self): return {'dofn_value': 42} diff --git a/sdks/python/apache_beam/transforms/dofn_lifecycle_test.py b/sdks/python/apache_beam/transforms/dofn_lifecycle_test.py index 73068657ac4a..633a2a65e57c 100644 --- a/sdks/python/apache_beam/transforms/dofn_lifecycle_test.py +++ b/sdks/python/apache_beam/transforms/dofn_lifecycle_test.py @@ -28,6 +28,7 @@ class CallSequenceEnforcingDoFn(beam.DoFn): + def __init__(self): self._setup_called = False self._start_bundle_calls = 0 @@ -76,6 +77,7 @@ def teardown(self): @pytest.mark.it_validatesrunner class DoFnLifecycleTest(unittest.TestCase): + def test_dofn_lifecycle(self): with TestPipeline() as p: _ = ( @@ -86,6 +88,7 @@ def test_dofn_lifecycle(self): class LocalDoFnLifecycleTest(unittest.TestCase): + def test_dofn_lifecycle(self): from apache_beam.runners.direct import direct_runner from apache_beam.runners.portability import fn_api_runner diff --git a/sdks/python/apache_beam/transforms/enrichment.py b/sdks/python/apache_beam/transforms/enrichment.py index 5bb1e2024e79..f67b1f348c61 100644 --- a/sdks/python/apache_beam/transforms/enrichment.py +++ b/sdks/python/apache_beam/transforms/enrichment.py @@ -88,6 +88,7 @@ class EnrichmentSourceHandler(Caller[InputT, OutputT]): Ensure that the implementation of ``__call__`` method returns a tuple of `beam.Row` objects. """ + def get_cache_key(self, request: InputT) -> str: """Returns the request to be cached. This is how the response will be looked up in the cache as well. @@ -130,6 +131,7 @@ class Enrichment(beam.PTransform[beam.PCollection[InputT], client-side adaptive throttling using :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`. """ + def __init__( self, source_handler: EnrichmentSourceHandler, diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index 06b40bf38cc1..70adb326ee19 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -75,6 +75,7 @@ class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, list[Row]], NOTE: Elements cannot be batched when using the `query_fn` parameter. """ + def __init__( self, project: str, diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py index 98ac6244910c..ac5fb067e9e9 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py @@ -29,6 +29,7 @@ class TestBigQueryEnrichment(unittest.TestCase): + def setUp(self) -> None: self.project = 'apache-beam-testing' diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py index 6bf57cefacbe..f3606cb557db 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py @@ -51,6 +51,7 @@ def _row_key_fn(request: beam.Row) -> bytes: class ValidateResponse(beam.DoFn): """ValidateResponse validates if a PCollection of `beam.Row` has the required fields.""" + def __init__( self, n_fields: int, @@ -152,6 +153,7 @@ def create_rows(table): @pytest.mark.uses_testcontainer class TestBigTableEnrichment(unittest.TestCase): + def setUp(self): self.project_id = 'apache-beam-testing' self.instance_id = 'beam-test' diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_test.py index 1c5cb4064e0e..9942099e3266 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_test.py @@ -26,6 +26,7 @@ class TestBigTableEnrichmentHandler(unittest.TestCase): + @parameterized.expand([('product_id', _row_key_fn), ('', None)]) def test_bigtable_enrichment_invalid_args(self, row_key, row_key_fn): with self.assertRaises(ValueError): diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py index f8e8b4db1d7f..c6d698cf5a06 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py @@ -91,6 +91,7 @@ class FeastFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row, transform. To filter the features to enrich, use the `join_fn` param in :class:`apache_beam.transforms.enrichment.Enrichment`. """ + def __init__( self, feature_store_yaml_path: str, diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py index 9c4dab3d68b8..e42232559de4 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py @@ -48,6 +48,7 @@ def _entity_row_fn(request: beam.Row) -> Mapping[str, Any]: @pytest.mark.uses_feast class TestFeastEnrichmentHandler(unittest.TestCase): + def setUp(self) -> None: self.feature_store_yaml_file = ( 'gs://apache-beam-testing-enrichment/' diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_test.py index 764086ab2c98..9259481a0bbc 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_test.py @@ -29,6 +29,7 @@ class TestFeastFeatureStoreHandler(unittest.TestCase): + def setUp(self) -> None: self.feature_store_yaml_file = ( 'gs://apache-beam-testing-enrichment/' diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py index b6de3aa1c826..16590ce6384f 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py @@ -61,6 +61,7 @@ class VertexAIFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row, exist. So make sure the feature store instance exists or set `exception_level` as `ExceptionLevel.RAISE`. """ + def __init__( self, project: str, @@ -201,6 +202,7 @@ class VertexAIFeatureStoreLegacyEnrichmentHandler(EnrichmentSourceHandler): object.You can specify the features names using `feature_ids` to fetch specific features. """ + def __init__( self, project: str, diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_it_test.py index c5482309a251..c49e6c45110f 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_it_test.py @@ -46,6 +46,7 @@ class ValidateResponse(beam.DoFn): """ValidateResponse validates if a PCollection of `beam.Row` has the required fields.""" + def __init__(self, expected_fields): self.expected_fields = expected_fields @@ -64,6 +65,7 @@ def process(self, element: beam.Row, *args, **kwargs): @pytest.mark.uses_testcontainer class TestVertexAIFeatureStoreHandler(unittest.TestCase): + def setUp(self) -> None: self.project = 'apache-beam-testing' self.location = 'us-central1' diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_test.py index 352146ecc078..13d06d5b0bcc 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_test.py @@ -28,6 +28,7 @@ class TestVertexAIFeatureStoreHandlerInit(unittest.TestCase): + def test_raise_error_duplicate_api_endpoint_online_store(self): with self.assertRaises(ValueError): _ = VertexAIFeatureStoreEnrichmentHandler( diff --git a/sdks/python/apache_beam/transforms/enrichment_it_test.py b/sdks/python/apache_beam/transforms/enrichment_it_test.py index 4a45fae2e869..75d93275e5d6 100644 --- a/sdks/python/apache_beam/transforms/enrichment_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_it_test.py @@ -55,6 +55,7 @@ class SampleHTTPEnrichment(EnrichmentSourceHandler[Request, beam.Row]): """Implements ``EnrichmentSourceHandler`` to call the ``EchoServiceGrpc``'s HTTP handler. """ + def __init__(self, url: str): self.url = url + '/v1/echo' # append path to the mock API. @@ -93,6 +94,7 @@ def __call__(self, request: Request, *args, **kwargs): class ValidateFields(beam.DoFn): """ValidateFields validates if a PCollection of `beam.Row` has certain fields.""" + def __init__(self, n_fields: int, fields: List[str]): self.n_fields = n_fields self._fields = fields diff --git a/sdks/python/apache_beam/transforms/enrichment_test.py b/sdks/python/apache_beam/transforms/enrichment_test.py index 23b5f1828c15..e6fd6a4e641e 100644 --- a/sdks/python/apache_beam/transforms/enrichment_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_test.py @@ -28,6 +28,7 @@ class TestEnrichmentTransform(unittest.TestCase): + def test_cross_join(self): left = {'id': 1, 'key': 'city'} right = {'id': 1, 'value': 'durham'} diff --git a/sdks/python/apache_beam/transforms/environments.py b/sdks/python/apache_beam/transforms/environments.py index 77704e0522b2..d34198d6ae24 100644 --- a/sdks/python/apache_beam/transforms/environments.py +++ b/sdks/python/apache_beam/transforms/environments.py @@ -111,11 +111,12 @@ class Environment(object): _known_urns = {} # type: Dict[str, Tuple[Optional[type], ConstructorFn]] _urn_to_env_cls = {} # type: Dict[str, type] - def __init__(self, + def __init__( + self, capabilities=(), # type: Iterable[str] artifacts=(), # type: Iterable[beam_runner_api_pb2.ArtifactInformation] resource_hints=None, # type: Optional[Mapping[str, bytes]] - ): + ): # type: (...) -> None self._capabilities = capabilities self._artifacts = sorted(artifacts, key=lambda x: x.SerializeToString()) @@ -174,26 +175,29 @@ def register_urn( @classmethod @overload - def register_urn(cls, - urn, # type: str - parameter_type, # type: Type[T] - constructor # type: Callable[[T, Iterable[str], Iterable[beam_runner_api_pb2.ArtifactInformation], PipelineContext], Any] - ): + def register_urn( + cls, + urn, # type: str + parameter_type, # type: Type[T] + constructor # type: Callable[[T, Iterable[str], Iterable[beam_runner_api_pb2.ArtifactInformation], PipelineContext], Any] + ): # type: (...) -> None pass @classmethod @overload - def register_urn(cls, - urn, # type: str - parameter_type, # type: None - constructor # type: Callable[[bytes, Iterable[str], Iterable[beam_runner_api_pb2.ArtifactInformation], PipelineContext], Any] - ): + def register_urn( + cls, + urn, # type: str + parameter_type, # type: None + constructor # type: Callable[[bytes, Iterable[str], Iterable[beam_runner_api_pb2.ArtifactInformation], PipelineContext], Any] + ): # type: (...) -> None pass @classmethod def register_urn(cls, urn, parameter_type, constructor=None): + def register(constructor): if isinstance(constructor, type): constructor.from_runner_api_parameter = register( @@ -232,10 +236,11 @@ def to_runner_api(self, context): resource_hints=self.resource_hints()) @classmethod - def from_runner_api(cls, - proto, # type: Optional[beam_runner_api_pb2.Environment] - context # type: PipelineContext - ): + def from_runner_api( + cls, + proto, # type: Optional[beam_runner_api_pb2.Environment] + context # type: PipelineContext + ): # type: (...) -> Optional[Environment] if proto is None or not proto.urn: return None @@ -282,16 +287,18 @@ def from_options(cls, options): @Environment.register_urn(common_urns.environments.DEFAULT.urn, None) class DefaultEnvironment(Environment): """Used as a stub when context is missing a default environment.""" + def to_runner_api_parameter(self, context): return common_urns.environments.DEFAULT.urn, None @staticmethod - def from_runner_api_parameter(payload, # type: beam_runner_api_pb2.DockerPayload + def from_runner_api_parameter( + payload, # type: beam_runner_api_pb2.DockerPayload capabilities, # type: Iterable[str] artifacts, # type: Iterable[beam_runner_api_pb2.ArtifactInformation] resource_hints, # type: Mapping[str, bytes] context # type: PipelineContext - ): + ): # type: (...) -> DefaultEnvironment return DefaultEnvironment( capabilities=capabilities, @@ -302,6 +309,7 @@ def from_runner_api_parameter(payload, # type: beam_runner_api_pb2.DockerPayloa @Environment.register_urn( common_urns.environments.DOCKER.urn, beam_runner_api_pb2.DockerPayload) class DockerEnvironment(Environment): + def __init__( self, container_image=None, # type: Optional[str] @@ -339,12 +347,13 @@ def to_runner_api_parameter(self, context): beam_runner_api_pb2.DockerPayload(container_image=self.container_image)) @staticmethod - def from_runner_api_parameter(payload, # type: beam_runner_api_pb2.DockerPayload + def from_runner_api_parameter( + payload, # type: beam_runner_api_pb2.DockerPayload capabilities, # type: Iterable[str] artifacts, # type: Iterable[beam_runner_api_pb2.ArtifactInformation] resource_hints, # type: Mapping[str, bytes] context # type: PipelineContext - ): + ): # type: (...) -> DockerEnvironment return DockerEnvironment( container_image=payload.container_image, @@ -401,6 +410,7 @@ def default_docker_image(): @Environment.register_urn( common_urns.environments.PROCESS.urn, beam_runner_api_pb2.ProcessPayload) class ProcessEnvironment(Environment): + def __init__( self, command, # type: str @@ -451,12 +461,13 @@ def to_runner_api_parameter(self, context): os=self.os, arch=self.arch, command=self.command, env=self.env)) @staticmethod - def from_runner_api_parameter(payload, + def from_runner_api_parameter( + payload, capabilities, # type: Iterable[str] artifacts, # type: Iterable[beam_runner_api_pb2.ArtifactInformation] resource_hints, # type: Mapping[str, bytes] context # type: PipelineContext - ): + ): # type: (...) -> ProcessEnvironment return ProcessEnvironment( command=payload.command, @@ -510,6 +521,7 @@ def from_options(cls, options): @Environment.register_urn( common_urns.environments.EXTERNAL.urn, beam_runner_api_pb2.ExternalPayload) class ExternalEnvironment(Environment): + def __init__( self, url, # type: str @@ -547,12 +559,13 @@ def to_runner_api_parameter(self, context): params=self.params)) @staticmethod - def from_runner_api_parameter(payload, # type: beam_runner_api_pb2.ExternalPayload + def from_runner_api_parameter( + payload, # type: beam_runner_api_pb2.ExternalPayload capabilities, # type: Iterable[str] artifacts, # type: Iterable[beam_runner_api_pb2.ArtifactInformation] resource_hints, # type: Mapping[str, bytes] context # type: PipelineContext - ): + ): # type: (...) -> ExternalEnvironment return ExternalEnvironment( payload.endpoint.url, @@ -605,17 +618,19 @@ def resolve_anyof_environment(env_proto, *preferred_types): @Environment.register_urn(python_urns.EMBEDDED_PYTHON, None) class EmbeddedPythonEnvironment(Environment): + def to_runner_api_parameter(self, context): # type: (PipelineContext) -> Tuple[str, None] return python_urns.EMBEDDED_PYTHON, None @staticmethod - def from_runner_api_parameter(unused_payload, # type: None + def from_runner_api_parameter( + unused_payload, # type: None capabilities, # type: Iterable[str] artifacts, # type: Iterable[beam_runner_api_pb2.ArtifactInformation] resource_hints, # type: Mapping[str, bytes] context # type: PipelineContext - ): + ): # type: (...) -> EmbeddedPythonEnvironment return EmbeddedPythonEnvironment(capabilities, artifacts, resource_hints) @@ -636,6 +651,7 @@ def default(cls): @Environment.register_urn(python_urns.EMBEDDED_PYTHON_GRPC, bytes) class EmbeddedPythonGrpcEnvironment(Environment): + def __init__( self, state_cache_size=None, @@ -682,12 +698,13 @@ def to_runner_api_parameter(self, context): return python_urns.EMBEDDED_PYTHON_GRPC, payload @staticmethod - def from_runner_api_parameter(payload, # type: bytes + def from_runner_api_parameter( + payload, # type: bytes capabilities, # type: Iterable[str] artifacts, # type: Iterable[beam_runner_api_pb2.ArtifactInformation] resource_hints, # type: Mapping[str, bytes] context # type: PipelineContext - ): + ): # type: (...) -> EmbeddedPythonGrpcEnvironment if payload: config = EmbeddedPythonGrpcEnvironment.parse_config( @@ -742,17 +759,19 @@ def default(cls): @Environment.register_urn(python_urns.EMBEDDED_PYTHON_LOOPBACK, None) class PythonLoopbackEnvironment(EmbeddedPythonEnvironment): """Used as a stub when the loopback worker has not yet been started.""" + def to_runner_api_parameter(self, context): # type: (PipelineContext) -> Tuple[str, None] return python_urns.EMBEDDED_PYTHON_LOOPBACK, None @staticmethod - def from_runner_api_parameter(unused_payload, # type: None + def from_runner_api_parameter( + unused_payload, # type: None capabilities, # type: Iterable[str] artifacts, # type: Iterable[beam_runner_api_pb2.ArtifactInformation] resource_hints, # type: Mapping[str, bytes] context # type: PipelineContext - ): + ): # type: (...) -> PythonLoopbackEnvironment return PythonLoopbackEnvironment( capabilities=capabilities, @@ -762,6 +781,7 @@ def from_runner_api_parameter(unused_payload, # type: None @Environment.register_urn(python_urns.SUBPROCESS_SDK, bytes) class SubprocessSDKEnvironment(Environment): + def __init__( self, command_string, # type: str @@ -789,12 +809,13 @@ def to_runner_api_parameter(self, context): return python_urns.SUBPROCESS_SDK, self.command_string.encode('utf-8') @staticmethod - def from_runner_api_parameter(payload, # type: bytes + def from_runner_api_parameter( + payload, # type: bytes capabilities, # type: Iterable[str] artifacts, # type: Iterable[beam_runner_api_pb2.ArtifactInformation] resource_hints, # type: Mapping[str, bytes] context # type: PipelineContext - ): + ): # type: (...) -> SubprocessSDKEnvironment return SubprocessSDKEnvironment( payload.decode('utf-8'), capabilities, artifacts, resource_hints) @@ -819,6 +840,7 @@ def from_command_string(cls, command_string): common_urns.environments.ANYOF.urn, beam_runner_api_pb2.AnyOfEnvironmentPayload) class AnyOfEnvironment(Environment): + def __init__(self, environments): self._environments = environments @@ -832,12 +854,13 @@ def to_runner_api_parameter(self, context): ])) @staticmethod - def from_runner_api_parameter(payload, # type: beam_runner_api_pb2.AnyOfEnvironmentPayload + def from_runner_api_parameter( + payload, # type: beam_runner_api_pb2.AnyOfEnvironmentPayload capabilities, # type: Iterable[str] artifacts, # type: Iterable[beam_runner_api_pb2.ArtifactInformation] resource_hints, # type: Mapping[str, bytes] context # type: PipelineContext - ): + ): # type: (...) -> AnyOfEnvironment return AnyOfEnvironment([ Environment.from_runner_api(env, context) diff --git a/sdks/python/apache_beam/transforms/environments_test.py b/sdks/python/apache_beam/transforms/environments_test.py index c32a85579fcb..becc8e541758 100644 --- a/sdks/python/apache_beam/transforms/environments_test.py +++ b/sdks/python/apache_beam/transforms/environments_test.py @@ -39,6 +39,7 @@ class RunnerApiTest(unittest.TestCase): + def test_environment_encoding(self): for environment in (DockerEnvironment(), DockerEnvironment(container_image='img'), @@ -84,6 +85,7 @@ def test_default_capabilities(self): class EnvironmentOptionsTest(unittest.TestCase): + def setUp(self) -> None: self.tmp_dir = tempfile.TemporaryDirectory() self.actual_mkdtemp = tempfile.mkdtemp diff --git a/sdks/python/apache_beam/transforms/error_handling.py b/sdks/python/apache_beam/transforms/error_handling.py index 8671c66a12e0..ad6eb1097e31 100644 --- a/sdks/python/apache_beam/transforms/error_handling.py +++ b/sdks/python/apache_beam/transforms/error_handling.py @@ -50,6 +50,7 @@ class ErrorHandler: In this case, any non-recoverable errors should fail the pipeline (e.g. propagate exceptions in `process` methods) rather than silently ignore errors. """ + def __init__(self, consumer): self._consumer = consumer self._creation_traceback = traceback.format_stack()[-2] @@ -99,6 +100,7 @@ def verify_closed(self): class _IdentityPTransform(transforms.PTransform): + def expand(self, pcoll): return pcoll @@ -109,6 +111,7 @@ class CollectingErrorHandler(ErrorHandler): This ErrorHandler requires the set of errors be retrieved via `output()` and consumed (or explicitly discarded). """ + def __init__(self): super().__init__(_IdentityPTransform()) self._creation_traceback = traceback.format_stack()[-2] diff --git a/sdks/python/apache_beam/transforms/error_handling_test.py b/sdks/python/apache_beam/transforms/error_handling_test.py index 4d8c2d23dc14..44a0887345d2 100644 --- a/sdks/python/apache_beam/transforms/error_handling_test.py +++ b/sdks/python/apache_beam/transforms/error_handling_test.py @@ -25,6 +25,7 @@ class PTransformWithErrors(beam.PTransform): + def __init__(self, limit): self._limit = limit self._error_handler = None @@ -62,6 +63,7 @@ def exception_throwing_map(x, limit): class ErrorHandlingTest(unittest.TestCase): + def test_error_handling(self): with beam.Pipeline() as p: pcoll = p | beam.Create(['a', 'bb', 'cccc']) @@ -87,6 +89,7 @@ def test_error_handling_pardo(self): assert_that(error_pcoll, equal_to(['error: cccc']), label='CheckBad') def test_error_handling_pardo_with_exception_handling_kwargs(self): + def side_effect(*args): beam._test_error_handling_pardo_with_exception_handling_kwargs_val = True diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 9ca5886f4cc2..54b06a8edef1 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -96,6 +96,7 @@ class PayloadBuilder(object): """ Abstract base class for building payloads to pass to ExternalTransform. """ + def build(self): """ :return: ExternalConfigurationPayload @@ -138,6 +139,7 @@ class SchemaBasedPayloadBuilder(PayloadBuilder): Base class for building payloads based on a schema that provides type information for each configuration value to encode. """ + def _get_named_tuple_instance(self): raise NotImplementedError() @@ -152,6 +154,7 @@ class ImplicitSchemaPayloadBuilder(SchemaBasedPayloadBuilder): """ Build a payload that generates a schema from the provided values. """ + def __init__(self, values): self._values = values @@ -163,8 +166,8 @@ def _get_named_tuple_instance(self): } schema = named_fields_to_schema([ - (key, convert_to_typing_type(instance_to_type(value))) for key, - value in values.items() + (key, convert_to_typing_type(instance_to_type(value))) + for key, value in values.items() ]) return named_tuple_from_schema(schema)(**values) @@ -173,6 +176,7 @@ class NamedTupleBasedPayloadBuilder(SchemaBasedPayloadBuilder): """ Build a payload based on a NamedTuple schema. """ + def __init__(self, tuple_instance): """ :param tuple_instance: an instance of a typing.NamedTuple @@ -185,6 +189,7 @@ def _get_named_tuple_instance(self): class SchemaTransformPayloadBuilder(PayloadBuilder): + def __init__(self, identifier, **kwargs): self._identifier = identifier self._kwargs = kwargs @@ -207,12 +212,14 @@ def build(self): class ExplicitSchemaTransformPayloadBuilder(SchemaTransformPayloadBuilder): + def __init__(self, identifier, schema_proto, **kwargs): self._identifier = identifier self._schema_proto = schema_proto self._kwargs = kwargs def build(self): + def dict_to_row_recursive(field_type, py_value): if py_value is None: return None @@ -227,8 +234,7 @@ def dict_to_row_recursive(field_type, py_value): elif type_info == 'map_type': return { key: dict_to_row_recursive(field_type.map_type.value_type, value) - for key, - value in py_value.items() + for key, value in py_value.items() } else: return py_value @@ -401,6 +407,7 @@ class SchemaAwareExternalTransform(ptransform.PTransform): keys map to the field names of the schema of the SchemaTransform (in-order). """ + def __init__( self, identifier, @@ -533,6 +540,7 @@ class JavaExternalTransform(ptransform.PTransform): :param classpath: (Optional) A list paths to additional jars to place on the expansion service classpath. """ + def __init__(self, class_name, expansion_service=None, classpath=None): if expansion_service and classpath: raise ValueError( @@ -585,6 +593,7 @@ class AnnotationBasedPayloadBuilder(SchemaBasedPayloadBuilder): """ Build a payload based on an external transform's type annotations. """ + def __init__(self, transform, **values): """ :param transform: a PTransform instance or class. type annotations will @@ -596,8 +605,8 @@ def __init__(self, transform, **values): def _get_named_tuple_instance(self): schema = named_fields_to_schema([ - (k, convert_to_typing_type(v)) for k, - v in self._transform.__init__.__annotations__.items() + (k, convert_to_typing_type(v)) + for k, v in self._transform.__init__.__annotations__.items() if k in self._values ]) return named_tuple_from_schema(schema)(**self._values) @@ -607,6 +616,7 @@ class DataclassBasedPayloadBuilder(SchemaBasedPayloadBuilder): """ Build a payload based on an external transform that uses dataclasses. """ + def __init__(self, transform): """ :param transform: a dataclass-decorated PTransform instance from which to @@ -771,8 +781,7 @@ def fix_output(pcoll, tag): self._outputs = { tag: fix_output(result_context.pcollections.get_by_id(pcoll_id), tag) - for tag, - pcoll_id in self._expanded_transform.outputs.items() + for tag, pcoll_id in self._expanded_transform.outputs.items() } return self._output_to_pvalueish(self._outputs) @@ -807,6 +816,7 @@ def service(expansion_service): yield stub def _resolve_artifacts(self, components, service, dest): + def _resolve_artifacts_for(env): if env.urn == common_urns.environments.ANYOF.urn: env.CopyFrom( @@ -893,13 +903,11 @@ def _normalize(coder_proto): subtransforms=proto.subtransforms, inputs={ tag: pcoll_renames.get(pcoll, pcoll) - for tag, - pcoll in proto.inputs.items() + for tag, pcoll in proto.inputs.items() }, outputs={ tag: pcoll_renames.get(pcoll, pcoll) - for tag, - pcoll in proto.outputs.items() + for tag, pcoll in proto.outputs.items() }, display_data=proto.display_data, environment_id=proto.environment_id) @@ -914,13 +922,11 @@ def _normalize(coder_proto): subtransforms=self._expanded_transform.subtransforms, inputs={ tag: pcoll_renames.get(pcoll, pcoll) - for tag, - pcoll in self._expanded_transform.inputs.items() + for tag, pcoll in self._expanded_transform.inputs.items() }, outputs={ tag: pcoll_renames.get(pcoll, pcoll) - for tag, - pcoll in self._expanded_transform.outputs.items() + for tag, pcoll in self._expanded_transform.outputs.items() }, annotations=self._expanded_transform.annotations, environment_id=self._expanded_transform.environment_id) @@ -928,6 +934,7 @@ def _normalize(coder_proto): class ExpansionAndArtifactRetrievalStub( beam_expansion_api_pb2_grpc.ExpansionServiceStub): + def __init__(self, channel, **kwargs): self._channel = channel self._kwargs = kwargs @@ -959,6 +966,7 @@ class JavaJarExpansionService(object): expansion service using the jar file. These arguments will be appended to the default arguments. """ + def __init__( self, path_to_jar, extra_args=None, classpath=None, append_args=None): if extra_args and append_args: @@ -1059,6 +1067,7 @@ class BeamJarExpansionService(JavaJarExpansionService): expansion service using the jar file. These arguments will be appended to the default arguments. """ + def __init__( self, gradle_target, diff --git a/sdks/python/apache_beam/transforms/external_it_test.py b/sdks/python/apache_beam/transforms/external_it_test.py index e24b70f3d3d7..9dad86ba8f7a 100644 --- a/sdks/python/apache_beam/transforms/external_it_test.py +++ b/sdks/python/apache_beam/transforms/external_it_test.py @@ -33,10 +33,13 @@ class ExternalTransformIT(unittest.TestCase): + @pytest.mark.it_postcommit def test_job_python_from_python_it(self): + @ptransform.PTransform.register_urn('simple', None) class SimpleTransform(ptransform.PTransform): + def expand(self, pcoll): return pcoll | beam.Map(lambda x: 'Simple(%s)' % x) diff --git a/sdks/python/apache_beam/transforms/external_java.py b/sdks/python/apache_beam/transforms/external_java.py index e3984fa8ef20..e737f592d980 100644 --- a/sdks/python/apache_beam/transforms/external_java.py +++ b/sdks/python/apache_beam/transforms/external_java.py @@ -51,6 +51,7 @@ class JavaExternalTransformTest(object): expansion_service_port: Optional[int] = None class _RunWithExpansion(object): + def __init__(self): self._server = None diff --git a/sdks/python/apache_beam/transforms/external_test.py b/sdks/python/apache_beam/transforms/external_test.py index c95a5d19f0cd..aa5b36c0f04c 100644 --- a/sdks/python/apache_beam/transforms/external_test.py +++ b/sdks/python/apache_beam/transforms/external_test.py @@ -121,6 +121,7 @@ def test_optional_error(self): class ExternalTuplePayloadTest(PayloadBase, unittest.TestCase): + def get_payload_from_typing_hints(self, values): TestSchema = typing.NamedTuple( 'TestSchema', @@ -147,6 +148,7 @@ class ExternalImplicitPayloadTest(unittest.TestCase): ImplicitSchemaPayloadBuilder works very differently than the other payload builders """ + def test_implicit_payload_builder(self): builder = ImplicitSchemaPayloadBuilder(PayloadBase.values) result = builder.build() @@ -177,6 +179,7 @@ def test_implicit_payload_builder_with_bytes(self): class ExternalTransformTest(unittest.TestCase): + def test_pipeline_generation(self): pipeline = beam.Pipeline() _ = ( @@ -389,7 +392,9 @@ def test_sanitize_java_traceback(self): class ExternalAnnotationPayloadTest(PayloadBase, unittest.TestCase): + def get_payload_from_typing_hints(self, values): + class AnnotatedTransform(beam.ExternalTransform): URN = 'beam:external:fakeurn:v1' @@ -418,6 +423,7 @@ def __init__( return get_payload(AnnotatedTransform(**values)) def get_payload_from_beam_typehints(self, values): + class AnnotatedTransform(beam.ExternalTransform): URN = 'beam:external:fakeurn:v1' @@ -447,7 +453,9 @@ def __init__( class ExternalDataclassesPayloadTest(PayloadBase, unittest.TestCase): + def get_payload_from_typing_hints(self, values): + @dataclasses.dataclass class DataclassTransform(beam.ExternalTransform): URN = 'beam:external:fakeurn:v1' @@ -463,6 +471,7 @@ class DataclassTransform(beam.ExternalTransform): return get_payload(DataclassTransform(**values)) def get_payload_from_beam_typehints(self, values): + @dataclasses.dataclass class DataclassTransform(beam.ExternalTransform): URN = 'beam:external:fakeurn:v1' @@ -479,6 +488,7 @@ class DataclassTransform(beam.ExternalTransform): class SchemaTransformPayloadBuilderTest(unittest.TestCase): + def test_build_payload(self): ComplexType = typing.NamedTuple( "ComplexType", [ @@ -508,6 +518,7 @@ def test_build_payload(self): class SchemaAwareExternalTransformTest(unittest.TestCase): + class MockDiscoveryService: # define context manager enter and exit functions def __enter__(self): @@ -575,6 +586,7 @@ def test_rearrange_kwargs_based_on_discovery(self, mock_service): class JavaClassLookupPayloadBuilderTest(unittest.TestCase): + def _verify_row(self, schema, row_payload, expected_values): row = RowCoder(schema).decode(row_payload) @@ -719,6 +731,7 @@ def test_implicit_builder_with_constructor_method(self): class JavaJarExpansionServiceTest(unittest.TestCase): + def setUp(self): SubprocessServer._cache._live_owners = set() @@ -743,6 +756,7 @@ def test_classpath(self): @mock.patch.object(JavaJarServer, 'local_jar') def test_classpath_with_url(self, local_jar): + def _side_effect_fn(path): return path[path.rindex('/') + 1:] @@ -765,6 +779,7 @@ def _side_effect_fn(path): @mock.patch.object(JavaJarServer, 'local_jar') def test_classpath_with_gradle_artifact(self, local_jar): + def _side_effect_fn(path): return path[path.rindex('/') + 1:] diff --git a/sdks/python/apache_beam/transforms/external_transform_provider.py b/sdks/python/apache_beam/transforms/external_transform_provider.py index b22cd4b24cb6..892342beb9af 100644 --- a/sdks/python/apache_beam/transforms/external_transform_provider.py +++ b/sdks/python/apache_beam/transforms/external_transform_provider.py @@ -200,6 +200,7 @@ class ExternalTransformProvider: row_restriction=restriction) | 'Some processing' >> beam.Map(...)) """ + def __init__(self, expansion_services, urn_pattern=STANDARD_URN_PATTERN): f"""Initialize an ExternalTransformProvider diff --git a/sdks/python/apache_beam/transforms/external_transform_provider_it_test.py b/sdks/python/apache_beam/transforms/external_transform_provider_it_test.py index 6b26206908fb..2c9703b781e1 100644 --- a/sdks/python/apache_beam/transforms/external_transform_provider_it_test.py +++ b/sdks/python/apache_beam/transforms/external_transform_provider_it_test.py @@ -44,6 +44,7 @@ class NameAndTypeUtilsTest(unittest.TestCase): + def test_snake_case_to_upper_camel_case(self): test_cases = [("", ""), ("test", "Test"), ("test_name", "TestName"), ("test_double_underscore", "TestDoubleUnderscore"), @@ -93,6 +94,7 @@ def test_infer_name_from_identifier(self): "EXPANSION_JARS environment var is not provided, " "indicating that jars have not been built") class ExternalTransformProviderIT(unittest.TestCase): + def test_generate_sequence_signature_and_doc(self): provider = ExternalTransformProvider( BeamJarExpansionService(":sdks:java:io:expansion-service:shadowJar")) diff --git a/sdks/python/apache_beam/transforms/fully_qualified_named_transform_test.py b/sdks/python/apache_beam/transforms/fully_qualified_named_transform_test.py index f4a4f75126a1..5da7b553d26d 100644 --- a/sdks/python/apache_beam/transforms/fully_qualified_named_transform_test.py +++ b/sdks/python/apache_beam/transforms/fully_qualified_named_transform_test.py @@ -32,6 +32,7 @@ class FullyQualifiedNamedTransformTest(unittest.TestCase): + def test_test_transform(self): with beam.Pipeline() as p: assert_that( @@ -115,11 +116,12 @@ def test_callable_transform(self): | FullyQualifiedNamedTransform( '__callable__', # the next argument is a callable to be applied ( - python_callable.PythonCallableWithSource(""" + python_callable.PythonCallableWithSource( + """ def func(pcoll, x): return pcoll | beam.Map(lambda e: e + x) """), - 'x' # arguments passed to the callable + 'x' # arguments passed to the callable ), {}), equal_to(['ax', 'bx', 'cx'])) @@ -133,16 +135,16 @@ def test_constructor_transform(self): '__constructor__', # the next argument constructs a PTransform (), { - 'source': python_callable.PythonCallableWithSource(""" + 'source': python_callable.PythonCallableWithSource( + """ class MyTransform(beam.PTransform): def __init__(self, x): self._x = x def expand(self, pcoll): return pcoll | beam.Map(lambda e: e + self._x) """), - 'x': 'x' # arguments passed to the above constructor - } - ), + 'x': 'x' # arguments passed to the above constructor + }), equal_to(['ax', 'bx', 'cx'])) def test_glob_filter(self): @@ -184,6 +186,7 @@ def test_resolve(self): class _TestTransform(beam.PTransform): + @classmethod def create(cls, *args, **kwargs): return cls(*args, **kwargs) diff --git a/sdks/python/apache_beam/transforms/periodicsequence.py b/sdks/python/apache_beam/transforms/periodicsequence.py index 613661b22957..ce0af0b47c1a 100644 --- a/sdks/python/apache_beam/transforms/periodicsequence.py +++ b/sdks/python/apache_beam/transforms/periodicsequence.py @@ -33,6 +33,7 @@ class ImpulseSeqGenRestrictionProvider(core.RestrictionProvider): + def initial_restriction(self, element): start, end, interval = element if isinstance(start, Timestamp): @@ -92,6 +93,7 @@ class ImpulseSeqGenDoFn(beam.DoFn): ImpulseSeqGenDoFn guarantees that elements would not be output prior to given runtime timestamp. ''' + @beam.DoFn.unbounded_per_element() def process( self, @@ -154,6 +156,7 @@ class PeriodicSequence(PTransform): runtime timestamp. The PCollection generated by PeriodicSequence is unbounded. ''' + def __init__(self): pass @@ -173,6 +176,7 @@ class PeriodicImpulse(PTransform): but can be used as first transform in pipeline. The PCollection generated by PeriodicImpulse is unbounded. ''' + def __init__( self, start_timestamp=Timestamp.now(), diff --git a/sdks/python/apache_beam/transforms/periodicsequence_it_test.py b/sdks/python/apache_beam/transforms/periodicsequence_it_test.py index e900ba4cd855..d476ce2e1acb 100644 --- a/sdks/python/apache_beam/transforms/periodicsequence_it_test.py +++ b/sdks/python/apache_beam/transforms/periodicsequence_it_test.py @@ -40,6 +40,7 @@ StandardOptions).streaming, "Watermark tests are only valid for streaming jobs.") class PeriodicSequenceIT(unittest.TestCase): + def setUp(self): self.test_pipeline = TestPipeline(is_integration_test=True) @@ -54,7 +55,9 @@ def test_periodicsequence_outputs_valid_watermarks_it(self): we make sure that there's not a long gap between an element being emitted and being correctly aggregated. """ + class FindLongGaps(DoFn): + def process(self, element): emitted_at, unused_count = element processed_at = time.time() diff --git a/sdks/python/apache_beam/transforms/periodicsequence_test.py b/sdks/python/apache_beam/transforms/periodicsequence_test.py index 221520c94622..c1d683e250f8 100644 --- a/sdks/python/apache_beam/transforms/periodicsequence_test.py +++ b/sdks/python/apache_beam/transforms/periodicsequence_test.py @@ -37,6 +37,7 @@ class PeriodicSequenceTest(unittest.TestCase): + def test_periodicsequence_outputs_valid_sequence(self): start_offset = 1 start_time = time.time() + start_offset diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 4848dc4aade8..ab2054d24e2b 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -121,6 +121,7 @@ class _PValueishTransform(object): This visits a PValueish, contstructing a (possibly mutated) copy. """ + def visit_nested(self, node, *args): if isinstance(node, (tuple, list)): args = [self.visit(x, *args) for x in node] @@ -138,6 +139,7 @@ def visit_nested(self, node, *args): class _SetInputPValues(_PValueishTransform): + def visit(self, node, replacements): if id(node) in replacements: return replacements[id(node)] @@ -196,6 +198,7 @@ def _release_materialized_pipeline(pipeline): class _MaterializedResult(object): + def __init__(self, pipeline_id, result_id): # type: (int, int) -> None self._pipeline_id = pipeline_id @@ -210,6 +213,7 @@ def __reduce__(self): class _MaterializedDoOutputsTuple(pvalue.DoOutputsTuple): + def __init__(self, deferred, results_by_tag): super().__init__(None, None, deferred._tags, deferred._main_tag) self._deferred = deferred @@ -223,6 +227,7 @@ def __getitem__(self, tag): class _AddMaterializationTransforms(_PValueishTransform): + def _materialize_transform(self, pipeline): result = _allocate_materialized_result(pipeline) @@ -232,6 +237,7 @@ def _materialize_transform(self, pipeline): from apache_beam import ParDo class _MaterializeValuesDoFn(DoFn): + def process(self, element): result.elements.append(element) @@ -253,6 +259,7 @@ def visit(self, node): class _FinalizeMaterialization(_PValueishTransform): + def visit(self, node): if isinstance(node, _MaterializedResult): return node.elements @@ -307,6 +314,7 @@ class _ZipPValues(object): [('a', pc1, 'A'), ('b', pc2, 'B'), ('b', pc3, 'B')] """ + def visit(self, pvalueish, sibling, pairs=None, context=None): if pairs is None: pairs = [] @@ -708,26 +716,29 @@ def register_urn( @classmethod @overload - def register_urn(cls, - urn, # type: str - parameter_type, # type: Type[T] - constructor # type: Callable[[beam_runner_api_pb2.PTransform, T, PipelineContext], Any] - ): + def register_urn( + cls, + urn, # type: str + parameter_type, # type: Type[T] + constructor # type: Callable[[beam_runner_api_pb2.PTransform, T, PipelineContext], Any] + ): # type: (...) -> None pass @classmethod @overload - def register_urn(cls, - urn, # type: str - parameter_type, # type: None - constructor # type: Callable[[beam_runner_api_pb2.PTransform, bytes, PipelineContext], Any] - ): + def register_urn( + cls, + urn, # type: str + parameter_type, # type: None + constructor # type: Callable[[beam_runner_api_pb2.PTransform, bytes, PipelineContext], Any] + ): # type: (...) -> None pass @classmethod def register_urn(cls, urn, parameter_type, constructor=None): + def register(constructor): if isinstance(constructor, type): constructor.from_runner_api_parameter = register( @@ -758,10 +769,11 @@ def to_runner_api(self, context, has_parts=False, **extra_kwargs): if isinstance(typed_param, str) else typed_param) @classmethod - def from_runner_api(cls, - proto, # type: Optional[beam_runner_api_pb2.PTransform] - context # type: PipelineContext - ): + def from_runner_api( + cls, + proto, # type: Optional[beam_runner_api_pb2.PTransform] + context # type: PipelineContext + ): # type: (...) -> Optional[PTransform] if proto is None or proto.spec is None or not proto.spec.urn: return None @@ -811,6 +823,7 @@ def _unpickle_transform(unused_ptransform, pickled_bytes, unused_context): class _ChainedPTransform(PTransform): + def __init__(self, *parts): # type: (*PTransform) -> None super().__init__(label=self._chain_label(parts)) @@ -842,6 +855,7 @@ class PTransformWithSideInputs(PTransform): and side inputs to that code. This internal-use-only class contains common functionality for :class:`PTransform` s that fit this model. """ + def __init__(self, fn, *args, **kwargs): # type: (WithTypeHints, *Any, **Any) -> None if isinstance(fn, type) and issubclass(fn, WithTypeHints): @@ -988,6 +1002,7 @@ def default_label(self): class _PTransformFnPTransform(PTransform): """A class wrapper for a function-based transform.""" + def __init__(self, fn, *args, **kwargs): super().__init__() self._fn = fn @@ -1119,6 +1134,7 @@ def label_from_callable(fn): class _NamedPTransform(PTransform): + def __init__(self, transform, label): super().__init__(label) self.transform = transform @@ -1165,6 +1181,7 @@ def annotate_yaml(constructor): Should only be used for transforms that are fully defined by their constructor arguments. """ + @wraps(constructor) def wrapper(*args, **kwargs): transform = constructor(*args, **kwargs) @@ -1199,8 +1216,7 @@ def wrapper(*args, **kwargs): # The outermost call is expected to be the most specific. 'yaml_provider': 'python', 'yaml_type': 'PyTransform', - 'yaml_args': config, - } + 'yaml_args': config, } return transform return wrapper diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index f9f6b230866e..a8edb339c4c4 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -69,6 +69,7 @@ class PTransformTest(unittest.TestCase): + def assertStartswith(self, msg, prefix): self.assertTrue( msg.startswith(prefix), '"%s" does not start with "%s"' % (msg, prefix)) @@ -113,7 +114,9 @@ def test_named_annotations(self): self.assertEqual(t.annotations(), {'test': 'value', 'another': 'value'}) def test_do_with_do_fn(self): + class AddNDoFn(beam.DoFn): + def process(self, element, addon): return [element + addon] @@ -123,7 +126,9 @@ def process(self, element, addon): assert_that(result, equal_to([11, 12, 13])) def test_do_with_unconstructed_do_fn(self): + class MyDoFn(beam.DoFn): + def process(self): pass @@ -206,6 +211,7 @@ def test_read_metrics(self): from apache_beam.io.utils import CountingSource class CounterDoFn(beam.DoFn): + def __init__(self): # This counter is unused. self.received_records = Metrics.counter( @@ -229,8 +235,10 @@ def process(self, element): @pytest.mark.it_validatesrunner def test_par_do_with_multiple_outputs_and_using_yield(self): + class SomeDoFn(beam.DoFn): """A custom DoFn using yield.""" + def process(self, element): yield element if element % 2 == 0: @@ -248,6 +256,7 @@ def process(self, element): @pytest.mark.it_validatesrunner def test_par_do_with_multiple_outputs_and_using_return(self): + def some_fn(v): if v % 2 == 0: return [v, pvalue.TaggedOutput('even', v)] @@ -267,9 +276,8 @@ def test_undeclared_outputs(self): nums = pipeline | 'Some Numbers' >> beam.Create([1, 2, 3, 4]) results = nums | 'ClassifyNumbers' >> beam.FlatMap( lambda x: [ - x, - pvalue.TaggedOutput('even' if x % 2 == 0 else 'odd', x), - pvalue.TaggedOutput('extra', x) + x, pvalue.TaggedOutput('even' if x % 2 == 0 else 'odd', x), pvalue + .TaggedOutput('extra', x) ]).with_outputs() assert_that(results[None], equal_to([1, 2, 3, 4])) assert_that(results.odd, equal_to([1, 3]), label='assert:odd') @@ -305,7 +313,9 @@ def incorrect_par_do_fn(x): self.assertStartswith(cm.exception.args[0], expected_error_prefix) def test_do_fn_with_finish(self): + class MyDoFn(beam.DoFn): + def process(self, element): pass @@ -318,6 +328,7 @@ def finish_bundle(self): # May have many bundles, but each has a start and finish. def matcher(): + def match(actual): equal_to(['finish'])(list(set(actual))) equal_to([1])([actual.count('finish')]) @@ -330,6 +341,7 @@ def test_do_fn_with_windowing_in_finish_bundle(self): windowfn = window.FixedWindows(2) class MyDoFn(beam.DoFn): + def process(self, element): yield TimestampedValue('process' + str(element), 5) @@ -343,9 +355,8 @@ def finish_bundle(self): | beam.ParDo(MyDoFn()) | WindowInto(windowfn) | 'create tuple' >> beam.Map( - lambda v, - t=beam.DoFn.TimestampParam, - w=beam.DoFn.WindowParam: (v, t, w.start, w.end))) + lambda v, t=beam.DoFn.TimestampParam, w=beam.DoFn.WindowParam: + (v, t, w.start, w.end))) expected_process = [ ('process1', Timestamp(5), Timestamp(4), Timestamp(6)) ] @@ -354,7 +365,9 @@ def finish_bundle(self): assert_that(result, equal_to(expected_process + expected_finish)) def test_do_fn_with_start(self): + class MyDoFn(beam.DoFn): + def __init__(self): self.state = 'init' @@ -372,6 +385,7 @@ def process(self, element): # May have many bundles, but each has a start and finish. def matcher(): + def match(actual): equal_to(['started'])(list(set(actual))) equal_to([1])([actual.count('started')]) @@ -381,7 +395,9 @@ def match(actual): assert_that(result, matcher()) def test_do_fn_with_start_error(self): + class MyDoFn(beam.DoFn): + def start_bundle(self): return [1] @@ -419,6 +435,7 @@ def test_filter(self): assert_that(result, equal_to([2, 4])) class _MeanCombineFn(beam.CombineFn): + def create_accumulator(self): return (0, 0) @@ -457,8 +474,7 @@ def test_combine_with_side_input_as_arg(self): divisor = pipeline | 'Divisor' >> beam.Create([2]) result = pcoll | 'Max' >> beam.CombineGlobally( # Multiples of divisor only. - lambda vals, - d: max(v for v in vals if v % d == 0), + lambda vals, d: max(v for v in vals if v % d == 0), pvalue.AsSingleton(divisor)).without_defaults() filt_vals = [v for v in values if v % 2 == 0] assert_that(result, equal_to([max(filt_vals)])) @@ -492,8 +508,7 @@ def test_combine_per_key_with_side_input_as_arg(self): ([('a', x) for x in vals_1] + [('b', x) for x in vals_2])) divisor = pipeline | 'Divisor' >> beam.Create([2]) result = pcoll | beam.CombinePerKey( - lambda vals, - d: max(v for v in vals if v % d == 0), + lambda vals, d: max(v for v in vals if v % d == 0), pvalue.AsSingleton(divisor)) # Multiples of divisor only. m_1 = max(v for v in vals_1 if v % 2 == 0) m_2 = max(v for v in vals_2 if v % 2 == 0) @@ -544,7 +559,9 @@ def test_group_by_key_allow_unsafe_triggers(self): assert_that(pcoll, equal_to([(1, [1, 2, 3, 4])])) def test_group_by_key_reiteration(self): + class MyDoFn(beam.DoFn): + def process(self, gbk_result): key, value_list = gbk_result sum_val = 0 @@ -566,6 +583,7 @@ def test_group_by_key_deterministic_coder(self): global MyObject # for pickling of the class instance class MyObject: + def __init__(self, value): self.value = value @@ -576,6 +594,7 @@ def __hash__(self): return hash(self.value) class MyObjectCoder(beam.coders.Coder): + def encode(self, o): return pickle.dumps((o.value, random.random())) @@ -589,6 +608,7 @@ def to_type_hint(self): return MyObject class MydeterministicObjectCoder(beam.coders.Coder): + def encode(self, o): return pickle.dumps(o.value) @@ -650,7 +670,9 @@ def test_group_by_key_fake_deterministic_coder(self): assert_that(grouped, equal_to([[None]])) def test_partition_with_partition_fn(self): + class SomePartitionFn(beam.PartitionFn): + def partition_for(self, element, num_partitions, offset): return (element % 3) + offset @@ -687,10 +709,7 @@ def test_partition_with_callable_and_side_input(self): side_input = pipeline | 'Side Input' >> beam.Create([100, 1000]) partitions = ( pcoll | 'part' >> beam.Partition( - lambda e, - n, - offset, - si_list: ((e + len(si_list)) % 3) + offset, + lambda e, n, offset, si_list: ((e + len(si_list)) % 3) + offset, 4, 1, pvalue.AsList(side_input))) @@ -864,7 +883,9 @@ def test_apply_to_list(self): join_input | beam.CoGroupByKey() | SortLists) def test_multi_input_ptransform(self): + class DisjointUnion(PTransform): + def expand(self, pcollections): return ( pcollections @@ -876,12 +897,14 @@ def expand(self, pcollections): self.assertEqual([1, 2, 3], sorted(([1, 2], [2, 3]) | DisjointUnion())) def test_apply_to_crazy_pvaluish(self): + class NestedFlatten(PTransform): """A PTransform taking and returning nested PValueish. Takes as input a list of dicts, and returns a dict with the corresponding values flattened. """ + def _extract_input_pvalues(self, pvalueish): pvalueish = list(pvalueish) return pvalueish, sum([list(p.values()) for p in pvalueish], []) @@ -909,6 +932,7 @@ def test_named_tuple(self): MinMax = collections.namedtuple('MinMax', ['min', 'max']) class MinMaxTransform(PTransform): + def expand(self, pcoll): return MinMax( min=pcoll | beam.CombineGlobally(min).without_defaults(), @@ -922,7 +946,9 @@ def expand(self, pcoll): self.assertEqual(sorted(flat), [1, 8]) def test_tuple_twice(self): + class Duplicate(PTransform): + def expand(self, pcoll): return pcoll, pcoll @@ -943,7 +969,9 @@ def test_resource_hint_application_is_additive(self): class TestGroupBy(unittest.TestCase): + def test_lambdas(self): + def normalize(key, values): return tuple(key) if isinstance(key, tuple) else key, sorted(values) @@ -971,11 +999,14 @@ def normalize(key, values): 'GroupTwo') def test_fields(self): + def normalize(key, values): if isinstance(key, tuple): key = beam.Row( - **{name: value - for name, value in zip(type(key)._fields, key)}) + **{ + name: value + for name, value in zip(type(key)._fields, key) + }) return key, sorted(v.value for v in values) with TestPipeline() as p: @@ -1022,10 +1053,13 @@ def normalize(key, values): 'GroupSquareNonzero') def test_aggregate(self): + def named_tuple_to_row(t): return beam.Row( - **{name: value - for name, value in zip(type(t)._fields, t)}) + **{ + name: value + for name, value in zip(type(t)._fields, t) + }) with TestPipeline() as p: pcoll = p | beam.Create(range(-2, 3)) | beam.Map( @@ -1034,15 +1068,15 @@ def named_tuple_to_row(t): assert_that( pcoll - | beam.GroupBy('square', big=lambda x: x.value > 1) - .aggregate_field('value', sum, 'sum') - .aggregate_field(lambda x: x.sign == 1, all, 'positive') + | beam.GroupBy('square', big=lambda x: x.value > 1).aggregate_field( + 'value', sum, 'sum').aggregate_field( + lambda x: x.sign == 1, all, 'positive') | beam.Map(named_tuple_to_row), equal_to([ - beam.Row(square=0, big=False, sum=0, positive=False), # [0], - beam.Row(square=1, big=False, sum=0, positive=False), # [-1, 1] + beam.Row(square=0, big=False, sum=0, positive=False), # [0], + beam.Row(square=1, big=False, sum=0, positive=False), # [-1, 1] beam.Row(square=4, big=False, sum=-2, positive=False), # [-2] - beam.Row(square=4, big=True, sum=2, positive=True), # [2] + beam.Row(square=4, big=True, sum=2, positive=True), # [2] ])) def test_pickled_field(self): @@ -1059,6 +1093,7 @@ def test_pickled_field(self): class SelectTest(unittest.TestCase): + def test_simple(self): with TestPipeline() as p: rows = ( @@ -1109,6 +1144,7 @@ def SamplePTransform(pcoll): class PTransformLabelsTest(unittest.TestCase): + class CustomTransform(beam.PTransform): pardo: Optional[beam.PTransform] = None @@ -1184,6 +1220,7 @@ def check_label(self, ptransform, expected_label): self.assertEqual(expected_label, re.sub(r'\d{3,}', '#', actual_label)) def test_default_labels(self): + def my_function(*args): pass @@ -1198,6 +1235,7 @@ def my_function(*args): self.check_label(beam.CombinePerKey(sum), 'CombinePerKey(sum)') class MyDoFn(beam.DoFn): + def process(self, unused_element): pass @@ -1212,6 +1250,7 @@ def test_label_propogation(self): self.check_label('TestCPK' >> beam.CombinePerKey(sum), 'TestCPK') class MyDoFn(beam.DoFn): + def process(self, unused_element): pass @@ -1219,6 +1258,7 @@ def process(self, unused_element): class PTransformTestDisplayData(unittest.TestCase): + def test_map_named_function(self): tr = beam.Map(len) dd = DisplayData.create_from(tr) @@ -1269,6 +1309,7 @@ def test_filter_anonymous_function(self): class PTransformTypeCheckTestCase(TypeHintTestCase): + def assertStartswith(self, msg, prefix): self.assertTrue( msg.startswith(prefix), '"%s" does not start with "%s"' % (msg, prefix)) @@ -1277,9 +1318,11 @@ def setUp(self): self.p = TestPipeline() def test_do_fn_pipeline_pipeline_type_check_satisfied(self): + @with_input_types(int, int) @with_output_types(int) class AddWithFive(beam.DoFn): + def process(self, element, five): return [element + five] @@ -1292,9 +1335,11 @@ def process(self, element, five): self.p.run() def test_do_fn_pipeline_pipeline_type_check_violated(self): + @with_input_types(str, str) @with_output_types(str) class ToUpperCaseWithPrefix(beam.DoFn): + def process(self, element, prefix): return [prefix + element.upper()] @@ -1311,6 +1356,7 @@ def test_do_fn_pipeline_runtime_type_check_satisfied(self): @with_input_types(int, int) @with_output_types(int) class AddWithNum(beam.DoFn): + def process(self, element, num): return [element + num] @@ -1328,6 +1374,7 @@ def test_do_fn_pipeline_runtime_type_check_violated(self): @with_input_types(int, int) @with_output_types(int) class AddWithNum(beam.DoFn): + def process(self, element, num): return [element + num] @@ -1340,6 +1387,7 @@ def process(self, element, num): self.p.run() def test_pardo_does_not_type_check_using_type_hint_decorators(self): + @with_input_types(a=int) @with_output_types(typing.List[str]) def int_to_str(a): @@ -1355,6 +1403,7 @@ def int_to_str(a): | 'ToStr' >> beam.FlatMap(int_to_str)) def test_pardo_properly_type_checks_using_type_hint_decorators(self): + @with_input_types(a=str) @with_output_types(typing.List[str]) def to_all_upper_case(a): @@ -1422,6 +1471,7 @@ def test_map_properly_type_checks_using_type_hints_methods(self): self.p.run() def test_map_does_not_type_check_using_type_hints_decorator(self): + @with_input_types(s=str) @with_output_types(str) def upper(s): @@ -1437,6 +1487,7 @@ def upper(s): | 'Upper' >> beam.Map(upper)) def test_map_properly_type_checks_using_type_hints_decorator(self): + @with_input_types(a=bool) @with_output_types(int) def bool_to_int(a): @@ -1475,6 +1526,7 @@ def test_filter_type_checks_using_type_hints_method(self): self.p.run() def test_filter_does_not_type_check_using_type_hints_decorator(self): + @with_input_types(a=float) def more_than_half(a): return a > 0.50 @@ -1488,6 +1540,7 @@ def more_than_half(a): | 'Half' >> beam.Filter(more_than_half)) def test_filter_type_checks_using_type_hints_decorator(self): + @with_input_types(b=int) def half(b): return bool(random.choice([0, 1])) @@ -1501,6 +1554,7 @@ def half(b): int).with_output_types(bool)) def test_pardo_like_inheriting_output_types_from_annotation(self): + def fn1(x: str) -> int: return 1 @@ -1891,6 +1945,7 @@ def test_pipeline_runtime_checking_violation_with_side_inputs_via_method(self): "instead found 1.0, an instance of {}.".format(int, float)) def test_combine_properly_pipeline_type_checks_using_decorator(self): + @with_output_types(int) @with_input_types(ints=typing.Iterable[int]) def sum_ints(ints): @@ -1906,6 +1961,7 @@ def sum_ints(ints): self.p.run() def test_combine_properly_pipeline_type_checks_without_decorator(self): + def sum_ints(ints): return sum(ints) @@ -1919,6 +1975,7 @@ def sum_ints(ints): self.p.run() def test_combine_func_type_hint_does_not_take_iterable_using_decorator(self): + @with_output_types(int) @with_input_types(a=int) def bad_combine(a): @@ -1937,6 +1994,7 @@ def bad_combine(a): e.exception.args[0]) def test_combine_pipeline_type_propagation_using_decorators(self): + @with_output_types(int) @with_input_types(ints=typing.Iterable[int]) def sum_ints(ints): @@ -2006,6 +2064,7 @@ def test_combine_pipeline_type_check_using_methods(self): with_input_types(str).with_output_types(str))) def matcher(expected): + def match(actual): equal_to(expected)(list(actual[0])) @@ -2378,6 +2437,7 @@ def test_sample_globally_pipeline_satisfied(self): self.assertCompatible(typing.Iterable[int], d.element_type) def matcher(expected_len): + def match(actual): equal_to([expected_len])([len(actual[0])]) @@ -2397,6 +2457,7 @@ def test_sample_globally_runtime_satisfied(self): self.assertCompatible(typing.Iterable[int], d.element_type) def matcher(expected_len): + def match(actual): equal_to([expected_len])([len(actual[0])]) @@ -2417,6 +2478,7 @@ def test_sample_per_key_pipeline_satisfied(self): typing.Tuple[int, typing.Iterable[int]], d.element_type) def matcher(expected_len): + def match(actual): for _, sample in actual: equal_to([expected_len])([len(sample)]) @@ -2440,6 +2502,7 @@ def test_sample_per_key_runtime_satisfied(self): typing.Tuple[int, typing.Iterable[int]], d.element_type) def matcher(expected_len): + def match(actual): for _, sample in actual: equal_to([expected_len])([len(sample)]) @@ -2458,6 +2521,7 @@ def test_to_list_pipeline_check_satisfied(self): self.assertCompatible(typing.List[int], d.element_type) def matcher(expected): + def match(actual): equal_to(expected)(actual[0]) @@ -2477,6 +2541,7 @@ def test_to_list_runtime_check_satisfied(self): self.assertCompatible(typing.List[str], d.element_type) def matcher(expected): + def match(actual): equal_to(expected)(actual[0]) @@ -2598,6 +2663,7 @@ def test_eager_execution_tagged_outputs(self): @parameterized_class([{'use_subprocess': False}, {'use_subprocess': True}]) class DeadLettersTest(unittest.TestCase): + @classmethod def die(cls, x): if cls.use_subprocess: @@ -2761,6 +2827,7 @@ def test_increment_counter(self): return class CounterDoFn(beam.DoFn): + def __init__(self): self.records_counter = Metrics.counter(self.__class__, 'recordsCounter') @@ -2875,7 +2942,9 @@ def test_threshold(self): class TestPTransformFn(TypeHintTestCase): + def test_type_checking_fail(self): + @beam.ptransform_fn def MyTransform(pcoll): return pcoll | beam.ParDo(lambda x: [x]).with_output_types(str) @@ -2886,6 +2955,7 @@ def MyTransform(pcoll): _ = (p | beam.Create([1, 2]) | MyTransform().with_output_types(int)) def test_type_checking_success(self): + @beam.ptransform_fn def MyTransform(pcoll): return pcoll | beam.ParDo(lambda x: [x]).with_output_types(int) @@ -2908,6 +2978,7 @@ def MyTransform(pcoll, type_hints, test_arg): class PickledObject(object): + def __init__(self, value): self.value = value diff --git a/sdks/python/apache_beam/transforms/resources_test.py b/sdks/python/apache_beam/transforms/resources_test.py index 939bdcd62651..99392e4377f1 100644 --- a/sdks/python/apache_beam/transforms/resources_test.py +++ b/sdks/python/apache_beam/transforms/resources_test.py @@ -25,6 +25,7 @@ class ResourcesTest(unittest.TestCase): + @parameterized.expand([ param( name='min_ram', diff --git a/sdks/python/apache_beam/transforms/sideinputs.py b/sdks/python/apache_beam/transforms/sideinputs.py index 0ff2a388b9e1..8a6fa766753f 100644 --- a/sdks/python/apache_beam/transforms/sideinputs.py +++ b/sdks/python/apache_beam/transforms/sideinputs.py @@ -76,6 +76,7 @@ def get_sideinput_index(tag: str) -> int: class SideInputMap(object): """Represents a mapping of windows to side input values.""" + def __init__(self, view_class: 'pvalue.AsSideInput', view_options, iterable): self._window_mapping_fn = view_options.get( 'window_mapping_fn', _global_window_mapping_fn) @@ -98,6 +99,7 @@ def is_globally_windowed(self) -> bool: class _FilteringIterable(object): """An iterable containing only those values in the given window. """ + def __init__(self, iterable, target_window): self._iterable = iterable self._target_window = target_window diff --git a/sdks/python/apache_beam/transforms/sideinputs_test.py b/sdks/python/apache_beam/transforms/sideinputs_test.py index 4c6df9f9d8ec..6b193ec673d2 100644 --- a/sdks/python/apache_beam/transforms/sideinputs_test.py +++ b/sdks/python/apache_beam/transforms/sideinputs_test.py @@ -38,6 +38,7 @@ class SideInputsTest(unittest.TestCase): + def create_pipeline(self): return TestPipeline() @@ -100,17 +101,17 @@ def test_sliding_windows(self): window.SlidingWindows(size=6, period=2), expected=[ # Element 1 falls in three windows - (1, [1]), # [-4, 2) - (1, [1, 2]), # [-2, 4) + (1, [1]), # [-4, 2) + (1, [1, 2]), # [-2, 4) (1, [1, 2, 4]), # [0, 6) # as does 2, - (2, [1, 2]), # [-2, 4) + (2, [1, 2]), # [-2, 4) (2, [1, 2, 4]), # [0, 6) - (2, [2, 4]), # [2, 8) + (2, [2, 4]), # [2, 8) # and 4. (4, [1, 2, 4]), # [0, 6) - (4, [2, 4]), # [2, 8) - (4, [4]), # [4, 10) + (4, [2, 4]), # [2, 8) + (4, [4]), # [4, 10) ]) def test_windowed_iter(self): @@ -228,13 +229,12 @@ def test_as_list_and_as_dict_side_inputs(self): side_list = pipeline | 'side list' >> beam.Create(a_list) side_pairs = pipeline | 'side pairs' >> beam.Create(some_pairs) results = main_input | 'concatenate' >> beam.Map( - lambda x, - the_list, - the_dict: [x, the_list, the_dict], + lambda x, the_list, the_dict: [x, the_list, the_dict], beam.pvalue.AsList(side_list), beam.pvalue.AsDict(side_pairs)) def matcher(expected_elem, expected_list, expected_pairs): + def match(actual): [[actual_elem, actual_list, actual_dict]] = actual equal_to([expected_elem])([actual_elem]) @@ -256,13 +256,12 @@ def test_as_singleton_without_unique_labels(self): main_input = pipeline | 'main input' >> beam.Create([1]) side_list = pipeline | 'side list' >> beam.Create(a_list) results = main_input | beam.Map( - lambda x, - s1, - s2: [x, s1, s2], + lambda x, s1, s2: [x, s1, s2], beam.pvalue.AsSingleton(side_list), beam.pvalue.AsSingleton(side_list)) def matcher(expected_elem, expected_singleton): + def match(actual): [[actual_elem, actual_singleton1, actual_singleton2]] = actual equal_to([expected_elem])([actual_elem]) @@ -281,13 +280,12 @@ def test_as_singleton_with_different_defaults(self): main_input = pipeline | 'main input' >> beam.Create([1]) side_list = pipeline | 'side list' >> beam.Create(a_list) results = main_input | beam.Map( - lambda x, - s1, - s2: [x, s1, s2], + lambda x, s1, s2: [x, s1, s2], beam.pvalue.AsSingleton(side_list, default_value=2), beam.pvalue.AsSingleton(side_list, default_value=3)) def matcher(expected_elem, expected_singleton1, expected_singleton2): + def match(actual): [[actual_elem, actual_singleton1, actual_singleton2]] = actual equal_to([expected_elem])([actual_elem]) @@ -308,13 +306,12 @@ def test_as_list_twice(self): main_input = pipeline | 'main input' >> beam.Create([1]) side_list = pipeline | 'side list' >> beam.Create(a_list) results = main_input | beam.Map( - lambda x, - ls1, - ls2: [x, ls1, ls2], + lambda x, ls1, ls2: [x, ls1, ls2], beam.pvalue.AsList(side_list), beam.pvalue.AsList(side_list)) def matcher(expected_elem, expected_list): + def match(actual): [[actual_elem, actual_list1, actual_list2]] = actual equal_to([expected_elem])([actual_elem]) @@ -333,13 +330,12 @@ def test_as_dict_twice(self): main_input = pipeline | 'main input' >> beam.Create([1]) side_kvs = pipeline | 'side kvs' >> beam.Create(some_kvs) results = main_input | beam.Map( - lambda x, - dct1, - dct2: [x, dct1, dct2], + lambda x, dct1, dct2: [x, dct1, dct2], beam.pvalue.AsDict(side_kvs), beam.pvalue.AsDict(side_kvs)) def matcher(expected_elem, expected_kvs): + def match(actual): [[actual_elem, actual_dict1, actual_dict2]] = actual equal_to([expected_elem])([actual_elem]) @@ -403,6 +399,7 @@ def test_multi_triggered_gbk_side_input(self): | 'Values' >> Map(lambda k_vs: k_vs[1])) class RecordFn(beam.DoFn): + def process( self, elm=beam.DoFn.ElementParam, diff --git a/sdks/python/apache_beam/transforms/sql_test.py b/sdks/python/apache_beam/transforms/sql_test.py index a7da253c4617..4e299e7047d4 100644 --- a/sdks/python/apache_beam/transforms/sql_test.py +++ b/sdks/python/apache_beam/transforms/sql_test.py @@ -48,8 +48,8 @@ @pytest.mark.xlang_sql_expansion_service @unittest.skipIf( - TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is - None, + TestPipeline().get_pipeline_options().view_as(StandardOptions).runner + is None, "Must be run with a runner that supports staging java artifacts.") class SqlTransformTest(unittest.TestCase): """Tests that exercise the cross-language SqlTransform (implemented in java). diff --git a/sdks/python/apache_beam/transforms/stats.py b/sdks/python/apache_beam/transforms/stats.py index 0d56b60b050f..0ad66759525e 100644 --- a/sdks/python/apache_beam/transforms/stats.py +++ b/sdks/python/apache_beam/transforms/stats.py @@ -145,6 +145,7 @@ def _get_sample_size_from_est_error(est_err): @typehints.with_output_types(int) class Globally(PTransform): """ Approximate.Globally approximate number of unique values""" + def __init__(self, size=None, error=None): self._sample_size = ApproximateUnique.parse_input_params(size, error) @@ -159,6 +160,7 @@ def expand(self, pcoll): @typehints.with_output_types(typing.Tuple[K, int]) class PerKey(PTransform): """ Approximate.PerKey approximate number of unique values per key""" + def __init__(self, size=None, error=None): self._sample_size = ApproximateUnique.parse_input_params(size, error) @@ -242,6 +244,7 @@ class ApproximateUniqueCombineFn(CombineFn): ApproximateUniqueCombineFn computes an estimate of the number of unique values that were combined. """ + def __init__(self, sample_size, coder): self._sample_size = sample_size coder = coders.typecoders.registry.verify_deterministic( @@ -306,6 +309,7 @@ class ApproximateQuantiles(object): out: [0, 2, 5, 7, 100] """ + @staticmethod def _display_data(num_quantiles, key, reverse, weighted, input_batched): return { @@ -343,6 +347,7 @@ class Globally(PTransform): weighted. Provides a way to accumulate multiple elements at a time more efficiently. """ + def __init__( self, num_quantiles, @@ -398,6 +403,7 @@ class PerKey(PTransform): weighted. Provides a way to accumulate multiple elements at a time more efficiently. """ + def __init__( self, num_quantiles, @@ -431,6 +437,7 @@ def display_data(self): class _QuantileSpec(object): """Quantiles computation specifications.""" + def __init__(self, buffer_size, num_buffers, weighted, key, reverse): # type: (int, int, bool, Any, bool) -> None self.buffer_size = buffer_size @@ -476,6 +483,7 @@ class _QuantileBuffer(object): """A single buffer in the sense of the referenced algorithm. (see http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.6513&rep=rep1 &type=pdf and ApproximateQuantilesCombineFn for further information)""" + def __init__( self, elements, weights, weighted, level=0, min_val=None, max_val=None): # type: (List, List, bool, int, Any, Any) -> None @@ -509,6 +517,7 @@ class _QuantileState(object): """ Compact summarization of a collection on which quantiles can be estimated. """ + def __init__(self, unbuffered_elements, unbuffered_weights, buffers, spec): # type: (List, List, List[_QuantileBuffer], _QuantileSpec) -> None self.buffers = buffers diff --git a/sdks/python/apache_beam/transforms/stats_test.py b/sdks/python/apache_beam/transforms/stats_test.py index bf634c003a07..150311cd2571 100644 --- a/sdks/python/apache_beam/transforms/stats_test.py +++ b/sdks/python/apache_beam/transforms/stats_test.py @@ -639,6 +639,7 @@ def _build_quantilebuffer_test_data(): class ApproximateQuantilesBufferTest(unittest.TestCase): """ Approximate Quantiles Buffer Tests to ensure we are calculating the optimal buffers.""" + @parameterized.expand(_build_quantilebuffer_test_data) def test_efficiency( self, epsilon, maxInputSize, expectedNumBuffers, expectedBufferSize): diff --git a/sdks/python/apache_beam/transforms/timestamped_value_type_test.py b/sdks/python/apache_beam/transforms/timestamped_value_type_test.py index 46449bb1ef72..8ca8689be3e2 100644 --- a/sdks/python/apache_beam/transforms/timestamped_value_type_test.py +++ b/sdks/python/apache_beam/transforms/timestamped_value_type_test.py @@ -46,6 +46,7 @@ def ConvertToTimestampedValue_3(plant: Dict[str, Any]) -> TimestampedValue[T]: class TypeCheckTimestampedValueTestCase(unittest.TestCase): + def setUp(self): self.opts = beam.options.pipeline_options.PipelineOptions( runtime_type_check=True) diff --git a/sdks/python/apache_beam/transforms/timeutil.py b/sdks/python/apache_beam/transforms/timeutil.py index 87294b0dcf4d..ef0c2d95ca3b 100644 --- a/sdks/python/apache_beam/transforms/timeutil.py +++ b/sdks/python/apache_beam/transforms/timeutil.py @@ -60,6 +60,7 @@ def is_event_time(domain): class TimestampCombinerImpl(metaclass=ABCMeta): """Implementation of TimestampCombiner.""" + @abstractmethod def assign_output_time(self, window, input_timestamp): raise NotImplementedError @@ -85,6 +86,7 @@ def merge(self, unused_result_window, merging_timestamps): class DependsOnlyOnWindow(TimestampCombinerImpl, metaclass=ABCMeta): """TimestampCombinerImpl that only depends on the window.""" + def merge(self, result_window, unused_merging_timestamps): # Since we know that the result only depends on the window, we can ignore # the given timestamps. @@ -93,6 +95,7 @@ def merge(self, result_window, unused_merging_timestamps): class OutputAtEarliestInputTimestampImpl(TimestampCombinerImpl): """TimestampCombinerImpl outputting at earliest input timestamp.""" + def assign_output_time(self, window, input_timestamp): return input_timestamp @@ -103,6 +106,7 @@ def combine(self, output_timestamp, other_output_timestamp): class OutputAtEarliestTransformedInputTimestampImpl(TimestampCombinerImpl): """TimestampCombinerImpl outputting at earliest input timestamp.""" + def __init__(self, window_fn): self.window_fn = window_fn @@ -115,6 +119,7 @@ def combine(self, output_timestamp, other_output_timestamp): class OutputAtLatestInputTimestampImpl(TimestampCombinerImpl): """TimestampCombinerImpl outputting at latest input timestamp.""" + def assign_output_time(self, window, input_timestamp): return input_timestamp @@ -124,6 +129,7 @@ def combine(self, output_timestamp, other_output_timestamp): class OutputAtEndOfWindowImpl(DependsOnlyOnWindow): """TimestampCombinerImpl outputting at end of window.""" + def assign_output_time(self, window, unused_input_timestamp): return window.max_timestamp() diff --git a/sdks/python/apache_beam/transforms/transforms_keyword_only_args_test.py b/sdks/python/apache_beam/transforms/transforms_keyword_only_args_test.py index 28566ba55a03..ea92dab5a309 100644 --- a/sdks/python/apache_beam/transforms/transforms_keyword_only_args_test.py +++ b/sdks/python/apache_beam/transforms/transforms_keyword_only_args_test.py @@ -29,6 +29,7 @@ class KeywordOnlyArgsTests(unittest.TestCase): + def test_side_input_keyword_only_args(self): with TestPipeline() as pipeline: @@ -114,6 +115,7 @@ def test_do_fn_keyword_only_args(self): with TestPipeline() as pipeline: class MyDoFn(beam.DoFn): + def process(self, element, *s, bound=500): return [min(sum(s) + element, bound)] diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 63895704727f..7773fccf3cd0 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -78,12 +78,14 @@ class _StateTag(metaclass=ABCMeta): """An identifier used to store and retrieve typed, combinable state. The given tag must be unique for this step.""" + def __init__(self, tag): self.tag = tag class _ReadModifyWriteStateTag(_StateTag): """StateTag pointing to an element.""" + def __repr__(self): return 'ValueStateTag(%s)' % (self.tag) @@ -93,6 +95,7 @@ def with_prefix(self, prefix): class _SetStateTag(_StateTag): """StateTag pointing to an element.""" + def __repr__(self): return 'SetStateTag({tag})'.format(tag=self.tag) @@ -122,6 +125,7 @@ def with_prefix(self, prefix): return _CombiningValueStateTag(prefix + self.tag, self.combine_fn) def without_extraction(self): + class NoExtractionCombineFn(core.CombineFn): setup = self.combine_fn.setup create_accumulator = self.combine_fn.create_accumulator @@ -136,6 +140,7 @@ class NoExtractionCombineFn(core.CombineFn): class _ListStateTag(_StateTag): """StateTag pointing to a list of elements.""" + def __repr__(self): return 'ListStateTag(%s)' % self.tag @@ -144,6 +149,7 @@ def with_prefix(self, prefix): class _WatermarkHoldStateTag(_StateTag): + def __init__(self, tag, timestamp_combiner_impl): super().__init__(tag) self.timestamp_combiner_impl = timestamp_combiner_impl @@ -192,6 +198,7 @@ class TriggerFn(metaclass=ABCMeta): See https://beam.apache.org/documentation/programming-guide/#triggers """ + @abstractmethod def on_element(self, element, window, context): """Called when a new element arrives in a window. @@ -320,6 +327,7 @@ def to_runner_api(self, unused_context): class DefaultTrigger(TriggerFn): """Semantically Repeatedly(AfterWatermark()), but more optimized.""" + def __init__(self): pass @@ -424,6 +432,7 @@ def has_ontime_pane(self): class Always(TriggerFn): """Repeatedly invoke the given trigger, never finishing.""" + def __init__(self): pass @@ -472,6 +481,7 @@ class _Never(TriggerFn): Data may still be released at window closing. """ + def __init__(self): pass @@ -695,6 +705,7 @@ def has_ontime_pane(self): class Repeatedly(TriggerFn): """Repeatedly invoke the given trigger, never finishing.""" + def __init__(self, underlying): self.underlying = underlying @@ -743,6 +754,7 @@ def has_ontime_pane(self): class _ParallelTriggerFn(TriggerFn, metaclass=ABCMeta): + def __init__(self, *triggers): self.triggers = triggers @@ -774,8 +786,7 @@ def should_fire(self, time_domain, watermark, window, context): return self.combine_op( trigger.should_fire( time_domain, watermark, window, self._sub_context(context, ix)) - for ix, - trigger in enumerate(self.triggers)) + for ix, trigger in enumerate(self.triggers)) def on_fire(self, watermark, window, context): finished = [] @@ -936,6 +947,7 @@ def has_ontime_pane(self): class OrFinally(AfterAny): + @staticmethod def from_runner_api(proto, context): return OrFinally( @@ -953,6 +965,7 @@ def to_runner_api(self, context): class TriggerContext(object): + def __init__(self, outer, window, clock): self._outer = outer self._window = window @@ -979,6 +992,7 @@ def clear_state(self, tag): class NestedContext(object): """Namespaced context useful for defining composite triggers.""" + def __init__(self, outer, prefix): self._outer = outer self._prefix = prefix @@ -1008,6 +1022,7 @@ class SimpleState(metaclass=ABCMeta): Only timers must hold the watermark (by their timestamp). """ + @abstractmethod def set_timer( self, window, name, time_domain, timestamp, dynamic_timer_tag=''): @@ -1042,6 +1057,7 @@ class UnmergedState(SimpleState): This class must be implemented by each backend. """ + @abstractmethod def set_global_state(self, tag, value): pass @@ -1195,6 +1211,7 @@ def create_trigger_driver( class TriggerDriver(metaclass=ABCMeta): """Breaks a series of bundle and timer firings into window (pane)s.""" + @abstractmethod def process_elements( self, @@ -1236,6 +1253,7 @@ def process_entire_key(self, key, windowed_values): class _UnwindowedValues(observable.ObservableMixin): """Exposes iterable of windowed values as iterable of unwindowed values.""" + def __init__(self, windowed_values): super().__init__() self._windowed_values = windowed_values @@ -1303,6 +1321,7 @@ def process_timer( class CombiningTriggerDriver(TriggerDriver): """Uses a phased_combine_fn to process output of wrapped TriggerDriver.""" + def __init__(self, phased_combine_fn, underlying): self.phased_combine_fn = phased_combine_fn self.underlying = underlying @@ -1383,6 +1402,7 @@ def process_elements( merged_away = {} class TriggerMergeContext(WindowFn.MergeContext): + def merge(_, to_be_merged, merge_result): # pylint: disable=no-self-argument for window in to_be_merged: if window != merge_result: @@ -1419,8 +1439,8 @@ def merge(_, to_be_merged, merge_result): # pylint: disable=no-self-argument ( element_output_time for element_output_time in ( self.timestamp_combiner_impl.assign_output_time( - window, timestamp) for unused_value, - timestamp in elements) + window, timestamp) + for unused_value, timestamp in elements) if element_output_time >= output_watermark)) if output_time is not None: state.add_state(window, self.WATERMARK_HOLD, output_time) @@ -1519,6 +1539,7 @@ class InMemoryUnmergedState(UnmergedState): Used for batch and testing. """ + def __init__(self, defensive_copy=False): # TODO(robertwb): Clean defensive_copy. It is too expensive in production. self.timers = collections.defaultdict(dict) diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py index 962a06e485df..380bb8cdae9e 100644 --- a/sdks/python/apache_beam/transforms/trigger_test.py +++ b/sdks/python/apache_beam/transforms/trigger_test.py @@ -75,11 +75,13 @@ class CustomTimestampingFixedWindowsWindowFn(FixedWindows): """WindowFn for testing custom timestamping.""" + def get_transformed_output_time(self, unused_window, input_timestamp): return input_timestamp + 100 class TriggerTest(unittest.TestCase): + def run_trigger_simple( self, window_fn, @@ -197,8 +199,10 @@ def test_fixed_watermark(self): AfterWatermark(), AccumulationMode.ACCUMULATING, [(1, 'a'), (2, 'b'), (13, 'c')], - {IntervalWindow(0, 10): [set('ab')], - IntervalWindow(10, 20): [set('c')]}, + { + IntervalWindow(0, 10): [set('ab')], + IntervalWindow(10, 20): [set('c')] + }, 1, 2, 3, @@ -225,36 +229,38 @@ def test_fixed_watermark_with_early(self): def test_fixed_watermark_with_early_late(self): self.run_trigger_simple( FixedWindows(100), # pyformat break - AfterWatermark(early=AfterCount(3), - late=AfterCount(2)), + AfterWatermark(early=AfterCount(3), late=AfterCount(2)), AccumulationMode.DISCARDING, zip(range(9), 'abcdefghi'), - {IntervalWindow(0, 100): [ - set('abcd'), set('efgh'), # early - set('i'), # on time - set('vw'), set('xy') # late - ]}, + { + IntervalWindow(0, 100): [ + set('abcd'), + set('efgh'), # early + set('i'), # on time + set('vw'), + set('xy') # late + ] + }, 2, late_data=zip(range(5), 'vwxyz')) def test_sessions_watermark_with_early_late(self): self.run_trigger_simple( Sessions(10), # pyformat break - AfterWatermark(early=AfterCount(2), - late=AfterCount(1)), + AfterWatermark(early=AfterCount(2), late=AfterCount(1)), AccumulationMode.ACCUMULATING, [(1, 'a'), (15, 'b'), (7, 'c'), (30, 'd')], { IntervalWindow(1, 25): [ - set('abc'), # early - set('abc'), # on time - set('abcxy') # late + set('abc'), # early + set('abc'), # on time + set('abcxy') # late ], IntervalWindow(30, 40): [ - set('d'), # on time + set('d'), # on time ], IntervalWindow(1, 40): [ - set('abcdxyz') # late + set('abcdxyz') # late ], }, 2, @@ -303,13 +309,16 @@ def test_repeatedly_after_first(self): Repeatedly(AfterAny(AfterCount(3), AfterWatermark())), AccumulationMode.ACCUMULATING, zip(range(7), 'abcdefg'), - {IntervalWindow(0, 100): [ - set('abc'), - set('abcdef'), - set('abcdefg'), - set('abcdefgx'), - set('abcdefgxy'), - set('abcdefgxyz')]}, + { + IntervalWindow(0, 100): [ + set('abc'), + set('abcdef'), + set('abcdefg'), + set('abcdefgx'), + set('abcdefgxy'), + set('abcdefgxyz') + ] + }, 1, late_data=zip(range(3), 'xyz')) @@ -350,8 +359,10 @@ def test_sessions_default(self): AccumulationMode.ACCUMULATING, [(1, 'a'), (2, 'b'), (15, 'c'), (16, 'd'), (30, 'z'), (9, 'e'), (10, 'f'), (30, 'y')], - {IntervalWindow(1, 26): [set('abcdef')], - IntervalWindow(30, 40): [set('yz')]}, + { + IntervalWindow(1, 26): [set('abcdef')], + IntervalWindow(30, 40): [set('yz')] + }, 1, 2, 3, @@ -381,9 +392,11 @@ def test_sessions_after_count(self): AccumulationMode.ACCUMULATING, [(1, 'a'), (15, 'b'), (6, 'c'), (30, 's'), (31, 't'), (50, 'z'), (50, 'y')], - {IntervalWindow(1, 25): [set('abc')], - IntervalWindow(30, 41): [set('st')], - IntervalWindow(50, 60): [set('yz')]}, + { + IntervalWindow(1, 25): [set('abc')], + IntervalWindow(30, 41): [set('st')], + IntervalWindow(50, 60): [set('yz')] + }, 1, 2, 3) @@ -412,8 +425,10 @@ def test_sessions_after_each(self): AfterEach(AfterCount(2), AfterCount(3)), AccumulationMode.ACCUMULATING, zip(range(10), 'abcdefghij'), - {IntervalWindow(0, 11): [set('ab')], - IntervalWindow(0, 15): [set('abcdef')]}, + { + IntervalWindow(0, 11): [set('ab')], + IntervalWindow(0, 15): [set('abcdef')] + }, 2) self.run_trigger_simple( @@ -421,9 +436,11 @@ def test_sessions_after_each(self): Repeatedly(AfterEach(AfterCount(2), AfterCount(3))), AccumulationMode.ACCUMULATING, zip(range(10), 'abcdefghij'), - {IntervalWindow(0, 11): [set('ab')], - IntervalWindow(0, 15): [set('abcdef')], - IntervalWindow(0, 17): [set('abcdefgh')]}, + { + IntervalWindow(0, 11): [set('ab')], + IntervalWindow(0, 15): [set('abcdef')], + IntervalWindow(0, 17): [set('abcdefgh')] + }, 2) def test_picklable_output(self): @@ -438,6 +455,7 @@ def test_picklable_output(self): class MayLoseDataTest(unittest.TestCase): + def _test(self, trigger, lateness, expected): windowing = WindowInto( GlobalWindows(), @@ -524,6 +542,7 @@ def test_after_each_all_may_finish(self): class RunnerApiTest(unittest.TestCase): + def test_trigger_encoding(self): for trigger_fn in (DefaultTrigger(), AfterAll(AfterCount(1), AfterCount(10)), @@ -540,6 +559,7 @@ def test_trigger_encoding(self): class TriggerPipelineTest(unittest.TestCase): + def test_after_processing_time(self): test_options = PipelineOptions( flags=['--allow_unsafe_triggers', '--streaming']) @@ -678,9 +698,11 @@ def test_after_count_streaming(self): assert_that( results, - equal_to(list({ - 'A': [1, 2, 3], # 4 - 6 discarded because trigger finished - 'B': [1, 2, 3]}.items()))) + equal_to( + list({ + 'A': [1, 2, 3], # 4 - 6 discarded because trigger finished + 'B': [1, 2, 3] + }.items()))) def test_always(self): with TestPipeline() as p: @@ -706,12 +728,9 @@ def format_result(k, vs): result, equal_to( list({ - 'A-2': {10, 11}, - # Elements out of windows are also emitted. - 'A-6': {1, 2, 3, 4, 5}, - # A,1 is emitted twice. - 'B-5': {6, 7, 8, 9}, - # B,6 is emitted twice. + 'A-2': {10, 11}, # Elements out of windows are also emitted. + 'A-6': {1, 2, 3, 4, 5}, # A,1 is emitted twice. + 'B-5': {6, 7, 8, 9}, # B,6 is emitted twice. 'B-3': {10, 15, 16}, }.items()))) @@ -840,6 +859,7 @@ def _run_log_test(self, spec): self._run_log(spec) def _run_log(self, spec): + def parse_int_list(s): """Parses strings like '[1, 2, 3]'.""" s = s.strip() @@ -1010,6 +1030,7 @@ class _ConcatCombineFn(beam.CombineFn): class TriggerDriverTranscriptTest(TranscriptTest): + def _execute( self, window_fn, @@ -1092,6 +1113,7 @@ def fire_timers(): class BaseTestStreamTranscriptTest(TranscriptTest): """A suite of TestStream-based tests based on trigger transcript entries. """ + def _execute( self, window_fn, @@ -1182,6 +1204,7 @@ class Check(beam.DoFn): The key is ignored, but all items must be on the same key to share state. """ + def __init__(self, allow_out_of_order=True): # Some runners don't support cross-stage TestStream semantics. self.allow_out_of_order = allow_out_of_order @@ -1284,6 +1307,7 @@ class WeakTestStreamTranscriptTest(BaseTestStreamTranscriptTest): class BatchTranscriptTest(TranscriptTest): + def _execute( self, window_fn, @@ -1324,6 +1348,7 @@ def _execute( merged_away = set() class MergeContext(WindowFn.MergeContext): + def merge(_, to_be_merged, merge_result): for window in to_be_merged: if window != merge_result: diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py index 3b876bf9dbfb..408dad408421 100644 --- a/sdks/python/apache_beam/transforms/userstate.py +++ b/sdks/python/apache_beam/transforms/userstate.py @@ -50,6 +50,7 @@ class StateSpec(object): """Specification for a user DoFn state cell.""" + def __init__(self, name: str, coder: Coder) -> None: if not isinstance(name, str): raise TypeError("name is not a string") @@ -68,6 +69,7 @@ def to_runner_api( class ReadModifyWriteStateSpec(StateSpec): """Specification for a user DoFn value state cell.""" + def to_runner_api( self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( @@ -79,6 +81,7 @@ def to_runner_api( class BagStateSpec(StateSpec): """Specification for a user DoFn bag state cell.""" + def to_runner_api( self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( @@ -90,6 +93,7 @@ def to_runner_api( class SetStateSpec(StateSpec): """Specification for a user DoFn Set State cell""" + def to_runner_api( self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( @@ -101,6 +105,7 @@ def to_runner_api( class CombiningValueStateSpec(StateSpec): """Specification for a user DoFn combining value state cell.""" + def __init__( self, name: str, @@ -152,6 +157,7 @@ def to_runner_api( class OrderedListStateSpec(StateSpec): """Specification for a user DoFn ordered list state cell.""" + def to_runner_api( self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( @@ -307,6 +313,7 @@ def validate_stateful_dofn(dofn: 'DoFn') -> None: class BaseTimer(object): + def clear(self, dynamic_timer_tag: str = '') -> None: raise NotImplementedError @@ -319,6 +326,7 @@ def set(self, timestamp: Timestamp, dynamic_timer_tag: str = '') -> None: class RuntimeTimer(BaseTimer): """Timer interface object passed to user code.""" + def __init__(self) -> None: self._timer_recordings: Dict[str, _TimerTuple] = {} self._cleared = False @@ -335,6 +343,7 @@ def set(self, timestamp: Timestamp, dynamic_timer_tag: str = '') -> None: class RuntimeState(object): """State interface object passed to user code.""" + def prefetch(self) -> None: # The default implementation here does nothing. pass @@ -344,6 +353,7 @@ def finalize(self) -> None: class ReadModifyWriteRuntimeState(RuntimeState): + def read(self) -> Any: raise NotImplementedError(type(self)) @@ -358,6 +368,7 @@ def commit(self) -> None: class AccumulatingRuntimeState(RuntimeState): + def read(self) -> Iterable[Any]: raise NotImplementedError(type(self)) @@ -385,6 +396,7 @@ class CombiningValueRuntimeState(AccumulatingRuntimeState): class OrderedListRuntimeState(AccumulatingRuntimeState): """Ordered list state interface object passed to user code.""" + def read(self) -> Iterable[Tuple[Timestamp, Any]]: raise NotImplementedError(type(self)) @@ -403,6 +415,7 @@ def clear_range( class UserStateContext(object): """Wrapper allowing user state and timers to be accessed by a DoFnInvoker.""" + def get_timer( self, timer_spec: TimerSpec, diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py index 5dd6c61d6add..707838ce1fd8 100644 --- a/sdks/python/apache_beam/transforms/userstate_test.py +++ b/sdks/python/apache_beam/transforms/userstate_test.py @@ -115,6 +115,7 @@ def on_expiry_family( class InterfaceTest(unittest.TestCase): + def _validate_dofn(self, dofn): # Construction of DoFnSignature performs validation of the given DoFn. # In particular, it ends up calling userstate._validate_stateful_dofn. @@ -202,6 +203,7 @@ def test_stateful_dofn_detection(self): self.assertTrue(is_stateful_dofn(TestStatefulDoFn())) def test_good_signatures(self): + class BasicStatefulDoFn(DoFn): BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK) @@ -448,13 +450,16 @@ def setUp(self): StatefulDoFnOnDirectRunnerTest.all_records = [] def record_dofn(self): + class RecordDoFn(DoFn): + def process(self, element): StatefulDoFnOnDirectRunnerTest.all_records.append(element) return RecordDoFn() def test_simple_stateful_dofn(self): + class SimpleTestStatefulDoFn(DoFn): BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK) @@ -494,6 +499,7 @@ def expiry_callback( StatefulDoFnOnDirectRunnerTest.all_records) def test_clearing_bag_state(self): + class BagStateClearingStatefulDoFn(beam.DoFn): BAG_STATE = BagStateSpec('bag_state', StrUtf8Coder()) @@ -536,6 +542,7 @@ def clear_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)): self.assertEqual(['extra'], StatefulDoFnOnDirectRunnerTest.all_records) def test_two_timers_one_function(self): + class BagStateClearingStatefulDoFn(beam.DoFn): BAG_STATE = BagStateSpec('bag_state', StrUtf8Coder()) @@ -575,6 +582,7 @@ def emit_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)): StatefulDoFnOnDirectRunnerTest.all_records) def test_simple_read_modify_write_stateful_dofn(self): + class SimpleTestReadModifyWriteStatefulDoFn(DoFn): VALUE_STATE = ReadModifyWriteStateSpec('value', StrUtf8Coder()) @@ -597,6 +605,7 @@ def process(self, element, last_element=DoFn.StateParam(VALUE_STATE)): StatefulDoFnOnDirectRunnerTest.all_records) def test_clearing_read_modify_write_state(self): + class SimpleClearingReadModifyWriteStatefulDoFn(DoFn): VALUE_STATE = ReadModifyWriteStateSpec('value', StrUtf8Coder()) @@ -624,6 +633,7 @@ def process(self, element, last_element=DoFn.StateParam(VALUE_STATE)): StatefulDoFnOnDirectRunnerTest.all_records) def test_simple_set_stateful_dofn(self): + class SimpleTestSetStatefulDoFn(DoFn): BUFFER_STATE = SetStateSpec('buffer', VarIntCoder()) EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK) @@ -658,6 +668,7 @@ def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE)): self.assertEqual([[1, 2, 3]], StatefulDoFnOnDirectRunnerTest.all_records) def test_clearing_set_state(self): + class SetStateClearingStatefulDoFn(beam.DoFn): SET_STATE = SetStateSpec('buffer', StrUtf8Coder()) @@ -701,6 +712,7 @@ def clear_values(self, set_state=beam.DoFn.StateParam(SET_STATE)): StatefulDoFnOnDirectRunnerTest.all_records) def test_stateful_set_state_portably(self): + class SetStatefulDoFn(beam.DoFn): SET_STATE = SetStateSpec('buffer', VarIntCoder()) @@ -721,6 +733,7 @@ def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)): assert_that(actual_values, equal_to([1, 3, 6, 10, 10])) def test_stateful_set_state_clean_portably(self): + class SetStateClearingStatefulDoFn(beam.DoFn): SET_STATE = SetStateSpec('buffer', VarIntCoder()) @@ -823,6 +836,7 @@ def emit_values( StatefulDoFnOnDirectRunnerTest.all_records) def test_simple_stateful_dofn_combining(self): + class SimpleTestStatefulDoFn(DoFn): BUFFER_STATE = CombiningValueStateSpec( 'buffer', ListCoder(VarIntCoder()), ToListCombineFn()) @@ -860,6 +874,7 @@ def expiry_callback( StatefulDoFnOnDirectRunnerTest.all_records) def test_timer_output_timestamp(self): + class TimerEmittingStatefulDoFn(DoFn): EMIT_TIMER_1 = TimerSpec('emit1', TimeDomain.WATERMARK) EMIT_TIMER_2 = TimerSpec('emit2', TimeDomain.WATERMARK) @@ -888,6 +903,7 @@ def emit_callback_3(self): yield 'timer3' class TimestampReifyingDoFn(DoFn): + def process(self, element, ts=DoFn.TimestampParam): yield (element, int(ts)) @@ -905,6 +921,7 @@ def process(self, element, ts=DoFn.TimestampParam): sorted(StatefulDoFnOnDirectRunnerTest.all_records)) def test_timer_output_timestamp_and_window(self): + class TimerEmittingStatefulDoFn(DoFn): EMIT_TIMER_1 = TimerSpec('emit1', TimeDomain.WATERMARK) @@ -940,6 +957,7 @@ def emit_callback_1( sorted(StatefulDoFnOnDirectRunnerTest.all_records)) def test_timer_default_tag(self): + class DynamicTimerDoFn(DoFn): EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK) @@ -966,6 +984,7 @@ def emit_callback( sorted(StatefulDoFnOnDirectRunnerTest.all_records)) def test_dynamic_timer_simple_dofn(self): + class DynamicTimerDoFn(DoFn): EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK) @@ -995,6 +1014,7 @@ def emit_callback( @pytest.mark.no_xdist @pytest.mark.timeout(10) def test_dynamic_timer_clear_then_set_timer(self): + class EmitTwoEvents(DoFn): EMIT_CLEAR_SET_TIMER = TimerSpec('emitclear', TimeDomain.WATERMARK) @@ -1033,6 +1053,7 @@ def emit_callback( assert_that(res, equal_to([('emit1', 10), ('emit2', 20), ('emit3', 40)])) def test_dynamic_timer_clear_timer(self): + class DynamicTimerDoFn(DoFn): EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK) @@ -1065,6 +1086,7 @@ def emit_callback( sorted(StatefulDoFnOnDirectRunnerTest.all_records)) def test_dynamic_timer_multiple(self): + class DynamicTimerDoFn(DoFn): EMIT_TIMER_FAMILY1 = TimerSpec('emit_family_1', TimeDomain.WATERMARK) EMIT_TIMER_FAMILY2 = TimerSpec('emit_family_2', TimeDomain.WATERMARK) @@ -1107,6 +1129,7 @@ def emit_callback_2( sorted(StatefulDoFnOnDirectRunnerTest.all_records)) def test_dynamic_timer_and_simple_timer(self): + class DynamicTimerDoFn(DoFn): EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK) GC_TIMER = TimerSpec('gc', TimeDomain.WATERMARK) @@ -1145,6 +1168,7 @@ def gc(self, ts=DoFn.TimestampParam): sorted(StatefulDoFnOnDirectRunnerTest.all_records)) def test_index_assignment(self): + class IndexAssigningStatefulDoFn(DoFn): INDEX_STATE = CombiningValueStateSpec('index', sum) @@ -1170,6 +1194,7 @@ def process(self, element, state=DoFn.StateParam(INDEX_STATE)): StatefulDoFnOnDirectRunnerTest.all_records) def test_hash_join(self): + class HashJoinStatefulDoFn(DoFn): BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) UNMATCHED_TIMER = TimerSpec('unmatched', TimeDomain.WATERMARK) diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index a03652de2496..0d906fb39e07 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -161,6 +161,7 @@ class CoGroupByKey(PTransform): (or if there's a chance there may be none), this argument is the only way to provide pipeline information, and should be considered mandatory. """ + def __init__(self, *, pipeline=None): self.pipeline = pipeline @@ -231,6 +232,7 @@ def expand(self, pcolls): class _CoGBKImpl(PTransform): + def __init__(self, *, pipeline=None): self.pipeline = pipeline @@ -570,6 +572,7 @@ def stats(self): class _GlobalWindowsBatchingDoFn(DoFn): + def __init__(self, batch_size_estimator, element_size_fn): self._batch_size_estimator = batch_size_estimator self._element_size_fn = element_size_fn @@ -604,6 +607,7 @@ def finish_bundle(self): class _SizedBatch(): + def __init__(self): self.elements = [] self.size = 0 @@ -670,6 +674,7 @@ def _pardo_stateful_batch_elements( 'batch_estimator', coders.PickleCoder()) class _StatefulBatchElementsDoFn(DoFn): + def process( self, element, @@ -761,6 +766,7 @@ class SharedKey(): """A class that holds a per-process UUID used to key elements for streaming BatchElements. """ + def __init__(self): self.key = uuid.uuid4().hex @@ -773,6 +779,7 @@ class WithSharedKey(DoFn): """A DoFn that keys elements with a per-process UUID. Used in streaming BatchElements. """ + def __init__(self): self.shared_handle = shared.Shared() @@ -843,6 +850,7 @@ class BatchElements(PTransform): record_metrics: (optional) whether or not to record beam metrics on distributions of the batch size. Defaults to True. """ + def __init__( self, min_batch_size=1, @@ -900,6 +908,7 @@ class _IdentityWindowFn(NonMergingWindowFn): Will raise an exception when used after DoFns that return TimestampedValue elements. """ + def __init__(self, window_coder): """Create a new WindowFn with compatible coder. To be applied to PCollections with windows that are compatible with the @@ -932,6 +941,7 @@ class ReshufflePerKey(PTransform): in particular checkpointing, and preventing fusion of the surrounding transforms. """ + def expand(self, pcoll): windowing_saved = pcoll.windowing if windowing_saved.is_default(): @@ -1062,9 +1072,7 @@ def WithKeys(pcoll, k, *args, **kwargs): for arg in args) and all(isinstance(kwarg, AsSideInput) for kwarg in kwargs.values()): return pcoll | Map( - lambda v, - *args, - **kwargs: (k(v, *args, **kwargs), v), + lambda v, *args, **kwargs: (k(v, *args, **kwargs), v), *args, **kwargs) return pcoll | Map(lambda v: (k(v, *args, **kwargs), v)) @@ -1081,6 +1089,7 @@ class GroupIntoBatches(PTransform): Windows are preserved (batches will contain elements from the same window) """ + def __init__( self, batch_size, max_buffering_duration_secs=None, clock=time.time): """Create a new GroupIntoBatches. @@ -1136,6 +1145,7 @@ class WithShardedKey(PTransform): override the default sharding to do a better load balancing during the execution time. """ + def __init__( self, batch_size, max_buffering_duration_secs=None, clock=time.time): """Create a new GroupIntoBatches with sharded output. @@ -1189,6 +1199,7 @@ class _GroupIntoBatchesParams: :class:`apache_beam.utils.GroupIntoBatches` transform, used to define how elements should be batched. """ + def __init__(self, batch_size, max_buffering_duration_secs): self.batch_size = batch_size self.max_buffering_duration_secs = ( @@ -1208,8 +1219,8 @@ def _validate(self): 'batch_size must be a positive value') assert ( self.max_buffering_duration_secs is not None and - self.max_buffering_duration_secs >= 0), ( - 'max_buffering_duration must be a non-negative value') + self.max_buffering_duration_secs + >= 0), ('max_buffering_duration must be a non-negative value') def get_payload(self): return beam_runner_api_pb2.GroupIntoBatchesPayload( @@ -1232,6 +1243,7 @@ def _pardo_group_into_batches( BUFFERING_TIMER = TimerSpec('buffering_end', TimeDomain.REAL_TIME) class _GroupIntoBatchesDoFn(DoFn): + def process( self, element, @@ -1330,7 +1342,9 @@ class LogElements(PTransform): `logging.INFO`, `logging.WARNING`, `logging.ERROR`). If not specified, the log is printed to stdout. """ + class _LoggingFn(DoFn): + def __init__( self, prefix='', with_timestamp=False, with_window=False, level=None): super().__init__() @@ -1391,11 +1405,13 @@ def expand(self, input): class Reify(object): """PTransforms for converting between explicit and implicit form of various Beam values.""" + @typehints.with_input_types(T) @typehints.with_output_types(T) class Timestamp(PTransform): """PTransform to wrap a value in a TimestampedValue with it's associated timestamp.""" + @staticmethod def add_timestamp_info(element, timestamp=DoFn.TimestampParam): yield TimestampedValue(element, timestamp) @@ -1409,6 +1425,7 @@ class Window(PTransform): """PTransform to convert an element in a PCollection into a tuple of (element, timestamp, window), wrapped in a TimestampedValue with it's associated timestamp.""" + @staticmethod def add_window_info( element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam): @@ -1422,6 +1439,7 @@ def expand(self, pcoll): class TimestampInValue(PTransform): """PTransform to wrap the Value in a KV pair in a TimestampedValue with the element's associated timestamp.""" + @staticmethod def add_timestamp_info(element, timestamp=DoFn.TimestampParam): key, value = element @@ -1436,6 +1454,7 @@ class WindowInValue(PTransform): """PTransform to convert the Value in a KV pair into a tuple of (value, timestamp, window), with the whole element being wrapped inside a TimestampedValue.""" + @staticmethod def add_window_info( element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam): @@ -1683,6 +1702,7 @@ class Tee(PTransform): | Tee(SomeSideTransform()) | ...) """ + def __init__( self, *consumers: Union[PTransform[PCollection[T], Any], @@ -1719,6 +1739,7 @@ class WaitOn(PTransform): This barrier often induces a fusion break. """ + def __init__(self, *to_be_waited_on): self._to_be_waited_on = to_be_waited_on diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index d86509c7dde3..947d234f9873 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -82,6 +82,7 @@ class CoGroupByKeyTest(unittest.TestCase): + def test_co_group_by_key_on_tuple(self): with TestPipeline() as pipeline: pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), @@ -182,6 +183,7 @@ def test_co_group_by_key_on_one(self): class FakeClock(object): + def __init__(self, now=time.time()): self._now = now @@ -288,8 +290,8 @@ def test_global_batch_timestamps(self): | beam.Create(range(3), reshuffle=False) | util.BatchElements(min_batch_size=2, max_batch_size=2) | beam.Map( - lambda batch, - timestamp=beam.DoFn.TimestampParam: (len(batch), timestamp))) + lambda batch, timestamp=beam.DoFn.TimestampParam: + (len(batch), timestamp))) assert_that( res, equal_to([ @@ -301,12 +303,19 @@ def test_sized_batches(self): with TestPipeline() as p: res = ( p - | beam.Create([ - 'a', 'a', # First batch. - 'aaaaaaaaaa', # Second batch. - 'aaaaa', 'aaaaa', # Third batch. - 'a', 'aaaaaaa', 'a', 'a' # Fourth batch. - ], reshuffle=False) + | beam.Create( + [ + 'a', + 'a', # First batch. + 'aaaaaaaaaa', # Second batch. + 'aaaaa', + 'aaaaa', # Third batch. + 'a', + 'aaaaaaa', + 'a', + 'a' # Fourth batch. + ], + reshuffle=False) | util.BatchElements( min_batch_size=10, max_batch_size=10, element_size_fn=len) | beam.Map(lambda batch: ''.join(batch)) @@ -330,10 +339,10 @@ def test_sized_windowed_batches(self): assert_that( res, equal_to([ - 'a' * (1+2), # Elements in [1, 3) - 'a' * (3+4), # Elements in [3, 6) + 'a' * (1 + 2), # Elements in [1, 3) + 'a' * (3 + 4), # Elements in [3, 6) 'a' * 5, - 'a' * 6, # Elements in [6, 9) + 'a' * 6, # Elements in [6, 9) 'a' * 7, ])) @@ -552,8 +561,8 @@ def test_stateful_buffering_timer_in_fixed_window_streaming(self): start_time = timestamp.Timestamp(0) test_stream = ( TestStream().add_elements([ - TimestampedValue(value, start_time + i) for i, - value in enumerate(BatchElementsTest._create_test_data()) + TimestampedValue(value, start_time + i) + for i, value in enumerate(BatchElementsTest._create_test_data()) ]).advance_processing_time(150).advance_watermark_to( start_time + window_duration).advance_watermark_to( start_time + window_duration + @@ -651,11 +660,13 @@ def test_stateful_grows_to_max_batch(self): class IdentityWindowTest(unittest.TestCase): + def test_window_preserved(self): expected_timestamp = timestamp.Timestamp(5) expected_window = window.IntervalWindow(1.0, 2.0) class AddWindowDoFn(beam.DoFn): + def process(self, element): yield WindowedValue(element, expected_timestamp, [expected_window]) @@ -691,6 +702,7 @@ def test_no_window_context_fails(self): expected_window = window.GlobalWindow() class AddTimestampDoFn(beam.DoFn): + def process(self, element): yield window.TimestampedValue(element, expected_timestamp) @@ -728,6 +740,7 @@ def process(self, element): class ReshuffleTest(unittest.TestCase): + def test_reshuffle_contents_unchanged(self): with TestPipeline() as pipeline: data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)] @@ -794,14 +807,13 @@ def test_reshuffle_windows_unchanged(self): data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] expected_data = [ TestWindowedValue(v, t - .001, [w]) - for (v, t, w) in [((1, contains_in_any_order([2, 1])), - 4.0, - IntervalWindow(1.0, 4.0)), - ((2, contains_in_any_order([2, 1])), - 4.0, + for (v, t, w) in [((1, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), ( - (3, [1]), 3.0, IntervalWindow(1.0, 3.0)), ( - (1, [4]), 6.0, IntervalWindow(4.0, 6.0))] + (2, contains_in_any_order([2, 1])), 4.0, + IntervalWindow(1.0, 4.0)), (( + 3, [1]), 3.0, IntervalWindow(1.0, 3.0)), (( + 1, + [4]), 6.0, IntervalWindow(4.0, 6.0))] ] before_reshuffle = ( pipeline @@ -828,13 +840,12 @@ def test_reshuffle_window_fn_preserved(self): data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] expected_windows = [ TestWindowedValue(v, t, [w]) - for (v, t, w) in [((1, 1), 1.0, IntervalWindow(1.0, 3.0)), ( - (2, 1), 1.0, IntervalWindow(1.0, 3.0)), ( - (3, 1), 1.0, IntervalWindow(1.0, 3.0)), ( - (1, 2), 2.0, IntervalWindow(2.0, 4.0)), ( + for (v, t, w) in [((1, 1), 1.0, IntervalWindow(1.0, 3.0)), (( + 2, 1), 1.0, IntervalWindow(1.0, 3.0)), (( + 3, 1), 1.0, IntervalWindow(1.0, 3.0)), (( + 1, 2), 2.0, IntervalWindow(2.0, 4.0)), ( (2, 2), 2.0, - IntervalWindow(2.0, 4.0)), ((1, 4), - 4.0, + IntervalWindow(2.0, 4.0)), ((1, 4), 4.0, IntervalWindow(4.0, 6.0))] ] expected_merged_windows = [ @@ -843,8 +854,7 @@ def test_reshuffle_window_fn_preserved(self): w) in [((1, any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), ( (2, any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), ( (3, [1]), 3.0, - IntervalWindow(1.0, 3.0)), ((1, [4]), - 6.0, + IntervalWindow(1.0, 3.0)), ((1, [4]), 6.0, IntervalWindow(4.0, 6.0))] ] before_reshuffle = ( @@ -1012,6 +1022,7 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam): class WithKeysTest(unittest.TestCase): + def setUp(self): self.l = [1, 2, 3] @@ -1019,7 +1030,8 @@ def test_constant_k(self): with TestPipeline() as p: pc = p | beam.Create(self.l) with_keys = pc | util.WithKeys('k') - assert_that(with_keys, equal_to([('k', 1), ('k', 2), ('k', 3)], )) + assert_that(with_keys, equal_to([('k', 1), ('k', 2), ('k', 3)], + )) def test_callable_k(self): with TestPipeline() as p: @@ -1044,9 +1056,7 @@ def test_sideinputs(self): si1 = AsList(p | "side input 1" >> beam.Create([1, 2, 3])) si2 = AsSingleton(p | "side input 2" >> beam.Create([10])) with_keys = pc | util.WithKeys( - lambda x, - the_list, - the_singleton: x + sum(the_list) + the_singleton, + lambda x, the_list, the_singleton: x + sum(the_list) + the_singleton, si1, the_singleton=si2) assert_that(with_keys, equal_to([(17, 1), (18, 2), (19, 3)])) @@ -1129,8 +1139,8 @@ def test_buffering_timer_in_fixed_window_streaming(self): start_time = timestamp.Timestamp(0) test_stream = ( TestStream().add_elements([ - TimestampedValue(value, start_time + i) for i, - value in enumerate(GroupIntoBatchesTest._create_test_data()) + TimestampedValue(value, start_time + i) + for i, value in enumerate(GroupIntoBatchesTest._create_test_data()) ]).advance_processing_time(150).advance_watermark_to( start_time + window_duration).advance_watermark_to( start_time + window_duration + @@ -1287,6 +1297,7 @@ def test_runner_api(self): class ToStringTest(unittest.TestCase): + def test_tostring_elements(self): with TestPipeline() as p: result = (p | beam.Create([1, 1, 2, 3]) | util.ToString.Element()) @@ -1324,6 +1335,7 @@ def test_tostring_kvs_empty_delimeter(self): class LogElementsTest(unittest.TestCase): + @pytest.fixture(scope="function") def _capture_stdout_log(request, capsys): with TestPipeline() as p: @@ -1389,6 +1401,7 @@ def test_setting_level_uses_appropriate_log_channel(self): class ReifyTest(unittest.TestCase): + def test_timestamp(self): l = [ TimestampedValue('a', 100), @@ -1476,6 +1489,7 @@ def test_window_in_value(self): class RegexTest(unittest.TestCase): + def test_find(self): with TestPipeline() as p: result = ( @@ -1842,6 +1856,7 @@ def count_side_effects(element): class WaitOnTest(unittest.TestCase): + def test_find(self): # We need shared reference that survives pickling. def increment_global_counter(): diff --git a/sdks/python/apache_beam/transforms/validate_runner_xlang_test.py b/sdks/python/apache_beam/transforms/validate_runner_xlang_test.py index 8e8e79648250..99dc61ef9427 100644 --- a/sdks/python/apache_beam/transforms/validate_runner_xlang_test.py +++ b/sdks/python/apache_beam/transforms/validate_runner_xlang_test.py @@ -74,6 +74,7 @@ class CrossLanguageTestPipelines(object): + def __init__(self, expansion_service=None): self.expansion_service = expansion_service or ( 'localhost:%s' % os.environ.get('EXPANSION_PORT')) diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py index fc20174ca1e2..0c9656d48189 100644 --- a/sdks/python/apache_beam/transforms/window.py +++ b/sdks/python/apache_beam/transforms/window.py @@ -121,8 +121,10 @@ def get_impl( class WindowFn(urns.RunnerApiFn, metaclass=abc.ABCMeta): """An abstract windowing function defining a basic assign and merge.""" + class AssignContext(object): """Context passed to WindowFn.assign().""" + def __init__( self, timestamp: TimestampTypes, @@ -149,6 +151,7 @@ def assign(self, class MergeContext(object): """Context passed to WindowFn.merge() to perform merging, if any.""" + def __init__(self, windows: Iterable['BoundedWindow']) -> None: self.windows = list(windows) @@ -200,6 +203,7 @@ class BoundedWindow(object): Attributes: end: End of window. """ + def __init__(self, end: TimestampTypes) -> None: self._end = Timestamp.of(end) @@ -256,6 +260,7 @@ class IntervalWindow(windowed_value._IntervalWindowBase, BoundedWindow): start: Start of window as seconds since Unix epoch. end: End of window as seconds since Unix epoch. """ + def __lt__(self, other): if self.end != other.end: return self.end < other.end @@ -280,6 +285,7 @@ class TimestampedValue(Generic[V]): value: The underlying value. timestamp: Timestamp associated with the value as seconds since Unix epoch. """ + def __init__(self, value: V, timestamp: TimestampTypes) -> None: self.value = value self.timestamp = Timestamp.of(timestamp) @@ -334,6 +340,7 @@ def _getTimestampFromProto() -> Timestamp: class NonMergingWindowFn(WindowFn): + def is_merging(self) -> bool: return False @@ -343,6 +350,7 @@ def merge(self, merge_context: WindowFn.MergeContext) -> None: class GlobalWindows(NonMergingWindowFn): """A windowing function that assigns everything to one global window.""" + @classmethod def windowed_batch( cls, @@ -404,6 +412,7 @@ class FixedWindows(NonMergingWindowFn): value in range [0, size). If it is not it will be normalized to this range. """ + def __init__(self, size: DurationTypes, offset: TimestampTypes = 0): """Initialize a ``FixedWindows`` function for a given size and offset. @@ -467,6 +476,7 @@ class SlidingWindows(NonMergingWindowFn): t=N * period + offset where t=0 is the epoch. The offset must be a value in range [0, period). If it is not it will be normalized to this range. """ + def __init__( self, size: DurationTypes, @@ -487,9 +497,8 @@ def assign(self, context: WindowFn.AssignContext) -> List[IntervalWindow]: (interval_start := Timestamp(micros=s)), interval_start + self.size, ) for s in range( - start.micros, - timestamp.micros - self.size.micros, - -self.period.micros) + start.micros, timestamp.micros - + self.size.micros, -self.period.micros) ] def get_window_coder(self) -> coders.IntervalWindowCoder: @@ -536,6 +545,7 @@ class Sessions(WindowFn): Attributes: gap_size: Size of the gap between windows as floating-point seconds. """ + def __init__(self, gap_size: DurationTypes) -> None: if gap_size <= 0: raise ValueError('The size parameter must be strictly positive.') diff --git a/sdks/python/apache_beam/transforms/window_test.py b/sdks/python/apache_beam/transforms/window_test.py index 3d73f92fb368..c5ede54ce56d 100644 --- a/sdks/python/apache_beam/transforms/window_test.py +++ b/sdks/python/apache_beam/transforms/window_test.py @@ -58,6 +58,7 @@ def context(element, timestamp): class ReifyWindowsFn(core.DoFn): + def process(self, element, window=core.DoFn.WindowParam): key, values = element yield "%s @ %s" % (key, window), values @@ -67,6 +68,7 @@ class TestCustomWindows(NonMergingWindowFn): """A custom non merging window fn which assigns elements into interval windows [0, 3), [3, 5) and [5, element timestamp) based on the element timestamps. """ + def assign(self, context): timestamp = context.timestamp if timestamp < 3: @@ -81,6 +83,7 @@ def get_window_coder(self): class WindowTest(unittest.TestCase): + def test_timestamped_value_cmp(self): self.assertEqual(TimestampedValue('a', 2), TimestampedValue('a', 2)) self.assertEqual(TimestampedValue('a', 2), TimestampedValue('a', 2.0)) @@ -145,6 +148,7 @@ def merge(*timestamps): running = set() class TestMergeContext(WindowFn.MergeContext): + def __init__(self): super().__init__(running) @@ -346,6 +350,7 @@ def test_window_assignment_through_multiple_gbk_idempotency(self): class RunnerApiTest(unittest.TestCase): + def test_windowfn_encoding(self): for window_fn in (GlobalWindows(), FixedWindows(37), diff --git a/sdks/python/apache_beam/transforms/write_ptransform_test.py b/sdks/python/apache_beam/transforms/write_ptransform_test.py index ce402d8d3062..600fb8621780 100644 --- a/sdks/python/apache_beam/transforms/write_ptransform_test.py +++ b/sdks/python/apache_beam/transforms/write_ptransform_test.py @@ -83,6 +83,7 @@ def write(self, value): class WriteToTestSink(PTransform): + def __init__(self, return_init_result=True, return_write_results=True): self.return_init_result = return_init_result self.return_write_results = return_write_results diff --git a/sdks/python/apache_beam/typehints/arrow_type_compatibility.py b/sdks/python/apache_beam/typehints/arrow_type_compatibility.py index 34a37a886bab..a7d3747641ac 100644 --- a/sdks/python/apache_beam/typehints/arrow_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/arrow_type_compatibility.py @@ -56,8 +56,8 @@ def beam_schema_from_arrow_schema(arrow_schema: pa.Schema) -> schema_pb2.Schema: if arrow_schema.metadata: schema_id = arrow_schema.metadata.get(BEAM_SCHEMA_ID_KEY, None) schema_options = [ - _hydrate_beam_option(value) for key, - value in arrow_schema.metadata.items() + _hydrate_beam_option(value) + for key, value in arrow_schema.metadata.items() if key.startswith(BEAM_SCHEMA_OPTION_KEY_PREFIX) ] else: @@ -78,14 +78,14 @@ def _beam_field_from_arrow_field(arrow_field: pa.Field) -> schema_pb2.Field: if arrow_field.metadata: field_options = [ - _hydrate_beam_option(value) for key, - value in arrow_field.metadata.items() + _hydrate_beam_option(value) + for key, value in arrow_field.metadata.items() if key.startswith(BEAM_FIELD_OPTION_KEY_PREFIX) ] if isinstance(arrow_field.type, pa.StructType): beam_fieldtype.row_type.schema.options.extend([ - _hydrate_beam_option(value) for key, - value in arrow_field.metadata.items() + _hydrate_beam_option(value) + for key, value in arrow_field.metadata.items() if key.startswith(BEAM_SCHEMA_OPTION_KEY_PREFIX) ]) if BEAM_SCHEMA_ID_KEY in arrow_field.metadata: @@ -295,6 +295,7 @@ def _arrow_type_from_beam_fieldtype( class PyarrowBatchConverter(BatchConverter): + def __init__(self, element_type: RowTypeConstraint): super().__init__(pa.Table, element_type) self._beam_schema = typing_to_runner_api(element_type).row_type.schema @@ -320,8 +321,8 @@ def from_typehints(element_type, def produce_batch(self, elements): arrays = [ pa.array([getattr(el, name) for el in elements], - type=self._arrow_schema.field(name).type) for name, - _ in self._element_type._fields + type=self._arrow_schema.field(name).type) + for name, _ in self._element_type._fields ] return pa.Table.from_arrays(arrays, schema=self._arrow_schema) @@ -331,8 +332,7 @@ def explode_batch(self, batch: pa.Table): yield self._element_type.user_type( **{ name: val.as_py() - for name, - val in zip(self._arrow_schema.names, row_values) + for name, val in zip(self._arrow_schema.names, row_values) }) def combine_batches(self, batches: List[pa.Table]): @@ -357,6 +357,7 @@ def __reduce__(self): class PyarrowArrayBatchConverter(BatchConverter): + def __init__(self, element_type: type): super().__init__(pa.Array, element_type) self._element_type = element_type diff --git a/sdks/python/apache_beam/typehints/arrow_type_compatibility_test.py b/sdks/python/apache_beam/typehints/arrow_type_compatibility_test.py index 1e9ab3f27bd9..9ff8af710927 100644 --- a/sdks/python/apache_beam/typehints/arrow_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/arrow_type_compatibility_test.py @@ -38,6 +38,7 @@ @pytest.mark.uses_pyarrow class ArrowTypeCompatibilityTest(unittest.TestCase): + @parameterized.expand([(beam_schema, ) for beam_schema in get_test_beam_schemas_protos()]) def test_beam_schema_survives_roundtrip(self, beam_schema): @@ -46,6 +47,7 @@ def test_beam_schema_survives_roundtrip(self, beam_schema): self.assertEqual(beam_schema, roundtripped) + @parameterized_class([ { 'batch_typehint': pa.Table, @@ -88,18 +90,17 @@ def test_beam_schema_survives_roundtrip(self, beam_schema): { 'batch_typehint': pa.Array, 'element_typehint': row_type.RowTypeConstraint.from_fields([ - ("bar", Optional[float]), # noqa: F821 - ("baz", Optional[str]), # noqa: F821 - ]), - 'batch': pa.array([ - { - 'bar': i / 100, 'baz': str(i) - } if i % 7 else None for i in range(100) + ("bar", Optional[float]), # noqa: F821 + ("baz", Optional[str]), # noqa: F821 ]), + 'batch': pa.array([{ + 'bar': i / 100, 'baz': str(i) + } if i % 7 else None for i in range(100)]), } ]) @pytest.mark.uses_pyarrow class ArrowBatchConverterTest(unittest.TestCase): + def create_batch_converter(self): return BatchConverter.from_typehints( element_type=self.element_typehint, batch_type=self.batch_typehint) @@ -194,20 +195,21 @@ def test_hash(self): class ArrowBatchConverterErrorsTest(unittest.TestCase): + @parameterized.expand([ - ( - pa.RecordBatch, - row_type.RowTypeConstraint.from_fields([ - ("bar", Optional[float]), # noqa: F821 - ("baz", Optional[str]), # noqa: F821 - ]), - r'batch type must be pa\.Table or pa\.Array', - ), - ( - pa.Table, - Any, - r'Element type .* must be compatible with Beam Schemas', - ), + ( + pa.RecordBatch, + row_type.RowTypeConstraint.from_fields([ + ("bar", Optional[float]), # noqa: F821 + ("baz", Optional[str]), # noqa: F821 + ]), + r'batch type must be pa\.Table or pa\.Array', + ), + ( + pa.Table, + Any, + r'Element type .* must be compatible with Beam Schemas', + ), ]) def test_construction_errors( self, batch_typehint, element_typehint, error_regex): diff --git a/sdks/python/apache_beam/typehints/batch.py b/sdks/python/apache_beam/typehints/batch.py index 35351b147d48..789d4680a3a2 100644 --- a/sdks/python/apache_beam/typehints/batch.py +++ b/sdks/python/apache_beam/typehints/batch.py @@ -51,6 +51,7 @@ class BatchConverter(Generic[B, E]): + def __init__(self, batch_type, element_type): self._batch_type = batch_type self._element_type = element_type @@ -74,6 +75,7 @@ def estimate_byte_size(self, batch): @staticmethod def register(*, name: str): + def do_registration( batch_converter_constructor: Callable[[type, type], 'BatchConverter']): if name in BATCH_CONVERTER_REGISTRY: @@ -169,6 +171,7 @@ def estimate_byte_size(self, batch): class NumpyBatchConverter(BatchConverter): + def __init__( self, batch_type, @@ -241,7 +244,9 @@ def estimate_byte_size(self, batch): # specifying shape, seems to be coming after # https://www.python.org/dev/peps/pep-0646/ class NumpyTypeHint(): + class NumpyTypeConstraint(typehints.TypeConstraint): + def __init__(self, dtype, shape=()): self.dtype = np.dtype(dtype) self.shape = shape @@ -289,7 +294,8 @@ def __getitem__(self, value): raise ValueError else: dtype = value - return self.NumpyTypeConstraint(dtype, shape=(N, )) + return self.NumpyTypeConstraint( + dtype, shape=(N, )) NumpyArray = NumpyTypeHint() diff --git a/sdks/python/apache_beam/typehints/batch_test.py b/sdks/python/apache_beam/typehints/batch_test.py index 3fbad76fce06..bcd320deb5b9 100644 --- a/sdks/python/apache_beam/typehints/batch_test.py +++ b/sdks/python/apache_beam/typehints/batch_test.py @@ -56,6 +56,7 @@ }, ]) class BatchConverterTest(unittest.TestCase): + def create_batch_converter(self): return BatchConverter.from_typehints( element_type=self.element_typehint, batch_type=self.batch_typehint) @@ -150,6 +151,7 @@ def test_hash(self): class BatchConverterErrorsTest(unittest.TestCase): + @parameterized.expand([ ( typing.List[int], diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index 7050df7016e5..e37a84256c6e 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -352,11 +352,10 @@ def strip_pcoll_helper( my_type: any, has_my_type: Callable[[], bool], my_key: str, - special_containers: List[ - Union['PBegin', 'PDone', 'PCollection']], # noqa: F821 + special_containers: List[Union['PBegin', 'PDone', + 'PCollection']], # noqa: F821 error_str: str, - source_str: str - ) -> 'IOTypeHints': + source_str: str) -> 'IOTypeHints': from apache_beam.pvalue import PCollection if not has_my_type() or not my_type or len(my_type[0]) != 1: @@ -468,6 +467,7 @@ def debug_str(self): return '\n'.join([self.__repr__()] + self.origin) def __eq__(self, other): + def same(a, b): if a is None or not any(a): return b is None or not any(b) @@ -489,6 +489,7 @@ def __reduce__(self): class WithTypeHints(object): """A mixin class that provides the ability to set and retrieve type hints. """ + def __init__(self, *unused_args, **unused_kwargs): self._type_hints = IOTypeHints.empty() @@ -571,8 +572,8 @@ def _unpack_positional_arg_hints(arg, hint): (arg, tuple_constraint, hint)) if isinstance(hint, typehints.TupleConstraint): return tuple( - _unpack_positional_arg_hints(a, t) for a, - t in zip(arg, hint.tuple_types)) + _unpack_positional_arg_hints(a, t) + for a, t in zip(arg, hint.tuple_types)) return (typehints.Any, ) * len(arg) return hint @@ -923,6 +924,7 @@ def gen(): iteration. If the generator received is already wrapped, then it is simply returned to avoid nested wrapping. """ + def wrapper(gen): if isinstance(gen, GeneratorWrapper): return gen @@ -945,6 +947,7 @@ class GeneratorWrapper(object): be called with the result of each yielded 'step' in the internal generator. """ + def __init__(self, gen, interleave_func): self.internal_gen = gen self.interleave_func = interleave_func diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index dd110ced5bb8..810a074a9b67 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -42,6 +42,7 @@ class IOTypeHintsTest(unittest.TestCase): + def test_get_signature(self): # Basic coverage only to make sure function works. def fn(a, b=1, *c, **d): @@ -56,6 +57,7 @@ def test_get_signature_builtin(self): self.assertEqual(s.return_annotation, List[Any]) def test_from_callable_without_annotations(self): + def fn(a, b=None, *args, **kwargs): return a, b, args, kwargs @@ -186,6 +188,7 @@ def test_with_defaults_noop_does_not_grow_origin(self): self.assertNotEqual(expected_id, id(th)) def test_from_callable(self): + def fn( a: int, b: str = '', @@ -202,6 +205,7 @@ def fn( self.assertEqual(th.output_types, ((Tuple[Any, ...], ), {})) def test_from_callable_partial_annotations(self): + def fn(a: int, b=None, *args, foo: List[int], **kwargs): return a, b, args, foo, kwargs @@ -214,7 +218,9 @@ def fn(a: int, b=None, *args, foo: List[int], **kwargs): self.assertEqual(th.output_types, ((Any, ), {})) def test_from_callable_class(self): + class Class(object): + def __init__(self, unused_arg: int): pass @@ -223,7 +229,9 @@ def __init__(self, unused_arg: int): self.assertEqual(th.output_types, ((Class, ), {})) def test_from_callable_method(self): + class Class(object): + def method(self, arg: T = None) -> None: pass @@ -236,6 +244,7 @@ def method(self, arg: T = None) -> None: self.assertEqual(th.output_types, ((None, ), {})) def test_from_callable_convert_to_beam_types(self): + def fn( a: typing.List[int], b: str = '', @@ -253,6 +262,7 @@ def fn( self.assertEqual(th.output_types, ((Tuple[Any, ...], ), {})) def test_from_callable_partial(self): + def fn(a: int) -> int: return a @@ -262,6 +272,7 @@ def fn(a: int) -> int: self.assertRegex(th.debug_str(), r'unknown') def test_getcallargs_forhints(self): + def fn( a: int, b: str = '', @@ -298,6 +309,7 @@ def fn(a=List[int], b=None, *args, foo=(), **kwargs) -> Tuple[Any, ...]: }) def test_getcallargs_forhints_missing_arg(self): + def fn(a, b=None, *args, foo, **kwargs): return a, b, args, foo, kwargs @@ -307,6 +319,7 @@ def fn(a, b=None, *args, foo, **kwargs): decorators.getcallargs_forhints(fn, 5) def test_origin_annotated(self): + def annotated(e: str) -> str: return e @@ -321,7 +334,9 @@ def annotated(e: str) -> str: class WithTypeHintsTest(unittest.TestCase): + def test_get_type_hints_no_settings(self): + class Base(WithTypeHints): pass @@ -330,6 +345,7 @@ class Base(WithTypeHints): self.assertEqual(th.output_types, None) def test_get_type_hints_class_decorators(self): + @decorators.with_input_types(int, str) @decorators.with_output_types(int) class Base(WithTypeHints): @@ -340,7 +356,9 @@ class Base(WithTypeHints): self.assertEqual(th.output_types, ((int, ), {})) def test_get_type_hints_class_defaults(self): + class Base(WithTypeHints): + def default_type_hints(self): return decorators.IOTypeHints( input_types=((int, str), {}), output_types=((int, ), {}), origin=[]) @@ -350,9 +368,11 @@ def default_type_hints(self): self.assertEqual(th.output_types, ((int, ), {})) def test_get_type_hints_precedence_defaults_over_decorators(self): + @decorators.with_input_types(int) @decorators.with_output_types(str) class Base(WithTypeHints): + def default_type_hints(self): return decorators.IOTypeHints( input_types=((float, ), {}), output_types=None, origin=[]) @@ -362,7 +382,9 @@ def default_type_hints(self): self.assertEqual(th.output_types, ((str, ), {})) def test_get_type_hints_precedence_instance_over_defaults(self): + class Base(WithTypeHints): + def default_type_hints(self): return decorators.IOTypeHints( input_types=((float, ), {}), output_types=((str, ), {}), origin=[]) @@ -375,6 +397,7 @@ def test_inherits_does_not_modify(self): # See BEAM-8629. @decorators.with_output_types(int) class Subclass(WithTypeHints): + def __init__(self): pass # intentionally avoiding super call @@ -388,6 +411,7 @@ def __init__(self): class DecoratorsTest(unittest.TestCase): + def tearDown(self): decorators._disable_from_callable = False @@ -397,6 +421,7 @@ def test_disable_type_annotations(self): self.assertTrue(decorators._disable_from_callable) def test_no_annotations_on_same_function(self): + def fn(a: int) -> int: return a @@ -409,6 +434,7 @@ def fn(a: int) -> int: _ = ['a', 'b', 'c'] | Map(fn) def test_no_annotations_on_diff_function(self): + def fn(a: int) -> int: return a diff --git a/sdks/python/apache_beam/typehints/intrinsic_one_ops_test.py b/sdks/python/apache_beam/typehints/intrinsic_one_ops_test.py index adffc945baad..08ff23ea4654 100644 --- a/sdks/python/apache_beam/typehints/intrinsic_one_ops_test.py +++ b/sdks/python/apache_beam/typehints/intrinsic_one_ops_test.py @@ -27,6 +27,7 @@ class IntrinsicOneOpsTest(unittest.TestCase): + def test_unary_intrinsic_ops_are_in_the_same_order_as_in_cpython(self): if sys.version_info >= (3, 12): dis_order = dis.__dict__['_intrinsic_1_descs'] diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 3f57a573b505..bbdeb45ffa5a 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -332,8 +332,9 @@ def convert_to_beam_type(typ): # This is needed to fix https://github.com/apache/beam/issues/33356 pass - elif (typ_module != 'typing') and (typ_module != - 'collections.abc') and not is_builtin(typ): + elif (typ_module + != 'typing') and (typ_module + != 'collections.abc') and not is_builtin(typ): # Only translate primitives and types from collections.abc and typing. return typ if (typ_module == 'collections.abc' and diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py index 15b5da99fb0c..6d114548fa64 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py @@ -62,6 +62,7 @@ class _TestEnum(enum.Enum): class NativeTypeCompatibilityTest(unittest.TestCase): + def test_convert_to_beam_type(self): test_cases = [ ('raw bytes', bytes, bytes), diff --git a/sdks/python/apache_beam/typehints/opcodes.py b/sdks/python/apache_beam/typehints/opcodes.py index 7bea621841f6..17a299b4b2c0 100644 --- a/sdks/python/apache_beam/typehints/opcodes.py +++ b/sdks/python/apache_beam/typehints/opcodes.py @@ -68,6 +68,7 @@ def pop_three(state, unused_arg): def push_value(v): + def pusher(state, unused_arg): state.stack.append(v) diff --git a/sdks/python/apache_beam/typehints/pandas_type_compatibility.py b/sdks/python/apache_beam/typehints/pandas_type_compatibility.py index ca9523f28349..9d136eb6f8ba 100644 --- a/sdks/python/apache_beam/typehints/pandas_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/pandas_type_compatibility.py @@ -91,8 +91,7 @@ PANDAS_TO_BEAM = { pd.Series([], dtype=dtype).dtype: fieldtype - for dtype, - fieldtype in _BIDIRECTIONAL + for dtype, fieldtype in _BIDIRECTIONAL } BEAM_TO_PANDAS = {fieldtype: dtype for dtype, fieldtype in _BIDIRECTIONAL} @@ -150,6 +149,7 @@ def create_pandas_batch_converter( class DataFrameBatchConverter(BatchConverter): + def __init__( self, element_type: RowTypeConstraint, @@ -196,8 +196,10 @@ def make_null_checking_generator(series): for values in zip(*iterators): yield self._element_type.user_type( - **{column: value - for column, value in zip(self._columns, values)}) + **{ + column: value + for column, value in zip(self._columns, values) + }) def combine_batches(self, batches: List[pd.DataFrame]): return pd.concat(batches) @@ -215,6 +217,7 @@ class DataFrameBatchConverterDropIndex(DataFrameBatchConverter): When producing a DataFrame from Rows, a meaningless index will be generated. When exploding a DataFrame into Rows, the index will be dropped. """ + def _get_series(self, batch: pd.DataFrame): return [batch[column] for column in batch.columns] @@ -233,6 +236,7 @@ class DataFrameBatchConverterKeepIndex(DataFrameBatchConverter): This is tracked via options on the Beam schema. Each field in the schema that should map to the index is tagged in an option with name 'dataframe:index'. """ + def __init__(self, element_type: RowTypeConstraint, index_columns: List[Any]): super().__init__(element_type) self._index_columns = index_columns @@ -254,6 +258,7 @@ def produce_batch(self, elements): class SeriesBatchConverter(BatchConverter): + def __init__( self, element_type: type, diff --git a/sdks/python/apache_beam/typehints/pandas_type_compatibility_test.py b/sdks/python/apache_beam/typehints/pandas_type_compatibility_test.py index ff66df1ce968..455d666a441c 100644 --- a/sdks/python/apache_beam/typehints/pandas_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/pandas_type_compatibility_test.py @@ -117,6 +117,7 @@ }, ]) class PandasBatchConverterTest(unittest.TestCase): + def create_batch_converter(self): return BatchConverter.from_typehints( element_type=self.element_typehint, batch_type=self.batch_typehint) @@ -210,20 +211,21 @@ def test_hash(self): class PandasBatchConverterErrorsTest(unittest.TestCase): + @parameterized.expand([ - ( - Any, - row_type.RowTypeConstraint.from_fields([ - ("bar", Optional[float]), # noqa: F821 - ("baz", Optional[str]), # noqa: F821 - ]), - r'batch type must be pd\.Series or pd\.DataFrame', - ), - ( - pd.DataFrame, - Any, - r'Element type must be compatible with Beam Schemas', - ), + ( + Any, + row_type.RowTypeConstraint.from_fields([ + ("bar", Optional[float]), # noqa: F821 + ("baz", Optional[str]), # noqa: F821 + ]), + r'batch type must be pd\.Series or pd\.DataFrame', + ), + ( + pd.DataFrame, + Any, + r'Element type must be compatible with Beam Schemas', + ), ]) def test_construction_errors( self, batch_typehint, element_typehint, error_regex): diff --git a/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py b/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py index f008174bcc03..e455481839b0 100644 --- a/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py @@ -24,6 +24,7 @@ class PytorchBatchConverter(BatchConverter): + def __init__( self, batch_type, @@ -88,7 +89,9 @@ def estimate_byte_size(self, batch): class PytorchTypeHint(): + class PytorchTypeConstraint(typehints.TypeConstraint): + def __init__(self, dtype, shape=()): self.dtype = dtype self.shape = shape @@ -136,7 +139,8 @@ def __getitem__(self, value): raise ValueError else: dtype = value - return self.PytorchTypeConstraint(dtype, shape=(N, )) + return self.PytorchTypeConstraint( + dtype, shape=(N, )) PytorchTensor = PytorchTypeHint() diff --git a/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py b/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py index d1f5c0d271ee..4f864c35e3fd 100644 --- a/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py @@ -52,6 +52,7 @@ ]) @pytest.mark.uses_pytorch class PytorchBatchConverterTest(unittest.TestCase): + def create_batch_converter(self): return BatchConverter.from_typehints( element_type=self.element_typehint, batch_type=self.batch_typehint) @@ -136,6 +137,7 @@ def test_hash(self): class PytorchBatchConverterErrorsTest(unittest.TestCase): + @parameterized.expand([ ( Any, diff --git a/sdks/python/apache_beam/typehints/row_type.py b/sdks/python/apache_beam/typehints/row_type.py index 880a897bbbe8..38044f2d0d5e 100644 --- a/sdks/python/apache_beam/typehints/row_type.py +++ b/sdks/python/apache_beam/typehints/row_type.py @@ -44,6 +44,7 @@ def _user_type_is_generated(user_type: type) -> bool: class RowTypeConstraint(typehints.TypeConstraint): + def __init__( self, fields: Sequence[Tuple[str, type]], @@ -85,8 +86,7 @@ def __init__( """ # Recursively wrap row types in a RowTypeConstraint self._fields = tuple((name, RowTypeConstraint.from_user_type(typ) or typ) - for name, - typ in fields) + for name, typ in fields) self._user_type = user_type @@ -207,6 +207,7 @@ class GeneratedClassRowTypeConstraint(RowTypeConstraint): Since the generated user_type cannot be pickled, we supply a custom __reduce__ function that will regenerate the user_type. """ + def __init__( self, fields, diff --git a/sdks/python/apache_beam/typehints/schema_registry.py b/sdks/python/apache_beam/typehints/schema_registry.py index a73e97f43f70..2e2cd6da639f 100644 --- a/sdks/python/apache_beam/typehints/schema_registry.py +++ b/sdks/python/apache_beam/typehints/schema_registry.py @@ -24,6 +24,7 @@ # Registry of typings for a schema by UUID class SchemaTypeRegistry(object): + def __init__(self): self.by_id = {} self.by_typing = {} diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index 2c5d35a68cc2..7edbcc60c2be 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -240,6 +240,7 @@ def schema_field( class SchemaTranslation(object): + def __init__(self, schema_registry: SchemaTypeRegistry = SCHEMA_REGISTRY): self.schema_registry = schema_registry @@ -579,6 +580,7 @@ def named_tuple_from_schema(self, schema: schema_pb2.Schema) -> type: def _named_tuple_reduce_method(serialized_schema): + def __reduce__(self): return _hydrate_namedtuple_instance, (serialized_schema, tuple(self)) @@ -643,6 +645,7 @@ def union_schema_type(element_types): class _Ephemeral: """Helper class for wrapping unpicklable objects.""" + def __init__(self, obj): self.obj = obj @@ -652,6 +655,7 @@ def __reduce__(self): # Registry of typings for a schema by UUID class LogicalTypeRegistry(object): + def __init__(self): self.by_urn = {} self.by_logical_type = {} @@ -806,6 +810,7 @@ def from_runner_api(cls, logical_type_proto): class NoArgumentLogicalType(LogicalType[LanguageT, RepresentationT, None]): + @classmethod def argument_type(cls): # type: () -> type @@ -828,6 +833,7 @@ class PassThroughLogicalType(LogicalType[LanguageT, LanguageT, ArgT]): """A base class for LogicalTypes that use the same type as the underlying representation type. """ + def to_language_type(self, value): return value @@ -867,6 +873,7 @@ class MillisInstant(NoArgumentLogicalType[Timestamp, np.int64]): To do this, re-register this class with :func:`~LogicalType.register_logical_type`. """ + @classmethod def representation_type(cls): # type: () -> type @@ -899,6 +906,7 @@ def to_language_type(self, value): class MicrosInstant(NoArgumentLogicalType[Timestamp, MicrosInstantRepresentation]): """Microsecond-precision instant logical type that handles ``Timestamp``.""" + @classmethod def urn(cls): return common_urns.micros_instant.urn @@ -925,6 +933,7 @@ def to_language_type(self, value): @LogicalType.register_logical_type class PythonCallable(NoArgumentLogicalType[PythonCallableWithSource, str]): """A logical type for PythonCallableSource objects.""" + @classmethod def urn(cls): return common_urns.python_callable.urn @@ -956,6 +965,7 @@ class DecimalLogicalType(NoArgumentLogicalType[decimal.Decimal, bytes]): """A logical type for decimal objects handling values consistent with that encoded by ``BigDecimalCoder`` in the Java SDK. """ + @classmethod def urn(cls): return common_urns.decimal.urn @@ -985,6 +995,7 @@ class FixedPrecisionDecimalLogicalType( FixedPrecisionDecimalArgumentRepresentation]): """A wrapper of DecimalLogicalType that contains the precision value. """ + def __init__(self, precision=-1, scale=0): self.precision = precision self.scale = scale @@ -1036,6 +1047,7 @@ def _from_typing(cls, typ): @LogicalType.register_logical_type class FixedBytes(PassThroughLogicalType[bytes, np.int32]): """A logical type for fixed-length bytes.""" + @classmethod def urn(cls): return common_urns.fixed_bytes.urn @@ -1069,6 +1081,7 @@ def argument(self): @LogicalType.register_logical_type class VariableBytes(PassThroughLogicalType[bytes, np.int32]): """A logical type for variable-length bytes with specified maximum length.""" + @classmethod def urn(cls): return common_urns.var_bytes.urn @@ -1099,6 +1112,7 @@ def argument(self): @LogicalType.register_logical_type class FixedString(PassThroughLogicalType[str, np.int32]): """A logical type for fixed-length string.""" + @classmethod def urn(cls): return common_urns.fixed_char.urn @@ -1132,6 +1146,7 @@ def argument(self): @LogicalType.register_logical_type class VariableString(PassThroughLogicalType[str, np.int32]): """A logical type for variable-length string with specified maximum length.""" + @classmethod def urn(cls): return common_urns.var_char.urn diff --git a/sdks/python/apache_beam/typehints/schemas_test.py b/sdks/python/apache_beam/typehints/schemas_test.py index 15144c6c2c17..e785d19c23b9 100644 --- a/sdks/python/apache_beam/typehints/schemas_test.py +++ b/sdks/python/apache_beam/typehints/schemas_test.py @@ -71,8 +71,8 @@ basic_array_types = [Sequence[typ] for typ in all_primitives] basic_map_types = [ - Mapping[key_type, value_type] for key_type, - value_type in itertools.product(all_primitives, all_primitives) + Mapping[key_type, value_type] for key_type, value_type in itertools.product( + all_primitives, all_primitives) ] @@ -129,8 +129,8 @@ def get_test_beam_fieldtype_protos(): basic_map_types = [ schema_pb2.FieldType( map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type)) - for key_type, - value_type in itertools.product(all_primitives, all_primitives) + for key_type, value_type in itertools.product( + all_primitives, all_primitives) ] selected_schemas = [ @@ -139,8 +139,8 @@ def get_test_beam_fieldtype_protos(): schema=schema_pb2.Schema( id='32497414-85e8-46b7-9c90-9a9cc62fe390', fields=[ - schema_pb2.Field(name='field%d' % i, type=typ) for i, - typ in enumerate(all_primitives) + schema_pb2.Field(name='field%d' % i, type=typ) + for i, typ in enumerate(all_primitives) ]))), schema_pb2.FieldType( row_type=schema_pb2.RowType( @@ -184,8 +184,8 @@ def get_test_beam_fieldtype_protos(): schema=schema_pb2.Schema( id='a-schema-with-options', fields=[ - schema_pb2.Field(name='field%d' % i, type=typ) for i, - typ in enumerate(all_primitives) + schema_pb2.Field(name='field%d' % i, type=typ) + for i, typ in enumerate(all_primitives) ], options=[ schema_pb2.Option(name='a_flag'), @@ -270,8 +270,7 @@ def get_test_beam_fieldtype_protos(): value=schema_pb2.FieldValue( atomic_value=schema_pb2.AtomicTypeValue( string='str'))), - ]) for i, - typ in enumerate(all_primitives) + ]) for i, typ in enumerate(all_primitives) ] + [ schema_pb2.Field( name='nested', @@ -571,8 +570,7 @@ def test_unknown_primitive_maps_to_any(self): def test_unknown_atomic_raise_valueerror(self): self.assertRaises( - ValueError, - lambda: typing_from_runner_api( + ValueError, lambda: typing_from_runner_api( schema_pb2.FieldType(atomic_type=schema_pb2.UNSPECIFIED))) def test_int_maps_to_int64(self): @@ -720,6 +718,7 @@ def test_named_fields_roundtrip(self, named_fields): }, ]) class PickleTest(unittest.TestCase): + def test_generated_class_pickle_instance(self): schema = schema_pb2.Schema( id="some-uuid", diff --git a/sdks/python/apache_beam/typehints/sharded_key_type.py b/sdks/python/apache_beam/typehints/sharded_key_type.py index 079ad48c2170..c5aa8e0bd7b0 100644 --- a/sdks/python/apache_beam/typehints/sharded_key_type.py +++ b/sdks/python/apache_beam/typehints/sharded_key_type.py @@ -35,6 +35,7 @@ class ShardedKeyTypeConstraint(typehints.TypeConstraint, metaclass=typehints.GetitemConstructor): + def __init__(self, key_type): typehints.validate_composite_type_param( key_type, error_msg_prefix='Parameter to ShardedKeyType hint') diff --git a/sdks/python/apache_beam/typehints/sharded_key_type_test.py b/sdks/python/apache_beam/typehints/sharded_key_type_test.py index 631b72bf0030..5e1e92837462 100644 --- a/sdks/python/apache_beam/typehints/sharded_key_type_test.py +++ b/sdks/python/apache_beam/typehints/sharded_key_type_test.py @@ -27,6 +27,7 @@ class ShardedKeyTypeConstraintTest(TypeHintTestCase): + def test_compatibility(self): constraint1 = ShardedKeyType[int] constraint2 = ShardedKeyType[str] diff --git a/sdks/python/apache_beam/typehints/testing/strategies.py b/sdks/python/apache_beam/typehints/testing/strategies.py index a388e25eea25..82b5fa473b0e 100644 --- a/sdks/python/apache_beam/typehints/testing/strategies.py +++ b/sdks/python/apache_beam/typehints/testing/strategies.py @@ -34,6 +34,7 @@ def field_names(): + @st.composite def field_name_candidates(draw): """Strategy to produce valid field names for Beam schema types.""" @@ -75,6 +76,7 @@ def _named_fields_from_types(types): def types(): """Strategy to produce types that are convertible to Beam schema FieldType instances.""" + def _extend_types(types): optionals = types.map(lambda typ: Optional[typ]) sequences = types.map(lambda typ: Sequence[typ]) diff --git a/sdks/python/apache_beam/typehints/trivial_inference.py b/sdks/python/apache_beam/typehints/trivial_inference.py index fe9007ed63ca..18bbdfee88de 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference.py +++ b/sdks/python/apache_beam/typehints/trivial_inference.py @@ -99,6 +99,7 @@ def union_list(xs, ys): class Const(object): + def __init__(self, value): self.value = value self.type = instance_to_type(value) @@ -126,6 +127,7 @@ def unwrap_all(xs): class FrameState(object): """Stores the state of the frame at a particular point of execution. """ + def __init__(self, f, local_vars=None, stack=(), kw_names=None): self.f = f self.co = f.__code__ @@ -209,6 +211,7 @@ def union(a, b): def finalize_hints(type_hint): """Sets type hint for empty data structures to Any.""" + def visitor(tc, unused_arg): if isinstance(tc, typehints.DictConstraint): empty_union = typehints.Union[()] @@ -258,6 +261,7 @@ def key_value_types(kv_type): class BoundMethod(object): """Used to create a bound method when we only know the type of the instance. """ + def __init__(self, func, type): """Instantiates a bound method object. @@ -546,8 +550,8 @@ def infer_return_type_func(f, input_types, debug=False, depth=0): # See https://github.com/python/cpython/issues/102403 for context. if (pop_count == 1 and last_real_opname == 'GET_ITER' and len(state.stack) > 1 and isinstance(state.stack[-2], Const) and - getattr(state.stack[-2].value, '__name__', None) in ( - '', '', '', '')): + getattr(state.stack[-2].value, '__name__', None) + in ('', '', '', '')): pop_count += 1 if depth <= 0 or pop_count > len(state.stack): return_type = Any diff --git a/sdks/python/apache_beam/typehints/trivial_inference_test.py b/sdks/python/apache_beam/typehints/trivial_inference_test.py index 48ccc8a6a2ed..0cfe268d088d 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference_test.py +++ b/sdks/python/apache_beam/typehints/trivial_inference_test.py @@ -32,6 +32,7 @@ class TrivialInferenceTest(unittest.TestCase): + def assertReturnType(self, expected, f, inputs=(), depth=5): self.assertEqual( expected, @@ -48,8 +49,8 @@ def testJumpOffsets(self): def testBuildListUnpack(self): # Lambda uses BUILD_LIST_UNPACK opcode in Python 3. self.assertReturnType( - typehints.List[int], - lambda _list: [*_list, *_list, *_list], [typehints.List[int]]) + typehints.List[int], lambda _list: [*_list, *_list, *_list], + [typehints.List[int]]) def testBuildTupleUnpack(self): # Lambda uses BUILD_TUPLE_UNPACK opcode in Python 3. @@ -63,16 +64,12 @@ def testBuildTupleUnpack(self): def testBuildSetUnpackOrUpdate(self): self.assertReturnType( typehints.Set[typehints.Union[int, str]], - lambda _list1, - _list2: {*_list1, *_list2, *_list2}, + lambda _list1, _list2: {*_list1, *_list2, *_list2}, [typehints.List[int], typehints.List[str]]) def testBuildMapUnpackOrUpdate(self): self.assertReturnType( - typehints.Dict[str, typehints.Union[int, str, float]], - lambda a, - b, - c: { + typehints.Dict[str, typehints.Union[int, str, float]], lambda a, b, c: { **a, **b, **c }, [ @@ -96,6 +93,7 @@ def testTuples(self): typehints.Tuple[str, int, float], lambda x: (x, 0, 1.0), [str]) def testGetItem(self): + def reverse(ab): return ab[-1], ab[0] @@ -121,6 +119,7 @@ def testGetItemSlice(self): self.assertReturnType(typehints.List[str], lambda: test_list[:], []) def testUnpack(self): + def reverse(a_b): (a, b) = a_b return b, a @@ -146,41 +145,34 @@ def reverse(a_b): def testBuildMap(self): self.assertReturnType( - typehints.Dict[typehints.Any, typehints.Any], - lambda k, - v: {}, [int, float]) + typehints.Dict[typehints.Any, typehints.Any], lambda k, v: {}, + [int, float]) self.assertReturnType( typehints.Dict[int, float], lambda k, v: {k: v}, [int, float]) self.assertReturnType( - typehints.Tuple[str, typehints.Dict[int, float]], - lambda k, - v: ('s', { + typehints.Tuple[str, typehints.Dict[int, float]], lambda k, v: + ('s', { k: v }), [int, float]) self.assertReturnType( typehints.Dict[int, typehints.Union[float, str]], - lambda k1, - v1, - k2, - v2: { + lambda k1, v1, k2, v2: { k1: v1, k2: v2 }, [int, float, int, str]) # Constant map. self.assertReturnType( - typehints.Dict[str, typehints.Union[int, float]], - lambda a, - b: { + typehints.Dict[str, typehints.Union[int, float]], lambda a, b: { 'a': a, 'b': b }, [int, float]) self.assertReturnType( typehints.Tuple[int, typehints.Dict[str, typehints.Union[int, float]]], - lambda a, - b: (4, { + lambda a, b: (4, { 'a': a, 'b': b }), [int, float]) def testNoneReturn(self): + def func(a): if a == 5: return a @@ -199,21 +191,20 @@ def testSimpleList(self): def testListComprehension(self): self.assertReturnType( - typehints.List[int], - lambda xs: [x for x in xs], [typehints.Tuple[int, ...]]) + typehints.List[int], lambda xs: [x for x in xs], + [typehints.Tuple[int, ...]]) def testTupleListComprehension(self): self.assertReturnType( - typehints.List[int], - lambda xs: [x for x in xs], [typehints.Tuple[int, int, int]]) + typehints.List[int], lambda xs: [x for x in xs], + [typehints.Tuple[int, int, int]]) self.assertReturnType( - typehints.List[typehints.Union[int, float]], - lambda xs: [x for x in xs], [typehints.Tuple[int, float]]) + typehints.List[typehints.Union[int, float]], lambda xs: [x for x in xs], + [typehints.Tuple[int, float]]) expected = typehints.List[typehints.Tuple[str, int]] self.assertReturnType( - expected, - lambda kvs: [(kvs[0], v) for v in kvs[1]], + expected, lambda kvs: [(kvs[0], v) for v in kvs[1]], [typehints.Tuple[str, typehints.Iterable[int]]]) self.assertReturnType( typehints.List[typehints.Tuple[str, typehints.Union[str, int], int]], @@ -221,6 +212,7 @@ def testTupleListComprehension(self): [typehints.Iterable[typehints.Tuple[str, int]]]) def testGenerator(self): + def foo(x, y): yield x yield y @@ -231,8 +223,8 @@ def foo(x, y): def testGeneratorComprehension(self): self.assertReturnType( - typehints.Iterable[int], - lambda xs: (x for x in xs), [typehints.Tuple[int, ...]]) + typehints.Iterable[int], lambda xs: (x for x in xs), + [typehints.Tuple[int, ...]]) def testBinOp(self): self.assertReturnType(int, lambda a, b: a + b, [int, int]) @@ -240,9 +232,8 @@ def testBinOp(self): self.assertReturnType( typehints.Any, lambda a, b: a + b, [int, typehints.Any]) self.assertReturnType( - typehints.List[typehints.Union[int, str]], - lambda a, - b: a + b, [typehints.List[int], typehints.List[str]]) + typehints.List[typehints.Union[int, str]], lambda a, b: a + b, + [typehints.List[int], typehints.List[str]]) def testCall(self): f = lambda x, *args: x @@ -253,6 +244,7 @@ def testCall(self): typehints.Tuple[int, typehints.Any], lambda: (1, f(x=1.0))) def testCallNullaryMethod(self): + class Foo: pass @@ -260,6 +252,7 @@ class Foo: typehints.Tuple[Foo, typehints.Any], lambda x: (x, x.unknown()), [Foo]) def testCallNestedLambda(self): + class Foo: pass @@ -284,10 +277,11 @@ def testBuiltins(self): def testGetAttr(self): self.assertReturnType( - typehints.Tuple[str, typehints.Any], - lambda: (typehints.__doc__, typehints.fake)) + typehints.Tuple[str, typehints.Any], lambda: + (typehints.__doc__, typehints.fake)) def testSetAttr(self): + def fn(obj, flag): if flag == 1: obj.attr = 1 @@ -300,6 +294,7 @@ def fn(obj, flag): self.assertReturnType(typehints.Union[int, float], fn, [int]) def testSetDeleteGlobal(self): + def fn(flag): # pylint: disable=global-variable-undefined global global_var @@ -314,7 +309,9 @@ def fn(flag): self.assertReturnType(typehints.Union[int, str], fn, [int]) def testMethod(self): + class A(object): + def m(self, x): return x @@ -337,6 +334,7 @@ def call_function_on_any(s): self.assertReturnType(int, call_function_on_any, [str]) def testAlwaysReturnsEarly(self): + def some_fn(v): if v: return 1 @@ -367,6 +365,7 @@ def testSet(self): # yapf: enable def testDepthFunction(self): + def f(i): return i @@ -374,7 +373,9 @@ def f(i): self.assertReturnType(int, lambda i: f(i), [int], depth=1) def testDepthMethod(self): + class A(object): + def m(self, x): return x @@ -392,16 +393,13 @@ def fn(x1, x2, *unused_args): self.assertReturnType( typehints.Tuple[typehints.Union[str, float, int], typehints.Union[str, float, int]], - lambda x1, - x2, - _list: fn(x1, x2, *_list), [str, float, typehints.List[int]]) + lambda x1, x2, _list: fn(x1, x2, *_list), + [str, float, typehints.List[int]]) # No *args self.assertReturnType( typehints.Tuple[typehints.Union[str, typehints.List[int]], typehints.Union[str, typehints.List[int]]], - lambda x1, - x2, - _list: fn(x1, x2, *_list), [str, typehints.List[int]]) + lambda x1, x2, _list: fn(x1, x2, *_list), [str, typehints.List[int]]) def testCallFunctionEx(self): # Test when fn arguments are built using BUiLD_LIST. @@ -410,22 +408,22 @@ def fn(*args): self.assertReturnType( typehints.List[typehints.Union[str, float]], - lambda x1, - x2: fn(*[x1, x2]), [str, float]) + lambda x1, x2: fn(*[x1, x2]), [str, float]) def testCallFunctionExKwargs(self): + def fn(x1, x2, **unused_kwargs): return x1, x2 # Keyword args are currently unsupported for CALL_FUNCTION_EX. self.assertReturnType( - typehints.Any, - lambda x1, - x2, - _dict: fn(x1, x2, **_dict), [str, float, typehints.List[int]]) + typehints.Any, lambda x1, x2, _dict: fn(x1, x2, **_dict), + [str, float, typehints.List[int]]) def testInstanceToType(self): + class MyClass(object): + def method(self): pass @@ -467,16 +465,14 @@ def method(self): def testRow(self): self.assertReturnType( row_type.RowTypeConstraint.from_fields([('x', int), ('y', str)]), - lambda x, - y: beam.Row(x=x + 1, y=y), [int, str]) + lambda x, y: beam.Row(x=x + 1, y=y), [int, str]) self.assertReturnType( row_type.RowTypeConstraint.from_fields([('x', int), ('y', str)]), lambda x: beam.Row(x=x, y=str(x)), [int]) def testRowAttr(self): self.assertReturnType( - typehints.Tuple[int, str], - lambda row: (row.x, getattr(row, 'y')), + typehints.Tuple[int, str], lambda row: (row.x, getattr(row, 'y')), [row_type.RowTypeConstraint.from_fields([('x', int), ('y', str)])]) def testPyCallable(self): diff --git a/sdks/python/apache_beam/typehints/typecheck.py b/sdks/python/apache_beam/typehints/typecheck.py index 7e84779a0f15..3af9510b2e48 100644 --- a/sdks/python/apache_beam/typehints/typecheck.py +++ b/sdks/python/apache_beam/typehints/typecheck.py @@ -45,6 +45,7 @@ class AbstractDoFnWrapper(DoFn): """An abstract class to create wrapper around DoFn""" + def __init__(self, dofn): super().__init__() self.dofn = dofn @@ -86,6 +87,7 @@ def teardown(self): class OutputCheckWrapperDoFn(AbstractDoFnWrapper): """A DoFn that verifies against common errors in the output type.""" + def __init__(self, dofn, full_label): super().__init__(dofn) self.full_label = full_label @@ -124,6 +126,7 @@ def _check_type(output): class TypeCheckWrapperDoFn(AbstractDoFnWrapper): """A wrapper around a DoFn which performs type-checking of input and output. """ + def __init__(self, dofn, type_hints, label=None): super().__init__(dofn) self._process_fn = self.dofn._process_argspec_fn() @@ -214,6 +217,7 @@ def type_check(type_constraint, datum, is_input): class TypeCheckCombineFn(core.CombineFn): """A wrapper around a CombineFn performing type-checking of input and output. """ + def __init__(self, combinefn, type_hints, label=None): self._combinefn = combinefn self._input_type_hint = type_hints.input_types @@ -298,6 +302,7 @@ def visit_transform(self, applied_transform): class PerformanceTypeCheckVisitor(pipeline.PipelineVisitor): + def visit_transform(self, applied_transform): transform = applied_transform.transform full_label = applied_transform.full_label diff --git a/sdks/python/apache_beam/typehints/typecheck_test.py b/sdks/python/apache_beam/typehints/typecheck_test.py index 32307c5202e9..81010764872f 100644 --- a/sdks/python/apache_beam/typehints/typecheck_test.py +++ b/sdks/python/apache_beam/typehints/typecheck_test.py @@ -48,6 +48,7 @@ class MyDoFn(beam.DoFn): + def __init__(self, output_filename): super().__init__() self.output_filename = output_filename @@ -77,12 +78,14 @@ def process(self, element: int, *args, **kwargs) -> Iterable[int]: class MyDoFnBadAnnotation(MyDoFn): + def process(self, element: int, *args, **kwargs) -> int: # Should raise an exception about return type not being iterable. return super().process() class RuntimeTypeCheckTest(unittest.TestCase): + def setUp(self): self.p = TestPipeline( options=PipelineOptions( @@ -134,6 +137,7 @@ def test_wrapper_pipeline_type_check(self): class PerformanceRuntimeTypeCheckTest(unittest.TestCase): + def setUp(self): self.p = Pipeline( options=PipelineOptions( @@ -174,9 +178,11 @@ def test_simple_output_error(self): e.exception.args[0]) def test_simple_input_error_with_kwarg_typehints(self): + @with_input_types(element=int) @with_output_types(int) class ToInt(beam.DoFn): + def process(self, element, *args, **kwargs): yield int(element) @@ -204,9 +210,11 @@ def incorrect_par_do_fn(x): self.assertStartswith(cm.exception.args[0], "'int' object is not iterable ") def test_simple_type_satisfied(self): + @with_input_types(int, int) @with_output_types(int) class AddWithNum(beam.DoFn): + def process(self, element, num): return [element + num] @@ -292,15 +300,18 @@ def test_pipeline_runtime_checking_violation_composite_type_output(self): "instead found 4.0, an instance of {}.".format(int, float)) def test_downstream_input_type_hint_error_has_descriptive_error_msg(self): + @with_input_types(int) @with_output_types(int) class IntToInt(beam.DoFn): + def process(self, element, *args, **kwargs): yield element @with_input_types(str) @with_output_types(int) class StrToInt(beam.DoFn): + def process(self, element, *args, **kwargs): yield int(element) diff --git a/sdks/python/apache_beam/typehints/typed_pipeline_test.py b/sdks/python/apache_beam/typehints/typed_pipeline_test.py index 820f78fa9ef5..ed7c126ca3b9 100644 --- a/sdks/python/apache_beam/typehints/typed_pipeline_test.py +++ b/sdks/python/apache_beam/typehints/typed_pipeline_test.py @@ -39,11 +39,13 @@ class MainInputTest(unittest.TestCase): + def assertStartswith(self, msg, prefix): self.assertTrue( msg.startswith(prefix), '"%s" does not start with "%s"' % (msg, prefix)) def test_bad_main_input(self): + @typehints.with_input_types(str, int) def repeat(s, times): return s * times @@ -69,6 +71,7 @@ def test_non_function_fails(self): [1, 2, 3] | beam.Map(str.upper) def test_loose_bounds(self): + @typehints.with_input_types(typing.Union[int, float]) @typehints.with_output_types(str) def format_number(x): @@ -78,9 +81,11 @@ def format_number(x): self.assertEqual(['1', '2', '3'], sorted(result)) def test_typed_dofn_class(self): + @typehints.with_input_types(int) @typehints.with_output_types(str) class MyDoFn(beam.DoFn): + def process(self, element): return [str(element)] @@ -96,7 +101,9 @@ def process(self, element): [1, 2, 3] | (beam.ParDo(MyDoFn()) | 'again' >> beam.ParDo(MyDoFn())) def test_typed_dofn_method(self): + class MyDoFn(beam.DoFn): + def process(self, element: int) -> typehints.Tuple[str]: return tuple(str(element)) @@ -116,6 +123,7 @@ def test_typed_dofn_method_with_class_decorators(self): @typehints.with_input_types(typehints.Tuple[int, int]) @typehints.with_output_types(int) class MyDoFn(beam.DoFn): + def process(self, element: int) -> typehints.Tuple[str]: yield element[0] @@ -185,6 +193,7 @@ def do_fn(element: typehints.Tuple[int, int]) -> typehints.Generator[str]: _ = [1, 2, 3] | (pardo | 'again' >> pardo) def test_filter_type_hint(self): + @typehints.with_input_types(int) def filter_fn(data): return data % 2 @@ -208,7 +217,9 @@ def test_partition(self): assert_that(res_odd, equal_to([1, 3]), label='odd_check') def test_typed_dofn_multi_output(self): + class MyDoFn(beam.DoFn): + def process(self, element): if element % 2: yield beam.pvalue.TaggedOutput('odd', element) @@ -240,7 +251,9 @@ def process(self, element): _ = res['undeclared tag'] def test_typed_dofn_multi_output_no_tags(self): + class MyDoFn(beam.DoFn): + def process(self, element): if element % 2: yield beam.pvalue.TaggedOutput('odd', element) @@ -291,6 +304,7 @@ def MyMap(pcoll): _ = ['a'] | MyMap() def test_typed_ptransform_fn_multi_input_types_pos(self): + @beam.ptransform_fn @beam.typehints.with_input_types(str, int) def multi_input(pcoll_tuple, additional_arg): @@ -305,6 +319,7 @@ def multi_input(pcoll_tuple, additional_arg): _ = (pcoll2, pcoll1) | 'fails' >> multi_input('additional_arg') def test_typed_ptransform_fn_multi_input_types_kw(self): + @beam.ptransform_fn @beam.typehints.with_input_types(strings=str, integers=int) def multi_input(pcoll_dict, additional_arg): @@ -324,7 +339,9 @@ def multi_input(pcoll_dict, additional_arg): } | 'fails' >> multi_input('additional_arg') def test_typed_dofn_method_not_iterable(self): + class MyDoFn(beam.DoFn): + def process(self, element: int) -> str: return str(element) @@ -332,7 +349,9 @@ def process(self, element: int) -> str: _ = [1, 2, 3] | beam.ParDo(MyDoFn()) def test_typed_dofn_method_return_none(self): + class MyDoFn(beam.DoFn): + def process(self, unused_element: int) -> None: pass @@ -340,7 +359,9 @@ def process(self, unused_element: int) -> None: self.assertListEqual([], result) def test_typed_dofn_method_return_optional(self): + class MyDoFn(beam.DoFn): + def process( self, unused_element: int) -> typehints.Optional[typehints.Iterable[int]]: @@ -350,7 +371,9 @@ def process( self.assertListEqual([], result) def test_typed_dofn_method_return_optional_not_iterable(self): + class MyDoFn(beam.DoFn): + def process(self, unused_element: int) -> typehints.Optional[int]: pass @@ -358,6 +381,7 @@ def process(self, unused_element: int) -> typehints.Optional[int]: _ = [1, 2, 3] | beam.ParDo(MyDoFn()) def test_typed_callable_not_iterable(self): + def do_fn(element: int) -> int: return element @@ -366,6 +390,7 @@ def do_fn(element: int) -> int: _ = [1, 2, 3] | beam.ParDo(do_fn) def test_typed_dofn_kwonly(self): + class MyDoFn(beam.DoFn): # TODO(BEAM-5878): A kwonly argument like # timestamp=beam.DoFn.TimestampParam would not work here. @@ -383,6 +408,7 @@ def process(self, element: int, *, side_input: str) -> \ _ = [1, 2, 3] | beam.ParDo(my_do_fn, side_input=1) def test_typed_dofn_var_kwargs(self): + class MyDoFn(beam.DoFn): def process(self, element: int, **side_inputs: typehints.Dict[str, str]) \ -> typehints.Generator[typehints.Optional[str]]: @@ -398,6 +424,7 @@ def process(self, element: int, **side_inputs: typehints.Dict[str, str]) \ _ = [1, 2, 3] | beam.ParDo(my_do_fn, a=1) def test_typed_callable_string_literals(self): + def do_fn(element: 'int') -> 'typehints.List[str]': return [[str(element)] * 2] @@ -409,6 +436,7 @@ def test_typed_ptransform_fn(self): @beam.ptransform_fn @typehints.with_input_types(int) def MyMap(pcoll): + def fn(element: int): yield element @@ -424,6 +452,7 @@ def test_typed_ptransform_fn_conflicting_hints(self): @beam.ptransform_fn @typehints.with_input_types(str) def MyMap(pcoll): + def fn(element: float): yield element @@ -437,7 +466,9 @@ def fn(element: float): _ = [b'a'] | MyMap() def test_typed_dofn_string_literals(self): + class MyDoFn(beam.DoFn): + def process(self, element: 'int') -> 'typehints.List[str]': return [[str(element)] * 2] @@ -445,6 +476,7 @@ def process(self, element: 'int') -> 'typehints.List[str]': self.assertEqual([['1', '1'], ['2', '2']], sorted(result)) def test_typed_map(self): + def fn(element: int) -> int: return element * 2 @@ -461,6 +493,7 @@ def fn(element: int) -> typehints.Optional[int]: self.assertCountEqual([None, 2, 3], result) def test_typed_flatmap(self): + def fn(element: int) -> typehints.Iterable[int]: yield element * 2 @@ -468,6 +501,7 @@ def fn(element: int) -> typehints.Iterable[int]: self.assertCountEqual([2, 4, 6], result) def test_typed_flatmap_output_hint_not_iterable(self): + def fn(element: int) -> int: return element * 2 @@ -477,6 +511,7 @@ def fn(element: int) -> int: _ = [1, 2, 3] | beam.FlatMap(fn) def test_typed_flatmap_output_value_not_iterable(self): + def fn(element: int) -> typehints.Iterable[int]: return element * 2 @@ -485,6 +520,7 @@ def fn(element: int) -> typehints.Iterable[int]: _ = [1, 2, 3] | beam.FlatMap(fn) def test_typed_flatmap_optional(self): + def fn(element: int) -> typehints.Optional[typehints.Iterable[int]]: if element > 1: yield element * 2 @@ -497,13 +533,16 @@ def fn2(element: int) -> int: self.assertCountEqual([4, 6], result) def test_typed_ptransform_with_no_error(self): + class StrToInt(beam.PTransform): + def expand( self, pcoll: beam.pvalue.PCollection[str]) -> beam.pvalue.PCollection[int]: return pcoll | beam.Map(lambda x: int(x)) class IntToStr(beam.PTransform): + def expand( self, pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]: @@ -512,13 +551,16 @@ def expand( _ = ['1', '2', '3'] | StrToInt() | IntToStr() def test_typed_ptransform_with_bad_typehints(self): + class StrToInt(beam.PTransform): + def expand( self, pcoll: beam.pvalue.PCollection[str]) -> beam.pvalue.PCollection[int]: return pcoll | beam.Map(lambda x: int(x)) class IntToStr(beam.PTransform): + def expand( self, pcoll: beam.pvalue.PCollection[str]) -> beam.pvalue.PCollection[str]: @@ -530,13 +572,16 @@ def expand( _ = ['1', '2', '3'] | StrToInt() | IntToStr() def test_typed_ptransform_with_bad_input(self): + class StrToInt(beam.PTransform): + def expand( self, pcoll: beam.pvalue.PCollection[str]) -> beam.pvalue.PCollection[int]: return pcoll | beam.Map(lambda x: int(x)) class IntToStr(beam.PTransform): + def expand( self, pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]: @@ -549,11 +594,14 @@ def expand( _ = [1, 2, 3] | StrToInt() | IntToStr() def test_typed_ptransform_with_partial_typehints(self): + class StrToInt(beam.PTransform): + def expand(self, pcoll) -> beam.pvalue.PCollection[int]: return pcoll | beam.Map(lambda x: int(x)) class IntToStr(beam.PTransform): + def expand( self, pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]: @@ -564,12 +612,15 @@ def expand( _ = [1, 2, 3] | StrToInt() | IntToStr() def test_typed_ptransform_with_bare_wrappers(self): + class StrToInt(beam.PTransform): + def expand( self, pcoll: beam.pvalue.PCollection) -> beam.pvalue.PCollection: return pcoll | beam.Map(lambda x: int(x)) class IntToStr(beam.PTransform): + def expand( self, pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]: @@ -578,11 +629,14 @@ def expand( _ = [1, 2, 3] | StrToInt() | IntToStr() def test_typed_ptransform_with_no_typehints(self): + class StrToInt(beam.PTransform): + def expand(self, pcoll): return pcoll | beam.Map(lambda x: int(x)) class IntToStr(beam.PTransform): + def expand( self, pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]: @@ -596,12 +650,14 @@ def test_typed_ptransform_with_generic_annotations(self): T = typing.TypeVar('T') class IntToInt(beam.PTransform): + def expand( self, pcoll: beam.pvalue.PCollection[T]) -> beam.pvalue.PCollection[T]: return pcoll | beam.Map(lambda x: x) class IntToStr(beam.PTransform): + def expand( self, pcoll: beam.pvalue.PCollection[T]) -> beam.pvalue.PCollection[str]: @@ -610,7 +666,9 @@ def expand( _ = [1, 2, 3] | IntToInt() | IntToStr() def test_typed_ptransform_with_do_outputs_tuple_compiles(self): + class MyDoFn(beam.DoFn): + def process(self, element: int, *args, **kwargs): if element % 2: yield beam.pvalue.TaggedOutput('odd', 1) @@ -618,6 +676,7 @@ def process(self, element: int, *args, **kwargs): yield beam.pvalue.TaggedOutput('even', 1) class MyPTransform(beam.PTransform): + def expand(self, pcoll: beam.pvalue.PCollection[int]): return pcoll | beam.ParDo(MyDoFn()).with_outputs('odd', 'even') @@ -626,6 +685,7 @@ def expand(self, pcoll: beam.pvalue.PCollection[int]): _ = [1, 2, 3] | MyPTransform() def test_typed_ptransform_with_unknown_type_vars_tuple_compiles(self): + @typehints.with_input_types(typing.TypeVar('T')) @typehints.with_output_types(typing.TypeVar('U')) def produces_unkown(e): @@ -636,6 +696,7 @@ def accepts_int(e): return e class MyPTransform(beam.PTransform): + def expand(self, pcoll): unknowns = pcoll | beam.Map(produces_unkown) ints = pcoll | beam.Map(int) @@ -645,7 +706,9 @@ def expand(self, pcoll): class NativeTypesTest(unittest.TestCase): + def test_good_main_input(self): + @typehints.with_input_types(typing.Tuple[str, int]) def munge(s_i): (s, i) = s_i @@ -655,6 +718,7 @@ def munge(s_i): self.assertEqual([('apples', 10), ('pears', 6)], sorted(result)) def test_bad_main_input(self): + @typehints.with_input_types(typing.Tuple[str, str]) def munge(s_i): (s, i) = s_i @@ -664,6 +728,7 @@ def munge(s_i): [('apple', 5), ('pear', 3)] | beam.Map(munge) def test_bad_main_output(self): + @typehints.with_input_types(typing.Tuple[int, int]) @typehints.with_output_types(typing.Tuple[str, str]) def munge(a_b): @@ -675,6 +740,7 @@ def munge(a_b): class SideInputTest(unittest.TestCase): + def _run_repeat_test(self, repeat): self._run_repeat_test_good(repeat) self._run_repeat_test_bad(repeat) @@ -704,6 +770,7 @@ def _run_repeat_test_bad(self, repeat): ['a', 'bb', 'c'] | beam.Map(repeat) def test_basic_side_input_hint(self): + @typehints.with_input_types(str, int) def repeat(s, times): return s * times @@ -711,6 +778,7 @@ def repeat(s, times): self._run_repeat_test(repeat) def test_keyword_side_input_hint(self): + @typehints.with_input_types(str, times=int) def repeat(s, times): return s * times @@ -718,6 +786,7 @@ def repeat(s, times): self._run_repeat_test(repeat) def test_default_typed_hint(self): + @typehints.with_input_types(str, int) def repeat(s, times=3): return s * times @@ -725,6 +794,7 @@ def repeat(s, times=3): self._run_repeat_test(repeat) def test_default_untyped_hint(self): + @typehints.with_input_types(str) def repeat(s, times=3): return s * times @@ -734,6 +804,7 @@ def repeat(s, times=3): @OptionsContext(pipeline_type_check=True) def test_varargs_side_input_hint(self): + @typehints.with_input_types(str, int) def repeat(s, *times): return s * times[0] @@ -789,6 +860,7 @@ def test_var_keyword_side_input_hint(self): str, ignored=str)) def test_deferred_side_inputs(self): + @typehints.with_input_types(str, int) def repeat(s, times): return s * times @@ -804,6 +876,7 @@ def repeat(s, times): main_input | 'bis' >> beam.Map(repeat, pvalue.AsSingleton(bad_side_input)) def test_deferred_side_input_iterable(self): + @typehints.with_input_types(str, typing.Iterable[str]) def concat(glue, items): return glue.join(sorted(items)) @@ -820,7 +893,9 @@ def concat(glue, items): class CustomTransformTest(unittest.TestCase): + class CustomTransform(beam.PTransform): + def _extract_input_pvalues(self, pvalueish): return pvalueish, (pvalueish['in0'], pvalueish['in1']) @@ -872,6 +947,7 @@ def test_flat_type_hint(self): class AnnotationsTest(unittest.TestCase): + def test_pardo_wrapper_builtin_method(self): th = beam.ParDo(str.strip).get_type_hints() self.assertEqual(th.input_types, ((str, typehints.Any), {})) @@ -889,7 +965,9 @@ def test_pardo_wrapper_builtin_func(self): self.assertIsNone(th.output_types) def test_pardo_dofn(self): + class MyDoFn(beam.DoFn): + def process(self, element: int) -> typehints.Generator[str]: yield str(element) @@ -898,7 +976,9 @@ def process(self, element: int) -> typehints.Generator[str]: self.assertEqual(th.output_types, ((str, ), {})) def test_pardo_dofn_not_iterable(self): + class MyDoFn(beam.DoFn): + def process(self, element: int) -> str: return str(element) @@ -906,6 +986,7 @@ def process(self, element: int) -> str: _ = beam.ParDo(MyDoFn()).get_type_hints() def test_pardo_wrapper(self): + def do_fn(element: int) -> typehints.Iterable[str]: return [str(element)] @@ -924,6 +1005,7 @@ def do_fn(element: int) -> typehints.Iterable[typehints.Tuple[str, int]]: self.assertEqual(th.output_types, ((typehints.Tuple[str, int], ), {})) def test_pardo_wrapper_not_iterable(self): + def do_fn(element: int) -> str: return str(element) @@ -932,6 +1014,7 @@ def do_fn(element: int) -> str: _ = beam.ParDo(do_fn).get_type_hints() def test_flat_map_wrapper(self): + def map_fn(element: int) -> typehints.Iterable[int]: return [element, element + 1] @@ -962,6 +1045,7 @@ def tuple_map_fn(a: str, b: str, c: str) -> typehints.Iterable[str]: self.assertEqual(th.output_types, ((str, ), {})) def test_map_wrapper(self): + def map_fn(unused_element: int) -> int: return 1 @@ -992,6 +1076,7 @@ def tuple_map_fn(a: str, b: str, c: str) -> str: self.assertEqual(th.output_types, ((str, ), {})) def test_filter_wrapper(self): + def filter_fn(element: int) -> bool: return bool(element % 2) @@ -1001,6 +1086,7 @@ def filter_fn(element: int) -> bool: class TestFlatMapTuple(unittest.TestCase): + def test_flatmaptuple(self): # Regression test. See # https://github.com/apache/beam/issues/33014 diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 0e18e887c2a0..330e378b70be 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -107,6 +107,7 @@ class CompositeTypeHintError(TypeError): class GetitemConstructor(type): """A metaclass that makes Cls[arg] an alias for Cls(arg).""" + def __getitem__(cls, arg): return cls(arg) @@ -119,6 +120,7 @@ class TypeConstraint(object): another :class:`CompositeTypeHint`. It binds and enforces a specific version of a generalized TypeHint. """ + def _consistent_with_check_(self, sub): """Returns whether sub is consistent with self. @@ -223,6 +225,7 @@ class IndexableTypeConstraint(TypeConstraint): """An internal common base-class for all type constraints with indexing. E.G. SequenceTypeConstraint + Tuple's of fixed size. """ + def _constraint_for_index(self, idx): """Returns the type at the given index. This is used to allow type inference to determine the correct type for a specific index. On lists this will also @@ -246,6 +249,7 @@ class SequenceTypeConstraint(IndexableTypeConstraint): inner_type: The type which every element in the sequence should be an instance of. """ + def __init__(self, inner_type, sequence_type): self.inner_type = normalize(inner_type) self._sequence_type = sequence_type @@ -329,6 +333,7 @@ class CompositeTypeHint(object): * Example: 'Coordinates = List[Tuple[int, int]]' """ + def __getitem___(self, py_type): """Given a type creates a TypeConstraint instance parameterized by the type. @@ -437,6 +442,7 @@ class AnyTypeConstraint(TypeConstraint): function arguments or return types. All other TypeConstraint's are equivalent to 'Any', and its 'type_check' method is a no-op. """ + def __eq__(self, other): return type(self) == type(other) @@ -453,6 +459,7 @@ def type_check(self, instance): class TypeVariable(AnyTypeConstraint): + def __init__(self, name, use_name_in_eq=True): self.name = name self.use_name_in_eq = use_name_in_eq @@ -502,7 +509,9 @@ class UnionHint(CompositeTypeHint): * Union[int, str] == Union[str, int] """ + class UnionConstraint(TypeConstraint): + def __init__(self, union_types): self.union_types = set(normalize(t) for t in union_types) @@ -619,6 +628,7 @@ class OptionalHint(UnionHint): The Optional[X] factory function proxies to Union[X, type(None)] """ + def __getitem__(self, py_type): # A single type must have been passed. if isinstance(py_type, abc.Sequence): @@ -661,7 +671,9 @@ class TupleHint(CompositeTypeHint): As an example, Tuple[str, ...] indicates a tuple of any length with each element being an instance of 'str'. """ + class TupleSequenceConstraint(SequenceTypeConstraint): + def __init__(self, type_param): super().__init__(type_param, tuple) @@ -677,6 +689,7 @@ def _consistent_with_check_(self, sub): return super()._consistent_with_check_(sub) class TupleConstraint(IndexableTypeConstraint): + def __init__(self, type_params): self.tuple_types = tuple(normalize(t) for t in type_params) @@ -703,8 +716,8 @@ def _consistent_with_check_(self, sub): return ( isinstance(sub, self.__class__) and len(sub.tuple_types) == len(self.tuple_types) and all( - is_consistent_with(sub_elem, elem) for sub_elem, - elem in zip(sub.tuple_types, self.tuple_types))) + is_consistent_with(sub_elem, elem) + for sub_elem, elem in zip(sub.tuple_types, self.tuple_types))) def type_check(self, tuple_instance): if not isinstance(tuple_instance, tuple): @@ -789,7 +802,9 @@ class ListHint(CompositeTypeHint): * ['1', '2', '3'] satisfies List[str] """ + class ListConstraint(SequenceTypeConstraint): + def __init__(self, list_type): super().__init__(list_type, list) @@ -812,6 +827,7 @@ class KVHint(CompositeTypeHint): accepts exactly two type-parameters. The first represents the required key-type and the second the required value-type. """ + def __getitem__(self, type_params): if not isinstance(type_params, tuple): raise TypeError( @@ -847,7 +863,9 @@ class DictHint(CompositeTypeHint): Dict[K, V] Represents a dictionary where all keys are of a particular type and all values are of another (possible the same) type. """ + class DictConstraint(TypeConstraint): + def __init__(self, key_type, value_type): self.key_type = normalize(key_type) self.value_type = normalize(value_type) @@ -969,7 +987,9 @@ class SetHint(CompositeTypeHint): Set[X] defines a type-hint for a set of homogeneous types. 'X' may be either a built-in Python type or a another nested TypeConstraint. """ + class SetTypeConstraint(SequenceTypeConstraint): + def __init__(self, type_param): super().__init__(type_param, set) @@ -994,7 +1014,9 @@ class FrozenSetHint(CompositeTypeHint): This is a mirror copy of SetHint - consider refactoring common functionality. """ + class FrozenSetTypeConstraint(SequenceTypeConstraint): + def __init__(self, type_param): super(FrozenSetHint.FrozenSetTypeConstraint, self).__init__(type_param, frozenset) @@ -1022,7 +1044,9 @@ class CollectionHint(CompositeTypeHint): __contains__, __iter__, and __len__. This acts as a parent type for sets but has fewer guarantees for mixins. """ + class CollectionTypeConstraint(SequenceTypeConstraint): + def __init__(self, type_param): super().__init__(type_param, abc.Collection) @@ -1078,7 +1102,9 @@ class IterableHint(CompositeTypeHint): Iterable[X] defines a type-hint for an object implementing an '__iter__' method which yields objects which are all of the same type. """ + class IterableTypeConstraint(SequenceTypeConstraint): + def __init__(self, iter_type): super(IterableHint.IterableTypeConstraint, self).__init__(iter_type, abc.Iterable) @@ -1120,7 +1146,9 @@ class IteratorHint(CompositeTypeHint): underlying lazily generated sequence. See decorators.interleave_type_check for further information. """ + class IteratorTypeConstraint(TypeConstraint): + def __init__(self, t): self.yielded_type = normalize(t) @@ -1175,6 +1203,7 @@ class WindowedTypeConstraint(TypeConstraint, metaclass=GetitemConstructor): Attributes: inner_type: The type which the element should be an instance of. """ + def __init__(self, inner_type): self.inner_type = normalize(inner_type) @@ -1220,6 +1249,7 @@ class GeneratorHint(IteratorHint): Subscriptor is in the form [yield_type, send_type, return_type], however only yield_type is supported. The 2 others are expected to be None. """ + def __getitem__(self, type_params): if isinstance(type_params, tuple) and len(type_params) == 3: yield_type, send_type, return_type = type_params diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index 6611dcecab01..a0cf024925b5 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -62,6 +62,7 @@ def check_or_interleave(hint, value, var): def check_type_hints(f): + @functools.wraps(f) def wrapper(*args, **kwargs): hints = get_type_hints(f) @@ -115,6 +116,7 @@ class NonBuiltInGeneric(typing.NamedTuple('Entry', [('Field1', T), class TypeHintTestCase(unittest.TestCase): + def assertCompatible(self, base, sub): # pylint: disable=invalid-name base, sub = native_type_compatibility.convert_to_beam_types([base, sub]) self.assertTrue( @@ -128,6 +130,7 @@ def assertNotCompatible(self, base, sub): # pylint: disable=invalid-name class TypeVariableTestCase(TypeHintTestCase): + def test_eq_with_name_check(self): use_name_in_eq = True self.assertNotEqual( @@ -150,6 +153,7 @@ def test_eq_with_different_name_checks(self): class AnyTypeConstraintTestCase(TypeHintTestCase): + def test_any_compatibility(self): self.assertCompatible(typehints.Any, typehints.List[int]) self.assertCompatible(typehints.Any, DummyTestClass1) @@ -185,6 +189,7 @@ def test_type_check(self): class UnionHintTestCase(TypeHintTestCase): + def test_getitem_must_be_valid_type_param_cant_be_object_instance(self): with self.assertRaises(TypeError) as e: typehints.Union[5] @@ -348,6 +353,7 @@ def visitor(hint, arg): class OptionalHintTestCase(TypeHintTestCase): + def test_getitem_sequence_not_allowed(self): with self.assertRaises(TypeError) as e: typehints.Optional[int, str] @@ -369,6 +375,7 @@ def test_is_optional(self): class TupleHintTestCase(TypeHintTestCase): + def test_getitem_invalid_ellipsis_type_param(self): error_msg = ( 'Ellipsis can only be used to type-hint an arbitrary length ' @@ -553,6 +560,7 @@ def test_builtin_and_type_compatibility(self): class ListHintTestCase(TypeHintTestCase): + def test_getitem_invalid_composite_type_param(self): with self.assertRaises(TypeError): typehints.List[4] @@ -627,6 +635,7 @@ def test_builtin_is_typing_generic(self): class KVHintTestCase(TypeHintTestCase): + def test_getitem_param_must_be_tuple(self): with self.assertRaises(TypeError) as e: typehints.KV[4] @@ -657,6 +666,7 @@ def test_enforce_kv_type_constraint(self): class DictHintTestCase(TypeHintTestCase): + def test_getitem_param_must_be_tuple(self): with self.assertRaises(TypeError) as e: typehints.Dict[4] @@ -775,7 +785,9 @@ def test_builtin_and_type_compatibility(self): class BaseSetHintTest: + class CommonTests(TypeHintTestCase): + def test_getitem_invalid_composite_type_param(self): try: self.beam_type[list] @@ -847,6 +859,7 @@ class FrozenSetHintTestCase(BaseSetHintTest.CommonTests): class CollectionHintTestCase(TypeHintTestCase): + def test_type_constraint_compatibility(self): self.assertCompatible(typehints.Collection[int], typehints.Set[int]) self.assertCompatible(typehints.Iterable[int], typehints.Collection[int]) @@ -876,6 +889,7 @@ def test_getitem_invalid_composite_type_param(self): class IterableHintTestCase(TypeHintTestCase): + def test_getitem_invalid_composite_type_param(self): with self.assertRaises(TypeError) as e: typehints.Iterable[5] @@ -960,7 +974,9 @@ def test_type_check_violation_valid_composite_type(self): class TestGeneratorWrapper(TypeHintTestCase): + def test_functions_as_regular_generator(self): + def count(n): for i in range(n): yield i @@ -978,6 +994,7 @@ def count(n): class GeneratorHintTestCase(TypeHintTestCase): + def test_repr(self): hint = typehints.Iterator[typehints.Set[str]] self.assertEqual('Iterator[Set[]]', repr(hint)) @@ -992,6 +1009,7 @@ def test_conversion(self): typehints.Iterator[int], typehints.Generator[int, None, None]) def test_generator_return_hint_invalid_yield_type(self): + @check_type_hints @with_output_types(typehints.Iterator[int]) def all_upper(s): @@ -1009,6 +1027,7 @@ def all_upper(s): e.exception.args[0]) def test_generator_argument_hint_invalid_yield_type(self): + def wrong_yield_gen(): for e in ['a', 'b']: yield e @@ -1030,6 +1049,7 @@ def increment(a): class TakesDecoratorTestCase(TypeHintTestCase): + def test_must_be_primitive_type_or_constraint(self): with self.assertRaises(TypeError) as e: t = [1, 2] @@ -1058,6 +1078,7 @@ def unused_foo(a): e.exception.args[0]) def test_basic_type_assertion(self): + @check_type_hints @with_input_types(a=int) def foo(a): @@ -1073,6 +1094,7 @@ def foo(a): e.exception.args[0]) def test_composite_type_assertion(self): + @check_type_hints @with_input_types(a=typehints.List[int]) def foo(a): @@ -1090,6 +1112,7 @@ def foo(a): e.exception.args[0]) def test_valid_simple_type_arguments(self): + @with_input_types(a=str) def upper(a): return a.upper() @@ -1098,6 +1121,7 @@ def upper(a): self.assertEqual('M', upper('m')) def test_any_argument_type_hint(self): + @check_type_hints @with_input_types(a=typehints.Any) def foo(a): @@ -1106,6 +1130,7 @@ def foo(a): self.assertEqual(4, foo('m')) def test_valid_mix_positional_and_keyword_arguments(self): + @check_type_hints @with_input_types(typehints.List[int], elem=typehints.List[int]) def combine(container, elem): @@ -1114,6 +1139,7 @@ def combine(container, elem): self.assertEqual([1, 2, 3], combine([1, 2], [3])) def test_invalid_only_positional_arguments(self): + @check_type_hints @with_input_types(int, int) def sub(a, b): @@ -1130,6 +1156,7 @@ def sub(a, b): e.exception.args[0]) def test_valid_only_positional_arguments(self): + @with_input_types(int, int) def add(a, b): return a + b @@ -1138,7 +1165,9 @@ def add(a, b): class InputDecoratorTestCase(TypeHintTestCase): + def test_valid_hint(self): + @with_input_types(int, int) def unused_add(a, b): return a + b @@ -1167,7 +1196,9 @@ def unused_foo(a): class OutputDecoratorTestCase(TypeHintTestCase): + def test_valid_hint(self): + @with_output_types(int) def unused_foo(): return 5 @@ -1208,6 +1239,7 @@ def unused_foo(): return 4, 'f' def test_type_check_violation(self): + @check_type_hints @with_output_types(int) def foo(a): @@ -1224,6 +1256,7 @@ def foo(a): e.exception.args[0]) def test_type_check_simple_type(self): + @check_type_hints @with_output_types(str) def upper(a): @@ -1232,6 +1265,7 @@ def upper(a): self.assertEqual('TEST', upper('test')) def test_type_check_composite_type(self): + @check_type_hints @with_output_types(typehints.List[typehints.Tuple[int, int]]) def bar(): @@ -1240,6 +1274,7 @@ def bar(): self.assertEqual([(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)], bar()) def test_any_return_type_hint(self): + @check_type_hints @with_output_types(typehints.Any) def bar(): @@ -1249,7 +1284,9 @@ def bar(): class CombinedReturnsAndTakesTestCase(TypeHintTestCase): + def test_enable_and_disable_type_checking_takes(self): + @with_input_types(a=int) def int_to_str(a): return str(a) @@ -1270,6 +1307,7 @@ def int_to_str(a): # pylint: disable=function-redefined int_to_str('a') def test_enable_and_disable_type_checking_returns(self): + @with_output_types(str) def int_to_str(a): return a @@ -1290,6 +1328,7 @@ def int_to_str(a): # pylint: disable=function-redefined int_to_str(9) def test_valid_mix_pos_and_keyword_with_both_orders(self): + @with_input_types(str, start=int) @with_output_types(str) def to_upper_with_slice(string, start): @@ -1298,6 +1337,7 @@ def to_upper_with_slice(string, start): self.assertEqual('ELLO', to_upper_with_slice('hello', 1)) def test_simple_takes_and_returns_hints(self): + @check_type_hints @with_output_types(str) @with_input_types(a=str) @@ -1322,6 +1362,7 @@ def to_lower(a): # pylint: disable=function-redefined to_lower('a') def test_composite_takes_and_returns_hints(self): + @check_type_hints @with_input_types(it=typehints.List[int]) @with_output_types(typehints.List[typehints.Tuple[int, int]]) @@ -1347,6 +1388,7 @@ def expand_ints(it): # pylint: disable=function-redefined class DecoratorHelpers(TypeHintTestCase): + def test_hint_helper(self): self.assertTrue(is_consistent_with(Any, int)) self.assertTrue(is_consistent_with(int, Any)) @@ -1363,6 +1405,7 @@ def test_positional_arg_hints(self): _positional_arg_hints(['x', 'y'], {'x': int})) def test_getcallargs_forhints(self): + def func(a, b_c, *d): return a, b_c, d @@ -1402,6 +1445,7 @@ def test_getcallargs_forhints_builtins(self): class TestGetYieldedType(unittest.TestCase): + def test_iterables(self): self.assertEqual(int, typehints.get_yielded_type(typehints.Iterable[int])) self.assertEqual(int, typehints.get_yielded_type(typehints.Iterator[int])) @@ -1419,6 +1463,7 @@ def test_not_iterable(self): class TestCoerceToKvType(TypeHintTestCase): + def test_coercion_success(self): cases = [ ((Any, ), typehints.KV[Any, Any]), @@ -1445,7 +1490,9 @@ def test_coercion_fail(self): class TestParDoAnnotations(unittest.TestCase): + def test_with_side_input(self): + class MyDoFn(DoFn): def process(self, element: float, side_input: str) -> \ Iterable[KV[str, float]]: @@ -1456,7 +1503,9 @@ def process(self, element: float, side_input: str) -> \ self.assertEqual(th.output_types, ((KV[str, float], ), {})) def test_pep484_annotations(self): + class MyDoFn(DoFn): + def process(self, element: int) -> Iterable[str]: pass @@ -1466,8 +1515,11 @@ def process(self, element: int) -> Iterable[str]: class TestPTransformAnnotations(unittest.TestCase): + def test_pep484_annotations(self): + class MyPTransform(PTransform): + def expand(self, pcoll: PCollection[int]) -> PCollection[str]: return pcoll | Map(lambda num: str(num)) @@ -1476,7 +1528,9 @@ def expand(self, pcoll: PCollection[int]) -> PCollection[str]: self.assertEqual(th.output_types, ((str, ), {})) def test_annotations_without_input_pcollection_wrapper(self): + class MyPTransform(PTransform): + def expand(self, pcoll: int) -> PCollection[str]: return pcoll | Map(lambda num: str(num)) @@ -1491,7 +1545,9 @@ def expand(self, pcoll: int) -> PCollection[str]: self.assertIn(error_str, log.output[0]) def test_annotations_without_output_pcollection_wrapper(self): + class MyPTransform(PTransform): + def expand(self, pcoll: PCollection[int]) -> str: return pcoll | Map(lambda num: str(num)) @@ -1508,7 +1564,9 @@ def expand(self, pcoll: PCollection[int]) -> str: self.assertEqual(th.output_types, None) def test_annotations_without_input_internal_type(self): + class MyPTransform(PTransform): + def expand(self, pcoll: PCollection) -> PCollection[str]: return pcoll | Map(lambda num: str(num)) @@ -1517,7 +1575,9 @@ def expand(self, pcoll: PCollection) -> PCollection[str]: self.assertEqual(th.output_types, ((str, ), {})) def test_annotations_without_output_internal_type(self): + class MyPTransform(PTransform): + def expand(self, pcoll: PCollection[int]) -> PCollection: return pcoll | Map(lambda num: str(num)) @@ -1526,7 +1586,9 @@ def expand(self, pcoll: PCollection[int]) -> PCollection: self.assertEqual(th.output_types, ((Any, ), {})) def test_annotations_without_any_internal_type(self): + class MyPTransform(PTransform): + def expand(self, pcoll: PCollection) -> PCollection: return pcoll | Map(lambda num: str(num)) @@ -1535,7 +1597,9 @@ def expand(self, pcoll: PCollection) -> PCollection: self.assertEqual(th.output_types, ((Any, ), {})) def test_annotations_without_input_typehint(self): + class MyPTransform(PTransform): + def expand(self, pcoll) -> PCollection[str]: return pcoll | Map(lambda num: str(num)) @@ -1544,7 +1608,9 @@ def expand(self, pcoll) -> PCollection[str]: self.assertEqual(th.output_types, ((str, ), {})) def test_annotations_without_output_typehint(self): + class MyPTransform(PTransform): + def expand(self, pcoll: PCollection[int]): return pcoll | Map(lambda num: str(num)) @@ -1553,7 +1619,9 @@ def expand(self, pcoll: PCollection[int]): self.assertEqual(th.output_types, ((Any, ), {})) def test_annotations_without_any_typehints(self): + class MyPTransform(PTransform): + def expand(self, pcoll): return pcoll | Map(lambda num: str(num)) @@ -1562,7 +1630,9 @@ def expand(self, pcoll): self.assertEqual(th.output_types, None) def test_annotations_with_pbegin(self): + class MyPTransform(PTransform): + def expand(self, pcoll: PBegin): return pcoll | Map(lambda num: str(num)) @@ -1571,7 +1641,9 @@ def expand(self, pcoll: PBegin): self.assertEqual(th.output_types, ((Any, ), {})) def test_annotations_with_pdone(self): + class MyPTransform(PTransform): + def expand(self, pcoll) -> PDone: return pcoll | Map(lambda num: str(num)) @@ -1580,7 +1652,9 @@ def expand(self, pcoll) -> PDone: self.assertEqual(th.output_types, ((Any, ), {})) def test_annotations_with_none_input(self): + class MyPTransform(PTransform): + def expand(self, pcoll: None) -> PCollection[str]: return pcoll | Map(lambda num: str(num)) @@ -1597,7 +1671,9 @@ def expand(self, pcoll: None) -> PCollection[str]: self.assertEqual(th.output_types, ((str, ), {})) def test_annotations_with_none_output(self): + class MyPTransform(PTransform): + def expand(self, pcoll) -> None: return pcoll | Map(lambda num: str(num)) @@ -1606,7 +1682,9 @@ def expand(self, pcoll) -> None: self.assertEqual(th.output_types, ((Any, ), {})) def test_annotations_with_arbitrary_output(self): + class MyPTransform(PTransform): + def expand(self, pcoll) -> str: return pcoll | Map(lambda num: str(num)) @@ -1615,7 +1693,9 @@ def expand(self, pcoll) -> str: self.assertEqual(th.output_types, None) def test_annotations_with_arbitrary_input_and_output(self): + class MyPTransform(PTransform): + def expand(self, pcoll: int) -> str: return pcoll | Map(lambda num: str(num)) @@ -1639,7 +1719,9 @@ def expand(self, pcoll: int) -> str: self.assertEqual(th.output_types, None) def test_typing_module_annotations_are_converted_to_beam_annotations(self): + class MyPTransform(PTransform): + def expand( self, pcoll: PCollection[typing.Dict[str, str]] ) -> PCollection[typing.Dict[str, str]]: @@ -1650,6 +1732,7 @@ def expand( self.assertEqual(th.input_types, ((typehints.Dict[str, str], ), {})) def test_nested_typing_annotations_are_converted_to_beam_annotations(self): + class MyPTransform(PTransform): def expand(self, pcoll: PCollection[typing.Union[int, typing.Any, typing.Dict[str, float]]]) \ @@ -1667,7 +1750,9 @@ def expand(self, pcoll: float]], ), {})) def test_mixed_annotations_are_converted_to_beam_annotations(self): + class MyPTransform(PTransform): + def expand(self, pcoll: typing.Any) -> typehints.Any: return pcoll @@ -1687,6 +1772,7 @@ def test_pipe_operator_as_union(self): class TestNonBuiltInGenerics(unittest.TestCase): + def test_no_error_thrown(self): input = NonBuiltInGeneric[str] output = typehints.normalize(input) diff --git a/sdks/python/apache_beam/utils/annotations.py b/sdks/python/apache_beam/utils/annotations.py index 4bf265c9fd73..84630c62a9f8 100644 --- a/sdks/python/apache_beam/utils/annotations.py +++ b/sdks/python/apache_beam/utils/annotations.py @@ -74,6 +74,7 @@ class BeamDeprecationWarning(DeprecationWarning): class _WarningMessage: """Utility class for assembling the warning message.""" + def __init__(self, label, since, current, extra_message, custom_message): """Initialize message, leave only name as placeholder.""" if custom_message is None: diff --git a/sdks/python/apache_beam/utils/counters.py b/sdks/python/apache_beam/utils/counters.py index 57d73fa283eb..fb99db615a74 100644 --- a/sdks/python/apache_beam/utils/counters.py +++ b/sdks/python/apache_beam/utils/counters.py @@ -208,6 +208,7 @@ def _str_internal(self): class AccumulatorCombineFnCounter(Counter): """Counter optimized for a mutating accumulator that holds all the logic.""" + def __init__(self, name, combine_fn): # type: (CounterName, cy_combiners.AccumulatorCombineFn) -> None assert isinstance(combine_fn, cy_combiners.AccumulatorCombineFn) @@ -228,6 +229,7 @@ def reset(self): class CounterFactory(object): """Keeps track of unique counters.""" + def __init__(self): self.counters = {} # type: Dict[CounterName, Counter] diff --git a/sdks/python/apache_beam/utils/counters_test.py b/sdks/python/apache_beam/utils/counters_test.py index 3b579cc4b4cd..5c6ac50b67b1 100644 --- a/sdks/python/apache_beam/utils/counters_test.py +++ b/sdks/python/apache_beam/utils/counters_test.py @@ -28,6 +28,7 @@ class CounterNameTest(unittest.TestCase): + def test_name_string_representation(self): counter_name = CounterName('counter_name', 'stage_name', 'step_name') @@ -80,6 +81,7 @@ def test_hash_two_objects(self): class CounterTest(unittest.TestCase): + def setUp(self): self.counter_factory = counters.CounterFactory() @@ -124,6 +126,7 @@ def test_distribution_counter(self): }, ]) class GeneralCounterTest(unittest.TestCase): + def setUp(self): self.counter_factory = counters.CounterFactory() diff --git a/sdks/python/apache_beam/utils/histogram.py b/sdks/python/apache_beam/utils/histogram.py index a0fd7129466e..d03a5b266eeb 100644 --- a/sdks/python/apache_beam/utils/histogram.py +++ b/sdks/python/apache_beam/utils/histogram.py @@ -26,6 +26,7 @@ class Histogram(object): """A histogram that supports estimated percentile with linear interpolation. """ + def __init__(self, bucket_type): self._lock = threading.Lock() self._bucket_type = bucket_type @@ -98,6 +99,7 @@ def p50(self): return self.get_linear_interpolation(0.50) def get_percentile_info(self): + def _format(f): if f == float('-inf'): return '<%s' % self._bucket_type.range_from() @@ -176,6 +178,7 @@ def __hash__(self): class BucketType(object): + def range_from(self): """Lower bound of a starting bucket.""" raise NotImplementedError @@ -207,6 +210,7 @@ def accumulated_bucket_size(self, end_index): class LinearBucket(BucketType): + def __init__(self, start, width, num_buckets): """Create a histogram with linear buckets. diff --git a/sdks/python/apache_beam/utils/histogram_test.py b/sdks/python/apache_beam/utils/histogram_test.py index a688d8a5ff4b..3671101b0576 100644 --- a/sdks/python/apache_beam/utils/histogram_test.py +++ b/sdks/python/apache_beam/utils/histogram_test.py @@ -26,6 +26,7 @@ class HistogramTest(unittest.TestCase): + @patch('apache_beam.utils.histogram._LOGGER') def test_out_of_range(self, mock_logger): histogram = Histogram(LinearBucket(0, 20, 5)) diff --git a/sdks/python/apache_beam/utils/interactive_utils_test.py b/sdks/python/apache_beam/utils/interactive_utils_test.py index 76e28ab7ee0a..beb43b665604 100644 --- a/sdks/python/apache_beam/utils/interactive_utils_test.py +++ b/sdks/python/apache_beam/utils/interactive_utils_test.py @@ -43,6 +43,7 @@ def corrupted_ipython(): not ie.current_env().is_interactive_ready, '[interactive] dependency is not installed.') class IPythonTest(unittest.TestCase): + @patch('IPython.get_ipython', new_callable=mock_get_ipython) def test_is_in_ipython_when_in_ipython_kernel(self, kernel): self.assertTrue(is_in_ipython()) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index aecb1284a1d4..3dfead286c64 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -64,6 +64,7 @@ class _SingletonProxy: """Proxies the shared object so we can release it with better errors and no risk of dangling references in the multiprocessing manager infrastructure. """ + def __init__(self, entry): # Guard names so as to not conflict with names of underlying object. self._SingletonProxy_entry = entry @@ -110,6 +111,7 @@ def __dir__(self): class _SingletonEntry: """Represents a single, refcounted entry in this process.""" + def __init__(self, constructor, initialize_eagerly=True): self.constructor = constructor self.refcount = 0 @@ -191,6 +193,7 @@ class _SingletonRegistrar(multiprocessing.managers.BaseManager): # singletonProxy_call__ calls (which is a wrapper around the underlying # object's __call__ function) class _AutoProxyWrapper: + def __init__(self, proxyObject: multiprocessing.managers.BaseProxy): self._proxyObject = proxyObject @@ -246,6 +249,7 @@ def method(self, arg): always_proxy: whether to direct all calls through the proxy, rather than call the object directly for the process that created it """ + def __init__( self, constructor: Callable[[], T], diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 0b7957632368..86bd278ba2de 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -25,6 +25,7 @@ class CallableCounter(object): + def __init__(self, start=0): self.running = start self.lock = threading.Lock() @@ -42,6 +43,7 @@ def error(self, msg): class Counter(object): + def __init__(self, start=0): self.running = start self.lock = threading.Lock() @@ -59,6 +61,7 @@ def error(self, msg): class CounterWithBadAttr(object): + def __init__(self, start=0): self.running = start self.lock = threading.Lock() @@ -83,6 +86,7 @@ def __getattribute__(self, __name: str) -> Any: class MultiProcessSharedTest(unittest.TestCase): + @classmethod def setUpClass(cls): cls.shared = multi_process_shared.MultiProcessShared( diff --git a/sdks/python/apache_beam/utils/plugin.py b/sdks/python/apache_beam/utils/plugin.py index 6c8cd58be31c..e4f3dfec8532 100644 --- a/sdks/python/apache_beam/utils/plugin.py +++ b/sdks/python/apache_beam/utils/plugin.py @@ -23,6 +23,7 @@ class BeamPlugin(object): """Plugin base class to be extended by dependent users such as FileSystem. Any instantiated subclass will be imported at worker startup time.""" + @classmethod def get_all_subclasses(cls): """Get all the subclasses of the BeamPlugin class.""" @@ -35,6 +36,7 @@ def get_all_subclasses(cls): @classmethod def get_all_plugin_paths(cls): """Get full import paths of the BeamPlugin subclass.""" + def fullname(o): return o.__module__ + "." + o.__name__ diff --git a/sdks/python/apache_beam/utils/processes_test.py b/sdks/python/apache_beam/utils/processes_test.py index 13425550dbbe..aa25204e389c 100644 --- a/sdks/python/apache_beam/utils/processes_test.py +++ b/sdks/python/apache_beam/utils/processes_test.py @@ -28,6 +28,7 @@ class Exec(unittest.TestCase): + def setUp(self): pass @@ -85,6 +86,7 @@ def test_method_forwarding_windows(self, *unused_mocks): class TestErrorHandlingCheckCall(unittest.TestCase): + @classmethod def setUpClass(cls): cls.mock_get_patcher = mock.patch(\ @@ -133,6 +135,7 @@ def test_check_call_pip_install_non_existing_package(self): class TestErrorHandlingCheckOutput(unittest.TestCase): + @classmethod def setUpClass(cls): cls.mock_get_patcher = mock.patch(\ @@ -174,6 +177,7 @@ def test_check_output_pip_install_non_existing_package(self): class TestErrorHandlingCall(unittest.TestCase): + @classmethod def setUpClass(cls): cls.mock_get_patcher = mock.patch(\ diff --git a/sdks/python/apache_beam/utils/profiler_test.py b/sdks/python/apache_beam/utils/profiler_test.py index e6991ca26819..7b81af40e525 100644 --- a/sdks/python/apache_beam/utils/profiler_test.py +++ b/sdks/python/apache_beam/utils/profiler_test.py @@ -26,6 +26,7 @@ class ProfilerTest(unittest.TestCase): + @parameterized.expand([ param(enable_cpu_memory=(True, True)), param(enable_cpu_memory=(True, False)), diff --git a/sdks/python/apache_beam/utils/proto_utils_test.py b/sdks/python/apache_beam/utils/proto_utils_test.py index c40967cd2c0f..19528cc5462d 100644 --- a/sdks/python/apache_beam/utils/proto_utils_test.py +++ b/sdks/python/apache_beam/utils/proto_utils_test.py @@ -24,6 +24,7 @@ class TestProtoUtils(unittest.TestCase): + def test_from_micros_duration(self): ts = proto_utils.from_micros(duration_pb2.Duration, MAX_TIMESTAMP.micros) expected = duration_pb2.Duration( diff --git a/sdks/python/apache_beam/utils/python_callable.py b/sdks/python/apache_beam/utils/python_callable.py index f6f507300ea8..f207e9519d08 100644 --- a/sdks/python/apache_beam/utils/python_callable.py +++ b/sdks/python/apache_beam/utils/python_callable.py @@ -43,6 +43,7 @@ class PythonCallableWithSource(object): is a valid chunk of source code. """ + def __init__(self, source: str) -> None: self._source = source self._callable = self.load_from_source(source) diff --git a/sdks/python/apache_beam/utils/python_callable_test.py b/sdks/python/apache_beam/utils/python_callable_test.py index 6fc6a1f04a69..084ca64ed837 100644 --- a/sdks/python/apache_beam/utils/python_callable_test.py +++ b/sdks/python/apache_beam/utils/python_callable_test.py @@ -22,6 +22,7 @@ class PythonCallableWithSourceTest(unittest.TestCase): + def test_builtin(self): self.assertEqual(PythonCallableWithSource.load_from_source('str'), str) diff --git a/sdks/python/apache_beam/utils/retry.py b/sdks/python/apache_beam/utils/retry.py index 485fc9d627e9..7fc03ed11f3d 100644 --- a/sdks/python/apache_beam/utils/retry.py +++ b/sdks/python/apache_beam/utils/retry.py @@ -89,6 +89,7 @@ class FuzzedExponentialIntervals(object): (None). You may need to increase num_retries to effectively use this feature. """ + def __init__( self, initial_delay_secs, @@ -222,6 +223,7 @@ def retry_if_valid_input_but_server_error_and_timeout_filter(exception): class Clock(object): """A simple clock implementing sleep().""" + def sleep(self, value): time.sleep(value) @@ -281,8 +283,10 @@ def with_exponential_backoff( @retry.with_exponential_backoff() make_http_request(args) """ + def real_decorator(fun): """The real decorator whose purpose is to return the wrapped function.""" + @functools.wraps(fun) def wrapper(*args, **kwargs): retry_intervals = iter( diff --git a/sdks/python/apache_beam/utils/retry_test.py b/sdks/python/apache_beam/utils/retry_test.py index 05cc6e6a7f40..015e296e8909 100644 --- a/sdks/python/apache_beam/utils/retry_test.py +++ b/sdks/python/apache_beam/utils/retry_test.py @@ -37,6 +37,7 @@ class FakeClock(object): """A fake clock object implementing sleep() and recording calls.""" + def __init__(self): self.calls = [] @@ -46,6 +47,7 @@ def sleep(self, value): class FakeLogger(object): """A fake logger object implementing log() and recording calls.""" + def __init__(self): self.calls = [] @@ -73,6 +75,7 @@ def _test_no_retry_function(a, b): class RetryTest(unittest.TestCase): + def setUp(self): self.clock = FakeClock() self.logger = FakeLogger() @@ -204,6 +207,7 @@ def test_log_calls_for_transient_failure(self): class DummyClass(object): + def __init__(self, results): self.index = 0 self.results = results @@ -225,6 +229,7 @@ class RetryStateTest(unittest.TestCase): The test_call_two_objects would test this inside the same test. """ + def test_two_failures(self): dummy = DummyClass(["Error", "Error", "Success"]) dummy.func() diff --git a/sdks/python/apache_beam/utils/sharded_key.py b/sdks/python/apache_beam/utils/sharded_key.py index f6492779ef34..e35901e6f48b 100644 --- a/sdks/python/apache_beam/utils/sharded_key.py +++ b/sdks/python/apache_beam/utils/sharded_key.py @@ -27,6 +27,7 @@ class ShardedKey(object): key: The user key. shard_id: An opaque byte string that uniquely represents a shard of the key. """ + def __init__( self, key, diff --git a/sdks/python/apache_beam/utils/shared.py b/sdks/python/apache_beam/utils/shared.py index bb04d1a19fb0..0125f9cb56a3 100644 --- a/sdks/python/apache_beam/utils/shared.py +++ b/sdks/python/apache_beam/utils/shared.py @@ -104,6 +104,7 @@ class _SharedControlBlock(object): We need this so we can call constructors for distinct Shared elements in the SharedMap concurrently. """ + def __init__(self): self._lock = threading.Lock() self._ref = None @@ -192,6 +193,7 @@ class _SharedMap(object): Related issues: BEAM-562 - DoFn reuse """ + def __init__(self): # Lock that protects cache_map self._lock = threading.Lock() diff --git a/sdks/python/apache_beam/utils/shared_test.py b/sdks/python/apache_beam/utils/shared_test.py index 48bd18263d59..ed1de10d6e59 100644 --- a/sdks/python/apache_beam/utils/shared_test.py +++ b/sdks/python/apache_beam/utils/shared_test.py @@ -26,6 +26,7 @@ class Count(object): + def __init__(self): self._lock = threading.Lock() self._total = 0 @@ -50,6 +51,7 @@ def get_total(self): class Marker(object): + def __init__(self, count): self._count = count self._count.add_ref() @@ -59,6 +61,7 @@ def __del__(self): class NamedObject(object): + def __init__(self, name): self._name = name @@ -67,6 +70,7 @@ def get_name(self): class Sequence(object): + def __init__(self): self._sequence = 0 @@ -81,6 +85,7 @@ def acquire_fn(): class SharedTest(unittest.TestCase): + def testKeepalive(self): count = Count() shared_handle = shared.Shared() diff --git a/sdks/python/apache_beam/utils/subprocess_server.py b/sdks/python/apache_beam/utils/subprocess_server.py index b1080cb643af..22356d5b1be1 100644 --- a/sdks/python/apache_beam/utils/subprocess_server.py +++ b/sdks/python/apache_beam/utils/subprocess_server.py @@ -68,6 +68,7 @@ class _SharedCache: finally: cache.purge(token) """ + def __init__(self, constructor, destructor): self._constructor = constructor self._destructor = destructor @@ -122,6 +123,7 @@ class SubprocessServer(object): with SubprocessServer(GrpcStubClass, [executable, arg, ...]) as stub: stub.CallService(...) """ + def __init__(self, stub_class, cmd, port=None): """Creates the server object. diff --git a/sdks/python/apache_beam/utils/subprocess_server_test.py b/sdks/python/apache_beam/utils/subprocess_server_test.py index c0c8e5694b86..d57175657d2a 100644 --- a/sdks/python/apache_beam/utils/subprocess_server_test.py +++ b/sdks/python/apache_beam/utils/subprocess_server_test.py @@ -34,6 +34,7 @@ class JavaJarServerTest(unittest.TestCase): + def test_gradle_jar_release(self): self.assertEqual( 'https://repo.maven.apache.org/maven2/org/apache/beam/' @@ -99,6 +100,7 @@ def test_beam_services(self): subprocess_server.JavaJarServer.path_to_beam_jar(':some:target')) def test_local_jar(self): + class Handler(socketserver.BaseRequestHandler): timeout = 1 @@ -168,6 +170,7 @@ def test_classpath_jar(self): class CacheTest(unittest.TestCase): + @staticmethod def with_prefix(prefix): return '%s-%s' % (prefix, random.random()) diff --git a/sdks/python/apache_beam/utils/thread_pool_executor.py b/sdks/python/apache_beam/utils/thread_pool_executor.py index e1e8ad5c43a6..9883b319c2bb 100644 --- a/sdks/python/apache_beam/utils/thread_pool_executor.py +++ b/sdks/python/apache_beam/utils/thread_pool_executor.py @@ -24,6 +24,7 @@ class _WorkItem(object): + def __init__(self, future, fn, args, kwargs): self._future = future self._fn = fn @@ -40,6 +41,7 @@ def run(self): class _Worker(threading.Thread): + def __init__(self, idle_worker_queue, work_item): super().__init__() self._idle_worker_queue = idle_worker_queue @@ -72,6 +74,7 @@ def shutdown(self): class UnboundedThreadPoolExecutor(_base.Executor): + def __init__(self): self._idle_worker_queue = queue.Queue() self._max_idle_threads = 16 @@ -122,6 +125,7 @@ def shutdown(self, wait=True): class _SharedUnboundedThreadPoolExecutor(UnboundedThreadPoolExecutor): + def shutdown(self, wait=True): # Prevent shutting down the shared thread pool pass diff --git a/sdks/python/apache_beam/utils/thread_pool_executor_test.py b/sdks/python/apache_beam/utils/thread_pool_executor_test.py index b382224f6850..3ffb0d3c307f 100644 --- a/sdks/python/apache_beam/utils/thread_pool_executor_test.py +++ b/sdks/python/apache_beam/utils/thread_pool_executor_test.py @@ -30,6 +30,7 @@ class UnboundedThreadPoolExecutorTest(unittest.TestCase): + def setUp(self): self._lock = threading.Lock() self._worker_idents = [] diff --git a/sdks/python/apache_beam/utils/timestamp.py b/sdks/python/apache_beam/utils/timestamp.py index 3f585eecae08..0e35ac6cce65 100644 --- a/sdks/python/apache_beam/utils/timestamp.py +++ b/sdks/python/apache_beam/utils/timestamp.py @@ -52,6 +52,7 @@ class Timestamp(object): especially after arithmetic operations (for example, 10000000 % 0.1 evaluates to 0.0999999994448885). """ + def __init__( self, seconds: Union[int, float] = 0, @@ -289,6 +290,7 @@ class Duration(object): especially after arithmetic operations (for example, 10000000 % 0.1 evaluates to 0.0999999994448885). """ + def __init__( self, seconds: Union[int, float] = 0, diff --git a/sdks/python/apache_beam/utils/timestamp_test.py b/sdks/python/apache_beam/utils/timestamp_test.py index f8d6cfdeafee..299139defce8 100644 --- a/sdks/python/apache_beam/utils/timestamp_test.py +++ b/sdks/python/apache_beam/utils/timestamp_test.py @@ -31,6 +31,7 @@ class TimestampTest(unittest.TestCase): + def test_of(self): interval = Timestamp(123) self.assertEqual(id(interval), id(Timestamp.of(interval))) @@ -194,6 +195,7 @@ def test_equality(self): class DurationTest(unittest.TestCase): + def test_of(self): interval = Duration(123) self.assertEqual(id(interval), id(Duration.of(interval))) diff --git a/sdks/python/apache_beam/utils/urns.py b/sdks/python/apache_beam/utils/urns.py index 2647a0200bde..d0d68c9c2d93 100644 --- a/sdks/python/apache_beam/utils/urns.py +++ b/sdks/python/apache_beam/utils/urns.py @@ -133,6 +133,7 @@ def register_urn(cls, urn, parameter_type, fn=None): A corresponding to_runner_api_parameter method would be expected that returns the tuple ('beam:fn:foo', FooPayload) """ + def register(fn): cls._known_urns[urn] = parameter_type, fn return fn @@ -149,14 +150,12 @@ def register_pickle_urn(cls, pickle_urn): """Registers and implements the given urn via pickling. """ inspect.currentframe().f_back.f_locals['to_runner_api_parameter'] = ( - lambda self, - context: + lambda self, context: (pickle_urn, wrappers_pb2.BytesValue(value=pickler.dumps(self)))) cls.register_urn( pickle_urn, wrappers_pb2.BytesValue, - lambda proto, - unused_context: pickler.loads(proto.value)) + lambda proto, unused_context: pickler.loads(proto.value)) def to_runner_api( self, context: 'PipelineContext') -> beam_runner_api_pb2.FunctionSpec: diff --git a/sdks/python/apache_beam/utils/windowed_value.py b/sdks/python/apache_beam/utils/windowed_value.py index f6232ce2f6b0..ebc0c9dbe7aa 100644 --- a/sdks/python/apache_beam/utils/windowed_value.py +++ b/sdks/python/apache_beam/utils/windowed_value.py @@ -78,6 +78,7 @@ class PaneInfo(object): whether it's an early/on time/late firing, if it's the last or first firing from a window, and the index of the firing. """ + def __init__(self, is_first, is_last, timing, index, nonspeculative_index): self._is_first = is_first self._is_last = is_last @@ -205,6 +206,7 @@ class WindowedValue(object): the pane that contained this value. If None, will be set to PANE_INFO_UNKNOWN. """ + def __init__( self, value, @@ -288,6 +290,7 @@ def create(value, timestamp_micros, windows, pane_info=PANE_INFO_UNKNOWN): class WindowedBatch(object): """A batch of N windowed values, each having a value, a timestamp and set of windows.""" + def with_values(self, new_values): # type: (Any) -> WindowedBatch @@ -312,6 +315,7 @@ class HomogeneousWindowedBatch(WindowedBatch): """A WindowedBatch with Homogeneous event-time information, represented internally as a WindowedValue. """ + def __init__(self, wv): self._wv = wv @@ -391,6 +395,7 @@ def from_windowed_values( class _IntervalWindowBase(object): """Optimized form of IntervalWindow storing only microseconds for endpoints. """ + def __init__(self, start, end): # type: (TimestampTypes, TimestampTypes) -> None if start is not None: diff --git a/sdks/python/apache_beam/utils/windowed_value_test.py b/sdks/python/apache_beam/utils/windowed_value_test.py index 1e4892aa9bd3..7de07cbfaf27 100644 --- a/sdks/python/apache_beam/utils/windowed_value_test.py +++ b/sdks/python/apache_beam/utils/windowed_value_test.py @@ -32,6 +32,7 @@ class WindowedValueTest(unittest.TestCase): + def test_timestamps(self): wv = windowed_value.WindowedValue(None, 3, ()) self.assertEqual(wv.timestamp, Timestamp.of(3)) @@ -88,6 +89,7 @@ def test_pickle(self): class WindowedBatchTest(unittest.TestCase): + def test_homogeneous_windowed_batch_with_values(self): pane_info = windowed_value.PaneInfo( True, True, windowed_value.PaneInfoTiming.ON_TIME, 0, 0) @@ -155,6 +157,7 @@ def test_homogeneous_from_windowed_values(self): @parameterized_class(('wb', ), [(wb, ) for wb in WINDOWED_BATCH_INSTANCES]) class WindowedBatchUtilitiesTest(unittest.TestCase): + def test_hash(self): wb_copy = copy.copy(self.wb) self.assertFalse(self.wb is wb_copy) diff --git a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py index 109e98410852..192734986c4f 100644 --- a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py +++ b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py @@ -94,6 +94,7 @@ def _fn(row): def check_output(expected: List[str]): + def _check_inner(actual: List[PCollection[str]]): formatted_actual = actual | beam.Flatten() | beam.Map( lambda row: str(beam.Row(**row._asdict()))) @@ -229,6 +230,7 @@ def bigquery_data(): def create_test_method( pipeline_spec_file: str, custom_preprocessors: List[Callable[..., Union[Dict, List]]]): + @mock.patch('apache_beam.Pipeline', TestPipeline) def test_yaml_example(self): with open(pipeline_spec_file, encoding="utf-8") as f: @@ -287,7 +289,8 @@ def parse_test_methods(cls, path: str): @classmethod def create_test_suite(cls, name: str, path: str): - return type(name, (unittest.TestCase, ), dict(cls.parse_test_methods(path))) + return type( + name, (unittest.TestCase, ), dict(cls.parse_test_methods(path))) @classmethod def register_test_preprocessor(cls, test_names: Union[str, List]): diff --git a/sdks/python/apache_beam/yaml/generate_yaml_docs.py b/sdks/python/apache_beam/yaml/generate_yaml_docs.py index 693df6179a2d..96f83b36c0da 100644 --- a/sdks/python/apache_beam/yaml/generate_yaml_docs.py +++ b/sdks/python/apache_beam/yaml/generate_yaml_docs.py @@ -232,8 +232,7 @@ def transform_docs(transform_base, transforms, providers, extra_docs=''): longest( lambda t: longest( lambda p: add_transform_links( - t, p.description(t), providers.keys()), - providers[t]), + t, p.description(t), providers.keys()), providers[t]), transforms).replace('::\n', '\n\n :::yaml\n'), '', extra_docs, diff --git a/sdks/python/apache_beam/yaml/json_utils.py b/sdks/python/apache_beam/yaml/json_utils.py index 76cc80bc2036..a1a76f206435 100644 --- a/sdks/python/apache_beam/yaml/json_utils.py +++ b/sdks/python/apache_beam/yaml/json_utils.py @@ -43,14 +43,17 @@ schema_pb2.INT16: 'integer', schema_pb2.INT32: 'integer', schema_pb2.FLOAT: 'number', - **{v: k - for k, v in JSON_ATOMIC_TYPES_TO_BEAM.items()} + **{ + v: k + for k, v in JSON_ATOMIC_TYPES_TO_BEAM.items() + } } def json_schema_to_beam_schema( json_schema: Dict[str, Any]) -> schema_pb2.Schema: """Returns a Beam schema equivalent for the given Json schema.""" + def maybe_nullable(beam_type, nullable): if nullable: beam_type.nullable = True @@ -227,6 +230,7 @@ def parse(s: bytes): class _PicklableFromConstructor: + def __init__(self, constructor): self._constructor = constructor self._value = None diff --git a/sdks/python/apache_beam/yaml/main_test.py b/sdks/python/apache_beam/yaml/main_test.py index d5fbfedc0349..1b5b22d51d05 100644 --- a/sdks/python/apache_beam/yaml/main_test.py +++ b/sdks/python/apache_beam/yaml/main_test.py @@ -42,6 +42,7 @@ class MainTest(unittest.TestCase): + def test_pipeline_spec_from_file(self): with tempfile.TemporaryDirectory() as tmpdir: yaml_path = os.path.join(tmpdir, 'test.yaml') diff --git a/sdks/python/apache_beam/yaml/options.py b/sdks/python/apache_beam/yaml/options.py index e80141c40b1d..16e1e43749a9 100644 --- a/sdks/python/apache_beam/yaml/options.py +++ b/sdks/python/apache_beam/yaml/options.py @@ -19,6 +19,7 @@ class YamlOptions(pipeline_options.PipelineOptions): + @classmethod def _add_argparse_args(cls, parser): parser.add_argument( diff --git a/sdks/python/apache_beam/yaml/readme_test.py b/sdks/python/apache_beam/yaml/readme_test.py index 555d1d0b583f..dc23f185039f 100644 --- a/sdks/python/apache_beam/yaml/readme_test.py +++ b/sdks/python/apache_beam/yaml/readme_test.py @@ -38,6 +38,7 @@ class FakeSql(beam.PTransform): + def __init__(self, query): self.query = query @@ -101,6 +102,7 @@ def guess_name_and_type(expr): class FakeReadFromPubSub(beam.PTransform): + def __init__(self, topic, format, schema): pass @@ -114,6 +116,7 @@ def expand(self, p): class FakeWriteToPubSub(beam.PTransform): + def __init__(self, topic, format): pass @@ -122,6 +125,7 @@ def expand(self, pcoll): class FakeAggregation(beam.PTransform): + def __init__(self, **unused_kwargs): pass @@ -134,6 +138,7 @@ class _Fakes: fn = str class SomeTransform(beam.PTransform): + def __init__(*args, **kwargs): pass @@ -153,12 +158,14 @@ def expand(self, pcoll): class TestProvider(yaml_provider.InlineProvider): + def _affinity(self, other): # Always try to choose this one. return float('inf') class TestEnvironment: + def __enter__(self): self.tempdir = tempfile.TemporaryDirectory() return self @@ -257,8 +264,7 @@ def test(self): # in precommits with mock.patch( 'apache_beam.yaml.yaml_provider.ExternalProvider.create_transform', - lambda *args, - **kwargs: _Fakes.SomeTransform(*args, **kwargs)): + lambda *args, **kwargs: _Fakes.SomeTransform(*args, **kwargs)): p = beam.Pipeline(options=PipelineOptions(**options)) yaml_transform.expand_pipeline( p, modified_yaml, yaml_provider.merge_providers([test_provider])) diff --git a/sdks/python/apache_beam/yaml/yaml_combine.py b/sdks/python/apache_beam/yaml/yaml_combine.py index b7499f3b0c7a..d37c28ad17e3 100644 --- a/sdks/python/apache_beam/yaml/yaml_combine.py +++ b/sdks/python/apache_beam/yaml/yaml_combine.py @@ -101,6 +101,7 @@ class PyJsYamlCombine(beam.PTransform): language: The language used to define (and execute) the custom callables in `combine`. Defaults to generic. """ + def __init__( self, group_by: Iterable[str], diff --git a/sdks/python/apache_beam/yaml/yaml_combine_test.py b/sdks/python/apache_beam/yaml/yaml_combine_test.py index caf3de10078b..f000364d82bc 100644 --- a/sdks/python/apache_beam/yaml/yaml_combine_test.py +++ b/sdks/python/apache_beam/yaml/yaml_combine_test.py @@ -32,6 +32,7 @@ class YamlCombineTest(unittest.TestCase): + def test_multiple_aggregations(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py index e26d6140af23..ceb1ff5092fd 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py @@ -28,6 +28,7 @@ class FakeEnrichmentTransform: + def __init__(self, enrichment_handler, handler_config, timeout=30): self._enrichment_handler = enrichment_handler self._handler_config = handler_config @@ -41,6 +42,7 @@ def __call__(self, enrichment_handler, *, handler_config, timeout=30): class EnrichmentTransformTest(unittest.TestCase): + def test_enrichment_with_bigquery(self): input_data = [ Row(label="item1", rank=0), diff --git a/sdks/python/apache_beam/yaml/yaml_errors.py b/sdks/python/apache_beam/yaml/yaml_errors.py index dace44ca09f6..a65d433bbf71 100644 --- a/sdks/python/apache_beam/yaml/yaml_errors.py +++ b/sdks/python/apache_beam/yaml/yaml_errors.py @@ -55,6 +55,7 @@ def map_errors_to_standard_format(input_type): def maybe_with_exception_handling(inner_expand): + def expand(self, pcoll): wrapped_pcoll = beam.core._MaybePValueWithErrors( pcoll, self._exception_handling_args) @@ -65,6 +66,7 @@ def expand(self, pcoll): def maybe_with_exception_handling_transform_fn(transform_fn): + @functools.wraps(transform_fn) def expand(pcoll, error_handling=None, **kwargs): wrapped_pcoll = beam.core._MaybePValueWithErrors( diff --git a/sdks/python/apache_beam/yaml/yaml_io.py b/sdks/python/apache_beam/yaml/yaml_io.py index a6525aef9877..d2ac3d99df29 100644 --- a/sdks/python/apache_beam/yaml/yaml_io.py +++ b/sdks/python/apache_beam/yaml/yaml_io.py @@ -77,8 +77,8 @@ def write_to_text(pcoll, path: str): """ try: field_names = [ - name for name, - _ in schemas.named_fields_from_element_type(pcoll.element_type) + name for name, _ in schemas.named_fields_from_element_type( + pcoll.element_type) ] except Exception as exn: raise ValueError( @@ -167,7 +167,9 @@ def write_to_bigquery( described at https://beam.apache.org/documentation/sdks/yaml-errors/ Otherwise permanently failing records will cause pipeline failure. """ + class WriteToBigQueryHandlingErrors(beam.PTransform): + def default_label(self): return 'WriteToBigQuery' diff --git a/sdks/python/apache_beam/yaml/yaml_io_test.py b/sdks/python/apache_beam/yaml/yaml_io_test.py index 393e31de0e6d..276bb3d614b7 100644 --- a/sdks/python/apache_beam/yaml/yaml_io_test.py +++ b/sdks/python/apache_beam/yaml/yaml_io_test.py @@ -32,6 +32,7 @@ class FakeReadFromPubSub: + def __init__( self, topic, @@ -65,6 +66,7 @@ def __call__( class FakeWriteToPubSub: + def __init__( self, topic, messages, id_attribute=None, timestamp_attribute=None): self._topic = topic @@ -81,6 +83,7 @@ def __call__(self, topic, *, with_attributes, id_label, timestamp_attribute): class YamlPubSubTest(unittest.TestCase): + def test_simple_read(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: @@ -208,9 +211,10 @@ def test_read_avro(self): ''' % json.dumps(self._avro_schema)) assert_that( result, - equal_to( - [beam.Row(label='37a', rank=1), # linebreak - beam.Row(label='389a', rank=2)])) + equal_to([ + beam.Row(label='37a', rank=1), # linebreak + beam.Row(label='389a', rank=2) + ])) def test_read_json(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( diff --git a/sdks/python/apache_beam/yaml/yaml_join_test.py b/sdks/python/apache_beam/yaml/yaml_join_test.py index 5d43b1cdb3ab..9101b9e7d17f 100644 --- a/sdks/python/apache_beam/yaml/yaml_join_test.py +++ b/sdks/python/apache_beam/yaml/yaml_join_test.py @@ -27,6 +27,7 @@ class ToRow(beam.PTransform): + def expand(self, pcoll): return pcoll | beam.Map(lambda row: beam.Row(**row._asdict())) @@ -51,10 +52,11 @@ def expand(self, pcoll): @unittest.skipIf( - TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is - None, + TestPipeline().get_pipeline_options().view_as(StandardOptions).runner + is None, 'Do not run this test on precommit suites.') class YamlJoinTest(unittest.TestCase): + def test_basic_join(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 7f7da7aca6a9..55673ca91c17 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -180,6 +180,7 @@ def _check_mapping_arguments( # that cannot be pickled without implementing the __getstate__ and # __setstate__ methods. class _CustomJsObjectWrapper(JsObjectWrapper): + def __init__(self, js_obj): super().__init__(js_obj.__dict__['_obj']) @@ -493,6 +494,7 @@ class _Validate(beam.PTransform): invalid elements will be passed to the specified error output along with information about how the schema was invalidated. """ + def __init__( self, schema: Dict[str, Any], @@ -596,9 +598,10 @@ def explode_zip(base, fields): pcoll | beam.FlatMap( lambda row: - (explode_cross_product if cross_product else explode_zip) - ({name: getattr(row, name) - for name in all_fields}, to_explode))) + (explode_cross_product if cross_product else explode_zip)({ + name: getattr(row, name) + for name in all_fields + }, to_explode))) def infer_output_type(self, input_type): return row_type.RowTypeConstraint.from_fields([( diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py b/sdks/python/apache_beam/yaml/yaml_mapping_test.py index 2c5feec18278..5669dee33a01 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping_test.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py @@ -35,6 +35,7 @@ class YamlMappingTest(unittest.TestCase): + def test_basic(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index e958ea70aff8..f49eb880fbaa 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -59,6 +59,7 @@ def inference_output_type(self): @staticmethod def parse_processing_transform(processing_transform, typ): + def _parse_config(callable=None, path=None, name=None): if callable and (path or name): raise ValueError( @@ -109,6 +110,7 @@ def validate(model_handler_spec): @classmethod def register_handler_type(cls, type_name): + def apply(constructor): cls.handler_types[type_name] = constructor return constructor @@ -131,6 +133,7 @@ def create_handler(cls, model_handler_spec) -> "ModelHandlerProvider": @ModelHandlerProvider.register_handler_type('VertexAIModelHandlerJSON') class VertexAIModelHandlerJSONProvider(ModelHandlerProvider): + def __init__( self, endpoint_id: str, diff --git a/sdks/python/apache_beam/yaml/yaml_ml_test.py b/sdks/python/apache_beam/yaml/yaml_ml_test.py index bc354136bee1..de3b49c9d5a8 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml_test.py +++ b/sdks/python/apache_beam/yaml/yaml_ml_test.py @@ -42,6 +42,7 @@ class MLTransformTest(unittest.TestCase): + def test_ml_transform(self): ml_opts = beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle', yaml_experimental_features=['ML']) diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index b3518c568653..cf3c5471417f 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -69,6 +69,7 @@ class Provider: """Maps transform types names and args to concrete PTransform instances.""" + def available(self) -> bool: """Returns whether this provider is available to use in this environment.""" raise NotImplementedError(type(self)) @@ -243,6 +244,7 @@ def provider_from_spec(cls, spec): @classmethod def register_provider_type(cls, type_name): + def apply(constructor): cls._provider_types[type_name] = constructor return constructor @@ -270,14 +272,9 @@ def maven_jar( classifier=None, appendix=None): return ExternalJavaProvider( - urns, - lambda: subprocess_server.JavaJarServer.path_to_maven_jar( - artifact_id=artifact_id, - group_id=group_id, - version=version, - repository=repository, - classifier=classifier, - appendix=appendix)) + urns, lambda: subprocess_server.JavaJarServer.path_to_maven_jar( + artifact_id=artifact_id, group_id=group_id, version=version, + repository=repository, classifier=classifier, appendix=appendix)) @ExternalProvider.register_provider_type('beamJar') @@ -289,8 +286,7 @@ def beam_jar( version=beam_version, artifact_id=None): return ExternalJavaProvider( - urns, - lambda: subprocess_server.JavaJarServer.path_to_beam_jar( + urns, lambda: subprocess_server.JavaJarServer.path_to_beam_jar( gradle_target=gradle_target, version=version, artifact_id=artifact_id) ) @@ -322,6 +318,7 @@ def cache_artifacts(self): class ExternalJavaProvider(ExternalProvider): + def __init__(self, urns, jar_provider): super().__init__( urns, lambda: external.JavaJarExpansionService(jar_provider())) @@ -342,14 +339,15 @@ def python(urns, packages=()): return ExternalPythonProvider(urns, packages) else: return InlineProvider({ - name: - python_callable.PythonCallableWithSource.load_from_source(constructor) + name: python_callable.PythonCallableWithSource.load_from_source( + constructor) for (name, constructor) in urns.items() }) @ExternalProvider.register_provider_type('pythonPackage') class ExternalPythonProvider(ExternalProvider): + def __init__(self, urns, packages: Iterable[str]): super().__init__(urns, PypiExpansionService(packages)) @@ -382,6 +380,7 @@ def _affinity(self, other: "Provider"): @ExternalProvider.register_provider_type('yaml') class YamlProvider(Provider): + def __init__(self, transforms: Mapping[str, Mapping[str, Any]]): if not isinstance(transforms, dict): raise ValueError('Transform mapping must be a dict.') @@ -474,6 +473,7 @@ def fn_takes_side_inputs(fn): class InlineProvider(Provider): + def __init__(self, transform_factories, no_input_transforms=()): self._transform_factories = transform_factories self._no_input_transforms = set(no_input_transforms) @@ -512,9 +512,8 @@ def type_of(p): for param in cls.get_docs(factory).params } - names_and_types = [ - (name, typing_to_runner_api(type_of(p))) for name, p in params.items() - ] + names_and_types = [(name, typing_to_runner_api(type_of(p))) + for name, p in params.items()] return schema_pb2.Schema( fields=[ schema_pb2.Field(name=name, type=type, description=docs.get(name)) @@ -526,6 +525,7 @@ def description(self, typ): @classmethod def description_from_callable(cls, factory): + def empty_if_none(s): return s or '' @@ -561,11 +561,13 @@ def requires_inputs(self, typ, args): class MetaInlineProvider(InlineProvider): + def create_transform(self, type, args, yaml_create_transform): return self._transform_factories[type](yaml_create_transform, **args) class SqlBackedProvider(Provider): + def __init__( self, transforms: Mapping[str, Callable[..., beam.PTransform]], @@ -637,6 +639,7 @@ def dicts_to_rows(o): class YamlProviders: + class AssertEqual(beam.PTransform): """Asserts that the input contains exactly the elements provided. @@ -661,6 +664,7 @@ class AssertEqual(beam.PTransform): elements: The set of elements that should belong to the PCollection. YAML/JSON-style mappings will be interpreted as Beam rows. """ + def __init__(self, elements: Iterable[Any]): self._elements = elements @@ -772,6 +776,7 @@ class Flatten(beam.PTransform): Note that in YAML transforms can always take a list of inputs which will be implicitly flattened. """ + def __init__(self): # Suppress the "label" argument from the superclass for better docs. # pylint: disable=useless-parent-delegation @@ -819,6 +824,7 @@ class WindowInto(beam.PTransform): Args: windowing: the type and parameters of the windowing to perform """ + def __init__(self, windowing): self._window_transform = self._parse_window_spec(windowing) @@ -938,6 +944,7 @@ def create_builtin_provider(): class TranslatingProvider(Provider): + def __init__( self, transforms: Mapping[str, Callable[..., beam.PTransform]], @@ -1119,6 +1126,7 @@ def __exit__(self, *args): @ExternalProvider.register_provider_type('renaming') class RenamingProvider(Provider): + def __init__(self, transforms, mappings, underlying_provider, defaults=None): if isinstance(underlying_provider, dict): underlying_provider = ExternalProvider.provider_from_spec( diff --git a/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py b/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py index 175f9388a0c6..79a90a2b6a45 100644 --- a/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py @@ -32,6 +32,7 @@ class WindowIntoTest(unittest.TestCase): + def __init__(self, methodName="runWindowIntoTest"): unittest.TestCase.__init__(self, methodName) self.parse_duration = YamlProviders.WindowInto._parse_duration @@ -101,8 +102,7 @@ def test_include_file(self): self.INLINE_PROVIDER, { 'include': self.to_include - }, - ]) + }, ]) ] self.assertEqual([ @@ -118,8 +118,7 @@ def test_include_url(self): self.INLINE_PROVIDER, { 'include': 'file:///' + self.to_include - }, - ]) + }, ]) ] self.assertEqual([ @@ -135,8 +134,7 @@ def test_nested_include(self): self.INLINE_PROVIDER, { 'include': self.to_include_nested - }, - ]) + }, ]) ] self.assertEqual([ @@ -148,6 +146,7 @@ def test_nested_include(self): class YamlDefinedProider(unittest.TestCase): + def test_yaml_define_provider(self): providers = ''' - type: yaml diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 12161d3d580d..fd0bade64719 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -91,6 +91,7 @@ def validate_against_schema(pipeline, strictness): def memoize_method(func): + def wrapper(self, *args): if not hasattr(self, '_cache'): self._cache = {} @@ -130,6 +131,7 @@ def empty_if_explicitly_empty(io): class LightweightScope(object): + def __init__(self, transforms): self._transforms = transforms self._transforms_by_uuid = {t['__uuid__']: t for t in self._transforms} @@ -166,6 +168,7 @@ def get_transform_id(self, transform_name): class Scope(LightweightScope): """To look up PCollections (typically outputs of prior transforms) by name.""" + def __init__( self, root, @@ -472,6 +475,7 @@ def expand_composite_transform(spec, scope): scope.input_providers) class CompositePTransform(beam.PTransform): + @staticmethod def expand(inputs): inner_scope.compute_all() @@ -500,6 +504,7 @@ def expand_chain_transform(spec, scope): def chain_as_composite(spec): + def is_not_output_of_last_transform(new_transforms, value): return ( ('name' in new_transforms[-1] and @@ -844,8 +849,10 @@ def lift_config(spec): if 'config' not in spec: common_params = 'name', 'type', 'input', 'output', 'transforms' return { - 'config': {k: v - for (k, v) in spec.items() if k not in common_params}, + 'config': { + k: v + for (k, v) in spec.items() if k not in common_params + }, **{ k: v for (k, v) in spec.items() # @@ -943,6 +950,7 @@ def validate_transform_references(spec): class _BeamFileIOLoader(jinja2.BaseLoader): + def get_source(self, environment, path): with FileSystems.open(path) as fin: source = fin.read().decode() @@ -959,6 +967,7 @@ def expand_jinja( class YamlTransform(beam.PTransform): + def __init__(self, spec, providers={}): # pylint: disable=dangerous-default-value if isinstance(spec, str): spec = yaml.load(spec, Loader=SafeLineLoader) diff --git a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py index 2a5a96aa42df..8cfacaa2442d 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py @@ -30,6 +30,7 @@ class ScopeTest(unittest.TestCase): + def get_scope_by_spec(self, p, spec): spec = yaml.load(spec, Loader=SafeLineLoader) @@ -113,6 +114,7 @@ def test_create_ptransform(self): class TestProvider(yaml_provider.InlineProvider): + def __init__(self, transform, name): super().__init__({ name: lambda: beam.Map(lambda x: (x or ()) + (name, )), # or None @@ -138,6 +140,7 @@ def _affinity(self, other): class ProviderAffinityTest(unittest.TestCase): + @staticmethod def create_scope(s, providers): providers_dict = collections.defaultdict(list) @@ -253,6 +256,7 @@ def test_best_provider_based_on_distant_follower(self): class LightweightScopeTest(unittest.TestCase): + @staticmethod def get_spec(): pipeline_yaml = ''' diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index b9caca4ca9f4..e8a66691188d 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -54,16 +54,19 @@ def expand(self, p): class SumGlobally(beam.PTransform): + def expand(self, pcoll): return pcoll | beam.CombineGlobally(sum).without_defaults() class SizeLimiter(beam.PTransform): + def __init__(self, limit, error_handling): self._limit = limit self._error_handling = error_handling def expand(self, pcoll): + def raise_on_big(row): if len(row.element) > self._limit: raise ValueError(row.element) @@ -84,6 +87,7 @@ def raise_on_big(row): class YamlTransformE2ETest(unittest.TestCase): + def test_composite(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: @@ -417,6 +421,7 @@ def test_annotations(self): class ErrorHandlingTest(unittest.TestCase): + def test_error_handling_outputs(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: @@ -582,6 +587,7 @@ def test_mapping_errors(self): class YamlWindowingTest(unittest.TestCase): + def test_explicit_window_into(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: @@ -719,10 +725,11 @@ class AnnotatingProvider(yaml_provider.InlineProvider): provider (as identified by name) was used, along with any prior history of the given element. """ + def __init__(self, name, transform_names): super().__init__({ - transform_name: - lambda: beam.Map(lambda x: (x if type(x) == tuple else ()) + (name, )) + transform_name: lambda: beam.Map( + lambda x: (x if type(x) == tuple else ()) + (name, )) for transform_name in transform_names.strip().split() }) self._name = name @@ -774,8 +781,7 @@ def test_prefers_same_provider(self): 'provider1', # All of the providers vend A, but since the input was produced # by provider1, we prefer to use that again. - 'provider1', - # Similarly for C. + 'provider1', # Similarly for C. 'provider1')]), label='StartWith1') @@ -795,10 +801,8 @@ def test_prefers_same_provider(self): result2, equal_to([( # provider2 was necessarily chosen for P2 - 'provider2', - # Unlike above, we choose provider2 to implement A. - 'provider2', - # Likewise for C. + 'provider2', # Unlike above, we choose provider2 to implement A. + 'provider2', # Likewise for C. 'provider2')]), label='StartWith2') @@ -848,6 +852,7 @@ def test_prefers_same_provider_class(self): @beam.transforms.ptransform.annotate_yaml class LinearTransform(beam.PTransform): """A transform used for testing annotate_yaml.""" + def __init__(self, a, b): self._a = a self._b = b diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index 5bc9de24bb38..c2c49c3a140d 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -48,6 +48,7 @@ def new_pipeline(): class MainTest(unittest.TestCase): + def assertYaml(self, expected, result): result = SafeLineLoader.strip_metadata(result) expected = yaml.load(expected, Loader=SafeLineLoader) @@ -925,6 +926,7 @@ def test_only_element(self): class YamlTransformTest(unittest.TestCase): + def test_init_with_string(self): provider1 = InlineProvider({"MyTransform1": lambda: beam.Map(lambda x: x)}) provider2 = InlineProvider({"MyTransform2": lambda: beam.Map(lambda x: x)}) diff --git a/sdks/python/apache_beam/yaml/yaml_udf_test.py b/sdks/python/apache_beam/yaml/yaml_udf_test.py index 1a50568c3d20..801f88901539 100644 --- a/sdks/python/apache_beam/yaml/yaml_udf_test.py +++ b/sdks/python/apache_beam/yaml/yaml_udf_test.py @@ -44,6 +44,7 @@ def as_rows(): class YamlUDFMappingTest(unittest.TestCase): + def __init__(self, method_name='runYamlMappingTest'): super().__init__(method_name) self.data = [ @@ -136,8 +137,8 @@ def test_map_to_fields_filter_inline_py(self): @staticmethod @unittest.skipIf( - TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is - None, + TestPipeline().get_pipeline_options().view_as(StandardOptions).runner + is None, 'Do not run this test on precommit suites.') def test_map_to_fields_sql_reserved_keyword(): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( @@ -168,8 +169,8 @@ def test_map_to_fields_sql_reserved_keyword(): @staticmethod @unittest.skipIf( - TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is - None, + TestPipeline().get_pipeline_options().view_as(StandardOptions).runner + is None, 'Do not run this test on precommit suites.') def test_map_to_fields_sql_reserved_keyword_append(): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( diff --git a/sdks/python/apache_beam/yaml/yaml_utils.py b/sdks/python/apache_beam/yaml/yaml_utils.py index 63beb90f0711..0193d2f0d9dd 100644 --- a/sdks/python/apache_beam/yaml/yaml_utils.py +++ b/sdks/python/apache_beam/yaml/yaml_utils.py @@ -24,12 +24,14 @@ class SafeLineLoader(SafeLoader): """A yaml loader that attaches line information to mappings and strings.""" + class TaggedString(str): """A string class to which we can attach metadata. This is primarily used to trace a string's origin back to its place in a yaml file. """ + def __reduce__(self): # Pickle as an ordinary string. return str, (str(self), ) @@ -55,8 +57,8 @@ def create_uuid(cls): def strip_metadata(cls, spec, tagged_str=True): if isinstance(spec, Mapping): return { - cls.strip_metadata(key, tagged_str): - cls.strip_metadata(value, tagged_str) + cls.strip_metadata(key, tagged_str): cls.strip_metadata( + value, tagged_str) for (key, value) in spec.items() if key not in ('__line__', '__uuid__') } diff --git a/sdks/python/apache_beam/yaml/yaml_utils_test.py b/sdks/python/apache_beam/yaml/yaml_utils_test.py index 4fd2c793e57e..f61a30747a97 100644 --- a/sdks/python/apache_beam/yaml/yaml_utils_test.py +++ b/sdks/python/apache_beam/yaml/yaml_utils_test.py @@ -24,6 +24,7 @@ class SafeLineLoaderTest(unittest.TestCase): + def test_get_line(self): pipeline_yaml = ''' type: composite diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index ed1f723d6d4d..db2f504e638c 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -265,7 +265,7 @@ commands = [testenv:py3-yapf] # keep the version of yapf in sync with the 'rev' in .pre-commit-config.yaml and pyproject.toml deps = - yapf==0.29.0 + yapf==0.43.0 commands = yapf --version time yapf --in-place --parallel --recursive apache_beam @@ -273,7 +273,7 @@ commands = [testenv:py3-yapf-check] # keep the version of yapf in sync with the 'rev' in .pre-commit-config.yaml and pyproject.toml deps = - yapf==0.29.0 + yapf==0.43.0 commands = yapf --version time yapf --diff --parallel --recursive apache_beam