Skip to content

Commit

Permalink
Some backports + gc for memory issue debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
mephenor committed Dec 17, 2024
1 parent 94b9452 commit 40bf5d0
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 7 deletions.
20 changes: 17 additions & 3 deletions src/ghga_connector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from ghga_connector.core.downloading.batch_processing import FileStager
from ghga_connector.core.main import (
decrypt_file,
download_files,
download_file,
get_wps_token,
upload_file,
)
Expand Down Expand Up @@ -260,18 +260,30 @@ def download(
debug: bool = typer.Option(
False, help="Set this option in order to view traceback for errors."
),
overwrite: bool = typer.Option(
False,
help="Set to true to overwrite already existing files in the output directory.",
),
):
"""Wrapper for the async download function"""
asyncio.run(
async_download(output_dir, my_public_key_path, my_private_key_path, debug)
async_download(
output_dir=output_dir,
my_public_key_path=my_public_key_path,
my_private_key_path=my_private_key_path,
debug=debug,
overwrite=overwrite,
)
)


async def async_download(
*,
output_dir: Path,
my_public_key_path: Path,
my_private_key_path: Path,
debug: bool = False,
overwrite: bool = False,
):
"""Download files asynchronously"""
if not my_public_key_path.is_file():
Expand Down Expand Up @@ -305,6 +317,7 @@ async def async_download(
work_package_information=work_package_information,
)

message_display.display("Preparing files for download...")
stager = FileStager(
wanted_file_ids=list(parameters.file_ids_with_extension),
dcs_api_url=parameters.dcs_api_url,
Expand All @@ -318,7 +331,7 @@ async def async_download(
staged_files = await stager.get_staged_files()
for file_id in staged_files:
message_display.display(f"Downloading file with id '{file_id}'...")
await download_files(
await download_file(
api_url=parameters.dcs_api_url,
client=client,
file_id=file_id,
Expand All @@ -329,6 +342,7 @@ async def async_download(
part_size=CONFIG.part_size,
message_display=message_display,
work_package_accessor=parameters.work_package_accessor,
overwrite=overwrite,
)
staged_files.clear()

Expand Down
7 changes: 7 additions & 0 deletions src/ghga_connector/core/downloading/batch_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,14 @@ async def get_staged_files(self) -> dict[str, URLResponse]:
These values contain the download URLs and file sizes.
The dict should cleared after these files have been downloaded.
"""
self.message_display.display("Updating list of staged files...")
staging_items = list(self.unstaged_retry_times.items())
for file_id, retry_time in staging_items:
if time() >= retry_time:
await self._check_file(file_id=file_id)
if len(self.staged_urls.items()) > 0:
self.started_waiting = time() # reset wait timer
break
if not self.staged_urls and not self._handle_failures():
sleep(1)
self._check_timeout()
Expand Down Expand Up @@ -217,8 +221,10 @@ async def _check_file(self, file_id: str) -> None:
if isinstance(response, URLResponse):
del self.unstaged_retry_times[file_id]
self.staged_urls[file_id] = response
self.message_display.display(f"File {file_id} is ready for download.")
elif isinstance(response, RetryResponse):
self.unstaged_retry_times[file_id] = time() + response.retry_after
self.message_display.display(f"File {file_id} is (still) being staged.")
else:
self.missing_files.append(file_id)

Expand Down Expand Up @@ -251,4 +257,5 @@ def _handle_failures(self) -> bool:
self.io_handler.handle_response(response=response)
self.message_display.display("Downloading remaining files")
self.time_started = time() # reset the timer
self.missing_files = [] # reset list of missing files
return True
12 changes: 10 additions & 2 deletions src/ghga_connector/core/downloading/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"""Contains a concrete implementation of the abstract downloader"""

import base64
from asyncio import Queue, Semaphore, Task, create_task
import gc
from asyncio import PriorityQueue, Queue, Semaphore, Task, create_task
from collections.abc import Coroutine
from io import BufferedWriter
from pathlib import Path
Expand Down Expand Up @@ -79,7 +80,7 @@ def __init__( # noqa: PLR0913
self._max_wait_time = max_wait_time
self._message_display = message_display
self._work_package_accessor = work_package_accessor
self._queue: Queue[Union[tuple[int, bytes], BaseException]] = Queue()
self._queue: Queue[Union[tuple[int, bytes], BaseException]] = PriorityQueue()
self._semaphore = Semaphore(value=max_concurrent_downloads)
self._retry_handler = HttpxClientConfigurator.retry_handler

Expand Down Expand Up @@ -135,10 +136,16 @@ async def await_download_url(self) -> URLResponse:
wait_time = 0
while wait_time < self._max_wait_time:
try:
self._message_display.display(
f"Fetching work order token for {self._file_id}"
)
url_and_headers = await get_file_authorization(
file_id=self._file_id,
work_package_accessor=self._work_package_accessor,
)
self._message_display.display(
f"Fetching download URL for {self._file_id}"
)
response = await get_download_url(
client=self._client, url_and_headers=url_and_headers
)
Expand Down Expand Up @@ -302,3 +309,4 @@ async def drain_queue_to_file(
downloaded_size += chunk_size
self._queue.task_done()
progress.advance(chunk_size)
gc.collect()
13 changes: 11 additions & 2 deletions src/ghga_connector/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def upload_file( # noqa: PLR0913
message_display.success(f"File with id '{file_id}' has been successfully uploaded.")


async def download_files( # noqa: PLR0913
async def download_file( # noqa: PLR0913
*,
api_url: str,
client: httpx.AsyncClient,
Expand All @@ -105,6 +105,7 @@ async def download_files( # noqa: PLR0913
work_package_accessor: WorkPackageAccessor,
file_id: str,
file_extension: str = "",
overwrite: bool = False,
) -> None:
"""Core command to download a file. Can be called by CLI, GUI, etc."""
if not is_service_healthy(api_url):
Expand All @@ -118,7 +119,15 @@ async def download_files( # noqa: PLR0913
# check output file
output_file = output_dir / f"{file_name}.c4gh"
if output_file.exists():
raise exceptions.FileAlreadyExistsError(output_file=str(output_file))
if overwrite:
message_display.display(
f"A file with name '{output_file}' already exists and will be overwritten."
)
else:
message_display.failure(
f"A file with name '{output_file}' already exists. Skipping."
)
return

# with_suffix() might overwrite existing suffixes, do this instead
output_file_ongoing = output_file.parent / (output_file.name + ".part")
Expand Down

0 comments on commit 40bf5d0

Please sign in to comment.