Skip to content

Commit

Permalink
Merge pull request #86 from jsosulski/allow_returning_of_epochs
Browse files Browse the repository at this point in the history
Allow retrieval of epochs instead of np.ndarray in process_raw
  • Loading branch information
Sylvain Chevallier authored Jun 24, 2020
2 parents 3bd566c + ad68920 commit 75a73d6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
33 changes: 23 additions & 10 deletions moabb/paradigms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def prepare_process(self, dataset):
"""
pass

def process_raw(self, raw, dataset):
def process_raw(self, raw, dataset, return_epochs=False):
"""
Process one raw data file.
Expand All @@ -83,10 +83,16 @@ def process_raw(self, raw, dataset):
The dataset corresponding to the raw file. mainly use to access
dataset specific information.
return_epochs: boolean
This flag specifies whether to return only the data array or the
complete processed mne.Epochs
returns
-------
X : np.ndarray
X : Union[np.ndarray, mne.Epochs]
the data that will be used as features for the model
Note: if return_epochs=True, this is mne.Epochs
if return_epochs=False, this is np.ndarray
labels: np.ndarray
the labels for training / evaluating the model
Expand Down Expand Up @@ -141,7 +147,10 @@ def process_raw(self, raw, dataset):
if self.resample is not None:
epochs = epochs.resample(self.resample)
# rescale to work with uV
X.append(dataset.unit_factor * epochs.get_data())
if return_epochs:
X.append(epochs)
else:
X.append(dataset.unit_factor * epochs.get_data())

inv_events = {k: v for v, k in event_id.items()}
labels = np.array([inv_events[e] for e in epochs.events[:, -1]])
Expand All @@ -155,7 +164,7 @@ def process_raw(self, raw, dataset):
metadata = pd.DataFrame(index=range(len(labels)))
return X, labels, metadata

def get_data(self, dataset, subjects=None):
def get_data(self, dataset, subjects=None, return_epochs=False):
"""
Return the data for a list of subject.
Expand Down Expand Up @@ -197,7 +206,8 @@ def get_data(self, dataset, subjects=None):
for subject, sessions in data.items():
for session, runs in sessions.items():
for run, raw in runs.items():
proc = self.process_raw(raw, dataset)
proc = self.process_raw(raw, dataset,
return_epochs=return_epochs)

if proc is None:
# this mean the run did not contain any selected event
Expand All @@ -211,12 +221,15 @@ def get_data(self, dataset, subjects=None):
metadata.append(met)

# grow X and labels in a memory efficient way. can be slow
if len(X) > 0:
X = np.append(X, x, axis=0)
labels = np.append(labels, lbs, axis=0)
if not return_epochs:
if len(X) > 0:
X = np.append(X, x, axis=0)
labels = np.append(labels, lbs, axis=0)
else:
X = x
labels = lbs
else:
X = x
labels = lbs
X.append(x)

metadata = pd.concat(metadata, ignore_index=True)
return X, labels, metadata
7 changes: 5 additions & 2 deletions moabb/paradigms/p300.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def is_valid(self, dataset):
def used_events(self, dataset):
pass

def process_raw(self, raw, dataset):
def process_raw(self, raw, dataset, return_epochs=False):
# find the events, first check stim_channels then annotations
stim_channels = mne.utils._get_stim_channel(
None, raw.info, raise_error=False)
Expand Down Expand Up @@ -126,7 +126,10 @@ def process_raw(self, raw, dataset):
if self.resample is not None:
epochs = epochs.resample(self.resample)
# rescale to work with uV
X.append(dataset.unit_factor * epochs.get_data())
if return_epochs:
X.append(epochs)
else:
X.append(dataset.unit_factor * epochs.get_data())

inv_events = {k: v for v, k in event_id.items()}
labels = np.array([inv_events[e] for e in epochs.events[:, -1]])
Expand Down

0 comments on commit 75a73d6

Please sign in to comment.