hyeonzzz's Tech Blog

[딥러닝 파이토치 교과서] 13장. 생성 모델 - (2) 본문

Deep Learning/Pytorch

[딥러닝 파이토치 교과서] 13장. 생성 모델 - (2)

hyeonzzz 2024. 7. 12. 19:38

13.3 적대적 생성 신경망(GAN)이란

GAN (Generative Adversarial Network)

위조지폐범 : 진짜와 같은 위조 화폐를 만들어 경찰을 속임 (= 생성 모델, 생성자)

경찰 : 진짜 화폐와 위조 화폐를 판별하여 위조지폐범을 검거 (= 분류 모델, 판별자)

 

위조지폐범과 경찰의 경쟁

- 위조지폐범은 진짜와 같은 위조지폐를 만들 수 있게 되고, 경찰은 위조지폐와 실제 화폐를 구분할 수 없는 상태에 이른다

 

적대적 학습

판별자 학습 -> 생성자 학습

 

판별자 학습

1. 실제 이미지를 입력 -> 네트워크가 해당 이미지를 진짜로 분류하도록 학습시킴

2. 생성자가 생성한 모조 이미지를 입력 -> 해당 이미지를 가짜로 분류하도록 학습시킴

-> 실제 이미지를 진짜로 / 모조 이미지를 가짜로

 

 

생성자 - 분류에 성공할 확률 낮춤

판별자 - 분류에 성공할 확률 높임  -> 서로 경쟁적으로 발전

 

 

13.3.1 GAN 동작 원리

  • 생성자 (Generator) G - 원래 이미지와 최대한 비슷한 이미지를 만들도록 학습
  • 판별자 (Discriminator) D - 원래 이미지와 생성자 G가 만든 이미지를 잘 구분하도록 학습

 

판별자 D

  • 이미지 x가 입력으로 주어졌을 때 판별자 D의 출력에 해당하는 D(x)가 진짜 이미지일 확률을 반환한다.

생성자 G

  • 진짜와 같은 모조 이미지를 노이즈 데이터를 사용하여 만들어낸다.

 예를 들어 실제 이미지인 알파벳 z가 입력으로 주어졌을 때

  • 판별자는 z를 학습한다. 
  • 생성자는 임의의 노이즈 데이터를 사용하여 모조 이미지 z'(G(z))를 생성한다.
  • G(z)를 다시 판별자 D의 입력으로 주면 판별자는 G(z)가 실제 이미지일 확률을 반환한다.

판별자 D를 학습 시킬 때는 생성자 G를 고정시킨 채 실제 이미지는 높은 확률을 반환하는 방향으로, 모조 이미지는 낮은 확률을 반환하는 방향으로 가중치를 업데이트한다.

 

GAN의 목적 함수 (가치 함수)

  • x~P_data(x) : 실제 데이터에 대한 확률 분포에서 샘플링한 데이터
  • z~P_z(z) : 가우시안 분포를 사용하는 임의의 노이즈에서 샘플링한 데이터
  • D(x) : 판별자 / 1에 가까우면 진짜 데이터, 0에 가까우면 가짜 데이터
  • D(G(z)) : 생성자 G가 생성한 이미지 / 1에 가까우면 진짜 데이터, 0에 가까우면 가짜 데이터로 판단

판별자가 G(z)를 입력받을 경우 1로 예측하도록 하는 것이 목표이다

가치 함수(목적 함수)
- 큰게 좋을 수도 있고 작은게 좋을 수도 있다 (최적화 대상)
- 판별자의 관점에서는 최대, 생성자의 관점에서는 최소

 

판별자 D 부분

  • 최상의 결과 : D(x)=1, D(G(z))=0인 경우이기 때문에 max
  • log(D(x))와 log(1- D(G(z))) 모두 최대가 되어야 한다

 

생성자 G 부분

  • 최상의 결과 : D(G(z))=1인 경우이기 때문에 min

GAN을 학습시키려면 판별자와 생성자의 파라미터를 번갈아 가며 업데이트해야 한다.

판별자의 파라미터를 업데이트할 때는 생성자의 파라미터를 고정시키고, 생성자의 파라미터를 업데이트할 때는 판별자의 파라미터를 고정해야 한다.

 

 

13.3.2 GAN 구현

data_fake = generator(torch.randn(b_size, nz).to(device)).detach()
  • torch.randn(b_size, nz) : 생성자에 노이즈 벡터 제공, 평균이 0이고 표준편차가 1인 가우시안 정규분포 이용, (b_size * nz) 크기
  • detach() : detach()를 통해 떼어 낸 데이터를 이용하여 판별자를 학습시키고 그 결과를 loss_d에 붙여 넣는다

 

생성자와 판별자에 대한 오차

 

처음 몇 에포크 동안 생성자의 오차는 증가하고 판별자의 오차는 감소한다

: 학습 초기 단계에 생성자는 좋은 가짜 이미지를 생성하지 못하기에 판별자가 실제 이미지와 가짜 이미지를 쉽게 구분할 수 있기 때문

학습이 진행됨에 따라 생성자의 오차는 감소하고 판별자의 오차는 증가한다

: 생성자는 진짜와 같은 가짜 이미지를 만들며 판별자는 가짜 이미지 중 일부를 진짜로 분류하기 때문

 

 

13.4 GAN 파생 기술

GAN은 생성자와 판별자가 대결하면서 학습하기 때문에 학습이 매우 불안정하다

한쪽으로 치우친 훈련이 발생하면 성능에 문제가 생겨 정상적인 분류가 불가능하다

 

제약을 해결한 모델들

  • DCGAN : GAN 학습에 CNN을 사용
  • cGAN : 가짜 이미지 생성을 위해 출력에 어떤 조건을 주어 변형 (시드 역할을 하는 임의의 노이즈와 함께 조건 추가)
  • CycleGAN : 사진이 주어졌을 때 다른 사진으로 변형 (예를 들어 말을 얼룩말로)

 

13.4.1 DCGAN

생성자 네트워크

  • 임의의 입력을 받아들여 판별자에서 사용할 수 있는 이미지 데이터를 생성한다
  • 출력은 64 * 64
  • 노이즈 데이터는 가로 *세로 형태가 아니기 때문에 입력 형태를 가로 * 세로로 reshape 해야 한다
  • 형태가 변형된 입력은 합성곱층으로 넘겨진 후 이미지 형태의 출력을 위해 분수-스트라이드 합성곱을 사용하여 출력 값을 키운다

 

생성자 네트워크의 특징

  • 풀링층을 모두 없애고, 분수-스트라이드 합성곱을 사용한다
  • 배치 정규화를 이용하여 층이 많아도 안정적으로 기울기를 계산할 수 있도록 했다 (단, 최종 출력층에서는 사용하지 않는다)
  • 활성화 함수는 ReLU를 사용, 최종 출력층에서는 tanh 사용

 

판별자 네트워크

  • 64 * 64 이미지를 입력받아 진짜 혹은 가짜의 1차원 결과를 출력한다
  • 활성화 함수로 LeakyReLU를 사용, 최종 출력층에서는 sigmoid 사용

 

판별자 네트워크의 특징

  • 풀링층을 모두 없애고, 스트라이드 합성곱을 사용한다
  • 배치 정규화를 이용하여 층이 많아도 안정적으로 기울기를 계산할 수 있도록 했다 (단, 최종 출력층에서는 사용하지 않는다)
  • 활성화 함수는 LeakyReLU를 사용, 최종 출력층에서는 sigmoid 사용

 

13.4.2 cGAN

기존 GAN 이용할 경우

  • 출력을 만들어 낼 때 사람이 통제할 수 있는 부분이 없다

cGAN 이용할 경우

  • 입력 이미지에 새로운 객체를 추가하거나 이미지에 자동으로 문자열 태그를 붙일 수 있다
  • 사람이 통제할 수 있다!

 

cGAN 원리

 

MNIST 데이터셋을 사용하여 데이터를 훈련시킨 후 숫자 1을 출력할 때

  • 생성자에 노이즈 벡터와 더불어 그것을 뜻하는 조건 C(예를 들어 [0,0,1])를 넣어 준다
  • 판별자에도 조건 C([0,0,1])가 추가되어야 한다

생성자와 판별자에 조건이 추가되면서 이미지에 대한 변형이 가능해진다

 

 

13.4.3 CycleGAN

GAN과 DCGAN -> 랜덤 노이즈를 입력으로 하므로 무작위 데이터가 생성되기 때문에 원하는 결과를 얻기 어렵다

 

PIX2PIX

임의의 노이즈 벡터가 아닌 이미지를 입력으로 받아 다른 스타일의 이미지를 출력하는 지도 학습 알고리즘

입력을 위한 데이터셋과 PIX2PIX를 거쳐서 나올 정답 이미지가 필요하다

 

생성자 네트워크

  • 입력과 출력이 모두 이미지 -> 인코더-디코더 구조
  • 인코더 : 입력 데이터의 특징을 찾아낸다
  • 디코더 : 이미지를 생성하는 역할
  • 출력층의 활성화 함수 : tanh (-1 ~ 1) , 따라서 입력 또한 (-1 ~ 1) 값으로 변경해야함

 

판별자 네트워크

  • 스트라이드가 2인 합성곱층, 뒤의 두 계층은 스트라이드가 1인 valid 합성곱을 이용하여 최종적으로 30 * 30 형태의 데이터 출력
  • 출력에서 차이가 있는 이유 : 판별자를 이미지의 각 부분별로 진행하기 위해서 (이미지의 각 부분이 진짜인지 아닌지 판별한다)

목적함수

 

생성자는 판별자를 속이는 것 말고도 생성한 이미지가 정답과 같아야 하는 과제가 있기 때문에

L1 손실함수를 사용한다

 

L1 손실함수가 추가된 최종 손실 함수는 다음과 같다

 

CycleGAN