diff --git a/docs/auth.md b/docs/auth.md new file mode 100644 index 0000000..d6f2cd1 --- /dev/null +++ b/docs/auth.md @@ -0,0 +1,111 @@ +# Authentication + +OpenAPI allows declaring security schemes and security requirements of operations. + +Lapidary allows to declare python methods that create or consume `https.Auth` objects. + +## Generating auth tokens + + +A `/login/` or `/authenticate/` endpoint that returns the token is quite common with simpler authentication schemes like http or apiKey, yet their support is poor in OpenAPI. There's no way to connect +such endpoint to a security scheme as in the case of OIDC. + +A function that handles such an endpoint can declare that it returns an Auth object, but it's not obvious to the user of python API which security scheme the method returns. + +```python +import pydantic +import typing_extensions as ty +from lapidary.runtime import POST, ClientBase, RequestBody, APIKeyAuth, Responses +from httpx import Auth + + +class LoginRequest(pydantic.BaseModel): + ... + + +class LoginResponse(pydantic.BaseModel): + token: str + + +class Client(ClientBase): + @POST('/login') + def login( + self: ty.Self, + *, + body: ty.Annotated[LoginRequest, RequestBody({'application/json': LoginRequest})], + ) -> ty.Annotated[ + Auth, + Responses({ + '200': { + 'application/json': ty.Annotated[ + LoginResponse, + APIKeyAuth( + in_='header', + name='Authorization', + format='Token {body.token}' + ), + ] + } + }), + ]: + """Authenticates with the "primary" security scheme""" +``` + +The top return Annotated declares the returned type, the inner one declares the processing steps for the actual response. +First the response is parsed as LoginResponse, then that object is passed to ApiKeyAuth which is a callable object. + +The result of the call, in this case an Auth object, is returned by the `login` function. + +The innermost Annotated is not necessary from the python syntax standpoint. It's done this way since it kind of matches the semantics of Annotated, but it could be replaced with a simple tuple or other type in the future. + +## Using auth tokens + +OpenApi allows operations to declare a collection of alternative groups of security requirements. + +The second most trivial example (the first being no security) is a single required security scheme. +```yaml +security: +- primary: [] +``` +The name of the security scheme corresponds to a key in `components.securitySchemes` object. + +This can be represented as a simple parameter, named for example `primary_auth` and of type `httpx.Auth`. +The parameter could be annotated as `Optional` if the security requirement is optional for the operation. + +In case of multiple alternative groups of security requirements, it gets harder to properly describe which schemes are required and in what combination. + +Lapidary takes all passed `httpx.Auth` parameters and passes them to `httpx.AsyncClient.send(..., auth=auth_params)`, leaving the responsibility to select the right ones to the user. + +If multiple `Auth` parameters are passed, they're wrapped in `lapidary.runtime.aauth.MultiAuth`, which is just reexported `_MultiAuth` from `https_auth` package. + +#### Example + +Auth object returned by the login operation declared in the above example can be used by another operation. + +```python +from typing import Annotated, Self +from httpx import Auth +from lapidary.runtime import ClientBase, GET, POST + + +class Client(ClientBase): + @POST('/login') + def login( + self: Self, + body: ..., + ) -> Annotated[ + Auth, + ... + ]: + """Authenticates with the "primary" security scheme""" + + @GET('/private') + def private( + self: Self, + *, + primary_auth: Auth, + ): + pass +``` + +In this example the method `client.private` can be called with the auth object returned by `client.login`. diff --git a/docs/python_representation.md b/docs/python_representation.md index bbaae2c..1768bc9 100644 --- a/docs/python_representation.md +++ b/docs/python_representation.md @@ -64,7 +64,7 @@ TBD ## Auth -TBD +See [Auth](auth.md). ## Servers diff --git a/src/lapidary/runtime/__init__.py b/src/lapidary/runtime/__init__.py index d658f41..6d85e63 100644 --- a/src/lapidary/runtime/__init__.py +++ b/src/lapidary/runtime/__init__.py @@ -1,5 +1,8 @@ from .absent import ABSENT, Absent +from .auth import APIKeyAuth from .client_base import ClientBase -from .model.params import ParamStyle +from .model.params import ParamLocation, ParamStyle +from .model.request import RequestBody +from .model.response_map import Responses from .operation import DELETE, GET, HEAD, PATCH, POST, PUT, TRACE from .param import Cookie, Header, Path, Query diff --git a/src/lapidary/runtime/auth.py b/src/lapidary/runtime/auth.py new file mode 100644 index 0000000..6f8aa8b --- /dev/null +++ b/src/lapidary/runtime/auth.py @@ -0,0 +1,55 @@ +__all__ = [ + 'APIKeyAuth', + 'MultiAuth', +] + +import abc +import dataclasses as dc +from typing import Mapping +import typing as ty + +import httpx +import httpx_auth +from httpx_auth.authentication import _MultiAuth as MultiAuth, _MultiAuth + +from .compat import typing as ty +from .model.api_key import CookieApiKey + +AuthType = ty.TypeVar("AuthType", bound=httpx.Auth) + + +class AuthFactory(abc.ABC): + @abc.abstractmethod + def __call__(self, body: object) -> httpx.Auth: + pass + + +APIKeyAuthLocation: ty.TypeAlias = ty.Literal['cookie', 'header', 'query'] + +api_key_in: ty.Mapping[APIKeyAuthLocation, ty.Type[AuthType]] = { + 'cookie': CookieApiKey, + 'header': httpx_auth.authentication.HeaderApiKey, + 'query': httpx_auth.authentication.QueryApiKey, +} + + +@dc.dataclass +class APIKeyAuth(AuthFactory): + in_: APIKeyAuthLocation + name: str + format: str + + def __call__(self, body: object) -> httpx.Auth: + typ = api_key_in[self.in_] + return typ(self.format.format(body=body), self.name) + + +def get_auth(params: Mapping[str, ty.Any]) -> ty.Optional[httpx.Auth]: + auth_params = [value for value in params.values() if isinstance(value, httpx.Auth)] + auth_num = len(auth_params) + if auth_num == 0: + return None + elif auth_num == 1: + return auth_params[0] + else: + return MultiAuth(*auth_params) diff --git a/src/lapidary/runtime/auth/__init__.py b/src/lapidary/runtime/auth/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/lapidary/runtime/auth/api_key.py b/src/lapidary/runtime/auth/api_key.py deleted file mode 100644 index b525cb3..0000000 --- a/src/lapidary/runtime/auth/api_key.py +++ /dev/null @@ -1,25 +0,0 @@ -from dataclasses import dataclass - -import httpx - -from .common import PageFlowGenT -from ..model.params import ParamLocation - - -@dataclass(eq=False, order=False, frozen=True) -class ApiKeyAuth(httpx.Auth): - api_key: str - name: str - placement: ParamLocation - - def auth_flow(self, request: httpx.Request) -> PageFlowGenT: - value = self.api_key - if self.placement is ParamLocation.header: - request.headers[self.name] = value - elif self.placement is ParamLocation.query: - request.url.params[self.name] = value - elif self.placement is ParamLocation.cookie: - request.headers.update({'Cookie': f'{self.name}={value}'}) - else: - raise ValueError(self.placement) - yield request diff --git a/src/lapidary/runtime/auth/common.py b/src/lapidary/runtime/auth/common.py deleted file mode 100644 index 38f7a48..0000000 --- a/src/lapidary/runtime/auth/common.py +++ /dev/null @@ -1,5 +0,0 @@ -import typing as ty - -import httpx - -PageFlowGenT = ty.Generator[httpx.Request, httpx.Response, None] diff --git a/src/lapidary/runtime/auth/http.py b/src/lapidary/runtime/auth/http.py deleted file mode 100644 index 339228e..0000000 --- a/src/lapidary/runtime/auth/http.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Optional - -from .api_key import ApiKeyAuth -from ..model.params import ParamLocation - - -class HTTPAuth(ApiKeyAuth): - def __init__(self, scheme: str, token: str, bearer_format: Optional[str] = None): - value_format_ = bearer_format if bearer_format and scheme.lower() == 'bearer' else '{token}' - super().__init__( - api_key=value_format_.format(token=token), - name='Authorization', - placement=ParamLocation.header, - ) diff --git a/src/lapidary/runtime/client_base.py b/src/lapidary/runtime/client_base.py index 4d68ae3..ecd1e29 100644 --- a/src/lapidary/runtime/client_base.py +++ b/src/lapidary/runtime/client_base.py @@ -1,32 +1,26 @@ -from abc import ABC -from collections.abc import Callable, Mapping -from functools import partial +import abc +from collections.abc import Mapping import logging import httpx -from .model import ResponseMap +from .auth import get_auth from .compat import typing as ty from .model.op import LapidaryOperation, get_operation_model from .request import RequestFactory, build_request -from .response import handle_response logger = logging.getLogger(__name__) -class ClientBase(ABC): +class ClientBase(abc.ABC): def __init__( self, - response_map: ResponseMap = None, - *, - _app: ty.Optional[Callable[..., ty.Any]] = None, **kwargs, ): - self._response_map = response_map or {} if 'base_url' not in kwargs: raise ValueError('Missing base_url.') - self._client = httpx.AsyncClient( app=_app, **kwargs) + self._client = httpx.AsyncClient(**kwargs) async def __aenter__(self): await self._client.__aenter__() @@ -52,24 +46,7 @@ async def _request( ty.cast(RequestFactory, self._client.build_request), ) - if logger.isEnabledFor(logging.DEBUG): - logger.debug("%s", f'{request.method} {request.url} {request.headers}') + logger.debug("%s %s %s", request.method, request.url, request.headers) - response_handler = partial(handle_response, operation.response_map) - - auth = get_auth(actual_params) - - response = await self._client.send(request, auth=auth) - return response_handler(response) - - -def get_auth(params: Mapping[str, ty.Any]) -> ty.Optional[httpx.Auth]: - auth_params = [value for value in params.values() if isinstance(value, httpx.Auth)] - auth_num = len(auth_params) - if auth_num == 0: - return None - elif auth_num == 1: - return auth_params[0] - else: - from httpx_auth.authentication import _MultiAuth - return _MultiAuth(*auth_params) + response = await self._client.send(request, auth=get_auth(actual_params)) + return operation.handle_response(response) diff --git a/src/lapidary/runtime/model/__init__.py b/src/lapidary/runtime/model/__init__.py index 4a9d736..e69de29 100644 --- a/src/lapidary/runtime/model/__init__.py +++ b/src/lapidary/runtime/model/__init__.py @@ -1,2 +0,0 @@ -from .params import FullParam, ParamLocation -from .response_map import ResponseMap, ReturnTypeInfo diff --git a/src/lapidary/runtime/model/api_key.py b/src/lapidary/runtime/model/api_key.py new file mode 100644 index 0000000..f14c5f1 --- /dev/null +++ b/src/lapidary/runtime/model/api_key.py @@ -0,0 +1,24 @@ +import httpx +import httpx_auth.authentication as authx + +from ..compat import typing as ty + + +class CookieApiKey(httpx.Auth, authx.SupportMultiAuth): + """Describes an API Key requests authentication.""" + + def __init__(self, api_key: str, cookie_name: str = None): + """ + :param api_key: The API key that will be sent. + :param cookie_name: Name of the query parameter. "api_key" by default. + """ + self.api_key = api_key + if not api_key: + raise Exception("API Key is mandatory.") + self.cookie_parameter_name = cookie_name or "api_key" + + def auth_flow( + self, request: httpx.Request + ) -> ty.Generator[httpx.Request, httpx.Response, None]: + request.headers['Cookie'] = f'{self.cookie_parameter_name}={self.api_key}' + yield request diff --git a/src/lapidary/runtime/model/op.py b/src/lapidary/runtime/model/op.py index 9ee5ddc..a9f612f 100644 --- a/src/lapidary/runtime/model/op.py +++ b/src/lapidary/runtime/model/op.py @@ -47,7 +47,7 @@ def process_params(self, actual_params: ty.Mapping[str, ty.Any]) -> ProcessedPar formal_param = self.params.get(param_name) if not formal_param: - raise KeyError(param_name) + raise TypeError(f'Operation {self.method} {self.path} got an unexpected argument {param_name}') if value is ABSENT: continue @@ -65,6 +65,35 @@ def process_params(self, actual_params: ty.Mapping[str, ty.Any]) -> ProcessedPar request_body=request_body, ) + def handle_response(self, response: httpx.Response) -> ty.Any: + """ + Possible special cases: + Exception + Auth + """ + + from ..response import find_type, parse_model + + response.read() + + typ = find_type(response, self.response_map) + + if typ is None: + response.raise_for_status() + return response.content + + obj: ty.Any = parse_model(response, typ) + + if '__metadata__' in dir(typ): + for anno in typ.__metadata__: + if callable(anno): + obj = anno(obj) + + if isinstance(obj, Exception): + raise obj + else: + return obj + @dc.dataclass class Operation: diff --git a/src/lapidary/runtime/model/params.py b/src/lapidary/runtime/model/params.py index f98f5b6..1029fbd 100644 --- a/src/lapidary/runtime/model/params.py +++ b/src/lapidary/runtime/model/params.py @@ -5,6 +5,7 @@ import uuid import httpx +import pydantic from ..compat import typing as ty @@ -107,6 +108,8 @@ def serialize_param(value, style: ParamStyle, explode_list: bool) -> ty.Iterator yield from values else: yield ','.join(values) + elif isinstance(value, (pydantic.BaseModel, pydantic.RootModel)): + yield value.model_dump_json() else: raise NotImplementedError(value) diff --git a/src/lapidary/runtime/model/request.py b/src/lapidary/runtime/model/request.py index 1599286..8375ba0 100644 --- a/src/lapidary/runtime/model/request.py +++ b/src/lapidary/runtime/model/request.py @@ -1,15 +1,14 @@ import dataclasses as dc from ..compat import typing as ty -from lapidary.runtime.types_ import Serializer @dc.dataclass(frozen=True) class RequestBody: - content: ty.Mapping[str, ty.Tuple[ty.Type, Serializer]] + content: ty.Mapping[str, ty.Type] @dc.dataclass(frozen=True) class RequestBodyModel: param_name: str - serializers: ty.Mapping[str, ty.Tuple[ty.Type, Serializer]] + serializers: ty.Mapping[str, ty.Type] diff --git a/src/lapidary/runtime/model/response_map.py b/src/lapidary/runtime/model/response_map.py index d060ef5..9b46282 100644 --- a/src/lapidary/runtime/model/response_map.py +++ b/src/lapidary/runtime/model/response_map.py @@ -6,20 +6,10 @@ MimeType = str ResponseCode = str - -@dc.dataclass -class ReturnTypeInfo: - type: ty.Type - iterator: bool = False - - -MimeMap = ty.Mapping[MimeType, ReturnTypeInfo] +MimeMap = ty.Mapping[MimeType, ty.Type] ResponseMap = ty.Mapping[ResponseCode, MimeMap] @dc.dataclass(frozen=True) class Responses: responses: ResponseMap - - def get_response(self, response_code: int, mime_type: str) -> ReturnTypeInfo: - pass diff --git a/src/lapidary/runtime/request.py b/src/lapidary/runtime/request.py index a7a7763..c3bbf31 100644 --- a/src/lapidary/runtime/request.py +++ b/src/lapidary/runtime/request.py @@ -3,13 +3,31 @@ from .compat import typing as ty from .http_consts import ACCEPT, CONTENT_TYPE, MIME_JSON from .mime import find_mime -from .model import ResponseMap from .model.op import OperationModel +from .model.params import serialize_param from .model.request import RequestBodyModel +from .model.response_map import ResponseMap +from .param import ParamStyle from .types_ import Serializer -# accepts parameters of httpx.Client.build_request -RequestFactory = ty.Callable[..., httpx.Request] + +class RequestFactory(ty.Protocol): + def __call__( + self, + method: str, + url: str, + *, + content: ty.Optional[httpx._types.RequestContent] = None, + data: ty.Optional[httpx._types.RequestData] = None, + files: ty.Optional[httpx._types.RequestFiles] = None, + json: ty.Optional[ty.Any] = None, + params: ty.Optional[httpx._types.QueryParamTypes] = None, + headers: ty.Optional[httpx._types.HeaderTypes] = None, + cookies: ty.Optional[httpx._types.CookieTypes] = None, + timeout: ty.Union[httpx._types.TimeoutTypes, httpx._client.UseClientDefault] = httpx.USE_CLIENT_DEFAULT, + extensions: ty.Optional[httpx._types.RequestExtensions] = None, + ) -> httpx.Request: + pass def get_accept_header(response_map: ty.Optional[ResponseMap]) -> ty.Optional[str]: @@ -56,13 +74,15 @@ def find_request_body_serializer( model: ty.Optional[RequestBodyModel], obj: ty.Any, ) -> ty.Tuple[str, Serializer]: - # find the serializer by type obj_type = type(obj) - for content_type, ser_info in model.serializers.items(): - type_, serializer = ser_info - if type_ == obj_type: - return content_type, serializer + for content_type, typ in model.serializers.items(): + if typ == obj_type: + async def serialize(model): + for item in serialize_param(model, ParamStyle.simple, explode_list=False): + yield item.encode() + return content_type, serialize + # return content_type, lambda model: serialize_param(model, ParamStyle.simple, explode_list=False) raise TypeError(f'Unknown serializer for {type(obj)}') diff --git a/src/lapidary/runtime/response.py b/src/lapidary/runtime/response.py index ce74de5..c002bc3 100644 --- a/src/lapidary/runtime/response.py +++ b/src/lapidary/runtime/response.py @@ -7,44 +7,13 @@ from .compat import typing as ty from .http_consts import CONTENT_TYPE from .mime import find_mime -from .model import ResponseMap -from .model.response_map import ReturnTypeInfo +from .model.response_map import ResponseMap logger = logging.getLogger(__name__) - - -def handle_response( - response_map: ResponseMap, - response: httpx.Response, -) -> Any: - response.read() - - type_info = find_type(response, response_map) - - if type_info is None: - response.raise_for_status() - return response.content - - try: - obj: Any = parse_model(response, type_info.type) - except pydantic.ValidationError as error: - raise ValueError(response.content) from error - - if isinstance(obj, Exception): - raise obj - elif type_info.iterator: - return aiter2(obj) - else: - return obj - T = ty.TypeVar('T') P = ty.TypeVar('P') -async def aiter2(values: Iterable[T]) -> AsyncIterator[T]: - """Turn Iterable to AsyncIterator (AsyncGenerator really).""" - for value in values: - yield value def parse_model(response: httpx.Response, typ: ty.Type[T]) -> T: if inspect.isclass(typ): diff --git a/src/lapidary/runtime/types_.py b/src/lapidary/runtime/types_.py index d78702d..3c02bb0 100644 --- a/src/lapidary/runtime/types_.py +++ b/src/lapidary/runtime/types_.py @@ -1,4 +1,7 @@ +import httpx + from .compat import typing as ty Serializer: ty.TypeAlias = ty.Callable[[ty.Any], ty.Union[str, bytes]] RequestContent: ty.TypeAlias = ty.Union[str, bytes, ty.Iterable[bytes], ty.AsyncIterable[bytes]] +PageFlowGenT = ty.Generator[httpx.Request, httpx.Response, None] diff --git a/tests/__init__.py b/tests/__init__.py index 871a258..884913b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,7 +2,6 @@ from lapidary.runtime import ClientBase, GET, PUT, ParamStyle, Path from lapidary.runtime.compat import typing as ty -from lapidary.runtime.model import ReturnTypeInfo from lapidary.runtime.model.request import RequestBody from lapidary.runtime.model.response_map import Responses @@ -25,7 +24,7 @@ async def get_cat( id_p: ty.Annotated[int, Path('id', ParamStyle.simple)], ) -> ty.Annotated[Cat, Responses({ 'default': { - 'application/json': ReturnTypeInfo(Cat) + 'application/json': Cat } })]: ... @@ -37,7 +36,7 @@ async def put_cat( body: ty.Annotated[Cat, RequestBody({'application/json': (Cat, lambda model: model.model_dump_json())})], ) -> ty.Annotated[Cat, Responses({ 'default': { - 'application/json': ReturnTypeInfo(Cat) + 'application/json': Cat } })]: ... diff --git a/tests/test_client.py b/tests/test_client.py index 2e3da15..26fcbec 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,25 +1,26 @@ import logging import unittest +import httpx +import pydantic from starlette.applications import Starlette from starlette.responses import JSONResponse from starlette.routing import Route -from lapidary.runtime import ClientBase, GET -from lapidary.runtime.model import ReturnTypeInfo +from lapidary.runtime import APIKeyAuth, ClientBase, GET, POST from lapidary.runtime.compat import typing as ty from lapidary.runtime.model.response_map import Responses -import logging - logger = logging.getLogger(__name__) logging.getLogger('lapidary').setLevel(logging.DEBUG) +class AuthResponse(pydantic.BaseModel): + api_key: str + + class Client(ClientBase): @GET('/strings') - async def get_strings(self: typing_extensions.Self) -> typing.Annotated[ - List[str], Responses({ async def get_strings(self: ty.Self) -> ty.Annotated[ ty.List[str], Responses({ @@ -28,7 +29,26 @@ async def get_strings(self: ty.Self) -> ty.Annotated[ } }) ]: - ... + pass + + @POST('/login') + async def login(self: ty.Self, + ) -> ty.Annotated[ + httpx.Auth, + Responses({ + '200': { + 'application/json': ty.Annotated[ + AuthResponse, + APIKeyAuth( + 'header', + 'Authorization', + 'Token {body.api_key}' + ), + ] + } + }) + ]: + pass class TestClient(unittest.IsolatedAsyncioTestCase): @@ -40,7 +60,19 @@ async def handler(_): Route('/strings', handler), ]) - client = Client('http://example.com', _app=app) + client = Client(base_url='http://example.com', app=app) response = await client.get_strings() self.assertIsInstance(response, list) self.assertEqual(['a', 'b', 'c'], response) + + async def test_response_auth(self): + async def handler(_): + return JSONResponse({'api_key': 'token'}) + + app = Starlette(debug=True, routes=[ + Route('/login', handler, methods=['POST']), + ]) + + client = Client(base_url='http://example.com', app=app) + response = await client.login() + self.assertEqual(dict(api_key='Token token', header_name='Authorization'), response.__dict__) diff --git a/tests/test_request.py b/tests/test_request.py index e4cc798..fcd2cb2 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -1,45 +1,49 @@ -import unittest from typing import List +import unittest from unittest.mock import Mock import httpx +from httpx import Cookies, Headers, QueryParams import pydantic -from httpx import QueryParams, Headers, Cookies from lapidary.runtime import ParamLocation -from lapidary.runtime.model import FullParam from lapidary.runtime.model.op import OperationModel +from lapidary.runtime.model.params import FullParam, ParamStyle from lapidary.runtime.model.request import RequestBodyModel -from lapidary.runtime.model.params import ParamStyle from lapidary.runtime.request import build_request -class BuildRequestTestCase(unittest.TestCase): - def test_build_request_from_list(self) -> None: +class BuildRequestAsyncTestCase(unittest.IsolatedAsyncioTestCase): + async def test_build_request_from_list(self) -> None: class MyRequestBodyModel(pydantic.BaseModel): a: str class MyRequestBodyList(pydantic.RootModel): root: List[MyRequestBodyModel] - deser = pydantic.TypeAdapter(List[MyRequestBodyModel]) - request_factory = Mock() build_request( - operation=OperationModel('GET', 'path', {}, RequestBodyModel('body', {'application/json': (MyRequestBodyList, lambda model: deser.dump_json(model))}), {}), + operation=OperationModel('GET', 'path', {}, RequestBodyModel('body', {'application/json': MyRequestBodyList}), {}), actual_params={'body': MyRequestBodyList([MyRequestBodyModel(a='a')])}, request_factory=request_factory ) - request_factory.assert_called_with( + call_args, call_kwargs = request_factory.call_args + call_kwargs['content'] = [item async for item in call_kwargs['content']] + + assert call_args == ( 'GET', 'path', - content=b'[{"a":"a"}]', - params=httpx.QueryParams(), - headers=httpx.Headers({'content-type': 'application/json'}), - cookies=httpx.Cookies(), ) + assert call_kwargs == { + 'content': [b'[{"a":"a"}]'], + 'params': httpx.QueryParams(), + 'headers': httpx.Headers({'content-type': 'application/json'}), + 'cookies': httpx.Cookies(), + } + +class BuildRequestTestCase(unittest.TestCase): def test_build_request_none(self): request_factory = Mock() build_request(