Skip to content

Commit

Permalink
Merge pull request #843 from MouseLand/jacob/working_changes
Browse files Browse the repository at this point in the history
Jacob/working changes
  • Loading branch information
jacobpennington authored Dec 30, 2024
2 parents 00f0ca8 + 9233270 commit 10620b5
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 6 deletions.
9 changes: 7 additions & 2 deletions kilosort/clustering_qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'),

clu = np.zeros(nsp, 'int32')
Wall = torch.zeros((0, ops['Nchan'], ops['settings']['n_pcs']))
Nfilt = None
nearby_chans_empty = 0
nmax = 0
prog = tqdm(np.arange(len(ycent)), miniters=20 if progress_bar else None,
Expand Down Expand Up @@ -433,9 +434,13 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'),
except:
logger.exception(f'Error in clustering_qr.run on center {ii}')
logger.debug(f'Xd shape: {Xd.shape}')
logger.debug(f'iclust shape: {iclust.shape}')
logger.debug(f'clu shape: {clu.shape}')
logger.debug(f'Nfilt: {Nfilt}')
logger.debug(f'num spikes: {nsp}')
try:
logger.debug(f'iclust shape: {iclust.shape}')
except UnboundLocalError:
logger.debug('iclust not yet assigned')
pass
raise

if nearby_chans_empty == len(ycent):
Expand Down
12 changes: 12 additions & 0 deletions kilosort/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,18 @@
default of 7 bins for a 30kHz sampling rate.
"""
},

'position_limit': {
'gui_name': 'position limit', 'type': float, 'min': 0, 'max': np.inf,
'exclude': [], 'default': 100, 'step': 'postprocessing',
'description':
"""
Maximum distance (in microns) between channels that can be used
to estimate spike positions in `postprocessing.compute_spike_positions`.
This does not affect spike sorting, only how positions are estimated
after sorting is complete.
"""
},
}

# Add default values to descriptions
Expand Down
8 changes: 8 additions & 0 deletions kilosort/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,22 @@ def remove_duplicates(spike_times, spike_clusters, dt=15):

def compute_spike_positions(st, tF, ops):
'''Get x,y positions of spikes relative to probe.'''
# Determine channel weightings for nearest channels
# based on norm of PC features. Channels that are far away have 0 weight,
# determined by `ops['settings']['position_limit']`.
tmass = (tF**2).sum(-1)
tmask = ops['iCC_mask'][:, ops['iU'][st[:,1]]].T.to(tmass.device)
tmass = tmass * tmask
tmass = tmass / tmass.sum(1, keepdim=True)

# Get x,y coordinates of nearest channels.
xc = torch.from_numpy(ops['xc']).to(tmass.device)
yc = torch.from_numpy(ops['yc']).to(tmass.device)
chs = ops['iCC'][:, ops['iU'][st[:,1]]].cpu()
xc0 = xc[chs.T]
yc0 = yc[chs.T]

# Estimate spike positions as weighted sum of coordinates of nearby channels.
xs = (xc0 * tmass).sum(1).cpu().numpy()
ys = (yc0 * tmass).sum(1).cpu().numpy()

Expand Down
49 changes: 45 additions & 4 deletions kilosort/template_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,56 @@
logger = logging.getLogger(__name__)


def prepare_extract(ops, U, nC, device=torch.device('cuda')):
ds = (ops['xc'] - ops['xc'][:, np.newaxis])**2 + (ops['yc'] - ops['yc'][:, np.newaxis])**2
def prepare_extract(xc, yc, U, nC, position_limit, device=torch.device('cuda')):
"""Identify desired channels based on distances and template norms.
Parameters
----------
xc : np.ndarray
X-coordinates of contact positions on probe.
yc : np.ndarray
Y-coordinates of contact positions on probe.
U : torch.Tensor
TODO
nC : int
Number of nearest channels to use.
position_limit : float
Max distance (in microns) between channels that are used to estimate
spike positions in `postprocessing.compute_spike_positions`.
Returns
-------
iCC : np.ndarray
For each channel, indices of nC nearest channels.
iCC_mask : np.ndarray
For each channel, a 1 if the channel is within 100um and a 0 otherwise.
Used to control spike position estimate in post-processing.
iU : torch.Tensor
For each template, index of channel with greatest norm.
Ucc : torch.Tensor
For each template, spatial PC features corresponding to iCC.
"""
ds = (xc - xc[:, np.newaxis])**2 + (yc - yc[:, np.newaxis])**2
iCC = np.argsort(ds, 0)[:nC]
iCC = torch.from_numpy(iCC).to(device)
iCC_mask = np.sort(ds, 0)[:nC]
iCC_mask = iCC_mask < position_limit**2
iCC_mask = torch.from_numpy(iCC_mask).to(device)
iU = torch.argmax((U**2).sum(1), -1)
Ucc = U[torch.arange(U.shape[0]),:,iCC[:,iU]]
return iCC, iU, Ucc

return iCC, iCC_mask, iU, Ucc


def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None):
nC = ops['settings']['nearest_chans']
iCC, iU, Ucc = prepare_extract(ops, U, nC, device=device)
position_limit = ops['settings']['position_limit']
iCC, iCC_mask, iU, Ucc = prepare_extract(
ops['xc'], ops['yc'], U, nC, position_limit, device=device
)
ops['iCC'] = iCC
ops['iCC_mask'] = iCC_mask
ops['iU'] = iU
nt = ops['nt']

Expand Down Expand Up @@ -85,6 +123,7 @@ def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None):

return st, tF, ops


def align_U(U, ops, device=torch.device('cuda')):
Uex = torch.einsum('xyz, zt -> xty', U.to(device), ops['wPCA'])
X = Uex.reshape(-1, ops['Nchan']).T
Expand All @@ -108,6 +147,7 @@ def postprocess_templates(Wall, ops, clu, st, device=torch.device('cuda')):
Wall3 = Wall3.transpose(1,2).to(device)
return Wall3


def prepare_matching(ops, U):
nt = ops['nt']
W = ops['wPCA'].contiguous()
Expand All @@ -122,6 +162,7 @@ def prepare_matching(ops, U):

return ctc


def run_matching(ops, X, U, ctc, device=torch.device('cuda')):
Th = ops['Th_learned']
nt = ops['nt']
Expand Down

0 comments on commit 10620b5

Please sign in to comment.