Skip to content

Commit

Permalink
💥 Upgrade auth model.
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Krupinski committed Jan 9, 2024
1 parent 26c6999 commit 1940bfd
Show file tree
Hide file tree
Showing 21 changed files with 331 additions and 159 deletions.
111 changes: 111 additions & 0 deletions docs/auth.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Authentication

OpenAPI allows declaring security schemes and security requirements of operations.

Lapidary allows to declare python methods that create or consume `https.Auth` objects.

## Generating auth tokens


A `/login/` or `/authenticate/` endpoint that returns the token is quite common with simpler authentication schemes like http or apiKey, yet their support is poor in OpenAPI. There's no way to connect
such endpoint to a security scheme as in the case of OIDC.

A function that handles such an endpoint can declare that it returns an Auth object, but it's not obvious to the user of python API which security scheme the method returns.

```python
import pydantic
import typing_extensions as ty
from lapidary.runtime import POST, ClientBase, RequestBody, APIKeyAuth, Responses
from httpx import Auth


class LoginRequest(pydantic.BaseModel):
...


class LoginResponse(pydantic.BaseModel):
token: str


class Client(ClientBase):
@POST('/login')
def login(
self: ty.Self,
*,
body: ty.Annotated[LoginRequest, RequestBody({'application/json': LoginRequest})],
) -> ty.Annotated[
Auth,
Responses({
'200': {
'application/json': ty.Annotated[
LoginResponse,
APIKeyAuth(
in_='header',
name='Authorization',
format='Token {body.token}'
),
]
}
}),
]:
"""Authenticates with the "primary" security scheme"""
```

The top return Annotated declares the returned type, the inner one declares the processing steps for the actual response.
First the response is parsed as LoginResponse, then that object is passed to ApiKeyAuth which is a callable object.

The result of the call, in this case an Auth object, is returned by the `login` function.

The innermost Annotated is not necessary from the python syntax standpoint. It's done this way since it kind of matches the semantics of Annotated, but it could be replaced with a simple tuple or other type in the future.

## Using auth tokens

OpenApi allows operations to declare a collection of alternative groups of security requirements.

The second most trivial example (the first being no security) is a single required security scheme.
```yaml
security:
- primary: []
```
The name of the security scheme corresponds to a key in `components.securitySchemes` object.

This can be represented as a simple parameter, named for example `primary_auth` and of type `httpx.Auth`.
The parameter could be annotated as `Optional` if the security requirement is optional for the operation.

In case of multiple alternative groups of security requirements, it gets harder to properly describe which schemes are required and in what combination.

Lapidary takes all passed `httpx.Auth` parameters and passes them to `httpx.AsyncClient.send(..., auth=auth_params)`, leaving the responsibility to select the right ones to the user.

If multiple `Auth` parameters are passed, they're wrapped in `lapidary.runtime.aauth.MultiAuth`, which is just reexported `_MultiAuth` from `https_auth` package.

#### Example

Auth object returned by the login operation declared in the above example can be used by another operation.

```python
from typing import Annotated, Self
from httpx import Auth
from lapidary.runtime import ClientBase, GET, POST
class Client(ClientBase):
@POST('/login')
def login(
self: Self,
body: ...,
) -> Annotated[
Auth,
...
]:
"""Authenticates with the "primary" security scheme"""
@GET('/private')
def private(
self: Self,
*,
primary_auth: Auth,
):
pass
```

In this example the method `client.private` can be called with the auth object returned by `client.login`.
2 changes: 1 addition & 1 deletion docs/python_representation.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ TBD

## Auth

TBD
See [Auth](auth.md).

## Servers

Expand Down
5 changes: 4 additions & 1 deletion src/lapidary/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from .absent import ABSENT, Absent
from .auth import APIKeyAuth
from .client_base import ClientBase
from .model.params import ParamStyle
from .model.params import ParamLocation, ParamStyle
from .model.request import RequestBody
from .model.response_map import Responses
from .operation import DELETE, GET, HEAD, PATCH, POST, PUT, TRACE
from .param import Cookie, Header, Path, Query
55 changes: 55 additions & 0 deletions src/lapidary/runtime/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
__all__ = [
'APIKeyAuth',
'MultiAuth',
]

import abc
import dataclasses as dc
from typing import Mapping
import typing as ty

import httpx
import httpx_auth
from httpx_auth.authentication import _MultiAuth as MultiAuth, _MultiAuth

from .compat import typing as ty
from .model.api_key import CookieApiKey

AuthType = ty.TypeVar("AuthType", bound=httpx.Auth)


class AuthFactory(abc.ABC):
@abc.abstractmethod
def __call__(self, body: object) -> httpx.Auth:
pass


APIKeyAuthLocation: ty.TypeAlias = ty.Literal['cookie', 'header', 'query']

api_key_in: ty.Mapping[APIKeyAuthLocation, ty.Type[AuthType]] = {
'cookie': CookieApiKey,
'header': httpx_auth.authentication.HeaderApiKey,
'query': httpx_auth.authentication.QueryApiKey,
}


@dc.dataclass
class APIKeyAuth(AuthFactory):
in_: APIKeyAuthLocation
name: str
format: str

def __call__(self, body: object) -> httpx.Auth:
typ = api_key_in[self.in_]
return typ(self.format.format(body=body), self.name)


def get_auth(params: Mapping[str, ty.Any]) -> ty.Optional[httpx.Auth]:
auth_params = [value for value in params.values() if isinstance(value, httpx.Auth)]
auth_num = len(auth_params)
if auth_num == 0:
return None
elif auth_num == 1:
return auth_params[0]
else:
return MultiAuth(*auth_params)
Empty file.
25 changes: 0 additions & 25 deletions src/lapidary/runtime/auth/api_key.py

This file was deleted.

5 changes: 0 additions & 5 deletions src/lapidary/runtime/auth/common.py

This file was deleted.

14 changes: 0 additions & 14 deletions src/lapidary/runtime/auth/http.py

This file was deleted.

39 changes: 8 additions & 31 deletions src/lapidary/runtime/client_base.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,26 @@
from abc import ABC
from collections.abc import Callable, Mapping
from functools import partial
import abc
from collections.abc import Mapping
import logging

import httpx

from .model import ResponseMap
from .auth import get_auth
from .compat import typing as ty
from .model.op import LapidaryOperation, get_operation_model
from .request import RequestFactory, build_request
from .response import handle_response

logger = logging.getLogger(__name__)


class ClientBase(ABC):
class ClientBase(abc.ABC):
def __init__(
self,
response_map: ResponseMap = None,
*,
_app: ty.Optional[Callable[..., ty.Any]] = None,
**kwargs,
):
self._response_map = response_map or {}
if 'base_url' not in kwargs:
raise ValueError('Missing base_url.')

self._client = httpx.AsyncClient( app=_app, **kwargs)
self._client = httpx.AsyncClient(**kwargs)

async def __aenter__(self):
await self._client.__aenter__()
Expand All @@ -52,24 +46,7 @@ async def _request(
ty.cast(RequestFactory, self._client.build_request),
)

if logger.isEnabledFor(logging.DEBUG):
logger.debug("%s", f'{request.method} {request.url} {request.headers}')
logger.debug("%s %s %s", request.method, request.url, request.headers)

response_handler = partial(handle_response, operation.response_map)

auth = get_auth(actual_params)

response = await self._client.send(request, auth=auth)
return response_handler(response)


def get_auth(params: Mapping[str, ty.Any]) -> ty.Optional[httpx.Auth]:
auth_params = [value for value in params.values() if isinstance(value, httpx.Auth)]
auth_num = len(auth_params)
if auth_num == 0:
return None
elif auth_num == 1:
return auth_params[0]
else:
from httpx_auth.authentication import _MultiAuth
return _MultiAuth(*auth_params)
response = await self._client.send(request, auth=get_auth(actual_params))
return operation.handle_response(response)
2 changes: 0 additions & 2 deletions src/lapidary/runtime/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
from .params import FullParam, ParamLocation
from .response_map import ResponseMap, ReturnTypeInfo
24 changes: 24 additions & 0 deletions src/lapidary/runtime/model/api_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import httpx
import httpx_auth.authentication as authx

from ..compat import typing as ty


class CookieApiKey(httpx.Auth, authx.SupportMultiAuth):
"""Describes an API Key requests authentication."""

def __init__(self, api_key: str, cookie_name: str = None):
"""
:param api_key: The API key that will be sent.
:param cookie_name: Name of the query parameter. "api_key" by default.
"""
self.api_key = api_key
if not api_key:
raise Exception("API Key is mandatory.")
self.cookie_parameter_name = cookie_name or "api_key"

def auth_flow(
self, request: httpx.Request
) -> ty.Generator[httpx.Request, httpx.Response, None]:
request.headers['Cookie'] = f'{self.cookie_parameter_name}={self.api_key}'
yield request
31 changes: 30 additions & 1 deletion src/lapidary/runtime/model/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def process_params(self, actual_params: ty.Mapping[str, ty.Any]) -> ProcessedPar

formal_param = self.params.get(param_name)
if not formal_param:
raise KeyError(param_name)
raise TypeError(f'Operation {self.method} {self.path} got an unexpected argument {param_name}')

if value is ABSENT:
continue
Expand All @@ -65,6 +65,35 @@ def process_params(self, actual_params: ty.Mapping[str, ty.Any]) -> ProcessedPar
request_body=request_body,
)

def handle_response(self, response: httpx.Response) -> ty.Any:
"""
Possible special cases:
Exception
Auth
"""

from ..response import find_type, parse_model

response.read()

typ = find_type(response, self.response_map)

if typ is None:
response.raise_for_status()
return response.content

obj: ty.Any = parse_model(response, typ)

if '__metadata__' in dir(typ):
for anno in typ.__metadata__:
if callable(anno):
obj = anno(obj)

if isinstance(obj, Exception):
raise obj
else:
return obj


@dc.dataclass
class Operation:
Expand Down
3 changes: 3 additions & 0 deletions src/lapidary/runtime/model/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import uuid

import httpx
import pydantic

from ..compat import typing as ty

Expand Down Expand Up @@ -107,6 +108,8 @@ def serialize_param(value, style: ParamStyle, explode_list: bool) -> ty.Iterator
yield from values
else:
yield ','.join(values)
elif isinstance(value, (pydantic.BaseModel, pydantic.RootModel)):
yield value.model_dump_json()

else:
raise NotImplementedError(value)
Loading

0 comments on commit 1940bfd

Please sign in to comment.