해당 깃허브 Repository의 tensorflow Keras를 기반으로 구현된 UNETR 2D 모델에 대한 코드리뷰를 진행할 예정이다.
해당 모델은 기존 3D-UNETR 논문을 기반으로 진행되었고, backbone 모델로 사용된 ViT 모델은 keras implementation을 기반으로 구현되었다.
UNETR_2D.py
2D UNETR의 전체 코드는 다음과 같다.
UNETR_2D/models/UNETR_2D.py
1. 라이브러리 import 및 파라미터
•
주요 라이브러리 import
from math import log2
from tensorflow.keras import Model, layers
from .modules import *
Python
복사
•
UNETR_2D 함수 모델 주요 파라미터
def UNETR_2D(
input_shape, # input image 사이즈 (H, W, C)
patch_size, # ViT patch size, 논문에서는 16
num_patches, # 이미지에서 나오는 총 패치수
projection_dim, # encoder Embedding dimension
transformer_layers, # transformer encoder layer 개수
num_heads, # MHA layer의 head 개수
transformer_units, # transformer의 MLP 개수
data_augmentation = None, #데이터 증강 함수
num_filters = 16, # UNETR 디코더 첫 레이어의 필터 개수, 기본 16
num_classes = 1, # segmentation class 개수
decoder_activation = 'relu', # 디코더 활성화 함수
decoder_kernel_init = 'he_normal', #디코더 커널 초기화 방법
ViT_hidd_mult = 3, # ViT에서 Skip-connection에 사용할 레이어 선택
batch_norm = True, # 배치 정규화 사용 여부
dropout = 0.0 # drop-out 사용 비율
):
Python
복사
◦
patch_size : ViT patch size
◦
num_patches : 이미지에서 나오는 총 패치수, patch_size x patch_size
→ 만약 256x256 이미지이고 patch_size가 16 이면 num_patches는
◦
num_heads: MHA 헤드 수이고, projection_dim 이 num_heads로 /나누어 떨어져야함
◦
transformer_units : 트랜스포머의 FFN 크기결정, projection_dim의 2배 또는 4배로 설정
◦
data_augmentation : 데이터 증강 함수, tf.keras.layers 혹은 tf.keras.Sequential 을 사용해서 구현
◦
ViT_hidd_mult : Transformer 인코더에서 skip-connection에서 사용할 레이어 선택
→ transformer_layers=12, ViT_hidd_mult=3이면, [Z3, Z6, Z9] 레이어의 출력을 사용
2. Input Layer 구현 - Linear Projection, Patches, PatchEncoder
# input 데이터 입력 받음
inputs = layers.Input(shape=input_shape)
# Augment data 사용 시 구현
augmented = data_augmentation(inputs) if data_augmentation != None else inputs
# Patches Class로 입력 이미지를 patch_size x patch_size 개로 자름
patches = Patches(patch_size)(augmented)
# PatchEncoder Class로 각 패치를 embedding dimension으로 Linear Projection
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
# Hidden states 선언
hidden_states_out = []
Python
복사
•
Patches 클래스 구현 - models/modules.py
class Patches(layers.Layer):
# 전체 이미지 mini-batch를 패치 mini-batch로 변환하는 레이어 클래스
def __init__(self, patch_size):
super(Patches, self).__init__()
self.patch_size = patch_size # 패치 크기
def call(self, images):
batch_size = tf.shape(images)[0] # 입력 이미지 batch size 추출
patches = tf.image.extract_patches(
images=images, # 입력 이미지
sizes=[1, self.patch_size, self.patch_size, 1], # 패치크기 설정
strides=[1, self.patch_size, self.patch_size, 1], # 패치stride 설정
rates=[1, 1, 1, 1], # 샘플링 비율
padding="VALID", # 패딩 없이 유효한 영역만 패치 추출
)
patch_dims = tf.shape(patches)[-1] # 패치의 차원 (W × H × C)
# 패치를 mini-batch로 재구성
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches # 배치 형태의 패치 return
Python
복사
Patches 클래스는 ViT와 동일하게 2D input image를 차원의 1차원 패치 임베딩으로 변환한다
이때 총 패치 시퀀스의 길이 (num_patches) 은 이다
→ tf.image.extract_patches() 함수를 사용
◦
patch 관련 파라미터
▪
sizes : 패치 크기 파라미터 [batch, height, width, channels] 1x4 텐서로 구성된다.
→ patch_size가 16, 패치의 batch 사이즈가 1이면 [1, 16, 16, 1]
▪
strides : 패치 간격 설정, 패치 크기와 동일하게 설정하면 겹치치 않게 나누게 된다
▪
rates : 샘플링 비율, 보통 1로 설정된다
▪
padding : 패치 추출시 패딩 처리 방법
▪
patch_dims : 패치 차원 정보 (H, W, C)
tf.reshape(patches, [batch_size, -1, patch_dims]) 로 이미지 단위의 배치를 패치단위의 배치로 재구성하여 반환한다.
•
PatchEncoder 클래스
class PatchEncoder(layers.Layer):
# 패치를 받아 임베딩 공간으로 투영하고 위치 임베딩을 더하는 레이어
def __init__(self, num_patches, projection_dim):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches # 패치 총 개수
# 패치를 지정된 차원 projection_dim 으로 투영
self.projection = layers.Dense(units=projection_dim)
# Position embedding 선언
self.position_embedding = layers.Embedding(
input_dim=num_patches,
output_dim=projection_dim
)
def call(self, patch):
# 패치 위치 index 정의(0 ~ num_patches-1)
positions = tf.range(start=0, limit=self.num_patches, delta=1)
# 패치를 projection_dim 차원으로 투영하고 position embedding 추가
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
Python
복사
PatchEncoder 클래스는 개별 패치를 지정된 embedding dimension 차원으로 linear projection을 수행하는 클래스이다.
◦
또한, ViT에서는 각 패치의 전체이미지에서의 position 정보를 보존하기 위해서
패치 인코딩 를 projection 된 patch embedding에 더해 인코더에 주입하게 된다.
3. ViT Encoder 구현
ViT Encoder Block 들은 for 문을 통해서 구현되어있다.
지정된 transformer_layers 수만큼의 ViT Encoder Block을 생성하고 각 Transformer Encoder는 다음과 같이 구현되어 있다.
for _ in range(transformer_layers):
# Layer Normalization 1
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Multi-Head Attention(Self-Attention)
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, # 어텐션 헤드 수
key_dim=projection_dim, # 각 어텐션 헤드의 key차원
dropout=0.1 # drop-out 비율
)(x1, x1) # Self-Attention이므로 query, key, value 모두 x1 사용
# Skip Connection 1: Attention 출력에 원래 입력(encoded_patches)을 더함
x2 = layers.Add()([attention_output, encoded_patches])
# Layer Normalization 2: Skip Connection 결과 다시 정규화
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP
x3 = mlp(
x3, # 입력 데이터
hidden_units=transformer_units, # MLP의 Desne 레이어의 노드개수
dropout_rate=0.1
)
# Skip Connection 2: MLP 출력에 이전 출력(x2)을 더함
encoded_patches = layers.Add()([x3, x2])
# Hidden State 저장: 각 블록의 출력을 리스트에 추가
hidden_states_out.append(encoded_patches)
Python
복사
•
각 ViT 블록의 Multi-Head Attention layer는 layers.MultiheadAttention() 함수로 구현되어있고, Self-Attention을 수행하기 때문에 query, key, value 모두 입력데이터인 x1이 들어가야한다.
•
또한 각 layer normalization 수행 전 Skip Connection 은 layers.Add() 함수를 통해 구현되어 attention 결과에 입력데이터를 더해 입력 데이터의 원본 정보를 유지한다.
•
ViT Block의 MLP는 GELU activation fuction을 수행하는 Dense 레이어와 Dropout을 수행하는 layer로 구성된다.
def mlp(inputs, hidden_units, dropout_rate):
# hidden_units는 각 Dense 레이어의 뉴런 개수를 정의하는 리스트
for units in hidden_units:
inputs = layers.Dense(units, activation=tf.nn.gelu)(inputs) # Dense 레이어 + 활성화 함수 (GELU)
inputs = layers.Dropout(dropout_rate)(inputs) # 드롭아웃 적용
return inputs
Python
복사
4. Encoder - Decoder 연결 Bottle neck, Decoder Upsampling 과정 구현
다음은 UNETR 디코더 파트로 bottleneck에서 시작하여 최종 출력까지에 대한 코드이다.
# 업스케일링 횟수를 patch_size 맞추어 지정
total_upscale_factor = int(log2(patch_size))
# dropout 값을 레이어 별 리스트로 확장
if type(dropout) is float: # dropout 값이 단일 값(float)일 경우
dropout = [dropout,] * total_upscale_factor # 업스케일링 레벨에 맞게 동일 값 반복
# Bottleneck - ViT의 encoded_patches을 Reshape하여 Bottleneck의 입력으로 변환
z = layers.Reshape(
[input_shape[0] // patch_size, input_shape[1] // patch_size, projection_dim]
)(encoded_patches)
# 업스케일링 block - up_green_block
x = up_green_block(z, num_filters * (2 ** (total_upscale_factor - 1)))
Python
복사
•
UNETR에서도 기존 UNet 과 같이 ConvTranspose(Deconvolution)을 사용해서 Upsampling을 진행하고, up_green_block 함수에서 수행하게 된다.
◦
up_green_block은 stride가 2인 (2x2) filter를 사용해 전치컨볼루션 연산을 수행하게 되고, 입력된 feature map 크기의 2배로 복원이 된다.
def up_green_block(x, filters, name=None):
x = layers.Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same', name=name) (x)
return x
Python
복사
•
Decoder Upsampling 과정
1.
먼저 total_upscale_factor - layer 개수 만큼 mid_blue_block을 수행하게 된다. (깊은 레이어일 수록 mid_blue_block 수행 횟수가 적다)
•
mid_blue_block에서는 2x2 전치컨볼루션과 3x3 Convolution 연산, Batch Normalization, RELU 활성화 함수를 거치게 된다.
def mid_blue_block(x, filters, activation='relu',
kernel_initializer='glorot_uniform', batch_norm=True, dropout=0.0):
# 2x2 DeConvolution
x = up_green_block(x, filters)
# 3x3 Conv, BN, RELU
x = basic_yellow_block(x, filters, activation=activation,
kernel_initializer=kernel_initializer, batch_norm=batch_norm,
dropout=dropout)
return x
Python
복사
•
3x3 Convolution 연산, Batch Normalization, RELU 활성화 함수는 basic_yellow_block에서 수행하는 내용과 동일하기 때문에 해당 함수를 호출하여 수행한다.
def basic_yellow_block(x, filters, activation='relu',
kernel_initializer='glorot_uniform', batch_norm=True, dropout=0.0):
x = layers.Conv2D(filters, (3,3), padding = 'same',
kernel_initializer = kernel_initializer)(x)
x = layers.BatchNormalization() (x) if batch_norm else x
x = layers.Activation(activation) (x)
x = layers.Dropout(dropout)(x) if dropout > 0.0 else x
return x
Python
복사
2.
N-1 번째 transformer encoder에서 나온 output과 up_green_block에서 나온 feature map을 합치게 된다.
3.
다시 해당 output을 two_yellow 블록에 넣어 upsampling된 특성을 convolution 연산을 통해 학습하고, Batch normalization으로 학습안정성과 일반화성능을 높이고, Activation layer로 비선형성을 도입한다.
→ Skip-Connection을 통해 합쳐진 정보로 부터 복잡한 패턴을 추출할 수 있음
def two_yellow(x, filters, activation='relu', kernel_initializer='glorot_uniform', batch_norm=True, dropout=0.0):
x = basic_yellow_block(x, filters, activation=activation, kernel_initializer=kernel_initializer, batch_norm=batch_norm, dropout=dropout)
x = basic_yellow_block(x, filters, activation=activation, kernel_initializer=kernel_initializer, batch_norm=batch_norm, dropout=0.0)
return x
Python
복사
4.
up_green_block으로 upsampling 진행
# Decoder 업스케일링 반복
for layer in reversed(range(1, total_upscale_factor)):
# Skip connection을 위해 Transformer의 hidden states 가져오기
z = layers.Reshape(
[input_shape[0] // patch_size, input_shape[1] // patch_size, projection_dim]
)(hidden_states_out[(ViT_hidd_mult * layer) - 1])
# Mid-block 처리 (블루 블록) - 레이어에 따라 mid 블록을 추가적으로 적용
for _ in range(total_upscale_factor - layer): # 레이어에 따라 mid 블록을 추가적으로 적용
z = mid_blue_block(
z,
num_filters * (2 ** layer), # 해당 레이어에 맞는 필터 수
activation=decoder_activation,
kernel_initializer=decoder_kernel_init,
batch_norm=batch_norm,
dropout=dropout[layer],
)
# Skip connection과 decoder 블록 병합
x = layers.concatenate([x, z]) # Skip connection으로 이전 단계 출력(x)과 현재 hidden state(z)를 결합
# Yellow block 적용
x = two_yellow(
x,
num_filters * (2 ** layer),
activation=decoder_activation,
kernel_initializer=decoder_kernel_init,
batch_norm=batch_norm,
dropout=dropout[layer],
)
# 업스케일링 (다음 단계로 이동)
x = up_green_block(
x,
num_filters * (2 ** (layer - 1)),
)
Python
복사
5. 첫번째 Skip-connection 구현
해당 코드는 가장 첫번째 skip-connection을 구현한 것으로, 입력이미지의 초기정보를 보존하기 위해 사용된다.
# 가장 첫번째 skip-connection
first_skip = two_yellow(augmented, num_filters, activation=decoder_activation, kernel_initializer=decoder_kernel_init, batch_norm=batch_norm, dropout=dropout[0])
x = layers.concatenate([first_skip, x])
Python
복사
•
two_yellow 블록을 통해서 feature 추출을 하고, 디코더의 첫번째 결과와 concatenate 한다
6. 최종 output feature map return
# UNETR_2D output
x = two_yellow(x, num_filters, activation=decoder_activation,
kernel_initializer=decoder_kernel_init, batch_norm=batch_norm,
dropout=dropout[0] )
output = layers.Conv2D( num_classes, (1, 1), activation='softmax',
name="mask") (x) # semantic segmentation -- ORIGINAL: softmax
# Create the Keras model.
model = Model(inputs=inputs, outputs=output)
return model.
Python
복사
UNETR의 최종 output은 다시 two_yellow block을 거치고, softmax를 포함한 1x1 Convolution layer을 거처 pixel 단위의 segmentation map을 얻는다.
→ binary segmentation에서는 activation function을 sigmoid로 사용하고 [H,W,1]로 출력하는 것이 좋음