diff --git a/one/registration.py b/one/registration.py index 8cb0f170..ea64bc0c 100644 --- a/one/registration.py +++ b/one/registration.py @@ -392,6 +392,58 @@ def register_session(self, ses_path, users=None, file_list=True, **kwargs): session['data_dataset_session_related'] = ensure_list(recs) return session, recs + def check_protected_files(self, file_list, created_by=None): + """ + Check whether a set of files associated to a session are protected + Parameters + ---------- + file_list : list, str, pathlib.Path + A filepath (or list thereof) of ALF datasets to register to Alyx. + created_by : str + Name of Alyx user (defaults to whoever is logged in to ONE instance). + + Returns + ------- + list of dicts, dict + A status for each session whether any of the files specified are protected + datasets or not.If none of the datasets are protected, a response with status + 200 is returned, if any of the files are protected a response with status + 403 is returned. + """ + + F = defaultdict(list) # empty map whose keys will be session paths + + if isinstance(file_list, (str, pathlib.Path)): + file_list = [file_list] + + # Filter valid files and sort by session + for fn in map(pathlib.Path, file_list): + session_path = get_session_path(fn) + if fn.suffix not in self.file_extensions: + _logger.debug(f'{fn}: No matching extension "{fn.suffix}" in database') + continue + try: + get_dataset_type(fn, self.dtypes) + except ValueError as ex: + _logger.debug('%s', ex.args[0]) + continue + F[session_path].append(fn.relative_to(session_path)) + + # For each unique session, make a separate POST request + records = [] + for session_path, files in F.items(): + # this is the generic relative path: subject/yyyy-mm-dd/NNN + details = session_path_parts(session_path.as_posix(), as_dict=True, assert_valid=True) + rel_path = PurePosixPath(details['subject'], details['date'], details['number']) + + r_ = {'created_by': created_by or self.one.alyx.user, + 'path': rel_path.as_posix(), + 'filenames': [x.as_posix() for x in files] + } + records.append(self.one.alyx.get('/check-protected', data=r_, clobber=True)) + + return records[0] if len(F.keys()) == 1 else records + def register_files(self, file_list, versions=None, default=True, created_by=None, server_only=False, repository=None, exists=True, dry=False, max_md5_size=None, **kwargs): diff --git a/one/tests/test_registration.py b/one/tests/test_registration.py index 70910fc4..d3d04734 100644 --- a/one/tests/test_registration.py +++ b/one/tests/test_registration.py @@ -222,6 +222,29 @@ def test_create_sessions(self): self.assertEqual(ses[0]['number'], int(session_path.parts[-1])) self.assertEqual(session_paths[0], session_path) + def test_check_protected(self): + """Test for RegistrationClient.check_protected_files""" + + session_path, eid = self.client.create_new_session(self.subject) + file_name = session_path.joinpath('wheel.timestamps.npy') + file_name.touch() + + # register a dataset + rec, = self.client.register_files(str(file_name)) + + # Check if it is protected, it shouldn't be, response 200 + protected = self.client.check_protected_files(str(file_name)) + self.assertEqual(protected['status_code'], 200) + + # Add a protected tag to all the datasets + tag = self.tag['name'] + self.one.alyx.rest('datasets', 'partial_update', id=rec['id'], data={'tags': [tag]}) + + # check protected + protected = self.client.check_protected_files(str(file_name)) + self.assertEqual(protected['status_code'], 403) + self.assertEqual(protected['error'], 'One or more datasets is protected') + def test_register_files(self): """Test for RegistrationClient.register_files""" # Test a few things not checked in register_session