Существуют различные состояния здоровья и заболевания, которые могут вызывать аномальные формы волны ЭЭГ, они могут варьироваться от эпилепсии до болезни Паркинсона (БП). В этой короткой статье я собираюсь реализовать очень простую безусловную GAN для генерации аномальных и нормальных данных LFP. В моем случае я обучаю свои данные, собранные у участников, страдающих эпилепсией, и для этого есть доступные наборы данных с открытым исходным кодом.

GAN — это сеть, состоящая из двух основных компонентов: дискриминатора и генератора. Рациональность заключается в том, что дискриминатор является бинарным предсказателем того, насколько реальны входные данные, а генератор обучен создавать реалистичные выходные данные из шума или в формальной форме из скрытого пространства. Процесс обучения включает в себя создание поддельных выборок из сети генератора, а также извлечение реальных выборок из доступных наборов данных и обучение дискриминатора различению между ними, что приводит к ошибке, которая затем используется для обновления генератора. сеть. Хитрость здесь заключается в том, чтобы передать метки как настоящие (одна из них в двоичной классификации означает True или искомый класс), поэтому генератор вынужден делать как можно более реальные выходные данные, пытаясь обмануть дискриминатор.

Если вы заинтересованы в создании собственного GAN для очень простой задачи, ознакомьтесь с этим полезным учебником.

Чтобы создать GAN, давайте сначала определим сеть генератора:

def make_gen_model(latent_dim, n_outputs=256):
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(32, activation='relu', kernel_initializer='he_uniform', input_dim=latent_dim))
model.add(tf.keras.layers.Dense(64, activation='relu', kernel_initializer='he_uniform'))
model.add(tf.keras.layers.Dense(128, activation='relu', kernel_initializer='he_uniform'))
model.add(tf.keras.layers.Dense(n_outputs, activation='linear'))
return model

Далее, давайте определим дискриминатор.

def make_disc_model(n_inputs=256):
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(128, activation='relu', kernel_initializer='he_uniform', input_dim=n_inputs))
model.add(tf.keras.layers.Dense(128, activation='relu', kernel_initializer='he_uniform'))
model.add(tf.keras.layers.Dense(32, activation='relu', kernel_initializer='he_uniform'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
return model

Наконец, давайте создадим модель GAN.

def make_gan_model(disc_model, gen_modl):
disc_model.trainable = False
model = tf.keras.Sequential()
model.add(gen_modl)
model.add(disc_model)
model.compile(loss='binary_crossentropy', optimizer='adam')
return model

Наконец, давайте создадим функцию для обучения сети GAN и дискриминатора.

def train(disc_model, gen_model, gan_model, epochs, latent_dim = 5):
batch_size = 128
half_batch = int(batch_size / 2)
for i in range(epochs):
x_real, y_real = generate_real_sample(half_batch)
x_fake, y_fake = generate_fake_sample(gen_model, latent_dim, half_batch)
disc_model.train_on_batch(x_real, y_real)
disc_model.train_on_batch(x_fake, y_fake)
x_gan = generate_latent_points(latent_dim, batch_size)
y_gan = np.ones((batch_size, 1))
gan_model.train_on_batch(x_gan, y_gan)

Это общий вид основных компонентов, необходимых для создания сети. В следующем бите я пройдусь по результатам.

Есть несколько физиологических особенностей, связанных с эпизодами судорог, которые являются результатом судорог. Как видно из отчетов и данных, собранных на сегодняшний день, обычно наблюдается рост на всех частотах и ​​всплеск в определенных диапазонах, скажем, 15–20 Гц. Сказав это, давайте посмотрим, чему научилась сеть, сначала обучаясь на данных LFP, где ничего патологического не происходит, а другая сеть, когда происходят припадки.

Нормальный: на рис. 1 показано, как сравниваются реальные и поддельные образцы для обычных данных LFP. Это после 5000 эпох со 128 выборками на партию.

Как видите, общая структура скопирована довольно хорошо, однако в модели по-прежнему не хватает внимания к конкретным частям. Одним из них, например, является провал на частоте 50 Гц для реальных выборок, поскольку реальные данные были пропущены через режекторный фильтр на частоте 50 Гц (гул сети), но сгенерированные данные не имеют этого аспекта.

Отклонение от нормы. На рис. 2 показано, как сравниваются результаты сети, обученной на выборках, взятых из эпох, когда происходят припадки.

Как видите, есть общий подъем в диапазоне 0–25 Гц, который также улавливает генератор, но опять же провал на 50 Гц сохраняется не так хорошо.

Можно спросить о причине развития таких сетей. В наши дни создается множество моделей для граничных вычислений, и не так много данных для обучения моделей из-за дорогостоящего процесса. Такой подход позволил бы увеличить обучающие наборы для менее сложных классификаторов, предназначенных для реализации краев.

Возможные изменения в процессе, описанном в этой статье, заключаются в разработке глубокой сверточной генеративно-состязательной сети (DCGAN) и обучении на спектрограммах данных временных рядов, что является нормой в таких областях, как обработка звука. Другое возможное редактирование могло бы состоять в том, чтобы иметь частотно-временное представление с использованием Continus Wavelet Transform. Я планирую написать еще одну короткую статью, используя DCGAN для этой цели.

Наконец, я создал простой репозиторий GitHub, чтобы поиграть с ним самостоятельно.