Skip to content

Commit

Permalink
Merge pull request #12 from neuro-ml/dev
Browse files Browse the repository at this point in the history
`fetch` is a generator now
  • Loading branch information
maxme1 authored Dec 19, 2022
2 parents 51cd084 + a6fc11d commit d073366
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ env:

jobs:
test:
runs-on: ubuntu-latest
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: [ 3.6, 3.7, 3.8, 3.9, '3.10' ]
Expand Down
2 changes: 1 addition & 1 deletion tarn/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.3'
__version__ = '0.2.0'
29 changes: 15 additions & 14 deletions tarn/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class RemoteStorage:

@abstractmethod
def fetch(self, keys: Sequence[Key], store: Callable[[Key, Path], Any],
config: HashConfig) -> Sequence[Tuple[Any, bool]]:
config: HashConfig) -> Iterable[Tuple[Any, bool]]:
"""
Fetches the value for ``key`` from a remote location.
"""
Expand Down Expand Up @@ -143,34 +143,35 @@ def read(self, key, context, *, fetch: bool) -> Tuple[Any, Union[None, bool]]:

return None, status

def fetch(self, keys: Iterable[Key], context, *, verbose: bool) -> Sequence[Key]:
def fetch(self, keys: Sequence[Key], context, *, verbose: bool) -> Iterable[Key]:
""" Fetch the `keys` from remote. Yields the keys that were successfully fetched """

def store(k, base):
status = self._replicate(k, base, context, self.levels)
bar.update()
return status if status is WriteError else k

keys = set(keys)
bar = tqdm(disable=not verbose, total=len(keys))
present = set()
for key in keys:
present = 0
for key in list(keys):
if self._contains(key, context):
present.add(key)
present += 1
keys.remove(key)
bar.update()
yield key

keys -= present
logger.info(f'Fetch: {len(present)} keys already present, fetching {len(keys)}')
logger.info('%s keys already present, fetching %s', present, len(keys))

for remote in self.remote:
if not keys:
break

logger.info(f'Trying remote {remote}')
keys -= {
k for k, success in remote.fetch(list(keys), store, self.hash)
if success and k is not WriteError
}

return list(keys)
logger.info('Trying remote %s', remote)
for key, success in remote.fetch(list(keys), store, self.hash):
if success and key is not WriteError:
keys.remove(key)
yield key

def _contains(self, key, context):
for layer in self.levels:
Expand Down
19 changes: 17 additions & 2 deletions tarn/local/storage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import warnings
from pathlib import Path
from typing import Sequence, Iterable, Callable, Union

Expand Down Expand Up @@ -50,8 +51,22 @@ def write(self, file: PathLike, error: bool = True) -> Union[Key, None]:

return key

def fetch(self, keys: Iterable[Key], *, verbose: bool) -> Sequence[Key]:
return self.storage.fetch(keys, None, verbose=verbose)
def fetch(self, keys: Sequence[Key], *, verbose: bool, legacy: bool = True) -> Iterable[Key]:
""" Fetch the `keys` from remote. Yields the keys that were successfully fetched """
keys = list(keys)
result = self.storage.fetch(keys, None, verbose=verbose)
if legacy:
warnings.warn(
'In a future release fetch will yield the successfully processed keys. '
'Pass legacy=False to adopt this behaviour early on', UserWarning,
)
warnings.warn(
'In a future release fetch will yield the successfully processed keys. '
'Pass legacy=False to adopt this behaviour early on', DeprecationWarning,
)
result = list(set(keys) - set(result))

return result

def resolve(self, key: Key, *, fetch: bool = True) -> Path:
""" This is not safe, but it's fast. """
Expand Down
14 changes: 6 additions & 8 deletions tarn/remote/http.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import shutil
import tempfile
from pathlib import Path
from typing import Sequence, Callable, Any, Tuple
from typing import Sequence, Callable, Any, Tuple, Iterable
from urllib.error import HTTPError, URLError
from urllib.parse import urljoin
from urllib.request import urlretrieve
Expand All @@ -23,29 +23,27 @@ def __init__(self, url: str, optional: bool = False):
self.levels = self.hash = None

def fetch(self, keys: Sequence[Key], store: Callable[[Key, Path], Any],
config: HashConfig) -> Sequence[Tuple[Any, bool]]:
config: HashConfig) -> Iterable[Tuple[Any, bool]]:

results = []
with tempfile.TemporaryDirectory() as temp_dir:
source = Path(temp_dir) / 'source'
if keys and not self._get_config(config):
return [(None, False)] * len(keys)
yield from [(None, False)] * len(keys)
return

for key in keys:
try:
self._fetch_tree(key_to_relative(key, self.levels), source)

value = store(key, source)
shutil.rmtree(source)
results.append((value, True))
yield value, True

except requests.exceptions.ConnectionError:
results.append((None, False))
yield None, False

shutil.rmtree(source, ignore_errors=True)

return results

def _fetch_one(self, relative, local):
try:
urlretrieve(urljoin(self.url, str(relative)), str(local))
Expand Down
19 changes: 9 additions & 10 deletions tarn/remote/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import socket
import tempfile
from pathlib import Path
from typing import Union, Sequence, Callable, Any, Tuple
from typing import Union, Sequence, Callable, Any, Tuple, Iterable

import paramiko
from paramiko import SSHClient, AuthenticationException, SSHException
Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(self, hostname: str, root: PathLike, port: int = SSH_PORT, username
self.optional = optional

def fetch(self, keys: Sequence[Key], store: Callable[[Key, Path], Any],
config: HashConfig) -> Sequence[Tuple[Any, bool]]:
config: HashConfig) -> Iterable[Tuple[Any, bool]]:

try:
self.ssh.connect(
Expand All @@ -64,36 +64,35 @@ def fetch(self, keys: Sequence[Key], store: Callable[[Key, Path], Any],
if not self.optional:
raise

return [(None, False)] * len(keys)
yield from [(None, False)] * len(keys)
return

results = []
try:
with SCPClient(self.ssh.get_transport()) as scp, tempfile.TemporaryDirectory() as temp_dir:
source = Path(temp_dir) / 'source'
if keys and not self._get_config(scp, config):
return [(None, False)] * len(keys)
yield from [(None, False)] * len(keys)
return

for key in keys:
try:
scp.get(str(self.root / key_to_relative(key, self.levels)), str(source), recursive=True)
if not source.exists():
results.append((None, False))
yield None, False

else:
value = store(key, source)
shutil.rmtree(source)
results.append((value, True))
yield value, True

except (SCPException, socket.timeout, SSHException):
results.append((None, False))
yield None, False

shutil.rmtree(source, ignore_errors=True)

finally:
self.ssh.close()

return results

def _get_config(self, scp, config):
try:
if self.levels is None:
Expand Down
8 changes: 5 additions & 3 deletions tests/test_remote/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ def test_missing(storage_factory):

key = remote.write(__file__)
missing = key[:-1] + 'x'
result = location.fetch([key, missing], lambda k, base: (base / 'data').exists(), location.hash)
result = list(location.fetch([key, missing], lambda k, base: (base / 'data').exists(), location.hash))
assert result == [(True, True), (None, False)]


def test_wrong_address():
with pytest.raises(requests.exceptions.ConnectionError):
HTTPLocation('http://localhost/wrong').fetch(['some-key'], lambda *args: True, None)
list(HTTPLocation('http://localhost/wrong').fetch(['some-key'], lambda *args: True, None))

assert HTTPLocation('http://localhost/wrong', True).fetch(['some-key'], lambda *args: True, None) == [(None, False)]
assert list(HTTPLocation(
'http://localhost/wrong', True
).fetch(['some-key'], lambda *args: True, None)) == [(None, False)]
7 changes: 4 additions & 3 deletions tests/test_remote/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_storage_ssh(storage_factory):

def test_wrong_host():
with pytest.raises((NoValidConnectionsError, SSHException)):
SSHLocation('localhost', '/').fetch(['some-key'], lambda *args: True, None)
assert SSHLocation('localhost', '/', optional=True).fetch(
['some-key'], lambda *args: True, None) == [(None, False)]
list(SSHLocation('localhost', '/').fetch(['some-key'], lambda *args: True, None))
assert list(SSHLocation(
'localhost', '/', optional=True
).fetch(['some-key'], lambda *args: True, None)) == [(None, False)]

0 comments on commit d073366

Please sign in to comment.