Differentiable Projection from Optical Coherence Tomography B-Scan without Retinal Layer Segmentation Supervision
Abstract
Projection map (PM) from optical coherence tomography (OCT) B-scan is an important tool to diagnose retinal diseases, which typically requires retinal layer segmentation. In this study, we present a novel end-to-end framework to predict PMs from B-scans. Instead of segmenting retinal layers explicitly, we represent them implicitly as predicted coordinates. By pixel interpolation on uniformly sampled coordinates between retinal layers, the corresponding PMs could be easily obtained with pooling. Notably, all the operators are differentiable; therefore, this Differentiable Projection Module (DPM) enables end-to-end training with the ground truth of PMs rather than retinal layer segmentation. Our framework produces high-quality PMs, significantly outperforming baselines, including a vanilla CNN without DPM and an optimization-based DPM without a deep prior. Furthermore, the proposed DPM, as a novel neural representation of areas/volumes between curves/surfaces, could be of independent interest for geometric deep learning.
Index Terms— OCT, Retinal, B-Scan, Differentiable Projection, Shape Modeling, Geometric Deep Learning.
1 Introduction
Optical coherence tomography (OCT) is an important modality in retina imaging thanks to its high resolution and non-invasiveness in 3D [1]. 2D projection maps (PMs) between retinal layers from OCT B-scans are popular in retinal disease diagnosis, which provides information inlier retinal pathology invisible in the conventional fundus images [2, 3]. The 2D projection maps require retinal layer segmentation on each slice of B-scans, and then aggregate (e.g., average or max) pixels between certain layers. Traditional algorithms of retinal layer segmentation are typically based on prior knowledge of retinal layer structures, e.g., boundary tracking [4], adaptive thresholding [5], gradient information in dual scales [6], texture and shape analysis [7]. However, hand-crafted algorithms are hard to generalize in the real world. Data-driven deep learning has been dominating medical image analysis [8, 9]. Researchers have developed deep learning-based methods for retinal layer segmentation [10, 11] with proven superiority over traditional algorithms. Nevertheless, the performance of deep learning approaches is built upon numerous retinal layer segmentation labels, which could be especially tedious for hundreds of slices in OCT B-scans.
In this study, we present an alternative strategy to obtain projection maps from OCT B-scans, WITHOUT explicit supervision of retinal layer segmentation. Instead, trained with pairs of OCT B-scans and the corresponding PMs, our end-to-end framework directly outputs the final target PMs. Although the PM ground truth is produced with retinal layer segmentation, paired B-scans and PMs could be easier to collect retrospectively as they are more likely to be stored in picture archiving and communication systems (PACS). As an example, the public OCTA-500 dataset [12] used in this study (Fig. 2 and Sec. 2.1) provides paired B-scans and PMs rather than retinal layer segmentation.
To our knowledge, this is the first study to predict projection maps from OCT B-scans in an end-to-end fashion. Unfortunately, a vanilla CNN without any geometric prior could only produce low-quality PMs. Therefore, we design a novel Differentiable Projection Module (DPM) to simulate the procedure of projection maps, inspired by spatial transformer networks [13]. As illustrated in Fig. 1, instead of segmenting retinal layers explicitly, our CNN backbone predicts them implicitly as coordinates of curves in 2D views or surfaces in 3D, which are processed into areas/volumes by uniform point sampling between layers. They could be interpolated into pixels from source B-scans, and finally projected into PMs via (average/max) pooling. All the above operations are differentiable, which could be seamlessly integrated into neural networks and end-to-end trainable with the supervision of PM ground truth. The proposed method produces high-quality PMs with proven superiority over baselines (Sec. 3). Independent from the clinical use cases, the introduced novel neural representation of areas/volumes between curves/surfaces, could be a technical contribution in shape modeling and geometric deep learning. To facilitate open research, our code is open source on GitHub111https://github.com/dingdingdingyi/projection-from-OCT.
2 Materials and Methods
2.1 Dataset Overview
We conducted experiments on the public OCTA-500 [12] dataset, specifically the OCTA_6M subset, which contains 300 subjects with 6mm 6mm FOV. As demonstrated in Fig. 2, each sample is paired OCT/OCTA B-scans and projection maps, where B2/B3 are projected from OCT and B5/B6 are from OCTA. Each slice of B-scans corresponds to a column in the projection maps. B2/B3 are the focus in this study. B2, projected by averaging pixels between ILM and OPL layers (i.e., the red and green curves in Fig. 4), shows the vessels in the inner retina with high reflection; and B3, averagely projected between OPL and BM layers (i.e., the green and blue curves in Fig. 4), shows the vessel shadows in the outer retina with low reflection. We use the official split, with 180/20/100 subjects (72,000/8,000/40,000 B-scans) for training, validation and evaluation, respectively.
2.2 Method Overview
As depicted in Fig. 1, we first use a 2D U-Net [14] at a B-scan level to output coordinates of retinal layers, which are then processed by the proposed Differentiable Projection Module (DPM) to generate the projection maps. As the ground truth of the target PMs is min-max normalized over the whole PMs, we design a Conditional Min-Max (CMM) normalization trick during training to rescale the intensity values. The model parameters are trained via loss back-propagation between the prediction and ground truth of PMs.
2.3 UNet Backbone
Given an B-scan of , we use a downsampled image of () as network input to reduce memory. The 2D U-Net backbone produces a feature map of . We apply a convolution with 3 output channels followed by a horizontal bilinear upsampling layer. The output, representing predicted coordinates of the 3 retinal layers (i.e., ILM, OPL and BM) are mapped into with .
2.4 Differentiable Projection Module (DPM)
For B2/B3, the projection maps are generated by averaging pixels between 2 layers (Sec. 2.1). For simplicity, we describe the Differentiable Projection Module (DPM) for 2 layers (i.e., ). As shown in Fig. 3, we uniformly sample points between the 2 layers (including the end points), which represent the vertical coordinates of sampled areas (). Then we assign a abscissa coordinate to each vertical coordinate, composing a spatial position matrix . For each coordinate, we could bilinearly interpolate the pixel value from the B-scan. This pixel interpolation operation introduced by the spatial transformer networks [13] could be easily implemented in modern deep learning frameworks, e.g., in PyTorch [15]. The sampled output of represents the warped image of pixels between the 2 layers. It could be averaged pooled vertically into , which is a column in the target projection map.
2.5 Conditional Min-Max Normalization Trick (CMM)
Our model produces a target PM column-by-column (i.e., slice-by-slice in B-scans); However, the PM ground truth provided in the OCTA-500 Dataset [12] is min-max normalized over the whole PMs. The accurate min-max values of prediction could be only obtained given all slices, which is yet impossible during training due to memory constraint. Inspired by the auto-decoding (NOT auto-encoding) training technique [16], we propose a Conditional Min-Max normalization trick to rescale the intensity of the projection maps:
(1) |
where and are learnable parameters for the maximum and minimum values of the projected B-scans, conditional on the training subject ID . Specifically, there are parameters for 180 training subjects in our study.
2.6 Training and Inference
The loss function of our model is composed of 2 terms: an L-1 loss and a feature loss . The measures the pixel-wise similarity, while the measures structural similarity by computing the feature similarity using a pretrained network [17]. The project loss for B2 is defined as
(2) |
where is the B2 ground truth, is the pre-normalzied prediction. When training the B2 and B3 projection maps simultaneously, the final loss is
(3) |
The normalization trick is only used during training. During inference, we predict pre-normalized outputs slice-by-slice, and then normalize it using the real min-max values on the predicted PMs.
3 EXPERIMENTS
3.1 Experiment Settings and Details
To validate the effectiveness of the proposed methods, we design the baselines: 1) CNN only: a vanilla CNN without DPM. 2) DPM only: an optimization-based DPM without the CNN part as a deep prior, i.e., optimizing the layer coordinates directly on each data pair. We use PSNR and SSIM to assess the predicted projection maps, where PSNR focuses on pixel similarity and SSIM focues on structural similarity.
Our framework is implemented in PyTorch [15], all networks were trained on 4 NVIDIA RTX 2080 Ti with an Adam optimizer [18] using a batch size of 72. The learning rate starts with and exponentially decays with a ratio of after every epoch. We use in Eq. 3 to balance the training for B2 and B3 with a simple grid search.
Model | PSNR (B2) | SSIM (B2) | PSNR (B3) | SSIM (B3) |
---|---|---|---|---|
CNN only | 28.0424 | 0.2869 | 27.9896 | 0.2332 |
DPM only | 27.8129 | 0.0663 | 27.8800 | 0.0683 |
CNN + DPM | 28.7781 | 0.7575 | 28.7288 | 0.7195 |
w/o CMM | 28.2814 | 0.5758 | 28.3460 | 0.6974 |
3.2 Quantitative Results
We quantitatively compared the performance of our method against CNN-only method and DPM-only method to assess the effect of each component of our model, we use PSNR and SSIM for evaluation. Table 1 gives a quantitative comparison of the results. From Table 1, we can observe that DPM-only method has the poorest performance, for this model has no component to understanding the input B-scan globally, just adjusts the layer segmentations blindly on the basis of the target projection map. CNN-only method also has a bad performance, we assume that this is because this method just uses a CNN network to predict the projection directly, ignoring the correct layered structure to get the target projection map.
In addition, we assessed the performance of CMM. The comparison demonstrates that adding CMM module is beneficial for learning layered structural information. However, the improvement of SSIM in B3 is not obvious, we assume that the reason is that B3 displays the vessel shadows in the outer retina with low reflection, in which vessel information is the dominance and the edges of tissues are distinct, so the accuracy of predicting ILM and BM layers does not have a great impact on projection B3 to a certain extent.
3.3 Qualitative Results
In Fig. 4, we show projection maps obtained from CNN-only, DPM-only, our method (CNN+DPM) and without-CMM, and layers predicted by our method. The CNN-only model can produce most blood vessels successfully but has serious distortion in details. The DPM-only model totally fails to produce a projection map, the image is full of noise for the predicted layers have crossed, but the contours of some thick blood vessels are visible. From the last three images, we can clearly see that adding CMM module can help produce a clearer projection map and get more precise layers.
3.4 Robustness on Extreme Cases
In Fig. 5 (a), we illustrate two cases with lesions, which are difficult to predict layers. In these cases, the boundaries of layers are blurred due to the presence of the diseased areas, which bring difficulties to the segmentation task. Our model tries to learn structural feature of diseased areas, and get acceptable results.
In Fig. 5 (b), we illustrate two typical failure cases. In the first example, our model fails to correctly segment the bottom layer of the B-scan, result in a serious distortion of the bottom part of the predicted projection map. The reason is that the bottom right part of the poor-quality original image lacks structural information. In the second example, the layers predicted by our model can not fit the target layers in B-scan very well, lead to generating a poor projection. The feature of the original B-scan is that the imaging of the retina is steep, which is rare in the used dataset, leading our model to learn more about flat B-scans and trying to predict smoother layers.
3.5 Transferring from OCT to OCTA
In this section, we try to transfer the segmentations to the OCTA B-scans. OCTA is non-invasive technique that shows details of blood vessels that have low inherent contrast in OCT images. Given that OCTA images are generated from the OCT images, we directly transfer the layers of OCT to the OCTA, no need to train the model on OCTA again, and generate good quality projection maps. In Fig. 6, we illustrate two cases to show the results of transformation. Fig. 6 (b) shows the projection results of OCTA, we can see that the synthesized projection map generates the shape of blood vessels greatly and is structurally coherent with its corresponding ground-truth.
(a) Success on Extreme Cases
(b) Failure on Extreme Cases
(a) Prediction on OCT
(b) Transfer to OCTA
4 Conclusion
In this study, we present a novel end-to-end framework to implicitly learn layer boundaries of an OCT B-scan from its projection map. Integrated with a 2D U-Net, the proposed end-to-end trainable CNN-DPM encourages the model to learn the structural feature of the input B-scan and predict layers from final projection maps. The qualitative and quantitative results demonstrate that the projection maps our framework generates have a similar texture with the real projection maps, which means that the layers are accurately predicted. However, for each B-scan of a subject, we predict layers independently, ignoring the structural continuous relationship between them, leading to streak distortion in generated projection. In our future work, we will improve our network to perform segmentation from the 3D OCT volume.
5 COMPLIANCE WITH ETHICAL STANDARDS
This research study was conducted retrospectively using human subject open-source data. Ethical approval was not required as confirmed by the license.
6 ACKNOWLEDGEMENT
This work was supported by National Science Foundation of China (U20B2072, 61976137). This work was also partially supported by Grant YG2021ZD18 from Shanghai Jiaotong University Medical Engineering Cross Research.
References
- [1] David Huang, Eric A Swanson, et al., “Optical coherence tomography,” science, vol. 254, no. 5035, pp. 1178–1181, 1991.
- [2] Meindert Niemeijer, Mona K Garvin, et al., “Vessel segmentation in 3d spectral oct scans of the retina,” in Medical Imaging 2008: Image Processing. International Society for Optics and Photonics, 2008, vol. 6914, p. 69141R.
- [3] Qiang Chen, Sijie Niu, Songtao Yuan, Wen Fan, and Qinghuai Liu, “High–low reflectivity enhancement based retinal vessel projection for sd-oct images,” Medical physics, vol. 43, no. 10, pp. 5464–5474, 2016.
- [4] Andrew Lang, Aaron Carass, et al., “Retinal layer segmentation of macular oct images using boundary classification,” Biomedical optics express, vol. 4, no. 7, pp. 1133–1152, 2013.
- [5] Hiroshi Ishikawa, Daniel M Stein, et al., “Macular segmentation with optical coherence tomography,” Investigative ophthalmology & visual science, vol. 46, no. 6, pp. 2012–2017, 2005.
- [6] Qi Yang, Charles A Reisman, et al., “Automated layer segmentation of macular oct images using dual-scale gradient information,” Optics express, vol. 18, no. 20, pp. 21293–21307, 2010.
- [7] Vedran Kajić, Boris Považay, et al., “Robust segmentation of intraretinal layers in the normal human fovea using a novel statistical model based on texture and shape analysis,” Optics express, vol. 18, no. 14, pp. 14730–14744, 2010.
- [8] Dinggang Shen, Guorong Wu, and Heung-Il Suk, “Deep learning in medical image analysis,” Annual review of biomedical engineering, vol. 19, pp. 221–248, 2017.
- [9] Geert Litjens, Thijs Kooi, et al., “A survey on deep learning in medical image analysis,” Medical image analysis, vol. 42, pp. 60–88, 2017.
- [10] Abhay Shah, Leixin Zhou, Michael D Abrámoff, and Xiaodong Wu, “Multiple surface segmentation using convolution neural nets: application to retinal layer segmentation in oct images,” Biomedical optics express, vol. 9, no. 9, pp. 4509–4526, 2018.
- [11] Leyuan Fang, David Cunefare, et al., “Automatic segmentation of nine retinal layer boundaries in oct images of non-exudative amd patients using deep learning and graph search,” Biomedical optics express, vol. 8, no. 5, pp. 2732–2744, 2017.
- [12] Mingchao Li, Yuhan Zhang, et al., “Ipn-v2 and octa-500: Methodology and dataset for retinal image segmentation,” arXiv preprint arXiv:2012.07261, 2020.
- [13] Max Jaderberg, Karen Simonyan, Andrew Zisserman, et al., “Spatial transformer networks,” Advances in neural information processing systems, vol. 28, pp. 2017–2025, 2015.
- [14] Olaf Ronneberger, Philipp Fischer, and Thomas Brox, “U-net: Convolutional networks for biomedical image segmentation,” in International Conference on Medical image computing and computer-assisted intervention. Springer, 2015, pp. 234–241.
- [15] Adam Paszke, Sam Gross, et al., “Pytorch: An imperative style, high-performance deep learning library,” Advances in neural information processing systems, vol. 32, pp. 8026–8037, 2019.
- [16] Jeong Joon Park, Peter Florence, et al., “Deepsdf: Learning continuous signed distance functions for shape representation,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019, pp. 165–174.
- [17] Justin Johnson, Alexandre Alahi, and Li Fei-Fei, “Perceptual losses for real-time style transfer and super-resolution,” in European conference on computer vision. Springer, 2016, pp. 694–711.
- [18] Diederik P Kingma and Jimmy Ba, “Adam: A method for stochastic optimization,” arXiv preprint arXiv:1412.6980, 2014.