Skip to content

Commit

Permalink
🚨 Fix pylint and mypy warnings.
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Krupinski committed Jan 9, 2024
1 parent 7f9903d commit b99ac4e
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 34 deletions.
12 changes: 5 additions & 7 deletions src/lapidary/runtime/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@

import abc
import dataclasses as dc
from typing import Mapping
import typing as ty

import httpx
import httpx_auth
from httpx_auth.authentication import _MultiAuth as MultiAuth, _MultiAuth
from httpx_auth.authentication import _MultiAuth as MultiAuth

from .compat import typing as ty
from .model.api_key import CookieApiKey

AuthType = ty.TypeVar("AuthType", bound=httpx.Auth)
AuthType: ty.TypeAlias = ty.Callable[[str, str], httpx.Auth]


class AuthFactory(abc.ABC):
Expand All @@ -26,7 +24,7 @@ def __call__(self, body: object) -> httpx.Auth:

APIKeyAuthLocation: ty.TypeAlias = ty.Literal['cookie', 'header', 'query']

api_key_in: ty.Mapping[APIKeyAuthLocation, ty.Type[AuthType]] = {
api_key_in: ty.Mapping[APIKeyAuthLocation, AuthType] = {
'cookie': CookieApiKey,
'header': httpx_auth.authentication.HeaderApiKey,
'query': httpx_auth.authentication.QueryApiKey,
Expand All @@ -41,10 +39,10 @@ class APIKeyAuth(AuthFactory):

def __call__(self, body: object) -> httpx.Auth:
typ = api_key_in[self.in_]
return typ(self.format.format(body=body), self.name)
return typ(self.format.format(body=body), self.name) # type: ignore[misc]


def get_auth(params: Mapping[str, ty.Any]) -> ty.Optional[httpx.Auth]:
def get_auth(params: ty.Mapping[str, ty.Any]) -> ty.Optional[httpx.Auth]:
auth_params = [value for value in params.values() if isinstance(value, httpx.Auth)]
auth_num = len(auth_params)
if auth_num == 0:
Expand Down
8 changes: 6 additions & 2 deletions src/lapidary/runtime/compat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
try:
__all__ = ['typing']

import sys

if sys.version_info < (3, 12):
import typing_extensions as typing
except ImportError:
else:
import typing
4 changes: 2 additions & 2 deletions src/lapidary/runtime/model/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
class CookieApiKey(httpx.Auth, authx.SupportMultiAuth):
"""Describes an API Key requests authentication."""

def __init__(self, api_key: str, cookie_name: str = None):
def __init__(self, api_key: str, cookie_name: ty.Optional[str] = None):
"""
:param api_key: The API key that will be sent.
:param cookie_name: Name of the query parameter. "api_key" by default.
"""
self.api_key = api_key
if not api_key:
raise Exception("API Key is mandatory.")
raise ValueError("API Key is mandatory.")
self.cookie_parameter_name = cookie_name or "api_key"

def auth_flow(
Expand Down
16 changes: 11 additions & 5 deletions src/lapidary/runtime/model/op.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
from collections.abc import Iterator
import dataclasses as dc
import inspect
Expand Down Expand Up @@ -105,7 +106,7 @@ def parse_params(sig: inspect.Signature) -> Iterator[ty.Union[FullParam, Request
for name, param in sig.parameters.items():
anno = param.annotation

if anno == ty.Self or (type(anno) is type and issubclass(anno, httpx.Auth)):
if anno == ty.Self or (isinstance(anno, type) and issubclass(anno, httpx.Auth)):
continue

if param.annotation == inspect.Parameter.empty:
Expand Down Expand Up @@ -133,24 +134,29 @@ def parse_params(sig: inspect.Signature) -> Iterator[ty.Union[FullParam, Request
)


def get_response_map(return_anno: ty.Union[ty.Annotated, inspect.Signature.empty]) -> ResponseMap:
annos = [anno for anno in return_anno.__metadata__ if isinstance(anno, Responses)]
def get_response_map(return_anno: type) -> ResponseMap: # type: ignore[valid-type]
if return_anno is inspect.Signature.empty:
raise TypeError('Operation function must have exactly one Responses annotation')
annos = [anno for anno in return_anno.__metadata__ if isinstance(anno, Responses)] # type: ignore[attr-defined]
if len(annos) != 1:
raise TypeError('Operation function must have exactly one Responses annotation')

return annos.pop().responses


class LapidaryOperation(ty.Callable):
class LapidaryOperation(abc.ABC):
lapidary_operation: Operation
lapidary_operation_model: ty.Optional[OperationModel]

def __call__(self, *args, **kwargs) -> ty.Any:
pass


def get_operation_model(
fn: LapidaryOperation,
) -> OperationModel:
base_model: Operation = fn.lapidary_operation
sig = inspect.signature(ty.cast(ty.Callable, fn))
sig = inspect.signature(fn)
params = list(parse_params(sig))
request_body_ = [param for param in params if isinstance(param, RequestBodyModel)]
if len(request_body_) > 1:
Expand Down
4 changes: 2 additions & 2 deletions src/lapidary/runtime/model/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class FullParam:

location: ParamLocation

style: ty.Optional[ParamStyle]
explode: ty.Optional[bool]
style: ParamStyle
explode: bool

name: str
type: ty.Type
Expand Down
7 changes: 4 additions & 3 deletions src/lapidary/runtime/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ def _operation(
path: str,
) -> ty.Callable:
def wrapper(fn: ty.Callable):
fn.lapidary_operation = Operation(method, path)
fn.lapidary_operation_model = None
fn_ = ty.cast(LapidaryOperation, fn)
fn_.lapidary_operation = Operation(method, path)
fn_.lapidary_operation_model = None

@ft.wraps(fn)
async def operation(self: 'ClientBase', **kwargs) -> ty.Any:
return await self._request(
return await self._request( # pylint: disable=protected-access
ty.cast(LapidaryOperation, fn),
kwargs,
)
Expand Down
20 changes: 11 additions & 9 deletions src/lapidary/runtime/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __call__(


def get_accept_header(response_map: ty.Optional[ResponseMap]) -> ty.Optional[str]:
if not response_map:
return None

all_mime_types = {
mime
for mime_map in response_map.values()
Expand Down Expand Up @@ -75,14 +78,13 @@ def find_request_body_serializer(
obj: ty.Any,
) -> ty.Tuple[str, Serializer]:
# find the serializer by type
obj_type = type(obj)

for content_type, typ in model.serializers.items():
if typ == obj_type:
async def serialize(model):
for item in serialize_param(model, ParamStyle.simple, explode_list=False):
yield item.encode()
return content_type, serialize
# return content_type, lambda model: serialize_param(model, ParamStyle.simple, explode_list=False)
if model:
for content_type, typ in model.serializers.items():
if typ == type(obj):
async def serialize(model):
for item in serialize_param(model, ParamStyle.simple, explode_list=False):
yield item.encode()

return content_type, serialize

raise TypeError(f'Unknown serializer for {type(obj)}')
4 changes: 2 additions & 2 deletions src/lapidary/runtime/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def parse_model(response: httpx.Response, typ: ty.Type[T]) -> T:
if inspect.isclass(typ):
if issubclass(typ, Exception):
return typ(response.json()) # type: ignore[return-value]
elif pydantic.BaseModel in inspect.getmro(typ):
return ty.cast(ty.Type[pydantic.BaseModel], typ).model_validate_json(response.content)
elif issubclass(typ, pydantic.BaseModel):
typ.model_validate_json(response.content)

return pydantic.TypeAdapter(typ).validate_json(response.content)

Expand Down
2 changes: 2 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# pylint: disable=unused-argument

import pydantic

from lapidary.runtime import ClientBase, GET, PUT, ParamStyle, Path
Expand Down
12 changes: 12 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
import unittest

Expand All @@ -9,6 +10,7 @@

from lapidary.runtime import APIKeyAuth, ClientBase, GET, POST
from lapidary.runtime.compat import typing as ty
from lapidary.runtime.model.op import get_response_map
from lapidary.runtime.model.response_map import Responses

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,3 +78,13 @@ async def handler(_):
client = Client(base_url='http://example.com', app=app)
response = await client.login()
self.assertEqual(dict(api_key='Token token', header_name='Authorization'), response.__dict__)


class TestClientSync(unittest.TestCase):
def test_missing_return_anno(self):
async def operation():
pass

sig = inspect.signature(operation)
with self.assertRaises(TypeError):
get_response_map(sig.return_annotation)
16 changes: 14 additions & 2 deletions tests/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ def test_build_request_none(self):
def test_request_param_list_simple(self):
request_factory = Mock()
build_request(
operation=OperationModel('GET', 'path', {'q_a': FullParam('a', ParamLocation.query, ParamStyle.form, False, 'q_a', List[str])}, None, {}),
operation=OperationModel(
'GET',
'path',
{'q_a': FullParam('a', ParamLocation.query, ParamStyle.form, False, 'q_a', List[str])},
None,
{},
),
actual_params=dict(q_a=['hello', 'world']),
request_factory=request_factory,
)
Expand All @@ -81,7 +87,13 @@ def test_request_param_list_simple(self):
def test_request_param_list_exploded(self):
request_factory = Mock()
build_request(
operation=OperationModel('GET', 'path', {'q_a': FullParam('a', ParamLocation.query, ParamStyle.form, True, 'q_a', List[str])}, None, {}),
operation=OperationModel(
'GET',
'path',
{'q_a': FullParam('a', ParamLocation.query, ParamStyle.form, True, 'q_a', List[str])},
None,
{},
),
actual_params=dict(q_a=['hello', 'world']),
request_factory=request_factory
)
Expand Down

0 comments on commit b99ac4e

Please sign in to comment.