Skip to content

Commit

Permalink
add test and fix
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Dec 10, 2024
1 parent dce4ffd commit 6794957
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 17 deletions.
24 changes: 7 additions & 17 deletions neo/rawio/spikeglxrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,32 +232,22 @@ def _parse_header(self):

self._t_starts = {stream_name: {} for stream_name in stream_names}
self._t_stops = {seg_index: 0.0 for seg_index in range(nb_segment)}
self._t_since_recording_time = {stream_name: {} for stream_name in stream_names}
self._session_start_time = {stream_name: {} for stream_name in stream_names}
self._gate_trigger = {stream_name: {} for stream_name in stream_names}
for seg_index in range(nb_segment):
for stream_name in stream_names:

for stream_name in stream_names:
for seg_index in range(nb_segment):
info = self.signals_info_dict[seg_index, stream_name]

frame_start = float(info["meta"]["firstSample"])
sampling_frequency = info["sampling_rate"]
t_start = frame_start / sampling_frequency

initial_date_time = info["meta"]["fileCreateTime"]
from datetime import datetime
initial_date_time_parsed = datetime.strptime(initial_date_time, "%Y-%m-%dT%H:%M:%S")
initial_timestamp = initial_date_time_parsed.timestamp()
shifted_timestamps = t_start + initial_timestamp


self._t_starts[stream_name][seg_index] = t_start
self._t_since_recording_time[stream_name][seg_index] = (initial_timestamp, shifted_timestamps)
gate_num = info["gate_num"]
trigger_num = info["trigger_num"]
self._gate_trigger[stream_name][seg_index] = f"{gate_num=}, {trigger_num} "
self._session_start_time[stream_name][seg_index] = initial_date_time_parsed
t_stop = info["sample_length"] / info["sampling_rate"]
self._t_stops[seg_index] = max(self._t_stops[seg_index], t_stop)




# fille into header dict
self.header = {}
self.header["nb_block"] = 1
Expand Down
59 changes: 59 additions & 0 deletions neo/test/rawiotest/test_spikeglxrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,65 @@ def test_nidq_digital_channel(self):
atol = 0.001
assert np.allclose(on_diff, 1, atol=atol)

def test_t_start_reading(self):
"""Test that t_start values are correctly read for all streams and segments."""

# Expected t_start values for each stream and segment
expected_t_starts = {
'imec0.ap': {
0: 15.319535472007237,
1: 15.339535431281986,
2: 21.284723325294053,
3: 21.3047232845688
},
'imec1.ap': {
0: 15.319554693264516,
1: 15.339521518106308,
2: 21.284735282142822,
3: 21.304702106984614
},
'imec0.lf': {
0: 15.3191688060872,
1: 15.339168765361949,
2: 21.284356659374016,
3: 21.304356618648765
},
'imec1.lf': {
0: 15.319321358082725,
1: 15.339321516521915,
2: 21.284568614155827,
3: 21.30456877259502
}
}

# Initialize the RawIO
rawio = SpikeGLXRawIO(self.get_local_path("spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI4"))
rawio.parse_header()

# Get list of stream names
stream_names = rawio.header["signal_streams"]["name"]

# Test t_start for each stream and segment
for stream_name, expected_values in expected_t_starts.items():
# Get stream index
stream_index = list(stream_names).index(stream_name)

# Check each segment
for seg_index, expected_t_start in expected_values.items():
actual_t_start = rawio.get_signal_t_start(
block_index=0,
seg_index=seg_index,
stream_index=stream_index
)

# Use numpy.testing for proper float comparison
np.testing.assert_allclose(
actual_t_start,
expected_t_start,
rtol=1e-9,
atol=1e-9,
err_msg=f"Mismatch in t_start for stream '{stream_name}', segment {seg_index}"
)

if __name__ == "__main__":
unittest.main()

0 comments on commit 6794957

Please sign in to comment.