Skip to content

Commit

Permalink
Add global cursor with fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
tolik0 committed Dec 18, 2024
1 parent a01c0b5 commit 357a925
Show file tree
Hide file tree
Showing 4 changed files with 1,086 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import threading
import logging
from collections import OrderedDict
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional

from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import iterate_with_last_flag_and_state, Timer
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.message import MessageRepository
Expand Down Expand Up @@ -77,6 +79,15 @@ def __init__(
# The dict is ordered to ensure that once the maximum number of partitions is reached,
# the oldest partitions can be efficiently removed, maintaining the most recent partitions.
self._cursor_per_partition: OrderedDict[str, Cursor] = OrderedDict()
self._state = {"states": []}
self._semaphore_per_partition = OrderedDict()
self._finished_partitions = set()
self._lock = threading.Lock()
self._timer = Timer()
self._global_cursor = None
self._new_global_cursor = None
self._lookback_window = 0
self._parent_state = None
self._over_limit = 0
self._partition_serializer = PerPartitionKeySerializer()

Expand All @@ -91,7 +102,7 @@ def state(self) -> MutableMapping[str, Any]:
states = []
for partition_tuple, cursor in self._cursor_per_partition.items():
cursor_state = cursor._connector_state_converter.convert_to_state_message(
cursor._cursor_field, cursor.state
self.cursor_field, cursor.state
)
if cursor_state:
states.append(
Expand All @@ -101,16 +112,40 @@ def state(self) -> MutableMapping[str, Any]:
}
)
state: dict[str, Any] = {"states": states}

state["state"] = self._global_cursor
if self._lookback_window is not None:
state["lookback_window"] = self._lookback_window
if self._parent_state is not None:
state["parent_state"] = self._parent_state
print(state)
return state

def close_partition(self, partition: Partition) -> None:
self._cursor_per_partition[self._to_partition_key(partition._stream_slice.partition)].close_partition_without_emit(partition=partition)
print(f"Closing partition {self._to_partition_key(partition._stream_slice.partition)}")
self._cursor_per_partition[self._to_partition_key(partition._stream_slice.partition)].close_partition(partition=partition)
with (self._lock):
self._semaphore_per_partition[self._to_partition_key(partition._stream_slice.partition)].acquire()
cursor = self._cursor_per_partition[self._to_partition_key(partition._stream_slice.partition)]
cursor_state = cursor._connector_state_converter.convert_to_state_message(
cursor._cursor_field, cursor.state
)
print(f"State {cursor_state} {cursor.state}")
if self._to_partition_key(partition._stream_slice.partition) in self._finished_partitions \
and self._semaphore_per_partition[self._to_partition_key(partition._stream_slice.partition)]._value == 0:
if self._new_global_cursor is None \
or self._new_global_cursor[self.cursor_field.cursor_field_key] < cursor_state[self.cursor_field.cursor_field_key]:
self._new_global_cursor = copy.deepcopy(cursor_state)

def ensure_at_least_one_state_emitted(self) -> None:
"""
The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
called.
"""
if not any(semaphore_item[1]._value for semaphore_item in self._semaphore_per_partition.items()):
self._global_cursor = self._new_global_cursor
self._lookback_window = self._timer.finish()
self._parent_state = self._partition_router.get_stream_state()
self._emit_state_message()

def _emit_state_message(self) -> None:
Expand All @@ -127,6 +162,7 @@ def _emit_state_message(self) -> None:

def stream_slices(self) -> Iterable[StreamSlice]:
slices = self._partition_router.stream_slices()
self._timer.start()
for partition in slices:
yield from self.generate_slices_from_partition(partition)

Expand All @@ -143,8 +179,15 @@ def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[Str
)
cursor = self._create_cursor(partition_state)
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor

for cursor_slice in cursor.stream_slices():
self._semaphore_per_partition[self._to_partition_key(partition.partition)] = threading.Semaphore(0)

for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
cursor.stream_slices(),
lambda: None,
):
self._semaphore_per_partition[self._to_partition_key(partition.partition)].release()
if is_last_slice:
self._finished_partitions.add(self._to_partition_key(partition.partition))
yield StreamSlice(
partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields
)
Expand Down Expand Up @@ -208,6 +251,7 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
self._cursor_per_partition[self._to_partition_key(state["partition"])] = (
self._create_cursor(state["cursor"])
)
self._semaphore_per_partition[self._to_partition_key(state["partition"])] = threading.Semaphore(0)

# set default state for missing partitions if it is per partition with fallback to global
if "state" in stream_state:
Expand All @@ -217,6 +261,7 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
self._partition_router.set_initial_state(stream_state)

def observe(self, record: Record) -> None:
print(self._to_partition_key(record.associated_slice.partition), record)
self._cursor_per_partition[self._to_partition_key(record.associated_slice.partition)].observe(record)

def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@
InMemoryMessageRepository,
LogAppenderMessageRepositoryDecorator,
MessageRepository,
NoopMessageRepository,
)
from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField
from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import (
Expand Down Expand Up @@ -773,6 +774,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
stream_namespace: Optional[str],
config: Config,
stream_state: MutableMapping[str, Any],
message_repository: Optional[MessageRepository] = None,
**kwargs: Any,
) -> ConcurrentCursor:
component_type = component_definition.get("type")
Expand Down Expand Up @@ -908,7 +910,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
stream_name=stream_name,
stream_namespace=stream_namespace,
stream_state=stream_state,
message_repository=self._message_repository,
message_repository=message_repository or self._message_repository,
connector_state_manager=state_manager,
connector_state_converter=connector_state_converter,
cursor_field=cursor_field,
Expand Down Expand Up @@ -961,6 +963,7 @@ def create_concurrent_cursor_from_perpartition_cursor(
stream_name=stream_name,
stream_namespace=stream_namespace,
config=config,
message_repository=NoopMessageRepository()
)
)

Expand Down
10 changes: 1 addition & 9 deletions airbyte_cdk/sources/streams/concurrent/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def _get_concurrent_state(
)

def observe(self, record: Record) -> None:
print(f"Observing record: {record}")
most_recent_cursor_value = self._most_recent_cursor_value_per_partition.get(
record.associated_slice
)
Expand All @@ -240,15 +241,6 @@ def observe(self, record: Record) -> None:
def _extract_cursor_value(self, record: Record) -> Any:
return self._connector_state_converter.parse_value(self._cursor_field.extract_value(record))

def close_partition_without_emit(self, partition: Partition) -> None:
slice_count_before = len(self.state.get("slices", []))
self._add_slice_to_state(partition)
if slice_count_before < len(
self.state["slices"]
): # only emit if at least one slice has been processed
self._merge_partitions()
self._has_closed_at_least_one_slice = True

def close_partition(self, partition: Partition) -> None:
slice_count_before = len(self.state.get("slices", []))
self._add_slice_to_state(partition)
Expand Down
Loading

0 comments on commit 357a925

Please sign in to comment.