Skip to content

Commit

Permalink
Merge pull request #29 from appliedAI-Initiative/feature/setting-extr…
Browse files Browse the repository at this point in the history
…a-and-custom-hash-extractor

Feature: permitted setting extra field for upload and overriding hash extraction
  • Loading branch information
MischaPanch authored Jul 31, 2024
2 parents 3e4ee0b + 5ba77b0 commit 802300f
Showing 1 changed file with 98 additions and 14 deletions.
112 changes: 98 additions & 14 deletions src/accsr/remote_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from functools import cached_property
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Expand All @@ -18,6 +20,7 @@
Pattern,
Protocol,
Sequence,
Tuple,
Union,
cast,
runtime_checkable,
Expand Down Expand Up @@ -99,6 +102,7 @@ def _switch_to_dir(path: Optional[str] = None) -> Generator[None, None, None]:
class Provider(str, Enum):
GOOGLE_STORAGE = "google_storage"
S3 = "s3"
AZURE_BLOBS = "azure_blobs"


@runtime_checkable
Expand All @@ -111,7 +115,7 @@ class RemoteObjectProtocol(Protocol):
name: str
size: int
hash: int
provider: str
provider: Union[Provider, str]

def download(
self, download_path, overwrite_existing=False
Expand All @@ -134,7 +138,18 @@ def __init__(
local_path: Optional[str] = None,
remote_obj: Optional[RemoteObjectProtocol] = None,
remote_path: Optional[str] = None,
remote_obj_overridden_md5_hash: Optional[int] = None,
):
"""
:param local_path: path to the local file
:param remote_obj: remote object
:param remote_path: path to the remote file (always in linux style)
:param remote_obj_overridden_md5_hash: pass this to override the hash of the remote object
(by default, the hash attribute of the remote object is used).
Setting this might be useful for Azure blob storage, as uploads to it may be chunked,
and the md5 hash of the remote object becomes different from the hash of the local file.
The hash is used to check if the local and remote files are equal.
"""
if remote_path is not None:
remote_path = remote_path.lstrip("/")
if remote_obj is not None:
Expand Down Expand Up @@ -172,6 +187,17 @@ def __init__(
self.local_size = 0
self.local_hash = None

if remote_obj_overridden_md5_hash is not None:
if remote_obj is None:
raise ValueError(
"remote_obj_overridden_md5_hash can only be set if remote_obj is not None"
)
self.remote_hash = remote_obj_overridden_md5_hash
elif remote_obj is not None:
self.remote_hash = remote_obj.hash
else:
self.remote_hash = None

@property
def name(self):
return self.remote_path
Expand Down Expand Up @@ -205,7 +231,7 @@ def exists_on_remote(self):
@property
def equal_md5_hash_sum(self):
if self.exists_on_target:
return self.local_hash == self.remote_obj.hash
return self.local_hash == self.remote_hash
return False

def to_dict(self, make_serializable=True):
Expand Down Expand Up @@ -339,7 +365,7 @@ def add_entry(
:return: None
"""
if isinstance(synced_object, str):
synced_object = SyncObject(synced_object)
synced_object = SyncObject(local_path=synced_object)
if skip:
self.skipped_source_files.append(synced_object)
else:
Expand Down Expand Up @@ -395,10 +421,27 @@ class RemoteStorageConfig:
class RemoteStorage:
"""
Wrapper around lib-cloud for accessing remote storage services.
:param conf:
"""

def __init__(self, conf: RemoteStorageConfig):
def __init__(
self,
conf: RemoteStorageConfig,
add_extra_to_upload: Optional[Callable[[SyncObject], dict]] = None,
remote_hash_extractor: Optional[Callable[[RemoteObjectProtocol], int]] = None,
):
"""
:param conf: configuration for the remote storage
:param add_extra_to_upload: a function that takes a `SyncObject` and returns a dictionary with extra parameters
that should be passed to the `upload_object` method of the storage driver as value of the `extra` kwarg.
This can be used to set custom metadata or other parameters. For example, for Azure blob storage, one can
set the hash of the local file as metadata by using
`add_extra_to_upload = lambda sync_object: {"meta_data": {"md5": sync_object.local_hash}}`.
:param remote_hash_extractor: a function that extracts the hash from a `RemoteObjectProtocol` object.
This is useful for Azure blob storage, as uploads to may be chunked, and the md5 hash of the remote object
becomes different from the hash of the local file. In that case, one can add the hash of the local file
to the metadata using `add_extra_to_upload`, and then use this function to extract the hash from the
remote object. If not set, the `hash` attribute of the `RemoteObjectProtocol` object is used.
"""
self._bucket: Optional[Container] = None
self._conf = conf
self._provider = conf.provider
Expand All @@ -415,6 +458,8 @@ def __init__(self, conf: RemoteStorageConfig):
self.driver_kwargs = {
k: v for k, v in possible_driver_kwargs.items() if v is not None
}
self.add_extra_to_upload = add_extra_to_upload
self.remote_hash_extractor = remote_hash_extractor

def create_bucket(self, exist_ok: bool = True):
try:
Expand Down Expand Up @@ -498,15 +543,31 @@ def _execute_sync(
f"Cannot push non-existing file: {sync_object.local_path}"
)
assert sync_object.local_path is not None

extra = (
self.add_extra_to_upload(sync_object)
if self.add_extra_to_upload is not None
else None
)
remote_obj = cast(
RemoteObjectProtocol,
self.bucket.upload_object(
sync_object.local_path,
sync_object.remote_path,
extra=extra,
verify_hash=False,
),
)
return SyncObject(sync_object.local_path, remote_obj)

if self.remote_hash_extractor is not None:
remote_obj_overridden_md5_hash = self.remote_hash_extractor(remote_obj)
else:
remote_obj_overridden_md5_hash = None
return SyncObject(
sync_object.local_path,
remote_obj,
remote_obj_overridden_md5_hash=remote_obj_overridden_md5_hash,
)

elif direction == "pull":
if None in [sync_object.remote_obj, sync_object.local_path]:
Expand Down Expand Up @@ -770,32 +831,45 @@ def _get_pull_summary(
List[RemoteObjectProtocol], list(self.bucket.list_objects(full_remote_path))
)

for obj in tqdm(
for remote_obj in tqdm(
remote_objects,
desc=f"Scanning remote paths in {self.bucket.name}/{full_remote_path}: ",
):
local_path = None
collides_with = None
if (obj.size == 0) or (
self._listed_due_to_name_collision(full_remote_path, obj)
if (remote_obj.size == 0) or (
self._listed_due_to_name_collision(full_remote_path, remote_obj)
):
log.debug(
f"Skipping {obj.name} since it was listed due to name collisions"
f"Skipping {remote_obj.name} since it was listed due to name collisions"
)
skip = True
else:
relative_obj_path = self._get_relative_remote_path(obj)
relative_obj_path = self._get_relative_remote_path(remote_obj)
skip = self._should_skip(
relative_obj_path, include_regex, exclude_regex
)

if not skip:
local_path = self._get_destination_path(obj, local_base_dir)
local_path = self._get_destination_path(remote_obj, local_base_dir)
if os.path.isdir(local_path):
collides_with = local_path

remote_obj_overridden_md5_hash = (
self.remote_hash_extractor(remote_obj)
if self.remote_hash_extractor is not None
else None
)
sync_obj = SyncObject(
local_path=local_path,
remote_obj=remote_obj,
remote_obj_overridden_md5_hash=remote_obj_overridden_md5_hash,
)

summary.add_entry(
SyncObject(local_path, obj), skip=skip, collides_with=collides_with
sync_obj,
skip=skip,
collides_with=collides_with,
)

return summary
Expand Down Expand Up @@ -898,7 +972,17 @@ def _get_push_summary(

elif matched_remote_obj:
remote_obj = matched_remote_obj[0]
synced_obj = SyncObject(file, remote_obj, remote_path=remote_path)
remote_obj_overridden_md5_hash = (
self.remote_hash_extractor(remote_obj)
if self.remote_hash_extractor is not None and remote_obj is not None
else None
)
synced_obj = SyncObject(
local_path=file,
remote_obj=remote_obj,
remote_path=remote_path,
remote_obj_overridden_md5_hash=remote_obj_overridden_md5_hash,
)
summary.add_entry(
synced_obj,
collides_with=collides_with,
Expand Down

0 comments on commit 802300f

Please sign in to comment.