-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
244 lines (194 loc) · 9.03 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import streamlit as st
import mne
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import joblib
from colorama import Fore, Style
import tempfile
import os
from collections import OrderedDict
import librosa
import tensorflow as tf
import pickle
# Define the channel pairs and their joined names
channel_pairs = [
['EEG Fp1', 'EEG F7'], ['EEG F7', 'EEG T3'], ['EEG T3', 'EEG T5'], ['EEG T5', 'EEG O1'],
['EEG Fp1', 'EEG F3'], ['EEG C3', 'EEG F3'], ['EEG F3', 'EEG O1'], ['EEG Fp2', 'EEG F4'],
['EEG F4', 'EEG C4'], ['EEG C4', 'EEG P4'], ['EEG P4', 'EEG O2'], ['EEG Fp2', 'EEG F8'],
['EEG F8', 'EEG T4'], ['EEG T4', 'EEG T6'], ['EEG T6', 'EEG O2']
]
channel_pairs_joined = ['{}-{}'.format(pair[0], pair[1]) for pair in channel_pairs]
#cnn_model = joblib.load('cnn_epilepsy_prediction_model.pkl')
# Define the target sampling rate
target_sampling_rate = 512 # in Hz
# Function to print decorative log
def print_decorative_log(message, color=Fore.BLUE, style=Style.RESET_ALL):
line_length = len(message) + 4 # Length of the message plus padding on both sides
decorative_line = "#" * line_length
print(color + decorative_line)
print(f"# {message} #")
print(decorative_line + style)
# Function to compute cepstrum_mel
def compute_cepstrum_mel(data, sfreq, n_mfcc=20):
mfccs = librosa.feature.mfcc(y=data, sr=sfreq, n_mfcc=n_mfcc)
return mfccs
# Function to preprocess the raw data
def preprocess_raw(raw):
# Preprocessing steps...
# Preprocessing steps...
print_decorative_log("Starting Preprocessing Sequence", Fore.GREEN)
# Select the desired channels from channel pairs which resemble the bipolar longitudinal channels of 10-20 system
selected_channels = []
[selected_channels.extend(pair) for pair in channel_pairs if pair not in selected_channels]
selected_channels = list(OrderedDict.fromkeys(selected_channels))
selected_channels.append('2')
#Drop extra channels
# Check the number of channels
#if len(raw.ch_names) > 35:
for i, channel_name in enumerate(raw.ch_names):
if 'EEG FP2' in channel_name:
raw.rename_channels({channel_name: 'EEG Fp2'})
# Drop channels not found in the desired channel list
channels_to_drop = [channel_name for channel_name in raw.ch_names if channel_name not in selected_channels]
raw.drop_channels(channels_to_drop)
print_decorative_log("Extra Channels Dropped ... ", Fore.RED)
# Reorder the channels to match the standard ordering for the dataset
channels_order = selected_channels
# Reorder channels
raw = raw.pick(channels_order)
print_decorative_log("Channels Reordered ... ", Fore.YELLOW)
# Set the channel type for '2' to 'ecg'
raw.set_channel_types({'2': 'ecg'})
print_decorative_log("ECG Channel Selected ... ", Fore.YELLOW)
# Filtering to remove slow drifts
filt_raw = raw.copy().filter(l_freq=1.0, h_freq=None)
print_decorative_log("Slow drifts removed ... ", Fore.YELLOW)
# Apply ICA to remove ECG artifacts
ica = mne.preprocessing.ICA(n_components=15, max_iter="auto", random_state=97)
ica.fit(filt_raw)
ica.exclude = []
ecg_indices, ecg_scores = ica.find_bads_ecg(raw, method="correlation", threshold="auto")
ica.exclude = ecg_indices
reconst_raw = raw.copy()
ica.apply(reconst_raw)
print_decorative_log("ECG Artificats Removed... ", Fore.YELLOW)
# Perform bipolar longitudinal referencing
anodes = []
cathodes = []
for pair in channel_pairs:
anodes.append(pair[0])
cathodes.append(pair[1])
raw_bip_ref = mne.set_bipolar_reference(reconst_raw, anode=anodes, cathode=cathodes)
raw_bip_ref_ch = raw_bip_ref.copy().pick_channels(channel_pairs_joined)
print_decorative_log("Bipolar Referencing Done ... ", Fore.YELLOW)
raw_clean = mne.preprocessing.oversampled_temporal_projection(raw_bip_ref_ch)
raw_clean.filter(0.0, 40.0)
print_decorative_log("Smoothing & Filtering Done ... ", Fore.YELLOW)
return raw_clean
# Function to simulate streaming data and make predictions
def simulate_streaming_data(raw, start_time, end_time, model_name='CNN'):
st.write("Starting Simulation")
# Crop the raw data to the specified start and end time
raw.crop(tmin=start_time, tmax=end_time)
# Preprocess the raw data
preprocessed_raw = preprocess_raw(raw)
# Get the data and the corresponding time vector
data = preprocessed_raw.get_data(picks=channel_pairs_joined)#, tmin=1138, tmax=1218)
time = preprocessed_raw.times
# Define the window size for frame sampling
window_size = 10 # Window size in seconds
# Calculate the number of samples in the window
window_samples = int(window_size * target_sampling_rate)
# Calculate the number of frames
num_frames = int(len(data[0]) / window_samples)
print(num_frames)
# Iterate over the frames
for frame_idx in range(num_frames):
# Calculate the start and end sample indices for the current frame
start_idx = frame_idx * window_samples
end_idx = start_idx + window_samples
# Extract the frame data for all channels
frame_data = data[:, start_idx:end_idx]
# Compute mfccs
n_mfcc = 20 # Number of MFCC coefficients
cepstrum_mel_features = []
for channel_data in frame_data:
cepstrum_mel = compute_cepstrum_mel(channel_data, target_sampling_rate, n_mfcc)
cepstrum_mel_features.append(cepstrum_mel)
cepstral_features = np.concatenate(cepstrum_mel_features, axis=0)
# Convert features to DataFrame
frame_df = pd.DataFrame(cepstral_features.T)
# Apply feature scaling to the latest frame data
scaler = StandardScaler()
frame_scaled = scaler.fit_transform(frame_df)
# Load the pre-trained machine learning model
with st.spinner('Model is being loaded..'):
model = load_model(model_name)
# Make prediction using the pre-trained model
if model_name == 'CNN':
frame_scaled = np.reshape(frame_scaled, (frame_scaled.shape[0], frame_scaled.shape[1], 1))
prediction = model.predict(frame_scaled)
else:
prediction = model.predict(frame_scaled)[0]
# Map the predicted label to the corresponding class
class_mapping = {0: 'pre-ictal', 1: 'ictal', 2: 'post-ictal', 3: 'normal'}
if model_name == 'CNN':
print("Prediction: ")
print(np.argmax(prediction, axis=1))
predicted_class = class_mapping[np.argmax(prediction, axis=1)[0]]
else:
predicted_class = class_mapping[prediction]
# Display the streaming data and classification result
st.subheader("Streaming 10 secs")
st.info(f"Classification Result: {predicted_class}")
st.write("--------------------------------")
# Streamlit app
@st.cache(allow_output_mutation=True)
def load_data(file_path):
raw = mne.io.read_raw_edf(file_path)
return raw
@st.cache(allow_output_mutation=True)
def load_model(model_name):
model_file = f'{model_name.lower()}_epilepsy_prediction_model.pkl'
model = joblib.load(model_file)
if model_name == 'CNN':
model = tf.keras.models.load_model('cnn_epilepsy_prediction_model.h5')
else:
model = joblib.load(model_file)
return model
def main():
st.title("Epilepsy Detection from Streaming EEG Data - Simulation App")
# Set GPU device
# Configure TensorFlow to use GPU
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(sess)
# Print GPU devices
tf.test.gpu_device_name()
gpus = tf.config.list_physical_devices('GPU')
print("Num GPUs Available: ", len(gpus))
for gpu in gpus:
print(gpu)
# Add sidebar for model selection
model_name = st.sidebar.selectbox("Select Model", ['SVM', 'Random Forest', 'Balanced Random Forest', 'XGBoost', 'AdaBoost', 'CNN'])
# File upload and user input
uploaded_file = st.file_uploader("Upload EDF file", type=["edf"])
if uploaded_file is not None:
# Save the uploaded file to a temporary location
with tempfile.NamedTemporaryFile(delete=False, suffix=".edf") as tmp_file:
tmp_filename = tmp_file.name
tmp_file.write(uploaded_file.read())
print(tmp_filename)
# Read the EDF file using mne.io.read_raw
raw = mne.io.read_raw_edf(tmp_filename, preload=True)
# Perform further processing or analysis with the raw data
# Remove the temporary file
os.remove(tmp_filename)
start_time = st.number_input("Start Time (in seconds)", min_value=0.0, max_value=raw.times[-1], value=0.0)
end_time = st.number_input("End Time (in seconds)", min_value=start_time, max_value=raw.times[-1], value=raw.times[-1])
if st.button("Start Classification"):
simulate_streaming_data(raw, start_time, end_time, model_name)
if __name__ == "__main__":
main()