Flemme: A Flexible and Modular Learning Platform for Medical Images
Abstract
As the rapid development of computer vision and the emergence of powerful network backbones and architectures, the application of deep learning in medical imaging has become increasingly significant. Unlike natural images, medical images lack huge volumes of data but feature more modalities, making it difficult to train a general model that has satisfactory performance across various datasets. In practice, practitioners often suffer from manually creating and testing models combining independent backbones and architectures, which is a laborious and time-consuming process. We propose Flemme, a FLExible and Modular learning platform for MEdical images. Our platform separates encoders from the model architectures so that different models can be constructed via various combinations of supported encoders and architectures. We construct encoders using building blocks based on convolution, transformer, and state-space model (SSM) to process both 2D and 3D image patches. A base architecture is implemented following an encoder-decoder style, with several derived architectures for image segmentation, reconstruction, and generation tasks. In addition, we propose a general hierarchical architecture incorporating a pyramid loss to optimize and fuse vertical features. Experiments demonstrate that this simple design leads to an average improvement of 5.60% in Dice score and 7.81% in mean interaction of units (mIoU) for segmentation models, as well as an enhancement of 5.57% in peak signal-to-noise ratio (PSNR) and 8.22% in structural similarity (SSIM) for reconstruction models. We further utilize Flemme as an analytical tool to assess the effectiveness and efficiency of various encoders across different tasks. Code is available at https://github.com/wlsdzyzl/flemme.
Index Terms:
deep learning platform, medical images, convolution, transformer, state-space model
I Introduction
Since AlexNet [1] competed in the ImageNet Large Scale Visual Recognition Challenge [2], convolutional neural networks (CNNs) have become dominant in computer vision. In particular, the most preferred choice in medical imaging is U-Net [3] which uses skip connections to enhance feature fusion across different time stages. However, the limitation of convolution lies in its focus on local feature extraction, lacking the ability to explore long-range dependencies. In recent years, following the immense success of transformers in natural language processing (NLP) [4, 5], treating images as sequences has garnered significant attention. Dosovitskiy et al.[6] propose vision transformer (ViT) to encode sequences of image patches. While ViT outperforms CNN-based models [7] in image recognition, the quadratic computational complexity of the multi-head self-attention (MSA) mechanism makes it difficult to scale to larger images. Swin Transformer [8] introduces a shifted window attention mechanism to reduce computational complexity and significantly improves the applicability of transformers for vision tasks. On the other hand, state space models (SSMs) have also made significant advances in sequence modeling. Gu et al.[9] propose a general sequence model backbone named Mamba with a linear time complexity, allowing for consideration of a much longer range of dependencies compared to transformers. Zhu et al.[10] and Liu et al.[11] process image patches with vision mamba blocks and introduced bidirectional and cross-scan strategies, respectively, to effectively integrate vision information from different directions.
Although new methods continue to emerge, several aspects remain worthy of exploration. Firstly, model performance depends highly on actual training techniques and deployment. For instance, Liu et al.[12] demonstrates that CNNs can be ”modernized” to achieve performance compatible with vision transformers. In real-world applications, the improvement in model performance needs to be balanced against the computational costs of training and inference. Secondly, beyond classification and segmentation, powerful model architectures have been proposed for image reconstruction and generation, such as variational auto-encoder (VAE) [13] and diffusion models [14, 15]. Applying the advanced backbones to these architectures for medical images is promising and worth anticipating. Thirdly, unlike natural images, medical images often lack large-scale, high-quality annotations but encompass a wide range of modalities, making it hard to train a general model that performs well across various medical image datasets. In practice, researchers and engineers often need to manually build and test models with multiple backbones and architectures for specific tasks. However, combining independent methods can be labor-intensive and hard to analyze.
Given the aforementioned challenges, the lack of a universal platform that supports the fast construction of diverse models significantly adds barriers to adopting the latest technologies and increases research duration. We propose Flemme, a FLExible, and Modular learning platform for MEdical images. Our platform decouples the encoder from the model architecture, enabling the fast construction of models by combining different encoders and architectures. We employ convolution, vision transformer, and SSM as backbones to create encoders and decoders for both 2D and 3D images. Our base model architecture follows an encoder-decoder structure, with further derivations for segmentation, reconstruction, and generation tasks. In addition to fusing horizontal features of the same scale among different stages like U-Net [3], we propose a general hierarchical framework for vertical feature fusion across different scales. To summarize, Flemme is distinguished by the following features:
-
•
We adopt a modular design with state-of-the-art encoders and architectures for fast model construction.
-
•
We implement a novel backbone-agnostic hierarchical architecture combining a pyramid loss for generic vertical feature fusion and optimization.
-
•
We support flexible context encoding and allow extensions of new modules for more data types and tasks.
We demonstrate the applications of the proposed platform on various medical image datasets. For all supported models, we can set hyper-parameters such as network depth, normalization, and training strategies in the same manner. Benefiting from this advantage, we conduct extensive experiments with various encoders on different tasks to fairly compare and analyze the performance, as well as the time and memory complexity. We hope our work will set new benchmarks and serve as a valuable research tool for future investigations into the potential of convolutions, transformers, and SSMs in medical imaging.
II Methods
An overview of the proposed platform is presented in Fig. 1, in which we show examples of how to construct different models with various encoders and architectures for different tasks. Our platform mainly consists of three modules: encoder & decoder, context embedding, and model architectures, of which context embedding is optional. The remainder of this section is organized as follows: we give mathematical formulations of various building blocks and derived encoders in Sec. II-A. Sec. II-B introduces our context encoding strategies. Sec.II-C elaborates supported architectures for different tasks.


II-A Encoder and Decoder
We implement all encoders and the corresponding decoders following pipelines illustrated in Fig. 2. An encoder consists of a patch encoding block, down-sampling layers, middle layers, and optional densely connected layers. Patch encoding block transfers input images into patch embeddings. A basic element of down-sampling layers contains a down-sampling block and a building block to compress image patches into smaller but deeper feature maps. Middle layers contain only building blocks for further information extraction. Finally, we can choose to flatten the feature maps and obtain 1-d vector embeddings through dense layers, or directly pass the feature maps into the decoder.
The decoder performs the reverse process of the encoder. It takes the latent features as input and reshapes them into feature maps if the inputs are 1D vector embeddings. The up-sampling layers consist of up-sampling blocks and building blocks for feature expansion and localization. Final layers are optional for further feature enhancement. A patch decoding block is then used to transfer feature maps into pixel space.
For convolution-based encoders, we use convolution to implement patch encoding and down-sampling, and transposed convolution for patch decoding and up-sampling. For encoders and decoders using transformers or SSMs as backbones, patch encoding, and decoding involve additional feature permutation to facilitate sequence modeling, with down-sampling and up-sampling implemented through patch merging and patch expansion operations as introduced in [8]. The main difference between various encoders lies in the employed building blocks, which are detailed in the remainder of this section.
II-A1 Convolution
We introduce two types of building blocks based on convolution: ConvBlock and ResConvBlock. A ConvBlock consists of a convolution, a normalization layer, and an activation function. Formally, we consider a forward pass of ConvBlock defined as:
(1) |
where and denote the rectified linear unit (ReLU) and group normalization [16], which are the default choices of activation and normalization functions. Two MLPs, named Gate and Bias, are optional to introduce a gate-bias mechanism if there is a context vector embedding input [17]:
(2) |
As illustrated in Fig. 1 (a), ResConvBlock contains two ConvBlocks and introduces a skip connection for residual learning:
(3) |
where is a MLP for context embedding projection, and are the first and second ConvBlock, respectively. Note that in ResConvBlock follows a similar manner to [7] and has no activation function.
II-A2 Transformer
Following Dosovitskiy et al.[6], we implement ViTBlock and further introduce the shifted-window strategy [8] to construct SwinBlock for time and space complexity reduction. Specifically, we partition image patches into smaller windows, and a window-based MSA (W-MSA) is employed to capture dependencies only among patches within the same window. Relative position encoding for both 2D and 3D windows are used to further improve the modeling capability. By shifting the windows at different stages, overlaps among windows are created to increase the receptive field. Fig 1 (b) gives an intuitive overview of SwinBlock that can be formulated as the following:
(4) |
where and refers to layer normalization [18] and drop path [19], represents a series of operations, including patch shift, window partition, applying window-based MSA and the reversed operations. Similar to ConvBlock, we also introduce ResSwinBlock , by simply replacing with in Eqn. (3).
II-A3 State-Space Model
SSM is a linear time-invariant system to map a 1-d sequence to through a hidden state . We adopt Mamba [9, 20] as the implementation of SSM to construct MambaBlock. In addition, we traverse the 2D and 3D image patches in a cross-scan manner and jointly model the obtained sequences to enhance the spatial awareness of Mamba. As depicted in Fig. 1 (c), MambaBlock processes image patches as follows:
(5) |
where projects image patch into inner space, indicates multiple operations including cross scanning, sequence modeling with Mamba, and scan merging. ResMambaBlock are introduced by using as a replacement for in Eqn. (3).
II-A4 U-shaped Networks
For all the aforementioned encoders, we also implement their U-shaped variants. The U-shaped encoder outputs feature maps from down-sampling layers, which are concatenated with the feature maps from up-sampling layers of corresponding spatial resolutions in the decoder. U-shaped networks have demonstrated superiority and become the de-facto standard in various imaging tasks [3, 14, 21, 22, 23]. Fig. 4 gives an illustration of a U-shaped structure and horizontal feature integration.
Because the building blocks are capable of handling context embeddings, the encoder and decoder can also incorporate an additional context embedding as input. To summarize, the basic encoding and decoding processes can be formulated as follows:
(6) |
where and represent the encoder and decoder, and refer to the patching encoding and decoding operations, and refer to the down-sampling and up-sampling layers, and refers to the middle layers, respectively.

II-B Context embedding
Encoding contexts is necessary for conditional imaging tasks [24] and diffusion models [14, 25]. Context refers to any additional input other than the input image, which can be a class label, a prompt image, or a time step index from diffusion process, We compute one-hot and sinusoidal position embeddings [4] for class label and time step index. When the context is an image, any of the encoders introduced in Sec. II-A can be employed.
Encoder | Archi | CVC-ClinicDB | Echonet | ISIC | TN3K | BraTS21 | ImageCAS | ||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Dice | mIoU | Dice | mIoU | Dice | mIoU | Dice | mIoU | Dice | mIoU | Dice | mIoU | ||
ResNet | SeM | 0.8112 | 0.7287 | 0.9193 | 0.8538 | 0.8827 | 0.8055 | 0.7105 | 0.6048 | 0.3091 | 0.2813 | 0.7076 | 0.5497 |
H-SeM | 0.8098 | 0.7278 | 0.9183 | 0.8521 | 0.8836 | 0.8087 | 0.6257 | 0.5227 | 0.2557 | 0.2525 | 0.7015 | 0.5425 | |
U-Net | SeM | 0.8457 | 0.7699 | 0.9232 | 0.8604 | 0.8864 | 0.8113 | 0.7344 | 0.6329 | 0.5008 | 0.4262 | 0.7450 | 0.5961 |
H-SeM | 0.8492 | 0.7727 | 0.9245 | 0.8626 | 0.8915 | 0.8184 | 0.7395 | 0.6449 | 0.7984 | 0.7182 | 0.7595 | 0.6148 | |
CAtten-U | SeM | 0.5431 | 0.4230 | 0.9146 | 0.8464 | 0.4121 | 0.2836 | 0.6096 | 0.4864 | - | - | - | - |
H-SeM | 0.8186 | 0.7401 | 0.9231 | 0.8600 | 0.7834 | 0.6858 | 0.7208 | 0.6179 | - | - | - | - | |
Swin-U | SeM | 0.6874 | 0.5800 | 0.9125 | 0.8428 | 0.8515 | 0.7675 | 0.6051 | 0.4761 | 0.8133 | 0.7361 | 0.7345 | 0.5833 |
H-SeM | 0.8388 | 0.7572 | 0.9182 | 0.8521 | 0.8980 | 0.8299 | 0.6729 | 0.5525 | 0.8406 | 0.7670 | 0.7475 | 0.5997 | |
Mamba-U | SeM | 0.8618 | 0.7937 | 0.9225 | 0.8595 | 0.8999 | 0.8326 | 0.7196 | 0.6082 | 0.8430 | 0.7754 | 0.7644 | 0.6215 |
H-SeM | 0.8700 | 0.8033 | 0.9234 | 0.8609 | 0.9050 | 0.8394 | 0.7506 | 0.6472 | 0.8450 | 0.7748 | 0.7678 | 0.6259 |
II-C Model Architecture
II-C1 Encoder-decoder Architecture
Our base architecture follows an encoder-decoder style with optional components for encoding context as mentioned in Sec. II-B. A forward pass of based architecture is defined as:
(7) |
where and are functions of encoding context for encoder and decoder , computes the sinusoidal positional embedding, is the latent embedding computed by the encoder, and are input contexts. Usually, is the input condition such as a prompt image or class label and is a time step index. Base architecture is not directly trainable due to the lack of loss functions and serves as a foundational structure for other architectures as illustrated in Fig. 3, including the Segmentation Model (SeM), the Auto-Encoder (AE) and the De-noising Diffusion Probabilistic Model (DDPM), for medical image segmentation, reconstruction, and generation, respectively. We explain these architectures in the remaining part of this section. To simplify the notation, we omit the optional input contexts in the following formulas to simplify the notation.
II-C2 Segmentation
As shown in Fig. 3 (a), SeM is a simple extension of base architecture by introducing segmentation loss. It takes raw images as inputs and gives the predictions through a forward pass defined in Eqn. (7). Loss is computed over prediction and ground truth:
(8) |
where is the prediction, specifies the ground-truth. By default, segmentation loss is a combination of cross-entropy and Dice loss, denoted by and , respectively.
II-C3 Reconstruction
We implement auto-encoder (AE) for medical image restoration, which builds reconstruction loss over predictions and inputs, indicating an unsupervised representation learning:
(9) |
where is the mean squared error, serving as the default reconstruction loss. As shown in Fig. 3 (b), AE can be regularized by distribution loss to learn a more continuous latent representation and have a certain capability of generation [13]. The learning process of AE can also be supervised by a ground-truth target.
II-C4 Generation
Fig. 3 (c) illustrates the pipeline of DDPM [14], in which base architecture is used to construct a noise predictor, denoted as -model. In the diffusion process, we add random Gaussian noise to the input image to get -th step noisy images, which can be formulated as the following:
(10) |
where and are hyper-parameters related to noise schedules. The -model takes noisy images as inputs and computes loss over the noise and prediction:
(11) |
where is the predicted noise. In the reversed diffusion process, we gradually remove the predicted noise from the current noisy image to recover the image of the last time step . Refer to [14] for more Details about the de-noising process.
II-C5 Hierarchical Architecture
Inspired by Ronneberger et al.[3], who propose to connect the feature maps from encoder and decoder for horizontal feature fusion, we introduce a generic hierarchical architecture to refine and fuse features vertically. In specific, we assign a building block and a feature decoding block for each stage in the decoder to generate multi-resolution predictions, denoted as . We introduce a combination layer, where all predictions are weighted with a learnable position encoding and integrated to produce the final output:
(12) |
In the above, denotes concatenation, scales to ’s size. Meanwhile, the target is scaled to the corresponding shapes of predictions to construct a pyramid loss for feature refinements:
(13) |
where is the original non-hierarchical loss. The overall loss is computed through the following equation:
(14) |
The hierarchical architecture is an improvement of our base architecture aiming at discovering and combining local fine details and global structures, both of which are important for high-quality segmentation and reconstruction. Therefore, we derive the hierarchical versions of SeM and AE, denoted as H-SeM and H-AE, respectively. A vertical feature fusion for DDPM is not recommended, because we predict noise instead of reconstructing the image in the reverse diffusion process. Noise usually doesn’t contain clear global structures, and scaling the noise map may cause severe loss of details. Fig. 4 gives an example of a segmentation model constructed with H-SeM and a U-shaped encoder using ConvBlock.

III Experimental Results
III-A Datasets and Evaluation Metrics
For segmentation, we evaluate our methods on six public datasets. CVC-ClinicDB [26] is a polyp segmentation dataset of 612 images with a resolution from 31 colonoscopy video sequences. Echonet [27] includes 10,030 labeled echocardiogram images for chamber segmentation. ISIC [28] was published by the International Skin Imaging Collaboration as a large-scale dataset of 1279 dermoscopy images for lesion segmentation. TN3K [29] is an open-access dataset containing 3493 thyroid nodule images with high-quality nodule masks labeling. BraTS21 [30] is a magnetic resonance image (MRI) dataset that provides 2000 3D images with segmentation labels of the different glioma sub-regions. ImageCAS [31] contains 1000 3D images for coronary artery segmentation. Images from CVC-ClinicDB, Echonet, ISIC, and TN3K are resized to the dimensions of 320256, 128128, 384256 and 256256 pixels, respectively. For BraTS21, all images are cropped to the region of non-zero values and resized to a shape of 120190120 voxels. For ImageCAS, the images are resized to a shape of 19219296 voxels. We also evaluate reconstruction accuracy on the FastMRI dataset introduced by Zbontar et al.[32], which aims to accelerate magnetic resonance imaging by taking fewer measurements. Specifically, we focus on the single-coil knee MRI reconstruction task that contains 1172 cases. Each case contains a k-space raw volume and the corresponding emulated single-coil (ESC) image. We randomly masked 99% of low-frequency k-space lines, leading to a theoretical 10-fold speedup over fully-sampled single-coil imaging. Zero-filled reconstruction was further performed on the masked k-space volume to obtain the noisy images, which serve as the input of reconstruction models. Similar to [32], we treat 3D images as multiple 2D slices and all images are cropped to the central 320320 pixel region to compensate for readout-direction oversampling. The first 5 slices of all samples are discarded due to limited information, resulting in a dataset containing 36,017 images. This dataset is also used for the evaluation of generation models.
For all datasets, we randomly split samples into 5 folds, where the first 3 folds are used for training, and the 4th fold is used for validation. Evaluations are performed on the 5th fold. We compute Dice score and mIoU for segmentation evaluations. PSNR [33] and SSIM [34] are used to evaluate reconstruction accuracy. Because there are no universal metrics for evaluating medical image generation, we use the noisy images as input conditions of generation models for high-quality MRI reconstruction so that the results can be evaluated through reconstruction metrics.
III-B Experimental Setup
III-B1 Segmentation
We construct segmentation models with SeM architecture and different encoders including ResNet [7], U-Net [3], CAtten-U [14], Swin-U [21] and Mamba-U [23], which are constructed with ResConvBlock, ConvBlock, ConvBlock plus MSA, SwinBlock and MambaBlock, respectively. In the above, U-Net, CAtten-U, Swin-U and Mamba-U are U-shaped encoders. The corresponding hierarchical models are constructed using H-SeM. Note that the huge GPU memory requirements prevent us from training models with CAtten-U on 3D image patches. Models trained on 2D and 3D image datasets contain 2 and 3 down-sampling layers, respectively. The number of middle layers is set to 2. We employ a hybrid loss function that combines Dice and cross-entropy loss for all segmentation models, which are trained using Adam [35] optimizer for 500 epochs. The learning rate starts from and follows a linear decaying schedule.
III-B2 Reconstruction
We construct reconstruction models with AE architecture and different encoders including ResNet, U-Net, Atten-U, Swin-U, and Mamba-U. H-AE architecture is used to construct their hierarchical versions. We train all reconstruction models for 50 epochs with L1 loss. Other settings remain the same as segmentation models.
III-B3 Generation
We construct generation models with DDPM architecture and use different encoders for noise prediction, which are ResNet, U-Net, Swin-U, and Mamba-U. We use noisy images as input conditions for MRI generation. All models are trained for 200 epochs with L2 loss. The sampling process is accelerated through [15]. For a given condition, we integrate 5 or 10 generated samples as the final result, denoted as DDPM (5) and DDPM (10). Other settings remain the same as reconstruction models.
All experiments about 2D segmentation and reconstruction are conducted on a server with 8 NVIDIA 3090 GPUs. Models for generation and 3D segmentation are trained on 4 NVIDIA A800 GPUs.

Encoder | Archi | PSNR | SSIM | Archi | PSNR | SSIM |
---|---|---|---|---|---|---|
ResNet | AE | 19.52 | 0.2741 | DDPM (5) | 9.89 | 0.0010 |
H-AE | 19.55 | 0.2715 | DDPM (10) | 10.14 | 0.0016 | |
U-Net | AE | 20.12 | 0.3429 | DDPM (5) | 11.42 | 0.0338 |
H-AE | 20.16 | 0.3431 | DDPM (10) | 11.69 | 0.0453 | |
CAtten-U | AE | 13.53 | 0.1326 | DDPM (5) | - | - |
H-AE | 16.79 | 0.1976 | DDPM (10) | - | - | |
Swin-U | AE | 18.44 | 0.2903 | DDPM (5) | 18.32 | 0.2721 |
H-AE | 20.38 | 0.3481 | DDPM (10) | 18.65 | 0.2940 | |
Mamba-U | AE | 20.81 | 0.3648 | DDPM (5) | 19.45 | 0.3093 |
H-AE | 20.99 | 0.3702 | DDPM (10) | 19.85 | 0.3309 |
III-C Quantitative and Qualitative Results
Table. I shows the quantitative results for segmentation models. From the encoder’s point of view, Mamba-based methods show the highest superiority. Mamba-U consistently ranks in the top 2 across all six segmentation datasets and achieves the highest accuracy on three of them. U-Net also demonstrates strong performance, securing a top-2 ranking on four datasets and achieving the highest accuracy on two of them. Swin-U achieves top-2 performance on two datasets. ResNet tends to overfit, highlighting that the U-shaped design significantly enhances the robustness of the segmentation models. CAtten-U has the poorest performance, suggesting that a straightforward combination of convolution and attention mechanisms may not be optimal for medical segmentation. Moreover, this approach severely increases the computational and memory burden. From the model architecture’s perspective, our hierarchical design improves performance across all U-shaped networks, which demonstrates the effectiveness of the proposed vertical feature fusion. A comprehensive comparison of qualitative results is provided in Fig. 5.
Table. II shows the quantitative results for reconstruction models and generation models. All U-shaped reconstruction models benefit from our hierarchical architecture. When using conditional DDPM for medical image restoration, a larger ensemble number leads to higher accuracy. Similar to segmentation models, Mamba-U achieves the highest reconstruction accuracy. Convolution-based encoders also have satisfactory capability for reconstruction, although they perform poorly in terms of generation. We notice that sequence modeling-based encoders generally outperform convolution-based encoders for reconstruction and generation, indicating that exploring long-range dependency is particularly effective for high-frequency component prediction. Qualitative results are illustrated in Fig. 6.
Building | #Layer | Patch | #Params | Memory | Training | Inference |
---|---|---|---|---|---|---|
block | (M) | (GB) | time (s) | time (s) | ||
ConvBlock | 16 | 642 | 0.056 | 8.69 | 0.73 | |
1282 | 3.88 | 0.22 | 8.75 | 0.74 | ||
643 | 0.62 | 9.16 | 1.45 | |||
1283 | 12.62 | 2.95 | 30.77 | 9.31 | ||
\cdashline2-7 | 32 | 642 | 0.16 | 14.17 | 1.23 | |
1282 | 6.98 | 0.20 | 14.52 | 1.24 | ||
643 | 0.77 | 14.94 | 2.47 | |||
1283 | 21.92 | 4.24 | 53.26 | 15.81 | ||
SwinBlock | 16 | 642 | 0.29 | 26.76 | 4.02 | |
1282 | 4.38 | 0.54 | 27.95 | 4.14 | ||
643 | 4.97 | 45.41 | 15.54 | |||
1283 | 4.75 | 38.53 | 332.42 | 120.16 | ||
\cdashline2-7 | 32 | 642 | 0.41 | 33.52 | 7.49 | |
1282 | 8.53 | 0.86 | 33.91 | 7.73 | ||
643 | 8.32 | 87.49 | 30.45 | |||
1283 | 8.53 | 65.41 | 629.67 | 230.64 | ||
MambaBlock | 16 | 642 | 0.34 | 28.13 | 4.55 | |
1282 | 5.53 | 0.73 | 27.62 | 4.53 | ||
643 | 4.35 | 60.28 | 17.55 | |||
1283 | 5.86 | 34.21 | 505.96 | 149.46 | ||
\cdashline2-7 | 32 | 642 | 0.50 | 35.02 | 8.77 | |
1282 | 10.85 | 1.24 | 35.38 | 8.80 | ||
643 | 7.11 | 110.08 | 32.44 | |||
1283 | 11.24 | 55.73 | 920.41 | 271.75 |
III-D Time and Memory Consumption
We also compare the training time and GPU memory consumption of different building blocks as shown in Table. III. We construct models with SeM architecture and encoders using building blocks ConvBlock, SwinBlock, and MambaBlock. Each model contains two down-sampling and two up-sampling layers. The batch size is set to 2. We measure the runtime and reserved GPU memory for 500 training and inference iterations on a single A800 GPU. As the network depth increases, the number of parameters, GPU memory consumption, and training/inference time for all models increase linearly. While ConvBlock achieves the highest time and memory efficiency, gaps between other building blocks and ConvBlock are acceptable for 2D images. When we switch the input to 3D images, SwinBlock and MambaBlock do not experience a significant increase in the number of parameters because they process images as sequences. However, the length of sequence for 3D images increases cubically, which leads to a sudden leap in runtime and memory consumption. As the image size increases, the advantages of ConvBlock on time and space complexity become more significant. The training/inference time of models using ConvBlock is 5.26 and 6.97 faster than models using SwinBlock and MambaBlock for input patches with a size of 643. When the patch size grows to 1283, it is 11.3 and 16.86 faster to train a CNN compared to training models with SwinBlock and MambaBlock under the same experimental settings. In practice, the differences in training duration are more pronounced because CNNs can be trained with much larger batch sizes. Due to the prevalence of 3D data in medical images and the superior performance of CNNs as analyzed in the above and Sec. III-C, we believe that CNNs will continue to hold an irreplaceable position in medical image segmentation and reconstruction.
IV Conclusion, Discussion, and Future Works
In this paper, we present Flemme, a general learning platform aiming at the rapid and flexible development of deep learning models for medical images. We introduce various encoders based on convolution, transformer, and SSM backbones, combining SeM, AE, and DDPM architectures to facilitate model creation for medical image segmentation, reconstruction, and generation. We also implement H-SeM and H-AE by employing a generic hierarchical architecture for vertical feature refinement and fusion. Extensive experiments on multiple datasets with multiple modalities showcase that this design improves model performance across different encoders.
Benefiting from the advantages of our platform in model creation and hyper-parameter control, we conduct a fair comparison and analysis of various encoders regarding runtime, memory consumption, and performance across different tasks. We confirm that CNNs are still playing essential roles in medical image segmentation and reconstruction, although we notice that there are notable limitations of convolution compared to sequence-modeling backbones in generation. Given these findings, we recommend SSM-based networks as the optimal choice when sufficient computational resources are available, especially for high-frequency component prediction.
For future works, we are actively working on expanding our platform to handle diverse data types such as modeling point cloud and graph, whose importance in biomedical applications is becoming increasingly evident. Additionally, high-quality medical image generation, as well as automated quality assessment of generated medical images, will also be the focus of our future research.
Acknowledgment
This work is supported by the Natural Science Foundation of China (Grant 62371270) and Shenzhen Key Laboratory of Ubiquitous Data Enabling (No.ZDSYS20220527171406015).
References
- [1] A. Krizhevsky, I. Sutskever, and G. E. Hinton, “Imagenet classification with deep convolutional neural networks,” Advances in neural information processing systems, vol. 25, 2012.
- [2] J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, and L. Fei-Fei, “Imagenet: A large-scale hierarchical image database,” in 2009 IEEE conference on computer vision and pattern recognition. Ieee, 2009, pp. 248–255.
- [3] O. Ronneberger, P. Fischer, and T. Brox, “U-net: Convolutional networks for biomedical image segmentation,” in Medical image computing and computer-assisted intervention–MICCAI 2015: 18th international conference, Munich, Germany, October 5-9, 2015, proceedings, part III 18. Springer, 2015, pp. 234–241.
- [4] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin, “Attention is all you need,” Advances in neural information processing systems, vol. 30, 2017.
- [5] J. Devlin, M.-W. Chang, K. Lee, and K. Toutanova, “Bert: Pre-training of deep bidirectional transformers for language understanding,” arXiv preprint arXiv:1810.04805, 2018.
- [6] A. Dosovitskiy, L. Beyer, A. Kolesnikov, D. Weissenborn, X. Zhai, T. Unterthiner, M. Dehghani, M. Minderer, G. Heigold, S. Gelly et al., “An image is worth 16x16 words: Transformers for image recognition at scale,” arXiv preprint arXiv:2010.11929, 2020.
- [7] K. He, X. Zhang, S. Ren, and J. Sun, “Deep residual learning for image recognition,” in Proceedings of the IEEE conference on computer vision and pattern recognition, 2016, pp. 770–778.
- [8] Z. Liu, Y. Lin, Y. Cao, H. Hu, Y. Wei, Z. Zhang, S. Lin, and B. Guo, “Swin transformer: Hierarchical vision transformer using shifted windows,” in Proceedings of the IEEE/CVF international conference on computer vision, 2021, pp. 10 012–10 022.
- [9] A. Gu and T. Dao, “Mamba: Linear-time sequence modeling with selective state spaces,” arXiv preprint arXiv:2312.00752, 2023.
- [10] L. Zhu, B. Liao, Q. Zhang, X. Wang, W. Liu, and X. Wang, “Vision mamba: Efficient visual representation learning with bidirectional state space model,” arXiv preprint arXiv:2401.09417, 2024.
- [11] Y. Liu, Y. Tian, Y. Zhao, H. Yu, L. Xie, Y. Wang, Q. Ye, and Y. Liu, “Vmamba: Visual state space model,” 2024.
- [12] Z. Liu, H. Mao, C.-Y. Wu, C. Feichtenhofer, T. Darrell, and S. Xie, “A convnet for the 2020s,” in Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, 2022, pp. 11 976–11 986.
- [13] D. P. Kingma and M. Welling, “Auto-encoding variational bayes,” arXiv preprint arXiv:1312.6114, 2013.
- [14] J. Ho, A. Jain, and P. Abbeel, “Denoising diffusion probabilistic models,” Advances in neural information processing systems, vol. 33, pp. 6840–6851, 2020.
- [15] J. Song, C. Meng, and S. Ermon, “Denoising diffusion implicit models,” arXiv preprint arXiv:2010.02502, 2020.
- [16] Y. Wu and K. He, “Group normalization,” in Proceedings of the European conference on computer vision (ECCV), 2018, pp. 3–19.
- [17] W. Grathwohl, R. T. Chen, J. Bettencourt, I. Sutskever, and D. Duvenaud, “Ffjord: Free-form continuous dynamics for scalable reversible generative models,” arXiv preprint arXiv:1810.01367, 2018.
- [18] J. L. Ba, J. R. Kiros, and G. E. Hinton, “Layer normalization,” arXiv preprint arXiv:1607.06450, 2016.
- [19] G. Larsson, M. Maire, and G. Shakhnarovich, “Fractalnet: Ultra-deep neural networks without residuals,” arXiv preprint arXiv:1605.07648, 2016.
- [20] T. Dao and A. Gu, “Transformers are ssms: Generalized models and efficient algorithms through structured state space duality,” arXiv preprint arXiv:2405.21060, 2024.
- [21] H. Cao, Y. Wang, J. Chen, D. Jiang, X. Zhang, Q. Tian, and M. Wang, “Swin-unet: Unet-like pure transformer for medical image segmentation,” in European conference on computer vision. Springer, 2022, pp. 205–218.
- [22] J. Chen, Y. Lu, Q. Yu, X. Luo, E. Adeli, Y. Wang, L. Lu, A. L. Yuille, and Y. Zhou, “Transunet: Transformers make strong encoders for medical image segmentation,” arXiv preprint arXiv:2102.04306, 2021.
- [23] Z. Wang, J.-Q. Zheng, Y. Zhang, G. Cui, and L. Li, “Mamba-unet: Unet-like pure visual mamba for medical image segmentation,” arXiv preprint arXiv:2402.05079, 2024.
- [24] A. Van den Oord, N. Kalchbrenner, L. Espeholt, O. Vinyals, A. Graves et al., “Conditional image generation with pixelcnn decoders,” Advances in neural information processing systems, vol. 29, 2016.
- [25] J. Ho and T. Salimans, “Classifier-free diffusion guidance,” arXiv preprint arXiv:2207.12598, 2022.
- [26] J. Bernal, F. J. Sánchez, G. Fernández-Esparrach, D. Gil, C. Rodríguez, and F. Vilariño, “Wm-dova maps for accurate polyp highlighting in colonoscopy: Validation vs. saliency maps from physicians,” Computerized medical imaging and graphics, vol. 43, pp. 99–111, 2015.
- [27] D. Ouyang, B. He, A. Ghorbani, M. P. Lungren, E. A. Ashley, D. H. Liang, and J. Y. Zou, “Echonet-dynamic: a large new cardiac motion video data resource for medical machine learning,” in NeurIPS ML4H Workshop, 2019, pp. 1–11.
- [28] D. Gutman, N. C. Codella, E. Celebi, B. Helba, M. Marchetti, N. Mishra, and A. Halpern, “Skin lesion analysis toward melanoma detection: A challenge at the international symposium on biomedical imaging (isbi) 2016, hosted by the international skin imaging collaboration (isic),” arXiv preprint arXiv:1605.01397, 2016.
- [29] H. Gong, G. Chen, R. Wang, X. Xie, M. Mao, Y. Yu, F. Chen, and G. Li, “Multi-task learning for thyroid nodule segmentation with thyroid region prior,” in 2021 IEEE 18th international symposium on biomedical imaging (ISBI). IEEE, 2021, pp. 257–261.
- [30] U. Baid, S. Ghodasara, S. Mohan, M. Bilello, E. Calabrese, E. Colak, K. Farahani, J. Kalpathy-Cramer, F. C. Kitamura, S. Pati et al., “The rsna-asnr-miccai brats 2021 benchmark on brain tumor segmentation and radiogenomic classification,” arXiv preprint arXiv:2107.02314, 2021.
- [31] A. Zeng, C. Wu, G. Lin, W. Xie, J. Hong, M. Huang, J. Zhuang, S. Bi, D. Pan, N. Ullah et al., “Imagecas: A large-scale dataset and benchmark for coronary artery segmentation based on computed tomography angiography images,” Computerized Medical Imaging and Graphics, vol. 109, p. 102287, 2023.
- [32] J. Zbontar, F. Knoll, A. Sriram, T. Murrell, Z. Huang, M. J. Muckley, A. Defazio, R. Stern, P. Johnson, M. Bruno et al., “fastmri: An open dataset and benchmarks for accelerated mri,” arXiv preprint arXiv:1811.08839, 2018.
- [33] A. Hore and D. Ziou, “Image quality metrics: Psnr vs. ssim,” in 2010 20th international conference on pattern recognition. IEEE, 2010, pp. 2366–2369.
- [34] Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, “Image quality assessment: from error visibility to structural similarity,” IEEE transactions on image processing, vol. 13, no. 4, pp. 600–612, 2004.
- [35] D. P. Kingma and J. Ba, “Adam: A method for stochastic optimization,” arXiv preprint arXiv:1412.6980, 2014.