Skip to content

Commit

Permalink
Added new shift and scale parameters for float32 data
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Jun 8, 2024
1 parent 2e9d0f0 commit 971697d
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 13 deletions.
8 changes: 8 additions & 0 deletions kilosort/gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ def load_binary_files(self):
tmin = self.params['tmin']
tmax = self.params['tmax']
artifact = self.params['artifact_threshold']
shift = self.params['shift']
scale = self.params['scale']

if chan_map.max() >= n_channels:
raise ValueError(
Expand All @@ -382,6 +384,8 @@ def load_binary_files(self):
tmin=tmin,
tmax=tmax,
artifact_threshold=artifact,
shift=shift,
scale=scale,
file_object=self.file_object
)

Expand All @@ -403,6 +407,8 @@ def load_binary_files(self):
tmin=tmin,
tmax=tmax,
artifact_threshold=artifact,
shift=shift,
scale=scale,
file_object=self.file_object
) as bin_file:
self.context.whitening_matrix = preprocessing.get_whitening_matrix(
Expand All @@ -424,6 +430,8 @@ def load_binary_files(self):
tmin=tmin,
tmax=tmax,
artifact_threshold=artifact,
shift=shift,
scale=scale,
file_object=self.file_object
)

Expand Down
5 changes: 4 additions & 1 deletion kilosort/gui/settings_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,10 @@ def check_settings(self):
if not self.check_valid_binary_path(self.data_file_path):
return False

none_allowed = ['dmin', 'nt0min', 'max_channel_distance', 'x_centers']
none_allowed = [
'dmin', 'nt0min', 'max_channel_distance', 'x_centers',
'shift', 'scale'
]
for k, v in self.settings.items():
if v is None and k not in none_allowed:
print(f'`None` not allowed for parameter {k}.')
Expand Down
30 changes: 24 additions & 6 deletions kilosort/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def __init__(self, filename: str, n_chan_bin: int, fs: int = 30000,
NT: int = 60000, nt: int = 61, nt0min: int = 20,
device: torch.device = None, write: bool = False,
dtype: str = None, tmin: float = 0.0, tmax: float = np.inf,
file_object=None):
shift=None, scale=None, file_object=None):
"""
Creates/Opens a BinaryFile for reading and/or writing data that acts like numpy array
Expand All @@ -363,6 +363,8 @@ def __init__(self, filename: str, n_chan_bin: int, fs: int = 30000,
self.NT = NT
self.nt = nt
self.nt0min = nt0min
self.shift = shift
self.scale = scale
if device is None:
if torch.cuda.is_available():
device = torch.device('cuda')
Expand Down Expand Up @@ -482,11 +484,18 @@ def __getitem__(self, *items):
# Shift indices by minimum sample index.
sample_indices = self._get_shifted_indices(idx)
samples = self.file[sample_indices]
# Shift data to +/- 2**15

if self.dtype == 'uint16':
# Shift data to +/- 2**15
samples = samples.astype('float32')
samples = samples - 2**15

# Typically only need to be used with float32 data
if self.scale is not None:
samples = samples * self.scale
if self.shift is not None:
samples = samples + self.shift

return samples

def _get_shifted_indices(self, idx):
Expand Down Expand Up @@ -526,11 +535,18 @@ def padded_batch_to_torch(self, ibatch, return_inds=False):
bend = min(self.imax, np.uint64(bstart + self.NT + 2*self.nt))
data = self.file[bstart : bend]
data = data.T
# Shift data to +/- 2**15

if self.dtype == 'uint16':
# Shift data to +/- 2**15
data = data.astype('float32')
data = data - 2**15

# Typically only need to be used with float32 data
if self.scale is not None:
data = data * self.scale
if self.shift is not None:
data = data + self.shift

nsamp = data.shape[-1]
X = torch.zeros((self.n_chan_bin, self.NT + 2*self.nt), device=self.device)

Expand All @@ -550,7 +566,7 @@ def padded_batch_to_torch(self, ibatch, return_inds=False):
bend += self.nt
else:
X[:] = torch.from_numpy(data).to(self.device).float()

inds = [bstart, bend]
if return_inds:
return X, inds
Expand Down Expand Up @@ -661,10 +677,12 @@ def __init__(self, filename: str, n_chan_bin: int, fs: int = 30000,
whiten_mat: torch.Tensor = None, dshift: torch.Tensor = None,
device: torch.device = None, do_CAR: bool = True,
artifact_threshold: float = np.inf, invert_sign: bool = False,
dtype=None, tmin: float = 0.0, tmax: float = np.inf, file_object=None):
dtype=None, tmin: float = 0.0, tmax: float = np.inf,
shift=None, scale=None, file_object=None):

super().__init__(filename, n_chan_bin, fs, NT, nt, nt0min, device,
dtype=dtype, tmin=tmin, tmax=tmax, file_object=file_object)
dtype=dtype, tmin=tmin, tmax=tmax, shift=shift,
scale=scale, file_object=file_object)
self.chan_map = chan_map
self.whiten_mat = whiten_mat
self.hp_filter = hp_filter
Expand Down
28 changes: 28 additions & 0 deletions kilosort/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,34 @@
"""
},

'shift': {
'gui_name': 'shift', 'type': float, 'min': -np.inf, 'max': np.inf,
'exclude': [], 'default': None, 'step': 'data',
'description':
"""
Scalar shift to apply to data before all other operations. In most
cases this should be left as None, but may be necessary for float32
data for example. If needed, `shift` and `scale` should be set such
that data is roughly in the range -100 to +100.
If set, data will be `data = data*scale + shift`.
"""
},

'scale': {
'gui_name': 'scale', 'type': float, 'min': -np.inf, 'max': np.inf,
'exclude': [], 'default': None, 'step': 'data',
'description':
"""
Scaling factor to apply to data before all other operations. In most
cases this should be left as None, but may be necessary for float32
data for example. If needed, `shift` and `scale` should be set such
that data is roughly in the range -100 to +100.
If set, data will be `data = data*scale + shift`.
"""
},

### PREPROCESSING
'artifact_threshold': {
'gui_name': 'artifact threshold', 'type': float, 'min': 0, 'max': np.inf,
Expand Down
16 changes: 10 additions & 6 deletions kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ def get_run_parameters(ops) -> list:
ops['probe']['yc'],
ops['settings']['tmin'],
ops['settings']['tmax'],
ops['settings']['artifact_threshold']
ops['settings']['artifact_threshold'],
ops['settings']['shift'],
ops['settings']['scale']
]

return parameters
Expand Down Expand Up @@ -322,7 +324,7 @@ def compute_preprocessing(ops, device, tic0=np.nan, file_object=None):
logger.info('-'*40)

n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, \
xc, yc, tmin, tmax, artifact = get_run_parameters(ops)
xc, yc, tmin, tmax, artifact, shift, scale = get_run_parameters(ops)
nskip = ops['settings']['nskip']
whitening_range = ops['settings']['whitening_range']

Expand All @@ -333,7 +335,7 @@ def compute_preprocessing(ops, device, tic0=np.nan, file_object=None):
chan_map, hp_filter, device=device, do_CAR=do_CAR,
invert_sign=invert, dtype=dtype, tmin=tmin,
tmax=tmax, artifact_threshold=artifact,
file_object=file_object)
shift=shift, scale=scale, file_object=file_object)
whiten_mat = preprocessing.get_whitening_matrix(bfile, xc, yc, nskip=nskip,
nrange=whitening_range)

Expand Down Expand Up @@ -387,15 +389,16 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
logger.info('-'*40)

n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, \
_, _, tmin, tmax, artifact = get_run_parameters(ops)
_, _, tmin, tmax, artifact, shift, scale = get_run_parameters(ops)
hp_filter = ops['preprocessing']['hp_filter']
whiten_mat = ops['preprocessing']['whiten_mat']

bfile = io.BinaryFiltered(
ops['filename'], n_chan_bin, fs, NT, nt, twav_min, chan_map,
hp_filter=hp_filter, whiten_mat=whiten_mat, device=device, do_CAR=do_CAR,
invert_sign=invert, dtype=dtype, tmin=tmin, tmax=tmax,
artifact_threshold=artifact, file_object=file_object
artifact_threshold=artifact, shift=shift, scale=scale,
file_object=file_object
)

ops, st = datashift.run(ops, bfile, device=device, progress_bar=progress_bar)
Expand All @@ -413,7 +416,8 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
ops['filename'], n_chan_bin, fs, NT, nt, twav_min, chan_map,
hp_filter=hp_filter, whiten_mat=whiten_mat, device=device,
dshift=ops['dshift'], do_CAR=do_CAR, dtype=dtype, tmin=tmin, tmax=tmax,
artifact_threshold=artifact, file_object=file_object
artifact_threshold=artifact, shift=shift, scale=scale,
file_object=file_object
)

return ops, bfile, st
Expand Down

0 comments on commit 971697d

Please sign in to comment.