Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add session displacement generation #3231

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
33 changes: 25 additions & 8 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def generate_sorting(
add_spikes_on_borders=False,
num_spikes_per_border=3,
border_size_samples=20,
extra_outputs=False,
seed=None,
):
"""
Expand Down Expand Up @@ -135,10 +136,14 @@ def generate_sorting(
num_segments = len(durations)
unit_ids = np.arange(num_units)

extra_outputs_dict = {
"firing_rates": [],
}

spikes = []
for segment_index in range(num_segments):
num_samples = int(sampling_frequency * durations[segment_index])
samples, labels = synthesize_poisson_spike_vector(
samples, labels, firing_rates_array = synthesize_poisson_spike_vector(
num_units=num_units,
sampling_frequency=sampling_frequency,
duration=durations[segment_index],
Expand Down Expand Up @@ -172,12 +177,17 @@ def generate_sorting(
)
spikes.append(spikes_on_borders)

extra_outputs_dict["firing_rates"].append(firing_rates_array)

spikes = np.concatenate(spikes)
spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))]

sorting = NumpySorting(spikes, sampling_frequency, unit_ids)

return sorting
if extra_outputs:
return sorting, extra_outputs_dict
else:
return sorting


def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None):
Expand Down Expand Up @@ -776,7 +786,7 @@ def synthesize_poisson_spike_vector(
unit_indices = unit_indices[sort_indices]
spike_frames = spike_frames[sort_indices]

return spike_frames, unit_indices
return spike_frames, unit_indices, firing_rates
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's worth changing the output of this public function? Alternatively, you could add

firing_rates_array = _ensure_firing_rates(firing_rates, num_units, seed)

at line 154 and keep it as before.

I know @h-mayorquin uses this function.



def synthesize_random_firings(
Expand Down Expand Up @@ -2188,12 +2198,19 @@ def generate_ground_truth_recording(
parent_recording=noise_rec,
upsample_vector=upsample_vector,
)
recording.annotate(is_filtered=True)
recording.set_probe(probe, in_place=True)
recording.set_channel_gains(1.0)
recording.set_channel_offsets(0.0)

setup_inject_templates_recording(recording, probe)
recording.name = "GroundTruthRecording"
sorting.name = "GroundTruthSorting"

return recording, sorting


def setup_inject_templates_recording(recording: BaseRecording, probe: Probe) -> None:
"""
Convenience function to modify a generated
recording in-place with annotation and probe details
"""
recording.annotate(is_filtered=True)
recording.set_probe(probe, in_place=True)
recording.set_channel_gains(1.0)
recording.set_channel_offsets(0.0)
130 changes: 111 additions & 19 deletions src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

"""

from __future__ import annotations
import numpy as np

from probeinterface import generate_multi_columns_probe
Expand All @@ -21,6 +22,7 @@
)
from .drift_tools import DriftingTemplates, make_linear_displacement, InjectDriftingTemplatesRecording
from .noise_tools import generate_noise
from probeinterface import Probe


# this should be moved in probeinterface but later
Expand Down Expand Up @@ -181,7 +183,7 @@ def generate_displacement_vector(
duration : float
Duration of the displacement vector in seconds
unit_locations : np.array
The unit location with shape (num_units, 3)
The unit location with shape (num_units, 2)
displacement_sampling_frequency : float, default: 5.
The sampling frequency of the displacement vector
drift_start_um : list of float, default: [0, 20.]
Expand Down Expand Up @@ -238,22 +240,70 @@ def generate_displacement_vector(
if non_rigid_gradient is None:
displacement_unit_factor[:, m] = 1
else:
gradient_direction = drift_stop_um - drift_start_um
gradient_direction /= np.linalg.norm(gradient_direction)

proj = np.dot(unit_locations, gradient_direction).squeeze()
factors = (proj - np.min(proj)) / (np.max(proj) - np.min(proj))
if non_rigid_gradient < 0:
# reverse
factors = 1 - factors
f = np.abs(non_rigid_gradient)
displacement_unit_factor[:, m] = factors * (1 - f) + f
displacement_unit_factor[:, m] = calculate_displacement_unit_factor(
non_rigid_gradient, unit_locations, drift_start_um, drift_stop_um
)

displacement_vectors = np.concatenate(displacement_vectors, axis=2)

return displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps


def calculate_displacement_unit_factor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this could be called something like "simulate_linear_gradient_drift"? i was a bit confused reading it, but it seems to be generating drift which is 0 at the top of the probe and something not zero at the bottom?

maybe someone can help explain what exactly

displacement_unit_factor = factors * (1 - f) + f

ends up producing... is it like there is some global drift plus per-unit linear drift?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @cwindolf thanks a lot for this review. This function is a refactoring of this code. I like simulate_linear_gradient_drift, for this PR I will keep the current naming for consistency with the old code. However I'll make an issue based on some of the points you raise in this PR (e.g. including some things in the within-session drift) and add a note on this there.

I agree I got quite confused the first (few) times looking through the use of non_rigid_gradient. I think the easiest way to see it is with some example values. In the first part of the function, the dot-product of the displacement vector and the unit location (expressed as a vector from the probe origin, I am not sure where it is, maybe bottom-left). In the y-displacement only case, this is just the unit y position. These unit positions are scaled to [0, 1] and called factors.

The f and expression you show ensure that the 'largest' unit location (e.g. near the top of the probe if the origin is bottom left) is scaled by 1 (no change). The scaling is linear across all unit positions, its kind of like a linspace where the max value is 1 and f sets the min value. e.g. looking at the smallest, largest and a middle location unit (i.e. factors 0, 1, and 0.5)

non_rigid_gradient=0.8

0 * 0.2 + 0.8= 0.8
0.5 * 0.2 + 0.8 = 0.9
1 * 0.8 + 0.2 = 1

non_rigid_gradient=0.2

0 * 0.8 + 0.2 = 0.2
0.5 * 0.8 + 0.2 = 0.6
1 * 0.8 + 0.2 = 1

So in the first case, the scaling of the units is only in the range [0.8, 1] of the normalised position of the unit. But for the smaller non_rigid_gradient=0.2, the scaling is between [0.2, 1].

I re-wrote the docstring of the function, let me know if its any clearer, I think there is still room for improvement, I am also not sure how much depth to go into.

non_rigid_gradient: float, unit_locations: np.array, drift_start_um: np.array, drift_stop_um: np.array
) -> np.array:
"""
Introduces a non-rigid drift across the probe, this is a linear
scaling of the displacement based on the unit position.

To introduce non-rigid drift, a set of scaling factors (one per unit)
are generated. These scale the displacement applied to each unit
as a function of unit position. The smaller the `non_rigid_gradient`,
the larger the influence of the unit position is on scaling the
displacement (more non-linearity).

The projections of the gradient vector (x, y)
and unit locations (x, y) are normalised to range between
0 and 1 (i.e. based on relative location to the gradient).

Parameters
----------

non_rigid_gradient : float
A number in the range [0, 1] by which to scale the scaling factors
that are based on unit location. This sets the weighting given to the factors
based on unit locations. When 1, the factors will all equal 1 (no effect),
when 0, the scaling factor based on unit location will be used directly.
Smaller number results in more nonlinearity.
unit_locations : np.array
The unit location with shape (num_units, 2)
drift_start_um : np.array
The start boundary of the motion in the x and y direction.
drift_stop_um : np.array
The stop boundary of the motion in the x and y direction.

Returns
-------
displacement_unit_factor : np.array
An array of scaling factors (one per unit) by which
to scale the displacement.
"""
gradient_direction = drift_stop_um - drift_start_um
gradient_direction /= np.linalg.norm(gradient_direction)

proj = np.dot(unit_locations, gradient_direction).squeeze()
factors = (proj - np.min(proj)) / (np.max(proj) - np.min(proj))

if non_rigid_gradient < 0: # reverse
factors = 1 - factors

f = np.abs(non_rigid_gradient)

displacement_unit_factor = factors * (1 - f) + f

return displacement_unit_factor


def generate_drifting_recording(
num_units=250,
duration=600.0,
Expand Down Expand Up @@ -352,12 +402,9 @@ def generate_drifting_recording(
rng = np.random.default_rng(seed=seed)

# probe
if generate_probe_kwargs is None:
generate_probe_kwargs = _toy_probes[probe_name]
probe = generate_multi_columns_probe(**generate_probe_kwargs)
num_channels = probe.get_contact_count()
probe.set_device_channel_indices(np.arange(num_channels))
probe = generate_probe(generate_probe_kwargs, probe_name)
channel_locations = probe.contact_positions

# import matplotlib.pyplot as plt
# import probeinterface.plotting
# fig, ax = plt.subplots()
Expand Down Expand Up @@ -385,9 +432,7 @@ def generate_drifting_recording(
unit_displacements[:, :, direction] += m

# unit_params need to be fixed before the displacement steps
generate_templates_kwargs = generate_templates_kwargs.copy()
unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed)
generate_templates_kwargs["unit_params"] = unit_params
generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed)

# generate templates
templates_array = generate_templates(
Expand Down Expand Up @@ -479,3 +524,50 @@ def generate_drifting_recording(
return static_recording, drifting_recording, sorting, extra_infos
else:
return static_recording, drifting_recording, sorting


def generate_probe(generate_probe_kwargs: dict, probe_name: str | None = None) -> Probe:
"""
Generate a probe for use in certain ground-truth recordings.

Parameters
----------

generate_probe_kwargs : dict
The kwargs to pass to `generate_multi_columns_probe()`
probe_name : str | None
The probe type if generate_probe_kwargs is None.
"""
if generate_probe_kwargs is None:
assert probe_name is not None, "`probe_name` must be set if `generate_probe_kwargs` is `None`."
generate_probe_kwargs = _toy_probes[probe_name]
probe = generate_multi_columns_probe(**generate_probe_kwargs)
num_channels = probe.get_contact_count()
probe.set_device_channel_indices(np.arange(num_channels))

return probe


def fix_generate_templates_kwargs(generate_templates_kwargs: dict, num_units: int, seed: int) -> dict:
"""
Fix the generate_template_kwargs such that the same units are created
across calls to `generate_template`. We must explicitly pre-set
the parameters for each unit, done in `_ensure_unit_params()`.

Parameters
----------

generate_templates_kwargs : dict
These kwargs will have the "unit_params" entry edited such that the
parameters are explicitly set for each unit to create (rather than
generated randomly on the fly).
num_units : int
Number of units to fix the kwargs for
seed : int
Random seed.
"""
generate_templates_kwargs = generate_templates_kwargs.copy()
unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed)
generate_templates_kwargs["unit_params"] = unit_params

return generate_templates_kwargs
Loading
Loading