1. Introduction
Medical Image 분석 분야에서 해부학적 구조를 파악하기 위해 Image Segmentation이 중요한 역할을 하게 되었고, 해당 분야에서 FCNNs 방식 중, UNet 기반의 모델들이 좋은 결과를 보임.
→ 하지만 FCNN 구조 특성상, 각 컨볼루션 필터가 입력 이미지의 작은영역(local area)에만 초점을 맞춰 처리하기 때문에 long-range의 dependency, 멀리 떨어진 영역 간의 관계 파악이 어렵다는 한계가 있었음.
[Solution]
FCNNs의 한계를 극복하기 위해 Vision Transformer와 같은 transformer 기반 모델들이 제시되었고, 본 논문에서 3D UNet과 Transformer 구조를 기반으로 한 UNETR(UNet TRansformers) 아키텍쳐를 제안.
1.1 UNETR 구조의 특징
3D 구조의 Medical Image segmentation 문제를 1D sequence-to-sequence 예측 문제로 전환
•
Transformer Encoder를 사용하기 때문에, 3D 이미지 데이터를 쪼개서 1D image patch embedding로 변환해 transformer의 input으로 넣어줌.
→ transformer encoder를 통해 long-range dependency와 global context를 보다 효율적으로 파악 가능
•
Decoder는 CNN-based decoder를 사용하고, transformer encoder와 CNN decoder는 skip-connection으로 직접 연결
→ transformer는 localized information를 정확하게 파악하지 못하기 때문에 CNN 기반 디코더 사용해 segmentation output 예측을 진행
→ skip connection을 통해서 다양한 해상도에서의 representation을 추출하고 병합할 수 있음.
1.2 UNETR의 성과
•
3D medical image segmentation task를 위한 transformer 기반 모델을 제시
→ volumetric data를 직접 사용. (기존 3D medical image segmentation 분야에서 많이 사용된 backbone CNN 모델을 사용하지 않음)
•
BTCV, MSD dataset에 성능을 검증하였으며, BTCV dataset에서 SOTA를 달성함.
•
뇌종양, spleen segmentation task에서 SOTA를 달성함
2. Methodology
2.1 Architecture
[모델 구조]
1.
3D input data를 1D patch embedding 으로 만든 후 linear projection 진행한다.
→ 구조의 3D input 데이터를 ( 구조의 N개의 1D patch embedding으로 변환
•
시퀀스 길이
•
: 각 patch의 해상도
•
linear projection
: 각각의 패치 임베딩을 차원의 embedding space로 투사
•
Position Embedding
: 공간 정보를 보존하기 위해서 1D의 positional embedding 을 patch embedding에 에 더해줌.
•
원래 Transformer 구조에서는 각 sentence의 시작과 끝을 구분하기 [class] token을 추가하지만, sementic segmentation을 위해선 [class]token을 사용하지 않음
2.
Normalization과 Multi-Head-Attention을 적용(Transformer block)한 후, residual network의 구조를 통하여서 합쳐진 데이터를 MLP에 다시 넣는다.
•
MSA(Multi-head Self Attention) 구조는 n개의 parallel한 SA head(self-attention head)로 구성되어 있음
MSA 계산 과정
•
MLP 구조는 2개의 linear layer와 GeLU 활성화 함수를 사용. (ViTransformer의 MLP와 동일한 구조)
3.
Transformer block을 총 N=12번 반복한다.
4.
Transformer 연산이 끝난 후, 나온 embedding을 reshape 과정을 통하여 3D voxel(데이터)형식으로 reconstruction한다.
→ transformer 연산이 끝난 sequence representation은 형태를 갖게 되는데,
이를 tensor 형태로 변형해줌
5.
Upsampling을 수행 하면서, N=9, N=6, N=3, N=0(original)에서 나온 feature들과 skip-connection을 진행한다.
•
Upsampling 과정에서는 기존 UNet과 같이 ConvTranspose(Deconvolution)을 사용하여 해상도를 높여줌 (원래 input data와 같은 차원으로 회복하는 과정)
→ encoder 바로 뒤에 skip-connection을 통해 Deconvolution layer 연결
•
N-1번째 tranformer encoder에서 나온 output과 Deconvolution layer를 거친 feature map을 합치고, (3 x 3 x 3) Conv layer에 넣어줌.
→ 다시 해당 output을 (2x2x2) Deconv layer를 거쳐 Upsampling 진행
→ 위의 과정을 원래 input data와 같은 차원으로 회복될 때 까지 반복
6.
마지막 layer에서 최종 dimension과 channel을 조절한 후 segmentaiton result를 만든다.
•
최종 output은 1 x 1 x 1 Conv layer 와 Softmax 활성화 함수를 거쳐 voxel-wise sementic prediction 값으로 생성됨.
2.2 Loss function
loss function의 soft dice loss와 cross-entropy loss를 조합하여 사용
•
: voxel의 개수
•
: dice loss
•
: class j와 voxel i에서 예측된 probability와 ground truth
[Soft Dice loss]
[Cross Entropy]
3. Experiment Details
3.1 Datasets
•
BTCV (CT)
BTCV dataset은 30 subjects로 구성되어 있으며, 13개의 장기가 annotation되어 있는 dataset 이다.
◦
각각의 CT scan은 80~255 slice, 512×512 pixel의 조영 CT이고, thickness는 1~6 mm이다.
◦
모든 image는 1.0 mm의 voxel space로 resample되었다.
◦
brain tumor segmentation 실험은 4-channel input의 3 class segmentation task로 진행됨
•
MSD (MRI/CT)
MSD dataset에서 brain tumoe segmentation task는 484개의 multi-modal, multi-site MRI data(FLAIR, T1w, T1gd, T2w)로 구성되어 있다.
◦
necrotic/active tumor와 oedema가 annotation되어 있다. (voxel space는 1.0×1.0×1.0 )
◦
Spleen segmentation task는 41개의 CT volume으로 구성되어 있으며, spleen body가 annotation 되어 있다.
◦
Spleen segmentation 실험은 1-channel input의 binary segmentation task, 여러 개의 장기 및 spleen segmentation task에서는 랜덤하게 input 이미지의 sample를 사용하여 진행. volume size (96, 96 ,96)
3.2 Metrics
•
Dice score 와 95% Hausdorff Distance(HD)를 사용
3.3 Implementation Details
4. Experiment Results & Discussion
4.1 Quantitative Evaluations
4.1.1. BTCV Segmentation performance
•
UNETR가 Standard, Free competition 모두에서 SOTA를 보임
→ 전반적으로 0.899 이상의 Dice score 기록
→ 2,3,4 번째로 좋은 방법들과 비교 했을 때, 1.238%, 1.696% and 5.269% 의 성능차이를 보임
•
Standard Competition
: UNETR과 CNN, transformer 기반의 베이스라인 모델들과 성능 비교
→ 모든 종류의 organ segmentation에대해 UNETR가 평균 85.3% 의 Dice score를 기록 (New SOTA 달성)
◦
spleen, liver, stomach와 같은 크기가 큰 organ segmentation에서 2번째 좋은 base-line 모델을 1.043%, 0.830% and 2.125% 차이로 능가
◦
작은 크기의 organ segmentation에서도 좋은 성과를 보임
4.1.2. MSD Segmentation performance
•
MSD dataset 실험에서는 brain tumor, spleen segmentation을 진행
→ UNETR과 CNN, transformer 기반의 베이스라인 모델들과 성능 비교
•
Brain Segmentation, Spleen segmentation 모두에서 UNETR가 모든 semantic class에 대해서 제일 좋은 성능을 보임
•
Tumor Core(TC) sub region segmentation task에 대해 상대적으로 좋은 성적을 보임
4.2 Qualitative Results
4.2.1. BTCV Qualitative Comparision
•
복부에 위치한 장기들에대한 segmentation task에서 UNETR가 좋은 성능을 보임.
•
nnUNet과 비교 했을 때, UNETR가 long-range dependency를 더 잘 학습함.
→ nnUNet은 간과 위 조직을 정확하게 구분하지 못했지만, UNETR는 장기들 사이의 경계를 정확하게 구분함. (row 3)
•
콩팥, 부신과 주위의 조직들의 구분 task도 잘 수행함
→ UNETR가 공간 정보 수집에 효과적임
→ 2D transformer 기반 모델들과 비교 했을 때도, 주변 경계 segmentation task를 훨씬 잘 수행함.
4.2.2. MSD Qualitative Comparision
→ Tumor Core(red, blue), Whole Tumor(red, blue, green), Enhancing Tumor core(green)
•
다른 baseline 모델들과 비교 했을 때, UNETR가 세부정보를 훨씬 잘 파악함.
4.3 Summary
•
BTCV, MSD Dataset 모두에 대해서 UNETR가 CNN, Transformer 기반 모델들보다 좋은 성능을 보임.
→ UNETR는 global, local depency 모두 학습이 가능하기 때문에 segmentation accuracy가 높은 것으로 보임.
→ long-range dependency 파악도 효과적으로 수행하여 target organ과 주변 조직들 사이의 경계를 정확하게 구분(qualitative comparision에서 확인가능)
•
UNETR은 BTCV dataset에서 SOTA를 달성함으로 효과를 입증했으며, gallbladder, adrenal glands와 같이 작은 크기의 organs에서도 좋은 성능을 보임.
5. Ablation
5.1 Decoder Choice
•
UNETR의 CNN decoder 구조 성능 비교를 위해, decoder를 NUP(Naive UPsampling), PUP(Progressive UPsampling), MLA(MuLti-scale Aggregation)으로 ablation 진행
→ UNETR의 decoder가 가장 좋은 성능을 얻음.
5.2 Patch Resolution
•
Patch의 resolution을 32 → 16으로 줄인 결과, 약간의 성능향상을 보임.
5.3 Model and Computational Complexity
•
Parameter size는 UNETR이 가장 크긴 하지만, FLOPs에서 CNN 기반의 model들 보다도 성능이 뛰어나면서 적당한 model complexity를 갖고 있음.
•
inference time은 다른 transformer 기반의 model들 보다 UNETR가 빠름.
6. Conclusion
•
3D medical image segmentation task에서 새로운 transformer 기반의 architecture 인UNETR을 제안.
•
UNETR은 encoder에 transformer를 사용함으로 model의 long-range dependencies를 학습하는 능력을 올렸고, 효과적으로 global contextual representation을 학습할 수 있었음.
•
BTCV, MSD dataset에서 좋은 성능을 얻었으며 BTCV dataset에서는 SOTA를 달성함.