Skip to content

Commit

Permalink
Added clear_cache option
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Aug 11, 2024
1 parent 5a1e630 commit 1d11e34
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 25 deletions.
14 changes: 11 additions & 3 deletions kilosort/clustering_qr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import gc

import numpy as np
import torch
from torch import sparse_coo_tensor as coo
Expand Down Expand Up @@ -301,7 +303,8 @@ def y_centers(ops):
return centers


def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_bar=None):
def run(ops, st, tF, mode = 'template', device=torch.device('cuda'),
progress_bar=None, clear_cache=False):

if mode == 'template':
xy, iC = xy_templates(ops)
Expand Down Expand Up @@ -362,11 +365,16 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_b

# find new clusters
iclust, iclust0, M, iclust_init = cluster(Xd, nskip=nskip, lam=1,
seed=5, device=device)
seed=5, device=device)
if clear_cache:
gc.collect()
torch.cuda.empty_cache()

xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0)

xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust, my_clus, meta = st0)
xtree, tstat = swarmsplitter.split(
Xd.numpy(), xtree, tstat,iclust, my_clus, meta=st0
)

iclust = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat)

Expand Down
8 changes: 6 additions & 2 deletions kilosort/datashift.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def kernel2D(x, y, sig = 1):
Kn = np.exp(-ds / (2*sig**2))
return Kn

def run(ops, bfile, device=torch.device('cuda'), progress_bar=None):
def run(ops, bfile, device=torch.device('cuda'), progress_bar=None,
clear_cache=False):
""" this step computes a drift correction model
it returns vertical correction amplitudes for each batch, and for multiple blocks in a batch if nblocks > 1.
"""
Expand All @@ -194,7 +195,10 @@ def run(ops, bfile, device=torch.device('cuda'), progress_bar=None):
return ops, None

# the first step is to extract all spikes using the universal templates
st, _, ops = spikedetect.run(ops, bfile, device=device, progress_bar=progress_bar)
st, _, ops = spikedetect.run(
ops, bfile, device=device, progress_bar=progress_bar,
clear_cache=clear_cache
)

# spikes are binned by amplitude and y-position to construct a "fingerprint" for each batch
F, ysamp = bin_spikes(ops, st)
Expand Down
61 changes: 42 additions & 19 deletions kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
def run_kilosort(settings, probe=None, probe_name=None, filename=None,
data_dir=None, file_object=None, results_dir=None,
data_dtype=None, do_CAR=True, invert_sign=False, device=None,
progress_bar=None, save_extra_vars=False,
progress_bar=None, save_extra_vars=False, clear_cache=False,
save_preprocessed_copy=False, bad_channels=None):
"""Run full spike sorting pipeline on specified data.
Expand Down Expand Up @@ -82,6 +82,12 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
not need to specify this.
save_extra_vars : bool; default=False.
If True, save tF and Wall to disk after sorting.
clear_cache : bool; default=False.
If True, force pytorch to free up memory reserved for its cache in
between memory-intensive operations.
Note that setting `clear_cache=True` is NOT recommended unless you
encounter GPU out-of-memory errors, since this can result in slower
sorting.
save_preprocessed_copy : bool; default=False.
If True, save a pre-processed copy of the data (including drift
correction) to `temp_wh.dat` in the results directory and format Phy
Expand Down Expand Up @@ -150,6 +156,8 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
try:
logger.info(f"Kilosort version {kilosort.__version__}")
logger.info(f"Sorting {filename}")
if clear_cache:
logger.info('clear_cache=True')
logger.info('-'*40)

if data_dtype is None:
Expand Down Expand Up @@ -189,15 +197,14 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False)
logger.debug(f"Initial ops:\n{print_ops}\n")


# Set preprocessing and drift correction parameters
ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object)
np.random.seed(1)
torch.cuda.manual_seed_all(1)
torch.random.manual_seed(1)
ops, bfile, st0 = compute_drift_correction(
ops, device, tic0=tic0, progress_bar=progress_bar,
file_object=file_object
file_object=file_object, clear_cache=clear_cache,
)

# Check scale of data for log file
Expand All @@ -208,14 +215,20 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)

# Sort spikes and save results
st,tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0,
progress_bar=progress_bar)
clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0,
progress_bar=progress_bar)
st,tF, _, _ = detect_spikes(
ops, device, bfile, tic0=tic0, progress_bar=progress_bar,
clear_cache=clear_cache
)
clu, Wall = cluster_spikes(
st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar,
clear_cache=clear_cache
)
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0,
save_extra_vars=save_extra_vars,
save_preprocessed_copy=save_preprocessed_copy)
save_sorting(
ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0,
save_extra_vars=save_extra_vars,
save_preprocessed_copy=save_preprocessed_copy
)
except:
# This makes sure the full traceback is written to log file.
logger.exception('Encountered error in `run_kilosort`:')
Expand Down Expand Up @@ -456,7 +469,7 @@ def compute_preprocessing(ops, device, tic0=np.nan, file_object=None):


def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
file_object=None):
file_object=None, clear_cache=False):
"""Compute drift correction parameters and save them to `ops`.
Parameters
Expand Down Expand Up @@ -504,7 +517,8 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
file_object=file_object
)

ops, st = datashift.run(ops, bfile, device=device, progress_bar=progress_bar)
ops, st = datashift.run(ops, bfile, device=device, progress_bar=progress_bar,
clear_cache=clear_cache)
bfile.close()
logger.info(f'drift computed in {time.time()-tic : .2f}s; ' +
f'total {time.time()-tic0 : .2f}s')
Expand All @@ -526,7 +540,8 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
return ops, bfile, st


def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None,
clear_cache=False):
"""Detect spikes via template deconvolution.
Parameters
Expand Down Expand Up @@ -563,7 +578,10 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
logger.info(' ')
logger.info(f'Extracting spikes using templates')
logger.info('-'*40)
st0, tF, ops = spikedetect.run(ops, bfile, device=device, progress_bar=progress_bar)
st0, tF, ops = spikedetect.run(
ops, bfile, device=device, progress_bar=progress_bar,
clear_cache=clear_cache
)
tF = torch.from_numpy(tF)
logger.info(f'{len(st0)} spikes extracted in {time.time()-tic : .2f}s; ' +
f'total {time.time()-tic0 : .2f}s')
Expand All @@ -576,8 +594,10 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
logger.info(' ')
logger.info('First clustering')
logger.info('-'*40)
clu, Wall = clustering_qr.run(ops, st0, tF, mode='spikes', device=device,
progress_bar=progress_bar)
clu, Wall = clustering_qr.run(
ops, st0, tF, mode='spikes', device=device, progress_bar=progress_bar,
clear_cache=clear_cache
)
Wall3 = template_matching.postprocess_templates(Wall, ops, clu, st0, device=device)
logger.info(f'{clu.max()+1} clusters found, in {time.time()-tic : .2f}s; ' +
f'total {time.time()-tic0 : .2f}s')
Expand All @@ -600,7 +620,8 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
return st, tF, Wall, clu


def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None):
def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None,
clear_cache=False):
"""Cluster spikes using graph-based methods.
Parameters
Expand Down Expand Up @@ -636,8 +657,10 @@ def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None):
logger.info(' ')
logger.info('Final clustering')
logger.info('-'*40)
clu, Wall = clustering_qr.run(ops, st, tF, mode = 'template', device=device,
progress_bar=progress_bar)
clu, Wall = clustering_qr.run(
ops, st, tF, mode = 'template', device=device, progress_bar=progress_bar,
clear_cache=clear_cache
)
logger.info(f'{clu.max()+1} clusters found, in {time.time()-tic : .2f}s; ' +
f'total {time.time()-tic0 : .2f}s')
logger.debug(f'clu shape: {clu.shape}')
Expand Down
3 changes: 2 additions & 1 deletion kilosort/spikedetect.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ def yweighted(yc, iC, adist, xy, device=torch.device('cuda')):
yct = (cF0 * yy[:,xy[:,0]]).sum(0)
return yct

def run(ops, bfile, device=torch.device('cuda'), progress_bar=None):
def run(ops, bfile, device=torch.device('cuda'), progress_bar=None,
clear_cache=False):
sig = ops['settings']['min_template_size']
nsizes = ops['settings']['template_sizes']

Expand Down

0 comments on commit 1d11e34

Please sign in to comment.