ECG Automatic Diagnosis using Deep Neural Network

In this notebook we provide the code for training deep neural network for doing ECG automatic diagnosis. The neural network architecture used is similar to the residual network proposed by (He et al., 2015). This network allows very deep neural networks to be efficiently trained by including shortcut connections. While the residual network (He et al., 2015) has been proposed in the context of images it can be easily adapted to the case of unidimensional signals. Actually, almost all architectures proposed for 2D data (images) can be adapted for 1D signals and could be tried out. Some examples of neural networks for image recognition that could be adapted include:

  • LeNet-5 (Y. Lecun et al., 1998);
  • AlexNet (Krizhevsky et al., 2012);
  • VGG-16 (Simonyan and Zisserman, 2014);
  • ResNet (He et al., 2015);
  • Inception network (Szegedy et al., 2015).

We choose to use residual networks because they have been employed for arrhythmia detection from ECG signals in (Rajpurkar et al., 2017), in an application very similar to our own. The details of the neural network used in (Rajpurkar et al., 2017) are ilustrated in the figure bellow.


"The network consists of 16 residual blocks with 2 convolutional layers per block. The convolutional layers all have a filter length of 16 and have 64k filters, where k starts out as 1 and is incremented every 4-th residual block. Every alternate residual block subsamples its inputs by a factor of 2, thus the original input is ultimately subsampled by a factor of $2^8$ . When a residual block subsamples the input, the corresponding shortcut connections also subsample their input using a Max Pooling operation with the same subsample factor. Before each convolutional layer they apply Batch Normalization (Ioffe & Szegedy, 2015) and a rectified linear activation, adopting the pre-activation block design. The first and last layers of the network are special-cased due to this pre-activation block structure. We also apply Dropout (Srivastava et al., 2014) between the convolutional layers and after the non-linearity."

In this notebook we provide the code for training neural networks with a structure similar to this one.

  • References
    • Rajpurkar, P., Hannun, A.Y., Haghpanahi, M., Bourn, C., Ng, A.Y., 2017. Cardiologist-Level Arrhythmia Detection with Convolutional Neural Networks. arXiv:1707.01836 [cs].
    • He, K., Zhang, X., Ren, S., Sun, J., 2015. Deep Residual Learning for Image Recognition. arXiv:1512.03385 [cs].
    • Y. Lecun, L. Bottou, Y. Bengio, P. Haffner, 1998. Gradient-based learning applied to document recognition. Proceedings of the IEEE 86, 2278–2324.
    • Krizhevsky, A., Sutskever, I., Hinton, G.E., 2012. Imagenet classification with deep convolutional neural networks, in: Advances in Neural Information Processing Systems. pp. 1097–1105.
    • Simonyan, K., Zisserman, A., 2014. Very Deep Convolutional Networks for Large-Scale Image Recognition. arXiv:1409.1556 [cs].
    • Szegedy, C., Liu, W., Jia, Y., Sermanet, P., Reed, S., Anguelov, D., Erhan, D., Vanhoucke, V., Rabinovich, A., 2015. Going deeper with convolutions, in: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. pp. 1–9.
    • He, K., Zhang, X., Ren, S., Sun, J., 2015. Deep Residual Learning for Image Recognition. arXiv:1512.03385 [cs].
    • Ioffe, S., Szegedy, C., 2015. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. arXiv:1502.03167 [cs].
In [2]:
# Correct behaviour for integer division
from __future__ import division
# Arrays manipulation
import numpy as np
# For plotting
import matplotlib.pyplot as plt
import plotly
import plotly.offline as py
import as tls
# For building and training the neural network
from keras.models import Model
from keras.layers import (Input, Conv1D, MaxPooling1D, Dropout,
                          BatchNormalization, Dense, Activation, Add, Flatten)
from keras.utils.vis_utils import model_to_dot
from keras.optimizers import Adam
from keras.utils import HDF5Matrix
from keras.callbacks import (ModelCheckpoint, TensorBoard, ReduceLROnPlateau)
# IPython features
from IPython.display import SVG