diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index ba12a5c462..ac2e43469a 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -3,6 +3,7 @@ import warnings import numpy as np from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension, SortingAnalyzer +from copy import deepcopy from spikeinterface.core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor @@ -93,9 +94,46 @@ def _select_extension_data(self, unit_ids): def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, censor_ms=None, verbose=False, **job_kwargs ): - # recomputing correlogram is fast enough and much easier in this case - new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) - new_data = dict(ccgs=new_ccgs, bins=new_bins) + + method = "keep_first" + cut_from = 1 + + correlograms, new_bins = deepcopy(self.get_data()) + + for new_unit_id, merge_unit_group in zip(new_unit_ids, merge_unit_groups): + + merge_unit_group_indices = self.sorting_analyzer.sorting.ids_to_indices(merge_unit_group) + + new_col = np.sum(correlograms[merge_unit_group_indices, :, :], axis=0) + correlograms[merge_unit_group_indices[0], :, :] = new_col + correlograms[merge_unit_group_indices[1:], :, :] = 0 + + new_row = np.sum(correlograms[:, merge_unit_group_indices, :], axis=1) + correlograms[:, merge_unit_group_indices[0], :] = new_row + correlograms[:, merge_unit_group_indices[1:], :] = 0 + + if new_unit_id not in merge_unit_group: + method = "append" + cut_from = 0 + + if method == "append": + old_num_units = np.shape(correlograms)[0] + correlograms = np.pad(correlograms, ((0, len(new_unit_ids)), (0, len(new_unit_ids)), (0, 0))) + for a, (new_unit_id, merge_unit_group) in enumerate(zip(new_unit_ids, merge_unit_groups)): + + old_loc = self.sorting_analyzer.sorting.ids_to_indices([merge_unit_group[0]])[0] + new_loc = old_num_units + a + + correlograms[:, [old_loc, new_loc], :] = correlograms[:, [new_loc, old_loc], :] + correlograms[[old_loc, new_loc], :, :] = correlograms[[new_loc, old_loc], :, :] + + units_to_delete = np.concatenate([merge_unit_group[cut_from:] for merge_unit_group in merge_unit_groups]) + indices_to_delete = self.sorting_analyzer.sorting.ids_to_indices(units_to_delete) + + correlograms = np.delete(correlograms, indices_to_delete, axis=0) + correlograms = np.delete(correlograms, indices_to_delete, axis=1) + + new_data = dict(ccgs=correlograms, bins=new_bins) return new_data def _run(self, verbose=False): diff --git a/src/spikeinterface/postprocessing/tests/test_extension_merges.py b/src/spikeinterface/postprocessing/tests/test_extension_merges.py new file mode 100644 index 0000000000..b42d5d84b4 --- /dev/null +++ b/src/spikeinterface/postprocessing/tests/test_extension_merges.py @@ -0,0 +1,36 @@ +import numpy as np + +from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer + + +def test_correlograms_merge(): + """ + When merging in `soft` mode, correlograms sum and we take advantage of this to make + a fast computation. This test checks that we get the same result using this fast + sum as recomputing the correlograms from scratch. + """ + + rec, sort = generate_ground_truth_recording() + + sorting_analyzer = create_sorting_analyzer(recording=rec, sorting=sort) + sorting_analyzer.compute("correlograms") + + trial_merges = [ + [["1", "2"]], + [["2", "4", "6", "8"]], + [["1", "4", "7"], ["2", "8"]], + [["4", "1", "8"], ["2", "7", "0"], ["3", "9"], ["5", "6"]], + ] + + for new_id_strategy in ["append", "take_first"]: + for merge_unit_groups in trial_merges: + + # first, compute the correlograms of the merged units using the merge method + merged_sorting_analyzer = sorting_analyzer.merge_units( + merge_unit_groups=merge_unit_groups, new_id_strategy=new_id_strategy + ) + computed_correlograms = merged_sorting_analyzer.get_extension("correlograms").get_data() + + # Then re-compute, and compare + recomputed_correlograms = merged_sorting_analyzer.compute("correlograms").get_data() + assert np.all(computed_correlograms[0] == recomputed_correlograms[0])