Skip to content

Commit

Permalink
Add Event bus
Browse files Browse the repository at this point in the history
  • Loading branch information
mariotaddeucci committed Dec 15, 2024
1 parent 2fbb2b1 commit e23b266
Show file tree
Hide file tree
Showing 4 changed files with 388 additions and 0 deletions.
Empty file added src/gyjd/database/__init__.py
Empty file.
52 changes: 52 additions & 0 deletions src/gyjd/database/connection_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
from pathlib import Path
from typing import Literal

from gyjd.database.sqlite_connection import SQLiteConnection


class ConnectionFactory:
@classmethod
def create_connection(cls, db_name: Literal["event_bus"]) -> SQLiteConnection:
location = Path.home() / "gyjd" / "database" / db_name / f"{db_name}.db"
os.makedirs(location.parent, exist_ok=True)
conn = SQLiteConnection(str(location.absolute()))
getattr(cls, f"_create_{db_name}_schema")(conn)
return conn

@classmethod
def _create_event_bus_schema(cls, conn: SQLiteConnection) -> None:
sql = """
CREATE TABLE IF NOT EXISTS events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
event_type TEXT NOT NULL,
payload TEXT NOT NULL,
processed INTEGER NOT NULL DEFAULT 0,
created_at DATETIME NOT NULL,
processed_at DATETIME
);
CREATE TABLE IF NOT EXISTS subscribers (
task_name TEXT PRIMARY KEY,
event_types TEXT NOT NULL,
function_path TEXT NOT NULL,
mode TEXT NOT NULL DEFAULT 'any',
created_at DATETIME NOT NULL,
max_attempts INTEGER NOT NULL DEFAULT 1,
retry_delay INTEGER NOT NULL DEFAULT 30,
concurrency_limit INTEGER NOT NULL DEFAULT 8
);
CREATE TABLE IF NOT EXISTS tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
subscriber_task_name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
attempt_count INTEGER NOT NULL DEFAULT 0,
handled_events TEXT NOT NULL,
parameters TEXT NOT NULL,
last_attempt_at DATETIME,
created_at DATETIME NOT NULL,
completed_at DATETIME,
scheduled_at DATETIME NOT NULL,
FOREIGN KEY(subscriber_task_name) REFERENCES subscribers(task_name)
);
"""
conn.conn.executescript(sql)
68 changes: 68 additions & 0 deletions src/gyjd/database/sqlite_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import asyncio
import sqlite3
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime
from typing import Literal


class SQLiteConnection:
_MAINTENANCE_TABLE = "maintenance_metadata"

def __init__(self, conn_str: str):
self.conn = sqlite3.connect(conn_str, check_same_thread=False)
self.conn.execute("PRAGMA foreign_keys = ON")
self.conn.execute("PRAGMA journal_mode = WAL;")
self._ensure_maintenance_metadata()
self.auto_maintenance()
self._lock = asyncio.Lock()

def _ensure_maintenance_metadata(self):
sql = f"CREATE TABLE IF NOT EXISTS {self._MAINTENANCE_TABLE} (event_name VARCHAR(36) PRIMARY KEY, last_event_datetime DATETIME);"
self.conn.execute(sql)

def _register_event(self, event_name):
sql = f"INSERT OR REPLACE INTO {self._MAINTENANCE_TABLE} (event_name, last_event_datetime) VALUES (?, ?);"
with self.cursor() as c:
c.execute(sql, (event_name, datetime.utcnow().isoformat()))

def _get_last_event_datetime(self, event_name) -> datetime | None:
sql = f"SELECT last_event_datetime FROM {self._MAINTENANCE_TABLE} WHERE event_name = ?;"
with self.cursor() as c:
c.execute(sql, (event_name,))
result = c.fetchone()
return datetime.fromisoformat(result[0]) if result else None

def vacuum(self):
self.conn.execute("VACUUM;")
self._register_event("vacuum")

def checkpoint(self, mode: Literal["passive", "full", "restart", "truncate"] = "full"):
self.conn.execute(f"PRAGMA wal_checkpoint({mode});")
self._register_event("checkpoint")

def auto_maintenance(self):
last_vacuum = self._get_last_event_datetime("vacuum")
if last_vacuum is None or (datetime.utcnow() - last_vacuum).days > 7:
self.vacuum()

last_checkpoint = self._get_last_event_datetime("checkpoint")
if last_checkpoint is None or (datetime.utcnow() - last_checkpoint).days > 1:
self.checkpoint()

@contextmanager
def cursor(self):
c = self.conn.cursor()
try:
yield c
except Exception:
self.conn.rollback()
raise
finally:
c.close()
self.conn.commit()

@asynccontextmanager
async def async_cursor(self):
async with self._lock:
with self.cursor() as c:
yield c
268 changes: 268 additions & 0 deletions src/gyjd/event_bus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
import asyncio
import importlib
import json
import logging
from datetime import datetime, timedelta
from typing import List, Literal, TypedDict

from gyjd.database.connection_factory import ConnectionFactory

DEFAULT_RETRY_DELAY = 30

logger = logging.getLogger("gyjd")


class MappedEventDict(TypedDict):
event_id: str
event_name: str
payload: dict
event_date: str


MappedEvent = List[MappedEventDict]


class EventBus:
def __init__(self, polling_interval=10):
self._conn = ConnectionFactory.create_connection("event_bus")
self.polling_interval = polling_interval

def add_event(self, event_type: str, payload: dict):
self._conn.conn.execute(
"INSERT INTO events (event_type, payload, created_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
(event_type, json.dumps(payload)),
)

def subscribe(
self,
event_types: List[str],
function_path: str,
task_name: str,
mode: Literal["any", "batch"] = "any",
max_attempts: int = 1,
retry_delay: int = DEFAULT_RETRY_DELAY,
concurrency_limit: int | None = None,
):
if concurrency_limit is None:
concurrency_limit = 8 if mode == "any" else 1

if mode == "batch" and concurrency_limit > 1:
raise ValueError("concurrency_limit must be 1 for batch mode")

with self._conn.cursor() as c:
sql = """
INSERT OR REPLACE INTO subscribers (task_name, event_types, function_path, mode, max_attempts, retry_delay, concurrency_limit, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
"""
c.execute(
sql,
(
task_name,
json.dumps(event_types),
function_path,
mode,
max_attempts,
retry_delay,
concurrency_limit,
),
)

async def process_events(self) -> bool:
logger.info("Looking for events to process")
async with self._conn.async_cursor() as c:
c.execute("SELECT min(id) min_id, max(id) max_id FROM events WHERE processed = 0")
min_id, max_id = c.fetchone()

if not (min_id and max_id):
logger.info("No events to process")
return False

c.execute(
"""
INSERT INTO tasks (subscriber_task_name, handled_events, parameters, created_at, scheduled_at)
WITH subs as (
SELECT task_name, value as event_type
FROM subscribers, json_each(event_types)
WHERE mode = 'any'
)
SELECT
task_name as subscriber_task_name,
'[' || e.id || ']' as handled_events,
'[' || payload || ']' as parameters,
CURRENT_TIMESTAMP as created_at,
CURRENT_TIMESTAMP as scheduled_at
FROM events e join subs using(event_type)
WHERE e.processed = 0 AND e.id BETWEEN ? AND ?
""",
(min_id, max_id),
)

c.execute(
"""
INSERT INTO tasks (subscriber_task_name, handled_events, parameters, created_at, scheduled_at)
WITH subs as (
SELECT task_name, value as event_type
FROM subscribers, json_each(event_types)
WHERE mode = 'batch'
)
SELECT
task_name as subscriber_task_name,
'[' || GROUP_CONCAT(e.id, ', ') || ']' as handled_events,
'[' || GROUP_CONCAT(e.payload, ', ') || ']' as parameters,
CURRENT_TIMESTAMP as created_at,
CURRENT_TIMESTAMP as scheduled_at
FROM events e join subs using(event_type)
WHERE e.processed = 0 AND e.id BETWEEN ? AND ?
GROUP BY 1
""",
(min_id, max_id),
)

c.execute(
"UPDATE events SET processed = 1, processed_at = CURRENT_TIMESTAMP WHERE processed = 0 AND id BETWEEN ? AND ?",
(min_id, max_id),
)

return True

@classmethod
def _load_function(cls, function_path: str):
parts = function_path.split(".")
mod_name = ".".join(parts[:-1])
func_name = parts[-1]
module = importlib.import_module(mod_name)
func = getattr(module, func_name)
return func

async def _run_task(self, task):
(task_id, subscriber_task_name, attempt_count, parameters) = task

async with self._conn.async_cursor() as c:
c.execute(
"SELECT function_path, max_attempts, retry_delay FROM subscribers WHERE task_name = ?",
(subscriber_task_name,),
)
row = c.fetchone()
if not row:
c.execute("UPDATE tasks SET status = 'failed' WHERE id = ?", (task_id,))
return

function_path, max_attempts, retry_delay = row

async with self._conn.async_cursor() as c:
c.execute(
"UPDATE tasks SET attempt_count = attempt_count + 1, last_attempt_at = CURRENT_TIMESTAMP WHERE id = ?",
(task_id,),
)

try:
func = self._load_function(function_path)
parameters = json.loads(parameters)
except Exception:
async with self._conn.async_cursor() as c:
c.execute("UPDATE tasks SET status = 'failed' WHERE id = ?", (task_id,))
return

try:
if asyncio.iscoroutinefunction(func):
await func(parameters)
else:
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, func, parameters)

async with self._conn.async_cursor() as c:
c.execute(
"UPDATE tasks SET status = 'done', completed_at = CURRENT_TIMESTAMP WHERE id = ?",
(task_id,),
)

except Exception as e:
print(e)
async with self._conn.async_cursor() as c:
c.execute("SELECT attempt_count FROM tasks WHERE id = ?", (task_id,))
attempt_count = c.fetchone()[0]
if attempt_count < max_attempts:
new_time = (datetime.utcnow() + timedelta(seconds=retry_delay)).isoformat()
c.execute("UPDATE tasks SET status = 'pending', scheduled_at = ? WHERE id = ?", (new_time, task_id))
else:
c.execute("UPDATE tasks SET status = 'failed' WHERE id = ?", (task_id,))

async def run_tasks(self) -> bool:
logger.info("Looking for tasks to run")
sql = """
WITH bs AS (
SELECT
t.id,
t.subscriber_task_name,
t.attempt_count,
t.parameters,
t.scheduled_at,
s.concurrency_limit,
row_number() OVER (PARTITION BY t.subscriber_task_name ORDER BY t.scheduled_at) AS rn
FROM
tasks t
JOIN subscribers s ON
t.subscriber_task_name = s.task_name
WHERE
t.scheduled_at <= CURRENT_TIMESTAMP AND
t.status = 'pending'
)
SELECT
id,
subscriber_task_name,
attempt_count,
parameters
FROM bs
WHERE rn <= concurrency_limit
ORDER BY scheduled_at
LIMIT 16
"""

async with self._conn.async_cursor() as c:
c.execute(sql)
tasks = c.fetchall()

if not tasks:
logger.info("No tasks to run")
return False

await asyncio.gather(*(self._run_task(t) for t in tasks))

return True

async def run(self):
while await self.run_tasks() or await self.process_events():
pass

async def run_forever(self):
logger.info("Event bus started")
while True:
await self.run()


event_bus = EventBus()


def subscribe(
event_types: List[str],
task_name: str | None = None,
mode: Literal["any", "batch"] = "any",
max_attempts: int = 1,
retry_delay: int = DEFAULT_RETRY_DELAY,
concurrency_limit: int = 8,
):
def decorator(func):
module_name: str = func.__module__
func_name: str = func.__name__
path = f"{module_name}.{func_name}"
nonlocal task_name
if task_name is None:
task_name = func_name
event_bus.subscribe(event_types, path, task_name, mode, max_attempts, retry_delay, concurrency_limit)
return func

return decorator


def emmit(*, event_type: str, payload: dict):
event_bus.add_event(event_type, payload)

0 comments on commit e23b266

Please sign in to comment.