Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry logic to each batch method of the GCS IO #33539

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Upgraded to protobuf 4 (Java) ([#33192](https://github.com/apache/beam/issues/33192)).
* [GCSIO] Added retry logic to each batch method of the GCS IO (Python) ([#33539](https://github.com/apache/beam/pull/33539))

## Breaking Changes

Expand Down
75 changes: 53 additions & 22 deletions sdks/python/apache_beam/io/gcp/gcsio.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
from typing import Optional
from typing import Union

from google.api_core.exceptions import RetryError
from google.cloud import storage
from google.cloud.exceptions import NotFound
from google.cloud.exceptions import from_http_response
from google.cloud.storage.fileio import BlobReader
from google.cloud.storage.fileio import BlobWriter
from google.cloud.storage.retry import DEFAULT_RETRY
Expand Down Expand Up @@ -264,9 +266,45 @@ def delete(self, path):
except NotFound:
return

def _batch_with_retry(self, requests, fn):
current_requests = [*enumerate(requests)]
responses = [None for _ in current_requests]

@self._storage_client_retry
def run_with_retry():
current_batch = self.client.batch(raise_exception=False)
with current_batch:
for _, request in current_requests:
fn(request)
last_retryable_exception = None
for (i, current_pair), response in zip(
[*current_requests], current_batch._responses
):
responses[i] = response
should_retry = (
response.status_code >= 400 and
self._storage_client_retry._predicate(from_http_response(response)))
if should_retry:
last_retryable_exception = from_http_response(response)
else:
current_requests.remove((i, current_pair))
if last_retryable_exception:
raise last_retryable_exception

try:
run_with_retry()
except RetryError:
pass

return responses

def _delete_batch_request(self, path):
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)

def delete_batch(self, paths):
"""Deletes the objects at the given GCS paths.
Warning: any exception during batch delete will NOT be retried.

Args:
paths: List of GCS file path patterns or Dict with GCS file path patterns
Expand All @@ -285,16 +323,11 @@ def delete_batch(self, paths):
current_paths = paths[s:s + MAX_BATCH_OPERATION_SIZE]
else:
current_paths = paths[s:]
current_batch = self.client.batch(raise_exception=False)
with current_batch:
for path in current_paths:
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)

responses = self._batch_with_retry(
current_paths, self._delete_batch_request)
for i, path in enumerate(current_paths):
error_code = None
resp = current_batch._responses[i]
resp = responses[i]
if resp.status_code >= 400 and resp.status_code != 404:
error_code = resp.status_code
final_results.append((path, error_code))
Expand Down Expand Up @@ -334,9 +367,16 @@ def copy(self, src, dest):
source_generation=src_generation,
retry=self._storage_client_retry)

def _copy_batch_request(self, pair):
src_bucket_name, src_blob_name = parse_gcs_path(pair[0])
dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1])
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
dest_bucket = self.client.bucket(dest_bucket_name)
src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name)

def copy_batch(self, src_dest_pairs):
"""Copies the given GCS objects from src to dest.
Warning: any exception during batch copy will NOT be retried.

Args:
src_dest_pairs: list of (src, dest) tuples of gs://<bucket>/<name> files
Expand All @@ -354,20 +394,11 @@ def copy_batch(self, src_dest_pairs):
current_pairs = src_dest_pairs[s:s + MAX_BATCH_OPERATION_SIZE]
else:
current_pairs = src_dest_pairs[s:]
current_batch = self.client.batch(raise_exception=False)
with current_batch:
for pair in current_pairs:
src_bucket_name, src_blob_name = parse_gcs_path(pair[0])
dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1])
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
dest_bucket = self.client.bucket(dest_bucket_name)

src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name)

responses = self._batch_with_retry(
current_pairs, self._copy_batch_request)
for i, pair in enumerate(current_pairs):
error_code = None
resp = current_batch._responses[i]
resp = responses[i]
if resp.status_code >= 400:
error_code = resp.status_code
final_results.append((pair[0], pair[1], error_code))
Expand Down
68 changes: 68 additions & 0 deletions sdks/python/apache_beam/io/gcp/gcsio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,74 @@ def test_copy(self):
'gs://gcsio-test/non-existent',
'gs://gcsio-test/non-existent-destination')

@staticmethod
def _fake_batch_responses(status_codes):
return mock.Mock(
__enter__=mock.Mock(),
__exit__=mock.Mock(),
_responses=[
mock.Mock(
**{
'json.return_value': {
'error': {
'message': 'error'
}
},
'request.method': 'BATCH',
'request.url': 'contentid://None',
},
status_code=code,
) for code in status_codes
],
)

@mock.patch('apache_beam.io.gcp.gcsio.MAX_BATCH_OPERATION_SIZE', 3)
@mock.patch('time.sleep', mock.Mock())
def test_copy_batch(self):
src_dest_pairs = [
(f'gs://source_bucket/file{i}.txt', f'gs://dest_bucket/file{i}.txt')
for i in range(7)
]
gcs_io = gcsio.GcsIO(
storage_client=mock.Mock(
batch=mock.Mock(
side_effect=[
self._fake_batch_responses([200, 404, 429]),
self._fake_batch_responses([429]),
self._fake_batch_responses([429]),
self._fake_batch_responses([200]),
self._fake_batch_responses([200, 429, 200]),
self._fake_batch_responses([200]),
self._fake_batch_responses([200]),
]),
))
results = gcs_io.copy_batch(src_dest_pairs)
expected = [
('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt', None),
('gs://source_bucket/file1.txt', 'gs://dest_bucket/file1.txt', 404),
('gs://source_bucket/file2.txt', 'gs://dest_bucket/file2.txt', None),
('gs://source_bucket/file3.txt', 'gs://dest_bucket/file3.txt', None),
('gs://source_bucket/file4.txt', 'gs://dest_bucket/file4.txt', None),
('gs://source_bucket/file5.txt', 'gs://dest_bucket/file5.txt', None),
('gs://source_bucket/file6.txt', 'gs://dest_bucket/file6.txt', None),
]
self.assertEqual(results, expected)

@mock.patch('time.sleep', mock.Mock())
@mock.patch('time.monotonic', mock.Mock(side_effect=[0, 120]))
def test_copy_batch_timeout_exceeded(self):
src_dest_pairs = [
('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt')
]
gcs_io = gcsio.GcsIO(
storage_client=mock.Mock(
batch=mock.Mock(side_effect=[self._fake_batch_responses([429])])))
results = gcs_io.copy_batch(src_dest_pairs)
expected = [
('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt', 429),
]
self.assertEqual(results, expected)

def test_copytree(self):
src_dir_name = 'gs://gcsio-test/source/'
dest_dir_name = 'gs://gcsio-test/dest/'
Expand Down
Loading