Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turn on checksum validation for CRT S3 transfer manager #280

Merged
merged 1 commit into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement-crt-30257.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "``crt``",
"description": "Automatically configure CRC32 checksums for uploads and checksum validation for downloads through the CRT transfer manager."
}
27 changes: 24 additions & 3 deletions s3transfer/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None):
extra_args = {}
if subscribers is None:
subscribers = {}
self._validate_checksum_algorithm_supported(extra_args)
callargs = CallArgs(
bucket=bucket,
key=key,
Expand All @@ -231,6 +232,17 @@ def delete(self, bucket, key, extra_args=None, subscribers=None):
def shutdown(self, cancel=False):
self._shutdown(cancel)

def _validate_checksum_algorithm_supported(self, extra_args):
checksum_algorithm = extra_args.get('ChecksumAlgorithm')
if checksum_algorithm is None:
return
supported_algorithms = list(awscrt.s3.S3ChecksumAlgorithm.__members__)
if checksum_algorithm.upper() not in supported_algorithms:
raise ValueError(
f'ChecksumAlgorithm: {checksum_algorithm} not supported. '
f'Supported algorithms are: {supported_algorithms}'
)

def _cancel_transfers(self):
for coordinator in self._future_coordinators:
if not coordinator.done():
Expand Down Expand Up @@ -623,11 +635,17 @@ def _get_make_request_args_put_object(
else:
call_args.extra_args["Body"] = call_args.fileobj

checksum_algorithm = call_args.extra_args.pop(
'ChecksumAlgorithm', 'CRC32'
).upper()
checksum_config = awscrt.s3.S3ChecksumConfig(
algorithm=awscrt.s3.S3ChecksumAlgorithm[checksum_algorithm],
location=awscrt.s3.S3ChecksumLocation.TRAILER,
)
# Suppress botocore's automatic MD5 calculation by setting an override
# value that will get deleted in the BotocoreCRTRequestSerializer.
# The CRT S3 client is able automatically compute checksums as part of
# requests it makes, and the intention is to configure automatic
# checksums in a future update.
# As part of the CRT S3 request, we request the CRT S3 client to
# automatically add trailing checksums to its uploads.
call_args.extra_args["ContentMD5"] = "override-to-be-removed"

make_request_args = self._default_get_make_request_args(
Expand All @@ -639,6 +657,7 @@ def _get_make_request_args_put_object(
on_done_after_calls=on_done_after_calls,
)
make_request_args['send_filepath'] = send_filepath
make_request_args['checksum_config'] = checksum_config
return make_request_args

def _get_make_request_args_get_object(
Expand All @@ -652,6 +671,7 @@ def _get_make_request_args_get_object(
):
recv_filepath = None
on_body = None
checksum_config = awscrt.s3.S3ChecksumConfig(validate_response=True)
if isinstance(call_args.fileobj, str):
final_filepath = call_args.fileobj
recv_filepath = self._os_utils.get_temp_filename(final_filepath)
Expand All @@ -673,6 +693,7 @@ def _get_make_request_args_get_object(
)
make_request_args['recv_filepath'] = recv_filepath
make_request_args['on_body'] = on_body
make_request_args['checksum_config'] = checksum_config
return make_request_args

def _default_get_make_request_args(
Expand Down
111 changes: 109 additions & 2 deletions tests/functional/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ def _assert_expected_crt_http_request(
str(expected_content_length),
)
if expected_missing_headers is not None:
header_names = [header[0] for header in crt_http_request.headers]
header_names = [
header[0].lower() for header in crt_http_request.headers
]
for expected_missing_header in expected_missing_headers:
self.assertNotIn(expected_missing_header, header_names)
self.assertNotIn(expected_missing_header.lower(), header_names)

def _assert_subscribers_called(self, expected_future=None):
self.assertTrue(self.record_subscriber.on_queued_called)
Expand All @@ -143,6 +145,21 @@ def _assert_subscribers_called(self, expected_future=None):
self.record_subscriber.on_done_future, expected_future
)

def _get_expected_upload_checksum_config(self, **overrides):
checksum_config_kwargs = {
'algorithm': awscrt.s3.S3ChecksumAlgorithm.CRC32,
'location': awscrt.s3.S3ChecksumLocation.TRAILER,
}
checksum_config_kwargs.update(overrides)
return awscrt.s3.S3ChecksumConfig(**checksum_config_kwargs)

def _get_expected_download_checksum_config(self, **overrides):
checksum_config_kwargs = {
'validate_response': True,
}
checksum_config_kwargs.update(overrides)
return awscrt.s3.S3ChecksumConfig(**checksum_config_kwargs)

def _invoke_done_callbacks(self, **kwargs):
callargs = self.s3_crt_client.make_request.call_args
callargs_kwargs = callargs[1]
Expand Down Expand Up @@ -180,6 +197,7 @@ def test_upload(self):
'send_filepath': self.filename,
'on_progress': mock.ANY,
'on_done': mock.ANY,
'checksum_config': self._get_expected_upload_checksum_config(),
},
)
self._assert_expected_crt_http_request(
Expand All @@ -206,6 +224,7 @@ def test_upload_from_seekable_stream(self):
'send_filepath': None,
'on_progress': mock.ANY,
'on_done': mock.ANY,
'checksum_config': self._get_expected_upload_checksum_config(),
},
)
self._assert_expected_crt_http_request(
Expand Down Expand Up @@ -237,6 +256,7 @@ def test_upload_from_nonseekable_stream(self):
'send_filepath': None,
'on_progress': mock.ANY,
'on_done': mock.ANY,
'checksum_config': self._get_expected_upload_checksum_config(),
},
)
self._assert_expected_crt_http_request(
Expand All @@ -251,6 +271,90 @@ def test_upload_from_nonseekable_stream(self):
)
self._assert_subscribers_called(future)

def test_upload_override_checksum_algorithm(self):
future = self.transfer_manager.upload(
self.filename,
self.bucket,
self.key,
{'ChecksumAlgorithm': 'CRC32C'},
[self.record_subscriber],
)
future.result()

callargs_kwargs = self.s3_crt_client.make_request.call_args[1]
self.assertEqual(
callargs_kwargs,
{
'request': mock.ANY,
'type': awscrt.s3.S3RequestType.PUT_OBJECT,
'send_filepath': self.filename,
'on_progress': mock.ANY,
'on_done': mock.ANY,
'checksum_config': self._get_expected_upload_checksum_config(
algorithm=awscrt.s3.S3ChecksumAlgorithm.CRC32C
),
},
)
self._assert_expected_crt_http_request(
callargs_kwargs["request"],
expected_http_method='PUT',
expected_content_length=len(self.expected_content),
expected_missing_headers=[
'Content-MD5',
'x-amz-sdk-checksum-algorithm',
'X-Amz-Trailer',
],
)
self._assert_subscribers_called(future)

def test_upload_override_checksum_algorithm_accepts_lowercase(self):
future = self.transfer_manager.upload(
self.filename,
self.bucket,
self.key,
{'ChecksumAlgorithm': 'crc32c'},
[self.record_subscriber],
)
future.result()

callargs_kwargs = self.s3_crt_client.make_request.call_args[1]
self.assertEqual(
callargs_kwargs,
{
'request': mock.ANY,
'type': awscrt.s3.S3RequestType.PUT_OBJECT,
'send_filepath': self.filename,
'on_progress': mock.ANY,
'on_done': mock.ANY,
'checksum_config': self._get_expected_upload_checksum_config(
algorithm=awscrt.s3.S3ChecksumAlgorithm.CRC32C
),
},
)
self._assert_expected_crt_http_request(
callargs_kwargs["request"],
expected_http_method='PUT',
expected_content_length=len(self.expected_content),
expected_missing_headers=[
'Content-MD5',
'x-amz-sdk-checksum-algorithm',
'X-Amz-Trailer',
],
)
self._assert_subscribers_called(future)

def test_upload_throws_error_for_unsupported_checksum(self):
with self.assertRaisesRegex(
ValueError, 'ChecksumAlgorithm: UNSUPPORTED not supported'
):
self.transfer_manager.upload(
self.filename,
self.bucket,
self.key,
{'ChecksumAlgorithm': 'UNSUPPORTED'},
[self.record_subscriber],
)

def test_download(self):
future = self.transfer_manager.download(
self.bucket, self.key, self.filename, {}, [self.record_subscriber]
Expand All @@ -267,6 +371,7 @@ def test_download(self):
'on_progress': mock.ANY,
'on_done': mock.ANY,
'on_body': None,
'checksum_config': self._get_expected_download_checksum_config(),
},
)
# the recv_filepath will be set to a temporary file path with some
Expand Down Expand Up @@ -304,6 +409,7 @@ def test_download_to_seekable_stream(self):
'on_progress': mock.ANY,
'on_done': mock.ANY,
'on_body': mock.ANY,
'checksum_config': self._get_expected_download_checksum_config(),
},
)
self._assert_expected_crt_http_request(
Expand Down Expand Up @@ -338,6 +444,7 @@ def test_download_to_nonseekable_stream(self):
'on_progress': mock.ANY,
'on_done': mock.ANY,
'on_body': mock.ANY,
'checksum_config': self._get_expected_download_checksum_config(),
},
)
self._assert_expected_crt_http_request(
Expand Down