[Paper Review] U-Net: Convolutional Networks for Biomedical Image Segmentation (+ FCN)

Date:     Updated:

카테고리:

태그:

Main Reference:
- FCN Paper
- U-Net Paper
- jeremyjordan blog


🚀 Introduction

2015년에 등장한 Fully Convolutional Networks(FCN)은 CNN기반 segmentation 모델의 초기 버전이다. 오늘 이야기할 U-Net 역시 2015년에 등장했으며 FCN을 개량한 모델이다. U-Net과 FCN의 구조가 매우 비슷하다고 느껴지기 때문에 FCN의 내용은 본 포스팅에서 잠깐 소개하는 것으로 넘어가려 한다. 또 하나 인상적이었던 것은 의료 이미지 데이터에 관한 내용이다. 본 논문을 포함한 많은 논문의 제목에서 ‘biomedical image 어쩌구’ 라는 것을 강조한다. U-Net의 논문을 읽어보며 알게 된 것은 biomedical image이기 때문에 가능한 data augmentation이 존재한다는 것이다. Segmentation이면 그냥 segmentation이지 왜 의료 이미지라는 사족을 달까 싶었는데, 조금이나마 이유를 알 것 같다.


🚀 About Fully Convolutional Networks(FCN)

fcn

FCN의 구조에는 크게 3가지 특성이 존재한다.

  • Can take arbitrary size as input
  • Upsampling
  • Skip connection


Can take arbitrary size as input

만약 누군가 segmentation 작업을 수행해야 한다면 가장 먼저 드는 생각은 ‘라벨링 된 데이터를 어떻게 구하지?’ 일 것이다. Segmentation task의 특성상 픽셀 단위로 라벨링을 해야 하기 때문에 데이터를 만들기가 매우 어렵다. Input size가 조금 다르다고 나만의 작고 소중한 데이터를 버려서 되겠는가. 이러한 문제 때문에 FCN은 기존 CNN모델에서 fully connected layer를 제거했다. CNN에서 fixed input size를 요구하는 범인이 바로 fc layer이기 때문이다. 따라서 fc layer를 1x1 conv layer로 교체했고, 임의의 크기를 가진 데이터를 input으로 받을 수 있게 되었다.


Upsampling

Convolution Transposed Convolution
upsample1 upsample2

Segmenation의 또 하나의 특징은 output data가 단순 클래스가 아니라는 것이다. Input 데이터의 동일한 크기의 size가 요구되며 픽셀 단위로 클래스를 예측해야 한다. 따라서 앞단 conv layer들을 통해 어떤 object들이 존재하는지(context)를 파악하고, input size와 동일한 output size를 얻기 위해 upsampling 과정이 필요하다. Umsampling 과정에는 크게 두가지 방법이 존재한다. 첫째는 non-trainable한 단순 max-pooling, average-pooling의 역연산 과정이고, 둘째는 trainable한 transposed convolution 과정이다. FCN에서는 학습 가능한 transposed convolution 연산을 통해 upsampling을 수행한다.
‘Pooling의 역연산을 이야기하는데 왜 conv layer가 등장하지?’ 싶을 수 있는데 padding이 없는 convolution은 down-sampling이기 때문이다. 이제 transposed-convolution이 무엇인지 이야기해보자. 기존의 convolution 과정을 matrix 곱으로 표현한다면 왼쪽 그림처럼 Ax=y 형태로 표현 가능하다. 이때 y(2x2)를 통해 x(3x3)를 얻고 싶다면, A 대신 A-transpose 형태를 곱해주면 되기 때문이다. (A * x = y & A^T * y = x)


Skip connection

fcn_result

마지막으로 skip connection 파트이다. FCN은 conv layer를 통해 이미지의 어떤 object들이 존재하는지 알아내고, Upsampling을 통해 object가 어디에 있는지 파악해야 한다. 하지만 앞단 conv layer를 통과하면 object의 위치 정보의 상당 부분을 손실하게 된다. 이러한 문제를 해결하기 위해 conv layer의 중간중간 output을 가져와 upsampling 과정에서 더해주게 된다. 전형적인 spatial pyramid 구조이다. 위 그림은 가장 왼쪽부터 skip connection을 이용하지 않은 경우, 16 stride마다, 8 stride마다 skip connection을 사용한 경우이다. 가장 초기 모델이라 그런지 성능이 좋지 못하다.


🚀 U-Net Architecture

unet

U-Net은 앞서 소개한 FCN을 개량한 버전이다. 전체 구조는 contractng path(축소 경로)와 expansive path(확장 경로)로 이루어져 있다. 앞단의 contracting path에서는 image의 context(what)를 파악하고, 뒷단의 expansive path에서는 localization(where)을 파악한다. Contracting path와 expansive path는 대칭적인 구조를 가지며 네트워크의 전체 구조는 U자형을 이룬다. U-Net 또한 FCN과 마찬가지로 정확한 localization을 위해 앞단의 output을 가져온다. 이때 FCN의 경우 add 연산을 수행하지만 U-Net의 경우 concat 연산을 이용한다. 또 하나의 차이점으로 언급한 것은 upsampling 구간에서 context 정보를 잃지 않기 위해 많은 채널 수를 가진다는 것이다.

  • Contracting path : 2 x (conv + relu + pooling)
  • Expansive path : upsample + concat + 2 x (conv + relu)


🚀 Loss Function

Output segmentation map

output1

먼저 일반적인 segmentation 모델의 output 형태에 대해 알아보자. 위와 같은 최종 라벨링 직전에는 클래스 갯수만큼의 채널을 가진 feature map이 존재한다. U-Net의 경우 흑백인 의료 이미지를 학습했기 때문에 2개의 채널을 가지지만 일반적으로는 k개의 채널을 가지게 된다. 각 채널은 해당 위치의 픽셀이 해당 클래스일 확률을 나타낸다.

before softmax after softmax(final output)
output2 output3


Loss function

\[L = \sum_{x\in \Omega }^{} w(x) log(p_{l(x)}(x))\] \[x\in \Omega \quad with \; \Omega \subset \mathbb{Z}^{2}\] \[l : \Omega \to {1, ..., K} \quad \text{is the true label of each pixel}\] \[p_{k}(x) = \frac{e^{a_{k}(x)}}{\sum_{k^{'}=1}^{K} e^{a_{k^{'}}(x)}}\] \[a_{k}(x) : \text {activation in feature channel k at the pixel position x}\]
  • p_k(x) : pixel x에서 클래스 k에 대한 softmax 값.
  • loss function은 각 픽셀에서 softmax 값들의 weighted cross-entropy
  • weight는 경계를 강조하기 위함


loss

의료 이미지의 경우 위 그림과 같이 다닥다닥 붙어있는 세포들이 많이 존재한다. 따라서 U-Net의 경우 object의 경계를 강조하기 위한 weighted sum을 사용한다.

\[w(x) = w_{c}(x) + w_{0}\cdot exp(-\frac{(d_{1}(x) + d_{2}(x))^{2}}{2\sigma ^{2}})\] \[w_{0} = 10 , \; \sigma = 5 \quad \text{(in paper)}\] \[w_{c}(x) = \frac{1}{frequency} \quad \text{where} \; frequency = \frac{\text{pixel count of class c}}{\text{total number of pixels}}\]
  • w_c(x) : 클래스간 밸런스를 맞춰주기 위한 weight
  • d_1(x) : x 픽셀에서 가장 가까운 object와의 거리
  • d_2(x) : x 픽셀에서 두번째로 가까운 object와의 거리
  • exp^(어쩌구~) : Gaussian 형태. 따라서 d1 + d2의 값이 작을수록 값이 커짐.
    • 특정 세포 내부의 픽셀을 예로 들면 d_1(x) = 0인 상태.
    • 다른 세포와 가까워질수록, 즉 경계에 다다를수록 (d1 + d2)값이 작아짐


Data Augmentation

elastic

U-Net의 경우 flip, rotation 등 일반적인 augmentation 기법도 사용하지만 특히 elastic deformation이 매우 효과적이었다고 강조한다. Elastic deformation(탄성 변환?)은 다른 변환들과는 다르게 비선형적으로 변환한다. 이 기법이 효과적인 이유는 사용하는 데이터가 의료 이미지이기 때문인 것 같다.


Overlap-tile strategy

Prediction Mirroring
overlap mirroring

U-Net의 구조를 자세히 보면 input의 이미지 크기와 output의 이미지 크기가 살짝 다르다는 것을 알 수 있다. 즉 위 그림에서 노란색 박스의 영역을 추론하기 위해서는 파란색 크기의 박스 영역이 필요하다. 이때 모자란 나머지 영역은 기존 데이터의 대칭으로 확장한다.


overlap_tile

U-Net의 경우 하나의 이미지를 패치 단위로 잘라 학습을 시키는데, 결과적으로 위 그림과 같이 겹치는 부분(overlap-tile)이 발생한다. 내가 정확히 이해한 것이지 조금 헷갈리지만 이러한 방법들을 굳이 도입한 이유는 현실적인 문제 때문일 것이다. 논문에서 gpu memory에 대한 이야기가 많이 등장하는데 실험에 사용한 gpu의 memory가 6GB라고 한다. 참고로 내가 게임하려고 몇년 전에 산 gpu memory가 8GB이다. 아무튼 gpu memory를 최대한 효율적으로 사용하기 위해서는 input size가 작은 patch 단위가 유리하다. 의료 이미지의 특성상 패치 단위로 잘라도 큰 문제가 되지 않고, mirroring을 통해 데이터를 확장하여도 크게 티가 나지 않는다.


🚀 Result

result

Experiments 파트에서는 여러 의료 데이터들을 대상으로 기존 모델과의 성능 비교를 보여준다. 근데 의료 용어가 너무 많아서 못알아먹겠다. 아무튼 U-Net이 기존 모델 대비 압도적이라고 한다.


🚀 Code

U-Net from scratch (pytorch)

여담이지만 모델 구조가 굉장히 깔끔한 것 같다. 물론 segmentation task에서 뛰어난 성능을 보여주는 모델은 아니지만 어지간한 classification 모델보다 구현이 간단했다. 사실 Detection 모델을 구현할 때 고통스러운 기억이 있어, segmentation은 얼마나 어려울까 벌벌 떨었는데 이렇게 간단할 줄이야.

inference



맨 위로 이동하기

Paper 카테고리 내 다른 글 보러가기

댓글 남기기