diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index ff75789aab..d4c44db49b 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -95,6 +95,7 @@ def generate_sorting( add_spikes_on_borders=False, num_spikes_per_border=3, border_size_samples=20, + extra_outputs=False, seed=None, ): """ @@ -135,10 +136,14 @@ def generate_sorting( num_segments = len(durations) unit_ids = np.arange(num_units) + extra_outputs_dict = { + "firing_rates": [], + } + spikes = [] for segment_index in range(num_segments): num_samples = int(sampling_frequency * durations[segment_index]) - samples, labels = synthesize_poisson_spike_vector( + samples, labels, firing_rates_array = synthesize_poisson_spike_vector( num_units=num_units, sampling_frequency=sampling_frequency, duration=durations[segment_index], @@ -172,12 +177,17 @@ def generate_sorting( ) spikes.append(spikes_on_borders) + extra_outputs_dict["firing_rates"].append(firing_rates_array) + spikes = np.concatenate(spikes) spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))] sorting = NumpySorting(spikes, sampling_frequency, unit_ids) - return sorting + if extra_outputs: + return sorting, extra_outputs_dict + else: + return sorting def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): @@ -776,7 +786,7 @@ def synthesize_poisson_spike_vector( unit_indices = unit_indices[sort_indices] spike_frames = spike_frames[sort_indices] - return spike_frames, unit_indices + return spike_frames, unit_indices, firing_rates def synthesize_random_firings( @@ -2188,12 +2198,19 @@ def generate_ground_truth_recording( parent_recording=noise_rec, upsample_vector=upsample_vector, ) - recording.annotate(is_filtered=True) - recording.set_probe(probe, in_place=True) - recording.set_channel_gains(1.0) - recording.set_channel_offsets(0.0) - + setup_inject_templates_recording(recording, probe) recording.name = "GroundTruthRecording" sorting.name = "GroundTruthSorting" return recording, sorting + + +def setup_inject_templates_recording(recording: BaseRecording, probe: Probe) -> None: + """ + Convenience function to modify a generated + recording in-place with annotation and probe details + """ + recording.annotate(is_filtered=True) + recording.set_probe(probe, in_place=True) + recording.set_channel_gains(1.0) + recording.set_channel_offsets(0.0) diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index b439c57c52..39fb2b4a81 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -8,6 +8,7 @@ """ +from __future__ import annotations import numpy as np from probeinterface import generate_multi_columns_probe @@ -21,6 +22,7 @@ ) from .drift_tools import DriftingTemplates, make_linear_displacement, InjectDriftingTemplatesRecording from .noise_tools import generate_noise +from probeinterface import Probe # this should be moved in probeinterface but later @@ -181,7 +183,7 @@ def generate_displacement_vector( duration : float Duration of the displacement vector in seconds unit_locations : np.array - The unit location with shape (num_units, 3) + The unit location with shape (num_units, 2) displacement_sampling_frequency : float, default: 5. The sampling frequency of the displacement vector drift_start_um : list of float, default: [0, 20.] @@ -238,22 +240,70 @@ def generate_displacement_vector( if non_rigid_gradient is None: displacement_unit_factor[:, m] = 1 else: - gradient_direction = drift_stop_um - drift_start_um - gradient_direction /= np.linalg.norm(gradient_direction) - - proj = np.dot(unit_locations, gradient_direction).squeeze() - factors = (proj - np.min(proj)) / (np.max(proj) - np.min(proj)) - if non_rigid_gradient < 0: - # reverse - factors = 1 - factors - f = np.abs(non_rigid_gradient) - displacement_unit_factor[:, m] = factors * (1 - f) + f + displacement_unit_factor[:, m] = calculate_displacement_unit_factor( + non_rigid_gradient, unit_locations, drift_start_um, drift_stop_um + ) displacement_vectors = np.concatenate(displacement_vectors, axis=2) return displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps +def calculate_displacement_unit_factor( + non_rigid_gradient: float, unit_locations: np.array, drift_start_um: np.array, drift_stop_um: np.array +) -> np.array: + """ + Introduces a non-rigid drift across the probe, this is a linear + scaling of the displacement based on the unit position. + + To introduce non-rigid drift, a set of scaling factors (one per unit) + are generated. These scale the displacement applied to each unit + as a function of unit position. The smaller the `non_rigid_gradient`, + the larger the influence of the unit position is on scaling the + displacement (more non-linearity). + + The projections of the gradient vector (x, y) + and unit locations (x, y) are normalised to range between + 0 and 1 (i.e. based on relative location to the gradient). + + Parameters + ---------- + + non_rigid_gradient : float + A number in the range [0, 1] by which to scale the scaling factors + that are based on unit location. This sets the weighting given to the factors + based on unit locations. When 1, the factors will all equal 1 (no effect), + when 0, the scaling factor based on unit location will be used directly. + Smaller number results in more nonlinearity. + unit_locations : np.array + The unit location with shape (num_units, 2) + drift_start_um : np.array + The start boundary of the motion in the x and y direction. + drift_stop_um : np.array + The stop boundary of the motion in the x and y direction. + + Returns + ------- + displacement_unit_factor : np.array + An array of scaling factors (one per unit) by which + to scale the displacement. + """ + gradient_direction = drift_stop_um - drift_start_um + gradient_direction /= np.linalg.norm(gradient_direction) + + proj = np.dot(unit_locations, gradient_direction).squeeze() + factors = (proj - np.min(proj)) / (np.max(proj) - np.min(proj)) + + if non_rigid_gradient < 0: # reverse + factors = 1 - factors + + f = np.abs(non_rigid_gradient) + + displacement_unit_factor = factors * (1 - f) + f + + return displacement_unit_factor + + def generate_drifting_recording( num_units=250, duration=600.0, @@ -352,12 +402,9 @@ def generate_drifting_recording( rng = np.random.default_rng(seed=seed) # probe - if generate_probe_kwargs is None: - generate_probe_kwargs = _toy_probes[probe_name] - probe = generate_multi_columns_probe(**generate_probe_kwargs) - num_channels = probe.get_contact_count() - probe.set_device_channel_indices(np.arange(num_channels)) + probe = generate_probe(generate_probe_kwargs, probe_name) channel_locations = probe.contact_positions + # import matplotlib.pyplot as plt # import probeinterface.plotting # fig, ax = plt.subplots() @@ -385,9 +432,7 @@ def generate_drifting_recording( unit_displacements[:, :, direction] += m # unit_params need to be fixed before the displacement steps - generate_templates_kwargs = generate_templates_kwargs.copy() - unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed) - generate_templates_kwargs["unit_params"] = unit_params + generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed) # generate templates templates_array = generate_templates( @@ -479,3 +524,50 @@ def generate_drifting_recording( return static_recording, drifting_recording, sorting, extra_infos else: return static_recording, drifting_recording, sorting + + +def generate_probe(generate_probe_kwargs: dict, probe_name: str | None = None) -> Probe: + """ + Generate a probe for use in certain ground-truth recordings. + + Parameters + ---------- + + generate_probe_kwargs : dict + The kwargs to pass to `generate_multi_columns_probe()` + probe_name : str | None + The probe type if generate_probe_kwargs is None. + """ + if generate_probe_kwargs is None: + assert probe_name is not None, "`probe_name` must be set if `generate_probe_kwargs` is `None`." + generate_probe_kwargs = _toy_probes[probe_name] + probe = generate_multi_columns_probe(**generate_probe_kwargs) + num_channels = probe.get_contact_count() + probe.set_device_channel_indices(np.arange(num_channels)) + + return probe + + +def fix_generate_templates_kwargs(generate_templates_kwargs: dict, num_units: int, seed: int) -> dict: + """ + Fix the generate_template_kwargs such that the same units are created + across calls to `generate_template`. We must explicitly pre-set + the parameters for each unit, done in `_ensure_unit_params()`. + + Parameters + ---------- + + generate_templates_kwargs : dict + These kwargs will have the "unit_params" entry edited such that the + parameters are explicitly set for each unit to create (rather than + generated randomly on the fly). + num_units : int + Number of units to fix the kwargs for + seed : int + Random seed. + """ + generate_templates_kwargs = generate_templates_kwargs.copy() + unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed) + generate_templates_kwargs["unit_params"] = unit_params + + return generate_templates_kwargs diff --git a/src/spikeinterface/generation/session_displacement_generator.py b/src/spikeinterface/generation/session_displacement_generator.py new file mode 100644 index 0000000000..8a7bca376b --- /dev/null +++ b/src/spikeinterface/generation/session_displacement_generator.py @@ -0,0 +1,510 @@ +import copy + +from spikeinterface.generation.drifting_generator import ( + generate_probe, + fix_generate_templates_kwargs, + calculate_displacement_unit_factor, +) +from spikeinterface.core.generate import ( + generate_unit_locations, + generate_sorting, + generate_templates, +) +import numpy as np +from spikeinterface.generation.noise_tools import generate_noise +from spikeinterface.core.generate import setup_inject_templates_recording, _ensure_firing_rates +from spikeinterface.core import InjectTemplatesRecording + + +def generate_session_displacement_recordings( + num_units=250, + recording_durations=(10, 10, 10), + recording_shifts=((0, 0), (0, 25), (0, 50)), + non_rigid_gradient=None, + recording_amplitude_scalings=None, + shift_units_outside_probe=False, + sampling_frequency=30000.0, + probe_name="Neuropixel-128", + generate_probe_kwargs=None, + generate_unit_locations_kwargs=dict( + margin_um=20.0, + minimum_z=5.0, + maximum_z=45.0, + minimum_distance=18.0, + max_iteration=100, + distance_strict=False, + ), + generate_templates_kwargs=dict( + ms_before=1.5, + ms_after=3.0, + mode="ellipsoid", + unit_params=dict( + alpha=(150.0, 500.0), + spatial_decay=(10, 45), + ), + ), + generate_sorting_kwargs=dict(firing_rates=(2.0, 8.0), refractory_period_ms=4.0), + generate_noise_kwargs=dict(noise_levels=(12.0, 15.0), spatial_decay=25.0), + extra_outputs=False, + seed=None, +): + """ + Generate a set of recordings simulating probe drift across recording + sessions. + + Rigid drift can be added in the (x, y) direction in `recording_shifts`. + These drifts can be made non-rigid (scaled dependent on the unit location) + with the `non_rigid_gradient` parameter. Amplitude of units can be scaled + (e.g. template signal removed by scaling with zero) by specifying scaling + factors in `recording_amplitude_scalings`. + + Parameters + ---------- + + num_units : int + The number of units in the generated recordings. + recording_durations : list + An array of length (num_recordings,) specifying the + duration that each created recording should be. + recording_shifts : list + An array of length (num_recordings,) in which each element + is a 2-element array specifying the (x, y) shift for the recording. + Typically, the first recording will have shift (0, 0) so all further + recordings are shifted relative to it. e.g. to create two recordings, + the second shifted by 50 um in the x-direction and 250 um in the y + direction : ((0, 0), (50, 250)). + non_rigid_gradient : float + Factor which sets the level of non-rigidty in the displacement. + See `calculate_displacement_unit_factor` for details. + recording_amplitude_scalings : dict + A dict with keys: + "method" - order by which to apply the scalings. + "by_passed_order" - scalings are applied to the unit templates + in order passed + "by_firing_rate" - scalings are applied to the units in order of + maximum to minimum firing rate + "by_amplitude_and_firing_rate" - scalings are applied to the units + in order of amplitude * firing_rate (maximum to minimum) + "scalings" - a list of numpy arrays, one for each recording, with + each entry an array of length num_units holding the unit scalings. + e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)). + shift_units_outside_probe : bool + By default (`False`), when units are shifted across sessions, new units are + not introduced into the recording (e.g. the region in which units + have been shifted out of is left at baseline level). In reality, + when the probe shifts new units from outside the original recorded + region are shifted into the recording. When `True`, new units + are shifted into the generated recording. + generate_sorting_kwargs : dict + Only `firing_rates` and `refractory_period_ms` are expected if passed. + + All other parameters are used as in from `generate_drifting_recording()`. + + Returns + ------- + output_recordings : list + A list of recordings with units shifted (i.e. replicated probe shift). + output_sortings : list + A list of corresponding sorting objects. + extra_outputs_dict (options) : dict + When `extra_outputs` is `True`, a dict containing variables used + in the generation process. + "unit_locations" : A list (length num records) of shifted unit locations + "templates_array_moved" : list[np.array] + A list (length num records) of (num_units, num_samples, num_channels) + arrays of templates that have been shifted. + + Notes + ----- + It is important to consider what unit properties are maintained + across the session. Here, all `generate_template_kwargs` are fixed + across sessions, to be sure the unit properties do not change. + The firing rates passed to `generate_sorting` for each unit are + also fixed across sessions. When a seed is set, the exact spike times + will also be fixed across recordings. otherwise, when seed is `None` + the actual spike times will be different across recordings, although + all other unit properties will be maintained (except any location + shifting and template scaling applied). + """ + # temporary fix + generate_unit_locations_kwargs = copy.deepcopy(generate_unit_locations_kwargs) + generate_templates_kwargs = copy.deepcopy(generate_templates_kwargs) + generate_sorting_kwargs = copy.deepcopy(generate_sorting_kwargs) + generate_noise_kwargs = copy.deepcopy(generate_noise_kwargs) + + _check_generate_session_displacement_arguments( + num_units, recording_durations, recording_shifts, recording_amplitude_scalings, shift_units_outside_probe + ) + + probe = generate_probe(generate_probe_kwargs, probe_name) + channel_locations = probe.contact_positions + + # Create the starting unit locations (which will be shifted). + unit_locations = generate_unit_locations( + num_units, + channel_locations, + seed=seed, + **generate_unit_locations_kwargs, + ) + + # Fix generate template kwargs, so they are the same for every created recording. + # Also fix unit firing rates across recordings. + fixed_generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed) + + fixed_firing_rates = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed) + fixed_generate_sorting_kwargs = copy.deepcopy(generate_sorting_kwargs) + fixed_generate_sorting_kwargs["firing_rates"] = fixed_firing_rates + + if shift_units_outside_probe: + # Create a new set of templates one probe-width above and + # one probe-width below the original templates. The number of + # units is duplicated for each section, so the new num units + # is 3x the old num units. + num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs = ( + _update_kwargs_for_extended_units( + num_units, + channel_locations, + unit_locations, + generate_unit_locations_kwargs, + generate_templates_kwargs, + generate_sorting_kwargs, + fixed_generate_templates_kwargs, + fixed_generate_sorting_kwargs, + seed, + ) + ) + + # Start looping over parameters, creating recordings shifted + # and scaled as required + extra_outputs_dict = { + "unit_locations": [], + "templates_array_moved": [], + "firing_rates": [], + } + output_recordings = [] + output_sortings = [] + + for rec_idx, (shift, duration) in enumerate(zip(recording_shifts, recording_durations)): + + displacement_vector, displacement_unit_factor = _get_inter_session_displacements( + shift, + non_rigid_gradient, + num_units, + unit_locations, + ) + + # Move the canonical `unit_locations` according to the set (x, y) shifts + unit_locations_moved = unit_locations.copy() + unit_locations_moved[:, :2] += displacement_vector[0, :][np.newaxis, :] * displacement_unit_factor + + # Generate the sorting (e.g. spike times) for the recording + sorting, sorting_extra_outputs = generate_sorting( + num_units=num_units, + sampling_frequency=sampling_frequency, + durations=[duration], + **fixed_generate_sorting_kwargs, + extra_outputs=True, + seed=seed, + ) + sorting.set_property("gt_unit_locations", unit_locations_moved) + + # Generate the noise in the recording + noise = generate_noise( + probe=probe, + sampling_frequency=sampling_frequency, + durations=[duration], + seed=seed, + **generate_noise_kwargs, + ) + + # Generate the (possibly shifted, scaled) unit templates + templates_array_moved = generate_templates( + channel_locations, + unit_locations_moved, + sampling_frequency=sampling_frequency, + seed=seed, + **fixed_generate_templates_kwargs, + ) + + if recording_amplitude_scalings is not None: + + first_rec_templates = ( + templates_array_moved if rec_idx == 0 else extra_outputs_dict["templates_array_moved"][0] + ) + _amplitude_scale_templates_in_place( + first_rec_templates, templates_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx + ) + + # Bring it all together in a `InjectTemplatesRecording` and + # propagate all relevant metadata to the recording. + ms_before = fixed_generate_templates_kwargs["ms_before"] + nbefore = int(sampling_frequency * ms_before / 1000.0) + + recording = InjectTemplatesRecording( + sorting=sorting, + templates=templates_array_moved, + nbefore=nbefore, + amplitude_factor=None, + parent_recording=noise, + num_samples=noise.get_num_samples(0), + upsample_vector=None, + check_borders=False, + ) + + setup_inject_templates_recording(recording, probe) + + recording.name = "InterSessionDisplacementRecording" + sorting.name = "InterSessionDisplacementSorting" + + output_recordings.append(recording) + output_sortings.append(sorting) + extra_outputs_dict["unit_locations"].append(unit_locations_moved) + extra_outputs_dict["templates_array_moved"].append(templates_array_moved) + extra_outputs_dict["firing_rates"].append(sorting_extra_outputs["firing_rates"][0]) + + if extra_outputs: + return output_recordings, output_sortings, extra_outputs_dict + else: + return output_recordings, output_sortings + + +def _get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_locations): + """ + Get the formatted `displacement_vector` and `displacement_unit_factor` + used to shift the `unit_locations`.. + + Parameters + --------- + shift : np.array | list | tuple + A 2-element array with the shift in the (x, y) direction. + non_rigid_gradient : float + Factor which sets the level of non-rigidty in the displacement. + See `calculate_displacement_unit_factor` for details. + num_units : int + Number of units + unit_locations : np.array + (num_units, 3) array of unit locations (x, y, z). + + Returns + ------- + displacement_vector : np.array + A (:, 2) array of (x, y) of displacements + to add to (i.e. move) unit_locations. + e.g. array([[1, 2]]) + displacement_unit_factor : np.array + A (num_units, :) array of scaling values to apply to the + displacement vector in order to add nonrigid shift to + the displacement. Note the same scaling is applied to the + x and y dimension. + """ + displacement_vector = np.atleast_2d(shift) + + if non_rigid_gradient is None or (shift[0] == 0 and shift[1] == 0): + displacement_unit_factor = np.ones((num_units, 1)) + else: + displacement_unit_factor = calculate_displacement_unit_factor( + non_rigid_gradient, + unit_locations[:, :2], + drift_start_um=np.array([0, 0], dtype=float), + drift_stop_um=np.array(shift, dtype=float), + ) + displacement_unit_factor = displacement_unit_factor[:, np.newaxis] + + return displacement_vector, displacement_unit_factor + + +def _amplitude_scale_templates_in_place( + first_rec_templates, moved_templates, recording_amplitude_scalings, sorting_extra_outputs, rec_idx +): + """ + Scale a set of templates given a set of scaling values. The scaling + values can be applied in the order passed, or instead in order of + the unit firing range (max to min) or unit amplitude * firing rate (max to min). + This will chang the `templates_array` in place. This must be done after + the templates are moved. + + Parameters + ---------- + first_rec_templates : np.array + The (num_units, num_samples, num_channels) templates array from the + first recording. Scaling by amplitude scales based on the amplitudes in + the first session. + moved_templates : np.array + A (num_units, num_samples, num_channels) array moved templates to the + current recording, that will be scaled. + recording_amplitude_scalings : dict + see `generate_session_displacement_recordings()`. + sorting_extra_outputs : dict + Extra output of `generate_sorting` holding the firing frequency of all units. + The unit order is assumed to match the templates. + rec_idx : int + The index of the recording for which the templates are being scaled. + + Notes + ----- + This method is used in the context of inter-session displacement. Often, + units may drop out of the recording across sessions. This simulates this by + directly scaling the template (e.g. if scaling by 0, the template is completely + dropped out). The provided scalings can be applied in the order passed, or + in the order of unit firing rate or firing rate * amplitude. The idea is + that it may be desired to remove to downscale the most activate neurons + that contribute most significantly to activity histograms. Similarly, + if amplitude is included in activity histograms the amplitude may + also want to be considered when ordering the units to downscale. + """ + method = recording_amplitude_scalings["method"] + + if method in ["by_amplitude_and_firing_rate", "by_firing_rate"]: + + firing_rates_hz = sorting_extra_outputs["firing_rates"][0] + + if method == "by_amplitude_and_firing_rate": + neg_ampl = np.min(np.min(first_rec_templates, axis=2), axis=1) + assert np.all(neg_ampl < 0), "assumes all amplitudes are negative here." + score = firing_rates_hz * neg_ampl + else: + score = firing_rates_hz + + order_idx = np.argsort(score) + ordered_rec_scalings = recording_amplitude_scalings["scalings"][rec_idx][order_idx, np.newaxis, np.newaxis] + + elif method == "by_passed_order": + + ordered_rec_scalings = recording_amplitude_scalings["scalings"][rec_idx][:, np.newaxis, np.newaxis] + + else: + raise ValueError("`recording_amplitude_scalings['method']` not recognised.") + + moved_templates *= ordered_rec_scalings + + +def _check_generate_session_displacement_arguments( + num_units, recording_durations, recording_shifts, recording_amplitude_scalings, shift_units_outside_probe +): + """ + Function to check the input arguments related to recording + shift and scale parameters are the correct size. + """ + expected_num_recs = len(recording_durations) + + if len(recording_shifts) != expected_num_recs: + raise ValueError( + "`recording_shifts` and `recording_durations` must be " + "the same length, the number of recordings to generate." + ) + + if recording_amplitude_scalings and shift_units_outside_probe: + raise ValueError( + "At present, using `recording_amplitude_scalings` and " + "`shift_units_outside_probe` together is not supported." + ) + + shifts_are_2d = [len(shift) == 2 for shift in recording_shifts] + if not all(shifts_are_2d): + raise ValueError("Each record entry for `recording_shifts` must have two elements, the x and y shift.") + + if recording_amplitude_scalings is not None: + + keys = recording_amplitude_scalings.keys() + if not "method" in keys or not "scalings" in keys: + raise ValueError("`recording_amplitude_scalings` must be a dict with keys `method` and `scalings`.") + + allowed_methods = ["by_passed_order", "by_amplitude_and_firing_rate", "by_firing_rate"] + if not recording_amplitude_scalings["method"] in allowed_methods: + raise ValueError(f"`recording_amplitude_scalings` must be one of {allowed_methods}") + + rec_scalings = recording_amplitude_scalings["scalings"] + if not len(rec_scalings) == expected_num_recs: + raise ValueError("`recording_amplitude_scalings` 'scalings' must have one array per recording.") + + if not all([len(scale) == num_units for scale in rec_scalings]): + raise ValueError( + "The entry for each recording in `recording_amplitude_scalings` " + "must have the same length as the number of units." + ) + + +def _update_kwargs_for_extended_units( + num_units, + channel_locations, + unit_locations, + generate_unit_locations_kwargs, + generate_templates_kwargs, + generate_sorting_kwargs, + fixed_generate_templates_kwargs, + fixed_generate_sorting_kwargs, + seed, +): + """ + In a real world situation, if the probe moves up / down + not only will previously recorded units be shifted, but + new units will be introduced into the recording. + + This function extends the default num units, unit locations, + and template / sorting kwargs to extend the unit of units + one probe's height (y dimension) above and below the probe. + Now, when the unit locations are shifted, new units will be + introduced into the recording (from below or above). + + It is important that the unit kwargs for the units are kept the + same across runs when seeded (i.e. whether `shift_units_outside_probe` + is `True` or `False`). To achieve this, the fixed unit kwargs + are extended with new units located above and below these fixed + units. The seeds are shifted slightly, so the introduced + units do not duplicate the existing units. Note that this maintains + the density of neurons above / below the probe (it is not random). + """ + seed_top = seed + 1 if seed is not None else None + seed_bottom = seed - 1 if seed is not None else None + + # Set unit locations above and below the probe and extend + # the `unit_locations` array. + channel_locations_extend_top = channel_locations.copy() + channel_locations_extend_top[:, 1] -= np.max(channel_locations[:, 1]) + + extend_top_locations = generate_unit_locations( + num_units, + channel_locations_extend_top, + seed=seed_top, + **generate_unit_locations_kwargs, + ) + + channel_locations_extend_bottom = channel_locations.copy() + channel_locations_extend_bottom[:, 1] += np.max(channel_locations[:, 1]) + + extend_bottom_locations = generate_unit_locations( + num_units, + channel_locations_extend_bottom, + seed=seed_bottom, + **generate_unit_locations_kwargs, + ) + + unit_locations = np.r_[extend_bottom_locations, unit_locations, extend_top_locations] + + # For the new units located above and below the probe, generate a set of + # firing rates and template kwargs. + + # Extend the template kwargs + template_kwargs_top = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_top) + template_kwargs_bottom = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_bottom) + + for key in fixed_generate_templates_kwargs["unit_params"].keys(): + fixed_generate_templates_kwargs["unit_params"][key] = np.r_[ + template_kwargs_top["unit_params"][key], + fixed_generate_templates_kwargs["unit_params"][key], + template_kwargs_bottom["unit_params"][key], + ] + + # Extend the firing rates + firing_rates_top = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_top) + firing_rates_bottom = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_bottom) + + fixed_generate_sorting_kwargs["firing_rates"] = np.r_[ + firing_rates_top, fixed_generate_sorting_kwargs["firing_rates"], firing_rates_bottom + ] + + # Update the number of units (3x as a + # new set above and below the existing units) + num_units *= 3 + + return num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs diff --git a/src/spikeinterface/generation/tests/test_session_displacement_generator.py b/src/spikeinterface/generation/tests/test_session_displacement_generator.py new file mode 100644 index 0000000000..44f80acead --- /dev/null +++ b/src/spikeinterface/generation/tests/test_session_displacement_generator.py @@ -0,0 +1,485 @@ +import pytest + +from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings +from spikeinterface.generation.drifting_generator import generate_drifting_recording +from spikeinterface.core import order_channels_by_depth +import numpy as np +from spikeinterface.sortingcomponents.peak_detection import detect_peaks +from spikeinterface.sortingcomponents.peak_localization import localize_peaks + + +class TestSessionDisplacementGenerator: + """ + This class tests the `generate_session_displacement_recordings` that + returns a recordings / sorting in which the units are shifted + across sessions. This is achieved by shifting the unit locations + in both (x, y) on the generated templates that are used in + `InjectTemplatesRecording()`. + """ + + @pytest.fixture(scope="function") + def options(self): + """ + Set a set of base options that can be used in + `generate_session_displacement_recordings() ("kwargs") + and provide general information on the generated recordings. + These can be edited in the tests as required. + """ + options = { + "kwargs": { + "recording_durations": (10, 10, 25, 33), + "recording_shifts": ((0, 0), (2, -100), (-3, 275), (4, 1e6)), + "num_units": 5, + "extra_outputs": True, + "seed": 42, + }, + "num_recs": 4, + "y_bin_um": 10, + } + options["kwargs"]["generate_probe_kwargs"] = dict( + num_columns=1, + num_contact_per_column=128, + xpitch=16, + ypitch=options["y_bin_um"], + contact_shapes="square", + contact_shape_params={"width": 12}, + ) + + return options + + ### Tests + def test_x_y_rigid_shifts_are_properly_set(self, options): + """ + The session displacement works by generating a set of + templates shared across all recordings, but set with + different `unit_locations()`. Check here that the + (x, y) displacements passed in `recording_shifts` are properly + propagated. + + First, check the set `unit_locations` are moved as expected according + to the (x, y) shifts). Next, check the templates themselves are + moved as expected. The x-axis shift has the effect of changing + the template amplitude, and is not possible to test. However, + the y-axis shift shifts the maximum signal channel, so we check + the maximum signal channel o fthe templates is shifted as expected. + This implicitly tests the x-axis case as if the x-axis `unit_locations` + are shifted as expected, and the unit-locations are propagated + to the template, then (x, y) will both be working. + """ + output_recordings, _, extra_outputs = generate_session_displacement_recordings(**options["kwargs"]) + num_units = options["kwargs"]["num_units"] + recording_shifts = options["kwargs"]["recording_shifts"] + + # test unit locations are shifted as expected according + # to the record shifts + locations_1 = extra_outputs["unit_locations"][0] + + for rec_idx in range(1, 4): + + shifts = recording_shifts[rec_idx] + + assert np.array_equal( + locations_1 + np.r_[shifts, 0].astype(np.float32), extra_outputs["unit_locations"][rec_idx] + ) + + # Check that the generated templates are correctly shifted + # For each generated unit, check that the max loading channel is + # shifted as expected. In the case that the unit location is off the + # probe, check the maximum signal channel is the min / max channel on + # the probe, or zero (the unit is too far to reach the probe). + min_channel_loc = output_recordings[0].get_channel_locations()[0, 1] + max_channel_loc = output_recordings[0].get_channel_locations()[-1, 1] + for unit_idx in range(num_units): + + start_pos = self._get_peak_chan_loc_in_um( + extra_outputs["templates_array_moved"][0][unit_idx], + options["y_bin_um"], + ) + + for rec_idx in range(1, options["num_recs"]): + + new_pos = self._get_peak_chan_loc_in_um( + extra_outputs["templates_array_moved"][rec_idx][unit_idx], options["y_bin_um"] + ) + + y_shift = recording_shifts[rec_idx][1] + if start_pos + y_shift > max_channel_loc: + assert new_pos == max_channel_loc or new_pos == 0 + elif start_pos + y_shift < min_channel_loc: + assert new_pos == min_channel_loc or new_pos == 0 + else: + assert np.isclose(new_pos, start_pos + y_shift, options["y_bin_um"]) + + # Confidence check the correct templates are + # loaded to the recording object. + for rec_idx in range(options["num_recs"]): + assert np.array_equal( + output_recordings[rec_idx].templates, + extra_outputs["templates_array_moved"][rec_idx], + ) + + def _get_peak_chan_loc_in_um(self, template_array, y_bin_um): + """ + Convenience function to get the maximally loading + channel y-position in um for the template. + """ + return np.argmax(np.max(template_array, axis=0)) * y_bin_um + + def test_recordings_length(self, options): + """ + Test that the `recording_durations` that sets the + length of each recording changes the recording + length as expected. + """ + output_recordings = generate_session_displacement_recordings(**options["kwargs"])[0] + + for rec, expected_rec_length in zip(output_recordings, options["kwargs"]["recording_durations"]): + assert rec.get_total_duration() == expected_rec_length + + def test_spike_times_and_firing_rates_across_recordings(self, options): + """ + Check the randomisation of spike times across recordings. + When a seed is set, this is passed to `generate_sorting` + and so the spike times across all records are expected + to be identical. However, if no seed is set, then the spike + times will be different across recordings. + """ + options["kwargs"]["recording_durations"] = (10,) * options["num_recs"] + + output_sortings_same, extra_outputs_same = generate_session_displacement_recordings(**options["kwargs"])[1:3] + + options["kwargs"]["seed"] = None + output_sortings_different, extra_outputs_different = generate_session_displacement_recordings( + **options["kwargs"] + )[1:3] + + for unit_idx in range(options["kwargs"]["num_units"]): + for rec_idx in range(1, options["num_recs"]): + + # Exact spike times are not preserved when seed is None + assert np.array_equal( + output_sortings_same[0].get_unit_spike_train(unit_idx), + output_sortings_same[rec_idx].get_unit_spike_train(unit_idx), + ) + assert not np.array_equal( + output_sortings_different[0].get_unit_spike_train(unit_idx), + output_sortings_different[rec_idx].get_unit_spike_train(unit_idx), + ) + # Firing rates should always be preserved. + assert np.array_equal( + extra_outputs_same["firing_rates"][0][unit_idx], + extra_outputs_same["firing_rates"][rec_idx][unit_idx], + ) + assert np.array_equal( + extra_outputs_different["firing_rates"][0][unit_idx], + extra_outputs_different["firing_rates"][rec_idx][unit_idx], + ) + + @pytest.mark.parametrize("dim_idx", [0, 1]) + def test_x_y_shift_non_rigid(self, options, dim_idx): + """ + Check that the non-rigid shift changes the channel location + as expected. Non-rigid shifts are calculated depending on the + position of the channel. The `non_rigid_gradient` parameter + determines how much the position or 'distance' of the channel + (w.r.t the gradient of movement) affects the scaling. When + 0, the displacement is scaled by the distance. When 0, the + distance is ignored and all scalings are 1. + + This test checks the generated `unit_locations` under extreme + cases, when `non_rigid_gradient` is `None` or 0, which are equivalent, + and when it is `1`, and the displacement is directly propotional to + the unit position. + """ + options["kwargs"]["recording_shifts"] = ((0, 0), (10, 15), (15, 20), (20, 25)) + + _, _, rigid_info = generate_session_displacement_recordings( + **options["kwargs"], + non_rigid_gradient=None, + ) + _, _, nonrigid_max_info = generate_session_displacement_recordings( + **options["kwargs"], + non_rigid_gradient=0, + ) + _, _, nonrigid_none_info = generate_session_displacement_recordings( + **options["kwargs"], + non_rigid_gradient=1, + ) + + initial_locations = rigid_info["unit_locations"][0] + + # For each recording (i.e. each recording as different displacement + # w.r.t the first recording), check the rigid and nonrigid shifts + # are as expected. + for rec_idx in range(1, options["num_recs"]): + + shift = options["kwargs"]["recording_shifts"][rec_idx][dim_idx] + + # Get the rigid shift between the first recording and this shifted recording + # Check shifts for all unit locations are all the same. + shifts_rigid = self._get_shifts(rigid_info, rec_idx, dim_idx, initial_locations) + shifts_rigid = np.round(shifts_rigid, 5) + + assert np.unique(shifts_rigid).size == 1 + + # Get the nonrigid shift between the first recording and this recording. + # The shift for each unit should be directly proportional to its position. + y_shifts_nonrigid = self._get_shifts(nonrigid_max_info, rec_idx, dim_idx, initial_locations) + + distance = np.linalg.norm(initial_locations, axis=1) + norm_distance = (distance - np.min(distance)) / (np.max(distance) - np.min(distance)) + + assert np.unique(y_shifts_nonrigid).size == options["kwargs"]["num_units"] + + # There is some small rounding error due to difference in distance computation, + # the main thing is the relative order not the absolute value. + assert np.allclose(y_shifts_nonrigid, shift * norm_distance, rtol=0, atol=0.5) + + # then do again with non-ridig-gradient 1 and check it matches rigid case + shifts_rigid_2 = self._get_shifts(nonrigid_none_info, rec_idx, dim_idx, initial_locations) + assert np.array_equal(shifts_rigid, np.round(shifts_rigid_2, 5)) + + def _get_shifts(self, extras_dict, rec_idx, dim_idx, initial_locations): + return extras_dict["unit_locations"][rec_idx][:, dim_idx] - initial_locations[:, dim_idx] + + def test_displacement_with_peak_detection(self, options): + """ + This test checks that the session displacement occurs + as expected under normal usage. Create a recording with a + single unit and a y-axis displacement. Find the peak + locations and check the shifted peak location is as expected, + within the tolerate of the y-axis pitch. + """ + # The seed is important here, otherwise the unit positions + # might go off the end of the probe. These kwargs are + # chosen to make the recording as small as possible as this + # test is slow for larger recordings. + y_shift = 50 + options["kwargs"]["recording_shifts"] = ((0, 0), (0, y_shift)) + options["kwargs"]["recording_durations"] = (0.5, 0.5) + options["num_recs"] = 2 + options["kwargs"]["num_units"] = 1 + options["kwargs"]["generate_probe_kwargs"]["num_contact_per_column"] = 18 + + output_recordings, _, _ = generate_session_displacement_recordings( + **options["kwargs"], generate_noise_kwargs=dict(noise_levels=(1.0, 2.0), spatial_decay=1.0) + ) + + first_recording = output_recordings[0] + + # Peak location of unshifted recording + peaks = detect_peaks(first_recording, method="by_channel") + peak_locs = localize_peaks(first_recording, peaks, method="center_of_mass") + first_pos = np.mean(peak_locs["y"]) + + # Find peak location on shifted recording and check it is + # the original location + shift. + shifted_recording = output_recordings[1] + peaks = detect_peaks(shifted_recording, method="by_channel") + peak_locs = localize_peaks(shifted_recording, peaks, method="center_of_mass") + + new_pos = np.mean(peak_locs["y"]) + + assert np.isclose(new_pos, first_pos + y_shift, rtol=0, atol=options["y_bin_um"]) + + def test_amplitude_scalings(self, options): + """ + Test that the templates are scaled by the passed scaling factors + in the specified order. The order can be in the passed order, + in the order of highest-to-lowest firing unit, or in the order + of (amplitude * firing_rate) (highest to lowest unit). + """ + # Setup arguments to create an unshifted set of recordings + # where the templates are to be scaled with `true_scalings` + options["kwargs"]["recording_durations"] = (10, 10) + options["kwargs"]["recording_shifts"] = ((0, 0), (0, 0)) + options["kwargs"]["num_units"] == 5, + + true_scalings = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) + + recording_amplitude_scalings = { + "method": "by_passed_order", + "scalings": (np.ones(5), true_scalings), + } + + _, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], + recording_amplitude_scalings=recording_amplitude_scalings, + ) + + # Check that the unit templates are scaled in the order + # the scalings were passed. + test_scalings = self._calculate_scalings_from_output(extra_outputs) + assert np.allclose(test_scalings, true_scalings) + + # Now run, again applying the scalings in the order of + # unit firing rates (highest to lowest). + firing_rates = np.array([5, 4, 3, 2, 1]) + generate_sorting_kwargs = dict(firing_rates=firing_rates, refractory_period_ms=4.0) + recording_amplitude_scalings["method"] = "by_firing_rate" + _, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], + recording_amplitude_scalings=recording_amplitude_scalings, + generate_sorting_kwargs=generate_sorting_kwargs, + ) + + test_scalings = self._calculate_scalings_from_output(extra_outputs) + assert np.allclose(test_scalings, true_scalings[np.argsort(firing_rates)]) + + # Finally, run again applying the scalings in the order of + # unit amplitude * firing_rate + recording_amplitude_scalings["method"] = "by_amplitude_and_firing_rate" # TODO: method -> order + amplitudes = np.min(np.min(extra_outputs["templates_array_moved"][0], axis=2), axis=1) + firing_rate_by_amplitude = np.argsort(amplitudes * firing_rates) + + _, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], + recording_amplitude_scalings=recording_amplitude_scalings, + generate_sorting_kwargs=generate_sorting_kwargs, + ) + + test_scalings = self._calculate_scalings_from_output(extra_outputs) + assert np.allclose(test_scalings, true_scalings[firing_rate_by_amplitude]) + + def _calculate_scalings_from_output(self, extra_outputs): + first, second = extra_outputs["templates_array_moved"] + first_min = np.min(np.min(first, axis=2), axis=1) + second_min = np.min(np.min(second, axis=2), axis=1) + test_scalings = second_min / first_min + return test_scalings + + def test_metadata(self, options): + """ + Check that metadata required to be set of generated recordings is present + on all output recordings. + """ + output_recordings, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], generate_noise_kwargs=dict(noise_levels=(1.0, 2.0), spatial_decay=1.0) + ) + num_chans = output_recordings[0].get_num_channels() + + for i in range(len(output_recordings)): + assert output_recordings[i].name == "InterSessionDisplacementRecording" + assert output_recordings[i]._annotations["is_filtered"] is True + assert output_recordings[i].has_probe() + assert np.array_equal(output_recordings[i].get_channel_gains(), np.ones(num_chans)) + assert np.array_equal(output_recordings[i].get_channel_offsets(), np.zeros(num_chans)) + + assert np.array_equal( + output_sortings[i].get_property("gt_unit_locations"), extra_outputs["unit_locations"][i] + ) + assert output_sortings[i].name == "InterSessionDisplacementSorting" + + def test_shift_units_outside_probe(self, options): + """ + When `shift_units_outside_probe` is `True`, a new set of + units above and below the probe (y dimension) are created, + such that they may be shifted into the recording. + + Here, check that these new units are created when `shift_units_outside_probe` + is on and that the kwargs for the central set of units match those + as when `shift_units_outside_probe` is `False`. + """ + num_sessions = len(options["kwargs"]["recording_durations"]) + _, _, baseline_outputs = generate_session_displacement_recordings( + **options["kwargs"], + ) + + _, _, outside_probe_outputs = generate_session_displacement_recordings( + **options["kwargs"], shift_units_outside_probe=True + ) + + num_units = options["kwargs"]["num_units"] + num_extended_units = num_units * 3 + + for ses_idx in range(num_sessions): + + # There are 3x the number of units when new units are created + # (one new set above, and one new set below the probe). + for key in ["unit_locations", "templates_array_moved", "firing_rates"]: + assert outside_probe_outputs[key][ses_idx].shape[0] == num_extended_units + + assert np.array_equal( + baseline_outputs[key][ses_idx], outside_probe_outputs[key][ses_idx][num_units:-num_units] + ) + + # The kwargs of the units in the central positions should be identical + # to those when `shift_units_outside_probe` is `False`. + lower_unit_pos = outside_probe_outputs["unit_locations"][ses_idx][-num_units:][:, 1] + upper_unit_pos = outside_probe_outputs["unit_locations"][ses_idx][:num_units][:, 1] + middle_unit_pos = baseline_outputs["unit_locations"][ses_idx][:, 1] + + assert np.min(upper_unit_pos) > np.max(middle_unit_pos) + assert np.max(lower_unit_pos) < np.min(middle_unit_pos) + + def test_same_as_generate_ground_truth_recording(self): + """ + It is expected that inter-session displacement randomly + generated recording and injected motion recording will + use exactly the same method to generate the ground-truth + recording (without displacement or motion). To check this, + set their kwargs equal and seed, then generate a non-displaced + recording. It should be identical to the static recroding + generated by `generate_drifting_recording()`. + """ + + # Set some shared kwargs + num_units = 5 + duration = 10 + sampling_frequency = 30000.0 + probe_name = "Neuropixel-128" + generate_probe_kwargs = None + generate_unit_locations_kwargs = dict() + generate_templates_kwargs = dict(ms_before=1.5, ms_after=3) + generate_sorting_kwargs = dict(firing_rates=1) + generate_noise_kwargs = dict() + seed = 42 + + # Generate a inter-session displacement recording with no displacement + no_shift_recording, _ = generate_session_displacement_recordings( + num_units=num_units, + recording_durations=[duration], + recording_shifts=((0, 0),), + sampling_frequency=sampling_frequency, + probe_name=probe_name, + generate_probe_kwargs=generate_probe_kwargs, + generate_unit_locations_kwargs=generate_unit_locations_kwargs, + generate_templates_kwargs=generate_templates_kwargs, + generate_sorting_kwargs=generate_sorting_kwargs, + generate_noise_kwargs=generate_noise_kwargs, + seed=seed, + ) + no_shift_recording = no_shift_recording[0] + + # Generate a drifting recording with no drift + static_recording, _, _ = generate_drifting_recording( + num_units=num_units, + duration=duration, + sampling_frequency=sampling_frequency, + probe_name=probe_name, + generate_probe_kwargs=generate_probe_kwargs, + generate_unit_locations_kwargs=generate_unit_locations_kwargs, + generate_templates_kwargs=generate_templates_kwargs, + generate_sorting_kwargs=generate_sorting_kwargs, + generate_noise_kwargs=generate_noise_kwargs, + generate_displacement_vector_kwargs=dict( + motion_list=[ + dict( + drift_mode="zigzag", + non_rigid_gradient=None, + t_start_drift=1.0, + t_end_drift=None, + period_s=200, + ), + ] + ), + seed=seed, + ) + + # Check the templates and raw data match exactly. + assert np.array_equal( + no_shift_recording.get_traces(start_frame=0, end_frame=10), + static_recording.get_traces(start_frame=0, end_frame=10), + ) + + assert np.array_equal(no_shift_recording.templates, static_recording.drifting_templates.templates_array)