Skip to content

Commit

Permalink
Merge pull request #116 from alejoe91/fix_scipy
Browse files Browse the repository at this point in the history
Remove decimate for instability
  • Loading branch information
alejoe91 authored Jul 18, 2022
2 parents 8769f5f + 065e32b commit 7df5231
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 16 deletions.
4 changes: 2 additions & 2 deletions MEArec/generators/recgensteps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
Important:
When tmp_mode=='memmap' : theses functions must assign and add directly the buffer.
When tmp_mode is Noe : theses functions return the buffer and the assignament is done externally.
When tmp_mode=='memmap' : these functions must assign and add directly the buffer.
When tmp_mode is None : these functions return the buffer and the assignament is done externally.
Expand Down
12 changes: 7 additions & 5 deletions MEArec/generators/recordinggenerator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# don't enter here without a good guide! (only one person in the world)

from distutils.log import DEBUG
import numpy as np
import time
from copy import deepcopy
Expand Down Expand Up @@ -34,9 +35,9 @@
use_loader = False


debug = True
DEBUG = False

if debug:
if DEBUG:
import matplotlib.pyplot as plt
plt.ion()
plt.show()
Expand Down Expand Up @@ -315,11 +316,14 @@ def generate_recordings(self, tmp_mode=None, tmp_folder=None, n_jobs=0, template
else:
params['recordings']['dtype'] = rec_params['dtype']
dtype = params['recordings']['dtype']

assert np.dtype(dtype).kind in ("i", "f"), "Only integers and float dtypes are supported"

params['recordings']['adc_bit_depth'] = rec_params.get('adc_bit_depth', None)
adc_bit_depth = params['recordings']['adc_bit_depth']
params['recordings']['lsb'] = rec_params.get('lsb', None)
lsb = params['recordings']['lsb']
if lsb is None:
if lsb is None and np.dtype(dtype).kind == "i":
lsb = 1
params['recordings']['gain'] = rec_params.get('gain', None)
gain = params['recordings']['gain']
Expand Down Expand Up @@ -1380,8 +1384,6 @@ def run_several_chunks(func, chunk_indexes, fs, lsb, args, n_jobs, tmp_mode, ass
or in paralell if n_jobs>1
The function can return
"""

# create task list
Expand Down
21 changes: 14 additions & 7 deletions MEArec/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
else:
use_loader = False

LOCAL_TMP = True
DEBUG = False

if DEBUG:
import matplotlib.pyplot as plt
plt.ion()
plt.show()


class TestGenerators(unittest.TestCase):
Expand All @@ -36,7 +41,7 @@ def setUpClass(self):
# Set seed
np.random.seed(2308)

if not LOCAL_TMP:
if not DEBUG:
self.test_dir = Path(tempfile.mkdtemp())
else:
self.test_dir = Path('./tmp').absolute()
Expand Down Expand Up @@ -100,7 +105,7 @@ def setUpClass(self):
@classmethod
def tearDownClass(self):
# Remove the directory after the test
if not LOCAL_TMP:
if not DEBUG:
shutil.rmtree(self.test_dir)

def test_gen_templates(self):
Expand Down Expand Up @@ -693,28 +698,30 @@ def test_recordings_backend(self):
rec_params['spiketrains']['n_exc'] = ne
rec_params['spiketrains']['n_inh'] = ni
rec_params['spiketrains']['duration'] = duration
n_jitter = 2
n_jitter = 10
rec_params['templates']['n_jitters'] = n_jitter
rec_params['recordings']['modulation'] = 'none'
rec_params['recordings']['filter'] = False


rec_params['seeds']['templates'] = 0
rec_params['seeds']['spiketrains'] = 0
rec_params['seeds']['convolution'] = 0
rec_params['seeds']['noise'] = 0


n_jobs = [1, 2]
chunk_durations = [0, 1]

for n in n_jobs:
for ch in chunk_durations:
print('Test recording backend with', n, 'jobs - chunk', ch)
rec_params['chunk_duration'] = n
rec_params['recordings']['chunk_duration'] = ch

recgen_memmap = mr.gen_recordings(params=rec_params, tempgen=self.tempgen, tmp_mode='memmap',
verbose=False, n_jobs=n)
recgen_np = mr.gen_recordings(params=rec_params, tempgen=self.tempgen, tmp_mode=None, verbose=False,
n_jobs=n)

assert np.allclose(np.array(recgen_np.recordings), recgen_memmap.recordings.copy(), atol=1e-4)
del recgen_memmap, recgen_np

Expand Down Expand Up @@ -955,4 +962,4 @@ def test_simulate_cell(self):
TestGenerators().setUpClass()
# TestGenerators().test_gen_recordings_drift()
# TestGenerators().test_default_params()
TestGenerators().test_recording_custom_drifts()
TestGenerators().test_recordings_backend()
2 changes: 1 addition & 1 deletion MEArec/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3457,7 +3457,7 @@ def _jitter_parallel(i, template, upsample, fs, n_jitters, jitter, drifting, ver
t_jitt = np.pad(temp_up, [(0, 0), (0, np.abs(shift))], 'constant')[:, -nsamples_up:]
else:
t_jitt = temp_up
temp_down = ss.decimate(t_jitt, upsample, axis=1)
temp_down = t_jitt[:, ::upsample]
templates_jitter[n] = temp_down
else:
if verbose:
Expand Down
2 changes: 1 addition & 1 deletion MEArec/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = '1.8.0.dev0'
version = '1.8.0'

0 comments on commit 7df5231

Please sign in to comment.