Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
nateprewitt committed Nov 27, 2023
1 parent 297604d commit 12b3c9c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 24 deletions.
28 changes: 17 additions & 11 deletions boto3/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,10 @@ def _create_crt_request_serializer(session, region_name):
)


def _create_crt_s3_client(session, config, region_name, credentials, **kwargs):
def _create_crt_s3_client(
session, config, region_name, credentials, lock, **kwargs
):
"""Create boto3 wrapper class to manage crt lock reference and S3 client."""
lock = acquire_crt_s3_process_lock(PROCESS_LOCK_NAME)
if lock is None:
# If we're unable to acquire the lock, we cannot
# use the CRT in this process and should default to
# the classic s3transfer manager.
return None

cred_wrapper = BotocoreCRTCredentialsWrapper(credentials)
cred_provider = cred_wrapper.to_crt_credentials_provider()
return CRTS3Client(
Expand All @@ -79,13 +74,20 @@ def _create_crt_s3_client(session, config, region_name, credentials, **kwargs):


def _initialize_crt_transfer_primatives(client, config):
lock = acquire_crt_s3_process_lock(PROCESS_LOCK_NAME)
if lock is None:
# If we're unable to acquire the lock, we cannot
# use the CRT in this process and should default to
# the classic s3transfer manager.
return None, None

session = Session()
region_name = client.meta.region_name
credentials = client._get_credentials()

serializer = _create_crt_request_serializer(session, region_name)
s3_client = _create_crt_s3_client(
session, config, region_name, credentials
session, config, region_name, credentials, lock
)
return serializer, s3_client

Expand Down Expand Up @@ -130,10 +132,14 @@ def is_crt_compatible_request(client, crt_s3_client):
if crt_s3_client is None:
return False

is_same_region = client.meta.region_name == crt_s3_client.region
boto3_creds = client._get_credentials()
if boto3_creds is None:
return False

is_same_identity = compare_identity(
client._get_credentials(), crt_s3_client.cred_provider
boto3_creds.get_frozen_credentials(), crt_s3_client.cred_provider
)
is_same_region = client.meta.region_name == crt_s3_client.region
return is_same_region and is_same_identity


Expand Down
38 changes: 25 additions & 13 deletions tests/unit/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,25 @@
from tests import mock, requires_crt

if HAS_CRT:
import awscrt.s3
from awscrt.s3 import CrossProcessLock as CrossProcessLockClass
from s3transfer.crt import BotocoreCRTCredentialsWrapper

import boto3.crt


@pytest.fixture
def mock_crt_process_lock(monkeypatch):
# The process lock is cached at the module layer whenever the
# cross process lock is successfully acquired. This patch ensures that
# test cases will start off with no previously cached process lock and
# if a cross process is instantiated/acquired it will be the mock that
# can be used for controlling lock behavior.
monkeypatch.setattr('s3transfer.crt.CRT_S3_PROCESS_LOCK', None)
with mock.patch('awscrt.s3.CrossProcessLock', spec=True) as mock_lock:
yield mock_lock



def create_test_client(service_name='s3', region_name="us-east-1"):
return boto3.client(
service_name,
Expand All @@ -43,21 +56,20 @@ def create_test_client(service_name='s3', region_name="us-east-1"):

class TestCRTTransferManager:
@requires_crt()
def test_create_crt_transfer_manager_with_lock_in_use(self):
with mock.patch('boto3.crt.acquire_crt_s3_process_lock') as lock:
lock.return_value = None
def test_create_crt_transfer_manager_with_lock_in_use(self, mock_crt_process_lock):
mock_crt_process_lock.return_value.acquire.side_effect = RuntimeError

# Verify we can't create a second CRT client
tm = boto3.crt.create_crt_transfer_manager(USW2_S3_CLIENT, None)
assert tm is None
# Verify we can't create a second CRT client
tm = boto3.crt.create_crt_transfer_manager(USW2_S3_CLIENT, None)
assert tm is None

@requires_crt()
def test_create_crt_transfer_manager(self):
def test_create_crt_transfer_manager(self, mock_crt_process_lock):
tm = boto3.crt.create_crt_transfer_manager(USW2_S3_CLIENT, None)
assert isinstance(tm, s3transfer.crt.CRTTransferManager)

@requires_crt()
def test_crt_singleton_is_returned_every_call(self):
def test_crt_singleton_is_returned_every_call(self, mock_crt_process_lock):
first_s3_client = boto3.crt.get_crt_s3_client(USW2_S3_CLIENT, None)
second_s3_client = boto3.crt.get_crt_s3_client(USW2_S3_CLIENT, None)

Expand All @@ -66,7 +78,7 @@ def test_crt_singleton_is_returned_every_call(self):
assert first_s3_client.crt_client is second_s3_client.crt_client

@requires_crt()
def test_create_crt_transfer_manager_w_client_in_wrong_region(self):
def test_create_crt_transfer_manager_w_client_in_wrong_region(self, mock_crt_process_lock):
"""Ensure we don't return the crt transfer manager if client is in
different region. The CRT isn't able to handle region redirects and
will consistently fail.
Expand Down Expand Up @@ -130,20 +142,20 @@ def no_credentials():
)

@requires_crt()
def test_get_crt_s3_client(self):
def test_get_crt_s3_client(self, mock_crt_process_lock):
config = TransferConfig()
crt_s3_client = boto3.crt.get_crt_s3_client(USW2_S3_CLIENT, config)
assert isinstance(crt_s3_client, boto3.crt.CRTS3Client)
assert isinstance(
crt_s3_client.process_lock, awscrt.s3.CrossProcessLock
crt_s3_client.process_lock, CrossProcessLockClass
)
assert crt_s3_client.region == "us-west-2"
assert isinstance(
crt_s3_client.cred_provider, BotocoreCRTCredentialsWrapper
)

@requires_crt()
def test_get_crt_s3_client_w_wrong_region(self):
def test_get_crt_s3_client_w_wrong_region(self, mock_crt_process_lock):
config = TransferConfig()
crt_s3_client = boto3.crt.get_crt_s3_client(USW2_S3_CLIENT, config)
assert isinstance(crt_s3_client, boto3.crt.CRTS3Client)
Expand Down

0 comments on commit 12b3c9c

Please sign in to comment.