diff --git a/.changes/next-release/enhancement-crt-51520.json b/.changes/next-release/enhancement-crt-51520.json new file mode 100644 index 00000000..2bf4a0cf --- /dev/null +++ b/.changes/next-release/enhancement-crt-51520.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "``crt``", + "description": "Add support for uploading and downloading file-like objects using CRT transfer manager. It supports both seekable and non-seekable file-like objects." +} diff --git a/s3transfer/crt.py b/s3transfer/crt.py index 7b5d1301..eda2985f 100644 --- a/s3transfer/crt.py +++ b/s3transfer/crt.py @@ -428,19 +428,12 @@ def _crt_request_from_aws_request(self, aws_request): headers_list.append((name, str(value, 'utf-8'))) crt_headers = awscrt.http.HttpHeaders(headers_list) - # CRT requires body (if it exists) to be an I/O stream. - crt_body_stream = None - if aws_request.body: - if hasattr(aws_request.body, 'seek'): - crt_body_stream = aws_request.body - else: - crt_body_stream = BytesIO(aws_request.body) crt_request = awscrt.http.HttpRequest( method=aws_request.method, path=crt_path, headers=crt_headers, - body_stream=crt_body_stream, + body_stream=aws_request.body, ) return crt_request @@ -453,6 +446,25 @@ def _convert_to_crt_http_request(self, botocore_http_request): crt_request.headers.set("host", url_parts.netloc) if crt_request.headers.get('Content-MD5') is not None: crt_request.headers.remove("Content-MD5") + + # In general, the CRT S3 client expects a content length header. It + # only expects a missing content length header if the body is not + # seekable. However, botocore does not set the content length header + # for GetObject API requests and so we set the content length to zero + # to meet the CRT S3 client's expectation that the content length + # header is set even if there is no body. + if crt_request.headers.get('Content-Length') is None: + if botocore_http_request.body is None: + crt_request.headers.add('Content-Length', "0") + + # Botocore sets the Transfer-Encoding header when it cannot determine + # the content length of the request body (e.g. it's not seekable). + # However, CRT does not support this header, but it supports + # non-seekable bodies. So we remove this header to not cause issues + # in the downstream CRT S3 request. + if crt_request.headers.get('Transfer-Encoding') is not None: + crt_request.headers.remove('Transfer-Encoding') + return crt_request def _capture_http_request(self, request, **kwargs): @@ -555,39 +567,20 @@ def __init__(self, crt_request_serializer, os_utils): def get_make_request_args( self, request_type, call_args, coordinator, future, on_done_after_calls ): - recv_filepath = None - send_filepath = None - s3_meta_request_type = getattr( - S3RequestType, request_type.upper(), S3RequestType.DEFAULT + request_args_handler = getattr( + self, + f'_get_make_request_args_{request_type}', + self._default_get_make_request_args, ) - on_done_before_calls = [] - if s3_meta_request_type == S3RequestType.GET_OBJECT: - final_filepath = call_args.fileobj - recv_filepath = self._os_utils.get_temp_filename(final_filepath) - file_ondone_call = RenameTempFileHandler( - coordinator, final_filepath, recv_filepath, self._os_utils - ) - on_done_before_calls.append(file_ondone_call) - elif s3_meta_request_type == S3RequestType.PUT_OBJECT: - send_filepath = call_args.fileobj - data_len = self._os_utils.get_file_size(send_filepath) - call_args.extra_args["ContentLength"] = data_len - - crt_request = self._request_serializer.serialize_http_request( - request_type, future + return request_args_handler( + request_type=request_type, + call_args=call_args, + coordinator=coordinator, + future=future, + on_done_before_calls=[], + on_done_after_calls=on_done_after_calls, ) - return { - 'request': crt_request, - 'type': s3_meta_request_type, - 'recv_filepath': recv_filepath, - 'send_filepath': send_filepath, - 'on_done': self.get_crt_callback( - future, 'done', on_done_before_calls, on_done_after_calls - ), - 'on_progress': self.get_crt_callback(future, 'progress'), - } - def get_crt_callback( self, future, @@ -613,6 +606,97 @@ def invoke_all_callbacks(*args, **kwargs): return invoke_all_callbacks + def _get_make_request_args_put_object( + self, + request_type, + call_args, + coordinator, + future, + on_done_before_calls, + on_done_after_calls, + ): + send_filepath = None + if isinstance(call_args.fileobj, str): + send_filepath = call_args.fileobj + data_len = self._os_utils.get_file_size(send_filepath) + call_args.extra_args["ContentLength"] = data_len + else: + call_args.extra_args["Body"] = call_args.fileobj + + # Suppress botocore's automatic MD5 calculation by setting an override + # value that will get deleted in the BotocoreCRTRequestSerializer. + # The CRT S3 client is able automatically compute checksums as part of + # requests it makes, and the intention is to configure automatic + # checksums in a future update. + call_args.extra_args["ContentMD5"] = "override-to-be-removed" + + make_request_args = self._default_get_make_request_args( + request_type=request_type, + call_args=call_args, + coordinator=coordinator, + future=future, + on_done_before_calls=on_done_before_calls, + on_done_after_calls=on_done_after_calls, + ) + make_request_args['send_filepath'] = send_filepath + return make_request_args + + def _get_make_request_args_get_object( + self, + request_type, + call_args, + coordinator, + future, + on_done_before_calls, + on_done_after_calls, + ): + recv_filepath = None + on_body = None + if isinstance(call_args.fileobj, str): + final_filepath = call_args.fileobj + recv_filepath = self._os_utils.get_temp_filename(final_filepath) + on_done_before_calls.append( + RenameTempFileHandler( + coordinator, final_filepath, recv_filepath, self._os_utils + ) + ) + else: + on_body = OnBodyFileObjWriter(call_args.fileobj) + + make_request_args = self._default_get_make_request_args( + request_type=request_type, + call_args=call_args, + coordinator=coordinator, + future=future, + on_done_before_calls=on_done_before_calls, + on_done_after_calls=on_done_after_calls, + ) + make_request_args['recv_filepath'] = recv_filepath + make_request_args['on_body'] = on_body + return make_request_args + + def _default_get_make_request_args( + self, + request_type, + call_args, + coordinator, + future, + on_done_before_calls, + on_done_after_calls, + ): + return { + 'request': self._request_serializer.serialize_http_request( + request_type, future + ), + 'type': getattr( + S3RequestType, request_type.upper(), S3RequestType.DEFAULT + ), + 'on_done': self.get_crt_callback( + future, 'done', on_done_before_calls, on_done_after_calls + ), + 'on_progress': self.get_crt_callback(future, 'progress'), + } + class RenameTempFileHandler: def __init__(self, coordinator, final_filename, temp_filename, osutil): @@ -642,3 +726,11 @@ def __init__(self, coordinator): def __call__(self, **kwargs): self._coordinator.set_done_callbacks_complete() + + +class OnBodyFileObjWriter: + def __init__(self, fileobj): + self._fileobj = fileobj + + def __call__(self, chunk, **kwargs): + self._fileobj.write(chunk) diff --git a/tests/__init__.py b/tests/__init__.py index e36c4936..03590fef 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -509,6 +509,9 @@ def write(self, b): def read(self, n=-1): return self._data.read(n) + def readinto(self, b): + return self._data.readinto(b) + class NonSeekableWriter(io.RawIOBase): def __init__(self, fileobj): diff --git a/tests/functional/test_crt.py b/tests/functional/test_crt.py index 0ead2959..152949d2 100644 --- a/tests/functional/test_crt.py +++ b/tests/functional/test_crt.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import fnmatch +import io import threading import time from concurrent.futures import Future @@ -18,7 +19,15 @@ from botocore.session import Session from s3transfer.subscribers import BaseSubscriber -from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest +from tests import ( + HAS_CRT, + FileCreator, + NonSeekableReader, + NonSeekableWriter, + mock, + requires_crt, + unittest, +) if HAS_CRT: import awscrt @@ -60,13 +69,19 @@ def setUp(self): self.region = 'us-west-2' self.bucket = "test_bucket" self.key = "test_key" + self.expected_content = b'my content' + self.expected_download_content = b'new content' self.files = FileCreator() - self.filename = self.files.create_file('myfile', 'my content') + self.filename = self.files.create_file( + 'myfile', self.expected_content, mode='wb' + ) self.expected_path = "/" + self.bucket + "/" + self.key self.expected_host = "s3.%s.amazonaws.com" % (self.region) self.s3_request = mock.Mock(awscrt.s3.S3Request) self.s3_crt_client = mock.Mock(awscrt.s3.S3Client) - self.s3_crt_client.make_request.return_value = self.s3_request + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect + ) self.session = Session() self.session.set_config_variable('region', self.region) self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( @@ -81,6 +96,42 @@ def setUp(self): def tearDown(self): self.files.remove_all() + def _assert_expected_crt_http_request( + self, + crt_http_request, + expected_http_method='GET', + expected_host=None, + expected_path=None, + expected_body_content=None, + expected_content_length=None, + expected_missing_headers=None, + ): + if expected_host is None: + expected_host = self.expected_host + if expected_path is None: + expected_path = self.expected_path + self.assertEqual(crt_http_request.method, expected_http_method) + self.assertEqual(crt_http_request.headers.get("host"), expected_host) + self.assertEqual(crt_http_request.path, expected_path) + if expected_body_content is not None: + # Note: The underlying CRT awscrt.io.InputStream does not expose + # a public read method so we have to reach into the private, + # underlying stream to determine the content. We should update + # to use a public interface if a public interface is ever exposed. + self.assertEqual( + crt_http_request.body_stream._stream.read(), + expected_body_content, + ) + if expected_content_length is not None: + self.assertEqual( + crt_http_request.headers.get('Content-Length'), + str(expected_content_length), + ) + if expected_missing_headers is not None: + header_names = [header[0] for header in crt_http_request.headers] + for expected_missing_header in expected_missing_headers: + self.assertNotIn(expected_missing_header, header_names) + def _assert_subscribers_called(self, expected_future=None): self.assertTrue(self.record_subscriber.on_queued_called) self.assertTrue(self.record_subscriber.on_done_called) @@ -99,47 +150,125 @@ def _invoke_done_callbacks(self, **kwargs): on_done(error=None) def _simulate_file_download(self, recv_filepath): - self.files.create_file(recv_filepath, "fake response") + self.files.create_file( + recv_filepath, self.expected_download_content, mode='wb' + ) + + def _simulate_on_body_download(self, on_body_callback): + on_body_callback(chunk=self.expected_download_content, offset=0) def _simulate_make_request_side_effect(self, **kwargs): if kwargs.get('recv_filepath'): self._simulate_file_download(kwargs['recv_filepath']) + if kwargs.get('on_body'): + self._simulate_on_body_download(kwargs['on_body']) self._invoke_done_callbacks() - return mock.DEFAULT + return self.s3_request def test_upload(self): - self.s3_crt_client.make_request.side_effect = ( - self._simulate_make_request_side_effect - ) future = self.transfer_manager.upload( self.filename, self.bucket, self.key, {}, [self.record_subscriber] ) future.result() - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] - self.assertEqual(callargs_kwargs["send_filepath"], self.filename) - self.assertIsNone(callargs_kwargs["recv_filepath"]) + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.PUT_OBJECT + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.PUT_OBJECT, + 'send_filepath': self.filename, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='PUT', + expected_content_length=len(self.expected_content), + expected_missing_headers=['Content-MD5'], ) - crt_request = callargs_kwargs["request"] - self.assertEqual("PUT", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) self._assert_subscribers_called(future) - def test_download(self): - self.s3_crt_client.make_request.side_effect = ( - self._simulate_make_request_side_effect + def test_upload_from_seekable_stream(self): + with open(self.filename, 'rb') as f: + future = self.transfer_manager.upload( + f, self.bucket, self.key, {}, [self.record_subscriber] + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.PUT_OBJECT, + 'send_filepath': None, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='PUT', + expected_body_content=self.expected_content, + expected_content_length=len(self.expected_content), + expected_missing_headers=['Content-MD5'], + ) + self._assert_subscribers_called(future) + + def test_upload_from_nonseekable_stream(self): + nonseekable_stream = NonSeekableReader(self.expected_content) + future = self.transfer_manager.upload( + nonseekable_stream, + self.bucket, + self.key, + {}, + [self.record_subscriber], + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.PUT_OBJECT, + 'send_filepath': None, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + }, ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='PUT', + expected_body_content=self.expected_content, + expected_missing_headers=[ + 'Content-MD5', + 'Content-Length', + 'Transfer-Encoding', + ], + ) + self._assert_subscribers_called(future) + + def test_download(self): future = self.transfer_manager.download( self.bucket, self.key, self.filename, {}, [self.record_subscriber] ) future.result() - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.GET_OBJECT, + 'recv_filepath': mock.ANY, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'on_body': None, + }, + ) # the recv_filepath will be set to a temporary file path with some # random suffix self.assertTrue( @@ -148,42 +277,109 @@ def test_download(self): f'{self.filename}.*', ) ) - self.assertIsNone(callargs_kwargs["send_filepath"]) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='GET', + expected_content_length=0, + ) + self._assert_subscribers_called(future) + with open(self.filename, 'rb') as f: + # Check the fake response overwrites the file because of download + self.assertEqual(f.read(), self.expected_download_content) + + def test_download_to_seekable_stream(self): + with open(self.filename, 'wb') as f: + future = self.transfer_manager.download( + self.bucket, self.key, f, {}, [self.record_subscriber] + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.GET_OBJECT + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.GET_OBJECT, + 'recv_filepath': None, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'on_body': mock.ANY, + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='GET', + expected_content_length=0, ) - crt_request = callargs_kwargs["request"] - self.assertEqual("GET", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) self._assert_subscribers_called(future) with open(self.filename, 'rb') as f: # Check the fake response overwrites the file because of download - self.assertEqual(f.read(), b'fake response') + self.assertEqual(f.read(), self.expected_download_content) - def test_delete(self): - self.s3_crt_client.make_request.side_effect = ( - self._simulate_make_request_side_effect + def test_download_to_nonseekable_stream(self): + underlying_stream = io.BytesIO() + nonseekable_stream = NonSeekableWriter(underlying_stream) + future = self.transfer_manager.download( + self.bucket, + self.key, + nonseekable_stream, + {}, + [self.record_subscriber], + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.GET_OBJECT, + 'recv_filepath': None, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'on_body': mock.ANY, + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='GET', + expected_content_length=0, + ) + self._assert_subscribers_called(future) + self.assertEqual( + underlying_stream.getvalue(), self.expected_download_content ) + + def test_delete(self): future = self.transfer_manager.delete( self.bucket, self.key, {}, [self.record_subscriber] ) future.result() - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] - self.assertIsNone(callargs_kwargs["send_filepath"]) - self.assertIsNone(callargs_kwargs["recv_filepath"]) + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.DEFAULT + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.DEFAULT, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='DELETE', + expected_content_length=0, ) - crt_request = callargs_kwargs["request"] - self.assertEqual("DELETE", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) self._assert_subscribers_called(future) def test_blocks_when_max_requests_processes_reached(self): + self.s3_crt_client.make_request.return_value = self.s3_request + # We simulate blocking by not invoking the on_done callbacks for + # all of the requests we send. The default side effect invokes all + # callbacks so we need to unset the side effect to avoid on_done from + # being called in the child threads. + self.s3_crt_client.make_request.side_effect = None futures = [] callargs = (self.bucket, self.key, self.filename, {}, []) max_request_processes = 128 # the hard coded max processes diff --git a/tests/integration/test_crt.py b/tests/integration/test_crt.py index 157ae2dc..7881fa63 100644 --- a/tests/integration/test_crt.py +++ b/tests/integration/test_crt.py @@ -11,11 +11,18 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import glob +import io import os from s3transfer.subscribers import BaseSubscriber from s3transfer.utils import OSUtils -from tests import HAS_CRT, assert_files_equal, requires_crt +from tests import ( + HAS_CRT, + NonSeekableReader, + NonSeekableWriter, + assert_files_equal, + requires_crt, +) from tests.integration import BaseTransferManagerIntegTest if HAS_CRT: @@ -44,13 +51,18 @@ def on_done(self, **kwargs): class TestCRTS3Transfers(BaseTransferManagerIntegTest): """Tests for the high level s3transfer based on CRT implementation.""" + def setUp(self): + super().setUp() + self.s3_key = 's3key.txt' + self.download_path = os.path.join(self.files.rootdir, 'download.txt') + def _create_s3_transfer(self): self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( - self.session + self.session, client_kwargs={'region_name': self.region} ) credetial_resolver = self.session.get_component('credential_provider') self.s3_crt_client = s3transfer.crt.create_s3_crt_client( - self.session.get_config_variable("region"), credetial_resolver + self.region, credetial_resolver ) self.record_subscriber = RecordingSubscriber() self.osutil = OSUtils() @@ -58,6 +70,40 @@ def _create_s3_transfer(self): self.s3_crt_client, self.request_serializer ) + def _upload_with_crt_transfer_manager(self, fileobj, key=None): + if key is None: + key = self.s3_key + self.addCleanup(self.delete_object, key) + with self._create_s3_transfer() as transfer: + future = transfer.upload( + fileobj, + self.bucket_name, + key, + subscribers=[self.record_subscriber], + ) + future.result() + + def _download_with_crt_transfer_manager(self, fileobj, key=None): + if key is None: + key = self.s3_key + self.addCleanup(self.delete_object, key) + with self._create_s3_transfer() as transfer: + future = transfer.download( + self.bucket_name, + key, + fileobj, + subscribers=[self.record_subscriber], + ) + future.result() + + def _assert_expected_s3_object(self, key, expected_size=None): + self.assertTrue(self.object_exists(key)) + if expected_size is not None: + response = self.client.head_object( + Bucket=self.bucket_name, Key=key + ) + self.assertEqual(response['ContentLength'], expected_size) + def _assert_has_public_read_acl(self, response): grants = response['Grants'] public_read = [ @@ -176,6 +222,43 @@ def test_upload_file_above_threshold_with_ssec(self): self.assertEqual(response['SSECustomerAlgorithm'], 'AES256') self._assert_subscribers_called(file_size) + def test_upload_seekable_stream(self): + size = 1024 * 1024 + self._upload_with_crt_transfer_manager(io.BytesIO(b'0' * size)) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_multipart_upload_seekable_stream(self): + size = 20 * 1024 * 1024 + self._upload_with_crt_transfer_manager(io.BytesIO(b'0' * size)) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_upload_nonseekable_stream(self): + size = 1024 * 1024 + self._upload_with_crt_transfer_manager(NonSeekableReader(b'0' * size)) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_multipart_upload_nonseekable_stream(self): + size = 20 * 1024 * 1024 + self._upload_with_crt_transfer_manager(NonSeekableReader(b'0' * size)) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_upload_empty_file(self): + size = 0 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self._upload_with_crt_transfer_manager(filename) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_upload_empty_stream(self): + size = 0 + self._upload_with_crt_transfer_manager(io.BytesIO(b'')) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + def test_can_send_extra_params_on_download(self): # We're picking the customer provided sse feature # of S3 to test the extra_args functionality of @@ -244,6 +327,65 @@ def test_download_above_threshold(self): file_size = self.osutil.get_file_size(download_path) self._assert_subscribers_called(file_size) + def test_download_seekable_stream(self): + size = 1024 * 1024 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(f) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_multipart_download_seekable_stream(self): + size = 20 * 1024 * 1024 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(f) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_download_nonseekable_stream(self): + size = 1024 * 1024 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(NonSeekableWriter(f)) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_multipart_download_nonseekable_stream(self): + size = 20 * 1024 * 1024 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(NonSeekableWriter(f)) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_download_empty_file(self): + size = 0 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + self._download_with_crt_transfer_manager(self.download_path) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_download_empty_stream(self): + size = 0 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(f) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + def test_delete(self): transfer = self._create_s3_transfer() filename = self.files.create_file_with_size( diff --git a/tests/unit/test_crt.py b/tests/unit/test_crt.py index b6ad3245..aadd3827 100644 --- a/tests/unit/test_crt.py +++ b/tests/unit/test_crt.py @@ -10,6 +10,8 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +import io + from botocore.credentials import CredentialResolver, ReadOnlyCredentials from botocore.session import Session @@ -171,3 +173,12 @@ def test_set_exception_can_override_previous_exception(self): self.future.set_exception(CustomFutureException()) with self.assertRaises(CustomFutureException): self.future.result() + + +@requires_crt +class TestOnBodyFileObjWriter(unittest.TestCase): + def test_call(self): + fileobj = io.BytesIO() + writer = s3transfer.crt.OnBodyFileObjWriter(fileobj) + writer(chunk=b'content') + self.assertEqual(fileobj.getvalue(), b'content')