Skip to content

Commit

Permalink
Ensure only a single QUIC timer task per connection
Browse files Browse the repository at this point in the history
This prevents Hypercorn calling aioquic's interface too many times,
due to many tasks running concurrently, triggering exponential backoff
and errors.

Many thanks to @rthalley from whom's work this is based.
  • Loading branch information
pgjones committed May 27, 2024
1 parent 81bbb32 commit ab98383
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 85 deletions.
34 changes: 9 additions & 25 deletions src/hypercorn/asyncio/tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import asyncio
from ssl import SSLError
from typing import Any, Generator, Optional
from typing import Any, Generator

from .task_group import TaskGroup
from .worker_context import WorkerContext
from .worker_context import AsyncioSingleTask, WorkerContext
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..protocol import ProtocolWrapper
Expand Down Expand Up @@ -33,9 +33,7 @@ def __init__(
self.reader = reader
self.writer = writer
self.send_lock = asyncio.Lock()
self.idle_lock = asyncio.Lock()

self._idle_handle: Optional[asyncio.Task] = None
self.idle_task = AsyncioSingleTask()

def __await__(self) -> Generator[Any, None, None]:
return self.run().__await__()
Expand All @@ -54,6 +52,7 @@ async def run(self) -> None:
alpn_protocol = "http/1.1"

async with TaskGroup(self.loop) as task_group:
self._task_group = task_group
self.protocol = ProtocolWrapper(
self.app,
self.config,
Expand All @@ -66,7 +65,7 @@ async def run(self) -> None:
alpn_protocol,
)
await self.protocol.initiate()
await self._start_idle()
await self.idle_task.restart(task_group, self._idle_timeout)
await self._read_data()
except OSError:
pass
Expand All @@ -85,9 +84,9 @@ async def protocol_send(self, event: Event) -> None:
await self._close()
elif isinstance(event, Updated):
if event.idle:
await self._start_idle()
await self.idle_task.restart(self._task_group, self._idle_timeout)
else:
await self._stop_idle()
await self.idle_task.stop()

async def _read_data(self) -> None:
while not self.reader.at_eof():
Expand Down Expand Up @@ -124,28 +123,13 @@ async def _close(self) -> None:
):
pass # Already closed
finally:
await self._stop_idle()
await self.idle_task.stop()

async def _initiate_server_close(self) -> None:
await self.protocol.handle(Closed())
self.writer.close()

async def _start_idle(self) -> None:
async with self.idle_lock:
if self._idle_handle is None:
self._idle_handle = self.loop.create_task(self._run_idle())

async def _stop_idle(self) -> None:
async with self.idle_lock:
if self._idle_handle is not None:
self._idle_handle.cancel()
try:
await self._idle_handle
except asyncio.CancelledError:
pass
self._idle_handle = None

async def _run_idle(self) -> None:
async def _idle_timeout(self) -> None:
try:
await asyncio.wait_for(self.context.terminated.wait(), self.config.keep_alive_timeout)
except asyncio.TimeoutError:
Expand Down
33 changes: 31 additions & 2 deletions src/hypercorn/asyncio/worker_context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,37 @@
from __future__ import annotations

import asyncio
from typing import Optional, Type, Union
from typing import Callable, Optional, Type, Union

from ..typing import Event
from ..typing import Event, SingleTask, TaskGroup


class AsyncioSingleTask:
def __init__(self) -> None:
self._handle: Optional[asyncio.Task] = None
self._lock = asyncio.Lock()

async def restart(self, task_group: TaskGroup, action: Callable) -> None:
async with self._lock:
if self._handle is not None:
self._handle.cancel()
try:
await self._handle
except asyncio.CancelledError:
pass

self._handle = task_group._task_group.create_task(action()) # type: ignore

async def stop(self) -> None:
async with self._lock:
if self._handle is not None:
self._handle.cancel()
try:
await self._handle
except asyncio.CancelledError:
pass

self._handle = None


class EventWrapper:
Expand All @@ -25,6 +53,7 @@ def is_set(self) -> bool:

class WorkerContext:
event_class: Type[Event] = EventWrapper
single_task_class: Type[SingleTask] = AsyncioSingleTask

def __init__(self, max_requests: Optional[int]) -> None:
self.max_requests = max_requests
Expand Down
68 changes: 44 additions & 24 deletions src/hypercorn/protocol/quic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from dataclasses import dataclass
from functools import partial
from typing import Awaitable, Callable, Dict, Optional, Tuple
from typing import Awaitable, Callable, Dict, Optional, Set, Tuple

from aioquic.buffer import Buffer
from aioquic.h3.connection import H3_ALPN
Expand All @@ -22,7 +23,15 @@
from .h3 import H3Protocol
from ..config import Config
from ..events import Closed, Event, RawData
from ..typing import AppWrapper, TaskGroup, WorkerContext
from ..typing import AppWrapper, SingleTask, TaskGroup, WorkerContext


@dataclass
class _Connection:
cids: Set[bytes]
quic: QuicConnection
task: SingleTask
h3: Optional[H3Protocol] = None


class QuicProtocol:
Expand All @@ -38,8 +47,7 @@ def __init__(
self.app = app
self.config = config
self.context = context
self.connections: Dict[bytes, QuicConnection] = {}
self.http_connections: Dict[QuicConnection, H3Protocol] = {}
self.connections: Dict[bytes, _Connection] = {}
self.send = send
self.server = server
self.task_group = task_group
Expand All @@ -49,7 +57,7 @@ def __init__(

@property
def idle(self) -> bool:
return len(self.connections) == 0 and len(self.http_connections) == 0
return len(self.connections) == 0

async def handle(self, event: Event) -> None:
if isinstance(event, RawData):
Expand All @@ -76,32 +84,46 @@ async def handle(self, event: Event) -> None:
and header.packet_type == PACKET_TYPE_INITIAL
and not self.context.terminated.is_set()
):
connection = QuicConnection(
quic_connection = QuicConnection(
configuration=self.quic_config,
original_destination_connection_id=header.destination_cid,
)
connection = _Connection(
cids={header.destination_cid, quic_connection.host_cid},
quic=quic_connection,
task=self.context.single_task_class(),
)
self.connections[header.destination_cid] = connection
self.connections[connection.host_cid] = connection
self.connections[quic_connection.host_cid] = connection

if connection is not None:
connection.receive_datagram(event.data, event.address, now=self.context.time())
connection.quic.receive_datagram(event.data, event.address, now=self.context.time())
await self._handle_events(connection, event.address)
elif isinstance(event, Closed):
pass

async def send_all(self, connection: QuicConnection) -> None:
for data, address in connection.datagrams_to_send(now=self.context.time()):
async def send_all(self, connection: _Connection) -> None:
for data, address in connection.quic.datagrams_to_send(now=self.context.time()):
await self.send(RawData(data=data, address=address))

timer = connection.quic.get_timer()
if timer is not None:
await connection.task.restart(
self.task_group, partial(self._handle_timer, timer, connection)
)

async def _handle_events(
self, connection: QuicConnection, client: Optional[Tuple[str, int]] = None
self, connection: _Connection, client: Optional[Tuple[str, int]] = None
) -> None:
event = connection.next_event()
event = connection.quic.next_event()
while event is not None:
if isinstance(event, ConnectionTerminated):
pass
await connection.task.stop()
for cid in connection.cids:
del self.connections[cid]
connection.cids = set()
elif isinstance(event, ProtocolNegotiated):
self.http_connections[connection] = H3Protocol(
connection.h3 = H3Protocol(
self.app,
self.config,
self.context,
Expand All @@ -112,24 +134,22 @@ async def _handle_events(
partial(self.send_all, connection),
)
elif isinstance(event, ConnectionIdIssued):
connection.cids.add(event.connection_id)
self.connections[event.connection_id] = connection
elif isinstance(event, ConnectionIdRetired):
connection.cids.remove(event.connection_id)
del self.connections[event.connection_id]

if connection in self.http_connections:
await self.http_connections[connection].handle(event)
if connection.h3 is not None:
await connection.h3.handle(event)

event = connection.next_event()
event = connection.quic.next_event()

await self.send_all(connection)

timer = connection.get_timer()
if timer is not None:
self.task_group.spawn(self._handle_timer, timer, connection)

async def _handle_timer(self, timer: float, connection: QuicConnection) -> None:
async def _handle_timer(self, timer: float, connection: _Connection) -> None:
wait = max(0, timer - self.context.time())
await self.context.sleep(wait)
if connection._close_at is not None:
connection.handle_timer(now=self.context.time())
if connection.quic._close_at is not None:

This comment has been minimized.

Copy link
@rthalley

rthalley May 27, 2024

This referencing of _close_at is reaching into aioquic's internal state, and more importantly isn't right as you still want to handle timers even when closing.

This comment has been minimized.

Copy link
@pgjones

pgjones May 27, 2024

Author Owner

Would you just call handle_timer instead?

This comment has been minimized.

Copy link
@rthalley

rthalley May 27, 2024

Yeah. That's what happens in my patch.

This comment has been minimized.

Copy link
@pgjones

pgjones May 27, 2024

Author Owner

Thanks, ba3d813

connection.quic.handle_timer(now=self.context.time())
await self._handle_events(connection, None)
45 changes: 13 additions & 32 deletions src/hypercorn/trio/tcp_server.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

from math import inf
from typing import Any, Generator, Optional
from typing import Any, Generator

import trio

from .task_group import TaskGroup
from .worker_context import WorkerContext
from .worker_context import TrioSingleTask, WorkerContext
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..protocol import ProtocolWrapper
Expand All @@ -25,11 +25,9 @@ def __init__(
self.context = context
self.protocol: ProtocolWrapper
self.send_lock = trio.Lock()
self.idle_lock = trio.Lock()
self.idle_task = TrioSingleTask()
self.stream = stream

self._idle_handle: Optional[trio.CancelScope] = None

def __await__(self) -> Generator[Any, None, None]:
return self.run().__await__()

Expand Down Expand Up @@ -66,7 +64,7 @@ async def run(self) -> None:
alpn_protocol,
)
await self.protocol.initiate()
await self._start_idle()
await self.idle_task.restart(self._task_group, self._idle_timeout)
await self._read_data()
except OSError:
pass
Expand All @@ -87,9 +85,9 @@ async def protocol_send(self, event: Event) -> None:
await self.protocol.handle(Closed())
elif isinstance(event, Updated):
if event.idle:
await self._start_idle()
await self.idle_task.restart(self._task_group, self._idle_timeout)
else:
await self._stop_idle()
await self.idle_task.stop()

async def _read_data(self) -> None:
while True:
Expand Down Expand Up @@ -122,30 +120,13 @@ async def _close(self) -> None:
pass
await self.stream.aclose()

async def _idle_timeout(self) -> None:
with trio.move_on_after(self.config.keep_alive_timeout):
await self.context.terminated.wait()

with trio.CancelScope(shield=True):
await self._initiate_server_close()

async def _initiate_server_close(self) -> None:
await self.protocol.handle(Closed())
await self.stream.aclose()

async def _start_idle(self) -> None:
async with self.idle_lock:
if self._idle_handle is None:
self._idle_handle = await self._task_group._nursery.start(self._run_idle)

async def _stop_idle(self) -> None:
async with self.idle_lock:
if self._idle_handle is not None:
self._idle_handle.cancel()
self._idle_handle = None

async def _run_idle(
self,
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
) -> None:
cancel_scope = trio.CancelScope()
task_status.started(cancel_scope)
with cancel_scope:
with trio.move_on_after(self.config.keep_alive_timeout):
await self.context.terminated.wait()

cancel_scope.shield = True
await self._initiate_server_close()
Loading

0 comments on commit ab98383

Please sign in to comment.