Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(AsyncRetriever): Allow for streams using AsyncRetriever and DatetimeBasedCursor to perform checkpointing #226

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions airbyte_cdk/sources/declarative/declarative_stream.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import logging
from dataclasses import InitVar, dataclass, field
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union
Expand All @@ -13,7 +14,7 @@
)
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever
from airbyte_cdk.sources.declarative.retrievers import AsyncRetriever, SimpleRetriever
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
from airbyte_cdk.sources.declarative.schema import DefaultSchemaLoader
from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader
Expand Down Expand Up @@ -189,7 +190,10 @@ def state_checkpoint_interval(self) -> Optional[int]:
return None

def get_cursor(self) -> Optional[Cursor]:
if self.retriever and isinstance(self.retriever, SimpleRetriever):
if self.retriever and (
isinstance(self.retriever, SimpleRetriever)
or isinstance(self.retriever, AsyncRetriever)
):
return self.retriever.cursor
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
AsyncJobOrchestrator,
AsyncPartition,
)
from airbyte_cdk.sources.declarative.incremental import DatetimeBasedCursor
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import (
SinglePartitionRouter,
)
Expand Down Expand Up @@ -35,6 +37,14 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._job_orchestrator_factory = self.job_orchestrator_factory
self._job_orchestrator: Optional[AsyncJobOrchestrator] = None
self._parameters = parameters
if isinstance(self.stream_slicer, DatetimeBasedCursor):
self._cursor: Optional[DeclarativeCursor] = self.stream_slicer
else:
self._cursor = None

@property
def cursor(self) -> Optional[DeclarativeCursor]:
return self._cursor

def stream_slices(self) -> Iterable[StreamSlice]:
slices = self.stream_slicer.stream_slices()
Expand Down
69 changes: 50 additions & 19 deletions airbyte_cdk/sources/declarative/retrievers/async_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncPartition
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
AsyncJobPartitionRouter,
)
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
from airbyte_cdk.sources.source import ExperimentalClassWarning
from airbyte_cdk.sources.streams.core import StreamData
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState
from airbyte_cdk.utils.traced_exception import AirbyteTracedException


Expand All @@ -35,27 +36,23 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:

@property
def state(self) -> StreamState:
"""
As a first iteration for sendgrid, there is no state to be managed
"""
return {}

@state.setter
def state(self, value: StreamState) -> None:
"""
As a first iteration for sendgrid, there is no state to be managed
"""
pass

def _get_stream_state(self) -> StreamState:
"""
Gets the current state of the stream.

Returns:
StreamState: Mapping[str, Any]
"""
return self.stream_slicer.cursor.get_stream_state() if self.stream_slicer.cursor else {}

@state.setter
def state(self, value: StreamState) -> None:
"""State setter, accept state serialized by state getter."""
if self.stream_slicer.cursor:
self.stream_slicer.cursor.set_initial_state(value)

return self.state
@property
def cursor(self) -> Optional[DeclarativeCursor]:
return self.stream_slicer.cursor

def _validate_and_get_stream_slice_partition(
self, stream_slice: Optional[StreamSlice] = None
Expand Down Expand Up @@ -88,13 +85,47 @@ def read_records(
records_schema: Mapping[str, Any],
stream_slice: Optional[StreamSlice] = None,
) -> Iterable[StreamData]:
stream_state: StreamState = self._get_stream_state()
_slice = stream_slice or StreamSlice(partition={}, cursor_slice={}) # None-check

stream_state: StreamState = self.state
partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice)
records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(partition)
most_recent_record_from_slice = None

yield from self.record_selector.filter_and_transform(
for stream_data in self.record_selector.filter_and_transform(
all_data=records,
stream_state=stream_state,
records_schema=records_schema,
stream_slice=stream_slice,
)
stream_slice=_slice,
):
if self.cursor and stream_data:
self.cursor.observe(_slice, stream_data)

most_recent_record_from_slice = self._get_most_recent_record(
most_recent_record_from_slice, stream_data, _slice
)
yield stream_data

if self.cursor:
# DatetimeBasedCursor doesn't expect a partition field, but for AsyncRetriever streams this will
# be the slice range
slice_no_partition = StreamSlice(cursor_slice=_slice.cursor_slice, partition={})
self.cursor.close_slice(slice_no_partition, most_recent_record_from_slice)

def _get_most_recent_record(
self,
current_most_recent: Optional[Record],
current_record: Optional[Record],
stream_slice: StreamSlice,
) -> Optional[Record]:
if self.cursor and current_record:
if not current_most_recent:
return current_record
else:
return (
current_most_recent
if self.cursor.is_greater_than_or_equal(current_most_recent, current_record)
else current_record
)
else:
return None
Loading