-
Notifications
You must be signed in to change notification settings - Fork 298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deepexpalin issue #29
Comments
This is due to the transition to Tensorflow 2, and the original DeepExplain package not supporting TF2 out of the box. There is an open pull request (marcoancona/DeepExplain#55) that provides support for TF2 as long as you disable eager execution: import tensorflow as tf
tf.compat.v1.disable_eager_execution()
...
(the rest of your code)
... Here's a code snippet that works out-of-the-box with the above pull request (using the MNE sample dataset): # import tensorflow and disable eager execution right up front
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import numpy as np
# mne imports
import mne
from mne import io
from mne.datasets import sample
# EEGNet-specific imports
from EEGModels import EEGNet
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.python.keras import backend as K
from tensorflow.keras.models import Model
from deepexplain.tensorflow import DeepExplain
# while the default tensorflow ordering is 'channels_last' we set it here
# to be explicit in case if the user has changed the default ordering
K.set_image_data_format('channels_last')
##################### Process, filter and epoch the data ######################
data_path = sample.data_path()
# Set parameters and read data
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
tmin, tmax = -0., 1
event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)
# Setup for reading the raw data
raw = io.Raw(raw_fname, preload=True, verbose=False)
raw.filter(2, None, method='iir') # replace baselining with high-pass
events = mne.read_events(event_fname)
raw.info['bads'] = ['MEG 2443'] # set bad channels
picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
exclude='bads')
# Read epochs
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False,
picks=picks, baseline=None, preload=True, verbose=False)
labels = epochs.events[:, -1]
# extract raw data. scale by 1000 due to scaling sensitivity in deep learning
X = epochs.get_data()*1000 # format is in (trials, channels, samples)
y = labels
kernels, chans, samples = 1, 60, 151
# take 50/25/25 percent of the data to train/validate/test
X_train = X[0:144,]
Y_train = y[0:144]
X_validate = X[144:216,]
Y_validate = y[144:216]
X_test = X[216:,]
Y_test = y[216:]
# convert labels to one-hot encodings.
Y_train = np_utils.to_categorical(Y_train-1)
Y_validate = np_utils.to_categorical(Y_validate-1)
Y_test = np_utils.to_categorical(Y_test-1)
# convert data to NHWC (trials, channels, samples, kernels) format. Data
# contains 60 channels and 151 time-points. Set the number of kernels to 1.
X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)
X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
# configure the EEGNet-8,2,16 model with kernel length of 32 samples (other
# model configurations may do better, but this is a good starting point)
model = EEGNet(nb_classes = 4, Chans = chans, Samples = samples,
dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16,
dropoutType = 'Dropout')
# compile the model and set the optimizers
model.compile(loss='categorical_crossentropy', optimizer='adam',
metrics = ['accuracy'])
# count number of parameters in the model
numParams = model.count_params()
# set a valid path for your system to record model checkpoints
checkpointer = ModelCheckpoint(filepath='/tmp/checkpoint.h5', verbose=1,
save_best_only=True)
###############################################################################
# if the classification task was imbalanced (significantly more trials in one
# class versus the others) you can assign a weight to each class during
# optimization to balance it out. This data is approximately balanced so we
# don't need to do this, but is shown here for illustration/completeness.
###############################################################################
# the syntax is {class_1:weight_1, class_2:weight_2,...}. Here just setting
# the weights all to be 1
class_weights = {0:1, 1:1, 2:1, 3:1}
fittedModel = model.fit(X_train, Y_train, batch_size = 16, epochs = 5,
verbose = 2, validation_data=(X_validate, Y_validate),
callbacks=[checkpointer], class_weight = class_weights)
with DeepExplain(session = K.get_session()) as de:
input_tensor = model.layers[0].input
fModel = Model(inputs = input_tensor, outputs = model.layers[-2].output)
target_tensor = fModel(input_tensor)
# can use epsilon-LRP as well if you like.
attributions = de.explain('deeplift', target_tensor * Y_test, input_tensor, X_test)
# attributions = de.explain('elrp', target_tensor * Y_test, input_tensor, X_test)
|
Alternatively, you could manually fix this by editing /deepexplain/tensorflow/methods.py directly, although this is a pretty bad hack:
I've verified this also works (not extensively tested however), although the above PR is the better route. |
Very good that works thank you! |
Does anyone of you know how to solve it?
The text was updated successfully, but these errors were encountered: