-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add session displacement generation #3231
base: main
Are you sure you want to change the base?
Changes from all commits
f94672b
70307a9
d37c4e8
132dbe4
51c6c8b
30f6efb
5e0eac0
f5fdf08
e996dee
faca0e5
dfe5ac9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
|
||
""" | ||
|
||
from __future__ import annotations | ||
import numpy as np | ||
|
||
from probeinterface import generate_multi_columns_probe | ||
|
@@ -21,6 +22,7 @@ | |
) | ||
from .drift_tools import DriftingTemplates, make_linear_displacement, InjectDriftingTemplatesRecording | ||
from .noise_tools import generate_noise | ||
from probeinterface import Probe | ||
|
||
|
||
# this should be moved in probeinterface but later | ||
|
@@ -181,7 +183,7 @@ def generate_displacement_vector( | |
duration : float | ||
Duration of the displacement vector in seconds | ||
unit_locations : np.array | ||
The unit location with shape (num_units, 3) | ||
The unit location with shape (num_units, 2) | ||
displacement_sampling_frequency : float, default: 5. | ||
The sampling frequency of the displacement vector | ||
drift_start_um : list of float, default: [0, 20.] | ||
|
@@ -238,22 +240,70 @@ def generate_displacement_vector( | |
if non_rigid_gradient is None: | ||
displacement_unit_factor[:, m] = 1 | ||
else: | ||
gradient_direction = drift_stop_um - drift_start_um | ||
gradient_direction /= np.linalg.norm(gradient_direction) | ||
|
||
proj = np.dot(unit_locations, gradient_direction).squeeze() | ||
factors = (proj - np.min(proj)) / (np.max(proj) - np.min(proj)) | ||
if non_rigid_gradient < 0: | ||
# reverse | ||
factors = 1 - factors | ||
f = np.abs(non_rigid_gradient) | ||
displacement_unit_factor[:, m] = factors * (1 - f) + f | ||
displacement_unit_factor[:, m] = calculate_displacement_unit_factor( | ||
non_rigid_gradient, unit_locations, drift_start_um, drift_stop_um | ||
) | ||
|
||
displacement_vectors = np.concatenate(displacement_vectors, axis=2) | ||
|
||
return displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps | ||
|
||
|
||
def calculate_displacement_unit_factor( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe this could be called something like "simulate_linear_gradient_drift"? i was a bit confused reading it, but it seems to be generating drift which is 0 at the top of the probe and something not zero at the bottom? maybe someone can help explain what exactly
ends up producing... is it like there is some global drift plus per-unit linear drift? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @cwindolf thanks a lot for this review. This function is a refactoring of this code. I like I agree I got quite confused the first (few) times looking through the use of The
0 * 0.2 + 0.8= 0.8
0 * 0.8 + 0.2 = 0.2 So in the first case, the scaling of the units is only in the range I re-wrote the docstring of the function, let me know if its any clearer, I think there is still room for improvement, I am also not sure how much depth to go into. |
||
non_rigid_gradient: float, unit_locations: np.array, drift_start_um: np.array, drift_stop_um: np.array | ||
) -> np.array: | ||
""" | ||
Introduces a non-rigid drift across the probe, this is a linear | ||
scaling of the displacement based on the unit position. | ||
|
||
To introduce non-rigid drift, a set of scaling factors (one per unit) | ||
are generated. These scale the displacement applied to each unit | ||
as a function of unit position. The smaller the `non_rigid_gradient`, | ||
the larger the influence of the unit position is on scaling the | ||
displacement (more non-linearity). | ||
|
||
The projections of the gradient vector (x, y) | ||
and unit locations (x, y) are normalised to range between | ||
0 and 1 (i.e. based on relative location to the gradient). | ||
|
||
Parameters | ||
---------- | ||
|
||
non_rigid_gradient : float | ||
A number in the range [0, 1] by which to scale the scaling factors | ||
that are based on unit location. This sets the weighting given to the factors | ||
based on unit locations. When 1, the factors will all equal 1 (no effect), | ||
when 0, the scaling factor based on unit location will be used directly. | ||
Smaller number results in more nonlinearity. | ||
unit_locations : np.array | ||
The unit location with shape (num_units, 2) | ||
drift_start_um : np.array | ||
The start boundary of the motion in the x and y direction. | ||
drift_stop_um : np.array | ||
The stop boundary of the motion in the x and y direction. | ||
|
||
Returns | ||
------- | ||
displacement_unit_factor : np.array | ||
An array of scaling factors (one per unit) by which | ||
to scale the displacement. | ||
""" | ||
gradient_direction = drift_stop_um - drift_start_um | ||
gradient_direction /= np.linalg.norm(gradient_direction) | ||
|
||
proj = np.dot(unit_locations, gradient_direction).squeeze() | ||
factors = (proj - np.min(proj)) / (np.max(proj) - np.min(proj)) | ||
|
||
if non_rigid_gradient < 0: # reverse | ||
factors = 1 - factors | ||
|
||
f = np.abs(non_rigid_gradient) | ||
|
||
displacement_unit_factor = factors * (1 - f) + f | ||
|
||
return displacement_unit_factor | ||
|
||
|
||
def generate_drifting_recording( | ||
num_units=250, | ||
duration=600.0, | ||
|
@@ -352,12 +402,9 @@ def generate_drifting_recording( | |
rng = np.random.default_rng(seed=seed) | ||
|
||
# probe | ||
if generate_probe_kwargs is None: | ||
generate_probe_kwargs = _toy_probes[probe_name] | ||
probe = generate_multi_columns_probe(**generate_probe_kwargs) | ||
num_channels = probe.get_contact_count() | ||
probe.set_device_channel_indices(np.arange(num_channels)) | ||
probe = generate_probe(generate_probe_kwargs, probe_name) | ||
channel_locations = probe.contact_positions | ||
|
||
# import matplotlib.pyplot as plt | ||
# import probeinterface.plotting | ||
# fig, ax = plt.subplots() | ||
|
@@ -385,9 +432,7 @@ def generate_drifting_recording( | |
unit_displacements[:, :, direction] += m | ||
|
||
# unit_params need to be fixed before the displacement steps | ||
generate_templates_kwargs = generate_templates_kwargs.copy() | ||
unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed) | ||
generate_templates_kwargs["unit_params"] = unit_params | ||
generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed) | ||
|
||
# generate templates | ||
templates_array = generate_templates( | ||
|
@@ -479,3 +524,50 @@ def generate_drifting_recording( | |
return static_recording, drifting_recording, sorting, extra_infos | ||
else: | ||
return static_recording, drifting_recording, sorting | ||
|
||
|
||
def generate_probe(generate_probe_kwargs: dict, probe_name: str | None = None) -> Probe: | ||
""" | ||
Generate a probe for use in certain ground-truth recordings. | ||
|
||
Parameters | ||
---------- | ||
|
||
generate_probe_kwargs : dict | ||
The kwargs to pass to `generate_multi_columns_probe()` | ||
probe_name : str | None | ||
The probe type if generate_probe_kwargs is None. | ||
""" | ||
if generate_probe_kwargs is None: | ||
assert probe_name is not None, "`probe_name` must be set if `generate_probe_kwargs` is `None`." | ||
generate_probe_kwargs = _toy_probes[probe_name] | ||
probe = generate_multi_columns_probe(**generate_probe_kwargs) | ||
num_channels = probe.get_contact_count() | ||
probe.set_device_channel_indices(np.arange(num_channels)) | ||
|
||
return probe | ||
|
||
|
||
def fix_generate_templates_kwargs(generate_templates_kwargs: dict, num_units: int, seed: int) -> dict: | ||
""" | ||
Fix the generate_template_kwargs such that the same units are created | ||
across calls to `generate_template`. We must explicitly pre-set | ||
the parameters for each unit, done in `_ensure_unit_params()`. | ||
|
||
Parameters | ||
---------- | ||
|
||
generate_templates_kwargs : dict | ||
These kwargs will have the "unit_params" entry edited such that the | ||
parameters are explicitly set for each unit to create (rather than | ||
generated randomly on the fly). | ||
num_units : int | ||
Number of units to fix the kwargs for | ||
seed : int | ||
Random seed. | ||
""" | ||
generate_templates_kwargs = generate_templates_kwargs.copy() | ||
unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed) | ||
generate_templates_kwargs["unit_params"] = unit_params | ||
|
||
return generate_templates_kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure it's worth changing the output of this public function? Alternatively, you could add
at line 154 and keep it as before.
I know @h-mayorquin uses this function.