Search

Swin Transformer 코드리뷰

카테고리
VCMI코드리뷰
Index
ML/DL
Transformer
Semantic Segmentation
SwinTransformer
날짜
2024/05/08
이번 코드 리뷰에서는 Swin-Unet에서 Encoder로 사용된 Swin Transformer 구현 코드에 대해서 다루고자 한다.
참고한 코드의 깃허브 링크는 다음과 같다.
vision_transformer.py
networks

1. Swin-Unet 클래스 호출

먼저, Swin-Unet 클래스에서 Swin Transformer를 인코더로 호출하는 코드를 확인하고자한다.
Swin-Unet/train.py에서 SwinUnet 클래스 인스턴스를 생성하고, 체크포인트를 로드해주는 코드를 확인할 수 있다.
from networks.vision_transformer import SwinUnet as ViT_seg if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # SwinUnet 인스턴스 생성 및 config, img_size, num_class 지정 및 체크포인트 로드 net = ViT_seg(config, img_size=args.img_size, num_classes=args.num_classes).cuda() net.load_from(config)
Python
복사
net.load_from(config) 코드에서 Encoder로 Swin Transformer를 설정해주는 부분이 있고, 해당 함수는 Swin-Unet 클래스의 메서드 함수로 구현되어 있다.
load_from 함수 : 사전학습된 Swin Transformer 인코더를 Swin Unet 모델에 불러오는 함수
def load_from(self, config): # 모델의 사전 학습된 체크포인트 경로를 config에서 불러오기 pretrained_path = config.MODEL.PRETRAIN_CKPT # 사전 학습된 경로가 None이 아니면 진행 if pretrained_path is not None: # 사전 학습된 경로를 출력 print("pretrained_path:{}".format(pretrained_path)) # CUDA 및 CPU 설정 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 사전 학습된 모델을 지정된 device에 로드 pretrained_dict = torch.load(pretrained_path, map_location=device) # 불러온 사전 학습된 모델에 'model' 키가 없는 경우에 대한 예외처리 if "model" not in pretrained_dict: print("---start load pretrained modle by splitting---") # 사전 학습된 모델의 키 수정 pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} # 수정된 키를 반복하여 'output'이 포함된 키는 삭제 for k in list(pretrained_dict.keys()): if "output" in k: print("delete key:{}".format(k)) del pretrained_dict[k] # 수정된 사전 학습된 모델을 현재 모델에 로드 msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False) return # 사전학습된 Swin Transformer 인코더의 model config 불러오기 pretrained_dict = pretrained_dict['model'] print("---start load pretrained model of swin encoder---") # 현재 swin-unet의 stat_dict 불러오기 model_dict = self.swin_unet.state_dict() # 사전 학습된 config 내용을 deep copy full_dict = copy.deepcopy(pretrained_dict) # 사전 학습된 swin transformer의 가중치를 Swin-Unet 디코더로 전달하고, 재매핑 for k, v in pretrained_dict.items(): if "layers." in k: #인코더 각 layer의 가중치 확인 # Swin-Unet의 디코더 층으로 가중치 전달 current_layer_num = 3-int(k[7:8]) # Swin-Transformer 인코더 층 이름을 디코더 층 이름으로 변환 current_k = "layers_up." + str(current_layer_num) + k[8:] # 디코더의 새로운 층에 대한 인코더 가중치 할당 full_dict.update({current_k:v}) # 가중치가 맞지 않는 경우 해당 키를 삭제 for k in list(full_dict.keys()): if k in model_dict: if full_dict[k].shape != model_dict[k].shape: print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) del full_dict[k] # 최종 state_dict를 swin_unet 모델에 적재 및 확인 메세지 출력 msg = self.swin_unet.load_state_dict(full_dict, strict=False) else: # 사전 학습된 모델 경로가 None일 경우 메시지 출력 print("none pretrain")
Python
복사
self.swin_unet.state_dict() 에서는 해당 SwinUnet 클래스 인스턴스의 state_dict() 메소드를 호출하고 있다.
SwinUnet 클래스의 생성자 함수를 확인하면, Swin Transformer 인코더는 SwinTransformerSys 라는 별개의 클래스로 구현되어있으며, SwinUnet 클래스의 swin_unet 속성은 SwinTransformerSys 클래스의 인스턴스를 생성해서 구현됨을 알 수 있다.
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm from torch.nn.modules.utils import _pair from scipy import ndimage from .swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys class SwinUnet(nn.Module): def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): super(SwinUnet, self).__init__() self.num_classes = num_classes self.zero_head = zero_head self.config = config self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, patch_size=config.MODEL.SWIN.PATCH_SIZE, in_chans=config.MODEL.SWIN.IN_CHANS, num_classes=self.num_classes, embed_dim=config.MODEL.SWIN.EMBED_DIM, depths=config.MODEL.SWIN.DEPTHS, num_heads=config.MODEL.SWIN.NUM_HEADS, window_size=config.MODEL.SWIN.WINDOW_SIZE, mlp_ratio=config.MODEL.SWIN.MLP_RATIO, qkv_bias=config.MODEL.SWIN.QKV_BIAS, qk_scale=config.MODEL.SWIN.QK_SCALE, drop_rate=config.MODEL.DROP_RATE, drop_path_rate=config.MODEL.DROP_PATH_RATE, ape=config.MODEL.SWIN.APE, patch_norm=config.MODEL.SWIN.PATCH_NORM, use_checkpoint=config.TRAIN.USE_CHECKPOINT)
Python
복사
Swin Transformer 인스턴스 생성할 때 지정해주어야 하는 주요 파라미터는 다음과 같다.
img_size - 입력 이미지의 크기 지정 (H, W)
patch_size - 이미지 패치 사이즈
in_chans - 입력 채널의 수 (RGB : 3)
num_classes - 모델이 최종적으로 예측해야 하는 클래스의 수
embed_dim - 각 패치에 대해서 Linear Projection 할 때 차원 수
depths - Swin Transformer 모델의 각 스테이지에서 블록의 수
num_heads - MSA 에서 사용되는 헤드의 수
window_size - Swin Transformer의 각 윈도우의 크기
qkv_bias - Query, Key, Value 계산에 바이어스를 사용할지 여부
ape - 절대 위치 인코딩(Absolute Position Encoding) 사용 여부를 결정
patch_norm - 패치 레벨 정규화 여부를 지정

2. SwinTransformerSys 클래스

networks/swin_transformer_unet_skip_expand_decoder_sys.py 에서 Swin Transformer 인코더가 구현된 코드를 확인할 수 있다.
Swin Transformer Architecture
Swin Transformer 인코더의 전체 구조는 다음과 같다.
1.
Patch Partition Layer
2.
Linear Embedding Layer
→ 임의의 차원 C로 각 이미지 패치를 선형 변환하는 레이어 (48→ C)
3.
Swin Transformer Blocks
[Stage 1]
linear projection을 거친 이미지 패치는 Swin Transformer Blocks를 통과하고, self-attention 연산이 수행된다.
[Stage 2,3,4]
계층적 특성 맵을 생성하기 위해서 네트워크가 깊어짐에 따라서 각 Stage 별로 레이어에서 처리하는 패치를 병합하여 토큰 수를 줄이는 구조 (Patch Merging Layer)
→ Stage 2, 3, 4 에서 Swin Transformer Block을 통과하기 전에 Patch Merging 레이어를 두어 패치 병합을 수행한다.
(Stage 1: H4\frac {H}{4} x W4\frac {W}{4}Stage 2: H8\frac {H}{8} x W8\frac {W}{8}Stage 3: H16\frac {H}{16} x W16\frac {W}{16}Stage 4: H32\frac {H}{32} x W32\frac {W}{32} )
[SwinTransformerSys 클래스 전체 코드]

1. Patch Partition

먼저, Patch Partition Layer에서는 RGB 입력 이미지를 4x4 패치로 분할한다.
각 이미지 패치의 feature 차원은 4 x 4 x 3 (H x W x C) = 48이 되고, 해당 과정은 PatchEmedding 이라는 별도의 클래스에서 구현되어 있다.
PatchEmbedding 레이어 적용시 patch_size, input chanel(in_chans), 변환하고자하는 차원 C 값인 embed_dim 을 설정해주어야한다.
# 인코더 레이어 구성 self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers.append(layer)
Python
복사
Relative position bias
Swin Transformer 구조에서는 각 이미지의 위치정보를 relative position bias를 self-attention 수행과정에서 추가해주는 방식으로 학습한다.
→ ViT에서 사용되었던 Position Embedding 기법을 계속 사용하고 싶다면 ape를 True로 설정하면 된다.
# 1. Patch Partition : 이미지를 패치로 분할하고, 패치에 포지션 임베딩을 적용여부를 설정 self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # Absolute Position Embedding을 원할 경우 (디폴트는 Relative Position bias) if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) # drop-out 설정 self.pos_drop = nn.Dropout(p=drop_rate) # drop path rate 계산 dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
Python
복사
PatchEmbedding 클래스
SwinTransformerSys 클래스 내부에 별로도 구현된 Patch Embedding 전용 클래스이다.
컨볼루션 레이어를 이용해서 임의의 차원 값 C로 Linear Projection을 수행한다.
→ Stage 1의 Linear Embedding 레이어
class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) # 이미지 크기를 튜플 형태로 변환 patch_size = to_2tuple(patch_size) # 패치 크기를 튜플 형태로 변환 # 각 차원의 패치 개수를 계산 (H/4, W/4) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size # 이미지 크기 변수 설정 self.patch_size = patch_size # 패치 크기 변수 설정 self.patches_resolution = patches_resolution # 패치 해상도 변수 설정 # 전체 패치 수를 계산하여 변수 설정 self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans # 입력 채널 수 self.embed_dim = embed_dim # 임베딩 차원 수 # 컨볼루션 레이어를 이용한 패치 Linear Projection Layer (48→ C) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: # 정규화 레이어가 있는 경우에 대한 예외처리 self.norm = norm_layer(embed_dim) # 정규화 레이어를 초기화 else: self.norm = None # 정규화 레이어 없음 def forward(self, x): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) if self.norm is not None: x = self.norm(x) return x def flops(self): Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops
Python
복사

2. Swin Transformer Block (Stage 1~4)

위 그림과 같이, PatchEmbedding을 통해서 Linear Embedding이 완료되었다면, Stage 1의 Swin Transformer Block이 수행되어야한다. 하나의 Swin Transformer Blocks는 2개의 Swin Transformer를 연달아 통과하는 구조로 구성되어있다.
Swin Transformer Block는 ViT의 Multi-Head Self Attention(MSA) 블록이 Shifted Window MSA를 수행하는 블록으로 교체되었고, 다른 레이어는 동일하게 구성되어있다.
BasicLayer 클래스
: Swin Transformer 아키텍쳐에서 하나의 Stage에 해당하는 Swin Transformer 블록 구현 클래스
단일 SwinTransformerBlock에 대해 구현한 코드는 별개의 클래스를 따로 생성하여 구현한 것을 확인하였고, for 문을 통해서 여러 개의 Transformer block을 연결한다.
Stage 2,3,4 의 경우, downsample() 함수를 통해 Patch Merging 레이어를 수행하여 계층적 feature 맵을 생성한다.
class BasicLayer(nn.Module): def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() self.dim = dim # 입력 차원 self.input_resolution = input_resolution # 입력 해상도 self.depth = depth # 레이어 깊이 self.use_checkpoint = use_checkpoint # 체크포인팅 사용 여부 # for문을 사용해서 한 스테이지 내의 depth(블록 수)만큼의 Swin Transformer 블록 구축 self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # Stage 2, 3, 4의 Patch Merging Layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops
Python
복사
클래스 주요 파라미터
SwinTransformerBlock 클래스
: 단일 Swin-Transformer Block 구현 클래스
Swin Transformer 블록에서는
1.
Linear Normalization Layer → Window MSA Layer / SW- MSA Layer
2.
Linear Normalization Layer → MLP Layer
순으로 수행하게 되며, 1번 모듈과 2번 모듈은 잔차연결을 통해서 연결된다.
MLP 레이어는 2개의 레이어로 구성되어있고, GELU 활성화 함수를 사용한다.
class SwinTransformerBlock(nn.Module): def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim # 입력 차원 self.input_resolution = input_resolution # 입력 해상도 self.num_heads = num_heads # 어텐션 헤드 수 self.window_size = window_size # 윈도우 크기 self.shift_size = shift_size # 시프트 크기 self.mlp_ratio = mlp_ratio # MLP 차원 비율 if min(self.input_resolution) <= self.window_size: # 입력 해상도보다 윈도우 크기가 크거나 같은 경우 윈도우 파티셔닝을 하지 않음 self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must be in 0-window_size" # 1. 첫 번째 정규화 레이어 self.norm1 = norm_layer(dim) # 2. WA 레이어 self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # 윈도우 어텐션 self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # 드롭 패스 # 3. 두 번째 정규화 레이어 self.norm2 = norm_layer(dim) # 4. MLP 레이어 mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # SW-MSA를 위한 어텐션 마스크 계산 if self.shift_size > 0: H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 이미지 마스크 초기화 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) # 어텐션 마스크 버퍼 등록 def forward(self, x): '''SwinTransformerBlock에서 어텐션과 피드포워드 네트워크를 처리하는 forward 함수''' H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x # 잔차연결을 위한 입력값을 shortcut에 저장 x = self.norm1(x) # 1번째 정규화 레이어 norm1 적용 x = x.view(B, H, W, C) # 순환 시프트 적용 if self.shift_size > 0: # shift_size가 0보다 큰 경우 # 입력 텐서를 지정된 크기만큼 순환 shift (SW-MSA를 위한 과정) shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # 윈도우 분할 - 순환 시프트된 텐서를 윈도우 단위로 분할 (window_size, window_size, C) x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA - 분할된 윈도우들에 대해 어텐션 메커니즘 적용 attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # 윈도우 병합 attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # window_reverse 함수로 어텐션을 적용한 윈도우들을 원래 이미지 형태로 병합 shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # 순환 시프트 복원 if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops
Python
복사
WindowAttention 클래스
: 윈도우 기반 어텐션 메커니즘을 구현한 클래스
Window Self Attention에서는 relative position bias (B)를 self-attention 연산에 추가해주는 방식으로 각 이미지의 위치정보를 고려한다.
1.
윈도우 크기에 기반해서 계산 된 각 토큰 쌍의 relative position index에 대한 바이어스 값을 저장하는 relative position bias table를 생성
2.
각 토큰 간의 상대적 위치 인덱스를 생성
→ 높이 방향 기준(coords_h), 너비 방향 기준(coords_w)으로 상대적 위치 인덱스 생성
각각의 x-axis matrix와, y-axis matrix에 (windowsize1)(window size-1) 의 값을 더해줌
→ index로 나타내기 위해서 값의 범위가 0부터 시작되도록 스케일링하는 것!
x-axis matrix에 (2 x windowsizewindowsize -1) 값을 곱해주고, x -axis matrix와 y-axis matrix를 합해주면 최종 relative position index 행렬이 완성됨
3.
어텐션 연산 수행 시, 해당 테이블을 참조해서 계산
class WindowAttention(nn.Module): def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim # 입력 차원 self.window_size = window_size # 윈도우 크기 (Wh, Ww) self.num_heads = num_heads # 어텐션 헤드 수 head_dim = dim // num_heads # 각 헤드의 차원 self.scale = qk_scale or head_dim ** -0.5 # 스케일링 인자 # 1. relative position 테이블 정의 self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 상대 위치에 따른 바이어스 값 # 2. 각 토큰 간의 relative position 인덱스를 계산 coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) # 좌표 격자 생성 coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 좌표 평탄화 coords_flatten = torch.flatten(coords, 1) # 상대 좌표 계산 relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 차원 재배열 relative_coords = relative_coords.permute(1, 2, 0).contiguous() # index 값이 0부터 시작하도록 조정 relative_coords[:, :, 0] += self.window_size[0] - 1 #x-axis matrix relative_coords[:, :, 1] += self.window_size[1] - 1 #y-axis matrix relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 #x-axis matrix : 2*window_size -1 곱하기 relative_position_index = relative_coords.sum(-1) # 최종 상대 위치 인덱스 self.register_buffer("relative_position_index", relative_position_index) # QKV(쿼리, 키, 값) 선형 변환을 정의 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) # 어텐션 드롭아웃 self.proj = nn.Linear(dim, dim) # 결과 프로젝션 레이어 self.proj_drop = nn.Dropout(proj_drop) # 프로젝션 드롭아웃 # relative position bias table 초기화 trunc_normal_(self.relative_position_bias_table, std=.02) # 소프트맥스 레이어로 최종 어텐션 확률 계산 self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ parameter : x: 입력 특징 벡터 (num_windows*B, N, C) mask: 어텐션 마스크 (num_windows, Wh*Ww, Wh*Ww) 또는 None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # 쿼리, 키, 값 분리 q = q * self.scale # 쿼리 스케일링 attn = (q @ k.transpose(-2, -1)) # 어텐션 스코어 계산 # relative position bias 추가 relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) relative_position_bias = relative_position
Python
복사
MLP 클래스
class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) # fully-connected layer x = self.act(x) # GeLU 활성화함수 x = self.drop(x) # drop-out x = self.fc2(x) # fully-connected layer x = self.drop(x) # GeLU 활성화함수 return x
Python
복사
MLP 레이어는 2개의 레이어로 구성되어있고, GELU 활성화 함수를 사용한다.
Shifted Window Multi-head Self-Attention(SW-MSA)
Swin Transformer에서는 W-MSA 이후 Shift Window Multihead Self-Attention을 수행하게 된다.
이때 Cycle shift 방법을 사용해서 윈도우를 shift하게 되고, A, B, C구역을 우측 하단으로 shift하고 mask를 씌워서 self-attention을 수행한다.
여기서 mask self-attention을 사용하는 이유는 shift한 A, B, C 구역이 실제 원본 이미지에서는 서로 인접한 위치가 아니었기 때문에 attention 계산의 의미가 없으므로, 실제 인접한 부분은 0, 인접하지 않은 부분을 -100으로 마스킹하여 attention을 수행하게 된다.
→ attention matrix에 큰 음수(-100, -inf)를 더해 주면 softmax에서 이를 무시하게 됨.
# SW-MSA를 위한 어텐션 마스크 계산 if self.shift_size > 0: H, W = self.input_resolution # 이미지 마스크 초기화 img_mask = torch.zeros((1, H, W, 1)) # 너비, 높이에 대한 슬라이스 정의 (window_size, shift_size) h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) # 각 윈도우 위치마다 상대적 위치 인덱스(cnt) 할당 cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 # window_partiton 함수로 이미지 마스크를 윈도우 크기에 맞게 분할 mask_windows = window_partition(img_mask, self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # 어텐션 마스크 생성 # 윈도우의 값을 1차원으로 평활화 attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # 서로 인접한 부분은 0, 인접하지 않은 부분은 -100으로 마스크 생성 attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None
Python
복사
SwinTransformerBlock 클래스의 forward 함수로 전체 흐름 파악
def forward(self, x): '''SwinTransformerBlock에서 어텐션과 피드포워드 네트워크를 처리하는 forward 함수''' H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x # 잔차연결을 위한 입력값을 shortcut에 저장 x = self.norm1(x) # 1번째 정규화 레이어 norm1 적용 x = x.view(B, H, W, C) # 1. 순환 시프트(Cyclie shift) 적용 if self.shift_size > 0: # shift_size가 0보다 큰 경우 # 입력 텐서를 지정된 크기만큼 순환 shift (SW-MSA를 위한 과정) shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # 윈도우 분할 - 순환 시프트된 텐서를 윈도우 단위로 분할 (window_size, window_size, C) x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # 2. W-MSA/SW-MSA - 분할된 윈도우들에 대해 어텐션 메커니즘 적용 attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # 3. 윈도우 병합 attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # window_reverse 함수로 어텐션을 적용한 윈도우들을 원래 이미지 형태로 병합 shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # 4. 순환 시프트 복원(reverse cyclic shift) if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) # shortcut에 드롭패스 적용결과 추가 # 2번째 normalization 적용 후 mlp 수행 x = x + self.drop_path(self.mlp(self.norm2(x)))
Python
복사

3. Patch Merging 클래스

Stage 2, 3, 4 에서는 Swin Transformer Block을 통과하기 전에 Patch Merging 레이어를 두어 패치 병합을 수행한다.
Patch Merging Class로 네트워크가 깊어짐에 따라서 각 Stage 별로 레이어에서 처리하는 패치를 병합하여 토큰 수를 줄여 계층적 특성 맵을 생성하게 된다.
(Stage 1: H4\frac {H}{4} x W4\frac {W}{4}x C C Stage 2: H8\frac {H}{8} x W8\frac {W}{8} x 2C2C Stage 3: H16\frac {H}{16} x W16\frac {W}{16} x 3C3CStage 4: H32\frac {H}{32} x W32\frac {W}{32} x 4C4C )
Patch Merging Class는 인풋 이미지 혹은 feature map의 각 2x2 패치를 단일 feature map으로 병합해서 차원 수를 늘리고, 공간 해상도를 줄인다.