One Network to Solve Them All: A Sequential Multi-Task Joint Learning Network Framework for MR Imaging Pipeline
Abstract
Magnetic resonance imaging (MRI) acquisition, reconstruction, and segmentation are usually processed independently in the conventional practice of MRI workflow. It is easy to notice that there are significant relevances among these tasks and this procedure artificially cuts off these potential connections, which may lead to losing clinically important information for the final diagnosis. To involve these potential relations for further performance improvement, a sequential multi-task joint learning network model is proposed to train a combined end-to-end pipeline in a differentiable way, aiming at exploring the mutual influence among those tasks simultaneously. Our design consists of three cascaded modules: 1) deep sampling pattern learning module optimizes the -space sampling pattern with predetermined sampling rate; 2) deep reconstruction module is dedicated to reconstructing MR images from the undersampled data using the learned sampling pattern; 3) deep segmentation module encodes MR images reconstructed from the previous module to segment the interested tissues. The proposed model retrieves the latently interactive and cyclic relations among those tasks, from which each task will be mutually beneficial. The proposed framework is verified on MRB dataset, which achieves superior performance on other SOTA methods in terms of both reconstruction and segmentation. The code is available online: https://github.com/Deep-Imaging-Group/SemuNet
Keywords:
fast MRI Deep Learning Sampling Learning Image Reconstruction Segmentation.1 Introduction
Magnetic resonance imaging (MRI) is a non-invasive diagnostic imaging technique that enables studying low-contrast soft tissue structures without harmful radiation risk. However, its long acquisition time results in increasing costs, patient uncomfortableness, and motion artifacts. To conquer these obstacles, fast MRI acquisition is of great emergency. Nevertheless, simply reducing the sampling rate will degrade the imaging quality and jeopardize the sequential diagnosis. In the past decades, numerous efforts have made to recover high-quality MR images from undersampled -space data, e.g., compressed sensing (CS) and later deep learning based methods. In spite of fruitful results obtained, two defects can be sensed: 1) current undersampling patterns are empirically handtailored, e.g., radial, Cartesian, or Gaussian, which ignore the fact that different images may be suitable for different undersampling patterns; 2) suboptimal sampling pattern will lead to suboptimal reconstruction and finally impact the sequential analysis task. In summary, isolatedly handling the main steps in the whole imaging pipeline reveals the potential fact that both radiologists and computer aided intervention systems may be working with suboptimal reconstructed images.
Recently, in the field of signal processing, task driven methods, which directly train an end-to-end network and neglects the explicit intermediate result, have drawn increasingly attention. For examples, Bojarski et al. trained a self-driving network which learns commands directly from cameras without recogniting any landmarkers [1]. Liu et al. proposed to integrate the denoising network with a segmentation network to improve the denoising performance for segmentation [2]. Encouraged by these promising results, similar ideas were introduced into the field of medical imaging. Wu et al. and Lee et al., respectively proposed to detect the pulmonary nodules and intracranial hemorrhage directly from the measured data without the step of image reconstruction [3, 4]. In [5, 6], the authors coupled MRI reconstruction with segmentation to improve the performance of both tasks. On the other hand, some recent studies attempted to optimize the -space sampling patterns with a data-driven manner [7, 8, 9] and the optimized undersampling patterns show significant improvements, compared to empirical ones. Unfortunately, the scheme of these trajectories only learned from the reconstruction stage ignore the tissue of interest. Meanwhile, these methods mentioned above either combine the sampling and reconstruction, or joint the tasks of reconstruction and segmentation. None of the existing works consider the whole pipeline of medical image analysis, which means that useful information for final segmentation may be lost in each step.
To fully explore the mutual influence among sequential tasks and further improve the performance of each task simultaneously, in this study, we propose a sequential multi-task joint learning network framework (SemuNet), which jointly optimizes the sampling, reconstruction and segmentation in an end-to-end manner. The proposed framework can be divided into three modules: the sampling pattern learning network (SampNet), the reconstruction network (ReconNet), and the segmentation network (SegNet). Specifically, the well-known U-Net [10] is adopted as the backbone of our proposed ReconNet and SegNet for simplicity, and a probabilistic sampling network is proposed to learn the sampling pattern.
The remainder of this paper is organized as follows. The details of the proposed model, including each module, are elaborated in Section 2. The experimental results are presented and discussed in Section 3 and the final section concludes this paper.
2 Method
In this section, the main modules of the proposed framework SemuNet, including SampNet, ReconNet and SegNet, are first described sequentially in detail. Then other issues of the SemuNet, especially about the training strategy and loss function, are presented.
2.1 SampNet: the sampling pattern learning network
For the problem of CS-MR imaging, the task is to reconstruct an MR image from undersampled measurements in -space, which approximates a fully-sampled MRI image . Let denotes the SampNet parameterized by , which outputs a continuous value matrix (i.e., sampling pattern) as a partial observation in -space. The undersampling process can be written as , where is Hadamard product, F is the Fourier transform matrix. The goal of SampNet is to optimize the sampling pattern for specific datasets in the -space. To learn a probabilistic observation matrix in the -space, we adopt the similar architecture to the [8, 10, 12] for our SampNet. The architecture of SampNet is shown in Fig. 1a. The details of SampNet is given in the supplementary material.
Since we do not have the labels for sampling pattern learning, we propose to merge the SampNet into the ReconNet and SegNet. When the cascaded network converges, the top- largest values in are replaced by Boolean values to produce the final sampling pattern T, and is chosen according to the predetermined sampling rate , where . Accordingly, the Booleanizing operation can be written defined as:
(1) |
As a result, the pattern is optimized by the knowledge of both high-quality reconstructed images and accurate segmentation labels.
2.2 ReconNet: the reconstruction network
Recently, extensive network models were proposed for MRI reconstruction [13], and in this work, we simply utilize the spatial-domain based reconstruction network. Letting denote the ReconNet with parameter set , is the inverse Fourier transform matrix, the reconstructed image can be obtained as:
(2) |
Then the training procedure can be formulated as the following optimization problem:
(3) |
where is a reconstruction metric function to measure the similarity between the reconstructed image and the label, and is the expectation over x. The architecture of ReconNet adopts the well-known U-Net [10] as the backbone as shown in Fig. 1b, which has demonstrated competitive performance in artifact reduction for MRI [8, 14].
2.3 SegNet: the segmentation network
Recently, lots of networks were proposed for automatic tissue segmentation [15]. Since U-Net like architecture has demonstrated excellent performance for medical image segmentation, in this part, we also choose the same network structure in Fig. 1b as our SegNet for simplicity. Then we can formulate the joint learning for simultaneously optimizing sampling, reconstruction and segmentation as follows
(4) |
where is the segmentation metric function to measure the segmentation accuracy of the result compared to the segmentation labels and is the segmentation network with parameter set .
The SegNet plays two roles. First, it is treated as a clinical analysis instructor to train ReconNet, such that the reconstruction network can better adapt to tissue segmentation work. Second, it serves as a radiologist, which can provide SampNet with sufficient clinical knowledge.
2.4 SemuNet: the sequential multi-task joint learning network framework
By cascading the previously mentioned SampNet, ReconNet and SegNet as the basic modules, we propose a deep joint learning framework for the whole MRI pipeline, which can: 1). learn an optimized sampling pattern simultaneously guided by both low- and high-level tasks, i.e. reconstruction and segmentation; 2). reconstruct high-quality MR images with the optimized sampling pattern for the downstream segmentation task; 3). and segment the target tissues more accurate based on the task-driven reconstruction.
Since these modules are cascaded and trained in an end-to-end manner, the features extracted from different tasks are mutually influenced in an interactive way and benefit from each other.

The overview of the proposed joint learning network framework is illustrated in Fig. 1a and Fig. 1c. It can be seen that the networks in training and testing stages are different. During the training stage, since we need to learn the sampling pattern for the specific dataset with fully-sampled -space data as labels, the whole framework has three parts. During the testing stage, since we can directly use the optimized sampling pattern to acquire the undersampled -space data, SampNet is abandoned and the undersampled -space data is fed into the ReconNet and SegNet in sequence. Finally, the estimated reconstruction and segmentation probability map are obtained.
2.4.1 Training Strategy.
At the beginning of training stage, the whole network is initialized randomly. The cascaded modules are trained in an end-to-end manner, which updates the weights of three modules simultaneously using backpropagation. The reason to adopt such training strategy is to guarantee the learned sampling pattern can acquire the useful information as more as possible for the subsequent reconstruction and segmentation tasks. More specifically, the proposed SemuNet can be easily adapted to different clinical tasks and we can substitute the SegNet with any other task networks. Our approach not only facilitates the training effort while imposing ReconNet to fit clinical tasks and keeping SegNet performing accurately for undersampled MR images but also enables SampNet to learn more clinically useful features from the -space data.
2.4.2 Loss function.
For MR images reconstruction, norm is adopted as the loss function:
(5) |
Cross-entropy loss is utilized for the SegNet:
(6) |

for brain tissues class labels and pixel number of an image, where is the pixel-level target label and is the pixel-level Softmax segmentation probability for the class of the pixel. Then the hybrid loss function for the proposed joint learning network is formulated as:
(7) |
3 Experiments and Discussion
3.1 Experimental Details
3.1.1 Dataset and baselines.
The brain dataset from the Grand Challenge on MR Brain Image Segmentation workshop (MRB) [17] is used to evaluate the proposed method. The dataset is acquired using 3.0T MRI scan and consists of five patients. The dataset of each patient is provided with four MRI modalities: T1, T1-1mm, T1-IR and T2-FLAIR with size of . The brain tissues of each patient are manually labeled with seven types of tissue (T1): cortical gray matter, basal ganglia, white matter, cerebrospinal fluid in the extracerebral space, ventricles, cerebellum, and brainstem. In our experiment, four T1 datasets are used for training and the remaining one for testing.
3.1.2 Experiment setup.
All implementations are based on Pytorch. All models are trained using one Quadro RTX 8000 GPU and the batch size is set to 12. The hyperparameter configuration of both ReconNet and SegNet are given in Fig. 1b. Uniform random initialization is used for SampNet and Xavier initialization for ReconNet and SegNet. The whole SemuNet is trained for 600 epochs. After that, the ReconNet and SegNet are fine-tuned for additional 500 epochs. ADAM [18] optimizer is adopted with an initial learning rate of . is empirically set to .
3.1.3 Baseline.
Two basic variants of our SemuNet framework are built: (1) Baseline = fixed pattern + ReconNet + SegNet; and (2) LOUPESeg = LOUPE + SegNet. LOUPE is a recently proposed sampling pattern learning model driven

by reconstruction [8]. We first trained LOUPE with high quality MR images and then SegNet is trained with the data generated by LOUPE. PSNR and SSIM are adopted as quantitative metrics.
3.2 Experiments Results
To validate the performance of the proposed SemuNet, we separately evaluate the results of reconstruction and segmentation.
For reconstruction, we compare the proposed SemuNet with the following methods: (1) Baseline (only use its reconstruction result); (2) Liu et al. [2] with a fixed pattern (only use its reconstruction result); (3) LOUPE; (4) MD-Recon-Net [16] (a recently proposed dual-domain reconstruction network) with a fixed pattern. The learned trajectories for LOUPE and SemuNet, and the fixed patterns used in Baseline, Liu et al. [2] and MD-Recon-Net [16] are shown in Fig. 2a, respectively.
Metric | Baseline | Liu et al. [2] | MD-Recon-Net [16] | LOUPE [8] | SemuNet | ||||
---|---|---|---|---|---|---|---|---|---|
Radial | Random | Radial | Random | Radial | Random | Learned | Learned | ||
20% | PSNR | 36.30 | 32.69 | 36.17 | 32.81 | 39.44 | 39.34 | 38.82 | 39.24 |
SSIM | 96.67 | 94.01 | 96.24 | 93.76 | 98.09 | 97.80 | 98.09 | 98.56 | |
10% | PSNR | 31.30 | 30.26 | 31.14 | 30.31 | 32.62 | 33.95 | 33.63 | 34.30 |
SSIM | 90.27 | 90.69 | 90.93 | 90.73 | 93.59 | 94.47 | 95.19 | 96.47 | |
5% | PSNR | 27.45 | 28.95 | 27.26 | 28.87 | 26.83 | 30.12 | 30.96 | 31.20 |
SSIM | 85.00 | 88.46 | 84.45 | 88.54 | 85.13 | 92.00 | 91.31 | 93.16 |
Baseline | Liu et al. [2] | LOUPE-Seg | SemuNet | |||
Random | Radial | Random | Radial | |||
20% | 70.65 | 73.91 | 71.64 | 73.77 | 76.19 | 76.79 |
10% | 68.3 | 70.48 | 67.73 | 70.92 | 72.91 | 75.08 |
5% | 66.6 | 64.54 | 64.66 | 63.65 | 70.97 | 72.45 |
In Fig. 3a, one typical slice reconstructed using different methods is chosen for visual comparison. It can be observed that the proposed SemuNet achieves the minimal reconstruction error and preserves more details than other methods which can be confirmed in the magnified regions. The average values of the quantitative metrics on the 48 test data (from one patient) are listed in Table 1. It is noticed that our method achieves the highest scores in most situations, which can be seen as a powerful evidence of that integrating sampling learning and segmentation tasks can efficiently improve the reconstruction performance.
As for segmentation, we compare our method with several methods: (1) Baseline; (2) LOUPESeg; and (3) Liu et al. [2]. The results of one representative slice are demonstrated in Fig. 3b. Each tissue is marked with a different color. It can be observed that the proposed SemuNet provides the most approximate visual result to the ground truth. Dice Similarity Coefficient (DSC) [19] is adopted as the quantitative metric and the results are list in Table 2. The quantitative results are consistent with the subjective evaluation, which confirm that introducing both sampling and reconstruction learning into segmentation network can further increase the accuracy. It is worth noting that the Baseline and Liu et al. [2] obtain much lower accuracy than other methods as shown in both Fig. 3b and Table 2, which shows the merit of undersampled MR image reconstruction with sampling learning as a preprocessing step for segmentation task. When we only apply sampling pattern learning without considering segmentation task, it also fails to achieve the highest accuracy since the reconstruction does not fully explore the latent features transferred from the segmentation task.
4 Conclusion
Sampling pattern learning is an important problem for MR imaging. With the recent developments of fast MRI in the industry, sampling pattern learning technique that takes both reconstruction and analysis tasks into account are of great significance. In this paper, a joint learning framework SemuNet, is proposed to integrate sampling pattern learning, reconstruction and segmentation into a unified network. The results demonstrate the joint learning strategy can benefit all the tasks from each other. In the future work, more datasets will used for evaluation and different analysis tasks will be considered.
References
- [1] Bojarski, M., Del Testa, D., Dworakowski, D., Firner, B., Flepp, B., Goyal, P., Jackel, L.D., Monfort, M., Muller, U., Zhang, J., Zhang, X., Zhao, J., Zieba, K.: End to End Learning for Self-Driving Cars. arXiv:1604.07316 [cs]. (2016).
- [2] Liu, D., Wen, B., Jiao, J., Liu, X., Wang, Z., Huang, T.S.: Connecting Image Denoising and High-Level Vision Tasks via Deep Learning. IEEE Transactions on Image Processing. 29, 3695–3706 (2020).
- [3] Wu, D., Kim, K., Dong, B., Fakhri, G.E., Li, Q.: End-to-End Lung Nodule Detection in Comput-ed Tomography. In: Shi, Y., Suk, H.-I., and Liu, M. (eds.) Machine Learning in Medical Imag-ing. pp. 37–45. Springer International Publishing, Cham (2018).
- [4] Lee, H., Huang, C., Yune, S., Tajmir, S.H., Kim, M., Do, S.: Machine Friendly Machine Learning: Interpretation of Computed Tomography Without Image Reconstruction. Scientific Reports. 9, 15540 (2019).
- [5] Sun, L., Fan, Z., Ding, X., Huang, Y., Paisley, J.: Joint CS-MRI Reconstruction and Seg-mentation with a Unified Deep Network. In: Chung, A.C.S., Gee, J.C., Yushkevich, P.A., and Bao, S. (eds.) Information Processing in Medical Imaging. pp. 492–504. Springer International Publishing, Cham (2019).
- [6] Fan, Z., Sun, L., Ding, X., Huang, Y., Cai, C., Paisley, J.: A Segmentation-aware Deep Fusion Network for Compressed Sensing MRI. Presented at the Proceedings of the European Conference on Computer Vision (ECCV) (2018).
- [7] Zijlstra, F., Viergever, M.A., Seevinck, P.R.: Evaluation of Variable Density and Data-Driven K-Space Undersampling for Compressed Sensing Magnetic Resonance Imaging. Investi-gative Radiology. 51, 410–419 (2016).
- [8] Bahadir, C.D., Dalca, A.V., Sabuncu, M.R.: Learning-based Optimization of the Under-sampling Pattern in MRI. arXiv:1901.01960 [cs, eess, stat]. (2019).
- [9] Jin, K.H., Unser, M., Yi, K.M.: Self-Supervised Deep Active Accelerated MRI. arXiv:1901.04547 [cs]. (2019).
- [10] Ronneberger, O., Fischer, P., Brox, T.: U-Net: Convolutional Networks for Biomedical Image Segmentation. In: Navab, N., Hornegger, J., Wells, W.M., and Frangi, A.F. (eds.) Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015. pp. 234–241. Springer International Publishing, Cham (2015).
- [11] Jang, E., Gu, S., Poole, B.: Categorical Reparameterization with Gumbel-Softmax. arXiv:1611.01144 [cs, stat]. (2017).
- [12] Maddison, C.J., Mnih, A., Teh, Y.W.: The Concrete Distribution: A Continuous Relax-ation of Discrete Random Variables. arXiv:1611.00712 [cs, stat]. (2017).
- [13] Wang, G., Ye, J.C., Mueller, K., Fessler, J.A.: Image Reconstruction is a New Frontier of Machine Learning. IEEE Transactions on Medical Imaging. 37, 1289–1296 (2018).
- [14] Yang, G., Yu, S., Dong, H., Slabaugh, G., Dragotti, P.L., Ye, X., Liu, F., Arridge, S., Keegan, J., Guo, Y., Firmin, D.: DAGAN: Deep De-Aliasing Generative Adversarial Networks for Fast Compressed Sensing MRI Reconstruction. IEEE Transactions on Medical Imaging. 37, 1310–1321 (2018).
- [15] Hesamian, M.H., Jia, W., He, X., Kennedy, P.: Deep Learning Techniques for Medical Image Segmentation: Achievements and Challenges. J Digit Imaging. 32, 582–596 (2019).
- [16] Ran, M., Xia, W., Huang, Y., Lu, Z., Bao, P., Liu, Y., Sun, H., Zhou, J., Zhang, Y.: MD-Recon-Net: A Parallel Dual-Domain Convolutional Neural Network for Compressed Sensing MRI. IEEE Transactions on Radiation and Plasma Medical Sciences. 5, 120–135 (2021).
- [17] Mendrik, A.M., Vincken, K.L., Kuijf, H.J., Breeuwer, M., Bouvy, W.H., de Bresser, J., Alansary, A., de Bruijne, M., Carass, A., El-Baz, A., Jog, A., Katyal, R., Khan, A.R., van der Lijn, F., Mahmood, Q., Mukherjee, R., van Opbroek, A., Paneri, S., Pereira, S., Persson, M., Rajchl, M., Sarikaya, D., Smedby, Ö., Silva, C.A., Vrooman, H.A., Vyas, S., Wang, C., Zhao, L., Biessels, G.J., Viergever, M.A.: MRBrainS Challenge: Online Evaluation Framework for Brain Image Segmentation in 3T MRI Scans. Comput Intell Neurosci. 2015, (2015).
- [18] Kingma, D.P., Ba, J.: Adam: A Method for Stochastic Optimization. arXiv:1412.6980 [cs]. (2017).
- [19] Crum, W.R., Camara, O., Hill, D.L.G.: Generalized Overlap Measures for Evaluation and Validation in Medical Image Analysis. IEEE Transactions on Medical Imaging. 25, 1451–1461 (2006).
Appendix 0.A Supplementary Material
