This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

VizECGNet: Visual ECG Image Network for Cardiovascular Diseases Classification with Multi-Modal Training and Knowledge Distillation

Abstract

An electrocardiogram (ECG) captures the heart’s electrical signal to assess various heart conditions. In practice, ECG data is stored as either digitized signals or printed images. Despite the emergence of numerous deep learning models for digitized signals, many hospitals prefer image storage due to cost considerations. Recognizing the unavailability of raw ECG signals in many clinical settings, we propose VizECGNet, which uses only printed ECG graphics to determine the prognosis of multiple cardiovascular diseases. During training, cross-modal attention modules (CMAM) are used to integrate information from two modalities - image and signal, while self-modality attention modules (SMAM) capture inherent long-range dependencies in ECG data of each modality. Additionally, we utilize knowledge distillation to improve the similarity between two distinct predictions from each modality stream. This innovative multi-modal deep learning architecture enables the utilization of only ECG images during inference. VizECGNet with image input achieves higher performance in precision, recall, and F1-Score compared to signal-based ECG classification models, with improvements of 3.50%, 8.21%, and 7.38%, respectively.

\star denotes the same contributions.This work was supported in part by the National Research Foundation of Korea (NRF) under Grant NRF-2021R1A2C2010893 and in part by Institute of Information and communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No.RS-2022-00155915, Artificial Intelligence Convergence Innovation Human Resources Development (Inha University).

Index Terms—  Deep Learning, Signal Processing, Multi-Modality Learning, ECG Classification

1 Introduction

Cardiovascular disease ranks as the second leading cause of death, following cancer. Physicians employ electrocardiograms (ECGs) as a diagnostic tool to monitor the heart’s electrical activity. During an ECG, electrodes are placed on the skin, typically on the chest, arms, and legs, to detect and record the heart’s electrical impulses. Healthcare professionals analyze these signals to diagnose various heart conditions, including arrhythmias, myocardial infarction (heart attack), and heart failure. Despite its utility, manual analysis of ECG signals is complex and prone to human error, potentially leading to the oversight of subtle yet critical diagnostic patterns.

With advancements in deep learning and signal processing, researchers have strived to automate cardiovascular disease diagnosis. For instance, [1] introduced a binary classification model for arrhythmic fibrillation signals, utilizing a 1D convolutional neural network (CNN) based on single-lead ECG data. Subsequently, RhythmNet [2] and Stacked CNN-LSTM [3] addressed the periodic nature of normal signals by incorporating recurrent neural networks (RNN and LSTM) to detect abnormalities based on long-term dependencies. Additionally, [4] proposed a method that extracts multi-scale features to classify diseases from abnormal ECG signals.

In clinical practice, physicians generally prefer utilizing multi-lead ECG signals, which are gathered from multiple electrodes, over single-lead ECG signals for diagnostic purposes. Consequently, numerous researchs have focused on developing diagnostic tools based on multi-lead signals. For example, [5] employed residual learning and recurrent neural networks to classify seven abnormal signals from 12-lead ECG data. Similarly, [6] classified nine abnormal signals by combining new features obtained by experts from 12-lead ECG signals with features extracted from 1D CNN. Addressing the challenge posed by the concurrent manifestation of multiple cardiovascular diseases in a 12-lead ECG signal, 1D RANet [7] tackles a multi-label classification task.

Refer to caption
Fig. 1: Overall architecture of the proposed VizECGNet, which mainly comprises CMAM and SMAM. (a) Overall block diagram of out network. (b) Overview of CMAM. (c) Overview of SMAM.

We observed that ECG signal classification models primarily rely on digitized signals for identifying abnormalities. Our primary motivation arises from the unavailability of digitized ECG signals, particularly in smaller clinics, due to factors such as the high maintenance costs of databases and the use of legacy devices. With this motivation, we introduce VizECGNet, a model for classifying cardiovascular diseases using printed ECG graphics (images). We utilized features extracted from both the image and the signal. Recognizing that features from each modality exhibit distinct characteristics, we apply a cross-modal attention module (CMAM) to fuse these features. Additionally, a self-modal attention module (SMAM) refines the extracted features across heterogeneous modalities, emphasizing discriminative features crucial for distinguishing normal and abnormal signals. These modality-specific features are then integrated in fully-connected layers, with knowledge distillation applied between the two predictions to prevent performance degradation when utilizing only images during inference. In comparison to existing models, our approach demonstrates superior classification performance across various metrics such as precision, recall, and F1-Score on large-scale ECG datasets. Our main contribution of this paper is as follows:

  • We propose a novel ECG multi-label classification model (VizECGNet) based on multi-modal learning and knowledge distillation to employ the complexity characteristics of time-series data within 12-lead ECG signals while using only images during inference.

  • We also propose two attention modules, CMAM and SMAM, enabling the model to exchange information between modalities and emphasize the discriminative features in each modality stream.

  • We experimentally achieved state-of-the-art performance in various evaluation metrics (precision, recall, and F1-Score) when comparing signal-/image-/hybrid-based classification models on a large-scale 12-lead ECG dataset.

2 Method

VizECGNet is a composition of 1D and 2D CNN for multi-label classification of 12-lead ECG signals. Our model is a novel structure that combines multi-modal learning and knowledge distillation techniques. For multi-modal learning, we adopt a self-attention mechanism between different and the same modality, called CMAM and SMAM. The features of each modality are forwarded through fully-connected layers for knowledge distillation. Only ECG images are used for predicting cardiovascular disease during the inference phase in 12-lead ECG signals. Fig. 1 illustrates the overall structure of VizECGNet.

2.1 Cross- and Self-Modal Attention Modules

The main goal of VizECGNet is to extract the correlations between different modalities and extract discriminative features via cross- and self-modal attention modules. To achieve this goal, we extract features from the 12-lead ECG signals 𝐗s={𝐱1s,,𝐱12s}\mathbf{X}^{s}=\{\mathbf{x}^{s}_{1},\dots,\mathbf{x}^{s}_{12}\} and image 𝐗i\mathbf{X}^{i} using CNN-based feature extractors 𝐟\mathbf{f} and 𝐠\mathbf{g}, respectively. For each ll-th single-lead signal 𝐱ls=[xl,1s,,xl,Ts]\mathbf{x}^{s}_{l}=[x^{s}_{l,1},\dots,x^{s}_{l,T}] with time length TT, we use twelve different 1D CNN-based feature extractor 𝐟={f1,,f12}\mathbf{f}=\{f_{1},\dots,f_{12}\} with sharing parameters to efficiently fuse each signal features as follows:

𝐳avgs=112l=112fl(𝐱ls)Cs×Ts.\mathbf{z}^{s}_{\text{avg}}=\frac{1}{12}\sum_{l=1}^{12}f_{l}(\mathbf{x}^{s}_{l})\in\mathbb{R}^{C^{s}\times T^{s}}. (1)

Similar to ECG signal 𝐗s\mathbf{X}^{s}, we extract the feature maps from 12-lead ECG image 𝐗i\mathbf{X}^{i} using 2D CNN-based feature extractor 𝐠\mathbf{g} as follow:

𝐳i=𝐠(𝐗i)Ci×Hi×Wi,\mathbf{z}^{i}=\mathbf{g}(\mathbf{X}^{i})\in\mathbb{R}^{C^{i}\times H^{i}\times W^{i}}, (2)

where Cs=Ci=512C^{s}=C^{i}=512. Subsequently, we apply cross-modal attention module between two extracted features 𝐳avgs\mathbf{z}^{s}_{\text{avg}} and 𝐳i\mathbf{z}^{i}. For clarify, let assume we use self-attention on modality mm based on modality nn. Then, we can write CMAM as follows:

𝐳¯mn=Softmax(𝐐n𝐊nT)𝐕m\bar{\mathbf{z}}^{m\leftarrow n}=\text{Softmax}\left(\mathbf{Q}_{n}\mathbf{K}^{T}_{n}\right)\mathbf{V}_{m}\\ (3)

where 𝐐n=W𝐐n𝐳n,𝐊n=W𝐊n𝐳n,𝐕m=W𝐕m𝐳m\mathbf{Q}_{n}=W_{\mathbf{Q}_{n}}\mathbf{z}^{n},\mathbf{K}_{n}=W_{\mathbf{K}_{n}}\mathbf{z}^{n},\mathbf{V}_{m}=W_{\mathbf{V}_{m}}\mathbf{z}^{m}. Then, each refined features 𝐳¯si\bar{\mathbf{z}}^{s\leftarrow i} and 𝐳¯is\bar{\mathbf{z}}^{i\leftarrow s} contain different modality information. Now, we apply SMAM to extract discriminative features for two features with mixed information in each modal stream as follows:

{𝐳^s=Softmax(𝐐s𝐊sT)𝐕s𝐳^i=Softmax(𝐐i𝐊iT)𝐕i\begin{cases}&\hat{\mathbf{z}}^{s}=\textbf{Softmax}\left(\mathbf{Q}_{s}\mathbf{K}^{T}_{s}\right)\mathbf{V}_{s}\\ &\hat{\mathbf{z}}^{i}=\textbf{Softmax}\left(\mathbf{Q}_{i}\mathbf{K}^{T}_{i}\right)\mathbf{V}_{i}\end{cases} (4)

where 𝐗m=W𝐗m𝐳¯mn\mathbf{X}_{m}=W_{\mathbf{X}_{m}}\bar{\mathbf{z}}^{m\leftarrow n} for m{s,i}m\in\{s,i\} and X{Q,K,V}\textbf{X}\in\{\textbf{Q},\textbf{K},\textbf{V}\}. Each final refined discriminative features 𝐳^s\hat{\mathbf{z}}^{s} and 𝐳^i\hat{\mathbf{z}}^{i} are forwarded into the fully-connected layers for each modality stream to classify abnormal ECG signal types.

2.2 Knowledge Distillation

Knowledge distillation is a technique commonly used in machine learning to transfer knowledge from a complex model (teacher model) to a simple model (student model). To utilize only printed ECG signal to classify abnormal signals during inference and acquire knowledge about 12-lead ECG signals, we adopt knowledge distillation from signal stream into image stream. First, predicting the probability distributions psp^{s} and pip^{i} for each class is performed using a classifier for each modality stream as follows:

pm=MLPm(GAP(𝐳^m))p^{m}=\textbf{MLP}_{m}\left(\textbf{GAP}\left(\hat{\mathbf{z}}^{m}\right)\right) (5)

where MLPm()\textbf{MLP}_{m}(\cdot) is a multi-layer perceptron (MLP) for each signal and image modality stream. For each prediction, the classification loss function cls\mathcal{L}_{cls} for the same label tt is computed as follows:

cls=c=1C(BCE(tc,pcs)+BCE(tc,pci))\mathcal{L}_{cls}=\sum_{c=1}^{C}\left(\mathcal{L}_{BCE}\left(t_{c},p^{s}_{c}\right)+\mathcal{L}_{BCE}\left(t_{c},p^{i}_{c}\right)\right) (6)

where CC is a number of classes and BCE\mathcal{L}_{BCE} is the binary cross-entropy loss function. Since the dataset used in this paper can have multiple diseases for a single ECG signal, a binary classification is performed for each class. Finally, we calculate knowledge distillation loss KD\mathcal{L}_{KD} between two modality streams to reduce the difference in probability distribution psp^{s} and pip^{i} as follows:

KD(ps,pi)=c=1CKL(pcs||pci)=c=1Cx𝒳pcs(x)log(pcs(x)pci(x))\begin{split}\mathcal{L}_{KD}(p^{s},p^{i})=\sum_{c=1}^{C}\mathcal{L}_{KL}(p^{s}_{c}||p^{i}_{c})\\ =\sum_{c=1}^{C}\sum_{x\in\mathcal{X}}p^{s}_{c}(x)\text{log}\left(\frac{p^{s}_{c}(x)}{p^{i}_{c}(x)}\right)\end{split} (7)

where KL\mathcal{L}_{KL} is a Kullback-Leibler Divergence to calculate the difference between two probability distributions pcsp^{s}_{c} and pcip^{i}_{c} for each class cc. The final loss function total=λ1cls+λ2KD\mathcal{L}_{total}=\lambda_{1}\mathcal{L}_{cls}+\lambda_{2}\mathcal{L}_{KD} is used for updating parameters of each modality stream. To reduce the sensitivity of hyperparameters, we fix λ1=λ2=1\lambda_{1}=\lambda_{2}=1.

Modality Method Parameters (M) Speed (ms) Inference Data Type Precision Recall F1-Score
Signal InceptionTime [8] 0.49M 15.9ms Signal 46.13 39.98 42.84
XResNet1D-101 [9] 13.94M 23.7ms 49.26 42.96 45.89
Transformer [10] 24.17M 32.1ms 51.34 41.73 46.04
ACNet [11] 262.00M 24.5ms 53.83 45.12 49.09
1D RANet [7] 3.93M 23.5ms 57.73 48.51 51.51
Image ResNet18 [12] 11.64M 21.4ms Image 52.62 47.14 49.73
MobileNetV3 [13] 18.20M 12.8ms 39.56 31.83 35.28
Signal & Image MHM [14] 15.44M 28.1ms Signal 44.16 39.42 41.68
VizECGNet (Ours) 11.17M 37.1ms 63.20 59.11 61.09
MHM [14] 15.44M 30.2ms Image 37.18 32.71 34.83
VizECGNet (Ours) 11.17M 39.3ms 61.23 56.72 58.89
Table 1: Experiment results on the Large-Scale ECG datasets. We also provide number of trainable parameters (M) and inference speed (sec) for each methods. Red and Blue are the first and second performance results, respectively.

3 Experimental Results

3.1 Experimental Settings and Implementation Details

We implemented VizECGNet in Pytorch 1.11 and Python 3.8. The large-scale 12-lead ECG signal dataset [15] used in this paper is multi-label (1dAVb, RBBB, LBBB, SB, AF, ST). The ECG signal of each lead is composed of 4096 time lengths. In this paper, to eliminate the trend of each signal, the average voltage for the entire time is set to zero mean, and we apply a detrending process. Finally, we convert the 12-lead ECG signals using the ecg_plot Python library for multi-modal learning on the images. We compared our VizECGNet with five signal-based models (InceptionTime [8], XResNet1D-101 [9], Transformer [10], ACNet [11], and 1D RANet [7]), two image-based models (ResNet18 [12] and MobileNetV3 [13]), and multi-modal models (MHM [14]). Since the use of default training settings generally performs poorly in ECG dataSet, for fair comparison, we optimized the parameters to work best on ECG dataset. We train all models in an end-to-end manner using the Adam optimizer. The initial learning rate starts from 10310^{-3} and is decreased to 10610^{-6} using the cosine annealing learning rate scheduler [16], and the training settings were set to a batch size of 16 and epochs of 300 till the loss functions of all models converged. For evaluation, we used three metrics (Precision, Recall, and macro-averaged F1-Score) to measure the performance of each model. To efficiently extract features from signals and image, we utilize ResNet18 as feature extractor.

3.2 Results Analysis

As shown in Table 1, VizECGNet outperforms single- and multi-modality models on all performance evaluation metrics in both inference data types. When predicting using ECG signals, VizECGNet achieve 9.58%, and 19.68% higher F1-Score compared with 1D RANet, and MHM, respectively. Furthermore, when using printed ECG images, VizECGNet achieve 17.21%, and 24.06% higher F1-Score compared with ResNet18, and MHM, respectively. Note that MobileNetV3 and ResNet18 used only simple ECG images for training and evaluation. This training strategy makes two image-based models unable to understand the abnormal signal characteristics of each lead. However, although 1D RANet receives 12-lead ECG signal data with complex data structures, its performance is low. These reasons indicate that the interpretability of 1D RANet for complex data is still poor. On the other hand, VizECGNet achieves high classification performance because it exchanges information based on multi-modal learning and distills the knowledge of abnormal signals generated by subtle differences in 12-lead ECG signals into an image modality stream during learning.

We also examined extrapolation to actual ECG printed images to confirm the utility of the model (Fig. 2). In addition, the inference results of two image-based models (ResNet18 and MobileNetV3) are added. Fig. 2.(e) shows the disease probability for each model for each disease (AF, RBBB, LBBB, 1dAVB). Our model was positive for all diseases, but the other two models were negative for all but well-characterized diseases such as 1dAVB. These results suggest that VizECGNet is more practical than other models because it can be applied to real printed images even though it was trained on synthetic ECG images created from 12-lead ECG signals.

Refer to caption
Fig. 2: The example of real ECG print image and prediction results of VizECGNet and Image-based models (ResNet18 and MobileNetV3). (a) AF. (b) RBBB. (c) LBBB. (d) 1dAVB. (e) Prediction probability for each cardiovascular diseases. Red, Yellow, and Green bars denotes VizECGNet, ResNet18, and MobileNetV3, respectively.

3.3 Ablation Study

In this section, we analyze the effectiveness of two attention modules (CMAM and SMAM). In Table 2, the results of the ablation study on attention modules are listed with four configurations. Basically, attention is used to extract discriminative features from messy features. In this paper, important information can be additionally extracted by focusing on each modality or between modalities. It can be seen from the actual experimental results that when the mode is paid attention to, all performance evaluation indicators have achieved high performance.

Settings Precision Recall F1-Score
VizECGNet (+SMA+CMA) 61.23 56.72 58.89
w/o SMA 59.53 49.06 53.83
w/o CMA 60.12 53.72 56.70
w/o SMA & CMA 57.85 47.90 52.37
Table 2: Ablation study of VizECGNet on the Large-Scale ECG dataset for attention modules. Red and Blue are the first and second performance results, respectively.

4 Conclusion

We propose VizECGNet, which applies multi-modal learning-based knowledge distillation techniques to classify abnormal electrical signals in ECG signals. Experimental results on a large-scale ECG dataset demonstrate that VizECGNet performs better than traditional heart disease classification models. In the multi-modal case, cross- and self-modality attention modules (CMAM and SMAM) enable us to focus on discriminative features between different modalities and apply knowledge distillation techniques to prevent the performance drop when only images are used during inference. These results prove that the model can be fully exploited in developing countries, which only have access to ECG printers without undergoing the refinement process. To further verify the generalization ability of VizECGNet, we plan to train and evaluate it on various 12-lead datasets, make it into an application, and test it in a real clinical setting.

References

  • [1] Zhaohan Xiong, Martin K Stiles, and Jichao Zhao, “Robust ecg signal classification for detection of atrial fibrillation using a novel neural network,” in 2017 Computing in Cardiology (CinC). IEEE, 2017, pp. 1–4.
  • [2] Zhaohan Xiong, Martyn P Nash, Elizabeth Cheng, Vadim V Fedorov, Martin K Stiles, and Jichao Zhao, “Ecg signal classification for the detection of cardiac arrhythmias using a convolutional recurrent neural network,” Physiological measurement, vol. 39, no. 9, pp. 094006, 2018.
  • [3] Jen Hong Tan, Yuki Hagiwara, Winnie Pang, Ivy Lim, Shu Lih Oh, Muhammad Adam, Ru San Tan, Ming Chen, and U Rajendra Acharya, “Application of stacked convolutional and long short-term memory network for accurate identification of cad ecg signals,” Computers in biology and medicine, vol. 94, pp. 19–26, 2018.
  • [4] Jonathan Rubin, Saman Parvaneh, Asif Rahman, Bryan Conroy, and Saeed Babaeizadeh, “Densely connected convolutional networks for detection of atrial fibrillation from short single-lead ecg recordings,” Journal of electrocardiology, vol. 51, no. 6, pp. S18–S21, 2018.
  • [5] Yu-Jhen Chen, Chien-Liang Liu, Vincent S Tseng, Yu-Feng Hu, and Shih-Ann Chen, “Large-scale classification of 12-lead ecg with deep learning,” in 2019 IEEE EMBS international conference on biomedical & health informatics (BHI). IEEE, 2019, pp. 1–4.
  • [6] Zhongdi Liu, Xiang’Ao Meng, Jiajia Cui, Zhipei Huang, and Jiankang Wu, “Automatic identification of abnormalities in 12-lead ecgs using expert features and convolutional neural networks,” in 2018 International Conference on Sensor Networks and Signal Processing (SNSP). IEEE, 2018, pp. 163–167.
  • [7] Yamin Liu, Hanshuang Xie, Qineng Cao, Jiayi Yan, Fan Wu, Huaiyu Zhu, and Yun Pan, “Multi-label classification of multi-lead ecg based on deep 1d convolutional neural networks with residual and attention mechanism,” in 2021 Computing in Cardiology (CinC). IEEE, 2021, vol. 48, pp. 1–4.
  • [8] Hassan Ismail Fawaz, Benjamin Lucas, Germain Forestier, Charlotte Pelletier, Daniel F Schmidt, Jonathan Weber, Geoffrey I Webb, Lhassane Idoumghar, Pierre-Alain Muller, and François Petitjean, “Inceptiontime: Finding alexnet for time series classification,” Data Mining and Knowledge Discovery, vol. 34, no. 6, pp. 1936–1962, 2020.
  • [9] Tong He, Zhi Zhang, Hang Zhang, Zhongyue Zhang, Junyuan Xie, and Mu Li, “Bag of tricks for image classification with convolutional neural networks,” in Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, 2019, pp. 558–567.
  • [10] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin, “Attention is all you need,” Advances in neural information processing systems, vol. 30, 2017.
  • [11] Xiaohan Ding, Yuchen Guo, Guiguang Ding, and Jungong Han, “Acnet: Strengthening the kernel skeletons for powerful cnn via asymmetric convolution blocks,” in Proceedings of the IEEE/CVF international conference on computer vision, 2019, pp. 1911–1920.
  • [12] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, “Deep residual learning for image recognition,” in Proceedings of the IEEE conference on computer vision and pattern recognition, 2016, pp. 770–778.
  • [13] Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, et al., “Searching for mobilenetv3,” in Proceedings of the IEEE/CVF international conference on computer vision, 2019, pp. 1314–1324.
  • [14] Zhi Qiao, Zhen Zhang, Xian Wu, Shen Ge, and Wei Fan, “Mhm: Multi-modal clinical data based hierarchical multi-label diagnosis prediction,” in Proceedings of the 43rd International ACM SIGIR Conference on Research and Development in Information Retrieval, 2020, pp. 1841–1844.
  • [15] AH Ribeiro, GM Paixao, EM Lima, et al., “Code-15%: A large scale annotated dataset of 12-lead ecgs,” Zenodo, Jun, vol. 9, 2021.
  • [16] Ilya Loshchilov and Frank Hutter, “Sgdr: Stochastic gradient descent with warm restarts,” arXiv preprint arXiv:1608.03983, 2016.