Skip to content

Commit

Permalink
Fix issues with recent PRs
Browse files Browse the repository at this point in the history
This fixes some linting issues and restructures some code to be
consistent with the code around it.
  • Loading branch information
pgjones committed May 15, 2024
1 parent ecce16a commit cd6c97b
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 12 deletions.
13 changes: 10 additions & 3 deletions src/quart_schema/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class MsgSpecValidationError(Exception): # type: ignore

T = TypeVar("T", bound=Model)

JsonSchemaMode = Literal['validation', 'serialization']
JsonSchemaMode = Literal["validation", "serialization"]


def convert_response_return_value(
Expand Down Expand Up @@ -186,9 +186,16 @@ def model_load(
raise exception_class(error)


def model_schema(model_class: Type[Model], *, preference: Optional[str] = None, schema_mode: JsonSchemaMode = "validation") -> dict:
def model_schema(
model_class: Type[Model],
*,
preference: Optional[str] = None,
schema_mode: JsonSchemaMode = "validation",
) -> dict:
if _use_pydantic(model_class, preference):
return TypeAdapter(model_class).json_schema(ref_template=PYDANTIC_REF_TEMPLATE, mode=schema_mode)
return TypeAdapter(model_class).json_schema(
ref_template=PYDANTIC_REF_TEMPLATE, mode=schema_mode
)
elif _use_msgspec(model_class, preference):
_, schema = schema_components([model_class], ref_template=MSGSPEC_REF_TEMPLATE)
return list(schema.values())[0]
Expand Down
11 changes: 7 additions & 4 deletions src/quart_schema/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,13 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
if source == DataSource.JSON:
data = await request.get_json()
else:
data = (await request.form).to_dict(flat=False)
for key, value in data.items():
if len(value) == 1:
data[key] = value[0]
data = {}
form = await request.form
for key in form:
if len(form.getlist(key)) > 1:
data[key] = form.getlist(key)
else:
data[key] = form[key]
if source == DataSource.FORM_MULTIPART:
files = await request.files
for key in files:
Expand Down
8 changes: 3 additions & 5 deletions tests/test_openapi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Dict, List, Optional, Tuple, Type
from typing_extensions import Annotated

import pytest
from pydantic import BaseModel, Field, computed_field
from pydantic import BaseModel, computed_field, Field
from pydantic.dataclasses import dataclass
from pydantic.functional_serializers import PlainSerializer
from quart import Quart

from quart_schema import (
Expand Down Expand Up @@ -273,13 +271,13 @@ class EmployeeWithComputedField(BaseModel):
first_name: str
last_name: str

@computed_field
@computed_field # type: ignore[misc]
@property
def full_name(self) -> str:
return f"{self.first_name} {self.last_name}"


async def test_response_model_with_computed_field():
async def test_response_model_with_computed_field() -> None:
"""
Test that routes returning a response model that has one or more computed fields have the
appropriate properties in the generated JSON schema.
Expand Down
24 changes: 24 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,30 @@ async def item(data: Any) -> ResponseReturnValue:
assert response.status_code == status


class MultiItem(BaseModel):
multi: List[int]
single: int


async def test_request_form_validation_multi() -> None:
app = Quart(__name__)
QuartSchema(app)

@app.route("/", methods=["POST"])
@validate_request(MultiItem, source=DataSource.FORM)
async def item(data: MultiItem) -> MultiItem:
return data

test_client = app.test_client()
response = await test_client.post(
"/",
data=b"multi=1&multi=2&single=2",
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
assert response.status_code == 200
assert await response.get_json() == {"multi": [1, 2], "single": 2}


async def test_request_file_validation() -> None:
app = Quart(__name__)
QuartSchema(app)
Expand Down

0 comments on commit cd6c97b

Please sign in to comment.