diff --git a/kilosort/gui/main.py b/kilosort/gui/main.py index f762261c..438ff4dc 100644 --- a/kilosort/gui/main.py +++ b/kilosort/gui/main.py @@ -340,6 +340,9 @@ def set_parameters(self): params = settings.copy() params['save_preprocessed_copy'] = self.run_box.save_preproc_check.isChecked() + params['clear_cache'] = self.run_box.clear_cache_check.isChecked() + params['do_CAR'] = self.run_box.do_CAR_check.isChecked() + params['invert_sign'] = self.run_box.invert_sign_check.isChecked() assert params diff --git a/kilosort/gui/run_box.py b/kilosort/gui/run_box.py index e9b4e402..6854e9ca 100644 --- a/kilosort/gui/run_box.py +++ b/kilosort/gui/run_box.py @@ -24,6 +24,9 @@ def __init__(self, parent): self.run_all_button = QtWidgets.QPushButton("Run") self.spike_sort_button = QtWidgets.QPushButton("Spikesort") self.save_preproc_check = QtWidgets.QCheckBox("Save Preprocessed Copy") + self.clear_cache_check = QtWidgets.QCheckBox("Clear PyTorch Cache") + self.do_CAR_check = QtWidgets.QCheckBox("CAR") + self.invert_sign_check = QtWidgets.QCheckBox("Invert Sign") self.buttons = [ self.run_all_button @@ -44,7 +47,7 @@ def __init__(self, parent): self.remote_widgets = None self.progress_bar = QtWidgets.QProgressBar() - self.layout.addWidget(self.progress_bar, 3, 0, 2, 2) + self.layout.addWidget(self.progress_bar, 5, 0, 3, 4) self.setup() @@ -64,8 +67,36 @@ def setup(self): """ self.save_preproc_check.setToolTip(preproc_text) - self.layout.addWidget(self.run_all_button, 0, 0, 2, 2) - self.layout.addWidget(self.save_preproc_check, 2, 0, 1, 2) + self.clear_cache_check.setCheckState(QtCore.Qt.CheckState.Unchecked) + cache_text = """ + If enabled, force pytorch to free up memory reserved for its cache in + between memory-intensive operations. + Note that setting `clear_cache=True` is NOT recommended unless you + encounter GPU out-of-memory errors, since this can result in slower + sorting. + """ + self.clear_cache_check.setToolTip(cache_text) + + self.do_CAR_check.setCheckState(QtCore.Qt.CheckState.Checked) + car_text = """ + If enabled, apply common average reference during preprocessing + (recommended). + """ + self.do_CAR_check.setToolTip(car_text) + + self.invert_sign_check.setCheckState(QtCore.Qt.CheckState.Unchecked) + invert_sign_text = """ + If enabled, flip positive/negative values in data to conform to + standard expected by Kilosort4. This is NOT recommended unless you + know your data is using the opposite sign. + """ + self.invert_sign_check.setToolTip(invert_sign_text) + + self.layout.addWidget(self.run_all_button, 0, 0, 3, 4) + self.layout.addWidget(self.save_preproc_check, 3, 0, 1, 2) + self.layout.addWidget(self.clear_cache_check, 3, 2, 1, 2) + self.layout.addWidget(self.do_CAR_check, 4, 0, 1, 2) + self.layout.addWidget(self.invert_sign_check, 4, 2, 1, 2) self.setLayout(self.layout) diff --git a/kilosort/gui/sorter.py b/kilosort/gui/sorter.py index c9a47de7..19e5440c 100644 --- a/kilosort/gui/sorter.py +++ b/kilosort/gui/sorter.py @@ -52,17 +52,13 @@ def run(self): try: logger.info(f"Kilosort version {kilosort.__version__}") logger.info(f"Sorting {self.data_path}") + clear_cache = settings['clear_cache'] + if clear_cache: + logger.info('clear_cache=True') logger.info('-'*40) tic0 = time.time() - # TODO: make these options in GUI - do_CAR=True - invert_sign=False - - if not do_CAR: - logger.info("Skipping common average reference.") - if probe['chanMap'].max() >= settings['n_chan_bin']: raise ValueError( f'Largest value of chanMap exceeds channel count of data, ' @@ -74,9 +70,13 @@ def run(self): data_dtype = settings['data_dtype'] device = self.device save_preprocessed_copy = settings['save_preprocessed_copy'] + do_CAR = settings['do_CAR'] + invert_sign = settings['invert_sign'] + if not do_CAR: + logger.info("Skipping common average reference.") ops = initialize_ops(settings, probe, data_dtype, do_CAR, - invert_sign, device, save_preprocessed_copy) + invert_sign, device, save_preprocessed_copy) # Remove some stuff that doesn't need to be printed twice, # then pretty-print format for log file. ops_copy = ops.copy() @@ -94,7 +94,7 @@ def run(self): torch.random.manual_seed(1) ops, bfile, st0 = compute_drift_correction( ops, self.device, tic0=tic0, progress_bar=self.progress_bar, - file_object=self.file_object + file_object=self.file_object, clear_cache=clear_cache ) # Check scale of data for log file @@ -113,7 +113,7 @@ def run(self): # Sort spikes and save results st, tF, Wall0, clu0 = detect_spikes( ops, self.device, bfile, tic0=tic0, - progress_bar=self.progress_bar + progress_bar=self.progress_bar, clear_cache=clear_cache ) self.Wall0 = Wall0 @@ -123,7 +123,7 @@ def run(self): clu, Wall = cluster_spikes( st, tF, ops, self.device, bfile, tic0=tic0, - progress_bar=self.progress_bar + progress_bar=self.progress_bar, clear_cache=clear_cache ) ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \ save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0)