diff --git a/requirements.txt b/requirements.txt index 33e5a09..f7f2780 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ pyyaml humanfriendly pydantic<2.0.0 requests -boto3>=1.23.0,<2.0.0 +s3fs>=2023.1.0 diff --git a/tarn/__version__.py b/tarn/__version__.py index 2c7bffb..2d7893e 100644 --- a/tarn/__version__.py +++ b/tarn/__version__.py @@ -1 +1 @@ -__version__ = '0.12.0' +__version__ = '0.13.0' diff --git a/tarn/compat.py b/tarn/compat.py index 9c414c5..ceb6988 100644 --- a/tarn/compat.py +++ b/tarn/compat.py @@ -25,14 +25,9 @@ from typing import Self except ImportError: Self = Any -try: - # just a convenience lib for typing - from mypy_boto3_s3 import S3Client -except ImportError: - S3Client = Any # we will try to support both versions 1 and 2 while they are more or less popular try: - from pydantic import field_validator as _field_validator, model_validator, BaseModel + from pydantic import BaseModel, field_validator as _field_validator, model_validator def field_validator(*args, always=None, **kwargs): @@ -55,7 +50,7 @@ class NoExtra(BaseModel): except ImportError: - from pydantic import root_validator, validator as _field_validator, BaseModel + from pydantic import BaseModel, root_validator, validator as _field_validator def model_validator(mode: str): diff --git a/tarn/location/__init__.py b/tarn/location/__init__.py index 02e0b18..5c72c66 100644 --- a/tarn/location/__init__.py +++ b/tarn/location/__init__.py @@ -1,12 +1,12 @@ from .disk_dict import DiskDict from .fanout import Fanout -from .interface import Location, Writable +from .interface import Location, ReadOnly from .levels import Level, Levels from .nginx import Nginx from .redis import RedisLocation from .s3 import S3 -from .ssh import SCP, SFTP from .small import Small +from .ssh import SCP, SFTP # TODO: deprecated SmallLocation = Small diff --git a/tarn/location/disk_dict/config.py b/tarn/location/disk_dict/config.py index 39abb87..2030f60 100644 --- a/tarn/location/disk_dict/config.py +++ b/tarn/location/disk_dict/config.py @@ -8,7 +8,7 @@ from pydantic import Field from yaml import safe_dump, safe_load -from ...compat import field_validator, get_path_group, model_validator, model_validate, model_dump, NoExtra +from ...compat import NoExtra, field_validator, get_path_group, model_dump, model_validate, model_validator from ...interface import PathOrStr from ...tools import DummyLabels, DummyLocker, DummySize, DummyUsage, LabelsStorage, Locker, SizeTracker, UsageTracker from ...utils import mkdir diff --git a/tarn/location/disk_dict/location.py b/tarn/location/disk_dict/location.py index fe7f51d..7ad542a 100644 --- a/tarn/location/disk_dict/location.py +++ b/tarn/location/disk_dict/location.py @@ -12,16 +12,16 @@ from ...digest import key_to_relative from ...exceptions import CollisionError, StorageCorruption from ...interface import Key, MaybeLabels, MaybeValue, PathOrStr, Value -from ...tools import Locker, SizeTracker, UsageTracker, LabelsStorage +from ...tools import LabelsStorage, Locker, SizeTracker, UsageTracker from ...utils import adjust_permissions, create_folders, get_size, match_buffers, match_files -from ..interface import Meta, Writable -from .config import StorageConfig, init_storage, load_config, root_params, CONFIG_NAME +from ..interface import Location, Meta +from .config import CONFIG_NAME, StorageConfig, init_storage, load_config, root_params logger = logging.getLogger(__name__) MaybePath = Optional[Path] -class DiskDict(Writable): +class DiskDict(Location): def __init__(self, root: PathOrStr, levels: Optional[Sequence[int]] = None): root = Path(root) config = root / CONFIG_NAME @@ -65,7 +65,7 @@ def contents(self) -> Iterable[Tuple[Key, Self, Meta]]: key = bytes.fromhex(''.join(file.relative_to(self.root).parts)) with self.locker.read(key): - yield key, self, str(self.root) + yield key, self, DiskDictMeta(key, self.usage_tracker, self.labels) @contextmanager def read(self, key: Key, return_labels: bool) -> ContextManager[Union[None, Value, Tuple[Value, MaybeLabels]]]: @@ -77,7 +77,7 @@ def read(self, key: Key, return_labels: bool) -> ContextManager[Union[None, Valu if file.is_dir(): file = file / 'data' - self.usage_tracker.update(key) + self.touch(key) try: if return_labels: yield file, self.labels.get(key) @@ -142,7 +142,7 @@ def write(self, key: Key, value: Value, labels: MaybeLabels) -> ContextManager[M # metadata self.size_tracker.inc(get_size(file)) - self.usage_tracker.update(key) + self.touch(key) self.labels.update(key, labels) yield file @@ -172,6 +172,13 @@ def delete(self, key: Key) -> bool: return True + def touch(self, key: Key) -> bool: + file = self._key_to_path(key) + if not file.exists(): + return False + self.usage_tracker.update(key) + return True + def _key_to_path(self, key: Key): assert key, 'The key must be non-empty' return self.root / key_to_relative(key, self.levels) @@ -195,7 +202,7 @@ def __eq__(self, other): class DiskDictMeta(Meta): - def __init__(self, key, usage, labels): + def __init__(self, key: Key, usage: UsageTracker, labels: LabelsStorage): self._key, self._usage, self._labels = key, usage, labels @property diff --git a/tarn/location/fanout.py b/tarn/location/fanout.py index 0ee5329..54abee6 100644 --- a/tarn/location/fanout.py +++ b/tarn/location/fanout.py @@ -4,10 +4,10 @@ from ..compat import Self from ..interface import Key, Keys, MaybeLabels, MaybeValue, Meta, Value from ..utils import is_binary_io -from .interface import Location, Writable +from .interface import Location -class Fanout(Writable): +class Fanout(Location): def __init__(self, *locations: Location): hashes = _get_not_none(locations, 'hash') assert len(hashes) <= 1, hashes @@ -33,27 +33,25 @@ def read(self, key: Key, return_labels: bool) -> ContextManager[Union[None, Valu @contextmanager def write(self, key: Key, value: Value, labels: MaybeLabels) -> ContextManager[MaybeValue]: for location in self._locations: - if isinstance(location, Writable): - if is_binary_io(value): - offset = value.tell() - leave = False - with location.write(key, value, labels) as written: - if written is not None: - leave = True - yield written - # see more info on the "leave" trick in `Levels` - if leave: - return - if is_binary_io(value) and offset != value.tell(): - value.seek(offset) + if is_binary_io(value): + offset = value.tell() + leave = False + with location.write(key, value, labels) as written: + if written is not None: + leave = True + yield written + # see more info on the "leave" trick in `Levels` + if leave: + return + if is_binary_io(value) and offset != value.tell(): + value.seek(offset) yield None def delete(self, key: Key) -> bool: deleted = False for location in self._locations: - if isinstance(location, Writable): - if location.delete(key): - deleted = True + if location.delete(key): + deleted = True return deleted @@ -78,6 +76,12 @@ def contents(self) -> Iterable[Tuple[Key, Self, Meta]]: for location in self._locations: yield from location.contents() + def touch(self, key: Key): + touched = False + for location in self._locations: + touched = location.touch(key) + return touched + def _get_not_none(seq, name): result = set() diff --git a/tarn/location/interface.py b/tarn/location/interface.py index 219d103..4766ecf 100644 --- a/tarn/location/interface.py +++ b/tarn/location/interface.py @@ -41,8 +41,6 @@ def read_batch(self, keys: Keys) -> Iterable[Tuple[Key, Union[None, Tuple[Value, def contents(self) -> Iterable[Tuple[Key, Self, Meta]]: pass - -class Writable(Location, ABC): @abstractmethod def write(self, key: Key, value: Value, labels: MaybeLabels) -> ContextManager[MaybeValue]: pass @@ -51,5 +49,23 @@ def write(self, key: Key, value: Value, labels: MaybeLabels) -> ContextManager[M def delete(self, key: Key) -> bool: pass + @abstractmethod + def touch(self, key: Key) -> bool: + """ + Update usage date for a given `key` + """ + pass + + +class ReadOnly(Location): + def write(self, key: Key, value: Value, labels: MaybeLabels) -> ContextManager[MaybeValue]: + yield None + + def delete(self, key: Key) -> bool: + return False + + def touch(self, key: Key) -> bool: + return False + Locations = Sequence[Location] diff --git a/tarn/location/levels.py b/tarn/location/levels.py index e183118..14cc3a9 100644 --- a/tarn/location/levels.py +++ b/tarn/location/levels.py @@ -5,7 +5,7 @@ from ..compat import Self from ..interface import Key, Keys, MaybeLabels, MaybeValue, Meta, Value -from ..location import Location, Writable +from ..location import Location from ..utils import is_binary_io from .fanout import _get_not_none @@ -17,7 +17,7 @@ class Level(NamedTuple): name: Optional[str] = None -class Levels(Writable): +class Levels(Location): def __init__(self, *levels: Union[Level, Location]): levels = [ level if isinstance(level, Level) else Level(level, write=True, replicate=True) @@ -54,29 +54,27 @@ def read(self, key: Key, return_labels: bool) -> ContextManager[Union[None, Valu def write(self, key: Key, value: Value, labels: MaybeLabels) -> ContextManager[MaybeValue]: for config in self._levels: location = config.location - if config.write and isinstance(location, Writable): - if is_binary_io(value): - offset = value.tell() - leave = False - with location.write(key, value, labels) as written: - if written is not None: - # we must leave the loop after the first successful write - leave = True - yield written - # but the context manager might have silenced the error, so we need an extra return here - if leave: - return - if is_binary_io(value) and offset != value.tell(): - value.seek(offset) + if is_binary_io(value): + offset = value.tell() + leave = False + with location.write(key, value, labels) as written: + if written is not None: + # we must leave the loop after the first successful write + leave = True + yield written + # but the context manager might have silenced the error, so we need an extra return here + if leave: + return + if is_binary_io(value) and offset != value.tell(): + value.seek(offset) yield None def delete(self, key: Key) -> bool: deleted = False for config in self._levels: - if config.write and isinstance(config.location, Writable): - if config.location.delete(key): - deleted = True + if config.location.delete(key): + deleted = True return deleted @@ -102,11 +100,17 @@ def contents(self) -> Iterable[Tuple[Key, Self, Meta]]: for level in self._levels: yield from level.location.contents() + def touch(self, key: Key): + touched = False + for level in self._levels: + touched = level.location.touch(key) + return touched + @contextmanager def _replicate(self, key: Key, value: Value, labels: MaybeLabels, index: int): for config in islice(self._levels, index): location = config.location - if config.replicate and isinstance(location, Writable): + if config.replicate: if is_binary_io(value): offset = value.tell() with _propagate_exception(location.write(key, value, labels)) as written: diff --git a/tarn/location/nginx.py b/tarn/location/nginx.py index 3b73d96..ae62105 100644 --- a/tarn/location/nginx.py +++ b/tarn/location/nginx.py @@ -8,10 +8,10 @@ from ..config import load_config_buffer from ..digest import key_to_relative from ..interface import MaybeLabels, Meta -from .interface import Key, Keys, Location, MaybeValue +from .interface import Key, Keys, MaybeValue, ReadOnly -class Nginx(Location): +class Nginx(ReadOnly): def __init__(self, url: str): if not url.endswith('/'): url += '/' diff --git a/tarn/location/redis.py b/tarn/location/redis.py index 5198549..f1aa0b5 100644 --- a/tarn/location/redis.py +++ b/tarn/location/redis.py @@ -8,10 +8,10 @@ from ..digest import value_to_buffer from ..exceptions import CollisionError, StorageCorruption from ..interface import Key, MaybeLabels, Meta, Value -from .interface import Writable +from .interface import Location -class RedisLocation(Writable): +class RedisLocation(Location): def __init__(self, *args, prefix: AnyStr = b'', **kwargs): # TODO: legacy mode if len(args) == 2 and isinstance(args[1], str) and not prefix: @@ -41,9 +41,9 @@ def read(self, key: Key, return_labels: bool) -> ContextManager: if content is None: yield return - self.update_usage_date(key) + self.touch(key) if return_labels: - labels = self.get_labels(key) + labels = self._get_labels(key) with value_to_buffer(self.redis.get(content_key)) as buffer: yield buffer, labels return @@ -60,8 +60,8 @@ def write(self, key: Key, value: Value, labels: MaybeLabels) -> ContextManager: content = self.redis.get(content_key) if content is None: self.redis.set(content_key, value.read()) - self.update_labels(key, labels) - self.update_usage_date(key) + self._update_labels(key, labels) + self.touch(key) with value_to_buffer(self.redis.get(content_key)) as buffer: yield buffer return @@ -70,36 +70,40 @@ def write(self, key: Key, value: Value, labels: MaybeLabels) -> ContextManager: raise CollisionError( f'Written value and the new one does not match: {key}' ) - self.update_labels(key, labels) - self.update_usage_date(key) + self._update_labels(key, labels) + self.touch(key) with value_to_buffer(self.redis.get(content_key)) as buffer: yield buffer except StorageCorruption: self.delete(key) - def get_labels(self, key: Key) -> MaybeLabels: + def _get_labels(self, key: Key) -> MaybeLabels: labels_key = b'labels' + self.prefix + key labels_bytes = self.redis.get(labels_key) if labels_bytes is None: return return list(json.loads(labels_bytes)) - def update_labels(self, key: Key, labels: MaybeLabels): + def _update_labels(self, key: Key, labels: MaybeLabels): labels_key = b'labels' + self.prefix + key - old_labels = self.get_labels(key) or [] + old_labels = self._get_labels(key) or [] if labels is not None: labels = list(set(old_labels).union(labels)) self.redis.set(labels_key, json.dumps(labels)) - def get_usage_date(self, key: Key) -> Optional[datetime]: + def _get_usage_date(self, key: Key) -> Optional[datetime]: usage_date_key = b'usage_date' + self.prefix + key usage_date = self.redis.get(usage_date_key) if usage_date is not None: return datetime.fromtimestamp(float(usage_date)) - def update_usage_date(self, key: Key): + def touch(self, key: Key): + content_key = self.prefix + key + if content_key not in self.redis.keys(): + return False usage_date_key = b'usage_date' + self.prefix + key self.redis.set(usage_date_key, datetime.now().timestamp()) + return True def delete(self, key: Key): content_key = self.prefix + key @@ -128,16 +132,16 @@ def _is_url(url): class RedisMeta(Meta): - def __init__(self, key, location): + def __init__(self, key: str, location: RedisLocation): self._key, self._location = key, location @property def last_used(self) -> Optional[datetime]: - return self._location.get_usage_date(self._key) + return self._location._get_usage_date(self._key) @property def labels(self) -> MaybeLabels: - return self._location.get_labels(self._key) + return self._location._get_labels(self._key) def __str__(self): return f'{self.last_used}, {self.labels}' diff --git a/tarn/location/s3.py b/tarn/location/s3.py index be0b121..9e8e35f 100644 --- a/tarn/location/s3.py +++ b/tarn/location/s3.py @@ -1,40 +1,37 @@ import warnings from contextlib import contextmanager from datetime import datetime -from io import SEEK_CUR, SEEK_SET from pickle import PicklingError -from typing import Any, BinaryIO, ContextManager, Iterable, Mapping, Optional, Tuple, Union +from typing import Any, ContextManager, Iterable, Optional, Tuple, Union -import boto3 -from botocore.exceptions import ClientError, ConnectionError +from s3fs.core import S3FileSystem -from ..compat import S3Client from ..digest import key_to_relative, value_to_buffer from ..exceptions import CollisionError, StorageCorruption from ..interface import Key, MaybeLabels, Meta, Value from ..utils import match_buffers -from .interface import Writable +from .interface import Location -class S3(Writable): - def __init__(self, s3_client_or_url: Union[S3Client, str], bucket_name: str, service_name: str = 's3', **kwargs): +class S3(Location): + def __init__(self, s3fs_or_url: Optional[Union[S3FileSystem, str]], bucket_name: str, **kwargs): self.bucket = bucket_name - if isinstance(s3_client_or_url, str): - self.s3 = boto3.client(service_name=service_name, endpoint_url=s3_client_or_url, **kwargs) + if s3fs_or_url is None: + self.s3 = S3FileSystem(**kwargs) + elif isinstance(s3fs_or_url, str): + self.s3 = S3FileSystem(client_kwargs={'endpoint_url': s3fs_or_url}, **kwargs) else: - self.s3 = s3_client_or_url - self._s3_client_or_url = s3_client_or_url + assert isinstance(s3fs_or_url, S3FileSystem), 's3fs_or_url should be either None, or str, or S3FileSystem' + self.s3 = s3fs_or_url + self._s3fs_or_url = s3fs_or_url self._kwargs = kwargs def contents(self) -> Iterable[Tuple[Key, Any, Meta]]: - paginator = self.s3.get_paginator('list_objects_v2') - response_iterator = paginator.paginate(Bucket=self.bucket) - for response in response_iterator: - if 'Contents' in response: - for obj in response['Contents']: - path = obj['Key'] - key = self._path_to_key(path) - yield key, self, S3Meta(path=path, location=self) + for directory, _, files in self.s3.walk(self.bucket): + for file in files: + path = f'{directory}/{file}' + key = self._path_to_key(path) + yield key, self, S3Meta(path=path, location=self) @contextmanager def read( @@ -43,23 +40,15 @@ def read( try: path = self._key_to_path(key) try: - self.update_usage_date(path) + self.touch(key) if return_labels: - with self._get_buffer(path) as buffer: - yield buffer, self.get_labels(path) + with self.s3.open(path, 'rb') as buffer: + yield buffer, self._get_labels(path) else: - with self._get_buffer(path) as buffer: + with self.s3.open(path, 'rb') as buffer: yield buffer - except ClientError as e: - if ( - e.response['ResponseMetadata']['HTTPStatusCode'] == 404 - or e.response['Error']['Code'] == 'NoSuchKey' - ): # file doesn't exist - yield - else: - raise - except ConnectionError: + except FileNotFoundError: yield except StorageCorruption: self.delete(key) @@ -70,165 +59,94 @@ def write(self, key: Key, value: Value, labels: MaybeLabels) -> ContextManager: path = self._key_to_path(key) with value_to_buffer(value) as value: try: - with self._get_buffer(path) as obj_body_buffer: + with self.s3.open(path, 'rb') as buffer: try: - match_buffers(value, obj_body_buffer, context=key.hex()) + match_buffers(value, buffer, context=key.hex()) except ValueError as e: raise CollisionError( f'Written value and the new one does not match: {key}' ) from e - self.update_labels(path, labels) - self.update_usage_date(path) - with self._get_buffer(path) as buffer: - yield buffer - return - except ClientError as e: - if ( - e.response['ResponseMetadata']['HTTPStatusCode'] == 404 - or e.response['Error']['Code'] == 'NoSuchKey' - ): - self.s3.upload_fileobj( - Bucket=self.bucket, Key=path, Fileobj=value - ) - self.update_labels(path, labels) - self.update_usage_date(path) - with self._get_buffer(path) as buffer: + self._update_labels(path, labels) + self.touch(key) + with self.s3.open(path, 'rb') as buffer: yield buffer return - else: - raise - except ConnectionError: - yield + except FileNotFoundError: + with self.s3.open(path, 'wb') as buffer: + buffer.write(value.read()) + self._update_labels(path, labels) + self.touch(key) + with self.s3.open(path, 'rb') as buffer: + yield buffer + return except StorageCorruption: self.delete(key) def delete(self, key: Key): path = self._key_to_path(key) - self.s3.delete_object(Bucket=self.bucket, Key=path) + self.s3.delete(path) return True - def update_labels(self, path: str, labels: MaybeLabels): + def touch(self, key: Key): + try: + path = self._key_to_path(key) + tags_dict = self.s3.get_tags(path) + tags_dict['usage_date'] = str(datetime.now().timestamp()) + self.s3.put_tags(path, tags_dict) + return True + except FileNotFoundError: + return False + + def _update_labels(self, path: str, labels: MaybeLabels): if labels is not None: - tags_dict = self._tags_to_dict( - self.s3.get_object_tagging(Bucket=self.bucket, Key=path)['TagSet'] - ) + tags_dict = self.s3.get_tags(path) tags_dict.update({f'_{label}': f'_{label}' for label in labels}) - tags = self._dict_to_tags(tags_dict) - self.s3.put_object_tagging( - Bucket=self.bucket, Key=path, Tagging={'TagSet': tags} - ) - - def get_labels(self, path: str) -> MaybeLabels: - tags_dict = self._tags_to_dict( - self.s3.get_object_tagging(Bucket=self.bucket, Key=path)['TagSet'] - ) + self.s3.put_tags(path, tags_dict) + + def _get_labels(self, path: str) -> MaybeLabels: + tags_dict = self.s3.get_tags(path) return [dict_key[1:] for dict_key in tags_dict if dict_key.startswith('_')] - def update_usage_date(self, path: str): - try: - tags_dict = self._tags_to_dict( - self.s3.get_object_tagging(Bucket=self.bucket, Key=path)['TagSet'] - ) - tags_dict['usage_date'] = str(datetime.now().timestamp()) - tags = self._dict_to_tags(tags_dict) - self.s3.put_object_tagging( - Bucket=self.bucket, Key=path, Tagging={'TagSet': tags} - ) - except KeyError: - warnings.warn(f'Cannot update usage date for the key {self._path_to_key(path)}', stacklevel=2) - - def get_usage_date(self, path: str) -> Optional[datetime]: - tags_dict = self._tags_to_dict( - self.s3.get_object_tagging(Bucket=self.bucket, Key=path)['TagSet'] - ) + def _get_usage_date(self, path: str) -> Optional[datetime]: + tags_dict = self.s3.get_tags(path) if 'usage_date' in tags_dict: return datetime.fromtimestamp(float(tags_dict['usage_date'])) return None - def _get_buffer(self, path): - # with self.s3.get_object(Bucket=self.bucket, Key=path)['Body'] as obj_body: - # assert False, obj_body.read() - return StreamingBodyBuffer(self.s3.get_object, Bucket=self.bucket, Key=path) - def _key_to_path(self, key: Key): - return str(key_to_relative(key, [2, -1])) + return f'{self.bucket}/{str(key_to_relative(key, [2, -1]))}' def _path_to_key(self, path: str): - return bytes.fromhex(path.replace('/', '')) - - @staticmethod - def _tags_to_dict(tags: Iterable[Mapping[str, str]]) -> Mapping[str, str]: - return {tag['Key']: tag['Value'] for tag in tags} - - @staticmethod - def _dict_to_tags(tag_dict: Mapping[str, str]) -> Iterable[Mapping[str, str]]: - return [{'Key': key, 'Value': value} for key, value in tag_dict.items()] + path = ''.join(path.split('/')[1:]) + try: + return bytes.fromhex(path) + except ValueError: + assert False, path @classmethod - def _from_args(cls, s3_client_or_url, bucket_name, kwargs): - return cls(s3_client_or_url, bucket_name, **kwargs) + def _from_args(cls, s3fs_or_url, bucket_name, kwargs): + return cls(s3fs_or_url, bucket_name, **kwargs) def __reduce__(self): - if isinstance(self._s3_client_or_url, str): - return self._from_args, (self._s3_client_or_url, self.bucket, self._kwargs) + if isinstance(self._s3fs_or_url, (str, None)): + return self._from_args, (self._s3fs_or_url, self.bucket, self._kwargs) raise PicklingError('Cannot pickle S3Client') def __eq__(self, other): return isinstance(other, S3) and self.__reduce__() == other.__reduce__() -class StreamingBodyBuffer(BinaryIO): - def __init__(self, getter, **kwargs): - super().__init__() - self.getter, self.kwargs = getter, kwargs - self._streaming_body = getter(**kwargs).get('Body') - - def seek(self, offset: int, whence: int = SEEK_SET) -> int: - # we can either return to the begining of the stream or do nothing - # everythnig else is too expensive - if whence == SEEK_SET: - if offset == 0: - self._streaming_body = self.getter(**self.kwargs).get('Body') - return 0 - if offset == self.tell(): - return offset - - if whence == SEEK_CUR: - if offset == 0: - return self.tell() - if offset == -self.tell(): - self._streaming_body = self.getter(**self.kwargs).get('Body') - return 0 - - raise NotImplementedError('Cannot seek anywhere but the begining of the stream') - - def seekable(self): - return True - - def __enter__(self): - return self - - def __exit__(self, type, value, traceback): - self.close() - - def __getattribute__(self, attr) -> Any: - if attr in ('seek', 'getter', 'kwargs', '__enter__', '__exit__'): - return super().__getattribute__(attr) - streaming_body = super().__getattribute__('_streaming_body') - return getattr(streaming_body, attr) - - class S3Meta(Meta): - def __init__(self, path, location): + def __init__(self, path: str, location: S3): self._path, self._location = path, location @property def last_used(self) -> Optional[datetime]: - return self._location.get_usage_date(self._path) + return self._location._get_usage_date(self._path) @property def labels(self) -> MaybeLabels: - return self._location.get_labels(self._path) + return self._location._get_labels(self._path) def __str__(self): return f'{self.last_used}, {self.labels}' diff --git a/tarn/location/small.py b/tarn/location/small.py index 4a78e52..08c1ba4 100644 --- a/tarn/location/small.py +++ b/tarn/location/small.py @@ -5,11 +5,11 @@ from ..compat import Self from ..interface import Key, Keys, MaybeLabels, Value from ..utils import value_to_buffer -from .interface import Meta, Writable +from .interface import Location, Meta -class Small(Writable): - def __init__(self, location: Writable, max_size: int): +class Small(Location): + def __init__(self, location: Location, max_size: int): self.location = location self.max_size = max_size # TODO: remove @@ -38,6 +38,9 @@ def write(self, key: Key, value: Value, labels: MaybeLabels) -> ContextManager: def delete(self, key: Key) -> bool: return self.location.delete(key) + def touch(self, key: Key) -> bool: + return self.location.touch(key) + # `read` is only guaranteed to return _at most_ n bytes, so we might need several calls def read_at_least(buffer, n): diff --git a/tarn/location/ssh/interface.py b/tarn/location/ssh/interface.py index 52eb3ae..ab4dec8 100644 --- a/tarn/location/ssh/interface.py +++ b/tarn/location/ssh/interface.py @@ -13,15 +13,15 @@ from ...compat import Self, remove_file, rmtree from ...digest import key_to_relative from ...interface import Key, Keys, MaybeLabels, Meta, PathOrStr, Value -from ..interface import Location from ..disk_dict.config import load_config +from ..interface import ReadOnly class UnknownHostException(SSHException): pass -class SSHRemote(Location, ABC): +class SSHRemote(ReadOnly, ABC): exceptions = () def __init__(self, hostname: str, root: PathOrStr, port: int = SSH_PORT, username: str = None, password: str = None, diff --git a/tarn/location/ssh/sftp.py b/tarn/location/ssh/sftp.py index 5fc834d..851d8bb 100644 --- a/tarn/location/ssh/sftp.py +++ b/tarn/location/ssh/sftp.py @@ -7,7 +7,7 @@ class SFTP(SSHRemote): - exceptions = () + exceptions = (FileNotFoundError, ) @contextmanager def _client(self) -> ContextManager[SFTPClient]: diff --git a/tarn/tools/usage.py b/tarn/tools/usage.py index ca53607..70a008e 100644 --- a/tarn/tools/usage.py +++ b/tarn/tools/usage.py @@ -21,7 +21,7 @@ def update(self, key: Key): @abstractmethod def get(self, key: Key) -> Optional[datetime]: - """ Deletes the usage time for a given `key` """ + """ Get the usage time for a given `key` """ @abstractmethod def delete(self, key: Key): diff --git a/tests/conftest.py b/tests/conftest.py index 91cf4ed..b892af3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,9 +4,8 @@ from pathlib import Path from typing import Iterator -import boto3 import pytest -from botocore.exceptions import ClientError +from s3fs import S3FileSystem from tarn import DiskDict, HashKeyStorage, PickleKeyStorage from tarn.config import StorageConfig, init_storage @@ -118,23 +117,22 @@ def bucket_name(): @pytest.fixture def s3_kwargs(inside_ci, bucket_name): if inside_ci: - s3 = boto3.client('s3', endpoint_url='http://127.0.0.1:8001', aws_access_key_id='admin', aws_secret_access_key='adminadminadminadmin') + s3fs = S3FileSystem( + client_kwargs={'endpoint_url': 'http://127.0.0.1:8001'}, + key='admin', + secret='adminadminadminadmin', + ) kwargs = { - 'service_name': 's3', - 's3_client_or_url': 'http://127.0.0.1:8001', - 'aws_access_key_id': 'admin', - 'aws_secret_access_key': 'adminadminadminadmin', + 's3fs_or_url': 'http://127.0.0.1:8001', + 'key': 'admin', + 'secret': 'adminadminadminadmin', 'bucket_name': bucket_name, } else: - s3 = boto3.client('s3', endpoint_url='http://10.0.1.2:11354') + s3fs = S3FileSystem(endpoint_url='http://10.0.1.2:11354') kwargs = { - 'service_name': 's3', - 's3_client_or_url': 'http://10.0.1.2:11354', + 's3fs_or_url': 'http://10.0.1.2:11354', 'bucket_name': bucket_name, } - try: - s3.head_bucket(Bucket=bucket_name) - except ClientError: - s3.create_bucket(Bucket=bucket_name) + s3fs.mkdirs(bucket_name, exist_ok=True) return kwargs diff --git a/tests/test_locations/test_interface.py b/tests/test_locations/test_interface.py index 0258f78..bb57202 100644 --- a/tests/test_locations/test_interface.py +++ b/tests/test_locations/test_interface.py @@ -2,7 +2,7 @@ import pytest -from tarn import DiskDict, Fanout, Level, Levels, StorageCorruption, Writable +from tarn import DiskDict, Fanout, Level, Levels, Location, StorageCorruption def _mkdir(x): @@ -24,7 +24,7 @@ def _mkdir(x): ), lambda x: Fanout(DiskDict(_mkdir(x / 'one')), DiskDict(_mkdir(x / 'two'))), ]) -def location(request, temp_dir) -> Writable: +def location(request, temp_dir) -> Location: return request.param(temp_dir) diff --git a/tests/test_locations/test_sftp.py b/tests/test_locations/test_sftp.py index dc7b779..a14ead1 100644 --- a/tests/test_locations/test_sftp.py +++ b/tests/test_locations/test_sftp.py @@ -25,9 +25,12 @@ def test_storage_ssh(storage_factory): local.read(load_text, key) # add a remote - both = HashKeyStorage(local._local, get_ssh_location(STORAGE_ROOT)) + ssh_location = get_ssh_location(STORAGE_ROOT) + both = HashKeyStorage(local._local, ssh_location) assert both.read(load_text, key) == load_text(__file__) local.read(load_text, key) + with ssh_location.read(b'123213213332', False) as v: + assert v is None def test_wrong_host():