diff --git a/src/spikeinterface/working/load_kilosort_utils.py b/src/spikeinterface/working/load_kilosort_utils.py index 2423bc02a0..e4047261d5 100644 --- a/src/spikeinterface/working/load_kilosort_utils.py +++ b/src/spikeinterface/working/load_kilosort_utils.py @@ -7,7 +7,7 @@ from scipy import stats -# TODO: spike_times -> spike_indexes +# TODO: spike_times -> spike_indices """ Notes ----- @@ -15,13 +15,17 @@ - things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude. """ +######################################################################################################################## +# Get Spike Data +######################################################################################################################## + def compute_spike_amplitude_and_depth( sorter_output: str | Path, localised_spikes_only, exclude_noise, gain: float | None = None, - localised_spikes_channel_cutoff: int = None, # TODO + localised_spikes_channel_cutoff: int = None, ) -> tuple[np.ndarray, ...]: """ Compute the amplitude and depth of all detected spikes from the kilosort output. @@ -46,8 +50,8 @@ def compute_spike_amplitude_and_depth( Returns ------- - spike_indexes : np.ndarray - (num_spikes,) array of spike indexes. + spike_indices : np.ndarray + (num_spikes,) array of spike indices. spike_amplitudes : np.ndarray (num_spikes,) array of corresponding spike amplitudes. spike_depths : np.ndarray @@ -66,7 +70,7 @@ def compute_spike_amplitude_and_depth( if isinstance(sorter_output, str): sorter_output = Path(sorter_output) - params = _load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise) + params = load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise) if localised_spikes_only: localised_templates = [] @@ -81,10 +85,52 @@ def compute_spike_amplitude_and_depth( localised_template_by_spike = np.isin(params["spike_templates"], localised_templates) - _strip_spikes(params, localised_template_by_spike) + params["spike_templates"] = params["spike_templates"][localised_template_by_spike] + params["spike_indices"] = params["spike_indices"][localised_template_by_spike] + params["spike_clusters"] = params["spike_clusters"][localised_template_by_spike] + params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][localised_template_by_spike] + params["pc_features"] = params["pc_features"][localised_template_by_spike] + + spike_locations, spike_max_sites = _get_locations_from_pc_features(params) + + # Amplitude is calculated for each spike as the template amplitude + # multiplied by the `template_scaling_amplitudes`. + template_amplitudes_unscaled, *_ = get_unwhite_template_info( + params["templates"], + params["whitening_matrix_inv"], + params["channel_positions"], + ) + spike_amplitudes = template_amplitudes_unscaled[params["spike_templates"]] * params["temp_scaling_amplitudes"] + + if gain is not None: + spike_amplitudes *= gain + compute_template_amplitudes_from_spikes(params["templates"], params["spike_templates"], spike_amplitudes) + + if localised_spikes_only: + # Interpolate the channel ids to location. + # Remove spikes > 5 um from average position + # Above we already removed non-localized templates, but that on its own is insufficient. + # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient + # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere + # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold. + # 3) just use depth. Probably go for that. check with others. + spike_depths = spike_locations[:, 1] + b = stats.linregress(spike_depths, spike_max_sites).slope + i = np.abs(spike_max_sites - b * spike_depths) <= 5 + + params["spike_indices"] = params["spike_indices"][i] + spike_amplitudes = spike_amplitudes[i] + spike_locations = spike_locations[i, :] + spike_max_sites = spike_max_sites[i] + + return params["spike_indices"], spike_amplitudes, spike_locations, spike_max_sites + + +def _get_locations_from_pc_features(params): + """ """ # Compute spike depths - pc_features = params["pc_features"][:, 0, :] # Do this compute + pc_features = params["pc_features"][:, 0, :] pc_features[pc_features < 0] = 0 # Some spikes do not load at all onto the first PC. To avoid biasing the @@ -109,58 +155,28 @@ def compute_spike_amplitude_and_depth( "to extend this code section to handle more components." ) - # Get the channel indexes corresponding to the 32 channels from the PC. + # Get the channel indices corresponding to the 32 channels from the PC. spike_features_indices = params["pc_features_indices"][params["spike_templates"], :] # Compute the spike locations as the center of mass of the PC scores spike_feature_coords = params["channel_positions"][spike_features_indices, :] - norm_weights = pc_features / np.sum(pc_features, axis=1)[:, np.newaxis] # TOOD: see why they use square + norm_weights = ( + pc_features / np.sum(pc_features, axis=1)[:, np.newaxis] + ) # TOOD: discuss use of square. Probbaly do not use to keep in line with COM in SI. spike_locations = spike_feature_coords * norm_weights[:, :, np.newaxis] spike_locations = np.sum(spike_locations, axis=1) # TODO: now max site per spike is computed from PCs, not as the channel max site as previous - spike_sites = spike_features_indices[np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1)] + spike_max_sites = spike_features_indices[ + np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1) + ] - # Amplitude is calculated for each spike as the template amplitude - # multiplied by the `template_scaling_amplitudes`. - template_amplitudes_unscaled, *_ = get_unwhite_template_info( - params["templates"], - params["whitening_matrix_inv"], - params["channel_positions"], - ) - spike_amplitudes = template_amplitudes_unscaled[params["spike_templates"]] * params["temp_scaling_amplitudes"] - - if gain is not None: - spike_amplitudes *= gain - - if localised_spikes_only: - # Interpolate the channel ids to location. - # Remove spikes > 5 um from average position - # Above we already removed non-localized templates, but that on its own is insufficient. - # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient - # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere - # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold. - # 3) just use depth. Probably go for that. check with others. - spike_depths = spike_locations[:, 1] - b = stats.linregress(spike_depths, spike_sites).slope - i = np.abs(spike_sites - b * spike_depths) <= 5 # TODO: need to expose this + return spike_locations, spike_max_sites - params["spike_indexes"] = params["spike_indexes"][i] - spike_amplitudes = spike_amplitudes[i] - spike_locations = spike_locations[i, :] - return params["spike_indexes"], spike_amplitudes, spike_locations, spike_sites - - -def _strip_spikes_in_place(params, indices): - """ """ - params["spike_templates"] = params["spike_templates"][ - indices - ] # TODO: make an function for this. because we do this a lot - params["spike_indexes"] = params["spike_indexes"][indices] - params["spike_clusters"] = params["spike_clusters"][indices] - params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][indices] - params["pc_features"] = params["pc_features"][indices] # TODO: be conciststetn! change indees to indices +######################################################################################################################## +# Get Template Data +######################################################################################################################## def get_unwhite_template_info( @@ -213,7 +229,7 @@ def get_unwhite_template_info( template_amplitudes_unscaled = np.max(template_amplitudes_per_channel, axis=1) - # Zero any small channel amplitudes + # Zero any small channel amplitudes TODO: removed this. # threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree? # template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0 @@ -253,9 +269,11 @@ def get_unwhite_template_info( ) -def compute_template_amplitudes_from_spikes(): - # Take the average of all spike amplitudes to get actual template amplitudes - # (since tempScalingAmps are equal mean for all templates) +def compute_template_amplitudes_from_spikes(templates, spike_templates, spike_amplitudes): + """ + Take the average of all spike amplitudes to get actual template amplitudes + (since tempScalingAmps are equal mean for all templates) + """ num_indices = templates.shape[0] sum_per_index = np.zeros(num_indices, dtype=np.float64) np.add.at(sum_per_index, spike_templates, spike_amplitudes) @@ -264,7 +282,12 @@ def compute_template_amplitudes_from_spikes(): return template_amplitudes -def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict: +######################################################################################################################## +# Load Parameters from KS Directory +######################################################################################################################## + + +def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict: """ Loads the output of Kilosort into a `params` dict. @@ -300,7 +323,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool params = read_python(sorter_output / "params.py") - spike_indexes = np.load(sorter_output / "spike_times.npy") + spike_indices = np.load(sorter_output / "spike_times.npy") spike_templates = np.load(sorter_output / "spike_templates.npy") if (clusters_path := sorter_output / "spike_clusters.csv").is_dir(): @@ -328,7 +351,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool noise_cluster_ids = cluster_ids[cluster_groups == 0] not_noise_clusters_by_spike = ~np.isin(spike_clusters.ravel(), noise_cluster_ids) - spike_indexes = spike_indexes[not_noise_clusters_by_spike] + spike_indices = spike_indices[not_noise_clusters_by_spike] spike_templates = spike_templates[not_noise_clusters_by_spike] temp_scaling_amplitudes = temp_scaling_amplitudes[not_noise_clusters_by_spike] @@ -343,7 +366,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool cluster_groups = 3 * np.ones(cluster_ids.size) new_params = { - "spike_indexes": spike_indexes.squeeze(), + "spike_indices": spike_indices.squeeze(), "spike_templates": spike_templates.squeeze(), "spike_clusters": spike_clusters.squeeze(), "pc_features": pc_features, diff --git a/src/spikeinterface/working/plot_kilosort_drift_map.py b/src/spikeinterface/working/plot_kilosort_drift_map.py index 449403dba9..ecac38495f 100644 --- a/src/spikeinterface/working/plot_kilosort_drift_map.py +++ b/src/spikeinterface/working/plot_kilosort_drift_map.py @@ -1,8 +1,8 @@ from pathlib import Path -from spikeinterface.widgets.base import BaseWidget, to_attr import matplotlib.axis import scipy.signal -from spikeinterface.core import read_python + +# from spikeinterface.core import read_python import numpy as np import pandas as pd @@ -10,6 +10,8 @@ from scipy import stats import load_kilosort_utils +from spikeinterface.widgets.base import BaseWidget, to_attr + class KilosortDriftMapWidget(BaseWidget): """ @@ -399,5 +401,24 @@ def _filter_large_amplitude_spikes( return spike_times, spike_amplitudes, spike_depths -KilosortDriftMapWidget(r"D:\data\New folder\CA_528_1\imec0_ks2") +KilosortDriftMapWidget( + "/Users/joeziminski/data/bombcelll/sorter_output", + only_include_large_amplitude_spikes=False, + localised_spikes_only=True, +) plt.show() + +""" + sorter_output: str | Path, + only_include_large_amplitude_spikes: bool = True, + decimate: None | int = None, + add_histogram_plot: bool = False, + add_histogram_peaks_and_boundaries: bool = True, + add_drift_events: bool = True, + weight_histogram_by_amplitude: bool = False, + localised_spikes_only: bool = False, + exclude_noise: bool = False, + gain: float | None = None, + large_amplitude_only_segment_size: float = 800.0, + localised_spikes_channel_cutoff: int = 20, +"""