Skip to content

Commit

Permalink
Ensure user passed session receives headers, fixes #476
Browse files Browse the repository at this point in the history
Only close non-user passed sessions, fixes #477
  • Loading branch information
EvieePy committed Dec 23, 2024
1 parent 44c93a9 commit 587d9c3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
4 changes: 2 additions & 2 deletions twitchio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from .models.teams import Team
from .payloads import EventErrorPayload
from .user import ActiveExtensions, Extension, PartialUser, User
from .utils import EventWaiter, unwrap_function
from .utils import MISSING, EventWaiter, unwrap_function
from .web import AiohttpAdapter
from .web.utils import BaseAdapter

Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
) -> None:
redirect_uri: str | None = options.get("redirect_uri")
scopes: Scopes | None = options.get("scopes")
session: aiohttp.ClientSession | None = options.get("session")
session: aiohttp.ClientSession = options.get("session", MISSING) or MISSING
self._bot_id: str | None = bot_id

self._http = ManagedHTTPClient(
Expand Down
26 changes: 14 additions & 12 deletions twitchio/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .models.subscriptions import BroadcasterSubscription, BroadcasterSubscriptions
from .models.videos import Video
from .user import ActiveExtensions, PartialUser
from .utils import Colour, _from_json, date_to_datetime_with_z, handle_user_ids, url_encode_datetime # type: ignore
from .utils import MISSING, Colour, _from_json, date_to_datetime_with_z, handle_user_ids, url_encode_datetime # type: ignore


if TYPE_CHECKING:
Expand Down Expand Up @@ -377,10 +377,12 @@ async def __anext__(self) -> T:


class HTTPClient:
__slots__ = ("_client_id", "_session", "user_agent")
__slots__ = ("_client_id", "_session", "user_agent", "_should_close")

def __init__(self, session: aiohttp.ClientSession = MISSING, *, client_id: str) -> None:
self._session: aiohttp.ClientSession = session
self._should_close: bool = session is MISSING

def __init__(self, session: aiohttp.ClientSession | None = None, *, client_id: str) -> None:
self._session: aiohttp.ClientSession | None = session # should be set on the first request
self._client_id: str = client_id

# User Agent...
Expand All @@ -393,24 +395,24 @@ def headers(self) -> dict[str, str]:
return {"User-Agent": self.user_agent, "Client-ID": self._client_id}

async def _init_session(self) -> None:
if self._session and not self._session.closed:
if self._session is not MISSING:
self._session.headers.update(self.headers)
return

logger.debug("Initialising a new session on %s.", self.__class__.__qualname__)

session = self._session or aiohttp.ClientSession()
session.headers.update(self.headers)

self._session = session
logger.debug("Initialising ClientSession on %s.", self.__class__.__qualname__)
self._session = aiohttp.ClientSession(headers=self.headers)

def clear(self) -> None:
if self._session and self._session.closed:
logger.debug(
"Clearing %s session. A new session will be created on the next request.", self.__class__.__qualname__
)
self._session = None
self._session = MISSING

async def close(self) -> None:
if not self._should_close:
return

if self._session and not self._session.closed:
try:
await self._session.close()
Expand Down

0 comments on commit 587d9c3

Please sign in to comment.