Skip to content

Commit

Permalink
Merge pull request #1 from NeuroTechX/master
Browse files Browse the repository at this point in the history
Update Fork
  • Loading branch information
ragatti authored Aug 1, 2020
2 parents 48c538b + 840c005 commit de1b659
Show file tree
Hide file tree
Showing 17 changed files with 227 additions and 102 deletions.
6 changes: 3 additions & 3 deletions moabb/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def score_plot(data, pipelines=None):
ax.axvline(0.5, linestyle='--', color='k', linewidth=2)
ax.set_title('Scores per dataset and algorithm')
handles, labels = ax.get_legend_handles_labels()
color_dict = {l: h.get_facecolor()[0] for l, h in zip(labels, handles)}
color_dict = {lb: h.get_facecolor()[0] for lb, h in zip(labels, handles)}
plt.tight_layout()
return fig, color_dict

Expand Down Expand Up @@ -105,8 +105,8 @@ def summary_plot(sig_df, effect_df, p_threshold=0.05, simplify=True):
fmt='', cmap=palette, linewidths=1,
linecolor='0.8', annot_kws={'size': 10}, cbar=False,
vmin=-np.log(0.05), vmax=-np.log(1e-100))
for l in ax.get_xticklabels():
l.set_rotation(45)
for lb in ax.get_xticklabels():
lb.set_rotation(45)
ax.tick_params(axis='y', rotation=0.9)
ax.set_title("Algorithm comparison")
plt.tight_layout()
Expand Down
9 changes: 6 additions & 3 deletions moabb/analysis/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Results:
'''

def __init__(self, evaluation_class, paradigm_class, suffix='',
overwrite=False):
overwrite=False, hdf5_path=None):
"""
class that will abstract result storage
"""
Expand All @@ -49,8 +49,11 @@ class that will abstract result storage
assert issubclass(evaluation_class, BaseEvaluation)
assert issubclass(paradigm_class, BaseParadigm)

self.mod_dir = os.path.dirname(
os.path.abspath(inspect.getsourcefile(moabb)))
if hdf5_path is None:
self.mod_dir = os.path.dirname(
os.path.abspath(inspect.getsourcefile(moabb)))
else:
self.mod_dir = os.path.abspath(hdf5_path)
self.filepath = os.path.join(self.mod_dir, 'results',
paradigm_class.__name__,
evaluation_class.__name__,
Expand Down
2 changes: 1 addition & 1 deletion moabb/datasets/Weibo2014.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _get_single_subject_data(self, subject):
ch_types[61] = 'misc'
info = mne.create_info(ch_names=ch_names + ['STIM014'],
ch_types=ch_types + ['stim'],
sfreq=200, montage=None)
sfreq=200)
# until we get the channel names montage is None
event_ids = data['label'].ravel()
raw_data = np.transpose(data['data'], axes=[2, 0, 1])
Expand Down
3 changes: 2 additions & 1 deletion moabb/datasets/bbci_eeg_fnirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ def _convert_one_session(self, data, mrk, session, trig_offset=0):

montage = make_standard_montage('standard_1005')
info = create_info(ch_names=ch_names, ch_types=ch_types,
sfreq=200., montage=montage)
sfreq=200.)
raw = RawArray(data=eeg, info=info, verbose=False)
raw.set_montage(montage)
return {'run_0': raw}

def data_path(self, subject, path=None, force_update=False,
Expand Down
20 changes: 10 additions & 10 deletions moabb/datasets/bnci.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,7 @@ def _load_data_003_2015(subject,
ch_types = ['eeg'] * 8 + ['stim'] * 2
montage = make_standard_montage('standard_1005')

info = create_info(
ch_names=ch_names, ch_types=ch_types, sfreq=sfreq, montage=montage)
info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq)

sessions = {}
sessions['session_0'] = {}
Expand All @@ -325,6 +324,7 @@ def _load_data_003_2015(subject,

eeg_data = np.r_[run[1:-2] * 1e-6, targets, flashs]
raw = RawArray(data=eeg_data, info=info, verbose=verbose)
raw.set_montage(montage)
sessions['session_0']['run_' + str(ri)] = raw

return sessions
Expand Down Expand Up @@ -531,9 +531,9 @@ def _convert_run(run, ch_names=None, ch_types=None, verbose=None):
ch_names = ch_names + ['stim']
ch_types = ch_types + ['stim']
event_id = {ev: (ii + 1) for ii, ev in enumerate(run.classes)}
info = create_info(
ch_names=ch_names, ch_types=ch_types, sfreq=sfreq, montage=montage)
info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq)
raw = RawArray(data=eeg_data.T, info=info, verbose=verbose)
raw.set_montage(montage)
return raw, event_id


Expand All @@ -551,9 +551,9 @@ def _convert_run_p300_sl(run, verbose=None):
eeg_data = np.c_[eeg_data, run.y, flash_stim]
event_id = {ev: (ii + 1) for ii, ev in enumerate(run.classes)}
event_id.update({ev: (ii + 3) for ii, ev in enumerate(run.classes_stim)})
info = create_info(
ch_names=ch_names, ch_types=ch_types, sfreq=sfreq, montage=montage)
info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq)
raw = RawArray(data=eeg_data.T, info=info, verbose=verbose)
raw.set_montage(montage)
return raw, event_id


Expand Down Expand Up @@ -596,9 +596,9 @@ def _convert_run_bbci(run, ch_types, verbose=None):
ch_names = ch_names + ['Target', 'Flash']
ch_types = ch_types + ['stim'] * 2

info = create_info(
ch_names=ch_names, ch_types=ch_types, sfreq=sfreq, montage=montage)
info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq)
raw = RawArray(data=eeg_data.T, info=info, verbose=verbose)
raw.set_montage(montage)
return raw, event_id


Expand Down Expand Up @@ -628,9 +628,9 @@ def _convert_run_epfl(run, verbose=None):
ch_types = ch_types + ['stim']
event_id = {'correct': 1, 'error': 2}

info = create_info(
ch_names=ch_names, ch_types=ch_types, sfreq=sfreq, montage=montage)
info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq)
raw = RawArray(data=eeg_data.T, info=info, verbose=verbose)
raw.set_montage(montage)
return raw, event_id


Expand Down
2 changes: 1 addition & 1 deletion moabb/datasets/braininvaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def data_path(self, subject, path=None, force_update=False,
meta_file = os.path.join('subject{:d}'.format(subject), 'meta.yml')
meta_path = path_folder + meta_file
with open(meta_path, 'r') as stream:
meta = yaml.load(stream)
meta = yaml.load(stream, Loader=yaml.FullLoader)
conditions = []
if self.adaptive:
conditions = conditions + ['adaptive']
Expand Down
12 changes: 11 additions & 1 deletion moabb/datasets/epfl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import datetime as dt
from moabb.datasets.base import BaseDataset
from moabb.datasets import download as dl
from mne.channels import make_standard_montage
from scipy.io import loadmat
import zipfile

Expand Down Expand Up @@ -115,6 +116,13 @@ def _get_single_run_data(self, file_path):
'MA2']
ch_types = ['eeg'] * 32 + ['misc'] * 2

# The last X entries are 0 for all signals. This leads to
# artifacts when epoching and band-pass filtering the data.
# Correct the signals for this.
sig_i = np.where(
np.diff(np.all(signals == 0, axis=0).astype(int)) != 0)[0][0]
signals = signals[:, :sig_i]
signals *= 1e-6 # data is stored as uV, but MNE expects V
# we have to re-reference the signals
# the average signal on the mastoids electrodes is used as reference
references = [32, 33]
Expand Down Expand Up @@ -149,11 +157,13 @@ def _get_single_run_data(self, file_path):
signals = np.concatenate([signals, stim_channel[None, :]])

# create info dictionary
info = mne.create_info(ch_names, sfreq, ch_types, montage='biosemi32')
info = mne.create_info(ch_names, sfreq, ch_types)
info['description'] = 'EPFL P300 dataset'

# create the Raw structure
raw = mne.io.RawArray(signals, info, verbose=False)
montage = make_standard_montage('biosemi32')
raw.set_montage(montage)

return raw

Expand Down
3 changes: 2 additions & 1 deletion moabb/datasets/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def _generate_raw(self):
eeg_data = np.c_[eeg_data, y]

info = create_info(ch_names=ch_names, ch_types=ch_types,
sfreq=sfreq, montage=montage)
sfreq=sfreq)
raw = RawArray(data=eeg_data.T, info=info, verbose=False)
raw.set_montage(montage)
return raw

def data_path(self, subject, path=None, force_update=False,
Expand Down
3 changes: 2 additions & 1 deletion moabb/datasets/gigadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def _get_single_subject_data(self, subject):
"continuous data -- edge effects present")

info = create_info(ch_names=ch_names, ch_types=ch_types,
sfreq=data.srate, montage=montage)
sfreq=data.srate)
raw = RawArray(data=eeg_data, info=info, verbose=False)
raw.set_montage(montage)

return {'session_0': {'run_0': raw}}

Expand Down
5 changes: 3 additions & 2 deletions moabb/datasets/schirrmeister2017.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ def get_all_sensors(filename, pattern=None):
"""
with h5py.File(filename, 'r') as h5file:
clab_set = h5file['nfo']['clab'][:].squeeze()
all_sensor_names = [''.join(chr(c) for c in h5file[obj_ref]) for
obj_ref in clab_set]
all_sensor_names = [''.join(
chr(c.squeeze()) for c in h5file[obj_ref])
for obj_ref in clab_set]
if pattern is not None:
all_sensor_names = filter(
lambda sname: re.search(pattern, sname),
Expand Down
7 changes: 5 additions & 2 deletions moabb/evaluations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ class BaseEvaluation(ABC):
'''

def __init__(self, paradigm, datasets=None, random_state=None, n_jobs=1,
overwrite=False, error_score='raise', suffix=''):
overwrite=False, error_score='raise', suffix='',
hdf5_path=None):
self.random_state = random_state
self.n_jobs = n_jobs
self.error_score = error_score
self.hdf5_path = hdf5_path

# check paradigm
if not isinstance(paradigm, BaseParadigm):
Expand Down Expand Up @@ -82,7 +84,8 @@ def __init__(self, paradigm, datasets=None, random_state=None, n_jobs=1,
self.results = Results(type(self),
type(self.paradigm),
overwrite=overwrite,
suffix=suffix)
suffix=suffix,
hdf5_path=self.hdf5_path)

def process(self, pipelines):
'''Runs all pipelines on all datasets.
Expand Down
52 changes: 33 additions & 19 deletions moabb/paradigms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def prepare_process(self, dataset):
"""
pass

def process_raw(self, raw, dataset):
def process_raw(self, raw, dataset, return_epochs=False):
"""
Process one raw data file.
Expand All @@ -83,10 +83,16 @@ def process_raw(self, raw, dataset):
The dataset corresponding to the raw file. mainly use to access
dataset specific information.
return_epochs: boolean
This flag specifies whether to return only the data array or the
complete processed mne.Epochs
returns
-------
X : np.ndarray
X : Union[np.ndarray, mne.Epochs]
the data that will be used as features for the model
Note: if return_epochs=True, this is mne.Epochs
if return_epochs=False, this is np.ndarray
labels: np.ndarray
the labels for training / evaluating the model
Expand All @@ -95,22 +101,22 @@ def process_raw(self, raw, dataset):
A dataframe containing the metadata
"""
# get events id
event_id = self.used_events(dataset)

# find the events, first check stim_channels then annotations
stim_channels = mne.utils._get_stim_channel(
None, raw.info, raise_error=False)
stim_channels = mne.utils._get_stim_channel(None, raw.info,
raise_error=False)
if len(stim_channels) > 0:
events = mne.find_events(raw, shortest_event=0, verbose=False)
else:
events, _ = mne.events_from_annotations(raw, verbose=False)

channels = () if self.channels is None else self.channels

# picks channels
picks = mne.pick_types(raw.info, eeg=True, stim=False,
include=channels)

# get events id
event_id = self.used_events(dataset)
if self.channels is None:
picks = mne.pick_types(raw.info, eeg=True, stim=False)
else:
picks = mne.pick_types(raw.info, stim=False, include=self.channels)

# pick events, based on event_id
try:
Expand All @@ -137,11 +143,15 @@ def process_raw(self, raw, dataset):
tmin=tmin, tmax=tmax, proj=False,
baseline=None, preload=True,
verbose=False, picks=picks,
event_repeated='drop',
on_missing='ignore')
if self.resample is not None:
epochs = epochs.resample(self.resample)
# rescale to work with uV
X.append(dataset.unit_factor * epochs.get_data())
if return_epochs:
X.append(epochs)
else:
X.append(dataset.unit_factor * epochs.get_data())

inv_events = {k: v for v, k in event_id.items()}
labels = np.array([inv_events[e] for e in epochs.events[:, -1]])
Expand All @@ -155,7 +165,7 @@ def process_raw(self, raw, dataset):
metadata = pd.DataFrame(index=range(len(labels)))
return X, labels, metadata

def get_data(self, dataset, subjects=None):
def get_data(self, dataset, subjects=None, return_epochs=False):
"""
Return the data for a list of subject.
Expand Down Expand Up @@ -197,7 +207,8 @@ def get_data(self, dataset, subjects=None):
for subject, sessions in data.items():
for session, runs in sessions.items():
for run, raw in runs.items():
proc = self.process_raw(raw, dataset)
proc = self.process_raw(raw, dataset,
return_epochs=return_epochs)

if proc is None:
# this mean the run did not contain any selected event
Expand All @@ -211,12 +222,15 @@ def get_data(self, dataset, subjects=None):
metadata.append(met)

# grow X and labels in a memory efficient way. can be slow
if len(X) > 0:
X = np.append(X, x, axis=0)
labels = np.append(labels, lbs, axis=0)
if not return_epochs:
if len(X) > 0:
X = np.append(X, x, axis=0)
labels = np.append(labels, lbs, axis=0)
else:
X = x
labels = lbs
else:
X = x
labels = lbs
X.append(x)

metadata = pd.concat(metadata, ignore_index=True)
return X, labels, metadata
7 changes: 5 additions & 2 deletions moabb/paradigms/p300.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def is_valid(self, dataset):
def used_events(self, dataset):
pass

def process_raw(self, raw, dataset):
def process_raw(self, raw, dataset, return_epochs=False):
# find the events, first check stim_channels then annotations
stim_channels = mne.utils._get_stim_channel(
None, raw.info, raise_error=False)
Expand Down Expand Up @@ -126,7 +126,10 @@ def process_raw(self, raw, dataset):
if self.resample is not None:
epochs = epochs.resample(self.resample)
# rescale to work with uV
X.append(dataset.unit_factor * epochs.get_data())
if return_epochs:
X.append(epochs)
else:
X.append(dataset.unit_factor * epochs.get_data())

inv_events = {k: v for v, k in event_id.items()}
labels = np.array([inv_events[e] for e in epochs.events[:, -1]])
Expand Down
4 changes: 2 additions & 2 deletions moabb/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def parse_pipelines_from_directory(d):
content = _file.read()

# load config
config_dict = yaml.load(content)
config_dict = yaml.load(content, Loader=yaml.FullLoader)
ppl = create_pipeline_from_config(config_dict['pipeline'])
pipeline_configs.append({'paradigms': config_dict['paradigms'],
'pipeline': ppl,
Expand Down Expand Up @@ -187,7 +187,7 @@ def generate_paradigms(pipeline_configs, context={}):
context_params = {}
if options.context is not None:
with open(options.context, 'r') as cfile:
context_params = yaml.load(cfile.read())
context_params = yaml.load(cfile.read(), Loader=yaml.FullLoader)

paradigms = generate_paradigms(pipeline_configs, context_params)

Expand Down
Loading

0 comments on commit de1b659

Please sign in to comment.