Skip to content

Commit

Permalink
chore: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Feb 18, 2025
1 parent 2fd40a8 commit 76c0b76
Showing 1 changed file with 49 additions and 69 deletions.
118 changes: 49 additions & 69 deletions src/fsspec_httpx/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from __future__ import annotations

import asyncio
from collections.abc import Mapping
import contextlib
from copy import copy
import io
import logging
Expand Down Expand Up @@ -69,22 +71,7 @@ def __init__(
encoded: bool = False,
**storage_options,
) -> None:
"""Initialize the filesystem.
Parameters
----------
simple_links : bool
If True, will consider both HTML <a> tags and anything that looks
like a URL; if False, will consider only the former.
block_size : int | None
Blocks to read bytes; if 0, will default to raw requests file-like
objects instead of HTTPFile instances
same_scheme : bool
When doing ls/glob, if True, only consider paths that have
http/https matching the input URLs.
client_kwargs : dict | None
Passed to httpx.AsyncClient
"""
"""Initialize the filesystem."""
super().__init__(asynchronous=asynchronous, loop=loop, **storage_options)
self.block_size = block_size or DEFAULT_BLOCK_SIZE
self.simple_links = simple_links
Expand Down Expand Up @@ -153,12 +140,14 @@ async def _get_decompressor(
return zlib.decompress
if encoding == "br":
try:
import brotli
import brotli # pyright: ignore

return brotli.decompress
except ImportError:
msg = "brotli module is required for brotli decompression"
raise ImportError(msg)
raise ImportError(msg) # noqa: B904
else:
return brotli.decompress

return None

async def _ls_real(self, url: str, detail: bool = True, **kwargs: Any) -> list | dict:
Expand All @@ -176,10 +165,7 @@ async def _ls_real(self, url: str, detail: bool = True, **kwargs: Any) -> list |
out = set()

# Extract links
if self.simple_links:
links = URL_PATTERN.findall(text)
else:
links = []
links = URL_PATTERN.findall(text) if self.simple_links else []

href_matches = HREF_PATTERN.findall(text)
links.extend(m[2] for m in href_matches)
Expand Down Expand Up @@ -231,7 +217,7 @@ async def _ls(self, url: str, detail: bool = True, **kwargs: Any) -> list | dict
try:
out = await self._ls_real(url, detail=detail, **kwargs)
if not out:
raise FileNotFoundError(url)
raise FileNotFoundError(url) # noqa: TRY301
if self.use_listings_cache:
self.dircache[url] = out
except Exception as e:
Expand All @@ -245,7 +231,7 @@ async def _ls(self, url: str, detail: bool = True, **kwargs: Any) -> list | dict

def _raise_not_found_for_status(self, response: httpx.Response, url: str) -> None:
"""Raise FileNotFoundError for 404s, otherwise raises HTTP errors."""
if response.status_code == 404:
if response.status_code == 404: # noqa: PLR2004
raise FileNotFoundError(url)
response.raise_for_status()

Expand Down Expand Up @@ -290,10 +276,7 @@ async def _get_file(
callback.set_size(size)
self._raise_not_found_for_status(r, rpath)

if isfilelike(lpath):
outfile = lpath
else:
outfile = open(lpath, "wb")
outfile = lpath if isfilelike(lpath) else open(lpath, "wb")

try:
async for chunk in r.aiter_bytes(chunk_size):
Expand All @@ -314,14 +297,15 @@ async def _put_file(
**kwargs: Any,
) -> None:
if mode != "overwrite":
raise NotImplementedError("Only 'overwrite' mode is supported")
msg = "Only 'overwrite' mode is supported"
raise NotImplementedError(msg)

def gen_chunks():
if isinstance(lpath, io.IOBase):
context = nullcontext(lpath)
use_seek = False
else:
context = open(lpath, "rb")
context = open(lpath, "rb") # noqa: SIM115
use_seek = True

with context as f:
Expand Down Expand Up @@ -356,9 +340,10 @@ async def _exists(self, path: str, **kwargs: Any) -> bool:
logger.debug(path)
session = await self.set_session()
r = await session.get(str(self.encode_url(path)), **kw)
return r.status_code < 400
except httpx.RequestError:
return False
else:
return r.status_code < 400 # noqa: PLR2004

async def _isfile(self, path: str, **kwargs: Any) -> bool:
return await self._exists(path, **kwargs)
Expand All @@ -376,7 +361,8 @@ def _open(
) -> HTTPFile | HTTPStreamFile:
"""Create a file-like object."""
if mode != "rb":
raise NotImplementedError("Write mode not supported")
msg = "Write mode not supported"
raise NotImplementedError(msg)

block_size = block_size if block_size is not None else self.block_size
kw = self.kwargs.copy()
Expand Down Expand Up @@ -413,7 +399,7 @@ def _open(
loop=self.loop,
**kw,
)
except Exception:
except Exception: # noqa: BLE001
pass

# Default to streaming
Expand All @@ -431,10 +417,8 @@ async def open_async(
) -> AsyncStreamFile:
session = await self.set_session()
if size is None:
try:
with contextlib.suppress(FileNotFoundError):
size = (await self._info(path, **kwargs))["size"]
except FileNotFoundError:
pass
return AsyncStreamFile(
self,
path,
Expand Down Expand Up @@ -476,7 +460,8 @@ async def _info(self, path: str, **kwargs: Any) -> dict[str, Any]:
async def _glob(self, path: str, maxdepth: int | None = None, **kwargs):
"""Find files by glob-matching."""
if maxdepth is not None and maxdepth < 1:
raise ValueError("maxdepth must be at least 1")
msg = "maxdepth must be at least 1"
raise ValueError(msg)

ends_with_slash = path.endswith("/")
path = self._strip_protocol(path)
Expand All @@ -499,7 +484,7 @@ async def _glob(self, path: str, maxdepth: int | None = None, **kwargs):
if "/" in path[:min_idx]:
min_idx = path[:min_idx].rindex("/")
root = path[: min_idx + 1]
depth = path[min_idx + 1 :].count("/") + 1
depth: int | None = path[min_idx + 1 :].count("/") + 1
else:
root = ""
depth = path[min_idx + 1 :].count("/") + 1
Expand All @@ -510,7 +495,7 @@ async def _glob(self, path: str, maxdepth: int | None = None, **kwargs):
depth_double_stars = path[idx_double_stars:].count("/") + 1
depth = depth - depth_double_stars + maxdepth
else:
depth = None
depth = None # type: ignore

allpaths = await self._find(
root, maxdepth=depth, withdirs=True, detail=True, **kwargs
Expand All @@ -527,7 +512,7 @@ async def _glob(self, path: str, maxdepth: int | None = None, **kwargs):
and p.endswith("/")
else p
): info
for p, info in sorted(allpaths.items())
for p, info in sorted(allpaths.items()) # type: ignore
if pattern.match(p.rstrip("/"))
}

Expand Down Expand Up @@ -595,7 +580,6 @@ def read(self, length: int = -1) -> bytes:

async def async_fetch_all(self) -> None:
"""Read whole file in one shot, without caching."""
logger.debug(f"Fetch all for {self}")
if not isinstance(self.cache, AllBytes):
r = await self.session.get(str(self.fs.encode_url(self.url)), **self.kwargs)
r.raise_for_status()
Expand All @@ -610,7 +594,7 @@ async def async_fetch_all(self) -> None:

_fetch_all = sync_wrapper(async_fetch_all)

def _parse_content_range(self, headers: dict[str, str]) -> tuple[int | None, ...]:
def _parse_content_range(self, headers: Mapping[str, str]) -> tuple[int | None, ...]:
"""Parse the Content-Range header."""
content_range = headers.get("Content-Range", "")
match = re.match(r"bytes (\d+-\d+|\*)/(\d+|\*)", content_range)
Expand All @@ -626,27 +610,25 @@ def _parse_content_range(self, headers: dict[str, str]) -> tuple[int | None, ...

async def async_fetch_range(self, start: int, end: int) -> bytes:
"""Download a block of data."""
logger.debug(f"Fetch range for {self}: {start}-{end}")
kwargs = self.kwargs.copy()
headers = kwargs.pop("headers", {}).copy()
headers["Range"] = f"bytes={start}-{end - 1}"
logger.debug(f"{self.url} : {headers['Range']}")

r = await self.session.get(
str(self.fs.encode_url(self.url)),
headers=headers,
**kwargs,
)

if r.status_code == 416:
if r.status_code == 416: # noqa: PLR2004
# Range request outside file
return b""

r.raise_for_status()

# Check if server handled range request correctly
response_is_range = (
r.status_code == 206
r.status_code == 206 # noqa: PLR2004
or self._parse_content_range(r.headers)[0] == start
or int(r.headers.get("Content-Length", end + 1)) <= end - start
)
Expand All @@ -655,11 +637,12 @@ async def async_fetch_range(self, start: int, end: int) -> bytes:
# Partial content, as expected
return r.content
if start > 0:
raise ValueError(
msg = (
"The HTTP server doesn't support range requests. "
"Only reading from the beginning is supported. "
"Open with block_size=0 for a streaming file interface."
)
raise ValueError(msg)
# Response is not a range, but we want the start of the file
content = []
total_bytes = 0
Expand Down Expand Up @@ -691,7 +674,8 @@ def __init__(
self._stream = None

if mode != "rb":
raise ValueError("Write mode not supported")
msg = "Write mode not supported"
raise ValueError(msg)

self.details = {"name": url, "size": None}
super().__init__(fs=fs, path=url, mode=mode, cache_type="none", **kwargs)
Expand All @@ -706,16 +690,19 @@ async def _init():
def seek(self, loc: int, whence: int = 0) -> int:
"""Seek to position in file."""
if not self.seekable():
raise ValueError("Stream is not seekable")
msg = "Stream is not seekable"
raise ValueError(msg)

if whence == 1: # SEEK_CUR
loc = self.loc + loc
elif whence == 2: # SEEK_END
raise ValueError("Cannot seek from end in streaming file")
elif whence == 2: # SEEK_END # noqa: PLR2004
msg = "Cannot seek from end in streaming file"
raise ValueError(msg)

# SEEK_SET or converted SEEK_CUR
if loc < 0:
raise ValueError("Cannot seek before start of file")
msg = "Cannot seek before start of file"
raise ValueError(msg)

if loc == self.loc:
return self.loc
Expand All @@ -728,13 +715,15 @@ def seek(self, loc: int, whence: int = 0) -> int:
self._stream = None
self.loc = 0
return 0
raise ValueError("Cannot seek backwards except to start")
msg = "Cannot seek backwards except to start"
raise ValueError(msg)

# Check for explicit range support
headers = self.kwargs.get("headers", {})
if not headers or headers.get("accept_range") == "none":
# Either no headers (default) or explicitly disabled ranges
raise ValueError("Random access not supported with streaming file")
msg = "Random access not supported with streaming file"
raise ValueError(msg)

# For forward seeks within buffered data
if self._content_buffer and loc <= len(self._content_buffer):
Expand All @@ -757,7 +746,7 @@ async def _read(self, num: int = -1) -> bytes:
# Read all remaining data
chunks = [self._content_buffer]
async for chunk in self._stream:
chunks.append(chunk)
chunks.append(chunk) # noqa: PERF401
self._content_buffer = b""
data = b"".join(chunks)
self.loc += len(data)
Expand Down Expand Up @@ -792,9 +781,10 @@ async def _read(self, num: int = -1) -> bytes:
self.loc += num
return data[:num]

read = sync_wrapper(_read)
read = sync_wrapper(_read) # type: ignore

async def _close(self) -> None:
assert self.r
await self.r.aclose()

def close(self) -> None:
Expand All @@ -817,7 +807,7 @@ def __init__(
) -> None:
self.url = url
self.session = session
self.r = None
self.r: httpx.Response | None = None
if mode != "rb":
raise ValueError("Write mode not supported")

Expand Down Expand Up @@ -875,17 +865,7 @@ async def _file_info(
size_policy: str = "head",
**kwargs: Any,
) -> dict[str, Any]:
"""Get details about the file (size/checksum etc.).
Parameters
----------
url : str
File URL
session : httpx.AsyncClient
HTTP client session
size_policy : str
Either 'head' or 'get' to determine how to get file size
"""
"""Get details about the file (size/checksum etc)."""
logger.debug("Retrieve file size for %s", url)
kwargs = kwargs.copy()
ar = kwargs.pop("allow_redirects", True)
Expand Down

0 comments on commit 76c0b76

Please sign in to comment.