diff --git a/kilosort/gui/main.py b/kilosort/gui/main.py index 20aa87db..11cb03de 100644 --- a/kilosort/gui/main.py +++ b/kilosort/gui/main.py @@ -365,6 +365,8 @@ def load_binary_files(self): tmin = self.params['tmin'] tmax = self.params['tmax'] artifact = self.params['artifact_threshold'] + shift = self.params['shift'] + scale = self.params['scale'] if chan_map.max() >= n_channels: raise ValueError( @@ -382,6 +384,8 @@ def load_binary_files(self): tmin=tmin, tmax=tmax, artifact_threshold=artifact, + shift=shift, + scale=scale, file_object=self.file_object ) @@ -403,6 +407,8 @@ def load_binary_files(self): tmin=tmin, tmax=tmax, artifact_threshold=artifact, + shift=shift, + scale=scale, file_object=self.file_object ) as bin_file: self.context.whitening_matrix = preprocessing.get_whitening_matrix( @@ -424,6 +430,8 @@ def load_binary_files(self): tmin=tmin, tmax=tmax, artifact_threshold=artifact, + shift=shift, + scale=scale, file_object=self.file_object ) diff --git a/kilosort/gui/settings_box.py b/kilosort/gui/settings_box.py index 24118005..e5eb80a9 100644 --- a/kilosort/gui/settings_box.py +++ b/kilosort/gui/settings_box.py @@ -486,7 +486,10 @@ def check_settings(self): if not self.check_valid_binary_path(self.data_file_path): return False - none_allowed = ['dmin', 'nt0min', 'max_channel_distance', 'x_centers'] + none_allowed = [ + 'dmin', 'nt0min', 'max_channel_distance', 'x_centers', + 'shift', 'scale' + ] for k, v in self.settings.items(): if v is None and k not in none_allowed: print(f'`None` not allowed for parameter {k}.') diff --git a/kilosort/io.py b/kilosort/io.py index d82b5879..c8f41ed6 100644 --- a/kilosort/io.py +++ b/kilosort/io.py @@ -337,7 +337,7 @@ def __init__(self, filename: str, n_chan_bin: int, fs: int = 30000, NT: int = 60000, nt: int = 61, nt0min: int = 20, device: torch.device = None, write: bool = False, dtype: str = None, tmin: float = 0.0, tmax: float = np.inf, - file_object=None): + shift=None, scale=None, file_object=None): """ Creates/Opens a BinaryFile for reading and/or writing data that acts like numpy array @@ -363,6 +363,8 @@ def __init__(self, filename: str, n_chan_bin: int, fs: int = 30000, self.NT = NT self.nt = nt self.nt0min = nt0min + self.shift = shift + self.scale = scale if device is None: if torch.cuda.is_available(): device = torch.device('cuda') @@ -482,11 +484,18 @@ def __getitem__(self, *items): # Shift indices by minimum sample index. sample_indices = self._get_shifted_indices(idx) samples = self.file[sample_indices] - # Shift data to +/- 2**15 + if self.dtype == 'uint16': + # Shift data to +/- 2**15 samples = samples.astype('float32') samples = samples - 2**15 + # Typically only need to be used with float32 data + if self.scale is not None: + samples = samples * self.scale + if self.shift is not None: + samples = samples + self.shift + return samples def _get_shifted_indices(self, idx): @@ -526,11 +535,18 @@ def padded_batch_to_torch(self, ibatch, return_inds=False): bend = min(self.imax, np.uint64(bstart + self.NT + 2*self.nt)) data = self.file[bstart : bend] data = data.T - # Shift data to +/- 2**15 + if self.dtype == 'uint16': + # Shift data to +/- 2**15 data = data.astype('float32') data = data - 2**15 + # Typically only need to be used with float32 data + if self.scale is not None: + data = data * self.scale + if self.shift is not None: + data = data + self.shift + nsamp = data.shape[-1] X = torch.zeros((self.n_chan_bin, self.NT + 2*self.nt), device=self.device) @@ -550,7 +566,7 @@ def padded_batch_to_torch(self, ibatch, return_inds=False): bend += self.nt else: X[:] = torch.from_numpy(data).to(self.device).float() - + inds = [bstart, bend] if return_inds: return X, inds @@ -661,10 +677,12 @@ def __init__(self, filename: str, n_chan_bin: int, fs: int = 30000, whiten_mat: torch.Tensor = None, dshift: torch.Tensor = None, device: torch.device = None, do_CAR: bool = True, artifact_threshold: float = np.inf, invert_sign: bool = False, - dtype=None, tmin: float = 0.0, tmax: float = np.inf, file_object=None): + dtype=None, tmin: float = 0.0, tmax: float = np.inf, + shift=None, scale=None, file_object=None): super().__init__(filename, n_chan_bin, fs, NT, nt, nt0min, device, - dtype=dtype, tmin=tmin, tmax=tmax, file_object=file_object) + dtype=dtype, tmin=tmin, tmax=tmax, shift=shift, + scale=scale, file_object=file_object) self.chan_map = chan_map self.whiten_mat = whiten_mat self.hp_filter = hp_filter diff --git a/kilosort/parameters.py b/kilosort/parameters.py index 64f1784c..256115f3 100644 --- a/kilosort/parameters.py +++ b/kilosort/parameters.py @@ -112,6 +112,34 @@ """ }, + 'shift': { + 'gui_name': 'shift', 'type': float, 'min': -np.inf, 'max': np.inf, + 'exclude': [], 'default': None, 'step': 'data', + 'description': + """ + Scalar shift to apply to data before all other operations. In most + cases this should be left as None, but may be necessary for float32 + data for example. If needed, `shift` and `scale` should be set such + that data is roughly in the range -100 to +100. + + If set, data will be `data = data*scale + shift`. + """ + }, + + 'scale': { + 'gui_name': 'scale', 'type': float, 'min': -np.inf, 'max': np.inf, + 'exclude': [], 'default': None, 'step': 'data', + 'description': + """ + Scaling factor to apply to data before all other operations. In most + cases this should be left as None, but may be necessary for float32 + data for example. If needed, `shift` and `scale` should be set such + that data is roughly in the range -100 to +100. + + If set, data will be `data = data*scale + shift`. + """ + }, + ### PREPROCESSING 'artifact_threshold': { 'gui_name': 'artifact threshold', 'type': float, 'min': 0, 'max': np.inf, diff --git a/kilosort/run_kilosort.py b/kilosort/run_kilosort.py index 381a1344..2da7c3c8 100644 --- a/kilosort/run_kilosort.py +++ b/kilosort/run_kilosort.py @@ -288,7 +288,9 @@ def get_run_parameters(ops) -> list: ops['probe']['yc'], ops['settings']['tmin'], ops['settings']['tmax'], - ops['settings']['artifact_threshold'] + ops['settings']['artifact_threshold'], + ops['settings']['shift'], + ops['settings']['scale'] ] return parameters @@ -322,7 +324,7 @@ def compute_preprocessing(ops, device, tic0=np.nan, file_object=None): logger.info('-'*40) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, \ - xc, yc, tmin, tmax, artifact = get_run_parameters(ops) + xc, yc, tmin, tmax, artifact, shift, scale = get_run_parameters(ops) nskip = ops['settings']['nskip'] whitening_range = ops['settings']['whitening_range'] @@ -333,7 +335,7 @@ def compute_preprocessing(ops, device, tic0=np.nan, file_object=None): chan_map, hp_filter, device=device, do_CAR=do_CAR, invert_sign=invert, dtype=dtype, tmin=tmin, tmax=tmax, artifact_threshold=artifact, - file_object=file_object) + shift=shift, scale=scale, file_object=file_object) whiten_mat = preprocessing.get_whitening_matrix(bfile, xc, yc, nskip=nskip, nrange=whitening_range) @@ -387,7 +389,7 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None, logger.info('-'*40) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, \ - _, _, tmin, tmax, artifact = get_run_parameters(ops) + _, _, tmin, tmax, artifact, shift, scale = get_run_parameters(ops) hp_filter = ops['preprocessing']['hp_filter'] whiten_mat = ops['preprocessing']['whiten_mat'] @@ -395,7 +397,8 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None, ops['filename'], n_chan_bin, fs, NT, nt, twav_min, chan_map, hp_filter=hp_filter, whiten_mat=whiten_mat, device=device, do_CAR=do_CAR, invert_sign=invert, dtype=dtype, tmin=tmin, tmax=tmax, - artifact_threshold=artifact, file_object=file_object + artifact_threshold=artifact, shift=shift, scale=scale, + file_object=file_object ) ops, st = datashift.run(ops, bfile, device=device, progress_bar=progress_bar) @@ -413,7 +416,8 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None, ops['filename'], n_chan_bin, fs, NT, nt, twav_min, chan_map, hp_filter=hp_filter, whiten_mat=whiten_mat, device=device, dshift=ops['dshift'], do_CAR=do_CAR, dtype=dtype, tmin=tmin, tmax=tmax, - artifact_threshold=artifact, file_object=file_object + artifact_threshold=artifact, shift=shift, scale=scale, + file_object=file_object ) return ops, bfile, st