From 9d52b325171ffb71d6379f2fd08824c4d20fe0b2 Mon Sep 17 00:00:00 2001 From: Jacob Pennington Date: Fri, 13 Dec 2024 14:11:53 -0800 Subject: [PATCH 1/4] Probe preview now updates every time probe settings changed --- kilosort/gui/probe_view_box.py | 6 ++++-- kilosort/gui/settings_box.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/kilosort/gui/probe_view_box.py b/kilosort/gui/probe_view_box.py index 7bd4d22..58671a5 100644 --- a/kilosort/gui/probe_view_box.py +++ b/kilosort/gui/probe_view_box.py @@ -116,7 +116,7 @@ def update_spots_variables(self, probe, template_args): for ind, (xc, yc) in enumerate(zip(self.xc, self.yc)): self.channel_map_dict[(xc, yc)] = ind - def get_template_spots(self, nC, dmin, dminx, max_dist, x_centers, device): + def get_template_spots(self, nC, dmin, dminx, max_dist, x_centers): ops = { 'yc': self.yc, 'xc': self.xc, 'max_channel_distance': max_dist, 'x_centers': x_centers, 'settings': {'dmin': dmin, 'dminx': dminx}, @@ -125,7 +125,9 @@ def get_template_spots(self, nC, dmin, dminx, max_dist, x_centers, device): ops = template_centers(ops) [ys, xs] = np.meshgrid(ops['yup'], ops['xup']) ys, xs = ys.flatten(), xs.flatten() - iC, ds = nearest_chans(ys, self.yc, xs, self.xc, nC, device=device) + iC, ds = nearest_chans( + ys, self.yc, xs, self.xc, nC, device=self.gui.device + ) igood = ds[0,:] <= max_dist**2 iC = iC[:,igood] diff --git a/kilosort/gui/settings_box.py b/kilosort/gui/settings_box.py index e1ac231..ed45e8a 100644 --- a/kilosort/gui/settings_box.py +++ b/kilosort/gui/settings_box.py @@ -19,6 +19,9 @@ _DEFAULT_DTYPE = 'int16' _ALLOWED_FILE_TYPES = ['.bin', '.dat', '.bat', '.raw'] # For binary data +_PROBE_SETTINGS = [ + 'nearest_chans', 'dmin', 'dminx', 'max_channel_distance', 'x_centers' + ] class SettingsBox(QtWidgets.QGroupBox): settingsUpdated = QtCore.Signal() @@ -247,6 +250,8 @@ def setup(self): ) inp = getattr(self, f'{k}_input') inp.editingFinished.connect(self.update_parameter) + if k in _PROBE_SETTINGS: + inp.editingFinished.connect(self.show_probe_layout()) row_count += rspan layout.addWidget( @@ -550,10 +555,7 @@ def update_settings(self): def get_probe_template_args(self): epw = self.extra_parameters_window - template_args = [ - epw.nearest_chans, epw.dmin, epw.dminx, - epw.max_channel_distance, epw.x_centers, self.gui.device - ] + template_args = [getattr(epw, k) for k in _PROBE_SETTINGS] return template_args @QtCore.Slot() @@ -862,6 +864,8 @@ def __init__(self, parent): layout.addWidget(getattr(self, f'{k}_input'), row_count, col+3, 1, 2) inp = getattr(self, f'{k}_input') inp.editingFinished.connect(self.update_parameter) + if k in _PROBE_SETTINGS: + inp.editingFinished.connect(self.main_settings.show_probe_layout) self.setLayout(layout) From 4d25a4429a5227f628cc9b06e346654d3b9efbe0 Mon Sep 17 00:00:00 2001 From: Jacob Pennington Date: Fri, 13 Dec 2024 14:34:26 -0800 Subject: [PATCH 2/4] added note about minimum version to plotting_example tutorial. --- docs/tutorials/plotting_example.ipynb | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/tutorials/plotting_example.ipynb b/docs/tutorials/plotting_example.ipynb index ee494a8..b89f368 100644 --- a/docs/tutorials/plotting_example.ipynb +++ b/docs/tutorials/plotting_example.ipynb @@ -7,6 +7,13 @@ "# Example plots using kilosort.data_tools" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Note that `kilosort.data_tools` was added in `v4.0.21`, so you will need to update Kilosort4 to at least that version to use these examples. This can be done using `pip install kilosort --upgrade`." + ] + }, { "cell_type": "code", "execution_count": 5, From 596184f125f3aba06562a62ca40f5a3ed268effa Mon Sep 17 00:00:00 2001 From: Jacob Pennington Date: Fri, 13 Dec 2024 14:41:42 -0800 Subject: [PATCH 3/4] changed get_best_channel to return for all clusters --- docs/tutorials/plotting_example.ipynb | 4 ++-- kilosort/data_tools.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/tutorials/plotting_example.ipynb b/docs/tutorials/plotting_example.ipynb index b89f368..4084f51 100644 --- a/docs/tutorials/plotting_example.ipynb +++ b/docs/tutorials/plotting_example.ipynb @@ -49,7 +49,7 @@ "from kilosort.io import load_ops\n", "from kilosort.data_tools import (\n", " mean_waveform, cluster_templates, get_good_cluster, get_cluster_spikes,\n", - " get_spike_waveforms, get_best_channel\n", + " get_spike_waveforms, get_best_channels\n", " )\n", "\n", "\n", @@ -108,7 +108,7 @@ "# Time in s for spike time axis\n", "t2 = spike_times / ops['fs']\n", "# Get single-channel waveform for each spike\n", - "chan = get_best_channel(cluster_id, results_dir)\n", + "chan = get_best_channels(results_dir)[cluster_id]\n", "waves = get_spike_waveforms(spike_times, results_dir, chan=chan)\n", "\n", "# Plot each waveform, using spike time as 3rd dimension\n", diff --git a/kilosort/data_tools.py b/kilosort/data_tools.py index 7e0957e..b33ca87 100644 --- a/kilosort/data_tools.py +++ b/kilosort/data_tools.py @@ -34,7 +34,7 @@ def mean_waveform(cluster_id, results_dir, n_spikes=np.inf, bfile=None, best=Tru """ results_dir = Path(results_dir) if best: - chan = get_best_channel(cluster_id, results_dir) + chan = get_best_channels(results_dir)[cluster_id] else: chan = None @@ -45,11 +45,11 @@ def mean_waveform(cluster_id, results_dir, n_spikes=np.inf, bfile=None, best=Tru return mean_wave -def get_best_channel(cluster_id, results_dir): - """Get channel number with largest template norm for this cluster.""" +def get_best_channels(results_dir): + """Get channel numbers with largest template norm for each cluster.""" templates = np.load(results_dir / 'templates.npy') - chan = (templates**2).sum(axis=1).argmax(axis=-1)[cluster_id] - return chan + best_chans = (templates**2).sum(axis=1).argmax(axis=-1) + return best_chans def get_cluster_spikes(cluster_id, results_dir, n_spikes=np.inf): From a5b43f0e05122e45f71d83525beb3e07bc0d6b56 Mon Sep 17 00:00:00 2001 From: Jacob Pennington Date: Fri, 13 Dec 2024 14:47:39 -0800 Subject: [PATCH 4/4] added get_best_channel back in to avoid disrupting scripts --- kilosort/data_tools.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/kilosort/data_tools.py b/kilosort/data_tools.py index b33ca87..193b6cf 100644 --- a/kilosort/data_tools.py +++ b/kilosort/data_tools.py @@ -51,6 +51,8 @@ def get_best_channels(results_dir): best_chans = (templates**2).sum(axis=1).argmax(axis=-1) return best_chans +def get_best_channel(results_dir, cluster_id): + return get_best_channels(results_dir)[cluster_id] def get_cluster_spikes(cluster_id, results_dir, n_spikes=np.inf): """Get `n_spikes` random spike times assigned to `cluster_id`."""