diff --git a/composer/utils/object_store/gcs_object_store.py b/composer/utils/object_store/gcs_object_store.py index 2a56324a07..5076ddb7fa 100644 --- a/composer/utils/object_store/gcs_object_store.py +++ b/composer/utils/object_store/gcs_object_store.py @@ -171,11 +171,12 @@ def upload_object( if callback is not None: raise ValueError('callback is not supported in gcs upload_object()') + from google.cloud.storage.retry import DEFAULT_RETRY src = filename dest = object_name dest = str(src) if dest == '' else dest blob = self.bucket.blob(self.get_key(dest)) - blob.upload_from_filename(src) + blob.upload_from_filename(src, retry=DEFAULT_RETRY) # pyright: ignore[reportGeneralTypeIssues] def download_object( self, diff --git a/tests/utils/object_store/test_gs_object_store.py b/tests/utils/object_store/test_gs_object_store.py index 1005f22c50..60b004f1ee 100644 --- a/tests/utils/object_store/test_gs_object_store.py +++ b/tests/utils/object_store/test_gs_object_store.py @@ -116,7 +116,8 @@ def test_upload_object(gs_object_store, monkeypatch): gs_object_store.upload_object(destination_blob_name, source_file_name) - mock_blob.upload_from_filename.assert_called_with(source_file_name) + from google.cloud.storage.retry import DEFAULT_RETRY + mock_blob.upload_from_filename.assert_called_with(source_file_name, retry=DEFAULT_RETRY) assert mock_blob.upload_from_filename.call_count == 1