Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: don't ignore empty dirs when unpacking model and bento #5073

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions src/bentoml/_internal/cloud/bento.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import tarfile
import typing as t
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from tempfile import NamedTemporaryFile

import attrs
Expand All @@ -19,6 +18,7 @@
from ..bento import BentoStore
from ..configuration.containers import BentoMLContainer
from ..tag import Tag
from ..utils.filesystem import safe_extract_tarfile
from .base import FILE_CHUNK_SIZE
from .base import UPLOAD_RETRY_COUNT
from .base import CallbackIOWrapper
Expand Down Expand Up @@ -520,14 +520,7 @@ def _do_pull_bento(
tar = tarfile.open(fileobj=tar_file, mode="r")
with self.spinner.spin(text=f'Extracting bento "{_tag}" tar file'):
with fs.open_fs("temp://") as temp_fs:
for member in tar.getmembers():
f = tar.extractfile(member)
if f is None:
continue
p = Path(member.name)
if p.parent != Path("."):
temp_fs.makedirs(p.parent.as_posix(), recreate=True)
temp_fs.writebytes(member.name, f.read())
safe_extract_tarfile(tar, temp_fs.getsyspath("/"))
bento = Bento.from_fs(temp_fs)
bento = bento.save(bento_store)
self.spinner.log(f'[bold green]Successfully pulled bento "{_tag}"')
Expand Down
11 changes: 2 additions & 9 deletions src/bentoml/_internal/cloud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import typing as t
import warnings
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from tempfile import NamedTemporaryFile

import attrs
Expand All @@ -20,6 +19,7 @@
from ..models import Model as StoredModel
from ..models import ModelStore
from ..tag import Tag
from ..utils.filesystem import safe_extract_tarfile
from .base import FILE_CHUNK_SIZE
from .base import UPLOAD_RETRY_COUNT
from .base import CallbackIOWrapper
Expand Down Expand Up @@ -482,14 +482,7 @@ def _do_pull_model(
tar = tarfile.open(fileobj=tar_file, mode="r")
with self.spinner.spin(text=f'Extracting model "{_tag}" tar file'):
with fs.open_fs("temp://") as temp_fs:
for member in tar.getmembers():
f = tar.extractfile(member)
if f is None:
continue
p = Path(member.name)
if p.parent != Path("."):
temp_fs.makedirs(str(p.parent), recreate=True)
temp_fs.writebytes(member.name, f.read())
safe_extract_tarfile(tar, temp_fs.getsyspath("/"))
model = StoredModel.from_fs(temp_fs).save(model_store)
self.spinner.log(f'[bold green]Successfully pulled model "{_tag}"')
return model
Expand Down
2 changes: 1 addition & 1 deletion src/bentoml/_internal/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import starlette.datastructures
from starlette.background import BackgroundTasks

from .utils.filesystem import TempfilePool
from .utils.http import Cookie
from .utils.temp import TempfilePool

if TYPE_CHECKING:
import starlette.requests
Expand Down
93 changes: 93 additions & 0 deletions src/bentoml/_internal/utils/filesystem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

import logging
import os
import shutil
import tarfile
import tempfile
from collections import deque
from functools import partial
from pathlib import Path
from threading import Lock

import fs

logger = logging.getLogger(__name__)


class TempfilePool:
"""A simple pool to get temp directories,
so they are reused as much as possible.
"""

def __init__(
self,
suffix: str | None = None,
prefix: str | None = None,
dir: str | None = None,
) -> None:
self._pool: deque[str] = deque([])
self._lock = Lock()
self._new = partial(tempfile.mkdtemp, suffix=suffix, prefix=prefix, dir=dir)

def cleanup(self) -> None:
while len(self._pool):
dir = self._pool.popleft()
shutil.rmtree(dir, ignore_errors=True)

def acquire(self) -> str:
with self._lock:
if not len(self._pool):
return self._new()
else:
return self._pool.popleft()

def release(self, dir: str) -> None:
for child in Path(dir).iterdir():
if child.is_dir():
shutil.rmtree(child)
else:
child.unlink()
with self._lock:
self._pool.append(dir)


def safe_extract_tarfile(tar: tarfile.TarFile, destination: str) -> None:
# Borrowed from pip but continue on error
os.makedirs(destination, exist_ok=True)
for member in tar.getmembers():
fn = member.name
path = os.path.join(destination, fn)
if not fs.path.relativefrom(destination, path):
logger.warning(
"The tar file has a file (%s) trying to unpack to"
"outside target directory",
fn,
)
continue
if member.isdir():
os.makedirs(path, exist_ok=True)
elif member.issym():
try:
tar._extract_member(member, path)
except Exception as exc:
# Some corrupt tar files seem to produce this
# (specifically bad symlinks)
logger.warning("In the tar file the member %s is invalid: %s", fn, exc)
continue
else:
try:
fp = tar.extractfile(member)
except (KeyError, AttributeError) as exc:
# Some corrupt tar files seem to produce this
# (specifically bad symlinks)
logger.warning("In the tar file the member %s is invalid: %s", fn, exc)
continue
os.makedirs(os.path.dirname(path), exist_ok=True)
if fp is None:
continue
with open(path, "wb") as destfp:
shutil.copyfileobj(fp, destfp)
fp.close()
# Update the timestamp (useful for cython compiled files)
tar.utime(member, path)
45 changes: 0 additions & 45 deletions src/bentoml/_internal/utils/temp.py

This file was deleted.

Loading