From 54c11f34587e4f2a03e485bebd52308af1a586a7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 12 Nov 2024 22:02:52 +0100 Subject: [PATCH 1/7] Add medecine wraper for estimate motion --- .../sortingcomponents/motion/medecine.py | 69 +++++++++++++++++++ .../motion/motion_estimation.py | 4 +- 2 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 src/spikeinterface/sortingcomponents/motion/medecine.py diff --git a/src/spikeinterface/sortingcomponents/motion/medecine.py b/src/spikeinterface/sortingcomponents/motion/medecine.py new file mode 100644 index 0000000000..176e3e081c --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/medecine.py @@ -0,0 +1,69 @@ +import numpy as np + +from .motion_utils import Motion +import tempfile +import shutil +from pathlib import Path + + +class MedecineRegistration: + """ + """ + + name = "medecine" + need_peak_location = True + params_doc = """ + + """ + + @classmethod + def run( + cls, + recording, + peaks, + peak_locations, + direction, + + #unsed need to be adapted + rigid, + win_shape, + win_step_um, + win_scale_um, + win_margin_um, + verbose, + progress_bar, + extra, + + bin_um=5.0, + hist_margin_um=20.0, + bin_s=2.0, + + ): + + from medicine import run_medicine + + folder = Path(tempfile.gettempdir()) + + run_medicine( + peak_amplitudes=peaks['amplitude'], + peak_depths=peak_locations[direction], + peak_times=peaks['sample_index'] / recording.get_sampling_frequency(), + output_dir=folder, + plot_figures=False, + ) + + # Load motion estimated by MEDiCINe + motion_array = np.load(folder / 'motion.npy') + time_bins = np.load(folder / 'time_bins.npy') + depth_bins = np.load(folder / 'depth_bins.npy') + + # Use interpolation to correct for motion estimated by MEDiCINe + motion = Motion( + displacement=[motion_array], + temporal_bins_s=[time_bins], + spatial_bins_um=depth_bins, + ) + + shutil.rmtree(folder) + + return motion diff --git a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py index 8a4daeb808..5732966adc 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py @@ -11,7 +11,7 @@ from .decentralized import DecentralizedRegistration from .iterative_template import IterativeTemplateRegistration from .dredge import DredgeLfpRegistration, DredgeApRegistration - +from .medecine import MedecineRegistration # estimate_motion > infer_motion def estimate_motion( @@ -130,7 +130,7 @@ def estimate_motion( return motion -_methods_list = [DecentralizedRegistration, IterativeTemplateRegistration, DredgeLfpRegistration, DredgeApRegistration] +_methods_list = [DecentralizedRegistration, IterativeTemplateRegistration, DredgeLfpRegistration, DredgeApRegistration, MedecineRegistration] estimate_motion_methods = {m.name: m for m in _methods_list} method_doc = make_multi_method_doc(_methods_list) estimate_motion.__doc__ = estimate_motion.__doc__.format(method_doc=method_doc) From 28d9d158e19c3bc5a520fdb858560305df726ca4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 13 Nov 2024 08:49:35 +0100 Subject: [PATCH 2/7] oups --- src/spikeinterface/sortingcomponents/motion/medecine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/motion/medecine.py b/src/spikeinterface/sortingcomponents/motion/medecine.py index 176e3e081c..cf99e75dad 100644 --- a/src/spikeinterface/sortingcomponents/motion/medecine.py +++ b/src/spikeinterface/sortingcomponents/motion/medecine.py @@ -64,6 +64,7 @@ def run( spatial_bins_um=depth_bins, ) - shutil.rmtree(folder) + # TODO check why not working + # shutil.rmtree(folder) return motion From f6909f014fccde39aa576091e523558909fca32a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Nov 2024 13:57:48 +0100 Subject: [PATCH 3/7] medecine api --- .../sortingcomponents/motion/medecine.py | 62 ++++++++++++++----- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/medecine.py b/src/spikeinterface/sortingcomponents/motion/medecine.py index cf99e75dad..3e2a0e1ebe 100644 --- a/src/spikeinterface/sortingcomponents/motion/medecine.py +++ b/src/spikeinterface/sortingcomponents/motion/medecine.py @@ -5,6 +5,8 @@ import shutil from pathlib import Path +from .motion_utils import get_spatial_windows + class MedecineRegistration: """ @@ -34,35 +36,63 @@ def run( progress_bar, extra, - bin_um=5.0, - hist_margin_um=20.0, - bin_s=2.0, + # bin_um=5.0, + # hist_margin_um=20.0, + bin_s=1.0, + time_kernel_width=30., + amplitude_threshold_quantile=0., + + + #### + training_steps=10_000, ): from medicine import run_medicine - folder = Path(tempfile.gettempdir()) + # folder = Path(tempfile.gettempdir()) + + if rigid: + # force one bin + num_depth_bins = 1 + else: - run_medicine( + # we use the spatial window mechanism only to estimate the number one spatial bins + dim = ["x", "y", "z"].index(direction) + contact_depths = recording.get_channel_locations()[:, dim] + + deph_range = max(contact_depths) - min(contact_depths) + if win_margin_um is not None: + deph_range = deph_range - 2 * win_margin_um + num_depth_bins = max(int(np.round(deph_range / win_scale_um)), 1) + print('num_depth_bins', num_depth_bins) + + + trainer, motion = run_medicine( peak_amplitudes=peaks['amplitude'], peak_depths=peak_locations[direction], peak_times=peaks['sample_index'] / recording.get_sampling_frequency(), - output_dir=folder, + time_bin_size=bin_s, + num_depth_bins=num_depth_bins, + training_steps=training_steps, + time_kernel_width=time_kernel_width, + amplitude_threshold_quantile=amplitude_threshold_quantile, + output_dir=None, plot_figures=False, + return_motion=True, ) # Load motion estimated by MEDiCINe - motion_array = np.load(folder / 'motion.npy') - time_bins = np.load(folder / 'time_bins.npy') - depth_bins = np.load(folder / 'depth_bins.npy') - - # Use interpolation to correct for motion estimated by MEDiCINe - motion = Motion( - displacement=[motion_array], - temporal_bins_s=[time_bins], - spatial_bins_um=depth_bins, - ) + # motion_array = np.load(folder / 'motion.npy') + # time_bins = np.load(folder / 'time_bins.npy') + # depth_bins = np.load(folder / 'depth_bins.npy') + + # # Use interpolation to correct for motion estimated by MEDiCINe + # motion = Motion( + # displacement=[motion_array], + # temporal_bins_s=[time_bins], + # spatial_bins_um=depth_bins, + # ) # TODO check why not working # shutil.rmtree(folder) From e7e76f6025f99e7819e2879a7c5c961b1461f3c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Nov 2024 12:59:15 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/motion/medecine.py | 23 +++++++------------ .../motion/motion_estimation.py | 9 +++++++- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/medecine.py b/src/spikeinterface/sortingcomponents/motion/medecine.py index 3e2a0e1ebe..7237efffec 100644 --- a/src/spikeinterface/sortingcomponents/motion/medecine.py +++ b/src/spikeinterface/sortingcomponents/motion/medecine.py @@ -9,8 +9,7 @@ class MedecineRegistration: - """ - """ + """ """ name = "medecine" need_peak_location = True @@ -25,8 +24,7 @@ def run( peaks, peak_locations, direction, - - #unsed need to be adapted + # unsed need to be adapted rigid, win_shape, win_step_um, @@ -35,19 +33,15 @@ def run( verbose, progress_bar, extra, - # bin_um=5.0, # hist_margin_um=20.0, bin_s=1.0, - time_kernel_width=30., - amplitude_threshold_quantile=0., - - + time_kernel_width=30.0, + amplitude_threshold_quantile=0.0, #### training_steps=10_000, - ): - + from medicine import run_medicine # folder = Path(tempfile.gettempdir()) @@ -65,13 +59,12 @@ def run( if win_margin_um is not None: deph_range = deph_range - 2 * win_margin_um num_depth_bins = max(int(np.round(deph_range / win_scale_um)), 1) - print('num_depth_bins', num_depth_bins) - + print("num_depth_bins", num_depth_bins) trainer, motion = run_medicine( - peak_amplitudes=peaks['amplitude'], + peak_amplitudes=peaks["amplitude"], peak_depths=peak_locations[direction], - peak_times=peaks['sample_index'] / recording.get_sampling_frequency(), + peak_times=peaks["sample_index"] / recording.get_sampling_frequency(), time_bin_size=bin_s, num_depth_bins=num_depth_bins, training_steps=training_steps, diff --git a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py index 5732966adc..e2eabebaef 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py @@ -13,6 +13,7 @@ from .dredge import DredgeLfpRegistration, DredgeApRegistration from .medecine import MedecineRegistration + # estimate_motion > infer_motion def estimate_motion( recording, @@ -130,7 +131,13 @@ def estimate_motion( return motion -_methods_list = [DecentralizedRegistration, IterativeTemplateRegistration, DredgeLfpRegistration, DredgeApRegistration, MedecineRegistration] +_methods_list = [ + DecentralizedRegistration, + IterativeTemplateRegistration, + DredgeLfpRegistration, + DredgeApRegistration, + MedecineRegistration, +] estimate_motion_methods = {m.name: m for m in _methods_list} method_doc = make_multi_method_doc(_methods_list) estimate_motion.__doc__ = estimate_motion.__doc__.format(method_doc=method_doc) From 133cb2a32093198e376a1b356e9a601caba0b369 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 8 Jan 2025 09:54:40 +0100 Subject: [PATCH 5/7] Update with medine API changes + wrap all params. --- .../sortingcomponents/motion/medecine.py | 57 +++++++++++-------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/medecine.py b/src/spikeinterface/sortingcomponents/motion/medecine.py index 7237efffec..4145342be5 100644 --- a/src/spikeinterface/sortingcomponents/motion/medecine.py +++ b/src/spikeinterface/sortingcomponents/motion/medecine.py @@ -33,13 +33,20 @@ def run( verbose, progress_bar, extra, - # bin_um=5.0, - # hist_margin_um=20.0, bin_s=1.0, - time_kernel_width=30.0, + + ## medecine specific kwargs propagated to the lib + motion_bound=800, + time_kernel_width=30, + activity_network_hidden_features=(256, 256), amplitude_threshold_quantile=0.0, - #### + batch_size=4096, training_steps=10_000, + initial_motion_noise=0.1, + motion_noise_steps=2000, + optimizer=None, + learning_rate=0.0005, + epsilon=1e-3, ): from medicine import run_medicine @@ -61,33 +68,37 @@ def run( num_depth_bins = max(int(np.round(deph_range / win_scale_um)), 1) print("num_depth_bins", num_depth_bins) - trainer, motion = run_medicine( + if optimizer is None: + import torch + optimizer = torch.optim.Adam + + trainer, time_bins, depth_bins, pred_motion = run_medicine( peak_amplitudes=peaks["amplitude"], peak_depths=peak_locations[direction], peak_times=peaks["sample_index"] / recording.get_sampling_frequency(), time_bin_size=bin_s, num_depth_bins=num_depth_bins, - training_steps=training_steps, - time_kernel_width=time_kernel_width, - amplitude_threshold_quantile=amplitude_threshold_quantile, + output_dir=None, plot_figures=False, - return_motion=True, - ) - # Load motion estimated by MEDiCINe - # motion_array = np.load(folder / 'motion.npy') - # time_bins = np.load(folder / 'time_bins.npy') - # depth_bins = np.load(folder / 'depth_bins.npy') - - # # Use interpolation to correct for motion estimated by MEDiCINe - # motion = Motion( - # displacement=[motion_array], - # temporal_bins_s=[time_bins], - # spatial_bins_um=depth_bins, - # ) + motion_bound=motion_bound, + time_kernel_width=time_kernel_width, + activity_network_hidden_features=activity_network_hidden_features, + amplitude_threshold_quantile=amplitude_threshold_quantile, + batch_size=batch_size, + training_steps=training_steps, + initial_motion_noise=initial_motion_noise, + motion_noise_steps=motion_noise_steps, + optimizer=optimizer, + learning_rate=learning_rate, + epsilon=epsilon, + ) - # TODO check why not working - # shutil.rmtree(folder) + motion = Motion( + displacement=[np.array(pred_motion)], + temporal_bins_s=[np.array(time_bins)], + spatial_bins_um=np.array(depth_bins), + ) return motion From 0190e886d321e8a63b7eb0c2d999dc4c074cbe6b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Jan 2025 08:55:19 +0000 Subject: [PATCH 6/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/motion/medecine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/medecine.py b/src/spikeinterface/sortingcomponents/motion/medecine.py index 4145342be5..1a29f4ea47 100644 --- a/src/spikeinterface/sortingcomponents/motion/medecine.py +++ b/src/spikeinterface/sortingcomponents/motion/medecine.py @@ -34,7 +34,6 @@ def run( progress_bar, extra, bin_s=1.0, - ## medecine specific kwargs propagated to the lib motion_bound=800, time_kernel_width=30, @@ -70,6 +69,7 @@ def run( if optimizer is None: import torch + optimizer = torch.optim.Adam trainer, time_bins, depth_bins, pred_motion = run_medicine( @@ -78,10 +78,8 @@ def run( peak_times=peaks["sample_index"] / recording.get_sampling_frequency(), time_bin_size=bin_s, num_depth_bins=num_depth_bins, - output_dir=None, plot_figures=False, - motion_bound=motion_bound, time_kernel_width=time_kernel_width, activity_network_hidden_features=activity_network_hidden_features, From 9a17967567ff062f28f4aaac5a8622a892ed407e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 8 Jan 2025 14:02:48 +0100 Subject: [PATCH 7/7] oups --- .../sortingcomponents/motion/{medecine.py => medicine.py} | 6 +++--- .../sortingcomponents/motion/motion_estimation.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) rename src/spikeinterface/sortingcomponents/motion/{medecine.py => medicine.py} (96%) diff --git a/src/spikeinterface/sortingcomponents/motion/medecine.py b/src/spikeinterface/sortingcomponents/motion/medicine.py similarity index 96% rename from src/spikeinterface/sortingcomponents/motion/medecine.py rename to src/spikeinterface/sortingcomponents/motion/medicine.py index 4145342be5..8b56cb394c 100644 --- a/src/spikeinterface/sortingcomponents/motion/medecine.py +++ b/src/spikeinterface/sortingcomponents/motion/medicine.py @@ -8,10 +8,10 @@ from .motion_utils import get_spatial_windows -class MedecineRegistration: +class MedicineRegistration: """ """ - name = "medecine" + name = "medicine" need_peak_location = True params_doc = """ @@ -35,7 +35,7 @@ def run( extra, bin_s=1.0, - ## medecine specific kwargs propagated to the lib + ## medicine specific kwargs propagated to the lib motion_bound=800, time_kernel_width=30, activity_network_hidden_features=(256, 256), diff --git a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py index e2eabebaef..a285ad5064 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py @@ -11,7 +11,7 @@ from .decentralized import DecentralizedRegistration from .iterative_template import IterativeTemplateRegistration from .dredge import DredgeLfpRegistration, DredgeApRegistration -from .medecine import MedecineRegistration +from .medicine import MedicineRegistration # estimate_motion > infer_motion @@ -136,7 +136,7 @@ def estimate_motion( IterativeTemplateRegistration, DredgeLfpRegistration, DredgeApRegistration, - MedecineRegistration, + MedicineRegistration, ] estimate_motion_methods = {m.name: m for m in _methods_list} method_doc = make_multi_method_doc(_methods_list)