diff --git a/CHANGES.md b/CHANGES.md index 1f8f13305c83..44f5fe88c4dc 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -68,6 +68,7 @@ * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Upgraded to protobuf 4 (Java) ([#33192](https://github.com/apache/beam/issues/33192)). +* [GCSIO] Added retry logic to each batch method of the GCS IO (Python) ([#33539](https://github.com/apache/beam/pull/33539)) ## Breaking Changes diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index e0dcffa86dff..3e2f5d4cf635 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -35,8 +35,10 @@ from typing import Optional from typing import Union +from google.api_core.exceptions import RetryError from google.cloud import storage from google.cloud.exceptions import NotFound +from google.cloud.exceptions import from_http_response from google.cloud.storage.fileio import BlobReader from google.cloud.storage.fileio import BlobWriter from google.cloud.storage.retry import DEFAULT_RETRY @@ -264,9 +266,45 @@ def delete(self, path): except NotFound: return + def _batch_with_retry(self, requests, fn): + current_requests = [*enumerate(requests)] + responses = [None for _ in current_requests] + + @self._storage_client_retry + def run_with_retry(): + current_batch = self.client.batch(raise_exception=False) + with current_batch: + for _, request in current_requests: + fn(request) + last_retryable_exception = None + for (i, current_pair), response in zip( + [*current_requests], current_batch._responses + ): + responses[i] = response + should_retry = ( + response.status_code >= 400 and + self._storage_client_retry._predicate(from_http_response(response))) + if should_retry: + last_retryable_exception = from_http_response(response) + else: + current_requests.remove((i, current_pair)) + if last_retryable_exception: + raise last_retryable_exception + + try: + run_with_retry() + except RetryError: + pass + + return responses + + def _delete_batch_request(self, path): + bucket_name, blob_name = parse_gcs_path(path) + bucket = self.client.bucket(bucket_name) + bucket.delete_blob(blob_name) + def delete_batch(self, paths): """Deletes the objects at the given GCS paths. - Warning: any exception during batch delete will NOT be retried. Args: paths: List of GCS file path patterns or Dict with GCS file path patterns @@ -285,16 +323,11 @@ def delete_batch(self, paths): current_paths = paths[s:s + MAX_BATCH_OPERATION_SIZE] else: current_paths = paths[s:] - current_batch = self.client.batch(raise_exception=False) - with current_batch: - for path in current_paths: - bucket_name, blob_name = parse_gcs_path(path) - bucket = self.client.bucket(bucket_name) - bucket.delete_blob(blob_name) - + responses = self._batch_with_retry( + current_paths, self._delete_batch_request) for i, path in enumerate(current_paths): error_code = None - resp = current_batch._responses[i] + resp = responses[i] if resp.status_code >= 400 and resp.status_code != 404: error_code = resp.status_code final_results.append((path, error_code)) @@ -334,9 +367,16 @@ def copy(self, src, dest): source_generation=src_generation, retry=self._storage_client_retry) + def _copy_batch_request(self, pair): + src_bucket_name, src_blob_name = parse_gcs_path(pair[0]) + dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1]) + src_bucket = self.client.bucket(src_bucket_name) + src_blob = src_bucket.blob(src_blob_name) + dest_bucket = self.client.bucket(dest_bucket_name) + src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name) + def copy_batch(self, src_dest_pairs): """Copies the given GCS objects from src to dest. - Warning: any exception during batch copy will NOT be retried. Args: src_dest_pairs: list of (src, dest) tuples of gs:/// files @@ -354,20 +394,11 @@ def copy_batch(self, src_dest_pairs): current_pairs = src_dest_pairs[s:s + MAX_BATCH_OPERATION_SIZE] else: current_pairs = src_dest_pairs[s:] - current_batch = self.client.batch(raise_exception=False) - with current_batch: - for pair in current_pairs: - src_bucket_name, src_blob_name = parse_gcs_path(pair[0]) - dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1]) - src_bucket = self.client.bucket(src_bucket_name) - src_blob = src_bucket.blob(src_blob_name) - dest_bucket = self.client.bucket(dest_bucket_name) - - src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name) - + responses = self._batch_with_retry( + current_pairs, self._copy_batch_request) for i, pair in enumerate(current_pairs): error_code = None - resp = current_batch._responses[i] + resp = responses[i] if resp.status_code >= 400: error_code = resp.status_code final_results.append((pair[0], pair[1], error_code)) diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py b/sdks/python/apache_beam/io/gcp/gcsio_test.py index 7b79030b4b71..1faae2b2a8f1 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py @@ -482,6 +482,74 @@ def test_copy(self): 'gs://gcsio-test/non-existent', 'gs://gcsio-test/non-existent-destination') + @staticmethod + def _fake_batch_responses(status_codes): + return mock.Mock( + __enter__=mock.Mock(), + __exit__=mock.Mock(), + _responses=[ + mock.Mock( + **{ + 'json.return_value': { + 'error': { + 'message': 'error' + } + }, + 'request.method': 'BATCH', + 'request.url': 'contentid://None', + }, + status_code=code, + ) for code in status_codes + ], + ) + + @mock.patch('apache_beam.io.gcp.gcsio.MAX_BATCH_OPERATION_SIZE', 3) + @mock.patch('time.sleep', mock.Mock()) + def test_copy_batch(self): + src_dest_pairs = [ + (f'gs://source_bucket/file{i}.txt', f'gs://dest_bucket/file{i}.txt') + for i in range(7) + ] + gcs_io = gcsio.GcsIO( + storage_client=mock.Mock( + batch=mock.Mock( + side_effect=[ + self._fake_batch_responses([200, 404, 429]), + self._fake_batch_responses([429]), + self._fake_batch_responses([429]), + self._fake_batch_responses([200]), + self._fake_batch_responses([200, 429, 200]), + self._fake_batch_responses([200]), + self._fake_batch_responses([200]), + ]), + )) + results = gcs_io.copy_batch(src_dest_pairs) + expected = [ + ('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt', None), + ('gs://source_bucket/file1.txt', 'gs://dest_bucket/file1.txt', 404), + ('gs://source_bucket/file2.txt', 'gs://dest_bucket/file2.txt', None), + ('gs://source_bucket/file3.txt', 'gs://dest_bucket/file3.txt', None), + ('gs://source_bucket/file4.txt', 'gs://dest_bucket/file4.txt', None), + ('gs://source_bucket/file5.txt', 'gs://dest_bucket/file5.txt', None), + ('gs://source_bucket/file6.txt', 'gs://dest_bucket/file6.txt', None), + ] + self.assertEqual(results, expected) + + @mock.patch('time.sleep', mock.Mock()) + @mock.patch('time.monotonic', mock.Mock(side_effect=[0, 120])) + def test_copy_batch_timeout_exceeded(self): + src_dest_pairs = [ + ('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt') + ] + gcs_io = gcsio.GcsIO( + storage_client=mock.Mock( + batch=mock.Mock(side_effect=[self._fake_batch_responses([429])]))) + results = gcs_io.copy_batch(src_dest_pairs) + expected = [ + ('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt', 429), + ] + self.assertEqual(results, expected) + def test_copytree(self): src_dir_name = 'gs://gcsio-test/source/' dest_dir_name = 'gs://gcsio-test/dest/'