Skip to content

Commit

Permalink
Merge pull request #83 from jsosulski/allow_event_lists
Browse files Browse the repository at this point in the history
Allow event lists in P300 paradigm
  • Loading branch information
jsosulski authored Jan 15, 2021
2 parents 009d6fa + fdd9f46 commit 374a01f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
6 changes: 6 additions & 0 deletions moabb/paradigms/p300.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def process_raw(self, raw, dataset, return_epochs=False):

# pick events, based on event_id
try:
if (type(event_id['Target']) is list and
type(event_id['NonTarget']) == list):
event_id_new = dict(Target=1, NonTarget=0)
events = mne.merge_events(events, event_id['Target'], 1)
events = mne.merge_events(events, event_id['NonTarget'], 0)
event_id = event_id_new
events = mne.pick_events(events, include=list(event_id.values()))
except RuntimeError:
# skip raw if no event found
Expand Down
13 changes: 8 additions & 5 deletions moabb/tests/paradigms.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,16 @@ class Test_P300(unittest.TestCase):

def test_BaseP300_paradigm(self):
paradigm = SimpleP300()
dataset = FakeDataset(paradigm='p300')
dataset = FakeDataset(paradigm='p300',
event_list=['Target', 'NonTarget'])
X, labels, metadata = paradigm.get_data(dataset, subjects=[1])

# we should have all the same length
self.assertEqual(len(X), len(labels), len(metadata))
# X must be a 3D Array
self.assertEqual(len(X.shape), 3)
# labels must contain 3 values
self.assertEqual(len(np.unique(labels)), 3)
# labels must contain 2 values (Target/NonTarget)
self.assertEqual(len(np.unique(labels)), 2)

# metadata must have subjets, sessions, runs
self.assertTrue('subject' in metadata.columns)
Expand All @@ -160,7 +161,8 @@ def test_BaseP300_tmintmax(self):
def test_BaseP300_filters(self):
# can work with filter bank
paradigm = SimpleP300(filters=[[1, 12], [12, 24]])
dataset = FakeDataset(paradigm='p300')
dataset = FakeDataset(paradigm='p300',
event_list=['Target', 'NonTarget'])
X, labels, metadata = paradigm.get_data(dataset, subjects=[1])

# X must be a 4D Array
Expand All @@ -171,7 +173,8 @@ def test_BaseP300_wrongevent(self):
# test process_raw return empty list if raw does not contain any
# selected event. cetain runs in dataset are event specific.
paradigm = SimpleP300(filters=[[1, 12], [12, 24]])
dataset = FakeDataset(paradigm='p300')
dataset = FakeDataset(paradigm='p300',
event_list=['Target', 'NonTarget'])
raw = dataset.get_data([1])[1]['session_0']['run_0']
# add something on the event channel
raw._data[-1] *= 10
Expand Down

0 comments on commit 374a01f

Please sign in to comment.