diff --git a/twitchio/client.py b/twitchio/client.py index 47e06e72..6c90dc41 100644 --- a/twitchio/client.py +++ b/twitchio/client.py @@ -43,7 +43,7 @@ from .models.teams import Team from .payloads import EventErrorPayload from .user import ActiveExtensions, Extension, PartialUser, User -from .utils import EventWaiter, unwrap_function +from .utils import MISSING, EventWaiter, unwrap_function from .web import AiohttpAdapter from .web.utils import BaseAdapter @@ -124,7 +124,7 @@ def __init__( ) -> None: redirect_uri: str | None = options.get("redirect_uri") scopes: Scopes | None = options.get("scopes") - session: aiohttp.ClientSession | None = options.get("session") + session: aiohttp.ClientSession = options.get("session", MISSING) or MISSING self._bot_id: str | None = bot_id self._http = ManagedHTTPClient( diff --git a/twitchio/http.py b/twitchio/http.py index 329893f4..38fd00c6 100644 --- a/twitchio/http.py +++ b/twitchio/http.py @@ -57,7 +57,7 @@ from .models.subscriptions import BroadcasterSubscription, BroadcasterSubscriptions from .models.videos import Video from .user import ActiveExtensions, PartialUser -from .utils import Colour, _from_json, date_to_datetime_with_z, handle_user_ids, url_encode_datetime # type: ignore +from .utils import MISSING, Colour, _from_json, date_to_datetime_with_z, handle_user_ids, url_encode_datetime # type: ignore if TYPE_CHECKING: @@ -377,10 +377,12 @@ async def __anext__(self) -> T: class HTTPClient: - __slots__ = ("_client_id", "_session", "user_agent") + __slots__ = ("_client_id", "_session", "user_agent", "_should_close") + + def __init__(self, session: aiohttp.ClientSession = MISSING, *, client_id: str) -> None: + self._session: aiohttp.ClientSession = session + self._should_close: bool = session is MISSING - def __init__(self, session: aiohttp.ClientSession | None = None, *, client_id: str) -> None: - self._session: aiohttp.ClientSession | None = session # should be set on the first request self._client_id: str = client_id # User Agent... @@ -393,24 +395,24 @@ def headers(self) -> dict[str, str]: return {"User-Agent": self.user_agent, "Client-ID": self._client_id} async def _init_session(self) -> None: - if self._session and not self._session.closed: + if self._session is not MISSING: + self._session.headers.update(self.headers) return - logger.debug("Initialising a new session on %s.", self.__class__.__qualname__) - - session = self._session or aiohttp.ClientSession() - session.headers.update(self.headers) - - self._session = session + logger.debug("Initialising ClientSession on %s.", self.__class__.__qualname__) + self._session = aiohttp.ClientSession(headers=self.headers) def clear(self) -> None: if self._session and self._session.closed: logger.debug( "Clearing %s session. A new session will be created on the next request.", self.__class__.__qualname__ ) - self._session = None + self._session = MISSING async def close(self) -> None: + if not self._should_close: + return + if self._session and not self._session.closed: try: await self._session.close()