1. Introduction
1.1 MedSAM-2의 등장 배경
기본 CNN 기반 모델이나 ViT 기반 Medical Image Segmentation 모델들은 다음과 같은 한계점이 있음.
1.
Generalization Issue
•
기존 모델들은 특정 장기나 조직에 대해서만 학습하는 Task-specific 방식 위주이기 때문에 새로운 데이터나 다른 의료 영상 모달리티(CT, MRI)에 대한 성능 저하 문제가 존재
2.
2D vs 3D 문제
•
기존 모델들은 대부분 2D 이미지를 기반으로 설계 되어 CT나 MRI 같은 3D 의료 영상 처리에 한계가 존재
3.
SAM 및 SAM-2의 한계
•
SAM
◦
Fine-tuning 비용을 줄이기 위해 Zero-shot learning 이 가능한 Interactive Segmentation model인 Segment Anything Model(SAM)이 개발되었지만, 이미지 마다 계속 user prompt를 입력해야하는 비효율성이 존재
•
SAM-2
◦
비디오 객체 추적을 위해 이전 이미지 슬라이드 정보를 반영할 수 있는 Memory Bank를 도입한 SAM-2 네트워크가 개발 되었지만 의료 영상은 프레임 간 시간적 연관성에 의존하지 않는 경우가 많이 때문에 Medical Image Segmentation 성능은 좋지 않음.
1.2 MedSAM-2 Contribution
MedSAM-2는 SAM-2 기반의 최초 의료 영상 분할을 위한 Generalized Auto-Tracking 모델로, 2D 및 3D 의료 영상분할을 통합적으로 처리할 수 있는 네트워크
1.
MedSAM-2: 최초의 SAM-2 기반 범용 의료 영상 분할 모델
•
2D 및 3D 의료 영상 분할을 통합적으로 처리할 수 있는 모델
•
기존 모델처럼 특정 장기나 조직에 국한되지 않고, 다양한 의료 영상에서 최소한의 사용자 개입(minimal user intervention)으로도 높은 성능 유지
2.
Self-Sorting Memory Bank
•
IoU Confidence & Dissimilarity 기반으로 동적으로 중요한 임베딩을 선택하는 Self-Sorting Memory Bank 메커니즘 제안
•
순서가 없는(unordered) 의료 이미지에서도 효과적인 분할 가능 → 기존 SAM-2의 한계를 극복
•
중복된 정보를 줄이고, 더 다양한 의료 영상 특징을 학습하여 일반화 성능 향상
3.
One-Prompt Segmentation 기능 도입
•
단 하나의 prompt만으로 여러 이미지를 분할 가능→ 2D 의료 영상에서도 높은 일반화 성능 보임
•
기존 SAM 기반 모델과 달리 모든 프레임마다 prompt를 제공할 필요 없음
→ 사용자 개입을 획기적으로 감소
4.
광범위한 벤치마크 평가 & SOTA 성능 달성
•
14개 이상의 벤치마크, 25개 세분화된 분할 작업(task)에서 평가
•
기존 fully-supervised segmentation 모델 및 SAM 기반 interactive 모델보다 더 높은 성능 기록
•
2D 및 3D segmentation 모두에서 최고 성능(SOTA, State-of-the-Art) 달성
2. Method
2.1 MedSAM-2 Architecture
MedSAM-2의 구조는 기존 SAM2에 Self-Sorting Memory bank 와 Weighted Resampling이 추가된 구조이다.
1.
Input image X를 Image encoder에 입력하여 Feature Embedding으로 변환
•
•
기존 SAM2의 MAE pre-trained Hiera image encoder가 사용되었다.
◦
multi-scale feature를 추출하기 위해서 Hiera image encoder의 Stage 3, Stage 4 레이어를 사용해 한 input image에 대해 각각 16, 32 stride downsampling을 진행
▪
stride 16 feature :
▪
stride 32 feature :
◦
Windowed absolute positional embedding을 추가해서 윈도우간 전역적 위치 정보를 보존
2.
각 input image에 대한 prompt는 Prompt Encoder 를 거쳐 Prompt Embedding으로 변환
•
3.
다음 Memory Attention Layer에서는 Self-Sorting Memory Bank에서 추출된 과거 memory embedding 과 Memory attention을 수행한다.
•
현재 이미지 임베딩 와 관련된 메모리 임베딩 과 Prompt 을 반영하기 위해 Memory attention을 진행
•
L개의 attention block으로 구성되어있고, 각 attention block은 self-attention과 cross-attention 을 진행
•
4.
이후 Mask Decoder에서는 최종 Segmentation mask를 생성하게 된다.
•
Mask decoder는 image embedding 와 memory embedding 을 입력받음
◦
self-sorting memory bank에서 추출된 object pointer token을 output token으로 사용하고, prompt embedding 부착해서 같이 처리
•
two-way transformer block을 사용
1.
Self Attention
•
output token + prompt token 의 self-attention 수행
2.
Token to image Attention
•
output token + prompt token 의 self-attention 결과와 image embedding 간의 attention을 수행
•
output token과 prompt token의 정보를 활용해서 이미지의 특정 feature에 집중할 수 있도록 조정하는 역할
•
이후 mlp를 거쳐 비선형성 추가 및 차원변환
3.
Image to Token Attention layer
•
image embedding 과 mlp를 거친 output token 간의 attention 수행
•
output token이 이미지의 어느 부분과 중요한 연관성이 있는지 학습
•
2x conv. trans. block은 Image Encoder Stage 1, Stage 2 레이어로, upsampling 레이어로 사용됨 (최종 segmentation mask 생성)
[Mask Decoder 주요 Output]
a.
Segmentation mask
•
dot product per mask 단계를 거친 최종 segmentation 결과
•
입력된 prompt가 모호하면, 각 frame 에대해서 multiple mask를 동시에 출력
b.
Object Pointer(Obj ptr)
•
각 mask를 대표하는 output token을 Object pointer로 지정하고, memory bank에 저장
c.
IoU Scores
•
multiple mask를 생성하기 때문에 각 Mask prediction에 대해서 IoU 점수를 계산
→ IoU 점수에 따라 가장 신뢰할 수 있는 Mask를 선택
d.
Occlusion Scores
•
Occlusion prediction head가 생성하는 값으로, 객체가 현재 이미지에서 가려졌는 지(occluded) 여부에 대한 예측 점수
Occlusion Prediction Head
•
SAM2에서 PVS 문제를 다루기 때문에 각 이미지에 segment 대상에 대한 정보가 없을 수도 있음.
•
Additional head를 두어서 객체가 현재 프레임에서 가려졌는 지(occluded) 여부를 예측
→token to image attn. 블록에 포함되어있고, prediction결과를 occlusion token으로 생성
→ 이후 MLP head를 거쳐 occlusion score 변환 (확률로 표현)
•
이후 memory bank에 같이 저장
5.
Memory Encoder를 통해 segmentation mask를 downsampling
•
convolutional 모듈을 통해서 downsampling 진행
•
light-weight convolution layer를 사용해서 image encoder에서 나온 element 단위의 image embedding을 추가
6.
Self-Sorting Memory Bank ()
•
가장 최근 K개의 memory embedding을 저장하는 SAM2와 달리 MedSAM2에서는 동적으로 가장 중요한 memory embedding을 선별해서 저장(Self-Sorting)
→ Medical image는 순서에 독립적인 데이터이기 때문에 일반적인 영상 데이터와는 다르게 slice 순서가 중요한 정보가 아닐 수 있음
•
각 time step(=slide)에 대해서 업데이트
•
IoU confidence score와 Dissimilarity score 기반으로 memory embedding 선별
◦
IoU Confidence Score
▪
≥ 임계값 이면 메모리 임베딩 를 memory bank에 추가
▪
< 이면 memory bank 그대로 유지
▪
◦
Dissimilarity score
▪
각 메모리 임베딩에 대해서 불일치도 를 계산
•
▪
C에 있는 임베딩 중, 수치가 가장 높은 K개 메모리 임베딩 저장
•
g.
Memory Bank Resampling
•
현재 이미지 임베딩 와 관련된 메모리 임베딩을 강조하기 위해서 Resampling 진행
•
와 간의 유사도 score(Probability Distribution)를 계산해서 resampling
◦
•
score 를 기준으로 와 유사한 메모리 임베딩을 우선적으로 처리
2.2 Unified Approach for 2D and 3D Images
•
MedSAM-2는 self-sorting memory bank를 사용하여 2D 와 3D medical image segmentation 문제를 모두 해결할 수 있다고 제시
•
또한, 2D medical image에서는 프롬프트를 1개만 사용하는 ‘One Prompt Segmenation’ 기법을 제시
3D Medical Image 처리를 위한 6개의 orientation
•
3D 의료 영상은 volume () 기반 이미지 이므로 3D volume 처리를 위해 6개의 orientation 과정을 self-sorting memory bank 에서 수행
•
6개의 Orientation을 모두 적용하게 되면 다양한 해부학적 정보를 학습할 수 있음
→ 하지만, 가장 좋은 성능의 순서 조합은 아직 파악하지 못했고, resampling 과정에서 가장 좋은 orientation 조합을 동적으로 찾음
•
Inference 과정에서는 input data 에 대한 multiple orientation 을 원본데이터와 함께 넣어서 예측측
One-prompt 2D Segmentation
•
MRI나 CT의 Axial, Coronal, Sagittal 슬라이드들은 일반적인 비디오 프레임과 다르게 일정한 순서대로 배치되지 않을 수 있음
→ iVOS, 기존 SAM2의 Memory Bank와 같이 순서대로 프레임을 기억하는 방식을 적용하면 segmentation 오류가 발생할 수 있음
•
MedSAM-2는 Weighted Resampling을 통해서 feature similarity를 기준으로 메모리 embedding을 선택하여 참고
•
이러한 medical image의 sequence-independent한 특징을 고려해서 One-prompt segmentation을 도입함
◦
Self-sorting memory bank에서 1개의 프롬프트만으로도 모든 슬라이스에서 자동으로 segmentation 확장을 할 수 있음.
◦
기존 SAM2와 다르게 이미지 마다 별도로 프롬프트를 입력해지 않아도 되어서 사용자 개입이 최소화 됨.
3. Experiment
Dataset
Implementation
Metrics and Variants
3.1 Universal Medical Image Segmentation
3D Medical Image(BTCV) Evaluation - Dice
•
MedSAM-2는 평균 Dice Score 89.0%를 기록해 모든 비교 모델 중 가장 높은 성능을 보임.
•
기존 가장 성능이 좋은 Interactive Segmentation 모델대비 더 더 높은 성능 기록(3.2%)
→ MedSAM2는 훨씬 적은 prompt 만으로도 더 좋은 성능을 보임.
•
일반적인 Task-tailored(TransUNet, Swin-UNETR)보다 더 높은 성능을 기록
3D Medical Image(BTCV) Qualitative Comparison
Universal 2D Medical Image Evaluation
Zero-shot setting에서 11개의 2D Medical Image Dataset(unseen task)에 대한 Segmentation 성능평가 진행
•
MedSAM-2가 76.8%로 2D 이미지에서도 모든 비교모델을 능가하며 최고 성능 달성
→ 기존 최고 성능을 기록한 One-Prompt 모델 보다도 6.8% 높은 성능을 보임.
•
기존 One-Prompt 및 interactive 모델보다 zero-shot 환경에서 더 높은 성능을 유지하면서도, 훨씬 적은 사용자 개입을 요구함.
3.2 One-prompt 2D Segmentation
MedSAM-2의 One-prompt Segmentation 성능을 평가하기 위해서 기존의 Few/One-Shot Learning model 들과 비교
•
각각의 모델을 5번 테스트하여 서로다른 prompt와 input sequence를 사용해서 평가
•
MedSAM-2의 평균 Dice score가 다른모델들 보다 높고, 성능 변동성(variance)가 더 낮게 측정됨
•
다양한 task와 prompt 유형에 대해서도 뛰어난 generalization 능력을 보임
3.3 Analysis and Ablation Study
Mutual Information Analysis of Self-Sorting Memory Bank
ISIC 데이터셋을 사용해서 시간에 따른 Self-sorting memory bank가 저장한 임베딩 변화 분석
•
초기에는 Mutual information이 2.54로 높게 측정되었지만, 시간이 지나면서 memory bank가 점점 더 다양한 특징을 저장하여 상호 정보량이 1.43으로 감소
•
self-sorting mechanism을 통해서 중복되는 정보보다 다양한 특성을 반영하도록 memory bank 최적화 됨
→ 모델의 일반화 성능 향상됨
Prompt Frequency Analysis on 2D and 3D Medical Images
REFUGE(2D), BTCV(3D) 데이터셋에서 prompt frequency이 segmentation 성능에 미치는 영향을 실험
•
prompt frequency가 1.00 일때 가장 높은 성능을 보였고, 이를 통해 프롬프트를 더 자주 제공할 수록 성능이 향상됨을 확인
•
또한, SAM2에 비해서 MedSAM-2가 프롬프트 빈도 변화에 더 robust 하다는 것을 확인
◦
SAM2 는 프롬프트 수가 적을 수록 성능이 크게 감소함(BTCV-7.5%, REFUGE-33.1%).
Ablation Study
Ablation study에서는 CadVidSet, BTCV-Aorta 데이터셋에서 MedSAM-2의 주요 설계 요소 3가지를 평가
•
IoU Threshold를 사용했을 때 불필요한 중복 데이터를 줄여 더 신뢰도 높은 샘플을 학습할 수 있다는 것을 확인
•
Dissimilarity 기반 저장 방식을 통해 더 다양한 특징을 학습할 수 있어 학습 성능이 개선됨. (57.8→64.5)
•
Memory Resampling 기법을 통해서 중요한 sample을 우선적으로 학습해 성능 개선됨 (64.5→72.9)