Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Finally Use Mffpy for read_raw_egi #12981

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
327 changes: 146 additions & 181 deletions mne/io/egi/egimff.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"""EGI NetStation Load Function."""

import datetime
import fnmatch
import itertools
import math
import os.path as op
import re
Expand All @@ -14,15 +16,14 @@
import numpy as np

from ..._fiff.constants import FIFF
from ..._fiff.meas_info import _empty_info, _ensure_meas_date_none_or_dt, create_info
from ..._fiff.meas_info import create_info
from ..._fiff.proj import setup_proj
from ..._fiff.utils import _create_chs, _mult_cal_one
from ...annotations import Annotations
from ...channels.montage import make_dig_montage
from ...evoked import EvokedArray
from ...utils import _check_fname, _check_option, _soft_import, logger, verbose, warn
from ..base import BaseRaw
from .events import _combine_triggers, _read_events, _triage_include_exclude
from .general import (
_block_r,
_extract,
Expand All @@ -35,6 +36,134 @@
REFERENCE_NAMES = ("VREF", "Vertex Reference")


# TODO: Running list
# - [ ] Add support for reading in the PNS data
# - [ ] Add tutorial for reading calibration data
# - [ ] Add support for reading in the channel status (bad channels)
# - [ ] Replace _read_header with mffpy functions?


def _read_mff(input_fname):
"""Read EGI MFF file."""
mff_reader = _get_mff_reader(input_fname)
eeg = _get_eeg_data(mff_reader)
info = _get_info(mff_reader)
annotations = _get_annotations(mff_reader, info)
return eeg, info, annotations


def _get_mff_reader(input_fname):
mffpy = _import_mffpy()
mff_reader = mffpy.Reader(input_fname)
mff_reader.set_unit("EEG", "V") # XXX: set PNS unit
return mff_reader


def _get_montage(mff_reader):
mffpy = _import_mffpy()
xml_files = mff_reader.directory.files_by_type[".xml"]
sensor_fname = fnmatch.filter(xml_files, "sensorLayout")
assert len(sensor_fname) == 1 # XXX: remove
sensor_fname = sensor_fname[0]
with mff_reader.directory.filepointer(sensor_fname) as fp:
sensor_layout = mffpy.XML.from_file(fp).get_content()["sensors"]
n_eeg_channels = mff_reader.num_channels["EEG"] # XXX: PNS?
ch_pos = dict()
hsp = list()
for ch in sensor_layout.values():
# XXX: the y coordinate seems to be inverted? Need to investigate
loc = np.array([ch["x"], -(ch["y"]), ch["z"]]) / 1000
if ch["number"] <= n_eeg_channels:
assert ch["type"] in [0, 1] # XXX: remove
name = f"E{ch['number']}" if ch["name"] == "None" else ch["name"]
ch_pos[name] = loc
elif ch["type"] == 2: # type 2 seems to be headshape points or COM..
if ch["name"] == "COM":
continue
hsp.append(loc)
# XXX: this is still wonky. MNE will complain that the head radius is unusually big
montage = make_dig_montage(ch_pos=ch_pos, coord_frame="head", hsp=hsp)
return montage


def _get_info(mff_reader):
montage = _get_montage(mff_reader)
ch_names = montage.ch_names
ch_types = ["eeg"] * len(ch_names) # XXX: refactor this when adding PNS support
meas_date = mff_reader.startdatetime.astimezone(datetime.timezone.utc)
sfreq = mff_reader.sampling_rates["EEG"] # XXX: check PNS sfreq?
info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
info.set_montage(montage)
info.set_meas_date(meas_date)
return info


def _get_eeg_data(mff_reader):
sfreq = mff_reader.sampling_rates["EEG"] # XXX: check PNS sfreq
n_channels = np.sum(list(mff_reader.num_channels.values()))
epochs = mff_reader.epochs

data_blocks, start_secs, end_secs = [], [], []
for epoch in epochs:
data_chunk, _ = mff_reader.get_physical_samples_from_epoch(epoch)["EEG"] # XXX
data_blocks.append(data_chunk)
start_secs.append(epoch.t0)
end_secs.append(epoch.t1)

first_samp = int(start_secs[0] * sfreq)
last_samp = int(end_secs[-1] * sfreq)
interval = (1 / sfreq) * 1000
all_samps = np.arange(first_samp, last_samp + 1, interval)
eeg = np.zeros((n_channels, len(all_samps)), dtype=np.float64)
for this_chunk, start, end in zip(data_blocks, start_secs, end_secs):
start = int(start * sfreq)
end = int(end * sfreq)
eeg[:, start:end] = this_chunk
return eeg


def _get_gap_annotations(mff_reader):
epochs = mff_reader.epochs
start_secs = [epoch.t0 for epoch in epochs]
end_secs = [epoch.t1 for epoch in epochs]
gap_durations = np.array(start_secs[1:]) - np.array(end_secs[:-1])
descriptions = "BAD_ACQ_SKIP" * len(gap_durations)
gap_onsets = np.array(end_secs[:-1])
gap_annots = Annotations(gap_onsets, gap_durations, descriptions)
return gap_annots


def _get_event_annotations(mff_reader, mne_info):
mffpy = _import_mffpy()
xml_files = mff_reader.directory.files_by_type[".xml"]
events_xmls = fnmatch.filter(xml_files, "Events*")
if not events_xmls:
raise RuntimeError("No events found in MFF file.")
mff_events = {}
for event_file in events_xmls:
with mff_reader.directory.filepointer(event_file) as fp:
categories = mffpy.XML.from_file(fp)
mff_events[event_file] = categories.get_content()["event"]
onsets = []
descriptions = []
mff_events = list(itertools.chain.from_iterable(mff_events.values()))
for event in mff_events:
onset_dt = event["beginTime"].astimezone(datetime.timezone.utc)
ts = (onset_dt - mne_info["meas_date"]).total_seconds()
onsets.append(ts)
# XXX: we could use event["duration"] but it always seems to be 1000ms?
descriptions.append(event["code"])
durations = [0] * len(onsets)
event_annots = Annotations(onsets, durations, descriptions)
return event_annots


def _get_annotations(mff_reader, mne_info):
event_annots = _get_event_annotations(mff_reader, mne_info)
gap_annots = _get_gap_annotations(mff_reader)
return event_annots + gap_annots


def _read_mff_header(filepath):
"""Read mff header."""
_soft_import("defusedxml", "reading EGI MFF data")
Expand Down Expand Up @@ -380,14 +509,14 @@ class RawMff(BaseRaw):
def __init__(
self,
input_fname,
eog=None,
misc=None,
include=None,
exclude=None,
preload=False,
channel_naming="E%d",
eog=None, # XXX: allow user to specify EOG channels?
misc=None, # XXX: allow user to specify misc channels?
include=None, # XXX: Now We dont create stim channels. Remove this?
exclude=None, # XXX: Ditto. But maybe we can exclude events from annots.
preload=False, # XXX: Make this work again
channel_naming="E%d", # XXX: Do we need to still support this?
*,
events_as_annotations=True,
events_as_annotations=True, # XXX: This is now the only way. Remove?
verbose=None,
):
"""Init the RawMff class."""
Expand All @@ -401,183 +530,19 @@ def __init__(
)
)
logger.info(f"Reading EGI MFF Header from {input_fname}...")
egi_info = _read_header(input_fname)
if eog is None:
eog = []
if misc is None:
misc = np.where(np.array(egi_info["chan_type"]) != "eeg")[0].tolist()

logger.info(" Reading events ...")
egi_events, egi_info, mff_events = _read_events(input_fname, egi_info)
cals = _get_eeg_calibration_info(input_fname, egi_info)
logger.info(" Assembling measurement info ...")
event_codes = egi_info["event_codes"]
include = _triage_include_exclude(include, exclude, egi_events, egi_info)
if egi_info["n_events"] > 0 and not events_as_annotations:
logger.info(' Synthesizing trigger channel "STI 014" ...')
if all(ch.startswith("D") for ch in include):
# support the DIN format DIN1, DIN2, ..., DIN9, DI10, DI11, ... DI99,
# D100, D101, ..., D255 that we get when sending 0-255 triggers on a
# parallel port.
events_ids = list()
for ch in include:
while not ch[0].isnumeric():
ch = ch[1:]
events_ids.append(int(ch))
else:
events_ids = np.arange(len(include)) + 1
egi_info["new_trigger"] = _combine_triggers(
egi_events[[c in include for c in event_codes]], remapping=events_ids
)
self.event_id = dict(
zip([e for e in event_codes if e in include], events_ids)
)
if egi_info["new_trigger"] is not None:
egi_events = np.vstack([egi_events, egi_info["new_trigger"]])
else:
self.event_id = None
egi_info["new_trigger"] = None
assert egi_events.shape[1] == egi_info["last_samps"][-1]

meas_dt_utc = egi_info["meas_dt_local"].astimezone(datetime.timezone.utc)
info = _empty_info(egi_info["sfreq"])
info["meas_date"] = _ensure_meas_date_none_or_dt(meas_dt_utc)
info["utc_offset"] = egi_info["utc_offset"]
info["device_info"] = dict(type=egi_info["device"])

# read in the montage, if it exists
ch_names, mon = _read_locs(input_fname, egi_info, channel_naming)
# Second: Stim
ch_names.extend(list(egi_info["event_codes"]))
n_extra = len(event_codes) + len(misc) + len(eog) + len(egi_info["pns_names"])
if egi_info["new_trigger"] is not None:
ch_names.append("STI 014") # channel for combined events
n_extra += 1

# Third: PNS
ch_names.extend(egi_info["pns_names"])

cals = np.concatenate([cals, np.ones(n_extra)])
assert len(cals) == len(ch_names), (len(cals), len(ch_names))

# Actually create channels as EEG, then update stim and PNS
ch_coil = FIFF.FIFFV_COIL_EEG
ch_kind = FIFF.FIFFV_EEG_CH
chs = _create_chs(ch_names, cals, ch_coil, ch_kind, eog, (), (), misc)

sti_ch_idx = [
i
for i, name in enumerate(ch_names)
if name.startswith("STI") or name in event_codes
]
for idx in sti_ch_idx:
chs[idx].update(
{
"unit_mul": FIFF.FIFF_UNITM_NONE,
"cal": cals[idx],
"kind": FIFF.FIFFV_STIM_CH,
"coil_type": FIFF.FIFFV_COIL_NONE,
"unit": FIFF.FIFF_UNIT_NONE,
}
)
chs = _add_pns_channel_info(chs, egi_info, ch_names)
info["chs"] = chs
info._unlocked = False
info._update_redundant()

if mon is not None:
info.set_montage(mon, on_missing="ignore")

ref_idx = np.flatnonzero(np.isin(mon.ch_names, REFERENCE_NAMES))
if len(ref_idx):
ref_idx = ref_idx.item()
ref_coords = info["chs"][int(ref_idx)]["loc"][:3]
for chan in info["chs"]:
if chan["kind"] == FIFF.FIFFV_EEG_CH:
chan["loc"][3:6] = ref_coords

file_bin = op.join(input_fname, egi_info["eeg_fname"])
egi_info["egi_events"] = egi_events

# Check how many channels to read are from EEG
keys = ("eeg", "sti", "pns")
idx = dict()
idx["eeg"] = np.where([ch["kind"] == FIFF.FIFFV_EEG_CH for ch in chs])[0]
idx["sti"] = np.where([ch["kind"] == FIFF.FIFFV_STIM_CH for ch in chs])[0]
idx["pns"] = np.where(
[
ch["kind"] in (FIFF.FIFFV_ECG_CH, FIFF.FIFFV_EMG_CH, FIFF.FIFFV_BIO_CH)
for ch in chs
]
)[0]
# By construction this should always be true, but check anyway
if not np.array_equal(
np.concatenate([idx[key] for key in keys]), np.arange(len(chs))
):
raise ValueError(
"Currently interlacing EEG and PNS channels is not supported"
)
egi_info["kind_bounds"] = [0]
for key in keys:
egi_info["kind_bounds"].append(len(idx[key]))
egi_info["kind_bounds"] = np.cumsum(egi_info["kind_bounds"])
assert egi_info["kind_bounds"][0] == 0
assert egi_info["kind_bounds"][-1] == info["nchan"]
first_samps = [0]
last_samps = [egi_info["last_samps"][-1] - 1]

annot = dict(onset=list(), duration=list(), description=list())

if len(idx["pns"]):
# PNS Data is present and should be read:
egi_info["pns_filepath"] = op.join(input_fname, egi_info["pns_fname"])
# Check for PNS bug immediately
pns_samples = np.sum(egi_info["pns_sample_blocks"]["samples_block"])
eeg_samples = np.sum(egi_info["samples_block"])
if pns_samples == eeg_samples - 1:
warn("This file has the EGI PSG sample bug")
annot["onset"].append(last_samps[-1] / egi_info["sfreq"])
annot["duration"].append(1 / egi_info["sfreq"])
annot["description"].append("BAD_EGI_PSG")
elif pns_samples != eeg_samples:
raise RuntimeError(
"PNS samples (%d) did not match EEG samples (%d)"
% (pns_samples, eeg_samples)
)
eeg, info, annots = _read_mff(input_fname)

super().__init__(
info,
preload=preload,
orig_format="single",
filenames=[file_bin],
first_samps=first_samps,
last_samps=last_samps,
raw_extras=[egi_info],
preload=eeg, # XXX: Make eager/lazy loading work again
orig_format="single", # XXX: Check if this is still correct
filenames=[input_fname], # XXX: multiple files? I need an example
first_samps=(0,), # XXX: multiple files?
last_samps=None, # XXX: multiple files?
raw_extras=(None,), # XXX: do we still need this?
verbose=verbose,
)

# Annotate acquisition skips
for first, prev_last in zip(
egi_info["first_samps"][1:], egi_info["last_samps"][:-1]
):
gap = first - prev_last
assert gap >= 0
if gap:
annot["onset"].append((prev_last - 0.5) / egi_info["sfreq"])
annot["duration"].append(gap / egi_info["sfreq"])
annot["description"].append("BAD_ACQ_SKIP")

# create events from annotations
if events_as_annotations:
for code, samples in mff_events.items():
if code not in include:
continue
annot["onset"].extend(np.array(samples) / egi_info["sfreq"])
annot["duration"].extend([0.0] * len(samples))
annot["description"].extend([code] * len(samples))

if len(annot["onset"]):
self.set_annotations(Annotations(**annot))
self.set_annotations(annots)

def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
"""Read a chunk of data."""
Expand Down