Recurrence With Correlation Network for Medical Image Registration
Abstract
We present Recurrence with Correlation Network (RWCNet), a medical image registration network with multi-scale features and a cost volume layer. We demonstrate that these architectural features improve medical image registration accuracy in two image registration datasets prepared for the MICCAI 2022 Learn2Reg Workshop Challenge. On the large-displacement National Lung Screening Test (NLST) dataset, RWCNet is able to achieve a total registration error (TRE) of 2.11mm between corresponding keypoints without finetuning. On the OASIS brain MRI dataset, RWCNet is able to achieve an average dice overlap of 81.7% for 35 different anatomical labels. It outperforms another multi-scale network, the Laplacian Image Registration Network (LapIRN), on both datasets. Ablation experiments are performed to highlight the contribution of the various architectural features. While multi-scale features improved validation accuracy for both datasets, the cost volume layer and number of recurrent steps only improved performance on the NLST dataset. This result suggests that cost volume layer and iterative refinement using RNN provide good support for optimization and generalization in large-displacement medical image registration. The code for RWCNet is available at https://github.com/vigsivan/optimization-based-registration.
Keywords:
Medical Image Registration Deep Learning Optical Flow1 Introduction
Medical image registration describes the task of spatially aligning one medical image to another. It has a myriad of uses, including image-guided pre-operative planning [1] and atlas creation for population-level studies [2]. Traditionally, image registration has been solved using optimization methods. There is growing interest in the application of deep learning to the task of medical image registration. An advantage of using deep learning over conventional optimization approaches is that it can often arrive at optimal solutions orders of magnitude faster. Based on deep learning’s success in the field of computer vision, it may also yield more accurate and more robust spatial transforms than conventional optimization [3]. Indeed, recent work on deep learning image registration (DLIR) has shown its potential for yielding faster and more accurate transforms than conventional methods for some datasets [3, 4, 5].
In terms of inputs and outputs, medical image registration is closely related to optical flow, where the objective of the latter is to compute a flow field describing the motion of objects between two images of, typically, the same scene. Optical flow is an important and well-studied computer vision problem and it has a range of applications including action recognition and pose estimation. Common architectural features of optical flow networks include multi-scale features [6] and cost volume layers [7, 6, 8].
The computation of multi-scale features and cost volumes for 3D medical image registration have been separately explored in prior work [5, 9], which have demonstrated improvements in registration performance on standard datasets. This paper introduces a novel, optical flow-inspired network architecture for DLIR, recurrence with correlation network (RWCNet). RWCNet combines multi-scale iterative features and a cost-volume layer for medical image registration. To the authors’ knowledge, no other work combines these architectural features for DLIR. The contributions of this paper are listed as follows:
-
•
We present a novel network architecture for DLIR, recurrence with correlation network (RWCNet), for DLIR that outperforms other continuous-domain and multiscale networks in some standard registration datasets prepared for the MICCAI 2022 Learn2Reg Workshop.
-
•
Ablations of architectural components are performed to demonstrate the performance of the various architectural features of the registration network. The results indicate that the contributions of various architectural features can vary by dataset.
2 Related Work
2.0.1 Voxelmorph
Voxelmorph is one of the earliest and best known methods for DLIR [3]. It trains a 3D UNet [10] to learn a sub-voxel displacement field, , that jointly optimizes image fidelity of the transformed image, denoted , and a regularization loss promoting smooth spatial transformation, denoted . The loss function, , describing the goodness of fit for registering a moving image to a fixed image is expressed as:
(1) |
where is the identity transform and resamples the moving image onto a new grid parametrized by the displacement field, . This resampling operation is carried out differentiably using the spatial transformer network (STN) [11]. The loss function described by (1) can be augmented with the correspondence of auxiliary data such as segmentations and keypoints. Voxelmorph was found to be competitive with optimization-based registration methods such as Demons [12] and Symmetric Normalization [13].
2.0.2 Laplacian Image Registration Network
Laplacian Pyramid Image Registration Network (LapIRN) [5] learns flow fields at resolutions. At each resolution a CNN with residual skip connections learns a displacement field by jointly optimizing an image fidelity term and smoothness term described by (1). LapIRN is trained in a coarse to fine manner; thus the CNN’s at low resolutions are trained before training the networks responsible for learning at higher resolutions. The flow fields learned at low resolutions are upsampled and used to warp the moving input image at higher resolutions.
Multi-scale refinement operates on the principle that an optimal displacement field at low resolution is also a good displacement field at high resolution [5]. Moreover, at coarse resolution the optimization problem is simpler, owing to the fact that the required displacements are smaller in magnitude and there are fewer high-level features to match; refining the flow field from a coarse-to-fine resolution can thus simplify the optimization problem. LAPIRN showed improvement in the large-displacement setting, when compared to Voxelmorph and conventional approaches. LapIRN uses convolutional neural network (CNN) architecture with residual connections to learn the flow fields at each resolution.
The method presented in this paper also learns flow fields from coarse to fine resolutions, like LapIRN. RWCNet, the architecture presented in this paper, differs from the one used by LapIRN; a recurrent CNN architecture with features encoders and a cost volume layer is used.
2.0.3 RAFT
The architecture of RWCNet is directly inspired by RAFT, which achieves impressive performance in optical flow [8]. The RAFT network architecture has three key characteristics: 1) CNN feature encoders, 2) cost volume computation at multiple scales between all pairs and 3) recurrent CNN architecture for computing and refining the flow field and performing a ‘lookup’ that subsamples the global cost volume. RWCNet differs from RAFT in several ways. First, computing the cost between all pairs of voxels becomes prohibitively expensive for 3D volumes. Furthermore, computing the cost volume at multiple resolutions simultaneously is also computationally expensive. As such, RWCNet consists of three sub-networks that learn displacement fields at three different resolutions, similar to LapIRN for DLIR or PWCNet[6] for optical flow.
3 Methods
3.0.1 Sub-network Architecture.
The architecture for the recurrent CNN sub-network is shown in Figure 1. Given a fixed and moving image pair ( and , respectively), the network first learns fixed and moving features ( and ) by feeding both images through a feature extractor network. A voxel-wise correlation between the fixed and moving features is computed, . Due to the large number of dimensions, the correlation is restricted so that only voxels within a certain range, of the moving voxel are considered. Additionally, the moving image is fed through a context network that extracts contextual information for the hidden network. The output of the contextual network is used as the initial hidden state of the RNN, . Finally, a displacement field with zero displacement, is initialized.
The hidden state, flow, cost volume and moving image are fed into an update block that is a modified gated recurrent unit (GRU) [14]. The GRU is almost identical in implementation to the one used by RAFT [8], and it outputs a new hidden state and a new displacement field, and . The new displacement field is used to update the aggregate displacement, i.e, . This new displacement field is used to warp the moving features (generating ), which can be used to generate a new cost volume, , for the next RNN time step. This process of updating displacement field with GRU cell and generating cost volume is repeated for RNN time steps.

3.0.2 Coarse to Fine Registration.
We adopt a course to fine approach to image registration. For each resolution, , a new RNN is trained using inputs from the previous resolution. The weights from previous resolutions are frozen at finer resolutions. At fine resolutions computing the cost volume for the whole volumes becomes prohibitively expensive; as such, our approach divides the input images into uniform, non-overlapping windows or patches. The size of the patches at each resolution is parameterized by the ‘patch factor’, . The size of the patches at resolution is computed as where is the size of the full image at resolution. At higher resolutions, the flow field is used to warp the initial moving image (or patch) using flow fields computed at lower resolutions. Furthermore, the final hidden state is cached at lower resolutions and added to the initial hidden state at higher resolutions, increasing the non-linearity of the network and providing additional context to the network.

3.0.3 Ablation Experiments.
Ablation experiments are performed to assess the contribution of the individual architectural features. To test the impact of multi-scale refinement, the network is trained to learn registration fields at a single resolution. To study the impact of correlation, the cost volume computation is removed; instead the input to the GRU is simply a concatenation of the input features, which is consistent with Voxelmorph and LapIRN.
4 Experiments
4.1 Datasets
Experiments are performed with the OASIS [15] and NLST [16] datasets prepared for the MICCAI 2022 Learn2Reg workshop challenge [17]. The OASIS dataset consists of 414 T1-weighted MRI scans of individuals from ages 18-96 with mild to severe Alzheimer’s. The scans are skull-stripped and resampled onto an isotropic grid and cropped to a uniform size. 35 segmentation labels are provided for important brain regions. The dataset is split into 395 images for training and 19 for validation. Intersubject registration in this context could be used for constructing a sub-population brain atlas or for analysing intensity changes in consistent brain regions that are linked to disease progression.
NLST is a lung-CT dataset with pairs of inhale/exhale scans; keypoints and masks are provided by the Learn2Reg challenge for semi-supervised training. We use a subset of the image pairs (100 out of 150) of the NLST dataset released by the Learn2Reg challenge for training and validation, with a 90:10 training/validation. Since respiration is accompanied by a large change in lung volume, the displacement field required to register NLST is large, relative to OASIS.
4.2 Training Parameters
The subnetworks at each scale were trained separately from coarse to fine. At higher resolutions, corresponding patches of the fixed and moving images are passed into the network to decrease GPU memory requirements. Table 1 shows the number of steps and the sizes of the inputs at each resolution. To address overfitting, dropout with probability 0.5 is used for the feature networks. The network takes about 30 hours to train both the NLST and OASIS datasets on an NVIDIA A100 with 32GB of RAM.
For OASIS, the similarity component of the loss function was a summation of the mean squared error (MSE) between the warped moving image and the fixed image intensities as well as the Dice loss between the warped segmentation and the fixed segmentation. The regularization loss was the average gradient of the flow field, as used in Voxelmorph. For NLST, the data was range normalized to between 0 and 1, with -4000 serving as the minimum value and 16000 serving as the maximum value. The loss function was a weighted summation of the MSE and the total registration error (TRE). The TRE measures the discrepancy in mm between corresponding keypoints in the fixed and moving images. The mean gradient of the flow field was used as the regularization loss on the flow field. For both datasets, an Adam optimizer with learning rate of was used.
LAPIRN, another multi-resolution model, was trained on OASIS using training parameters from [5]. We use the non-diffeomorphic variant, which does not ensure topological consistency of the spatial transformation, but achieves greater quantitative accuracy. For NLST, we augment training with a supervised discrepancy loss between the fixed and moving keypoints once the displacement is applied to the moving keypoints. We use a MSE loss instead of the normalized cross correlation (NCC) loss prescribed by the original paper to be consistent with the the loss used for training RWCNet.
Resolution | RNN Steps | Patches Per Image | Patch Factor | Training Steps |
---|---|---|---|---|
0.25 | 12 | 1 | 30000 | |
0.5 | 12 | 8 | 45000 | |
1 | 4 | 8 | 60000 |
5 Results
Figure 3 shows qualitative results generated for the NLST and OASIS datasets. Table 2 summarizes the results when comparing RWCNet with LAPIRN on the NLST and OASIS datasets. RWCNet outperforms LAPIRN on both datasets. However, the difference in Dice is only 0.7%, which might be entirely explainable by random weight initialization in network training. For the NLST dataset, the difference in performance is much more pronounced, with the difference in average TRE being 3mm. These results suggest that the architectural features of RWCNet, significantly aid generalization performance in the more challenging large-displacement setting.
![]() |
(a) OASIS sample results |
![]() |
(b) NLST sample results |
Experiment | NLST TRE (mm) | OASIS Dice (%) |
---|---|---|
Zero Displacement | 9.73 | 52.4 |
LAPIRN [5] | 5.51 | 80.0 |
RWCNet | 2.11 | 80.7 |
Table 3 shows results from the architectural ablation tests. These ablation results 3 provide interesting insights into the role that architecture plays in registration accuracy in different datasets. Unsurprisingly, multi-resolution registration plays a crucial role in the accuracy of RWCNet; registering at only downsampling on the NLST dataset yields a keypoint discrepancy of 5.52 mm, whereas registering at multiple resolutions yields a discrepancy of 2.11mm. OASIS Dice, likewise, drops from 80.7% to 74.0%.
Experiment | NLST TRE (mm) | OASIS Dice (%) |
---|---|---|
RWCNet | 2.11 | 80.7 |
RWCNet with single resolution (4x) | 5.52 | 74 |
RWCNet without correlation | 4.10 | 80.0 |
RWCNet with 2-timestep GRU | 5.17 | 80.1 |
The impact of correlation and number of RNN time steps is markedly different for both datasets. In the OASIS datasets, replacing correlation with stacking of the input feature tensors does not drastically impact the registration performance. The Dice score drops by 0.7. Likewise, when only 2 RNN time steps are used in RWCNet, the drop in accuracy is even lower in the OASIS dataset; the Dice score only drops by 0.6%. This is in contrast to NLST, where decreasing the number of RNN time steps and removing correlation drastically decrease performance. The keypoint accuracy decreases to 4.10mm when correlation is not computed. Likewise, when only 2 time steps are used, the accuracy decreases to 5.17mm.
The varying impact of ablation between the 2 datasets may be explained by the nature of the datasets. OASIS has relatively small displacements compared to NLST and as such may not benefit as much from cost volumes and the RNN structure. This observation could be helpful for designing generalizable methods to large-displacement datasets.
6 Conclusion
The investigation developed RWCNet for medical image registration, showing good performance 2 open datasets. RWCNet includes architectural features common in optical flow, a multi-scale approach, explicit cost volume computation, and iterative refinement with a RNN. The optical flow features were found to be most useful in the NLST dataset, with large displacements and had little impact on the OASIS dataset results. Future work should investigate how the performance of these architectural features changes for other datasets to confirm the importance of these features for large deformation problems. This can inform future work towards developing dataset-dependent self-configuring registration methods.
References
- [1] P. Risholm, A. J. Golby, and W. M. Wells, “Multi-Modal Image Registration for Pre-Operative planning and Image Guided Neurosurgical Procedures,” Neurosurgery clinics of North America, vol. 22, pp. 197–206, Apr. 2011.
- [2] A. Toga and P. Thompson, “The role of image registration in brain mapping,” Image and vision computing, vol. 19, pp. 3–24, Jan. 2001.
- [3] G. Balakrishnan, A. Zhao, M. R. Sabuncu, J. Guttag, and A. V. Dalca, “VoxelMorph: A Learning Framework for Deformable Medical Image Registration,” IEEE Transactions on Medical Imaging, vol. 38, pp. 1788–1800, Aug. 2019. arXiv: 1809.05231.
- [4] M. P. Heinrich and L. Hansen, “Voxelmorph++ Going beyond the cranial vault with keypoint supervision and multi-channel instance optimisation,” Feb. 2022. arXiv:2203.00046 [cs].
- [5] T. C. W. Mok and A. C. S. Chung, “Large Deformation Diffeomorphic Image Registration with Laplacian Pyramid Networks,” June 2020. arXiv:2006.16148 [cs, eess].
- [6] D. Sun, X. Yang, M.-Y. Liu, and J. Kautz, “PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume,” June 2018. arXiv:1709.02371 [cs].
- [7] E. Ilg, N. Mayer, T. Saikia, M. Keuper, A. Dosovitskiy, and T. Brox, “FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks,” Dec. 2016. arXiv:1612.01925 [cs].
- [8] Z. Teed and J. Deng, “RAFT: Recurrent All-Pairs Field Transforms for Optical Flow,” Aug. 2020. arXiv:2003.12039 [cs].
- [9] M. P. Heinrich, “Closing the Gap between Deep and Conventional Image Registration using Probabilistic Dense Displacement Networks,” July 2019. arXiv:1907.10931 [cs].
- [10] O. Cicek, A. Abdulkadir, S. S. Lienkamp, T. Brox, and O. Ronneberger, “3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation,” June 2016. arXiv:1606.06650 [cs].
- [11] M. Jaderberg, K. Simonyan, A. Zisserman, and K. Kavukcuoglu, “Spatial Transformer Networks,” Feb. 2016. arXiv:1506.02025 [cs].
- [12] J.-P. Thirion, “Non-rigid matching using demons,” in Proceedings CVPR IEEE Computer Society Conference on Computer Vision and Pattern Recognition, pp. 245–251, June 1996. ISSN: 1063-6919.
- [13] B. B. Avants, C. L. Epstein, M. Grossman, and J. C. Gee, “Symmetric diffeomorphic image registration with cross-correlation: evaluating automated labeling of elderly and neurodegenerative brain,” Medical Image Analysis, vol. 12, pp. 26–41, Feb. 2008.
- [14] K. Cho, B. van Merrienboer, D. Bahdanau, and Y. Bengio, “On the Properties of Neural Machine Translation: Encoder-Decoder Approaches,” Oct. 2014. arXiv:1409.1259 [cs, stat].
- [15] D. S. Marcus, T. H. Wang, J. Parker, J. G. Csernansky, J. C. Morris, and R. L. Buckner, “Open Access Series of Imaging Studies (OASIS): cross-sectional MRI data in young, middle aged, nondemented, and demented older adults,” Journal of Cognitive Neuroscience, vol. 19, pp. 1498–1507, Sept. 2007.
- [16] National Lung Screening Trial Research Team, “The National Lung Screening Trial: Overview and Study Design,” Radiology, vol. 258, pp. 243–253, Jan. 2011.
- [17] “Learn2Reg - Grand Challenge.”