Skip to content

Commit

Permalink
Working on the private PC decomposition.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Nov 14, 2024
1 parent 2a9e082 commit b5c85ff
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
19 changes: 13 additions & 6 deletions src/spikeinterface/working/load_kilosort_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ def compute_spike_amplitude_and_depth(
center of mass from the first PC (or, second PC if no signal on first PC).
See `_get_locations_from_pc_features()` for details.
"""
if isinstance(sorter_output, str):
sorter_output = Path(sorter_output)

if not params["pc_features"]:
if params["pc_features"] is None:
raise ValueError("`pc_features` must be loaded into params. Use `load_ks_dir` with `load_pcs=True`.")

if localised_spikes_only:
Expand Down Expand Up @@ -118,10 +115,12 @@ def compute_spike_amplitude_and_depth(

def _get_locations_from_pc_features(params):
"""
Compute locations from the waveform principal component scores.
Notes
-----
Location of of each individual spike is computed from its low-dimensional projection.
During sorting, kilosort computes the '
`pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike.
Taking the first component, the subset of 32 channels associated with this
spike are indexed to get the actual channel locations (in um). Then, the channel
Expand All @@ -131,6 +130,13 @@ def _get_locations_from_pc_features(params):
https://github.com/cortex-lab/spikes
"""
# Compute spike depths

# for each spike, a PCA is computed just on that spike (n samples x n channels).
# the components are all different between spikes, so are not saved.
# This gives a (n pc = 3, num channels) set of scores.
# but then how it is possible for some spikes to have zero score onto the principal channel?

breakpoint()
pc_features = params["pc_features"][:, 0, :]
pc_features[pc_features < 0] = 0

Expand All @@ -153,7 +159,7 @@ def _get_locations_from_pc_features(params):

pc_features[no_pc1_signal_spikes] = pc_features_2[no_pc1_signal_spikes]

if any(np.sum(pc_features, axis=1) == 0):
if np.any(np.sum(pc_features, axis=1) == 0):
raise RuntimeError(
"Some spikes do not load at all onto the first"
"or second principal component. It is necessary"
Expand Down Expand Up @@ -319,7 +325,8 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
As this function strips the spikes and units based on only these two
data structures, they will work following manual reassignment in Phy.
"""
sorter_output = Path(sorter_output)
if isinstance(sorter_output, str):
sorter_output = Path(sorter_output)

params = read_python(sorter_output / "params.py")

Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/working/plot_kilosort_drift_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ def plot_matplotlib(self, data_plot: dict, **unused_kwargs) -> None:

dp = to_attr(data_plot)

params = load_kilosort_utils.load_ks_dir(dp.sorter_output, load_pcs=True, exclude_noise=dp.exclude_noise)

spike_indexes, spike_amplitudes, spike_locations, _ = load_kilosort_utils.compute_spike_amplitude_and_depth(
dp.sorter_output, dp.localised_spikes_only, dp.exclude_noise, dp.gain, dp.localised_spikes_channel_cutoff
params, dp.localised_spikes_only, dp.gain, dp.localised_spikes_channel_cutoff
)
spike_times = spike_indexes / 30000
spike_depths = spike_locations[:, 1]
Expand Down

0 comments on commit b5c85ff

Please sign in to comment.