diff --git a/one/registration.py b/one/registration.py index ea64bc0c..94003b10 100644 --- a/one/registration.py +++ b/one/registration.py @@ -392,32 +392,37 @@ def register_session(self, ses_path, users=None, file_list=True, **kwargs): session['data_dataset_session_related'] = ensure_list(recs) return session, recs - def check_protected_files(self, file_list, created_by=None): + def prepare_files(self, file_list, versions=None): """ - Check whether a set of files associated to a session are protected + Validates file list for registration and splits files into a list of files per + session path. + Parameters ---------- file_list : list, str, pathlib.Path A filepath (or list thereof) of ALF datasets to register to Alyx. - created_by : str - Name of Alyx user (defaults to whoever is logged in to ONE instance). + versions : str, list of str + Optional version tags. Returns ------- - list of dicts, dict - A status for each session whether any of the files specified are protected - datasets or not.If none of the datasets are protected, a response with status - 200 is returned, if any of the files are protected a response with status - 403 is returned. + list of dicts, list of dicts + A dict containing a list of files for each session + A dict containg a list of versions for each session """ F = defaultdict(list) # empty map whose keys will be session paths - + V = defaultdict(list) # empty map for versions if isinstance(file_list, (str, pathlib.Path)): file_list = [file_list] + if versions is None or isinstance(versions, str): + versions = itertools.repeat(versions) + else: + versions = itertools.cycle(versions) + # Filter valid files and sort by session - for fn in map(pathlib.Path, file_list): + for fn, ver in zip(map(pathlib.Path, file_list), versions): session_path = get_session_path(fn) if fn.suffix not in self.file_extensions: _logger.debug(f'{fn}: No matching extension "{fn.suffix}" in database') @@ -428,6 +433,31 @@ def check_protected_files(self, file_list, created_by=None): _logger.debug('%s', ex.args[0]) continue F[session_path].append(fn.relative_to(session_path)) + V[session_path].append(ver) + + return F, V + + def check_protected_files(self, file_list, created_by=None): + """ + Check whether a set of files associated to a session are protected + Parameters + ---------- + file_list : list, str, pathlib.Path + A filepath (or list thereof) of ALF datasets to register to Alyx. + created_by : str + Name of Alyx user (defaults to whoever is logged in to ONE instance). + + Returns + ------- + list of dicts, dict + A status for each session whether any of the files specified are protected + datasets or not.If none of the datasets are protected, a response with status + 200 is returned, if any of the files are protected a response with status + 403 is returned. + """ + + # Validate files and rearrange into list per session + F, _ = self.prepare_files(file_list) # For each unique session, make a separate POST request records = [] @@ -496,29 +526,8 @@ def register_files(self, file_list, Server side database error (500 status code) Revision protected (403 status code) """ - F = defaultdict(list) # empty map whose keys will be session paths - V = defaultdict(list) # empty map for versions - if isinstance(file_list, (str, pathlib.Path)): - file_list = [file_list] - - if versions is None or isinstance(versions, str): - versions = itertools.repeat(versions) - else: - versions = itertools.cycle(versions) - # Filter valid files and sort by session - for fn, ver in zip(map(pathlib.Path, file_list), versions): - session_path = get_session_path(fn) - if fn.suffix not in self.file_extensions: - _logger.debug(f'{fn}: No matching extension "{fn.suffix}" in database') - continue - try: - get_dataset_type(fn, self.dtypes) - except ValueError as ex: - _logger.debug('%s', ex.args[0]) - continue - F[session_path].append(fn.relative_to(session_path)) - V[session_path].append(ver) + F, V = self.prepare_files(file_list, versions=versions) # For each unique session, make a separate POST request records = [] diff --git a/one/tests/test_registration.py b/one/tests/test_registration.py index d3d04734..d896fd95 100644 --- a/one/tests/test_registration.py +++ b/one/tests/test_registration.py @@ -222,6 +222,30 @@ def test_create_sessions(self): self.assertEqual(ses[0]['number'], int(session_path.parts[-1])) self.assertEqual(session_paths[0], session_path) + def test_prepare_files(self): + """Test for RegistrationClient.prepare_files""" + + session_path = self.session_path.parent / next_num_folder(self.session_path.parent) + session_path_2 = session_path.parent / next_num_folder(session_path) + file_list = [session_path.joinpath('wheel.position.npy'), + session_path.joinpath('wheel.timestamps.npy'), + session_path_2.joinpath('wheel.position.npy')] + + # Test with file list and version is None + F, V = self.client.prepare_files(file_list) + self.assertTrue(len(F), 2) + self.assertListEqual(sorted(list(F.keys())), sorted([session_path, session_path_2])) + for sess, n in zip([session_path, session_path_2], [2, 1]): + self.assertTrue(len(F[sess]), n) + self.assertTrue(len(V[sess]), n) + self.assertIsNone(V[session_path][0]) + + # Test with specifying version + versions = ['1.2.2', 'v1.2', '1.3.4'] + _, V = self.client.prepare_files(file_list, versions=versions) + self.assertListEqual(V[session_path], versions[:-1]) + self.assertListEqual(V[session_path_2], [versions[-1]]) + def test_check_protected(self): """Test for RegistrationClient.check_protected_files"""