Skip to content

Commit

Permalink
♻️ Make handling of operation functions parameters more object oriented.
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Krupinski committed Jan 25, 2024
1 parent b8499aa commit 2c33e30
Show file tree
Hide file tree
Showing 14 changed files with 464 additions and 574 deletions.
173 changes: 31 additions & 142 deletions poetry.lock

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ python = "^3.9"
httpx = {extras = ["http2"], version = "^0.26.0"}
httpx-auth = "^0.19.0"
pydantic = {extras = ["email"], version = "^2.5.2"}
pytest-asyncio = "^0.23.3"
python-mimeparse = "^1.6.0"
typing-extensions = { python = "<3.12", version = "^4.9.0" }

[tool.poetry.group.dev.dependencies]
fastapi = "^0.109.0"
mypy = "^1.0.1"
pylint = "^2.16.2"
pylint = "^3.0.3"
pytest = "^7.1"

[build-system]
Expand All @@ -52,8 +53,7 @@ addopts = "--color=yes"
[tool.mypy]
mypy_path = "src"
namespace_packages = true
ignore_missing_imports = true
explicit_package_bases = true
python_version = "3.9"

[tool.pylint]
disable = [
Expand All @@ -71,5 +71,6 @@ disable = [
disable = [
"C0116", "C0103",
"E0401",
"R0801", "R0903"
"R0801", "R0903",
"W0621"
]
4 changes: 1 addition & 3 deletions src/lapidary/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from .absent import ABSENT, Absent
from .auth import APIKeyAuth
from .client_base import ClientBase, USER_AGENT
from .model.op import Operation
from .model.params import ParamLocation, ParamStyle
from .model.request import RequestBody
from .model.params import ParamStyle, RequestBody
from .model.response_map import Responses
from .operation import delete, get, head, patch, post, put, trace
from .param import Cookie, Header, Path, Query
11 changes: 0 additions & 11 deletions src/lapidary/runtime/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,3 @@ class APIKeyAuth(AuthFactory):
def __call__(self, body: object) -> httpx.Auth:
typ = api_key_in[self.in_]
return typ(self.format.format(body=body), self.name) # type: ignore[misc]


def get_auth(params: ty.Mapping[str, ty.Any]) -> ty.Optional[httpx.Auth]:
auth_params = [value for value in params.values() if isinstance(value, httpx.Auth)]
auth_num = len(auth_params)
if auth_num == 0:
return None
elif auth_num == 1:
return auth_params[0]
else:
return MultiAuth(*auth_params)
33 changes: 18 additions & 15 deletions src/lapidary/runtime/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

import httpx

from .auth import get_auth
from .compat import typing as ty
from .model.op import LapidaryOperation, get_operation_model
from .request import RequestFactory, build_request
from .model.op import OperationModel, get_operation_model
from .request import build_request

logger = logging.getLogger(__name__)

Expand All @@ -18,17 +17,17 @@
class ClientBase(abc.ABC):
def __init__(
self,
base_url: str,
user_agent: ty.Optional[str] = USER_AGENT,
_http_client: ty.Optional[httpx.AsyncClient] = None,
**httpx_kwargs,
):
if 'base_url' not in httpx_kwargs:
raise ValueError('Missing base_url.')
headers = httpx.Headers(httpx_kwargs.pop('headers', None)) or httpx.Headers()
if user_agent:
headers['User-Agent'] = user_agent


self._client = httpx.AsyncClient(**httpx_kwargs, headers=headers)
self._client = _http_client or httpx.AsyncClient(base_url=base_url, headers=headers, **httpx_kwargs)
self._lapidary_operations: ty.MutableMapping[str, OperationModel] = {}

async def __aenter__(self):
await self._client.__aenter__()
Expand All @@ -39,22 +38,26 @@ async def __aexit__(self, __exc_type=None, __exc_value=None, __traceback=None) -

async def _request(
self,
fn: LapidaryOperation,
method: str,
path: str,
fn: ty.Callable[..., ty.Awaitable],
actual_params: Mapping[str, ty.Any],
):
if not fn.lapidary_operation_model:
operation = get_operation_model(fn)
fn.lapidary_operation_model = operation
if fn.__name__ not in self._lapidary_operations:
operation = get_operation_model(method, path, fn)
self._lapidary_operations[fn.__name__] = operation
else:
operation = fn.lapidary_operation_model
operation = self._lapidary_operations[fn.__name__]

request = build_request(
request, auth = build_request(
operation,
actual_params,
ty.cast(RequestFactory, self._client.build_request),
self._client.build_request,
)

logger.debug("%s %s %s", request.method, request.url, request.headers)

response = await self._client.send(request, auth=get_auth(actual_params))
response = await self._client.send(request, auth=auth)
await response.aread()

return operation.handle_response(response)
141 changes: 23 additions & 118 deletions src/lapidary/runtime/model/op.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,34 @@
import abc
from collections.abc import Iterator
import dataclasses as dc
import inspect

import httpx

from .params import FullParam, Param, ParamLocation, ProcessedParams, serialize_param
from .request import RequestBody, RequestBodyModel
from .params import RequestPart, parse_params, find_annotations
from .response_map import ResponseMap, Responses
from ..absent import ABSENT
from ..compat import typing as ty
from ..response import find_type, parse_model


@dc.dataclass(frozen=True)
class InputData:
query: httpx.QueryParams
headers: httpx.Headers
cookies: httpx.Cookies
path: ty.Mapping[str, str]
request_body: ty.Any
if ty.TYPE_CHECKING:
from .request import RequestBuilder


@dc.dataclass(frozen=True)
class OperationModel:
method: str
path: str
params: ty.Mapping[str, FullParam]
request_body: ty.Optional[RequestBodyModel]
params: ty.Mapping[str, RequestPart]
response_map: ResponseMap

def process_params(self, actual_params: ty.Mapping[str, ty.Any]) -> ProcessedParams:
containers: ty.Mapping[ParamLocation, ty.List[ty.Any]] = {
ParamLocation.cookie: [],
ParamLocation.header: [],
ParamLocation.query: [],
ParamLocation.path: [],
}
request_body: ty.Any = None

for param_name, value in actual_params.items():
if self.request_body and param_name == self.request_body.param_name:
request_body = value
continue

if isinstance(value, httpx.Auth):
continue

formal_param = self.params.get(param_name)
if not formal_param:
raise TypeError(f'Operation {self.method} {self.path} got an unexpected argument {param_name}')

if value is ABSENT:
def process_params(
self,
actual_params: ty.Mapping[str, ty.Any],
request: 'RequestBuilder',
) -> None:
for param_name, param_handler in self.params.items():
if param_name not in actual_params:
continue

placement = formal_param.location

value = [(formal_param.alias, value) for value in serialize_param(value, formal_param.style, formal_param.explode)]
containers[placement].extend(value)

return ProcessedParams(
query=httpx.QueryParams(containers[ParamLocation.query]),
headers=httpx.Headers(containers[ParamLocation.header]),
cookies=httpx.Cookies(containers[ParamLocation.cookie]),
path={item[0]: item[1] for item in containers[ParamLocation.path]},
request_body=request_body,
)
param_handler.apply(request, actual_params[param_name])

def handle_response(self, response: httpx.Response) -> ty.Any:
"""
Expand All @@ -73,15 +37,10 @@ def handle_response(self, response: httpx.Response) -> ty.Any:
Auth
"""

from ..response import find_type, parse_model

response.read()

typ = find_type(response, self.response_map)

if typ is None:
response.raise_for_status()
return response.content
return None

obj: ty.Any = parse_model(response, typ)

Expand All @@ -96,77 +55,23 @@ def handle_response(self, response: httpx.Response) -> ty.Any:
return obj


@dc.dataclass
class Operation:
method: str
path: str


def parse_params(sig: inspect.Signature) -> Iterator[ty.Union[FullParam, RequestBodyModel]]:
for name, param in sig.parameters.items():
anno = param.annotation

if anno == ty.Self or (isinstance(anno, type) and issubclass(anno, httpx.Auth)):
continue

if param.annotation == inspect.Parameter.empty:
raise TypeError(f"Parameter '{name} is missing annotation'")

param_annos = [a for a in anno.__metadata__ if isinstance(a, (Param, RequestBody))]
if len(param_annos) != 1:
raise ValueError(f'{param.name}: expected exactly one annotation of type RequestBody, ')

param_anno = param_annos.pop()

if isinstance(param_anno, RequestBody):
yield RequestBodyModel(
name,
param_anno.content
)
else:
yield FullParam(
name=param.name,
alias=param_anno.alias or param.name,
location=param_anno.location,
type=param.annotation,
style=param_anno.get_style(),
explode=param_anno.get_explode(),
)


def get_response_map(return_anno: type) -> ResponseMap: # type: ignore[valid-type]
if return_anno is inspect.Signature.empty:
raise TypeError('Operation function must have exactly one Responses annotation')
annos = [anno for anno in return_anno.__metadata__ if isinstance(anno, Responses)] # type: ignore[attr-defined]
def get_response_map(return_anno: type) -> ResponseMap:
annos = find_annotations(return_anno, Responses)
if len(annos) != 1:
raise TypeError('Operation function must have exactly one Responses annotation')

return annos.pop().responses


class LapidaryOperation(abc.ABC):
lapidary_operation: Operation
lapidary_operation_model: ty.Optional[OperationModel]

def __call__(self, *args, **kwargs) -> ty.Any:
pass
return annos[0].responses


def get_operation_model(
fn: LapidaryOperation,
method: str,
path: str,
fn: ty.Callable,
) -> OperationModel:
base_model: Operation = fn.lapidary_operation
sig = inspect.signature(fn)
params = list(parse_params(sig))
request_body_ = [param for param in params if isinstance(param, RequestBodyModel)]
if len(request_body_) > 1:
raise ValueError()
request_body = request_body_.pop() if request_body_ else None

return OperationModel(
method=base_model.method,
path=base_model.path,
params={param.name: param for param in params if isinstance(param, (FullParam, httpx.Auth))},
request_body=request_body,
method=method,
path=path,
params=parse_params(sig),
response_map=get_response_map(sig.return_annotation),
)
Loading

0 comments on commit 2c33e30

Please sign in to comment.