Skip to content

Commit

Permalink
Fix GCSObjectStore to match function signatures of other object stores (
Browse files Browse the repository at this point in the history
  • Loading branch information
eracah authored Aug 17, 2023
1 parent 10739c4 commit f34f86b
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 24 deletions.
2 changes: 1 addition & 1 deletion composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
raise NotImplementedError(f'There is no implementation for WandB load_object_store via URI. Please use '
'WandBLogger')
elif backend == 'gs':
return GCSObjectStore(uri)
return GCSObjectStore(bucket=bucket_name)
elif backend == 'oci':
return OCIObjectStore(bucket=bucket_name)
else:
Expand Down
44 changes: 24 additions & 20 deletions composer/utils/object_store/gcs_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,26 @@ class GCSObjectStore(ObjectStore):
See :ref:`guide to credentials <boto3:guide_credentials>` for more information.
Args:
gcs_root_dir (str, optional): Required. The URL to a Google Cloud Storage object, formatted as gs://bucket/path
bucket (str): The name of the Google Cloud bucket to upload to or download from.
prefix (str, optional): The prefix to use when uploading to or downloading from the bucket. Default is an empty string.
"""

def __init__(
self,
gcs_root_dir: str,
bucket: str,
prefix: str = '',
) -> None:
try:
from google.cloud.storage import Client
except ImportError as e:
raise MissingConditionalImportError('gcs', 'google.cloud.storage') from e

# Format paths
self.bucket_name = bucket.strip('/')
self.prefix = prefix.strip('/')
if self.prefix != '':
self.prefix += '/'

if 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ:
service_account_path = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
self.client = Client.from_service_account_json(service_account_path)
Expand All @@ -83,16 +91,10 @@ def __init__(
raise ValueError(f'GOOGLE_APPLICATION_CREDENTIALS needs to be set for ' +
f'service level accounts or GCS_KEY and GCS_SECRET env variables must be set.')

from composer.utils import parse_uri
backend, self.bucket_name, self.prefix = parse_uri(gcs_root_dir)
if backend == '':
raise ValueError(f"gcs_root_dir ({gcs_root_dir}) doesn't have a valid format")
self.prefix = self.prefix.lstrip('/')

try:
self.bucket = self.client.get_bucket(self.bucket_name, timeout=60.0)
except Exception as e:
_reraise_gcs_errors(gcs_root_dir, e)
_reraise_gcs_errors(self.get_uri(object_name=''), e)

def get_key(self, object_name: str) -> str:
return f'{self.prefix}{object_name}'
Expand Down Expand Up @@ -128,48 +130,50 @@ def get_object_size(self, object_name: str) -> int:
return blob.size # size in bytes

def upload_object(self,
src: Union[str, pathlib.Path],
dest: str = '',
object_name: str,
filename: Union[str, pathlib.Path],
callback: Optional[Callable[[int, int], None]] = None):
"""Uploads a file to the cloud storage bucket.
Args:
src (Union[str, pathlib.Path]): The path to the local file
dest (str, optional): The destination path in the cloud storage bucket where the file will be saved.
object_name (str, optional): The destination path in the cloud storage bucket where the file will be saved.
If not provided or an empty string is given, the file will be uploaded to the root of the bucket with the same
name as the source file. Default is an empty string.
filename (Union[str, pathlib.Path]): The path to the local file
callback: optional
"""
if callback is not None:
raise ValueError('callback is not supported in gcs upload_object()')
src = filename
dest = object_name
dest = str(src) if dest == '' else dest
blob = self.bucket.blob(self.get_key(dest))
blob.upload_from_filename(src)

def download_object(
self,
src: str,
dest: Union[str, pathlib.Path],
object_name: str,
filename: Union[str, pathlib.Path],
overwrite: bool = False,
callback: Optional[Callable[[int, int], None]] = None,
):
"""Downloads an object from the specified source in the cloud storage bucket and saves it to the given destination.
Args:
src (str): The path to the object in the cloud storage bucket that needs to be downloaded.
dest (Union[str, pathlib.Path]): The destination path where the object will be saved locally. It can be a
object_name (str): The path to the object in the cloud storage bucket that needs to be downloaded.
filename (Union[str, pathlib.Path]): The destination path where the object will be saved locally. It can be a
string representing the file path or a pathlib.Path object.
overwrite (bool, optional): If set to True, the function will overwrite the destination file if it already
exists. If set to False, and the destination file exists, a FileExistsError will be raised. Default is False.
callback (Callable[[int, int], None], optional): A callback function that can be used to track the progress of
the download. It takes two integer arguments - the number of bytes downloaded and the total size of the
object. Default is None.
object. Default is None. Unused for GCSObjectStore.
Raises:
FileExistsError: If the destination file already exists and the `overwrite` parameter is set to False.
"""
if callback is not None:
raise ValueError('callback is not supported in gcs upload_object()')
dest = filename
src = object_name

if os.path.exists(dest) and not overwrite:
raise FileExistsError(f'The file at {dest} already exists and overwrite is set to False.')
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/object_store/test_gs_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def gs_object_store(monkeypatch):
with mock.patch.dict(os.environ, {'GOOGLE_APPLICATION_CREDENTIALS': 'FAKE_CREDENTIAL'}):
mock_client = mock.MagicMock()
with mock.patch.object(Client, 'from_service_account_json', return_value=mock_client):
yield GCSObjectStore('gs://test-bucket/test-prefix/')
yield GCSObjectStore(bucket='test-bucket', prefix='test-prefix')


def test_get_uri(gs_object_store):
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_upload_object(gs_object_store, monkeypatch):
source_file_name = 'dummy-file.txt'
destination_blob_name = 'dummy-blob.txt'

gs_object_store.upload_object(source_file_name, destination_blob_name)
gs_object_store.upload_object(destination_blob_name, source_file_name)

mock_blob.upload_from_filename.assert_called_with(source_file_name)
assert mock_blob.upload_from_filename.call_count == 1
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def test_maybe_create_object_store_from_uri(monkeypatch):
maybe_create_object_store_from_uri('wandb://my-cool/checkpoint/for/my/model.pt')

maybe_create_object_store_from_uri('gs://my-bucket/path')
mock_gs_obj.assert_called_once_with('gs://my-bucket/path')
mock_gs_obj.assert_called_once_with(bucket='my-bucket')

maybe_create_object_store_from_uri('oci://my-bucket/path')
mock_oci_obj.assert_called_once_with(bucket='my-bucket')
Expand Down

0 comments on commit f34f86b

Please sign in to comment.