Search

[논문리뷰]Medical SAM2 : Segment Medical Images as Video vis Segment Anything Model2

카테고리
VCMI논문리뷰
Index
날짜
2025/03/19

1. Introduction

1.1 MedSAM-2의 등장 배경

기본 CNN 기반 모델이나 ViT 기반 Medical Image Segmentation 모델들은 다음과 같은 한계점이 있음.
1.
Generalization Issue
기존 모델들은 특정 장기나 조직에 대해서만 학습하는 Task-specific 방식 위주이기 때문에 새로운 데이터나 다른 의료 영상 모달리티(CT, MRI)에 대한 성능 저하 문제가 존재
2.
2D vs 3D 문제
기존 모델들은 대부분 2D 이미지를 기반으로 설계 되어 CT나 MRI 같은 3D 의료 영상 처리에 한계가 존재
3.
SAM 및 SAM-2의 한계
SAM
Fine-tuning 비용을 줄이기 위해 Zero-shot learning 이 가능한 Interactive Segmentation model인 Segment Anything Model(SAM)이 개발되었지만, 이미지 마다 계속 user prompt를 입력해야하는 비효율성이 존재
SAM-2
비디오 객체 추적을 위해 이전 이미지 슬라이드 정보를 반영할 수 있는 Memory Bank를 도입한 SAM-2 네트워크가 개발 되었지만 의료 영상은 프레임 간 시간적 연관성에 의존하지 않는 경우가 많이 때문에 Medical Image Segmentation 성능은 좋지 않음.

1.2 MedSAM-2 Contribution

MedSAM-2는 SAM-2 기반의 최초 의료 영상 분할을 위한 Generalized Auto-Tracking 모델로, 2D 및 3D 의료 영상분할을 통합적으로 처리할 수 있는 네트워크
1.
MedSAM-2: 최초의 SAM-2 기반 범용 의료 영상 분할 모델
2D 및 3D 의료 영상 분할을 통합적으로 처리할 수 있는 모델
기존 모델처럼 특정 장기나 조직에 국한되지 않고, 다양한 의료 영상에서 최소한의 사용자 개입(minimal user intervention)으로도 높은 성능 유지
2.
Self-Sorting Memory Bank
IoU Confidence & Dissimilarity 기반으로 동적으로 중요한 임베딩을 선택하는 Self-Sorting Memory Bank 메커니즘 제안
순서가 없는(unordered) 의료 이미지에서도 효과적인 분할 가능 → 기존 SAM-2의 한계를 극복
중복된 정보를 줄이고, 더 다양한 의료 영상 특징을 학습하여 일반화 성능 향상
3.
One-Prompt Segmentation 기능 도입
단 하나의 prompt만으로 여러 이미지를 분할 가능→ 2D 의료 영상에서도 높은 일반화 성능 보임
기존 SAM 기반 모델과 달리 모든 프레임마다 prompt를 제공할 필요 없음
→ 사용자 개입을 획기적으로 감소
4.
광범위한 벤치마크 평가 & SOTA 성능 달성
14개 이상의 벤치마크, 25개 세분화된 분할 작업(task)에서 평가
기존 fully-supervised segmentation 모델 및 SAM 기반 interactive 모델보다 더 높은 성능 기록
2D 및 3D segmentation 모두에서 최고 성능(SOTA, State-of-the-Art) 달성

2. Method

2.1 MedSAM-2 Architecture

MedSAM-2의 구조는 기존 SAM2에 Self-Sorting Memory bank 와 Weighted Resampling이 추가된 구조이다.
1.
Input image X를 Image encoder에 입력하여 Feature Embedding으로 변환
Ft=Eimg(xt)F_t = \Epsilon_{img}(x_t)
기존 SAM2의 MAE pre-trained Hiera image encoder가 사용되었다.
multi-scale feature를 추출하기 위해서 Hiera image encoder의 Stage 3, Stage 4 레이어를 사용해 한 input image에 대해 각각 16, 32 stride downsampling을 진행
stride 16 feature : H16×W16\frac{H}{16} \times \frac{W}{16}
stride 32 feature : H32×W32\frac{H}{32} \times \frac{W}{32}
Windowed absolute positional embedding을 추가해서 윈도우간 전역적 위치 정보를 보존
2.
각 input image에 대한 prompt는 Prompt Encoder 를 거쳐 Prompt Embedding으로 변환
Qt=Eprompt(pt)Q_t = \Epsilon_{prompt}(p_t)
3.
다음 Memory Attention Layer에서는 Self-Sorting Memory Bank에서 추출된 과거 memory embedding (M~tsort)(\tilde M_t^{sort}) 과 Memory attention을 수행한다.
현재 이미지 임베딩 Ft F_t 와 관련된 메모리 임베딩 M~tsort\tilde M_t^{sort} 과 Prompt Q1 Q_1 을 반영하기 위해 Memory attention을 진행
L개의 attention block으로 구성되어있고, 각 attention block은 self-attention과 cross-attention 을 진행
MemoryAttention=A(Ft,M~tsort,Q1)Memory Attention = A(F_t, \tilde M_t^{sort}, Q_1)
4.
이후 Mask Decoder에서는 최종 Segmentation mask를 생성하게 된다.
Mask decoder는 image embedding FtF_t 와 memory embedding M~tsort\tilde M_t^{sort}을 입력받음
self-sorting memory bank에서 추출된 object pointer token을 output token으로 사용하고, prompt embedding Q1Q_1 부착해서 같이 처리
two-way transformer block을 사용
1.
Self Attention
output token + prompt token 의 self-attention 수행
2.
Token to image Attention
output token + prompt token 의 self-attention 결과와 image embedding 간의 attention을 수행
output token과 prompt token의 정보를 활용해서 이미지의 특정 feature에 집중할 수 있도록 조정하는 역할
이후 mlp를 거쳐 비선형성 추가 및 차원변환
3.
Image to Token Attention layer
image embeddingmlp를 거친 output token 간의 attention 수행
output token이 이미지의 어느 부분과 중요한 연관성이 있는지 학습
2x conv. trans. block은 Image Encoder Stage 1, Stage 2 레이어로, upsampling 레이어로 사용됨 (최종 segmentation mask 생성)
[Mask Decoder 주요 Output]
a.
Segmentation mask
dot product per mask 단계를 거친 최종 segmentation 결과
입력된 prompt가 모호하면, 각 frame 에대해서 multiple mask를 동시에 출력
b.
Object Pointer(Obj ptr)
각 mask를 대표하는 output token을 Object pointer로 지정하고, memory bank에 저장
c.
IoU Scores
multiple mask를 생성하기 때문에 각 Mask prediction에 대해서 IoU 점수를 계산
→ IoU 점수에 따라 가장 신뢰할 수 있는 Mask를 선택
d.
Occlusion Scores
Occlusion prediction head가 생성하는 값으로, 객체가 현재 이미지에서 가려졌는 지(occluded) 여부에 대한 예측 점수
Occlusion Prediction Head
SAM2에서 PVS 문제를 다루기 때문에 각 이미지에 segment 대상에 대한 정보가 없을 수도 있음.
Additional head를 두어서 객체가 현재 프레임에서 가려졌는 지(occluded) 여부를 예측
token to image attn. 블록에 포함되어있고, prediction결과를 occlusion token으로 생성
→ 이후 MLP head를 거쳐 occlusion score 변환 (확률로 표현)
이후 memory bank에 같이 저장
5.
Memory Encoder를 통해 segmentation mask를 downsampling
convolutional 모듈을 통해서 downsampling 진행
light-weight convolution layer를 사용해서 image encoder에서 나온 element 단위의 image embedding을 추가
6.
Self-Sorting Memory Bank (MtsortM_t^{sort})
가장 최근 K개의 memory embedding을 저장하는 SAM2와 달리 MedSAM2에서는 동적으로 가장 중요한 memory embedding을 선별해서 저장(Self-Sorting)
→ Medical image는 순서에 독립적인 데이터이기 때문에 일반적인 영상 데이터와는 다르게 slice 순서가 중요한 정보가 아닐 수 있음
각 time step(=slide)에 대해서 MtsortM_t^{sort} 업데이트
IoU confidence score와 Dissimilarity score 기반으로 memory embedding 선별
IoU Confidence Score
ct1c_{t-1} ≥ 임계값 cthreshc_{thresh} 이면 메모리 임베딩 Et1E_{t-1}를 memory bank에 추가
ct1c_{t-1} < cthreshc_{thresh} 이면 memory bank 그대로 유지
C=Mt1sortEt1C = M_{t-1}^{sort} \cup {E_{t-1}}
Dissimilarity score DiD_i
각 메모리 임베딩에 대해서 불일치도 DiD_i를 계산
Di=ΣjC1Sim(Ei,Ej),EiCD_i = \Sigma_{j \in C} 1 - Sim(E_i, E_j), \forall E_i \in C
C에 있는 임베딩 중, DiD_i 수치가 가장 높은 K개 메모리 임베딩 저장
Mtsort=TopK(Di)EiCM_t^{sort} = Top K (D_i)_{E_i \in C}
g.
Memory Bank Resampling(M~tsort) (\tilde M_t^{sort})
현재 이미지 임베딩 FtF_t 와 관련된 메모리 임베딩을 강조하기 위해서 Resampling 진행
FtF_t Mtsort M_t^{sort} 간의 유사도 score(Probability Distribution)를 계산해서 resampling
Pi,t=sim(Ft,Ei)EjMtsortsim(Ft,Ej)P_{i,t} = \frac{sim(F_t, E_i)}{\sum _{E_j \in M_t^{sort}}sim(F_t, E_j)}
Pi,tP_{i,t} score 를 기준으로 FtF_t 와 유사한 메모리 임베딩을 우선적으로 처리

2.2 Unified Approach for 2D and 3D Images

MedSAM-2는 self-sorting memory bank를 사용하여 2D 와 3D medical image segmentation 문제를 모두 해결할 수 있다고 제시
또한, 2D medical image에서는 프롬프트를 1개만 사용하는 ‘One Prompt Segmenation’ 기법을 제시

3D Medical Image 처리를 위한 6개의 orientation

3D 의료 영상은 volume (VRH×W×DV\in \R ^{H\times W \times D}) 기반 이미지 이므로 3D volume 처리를 위해 6개의 orientation 과정을 self-sorting memory bank 에서 수행
6개의 Orientation을 모두 적용하게 되면 다양한 해부학적 정보를 학습할 수 있음
→ 하지만, 가장 좋은 성능의 순서 조합은 아직 파악하지 못했고, resampling 과정에서 가장 좋은 orientation 조합을 동적으로 찾음
Inference 과정에서는 input data 에 대한 multiple orientation 을 원본데이터와 함께 넣어서 예측측

One-prompt 2D Segmentation

MRI나 CT의 Axial, Coronal, Sagittal 슬라이드들은 일반적인 비디오 프레임과 다르게 일정한 순서대로 배치되지 않을 수 있음
→ iVOS, 기존 SAM2의 Memory Bank와 같이 순서대로 프레임을 기억하는 방식을 적용하면 segmentation 오류가 발생할 수 있음
MedSAM-2는 Weighted Resampling을 통해서 feature similarity를 기준으로 메모리 embedding을 선택하여 참고
이러한 medical image의 sequence-independent한 특징을 고려해서 One-prompt segmentation을 도입함
Self-sorting memory bank에서 1개의 프롬프트만으로도 모든 슬라이스에서 자동으로 segmentation 확장을 할 수 있음.
기존 SAM2와 다르게 이미지 마다 별도로 프롬프트를 입력해지 않아도 되어서 사용자 개입이 최소화 됨.

3. Experiment

Dataset

Implementation

Metrics and Variants

3.1 Universal Medical Image Segmentation

3D Medical Image(BTCV) Evaluation - Dice

MedSAM-2는 평균 Dice Score 89.0%를 기록해 모든 비교 모델 중 가장 높은 성능을 보임.
기존 가장 성능이 좋은 Interactive Segmentation 모델대비 더 더 높은 성능 기록(3.2%)
→ MedSAM2는 훨씬 적은 prompt 만으로도 더 좋은 성능을 보임.
일반적인 Task-tailored(TransUNet, Swin-UNETR)보다 더 높은 성능을 기록

3D Medical Image(BTCV) Qualitative Comparison

Universal 2D Medical Image Evaluation

Zero-shot setting에서 11개의 2D Medical Image Dataset(unseen task)에 대한 Segmentation 성능평가 진행
MedSAM-2가 76.8%로 2D 이미지에서도 모든 비교모델을 능가하며 최고 성능 달성
→ 기존 최고 성능을 기록한 One-Prompt 모델 보다도 6.8% 높은 성능을 보임.
기존 One-Prompt 및 interactive 모델보다 zero-shot 환경에서 더 높은 성능을 유지하면서도, 훨씬 적은 사용자 개입을 요구함.

3.2 One-prompt 2D Segmentation

MedSAM-2의 One-prompt Segmentation 성능을 평가하기 위해서 기존의 Few/One-Shot Learning model 들과 비교
각각의 모델을 5번 테스트하여 서로다른 prompt와 input sequence를 사용해서 평가
MedSAM-2의 평균 Dice score가 다른모델들 보다 높고, 성능 변동성(variance)가 더 낮게 측정됨
다양한 task와 prompt 유형에 대해서도 뛰어난 generalization 능력을 보임

3.3 Analysis and Ablation Study

Mutual Information Analysis of Self-Sorting Memory Bank

ISIC 데이터셋을 사용해서 시간에 따른 Self-sorting memory bank가 저장한 임베딩 변화 분석
초기에는 Mutual information이 2.54로 높게 측정되었지만, 시간이 지나면서 memory bank가 점점 더 다양한 특징을 저장하여 상호 정보량이 1.43으로 감소
self-sorting mechanism을 통해서 중복되는 정보보다 다양한 특성을 반영하도록 memory bank 최적화 됨
→ 모델의 일반화 성능 향상됨

Prompt Frequency Analysis on 2D and 3D Medical Images

REFUGE(2D), BTCV(3D) 데이터셋에서 prompt frequency이 segmentation 성능에 미치는 영향을 실험
prompt frequency가 1.00 일때 가장 높은 성능을 보였고, 이를 통해 프롬프트를 더 자주 제공할 수록 성능이 향상됨을 확인
또한, SAM2에 비해서 MedSAM-2가 프롬프트 빈도 변화에 더 robust 하다는 것을 확인
SAM2 는 프롬프트 수가 적을 수록 성능이 크게 감소함(BTCV-7.5%, REFUGE-33.1%).

Ablation Study

Ablation study에서는 CadVidSet, BTCV-Aorta 데이터셋에서 MedSAM-2의 주요 설계 요소 3가지를 평가
IoU Threshold를 사용했을 때 불필요한 중복 데이터를 줄여 더 신뢰도 높은 샘플을 학습할 수 있다는 것을 확인
Dissimilarity 기반 저장 방식을 통해 더 다양한 특징을 학습할 수 있어 학습 성능이 개선됨. (57.8→64.5)
Memory Resampling 기법을 통해서 중요한 sample을 우선적으로 학습해 성능 개선됨 (64.5→72.9)