Search

UNETR: Transformers for 3D Medical Image Segmentation

카테고리
VCMI논문리뷰
Index
ML/DL
ViT
Semantic Segmentation
UNET
3D Medical Image Segmentation
UNeTR
날짜
2024/01/15

1. Introduction

Medical Image 분석 분야에서 해부학적 구조를 파악하기 위해 Image Segmentation이 중요한 역할을 하게 되었고, 해당 분야에서 FCNNs 방식 중, UNet 기반의 모델들이 좋은 결과를 보임.
→ 하지만 FCNN 구조 특성상, 각 컨볼루션 필터가 입력 이미지의 작은영역(local area)에만 초점을 맞춰 처리하기 때문에 long-range의 dependency, 멀리 떨어진 영역 간의 관계 파악이 어렵다는 한계가 있었음.
[Solution]
FCNNs의 한계를 극복하기 위해 Vision Transformer와 같은 transformer 기반 모델들이 제시되었고, 본 논문에서 3D UNet과 Transformer 구조를 기반으로 한 UNETR(UNet TRansformers) 아키텍쳐를 제안.

1.1 UNETR 구조의 특징

3D 구조의 Medical Image(H,W,S,1)(H,W,S,1) segmentation 문제를 1D sequence-to-sequence 예측 문제로 전환
Transformer Encoder를 사용하기 때문에, 3D 이미지 데이터를 쪼개서 1D image patch embedding로 변환해 transformer의 input으로 넣어줌.
→ transformer encoder를 통해 long-range dependency와 global context를 보다 효율적으로 파악 가능
Decoder는 CNN-based decoder를 사용하고, transformer encoder와 CNN decoder는 skip-connection으로 직접 연결
→ transformer는 localized information를 정확하게 파악하지 못하기 때문에 CNN 기반 디코더 사용해 segmentation output 예측을 진행
→ skip connection을 통해서 다양한 해상도에서의 representation을 추출하고 병합할 수 있음.

1.2 UNETR의 성과

3D medical image segmentation task를 위한 transformer 기반 모델을 제시
→ volumetric data를 직접 사용. (기존 3D medical image segmentation 분야에서 많이 사용된 backbone CNN 모델을 사용하지 않음)
BTCV, MSD dataset에 성능을 검증하였으며, BTCV dataset에서 SOTA를 달성함.
뇌종양, spleen segmentation task에서 SOTA를 달성함

2. Methodology

2.1 Architecture

[모델 구조]

1.
3D input data를 1D patch embedding 으로 만든 후 linear projection 진행한다.
(H,W,D,C)(H,W,D,C) 구조의 3D input 데이터(P3,C)P^3, C) 구조의 N개의 1D patch embedding으로 변환
시퀀스 길이 N=(HWD)/P3 N = (H * W * D) / P^3
(P,P,P)(P, P, P) : 각 patch의 해상도
linear projection
: 각각의 패치 임베딩을 KK 차원의 embedding space로 투사
Position Embedding
: 공간 정보를 보존하기 위해서 1D의 positional embedding EposE_{pos} 을 patch embedding에 EE 에 더해줌.
원래 Transformer 구조에서는 각 sentence의 시작과 끝을 구분하기 [class] token을 추가하지만, sementic segmentation을 위해선 [class]token을 사용하지 않음
2.
Normalization과 Multi-Head-Attention을 적용(Transformer block)한 후, residual network의 구조를 통하여서 합쳐진 데이터를 MLP에 다시 넣는다.
MSA(Multi-head Self Attention) 구조는 n개의 parallel한 SA head(self-attention head)로 구성되어 있음
MSA 계산 과정
MLP 구조는 2개의 linear layer와 GeLU 활성화 함수를 사용. (ViTransformer의 MLP와 동일한 구조)
3.
Transformer block을 총 N=12번 반복한다.
4.
Transformer 연산이 끝난 후, 나온 embedding을 reshape 과정을 통하여 3D voxel(데이터)형식으로 reconstruction한다.
→ transformer 연산이 끝난 sequence representation은 H×W×DP3×K \frac{H \times W \times D}{P^3} \times K 형태를 갖게 되는데,
이를 HP×WP×DP×K \frac{H}{P} \times \frac{W}{P} \times \frac{D}{P} \times K tensor 형태로 변형해줌
5.
Upsampling을 수행 하면서, N=9, N=6, N=3, N=0(original)에서 나온 feature들과 skip-connection을 진행한다.
Upsampling 과정에서는 기존 UNet과 같이 ConvTranspose(Deconvolution)을 사용하여 해상도를 높여줌 (원래 input data와 같은 차원으로 회복하는 과정)
→ encoder 바로 뒤에 skip-connection을 통해 Deconvolution layer 연결
N-1번째 tranformer encoder에서 나온 output과 Deconvolution layer를 거친 feature map을 합치고, (3 x 3 x 3) Conv layer에 넣어줌.
→ 다시 해당 output을 (2x2x2) Deconv layer를 거쳐 Upsampling 진행
→ 위의 과정을 원래 input data와 같은 차원으로 회복될 때 까지 반복
6.
마지막 layer에서 최종 dimension과 channel을 조절한 후 segmentaiton result를 만든다.
최종 output은 1 x 1 x 1 Conv layer 와 Softmax 활성화 함수를 거쳐 voxel-wise sementic prediction 값으로 생성됨.

2.2 Loss function

loss function의 soft dice loss와 cross-entropy loss를 조합하여 사용
II : voxel의 개수
JJ: dice loss
Yi,j,Gi,jY_{i,j}, G_{i,j} : class j와 voxel i에서 예측된 probability와 ground truth
[Soft Dice loss]
[Cross Entropy]

3. Experiment Details

3.1 Datasets

BTCV (CT)
BTCV dataset은 30 subjects로 구성되어 있으며, 13개의 장기가 annotation되어 있는 dataset 이다.
각각의 CT scan은 80~255 slice, 512×512 pixel의 조영 CT이고, thickness는 1~6 mm이다.
모든 image는 1.0 mm의 voxel space로 resample되었다.
brain tumor segmentation 실험은 4-channel input의 3 class segmentation task로 진행됨
MSD (MRI/CT)
MSD dataset에서 brain tumoe segmentation task는 484개의 multi-modal, multi-site MRI data(FLAIR, T1w, T1gd, T2w)로 구성되어 있다.
necrotic/active tumor와 oedema가 annotation되어 있다. (voxel space는 1.0×1.0×1.0 mm3mm^3)
Spleen segmentation task는 41개의 CT volume으로 구성되어 있으며, spleen body가 annotation 되어 있다.
Spleen segmentation 실험은 1-channel input의 binary segmentation task, 여러 개의 장기 및 spleen segmentation task에서는 랜덤하게 input 이미지의 sample를 사용하여 진행. volume size (96, 96 ,96)

3.2 Metrics

Dice score 와 95% Hausdorff Distance(HD)를 사용

3.3 Implementation Details

4. Experiment Results & Discussion

4.1 Quantitative Evaluations

4.1.1. BTCV Segmentation performance

UNETR가 Standard, Free competition 모두에서 SOTA를 보임
→ 전반적으로 0.899 이상의 Dice score 기록
→ 2,3,4 번째로 좋은 방법들과 비교 했을 때, 1.238%, 1.696% and 5.269% 의 성능차이를 보임
Standard Competition
: UNETR과 CNN, transformer 기반의 베이스라인 모델들과 성능 비교
→ 모든 종류의 organ segmentation에대해 UNETR가 평균 85.3% 의 Dice score를 기록 (New SOTA 달성)
spleen, liver, stomach와 같은 크기가 큰 organ segmentation에서 2번째 좋은 base-line 모델을 1.043%, 0.830% and 2.125% 차이로 능가
작은 크기의 organ segmentation에서도 좋은 성과를 보임

4.1.2. MSD Segmentation performance

MSD dataset 실험에서는 brain tumor, spleen segmentation을 진행
→ UNETR과 CNN, transformer 기반의 베이스라인 모델들과 성능 비교
Brain Segmentation, Spleen segmentation 모두에서 UNETR가 모든 semantic class에 대해서 제일 좋은 성능을 보임
Tumor Core(TC) sub region segmentation task에 대해 상대적으로 좋은 성적을 보임

4.2 Qualitative Results

4.2.1. BTCV Qualitative Comparision

복부에 위치한 장기들에대한 segmentation task에서 UNETR가 좋은 성능을 보임.
nnUNet과 비교 했을 때, UNETR가 long-range dependency를 더 잘 학습함.
→ nnUNet은 간과 위 조직을 정확하게 구분하지 못했지만, UNETR는 장기들 사이의 경계를 정확하게 구분함. (row 3)
콩팥, 부신과 주위의 조직들의 구분 task도 잘 수행함
→ UNETR가 공간 정보 수집에 효과적임
→ 2D transformer 기반 모델들과 비교 했을 때도, 주변 경계 segmentation task를 훨씬 잘 수행함.

4.2.2. MSD Qualitative Comparision

→ Tumor Core(red, blue), Whole Tumor(red, blue, green), Enhancing Tumor core(green)
다른 baseline 모델들과 비교 했을 때, UNETR가 세부정보를 훨씬 잘 파악함.

4.3 Summary

BTCV, MSD Dataset 모두에 대해서 UNETR가 CNN, Transformer 기반 모델들보다 좋은 성능을 보임.
→ UNETR는 global, local depency 모두 학습이 가능하기 때문에 segmentation accuracy가 높은 것으로 보임.
→ long-range dependency 파악도 효과적으로 수행하여 target organ과 주변 조직들 사이의 경계를 정확하게 구분(qualitative comparision에서 확인가능)
UNETR은 BTCV dataset에서 SOTA를 달성함으로 효과를 입증했으며, gallbladder, adrenal glands와 같이 작은 크기의 organs에서도 좋은 성능을 보임.

5. Ablation

5.1 Decoder Choice

UNETR의 CNN decoder 구조 성능 비교를 위해, decoder를 NUP(Naive UPsampling), PUP(Progressive UPsampling), MLA(MuLti-scale Aggregation)으로 ablation 진행
→ UNETR의 decoder가 가장 좋은 성능을 얻음.

5.2 Patch Resolution

Patch의 resolution을 32 → 16으로 줄인 결과, 약간의 성능향상을 보임.

5.3 Model and Computational Complexity

Parameter size는 UNETR이 가장 크긴 하지만, FLOPs에서 CNN 기반의 model들 보다도 성능이 뛰어나면서 적당한 model complexity를 갖고 있음.
inference time은 다른 transformer 기반의 model들 보다 UNETR가 빠름.

6. Conclusion

3D medical image segmentation task에서 새로운 transformer 기반의 architecture 인UNETR을 제안.
UNETR은 encoder에 transformer를 사용함으로 model의 long-range dependencies를 학습하는 능력을 올렸고, 효과적으로 global contextual representation을 학습할 수 있었음.
BTCV, MSD dataset에서 좋은 성능을 얻었으며 BTCV dataset에서는 SOTA를 달성함.

UNETR Code