Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

serializer as composition #298

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions cashews/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from .formatter import default_formatter
from .helpers import add_prefix, all_keys_lower, memory_limit
from .key import get_cache_key_template, noself
from .key_context import context as key_context
from .key_context import register as register_key_context
from .validation import invalidate_further
from .wrapper import Cache, TransactionMode, register_backend

Expand All @@ -28,6 +26,7 @@
locked = cache.locked

invalidate = cache.invalidate
key_context = cache.template_context

mem = Cache(name="mem")
mem.setup(
Expand Down Expand Up @@ -75,5 +74,4 @@
"TransactionMode",
"register_backend",
"key_context",
"register_key_context",
]
35 changes: 22 additions & 13 deletions cashews/backends/diskcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,24 @@
from diskcache import Cache, FanoutCache

from cashews._typing import Key, Value
from cashews.serialize import SerializerMixin
from cashews.serialize import DEFAULT_SERIALIZER, Serializer
from cashews.utils import Bitarray

from .interface import NOT_EXIST, UNLIMITED, Backend


class _DiskCache(Backend):
class DiskCache(Backend):
def __init__(self, *args, directory=None, shards=8, **kwargs: Any) -> None:
serializer = kwargs.pop("serializer", DEFAULT_SERIALIZER)
self.__is_init = False
self._set_locks: dict[str, asyncio.Lock] = {}
self._sharded = shards > 1
if not self._sharded:
self._cache = Cache(directory=directory, **kwargs)
else:
self._cache = FanoutCache(directory=directory, shards=shards, **kwargs)
super().__init__(**kwargs)
super().__init__(serializer=serializer, **kwargs)
self._serializer: Serializer

async def init(self):
self.__is_init = True
Expand All @@ -46,6 +48,7 @@ async def set(
expire: float | None = None,
exist: bool | None = None,
) -> bool:
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
future = self._run_in_executor(self._set, key, value, expire, exist)
if exist is not None:
# we should have async lock until value real set
Expand All @@ -69,25 +72,34 @@ async def set_raw(self, key: Key, value: Any, **kwargs: Any):
return self._cache.set(key, value, **kwargs)

async def get(self, key: Key, default: Value | None = None) -> Value:
return await self._run_in_executor(self._cache.get, key, default)
value = await self._run_in_executor(self._cache.get, key, default)
return await self._serializer.decode(self, key=key, value=value, default=default)

async def get_raw(self, key: Key) -> Value:
return self._cache.get(key)

async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value]:
return await self._run_in_executor(self._get_many, keys, default)
async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value | None, ...]:
if not keys:
return ()
values = await self._run_in_executor(self._get_many, keys, default)
values = await asyncio.gather(
*[self._serializer.decode(self, key=key, value=value, default=default) for key, value in zip(keys, values)]
)
return tuple(None if isinstance(value, Bitarray) else value for value in values)

def _get_many(self, keys: list[Key], default: Value | None = None):
values = []
for key in keys:
val = self._cache.get(key, default=default)
if isinstance(val, Bitarray):
val = None
values.append(val)
return values

async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None):
return await self._run_in_executor(self._set_many, pairs, expire)
_pairs = {}
for key, value in pairs.items():
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
_pairs[key] = value
return await self._run_in_executor(self._set_many, _pairs, expire)

def _set_many(self, pairs: Mapping[Key, Value], expire: float | None = None):
for key, value in pairs.items():
Expand Down Expand Up @@ -215,6 +227,7 @@ async def is_locked(
return await self.exists(key)

async def unlock(self, key: Key, value: Value) -> bool:
value = await self._serializer.encode(self, key=key, value=value, expire=None)
return await self._run_in_executor(self._unlock, key, value)

def _unlock(self, key: Key, value: Value) -> bool:
Expand Down Expand Up @@ -269,7 +282,3 @@ async def set_pop(self, key: Key, count: int = 100) -> Iterable[str]:

async def get_keys_count(self) -> int:
return await self._run_in_executor(lambda: len(self._cache))


class DiskCache(SerializerMixin, _DiskCache):
pass
62 changes: 5 additions & 57 deletions cashews/backends/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import uuid
from abc import ABCMeta, abstractmethod
from contextlib import asynccontextmanager
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Iterable, Mapping, overload

from cashews.commands import ALL, Command
from cashews.exceptions import CacheBackendInteractionError, LockedError
from cashews.serialize import Serializer

if TYPE_CHECKING: # pragma: no cover
from cashews._typing import Default, Key, OnRemoveCallback, Value
Expand Down Expand Up @@ -172,62 +171,11 @@ async def lock(self, key: Key, expire: float, wait: bool = True) -> AsyncGenerat
return


class ControlMixin:
enable_by_default = True

def __init__(self, *args, **kwargs) -> None:
self.__disable: ContextVar[set[Command]] = ContextVar(str(id(self)), default=set())
self._control_set = False
super().__init__(*args, **kwargs)

@property
def _disable(self) -> set[Command]:
return self.__disable.get()

def _set_disable(self, value: set[Command]) -> None:
self.__disable.set(value)
self._control_set = True

def is_disable(self, *cmds: Command) -> bool:
if not self._control_set:
return not self.enable_by_default
_disable = self._disable
if not cmds and _disable:
return True
for cmd in cmds:
if cmd in _disable:
return True
return False

def is_enable(self, *cmds: Command) -> bool:
return not self.is_disable(*cmds)

@property
def is_full_disable(self) -> bool:
if not self._control_set:
return not self.enable_by_default
return self._disable == ALL

def disable(self, *cmds: Command) -> None:
if not cmds:
_disable = ALL.copy()
else:
_disable = self._disable.copy()
_disable.update(cmds)
self._set_disable(_disable)

def enable(self, *cmds: Command) -> None:
if not cmds:
_disable = set()
else:
_disable = self._disable.copy()
_disable -= set(cmds)
self._set_disable(_disable)


class Backend(ControlMixin, _BackendInterface, metaclass=ABCMeta):
def __init__(self, *args, **kwargs) -> None:
class Backend(_BackendInterface, metaclass=ABCMeta):
def __init__(self, *args, serializer: Serializer | None = None, **kwargs) -> None:
super().__init__()
self._id = uuid.uuid4().hex
self._serializer = serializer
self._on_remove_callbacks: list[OnRemoveCallback] = []

def on_remove_callback(self, callback: OnRemoveCallback) -> None:
Expand Down
22 changes: 13 additions & 9 deletions cashews/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from copy import copy
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Mapping, overload

from cashews.serialize import SerializerMixin
from cashews.utils import Bitarray, get_obj_size

from .interface import NOT_EXIST, UNLIMITED, Backend
Expand All @@ -22,7 +21,7 @@
_missed = object()


class _Memory(Backend):
class Memory(Backend):
"""
Inmemory backend lru with ttl
"""
Expand Down Expand Up @@ -74,17 +73,22 @@ async def set(
) -> bool:
if exist is not None and (key in self.store) is not exist:
return False
if self._serializer:
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
self._set(key, value, expire)
return True

async def set_raw(self, key: Key, value: Value, **kwargs: Any) -> None:
self.store[key] = value
self.store[key] = (None, value)

async def get(self, key: Key, default: Value | None = None) -> Value:
return await self._get(key, default=default)

async def get_raw(self, key: Key) -> Value:
return self.store.get(key)
val = self.store.get(key)
if val:
return val[1]
return None

async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value | None, ...]:
values = []
Expand All @@ -97,6 +101,8 @@ async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Valu

async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None):
for key, value in pairs.items():
if self._serializer:
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
self._set(key, value, expire)

async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: # type: ignore
Expand Down Expand Up @@ -200,7 +206,9 @@ async def _get(self, key: Key, default: Default | None = None) -> Value | None:
if expire_at and expire_at < time.time():
await self._delete(key)
return default
return value
if not self._serializer:
return value
return await self._serializer.decode(self, key=key, value=value, default=default)

async def _key_exist(self, key: Key) -> bool:
return (await self._get(key, default=_missed)) is not _missed
Expand Down Expand Up @@ -279,7 +287,3 @@ async def close(self):
del self.__remove_expired_stop
self.__remove_expired_stop = None
self.__is_init = False


class Memory(SerializerMixin, _Memory):
pass
7 changes: 2 additions & 5 deletions cashews/backends/redis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from cashews.picklers import DEFAULT_PICKLE
from cashews.serialize import SerializerMixin

from .backend import _Redis

__all__ = ["Redis"]


class Redis(SerializerMixin, _Redis):
pickle_type = DEFAULT_PICKLE
class Redis(_Redis):
pass
17 changes: 11 additions & 6 deletions cashews/backends/redis/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from cashews._typing import Key, Value
from cashews.backends.interface import Backend
from cashews.serialize import DEFAULT_SERIALIZER, Serializer

from .client import Redis, SafePipeline, SafeRedis

Expand Down Expand Up @@ -76,7 +77,8 @@ def __init__(
self._kwargs = kwargs
self._address = address
self.__is_init = False
super().__init__()
super().__init__(serializer=kwargs.pop("serializer", None))
self._serializer: Serializer = self._serializer or DEFAULT_SERIALIZER

@property
def is_init(self) -> bool:
Expand Down Expand Up @@ -105,6 +107,7 @@ async def set(
expire: float | None = None,
exist=None,
) -> bool:
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
nx = xx = False
if exist is True:
xx = True
Expand All @@ -118,6 +121,7 @@ async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None
px = int(expire * 1000) if expire else None
async with self._pipeline as pipe:
for key, value in pairs.items():
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
await pipe.set(key, value, px=px)
await pipe.execute()

Expand Down Expand Up @@ -211,23 +215,24 @@ async def get_size(self, key: Key) -> int:

async def get(self, key: Key, default: Value | None = None) -> Value:
value = await self._client.get(key)
return self._transform_value(value, default)
return await self._transform_value(key, value, default)

async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value | None, ...]:
if not keys:
return ()
values = await self._client.mget(*keys)
if values is None:
return tuple([default] * len(keys))
return tuple(self._transform_value(value, default) for value in values)
return tuple(
await asyncio.gather(*[self._transform_value(key, value, default) for key, value in zip(keys, values)])
)

@staticmethod
def _transform_value(value: bytes | None, default: Value | None):
async def _transform_value(self, key: Key, value: bytes | None, default: Value | None):
if value is None:
return default
if value.isdigit():
return int(value)
return value
return await self._serializer.decode(self, key=key, value=value, default=default)

async def incr(self, key: Key, value: int = 1, expire: float | None = None) -> int:
if not expire:
Expand Down
7 changes: 4 additions & 3 deletions cashews/backends/redis/client_side.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,12 @@ def __init__(
self._expire_for_recently_update = 5
self._listen_started = asyncio.Event()
self.__listen_stop = asyncio.Event()
super().__init__(*args, suppress=suppress, **kwargs)
kwargs["suppress"] = suppress
super().__init__(*args, **kwargs)

async def init(self):
self._listen_started = asyncio.Event()
self.__listen_stop = asyncio.Event()
self._listen_started.clear()
self.__listen_stop.clear()
await self._local_cache.init()
await self._recently_update.init()
await super().init()
Expand Down
2 changes: 2 additions & 0 deletions cashews/backends/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ class TransactionBackend(Backend):
"_local_cache",
"_to_delete",
"__disable",
"_id",
]

def __init__(self, backend: Backend):
self._backend = backend
self._local_cache = Memory()
self._to_delete: set[Key] = set()
super().__init__()
self._id = backend._id

def _key_is_delete(self, key: Key) -> bool:
if key in self._to_delete:
Expand Down
Loading
Loading