Search

2D-UNETR 코드리뷰

카테고리
VCMI코드리뷰
Index
UNETR
날짜
2024/12/11
해당 깃허브 Repository의 tensorflow Keras를 기반으로 구현된 UNETR 2D 모델에 대한 코드리뷰를 진행할 예정이다.
해당 모델은 기존 3D-UNETR 논문을 기반으로 진행되었고, backbone 모델로 사용된 ViT 모델은 keras implementation을 기반으로 구현되었다.

UNETR_2D.py

2D UNETR의 전체 코드는 다음과 같다.
UNETR_2D/models/UNETR_2D.py

1. 라이브러리 import 및 파라미터

주요 라이브러리 import
from math import log2 from tensorflow.keras import Model, layers from .modules import *
Python
복사
UNETR_2D 함수 모델 주요 파라미터
def UNETR_2D( input_shape, # input image 사이즈 (H, W, C) patch_size, # ViT patch size, 논문에서는 16 num_patches, # 이미지에서 나오는 총 패치수 projection_dim, # encoder Embedding dimension transformer_layers, # transformer encoder layer 개수 num_heads, # MHA layer의 head 개수 transformer_units, # transformer의 MLP 개수 data_augmentation = None, #데이터 증강 함수 num_filters = 16, # UNETR 디코더 첫 레이어의 필터 개수, 기본 16 num_classes = 1, # segmentation class 개수 decoder_activation = 'relu', # 디코더 활성화 함수 decoder_kernel_init = 'he_normal', #디코더 커널 초기화 방법 ViT_hidd_mult = 3, # ViT에서 Skip-connection에 사용할 레이어 선택 batch_norm = True, # 배치 정규화 사용 여부 dropout = 0.0 # drop-out 사용 비율 ):
Python
복사
patch_size : ViT patch size
num_patches : 이미지에서 나오는 총 패치수, patch_size x patch_size
→ 만약 256x256 이미지이고 patch_size가 16 이면 num_patches는 16×16=25616\times16 = 256
num_heads: MHA 헤드 수이고, projection_dimnum_heads로 /나누어 떨어져야함
transformer_units : 트랜스포머의 FFN 크기결정, projection_dim의 2배 또는 4배로 설정
data_augmentation : 데이터 증강 함수, tf.keras.layers 혹은 tf.keras.Sequential 을 사용해서 구현
ViT_hidd_mult : Transformer 인코더에서 skip-connection에서 사용할 레이어 선택
transformer_layers=12, ViT_hidd_mult=3이면, [Z3, Z6, Z9] 레이어의 출력을 사용

2. Input Layer 구현 - Linear Projection, Patches, PatchEncoder

# input 데이터 입력 받음 inputs = layers.Input(shape=input_shape) # Augment data 사용 시 구현 augmented = data_augmentation(inputs) if data_augmentation != None else inputs # Patches Class로 입력 이미지를 patch_size x patch_size 개로 자름 patches = Patches(patch_size)(augmented) # PatchEncoder Class로 각 패치를 embedding dimension으로 Linear Projection encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) # Hidden states 선언 hidden_states_out = []
Python
복사
Patches 클래스 구현 - models/modules.py
class Patches(layers.Layer): # 전체 이미지 mini-batch를 패치 mini-batch로 변환하는 레이어 클래스 def __init__(self, patch_size): super(Patches, self).__init__() self.patch_size = patch_size # 패치 크기 def call(self, images): batch_size = tf.shape(images)[0] # 입력 이미지 batch size 추출 patches = tf.image.extract_patches( images=images, # 입력 이미지 sizes=[1, self.patch_size, self.patch_size, 1], # 패치크기 설정 strides=[1, self.patch_size, self.patch_size, 1], # 패치stride 설정 rates=[1, 1, 1, 1], # 샘플링 비율 padding="VALID", # 패딩 없이 유효한 영역만 패치 추출 ) patch_dims = tf.shape(patches)[-1] # 패치의 차원 (W × H × C) # 패치를 mini-batch로 재구성 patches = tf.reshape(patches, [batch_size, -1, patch_dims]) return patches # 배치 형태의 패치 return
Python
복사
Patches 클래스는 ViT와 동일하게 (H,W,C)(H,W,C) 2D input image를 (P2,C)(P^2, C) 차원의 1차원 패치 임베딩으로 변환한다
이때 총 패치 시퀀스의 길이 (num_patches) N NH×WP2\frac{H\times W}{P^2}이다
tf.image.extract_patches() 함수를 사용
patch 관련 파라미터
sizes : 패치 크기 파라미터 [batch, height, width, channels] 1x4 텐서로 구성된다.
→ patch_size가 16, 패치의 batch 사이즈가 1이면 [1, 16, 16, 1]
strides : 패치 간격 설정, 패치 크기와 동일하게 설정하면 겹치치 않게 나누게 된다
rates : 샘플링 비율, 보통 1로 설정된다
padding : 패치 추출시 패딩 처리 방법
patch_dims : 패치 차원 정보 (H, W, C)
tf.reshape(patches, [batch_size, -1, patch_dims]) 로 이미지 단위의 배치를 패치단위의 배치로 재구성하여 반환한다.
PatchEncoder 클래스
class PatchEncoder(layers.Layer): # 패치를 받아 임베딩 공간으로 투영하고 위치 임베딩을 더하는 레이어 def __init__(self, num_patches, projection_dim): super(PatchEncoder, self).__init__() self.num_patches = num_patches # 패치 총 개수 # 패치를 지정된 차원 projection_dim 으로 투영 self.projection = layers.Dense(units=projection_dim) # Position embedding 선언 self.position_embedding = layers.Embedding( input_dim=num_patches, output_dim=projection_dim ) def call(self, patch): # 패치 위치 index 정의(0 ~ num_patches-1) positions = tf.range(start=0, limit=self.num_patches, delta=1) # 패치를 projection_dim 차원으로 투영하고 position embedding 추가 encoded = self.projection(patch) + self.position_embedding(positions) return encoded
Python
복사
PatchEncoder 클래스는 개별 패치를 지정된 embedding dimension 차원으로 linear projection을 수행하는 클래스이다.
z0=[xv1E;xv2E;;xvNE]+Eposz_0 = [x^1_vE;x^2_vE;…;x^N_vE]+E_{pos}
ER(P3,C)×KE \in R^{(P^3,C)\times K}
또한, ViT에서는 각 패치의 전체이미지에서의 position 정보를 보존하기 위해서
패치 인코딩 EposRN×KE_{pos} \in R^{N \times K}를 projection 된 patch embedding에 더해 인코더에 주입하게 된다.

3. ViT Encoder 구현

ViT Encoder Block 들은 for 문을 통해서 구현되어있다.
지정된 transformer_layers 수만큼의 ViT Encoder Block을 생성하고 각 Transformer Encoder는 다음과 같이 구현되어 있다.
for _ in range(transformer_layers): # Layer Normalization 1 x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) # Multi-Head Attention(Self-Attention) attention_output = layers.MultiHeadAttention( num_heads=num_heads, # 어텐션 헤드 수 key_dim=projection_dim, # 각 어텐션 헤드의 key차원 dropout=0.1 # drop-out 비율 )(x1, x1) # Self-Attention이므로 query, key, value 모두 x1 사용 # Skip Connection 1: Attention 출력에 원래 입력(encoded_patches)을 더함 x2 = layers.Add()([attention_output, encoded_patches]) # Layer Normalization 2: Skip Connection 결과 다시 정규화 x3 = layers.LayerNormalization(epsilon=1e-6)(x2) # MLP x3 = mlp( x3, # 입력 데이터 hidden_units=transformer_units, # MLP의 Desne 레이어의 노드개수 dropout_rate=0.1 ) # Skip Connection 2: MLP 출력에 이전 출력(x2)을 더함 encoded_patches = layers.Add()([x3, x2]) # Hidden State 저장: 각 블록의 출력을 리스트에 추가 hidden_states_out.append(encoded_patches)
Python
복사
각 ViT 블록의 Multi-Head Attention layer는 layers.MultiheadAttention() 함수로 구현되어있고, Self-Attention을 수행하기 때문에 query, key, value 모두 입력데이터인 x1이 들어가야한다.
또한 각 layer normalization 수행 전 Skip Connection 은 layers.Add() 함수를 통해 구현되어 attention 결과에 입력데이터를 더해 입력 데이터의 원본 정보를 유지한다.
ViT Block의 MLP는 GELU activation fuction을 수행하는 Dense 레이어와 Dropout을 수행하는 layer로 구성된다.
def mlp(inputs, hidden_units, dropout_rate): # hidden_units는 각 Dense 레이어의 뉴런 개수를 정의하는 리스트 for units in hidden_units: inputs = layers.Dense(units, activation=tf.nn.gelu)(inputs) # Dense 레이어 + 활성화 함수 (GELU) inputs = layers.Dropout(dropout_rate)(inputs) # 드롭아웃 적용 return inputs
Python
복사

4. Encoder - Decoder 연결 Bottle neck, Decoder Upsampling 과정 구현

다음은 UNETR 디코더 파트로 bottleneck에서 시작하여 최종 출력까지에 대한 코드이다.
# 업스케일링 횟수를 patch_size 맞추어 지정 total_upscale_factor = int(log2(patch_size)) # dropout 값을 레이어 별 리스트로 확장 if type(dropout) is float: # dropout 값이 단일 값(float)일 경우 dropout = [dropout,] * total_upscale_factor # 업스케일링 레벨에 맞게 동일 값 반복 # Bottleneck - ViT의 encoded_patches을 Reshape하여 Bottleneck의 입력으로 변환 z = layers.Reshape( [input_shape[0] // patch_size, input_shape[1] // patch_size, projection_dim] )(encoded_patches) # 업스케일링 block - up_green_block x = up_green_block(z, num_filters * (2 ** (total_upscale_factor - 1)))
Python
복사
UNETR에서도 기존 UNet 과 같이 ConvTranspose(Deconvolution)을 사용해서 Upsampling을 진행하고, up_green_block 함수에서 수행하게 된다.
up_green_block은 stride가 2인 (2x2) filter를 사용해 전치컨볼루션 연산을 수행하게 되고, 입력된 feature map 크기의 2배로 복원이 된다.
def up_green_block(x, filters, name=None): x = layers.Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same', name=name) (x) return x
Python
복사
Decoder Upsampling 과정
1.
먼저 total_upscale_factor - layer 개수 만큼 mid_blue_block을 수행하게 된다. (깊은 레이어일 수록 mid_blue_block 수행 횟수가 적다)
mid_blue_block에서는 2x2 전치컨볼루션과 3x3 Convolution 연산, Batch Normalization, RELU 활성화 함수를 거치게 된다.
def mid_blue_block(x, filters, activation='relu', kernel_initializer='glorot_uniform', batch_norm=True, dropout=0.0): # 2x2 DeConvolution x = up_green_block(x, filters) # 3x3 Conv, BN, RELU x = basic_yellow_block(x, filters, activation=activation, kernel_initializer=kernel_initializer, batch_norm=batch_norm, dropout=dropout) return x
Python
복사
3x3 Convolution 연산, Batch Normalization, RELU 활성화 함수는 basic_yellow_block에서 수행하는 내용과 동일하기 때문에 해당 함수를 호출하여 수행한다.
def basic_yellow_block(x, filters, activation='relu', kernel_initializer='glorot_uniform', batch_norm=True, dropout=0.0): x = layers.Conv2D(filters, (3,3), padding = 'same', kernel_initializer = kernel_initializer)(x) x = layers.BatchNormalization() (x) if batch_norm else x x = layers.Activation(activation) (x) x = layers.Dropout(dropout)(x) if dropout > 0.0 else x return x
Python
복사
2.
N-1 번째 transformer encoder에서 나온 output과 up_green_block에서 나온 feature map을 합치게 된다.
3.
다시 해당 output을 two_yellow 블록에 넣어 upsampling된 특성을 convolution 연산을 통해 학습하고, Batch normalization으로 학습안정성과 일반화성능을 높이고, Activation layer로 비선형성을 도입한다.
→ Skip-Connection을 통해 합쳐진 정보로 부터 복잡한 패턴을 추출할 수 있음
def two_yellow(x, filters, activation='relu', kernel_initializer='glorot_uniform', batch_norm=True, dropout=0.0): x = basic_yellow_block(x, filters, activation=activation, kernel_initializer=kernel_initializer, batch_norm=batch_norm, dropout=dropout) x = basic_yellow_block(x, filters, activation=activation, kernel_initializer=kernel_initializer, batch_norm=batch_norm, dropout=0.0) return x
Python
복사
4.
up_green_block으로 upsampling 진행
# Decoder 업스케일링 반복 for layer in reversed(range(1, total_upscale_factor)): # Skip connection을 위해 Transformer의 hidden states 가져오기 z = layers.Reshape( [input_shape[0] // patch_size, input_shape[1] // patch_size, projection_dim] )(hidden_states_out[(ViT_hidd_mult * layer) - 1]) # Mid-block 처리 (블루 블록) - 레이어에 따라 mid 블록을 추가적으로 적용 for _ in range(total_upscale_factor - layer): # 레이어에 따라 mid 블록을 추가적으로 적용 z = mid_blue_block( z, num_filters * (2 ** layer), # 해당 레이어에 맞는 필터 수 activation=decoder_activation, kernel_initializer=decoder_kernel_init, batch_norm=batch_norm, dropout=dropout[layer], ) # Skip connection과 decoder 블록 병합 x = layers.concatenate([x, z]) # Skip connection으로 이전 단계 출력(x)과 현재 hidden state(z)를 결합 # Yellow block 적용 x = two_yellow( x, num_filters * (2 ** layer), activation=decoder_activation, kernel_initializer=decoder_kernel_init, batch_norm=batch_norm, dropout=dropout[layer], ) # 업스케일링 (다음 단계로 이동) x = up_green_block( x, num_filters * (2 ** (layer - 1)), )
Python
복사

5. 첫번째 Skip-connection 구현

해당 코드는 가장 첫번째 skip-connection을 구현한 것으로, 입력이미지의 초기정보를 보존하기 위해 사용된다.
# 가장 첫번째 skip-connection first_skip = two_yellow(augmented, num_filters, activation=decoder_activation, kernel_initializer=decoder_kernel_init, batch_norm=batch_norm, dropout=dropout[0]) x = layers.concatenate([first_skip, x])
Python
복사
two_yellow 블록을 통해서 feature 추출을 하고, 디코더의 첫번째 결과와 concatenate 한다

6. 최종 output feature map return

# UNETR_2D output x = two_yellow(x, num_filters, activation=decoder_activation, kernel_initializer=decoder_kernel_init, batch_norm=batch_norm, dropout=dropout[0] ) output = layers.Conv2D( num_classes, (1, 1), activation='softmax', name="mask") (x) # semantic segmentation -- ORIGINAL: softmax # Create the Keras model. model = Model(inputs=inputs, outputs=output) return model.
Python
복사
UNETR의 최종 output은 다시 two_yellow block을 거치고, softmax를 포함한 1x1 Convolution layer을 거처 pixel 단위의 segmentation map을 얻는다.
→ binary segmentation에서는 activation function을 sigmoid로 사용하고 [H,W,1]로 출력하는 것이 좋음