Skip to content

Commit

Permalink
improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
mayofaulkner committed Mar 25, 2024
1 parent e163975 commit 7eb5a99
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 33 deletions.
75 changes: 42 additions & 33 deletions one/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,32 +392,37 @@ def register_session(self, ses_path, users=None, file_list=True, **kwargs):
session['data_dataset_session_related'] = ensure_list(recs)
return session, recs

def check_protected_files(self, file_list, created_by=None):
def prepare_files(self, file_list, versions=None):
"""
Check whether a set of files associated to a session are protected
Validates file list for registration and splits files into a list of files per
session path.
Parameters
----------
file_list : list, str, pathlib.Path
A filepath (or list thereof) of ALF datasets to register to Alyx.
created_by : str
Name of Alyx user (defaults to whoever is logged in to ONE instance).
versions : str, list of str
Optional version tags.
Returns
-------
list of dicts, dict
A status for each session whether any of the files specified are protected
datasets or not.If none of the datasets are protected, a response with status
200 is returned, if any of the files are protected a response with status
403 is returned.
list of dicts, list of dicts
A dict containing a list of files for each session
A dict containg a list of versions for each session
"""

F = defaultdict(list) # empty map whose keys will be session paths

V = defaultdict(list) # empty map for versions
if isinstance(file_list, (str, pathlib.Path)):
file_list = [file_list]

if versions is None or isinstance(versions, str):
versions = itertools.repeat(versions)
else:
versions = itertools.cycle(versions)

# Filter valid files and sort by session
for fn in map(pathlib.Path, file_list):
for fn, ver in zip(map(pathlib.Path, file_list), versions):
session_path = get_session_path(fn)
if fn.suffix not in self.file_extensions:
_logger.debug(f'{fn}: No matching extension "{fn.suffix}" in database')
Expand All @@ -428,6 +433,31 @@ def check_protected_files(self, file_list, created_by=None):
_logger.debug('%s', ex.args[0])
continue
F[session_path].append(fn.relative_to(session_path))
V[session_path].append(ver)

return F, V

def check_protected_files(self, file_list, created_by=None):
"""
Check whether a set of files associated to a session are protected
Parameters
----------
file_list : list, str, pathlib.Path
A filepath (or list thereof) of ALF datasets to register to Alyx.
created_by : str
Name of Alyx user (defaults to whoever is logged in to ONE instance).
Returns
-------
list of dicts, dict
A status for each session whether any of the files specified are protected
datasets or not.If none of the datasets are protected, a response with status
200 is returned, if any of the files are protected a response with status
403 is returned.
"""

# Validate files and rearrange into list per session
F, _ = self.prepare_files(file_list)

# For each unique session, make a separate POST request
records = []
Expand Down Expand Up @@ -496,29 +526,8 @@ def register_files(self, file_list,
Server side database error (500 status code)
Revision protected (403 status code)
"""
F = defaultdict(list) # empty map whose keys will be session paths
V = defaultdict(list) # empty map for versions
if isinstance(file_list, (str, pathlib.Path)):
file_list = [file_list]

if versions is None or isinstance(versions, str):
versions = itertools.repeat(versions)
else:
versions = itertools.cycle(versions)

# Filter valid files and sort by session
for fn, ver in zip(map(pathlib.Path, file_list), versions):
session_path = get_session_path(fn)
if fn.suffix not in self.file_extensions:
_logger.debug(f'{fn}: No matching extension "{fn.suffix}" in database')
continue
try:
get_dataset_type(fn, self.dtypes)
except ValueError as ex:
_logger.debug('%s', ex.args[0])
continue
F[session_path].append(fn.relative_to(session_path))
V[session_path].append(ver)
F, V = self.prepare_files(file_list, versions=versions)

# For each unique session, make a separate POST request
records = []
Expand Down
24 changes: 24 additions & 0 deletions one/tests/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,30 @@ def test_create_sessions(self):
self.assertEqual(ses[0]['number'], int(session_path.parts[-1]))
self.assertEqual(session_paths[0], session_path)

def test_prepare_files(self):
"""Test for RegistrationClient.prepare_files"""

session_path = self.session_path.parent / next_num_folder(self.session_path.parent)
session_path_2 = session_path.parent / next_num_folder(session_path)
file_list = [session_path.joinpath('wheel.position.npy'),
session_path.joinpath('wheel.timestamps.npy'),
session_path_2.joinpath('wheel.position.npy')]

# Test with file list and version is None
F, V = self.client.prepare_files(file_list)
self.assertTrue(len(F), 2)
self.assertListEqual(sorted(list(F.keys())), sorted([session_path, session_path_2]))
for sess, n in zip([session_path, session_path_2], [2, 1]):
self.assertTrue(len(F[sess]), n)
self.assertTrue(len(V[sess]), n)
self.assertIsNone(V[session_path][0])

# Test with specifying version
versions = ['1.2.2', 'v1.2', '1.3.4']
_, V = self.client.prepare_files(file_list, versions=versions)
self.assertListEqual(V[session_path], versions[:-1])
self.assertListEqual(V[session_path_2], [versions[-1]])

def test_check_protected(self):
"""Test for RegistrationClient.check_protected_files"""

Expand Down

0 comments on commit 7eb5a99

Please sign in to comment.