Skip to content

Commit

Permalink
Merge branch 'release/v0.4.8'
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Panchenko committed Jul 31, 2024
2 parents b0e7202 + e2bc043 commit 83632cb
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.7
current_version = 0.4.8
commit = False
tag = False
allow_dirty = False
Expand Down
2 changes: 1 addition & 1 deletion latest_release_notes.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Release Notes: 0.4.7
# Release Notes: 0.4.8

## Bugfix release
- Fixed bugs in RemoteStorage related to name collisions and serialization.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
package_dir={"": "src"},
packages=find_packages(where="src"),
include_package_data=True,
version="0.4.7",
version="0.4.8",
description="Utils for accessing data from anywhere",
install_requires=open("requirements.txt").readlines(),
setup_requires=["wheel"],
Expand Down
2 changes: 1 addition & 1 deletion src/accsr/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.7"
__version__ = "0.4.8"
112 changes: 98 additions & 14 deletions src/accsr/remote_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from functools import cached_property
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Expand All @@ -18,6 +20,7 @@
Pattern,
Protocol,
Sequence,
Tuple,
Union,
cast,
runtime_checkable,
Expand Down Expand Up @@ -99,6 +102,7 @@ def _switch_to_dir(path: Optional[str] = None) -> Generator[None, None, None]:
class Provider(str, Enum):
GOOGLE_STORAGE = "google_storage"
S3 = "s3"
AZURE_BLOBS = "azure_blobs"


@runtime_checkable
Expand All @@ -111,7 +115,7 @@ class RemoteObjectProtocol(Protocol):
name: str
size: int
hash: int
provider: str
provider: Union[Provider, str]

def download(
self, download_path, overwrite_existing=False
Expand All @@ -134,7 +138,18 @@ def __init__(
local_path: Optional[str] = None,
remote_obj: Optional[RemoteObjectProtocol] = None,
remote_path: Optional[str] = None,
remote_obj_overridden_md5_hash: Optional[int] = None,
):
"""
:param local_path: path to the local file
:param remote_obj: remote object
:param remote_path: path to the remote file (always in linux style)
:param remote_obj_overridden_md5_hash: pass this to override the hash of the remote object
(by default, the hash attribute of the remote object is used).
Setting this might be useful for Azure blob storage, as uploads to it may be chunked,
and the md5 hash of the remote object becomes different from the hash of the local file.
The hash is used to check if the local and remote files are equal.
"""
if remote_path is not None:
remote_path = remote_path.lstrip("/")
if remote_obj is not None:
Expand Down Expand Up @@ -172,6 +187,17 @@ def __init__(
self.local_size = 0
self.local_hash = None

if remote_obj_overridden_md5_hash is not None:
if remote_obj is None:
raise ValueError(
"remote_obj_overridden_md5_hash can only be set if remote_obj is not None"
)
self.remote_hash = remote_obj_overridden_md5_hash
elif remote_obj is not None:
self.remote_hash = remote_obj.hash
else:
self.remote_hash = None

@property
def name(self):
return self.remote_path
Expand Down Expand Up @@ -205,7 +231,7 @@ def exists_on_remote(self):
@property
def equal_md5_hash_sum(self):
if self.exists_on_target:
return self.local_hash == self.remote_obj.hash
return self.local_hash == self.remote_hash
return False

def to_dict(self, make_serializable=True):
Expand Down Expand Up @@ -339,7 +365,7 @@ def add_entry(
:return: None
"""
if isinstance(synced_object, str):
synced_object = SyncObject(synced_object)
synced_object = SyncObject(local_path=synced_object)
if skip:
self.skipped_source_files.append(synced_object)
else:
Expand Down Expand Up @@ -395,10 +421,27 @@ class RemoteStorageConfig:
class RemoteStorage:
"""
Wrapper around lib-cloud for accessing remote storage services.
:param conf:
"""

def __init__(self, conf: RemoteStorageConfig):
def __init__(
self,
conf: RemoteStorageConfig,
add_extra_to_upload: Optional[Callable[[SyncObject], dict]] = None,
remote_hash_extractor: Optional[Callable[[RemoteObjectProtocol], int]] = None,
):
"""
:param conf: configuration for the remote storage
:param add_extra_to_upload: a function that takes a `SyncObject` and returns a dictionary with extra parameters
that should be passed to the `upload_object` method of the storage driver as value of the `extra` kwarg.
This can be used to set custom metadata or other parameters. For example, for Azure blob storage, one can
set the hash of the local file as metadata by using
`add_extra_to_upload = lambda sync_object: {"meta_data": {"md5": sync_object.local_hash}}`.
:param remote_hash_extractor: a function that extracts the hash from a `RemoteObjectProtocol` object.
This is useful for Azure blob storage, as uploads to may be chunked, and the md5 hash of the remote object
becomes different from the hash of the local file. In that case, one can add the hash of the local file
to the metadata using `add_extra_to_upload`, and then use this function to extract the hash from the
remote object. If not set, the `hash` attribute of the `RemoteObjectProtocol` object is used.
"""
self._bucket: Optional[Container] = None
self._conf = conf
self._provider = conf.provider
Expand All @@ -415,6 +458,8 @@ def __init__(self, conf: RemoteStorageConfig):
self.driver_kwargs = {
k: v for k, v in possible_driver_kwargs.items() if v is not None
}
self.add_extra_to_upload = add_extra_to_upload
self.remote_hash_extractor = remote_hash_extractor

def create_bucket(self, exist_ok: bool = True):
try:
Expand Down Expand Up @@ -498,15 +543,31 @@ def _execute_sync(
f"Cannot push non-existing file: {sync_object.local_path}"
)
assert sync_object.local_path is not None

extra = (
self.add_extra_to_upload(sync_object)
if self.add_extra_to_upload is not None
else None
)
remote_obj = cast(
RemoteObjectProtocol,
self.bucket.upload_object(
sync_object.local_path,
sync_object.remote_path,
extra=extra,
verify_hash=False,
),
)
return SyncObject(sync_object.local_path, remote_obj)

if self.remote_hash_extractor is not None:
remote_obj_overridden_md5_hash = self.remote_hash_extractor(remote_obj)
else:
remote_obj_overridden_md5_hash = None
return SyncObject(
sync_object.local_path,
remote_obj,
remote_obj_overridden_md5_hash=remote_obj_overridden_md5_hash,
)

elif direction == "pull":
if None in [sync_object.remote_obj, sync_object.local_path]:
Expand Down Expand Up @@ -770,32 +831,45 @@ def _get_pull_summary(
List[RemoteObjectProtocol], list(self.bucket.list_objects(full_remote_path))
)

for obj in tqdm(
for remote_obj in tqdm(
remote_objects,
desc=f"Scanning remote paths in {self.bucket.name}/{full_remote_path}: ",
):
local_path = None
collides_with = None
if (obj.size == 0) or (
self._listed_due_to_name_collision(full_remote_path, obj)
if (remote_obj.size == 0) or (
self._listed_due_to_name_collision(full_remote_path, remote_obj)
):
log.debug(
f"Skipping {obj.name} since it was listed due to name collisions"
f"Skipping {remote_obj.name} since it was listed due to name collisions"
)
skip = True
else:
relative_obj_path = self._get_relative_remote_path(obj)
relative_obj_path = self._get_relative_remote_path(remote_obj)
skip = self._should_skip(
relative_obj_path, include_regex, exclude_regex
)

if not skip:
local_path = self._get_destination_path(obj, local_base_dir)
local_path = self._get_destination_path(remote_obj, local_base_dir)
if os.path.isdir(local_path):
collides_with = local_path

remote_obj_overridden_md5_hash = (
self.remote_hash_extractor(remote_obj)
if self.remote_hash_extractor is not None
else None
)
sync_obj = SyncObject(
local_path=local_path,
remote_obj=remote_obj,
remote_obj_overridden_md5_hash=remote_obj_overridden_md5_hash,
)

summary.add_entry(
SyncObject(local_path, obj), skip=skip, collides_with=collides_with
sync_obj,
skip=skip,
collides_with=collides_with,
)

return summary
Expand Down Expand Up @@ -898,7 +972,17 @@ def _get_push_summary(

elif matched_remote_obj:
remote_obj = matched_remote_obj[0]
synced_obj = SyncObject(file, remote_obj, remote_path=remote_path)
remote_obj_overridden_md5_hash = (
self.remote_hash_extractor(remote_obj)
if self.remote_hash_extractor is not None and remote_obj is not None
else None
)
synced_obj = SyncObject(
local_path=file,
remote_obj=remote_obj,
remote_path=remote_path,
remote_obj_overridden_md5_hash=remote_obj_overridden_md5_hash,
)
summary.add_entry(
synced_obj,
collides_with=collides_with,
Expand Down

0 comments on commit 83632cb

Please sign in to comment.