Skip to content

Commit

Permalink
Introduce fast correlogram merge
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Jan 10, 2025
1 parent a64aed9 commit ea61047
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 3 deletions.
44 changes: 41 additions & 3 deletions src/spikeinterface/postprocessing/correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
import numpy as np
from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension, SortingAnalyzer
from copy import deepcopy

from spikeinterface.core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor

Expand Down Expand Up @@ -93,9 +94,46 @@ def _select_extension_data(self, unit_ids):
def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, censor_ms=None, verbose=False, **job_kwargs
):
# recomputing correlogram is fast enough and much easier in this case
new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params)
new_data = dict(ccgs=new_ccgs, bins=new_bins)

method = "keep_first"
cut_from = 1

correlograms, new_bins = deepcopy(self.get_data())

for new_unit_id, merge_unit_group in zip(new_unit_ids, merge_unit_groups):

merge_unit_group_indices = self.sorting_analyzer.sorting.ids_to_indices(merge_unit_group)

new_col = np.sum(correlograms[merge_unit_group_indices, :, :], axis=0)
correlograms[merge_unit_group_indices[0], :, :] = new_col
correlograms[merge_unit_group_indices[1:], :, :] = 0

new_row = np.sum(correlograms[:, merge_unit_group_indices, :], axis=1)
correlograms[:, merge_unit_group_indices[0], :] = new_row
correlograms[:, merge_unit_group_indices[1:], :] = 0

if new_unit_id not in merge_unit_group:
method = "append"
cut_from = 0

if method == "append":
old_num_units = np.shape(correlograms)[0]
correlograms = np.pad(correlograms, ((0, len(new_unit_ids)), (0, len(new_unit_ids)), (0, 0)))
for a, (new_unit_id, merge_unit_group) in enumerate(zip(new_unit_ids, merge_unit_groups)):

old_loc = self.sorting_analyzer.sorting.ids_to_indices([merge_unit_group[0]])[0]
new_loc = old_num_units + a

correlograms[:, [old_loc, new_loc], :] = correlograms[:, [new_loc, old_loc], :]
correlograms[[old_loc, new_loc], :, :] = correlograms[[new_loc, old_loc], :, :]

units_to_delete = np.concatenate([merge_unit_group[cut_from:] for merge_unit_group in merge_unit_groups])
indices_to_delete = self.sorting_analyzer.sorting.ids_to_indices(units_to_delete)

correlograms = np.delete(correlograms, indices_to_delete, axis=0)
correlograms = np.delete(correlograms, indices_to_delete, axis=1)

new_data = dict(ccgs=correlograms, bins=new_bins)
return new_data

def _run(self, verbose=False):
Expand Down
36 changes: 36 additions & 0 deletions src/spikeinterface/postprocessing/tests/test_extension_merges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np

from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer


def test_correlograms_merge():
"""
When merging in `soft` mode, correlograms sum and we take advantage of this to make
a fast computation. This test checks that we get the same result using this fast
sum as recomputing the correlograms from scratch.
"""

rec, sort = generate_ground_truth_recording()

sorting_analyzer = create_sorting_analyzer(recording=rec, sorting=sort)
sorting_analyzer.compute("correlograms")

trial_merges = [
[["1", "2"]],
[["2", "4", "6", "8"]],
[["1", "4", "7"], ["2", "8"]],
[["4", "1", "8"], ["2", "7", "0"], ["3", "9"], ["5", "6"]],
]

for new_id_strategy in ["append", "take_first"]:
for merge_unit_groups in trial_merges:

# first, compute the correlograms of the merged units using the merge method
merged_sorting_analyzer = sorting_analyzer.merge_units(
merge_unit_groups=merge_unit_groups, new_id_strategy=new_id_strategy
)
computed_correlograms = merged_sorting_analyzer.get_extension("correlograms").get_data()

# Then re-compute, and compare
recomputed_correlograms = merged_sorting_analyzer.compute("correlograms").get_data()
assert np.all(computed_correlograms[0] == recomputed_correlograms[0])

0 comments on commit ea61047

Please sign in to comment.