0. Backgrounds
- minimax game algorithm
"상대방의 최고의 수가 나에게 가장 최소의 영향을 끼치게 만들자."
본인 차례에서는 최선의 선택을 해야하며 똑똑한 상대방이 나에게 유리한 선택을 할 리가 없기에 상대방 차례에선 나에게 가장 최악의 선택을 하는 방식.
- Saturation (포화)
예를 들어 sigmoid function의 경우 입력이 크거나 작을 때 기울기가 0에 가까워짐. 이렇게 activation function 구간에서 기울기(gradient)가 0에 가까워지는 현상을 saturated라고 하고 이는 vanishing gradient problem을 야기해 학습이 잘 안됨.
- parametric model vs non-parametric model -> 수정 필요. GAN에서는 Nonparametric statistics 개념이 'non-parametric limit setting'과 더 관련있을 듯
parametric model은 parameters 수가 정해진 모델. non-parametric model은 training data의 크기에 따라 parameters 수가 달라지는 모델.
즉, parametric model은 학습 데이터가 특정 분포를 따른다고 가정하고 학습하면서 정해야하는 (분포의) 파라미터의 종류와 수가 정해진 모델.
반면 non-parametric model은 학습 데이터가 특정 분포를 따른다는 가정이 없기 때문에 학습에 따라 튜닝해야할 파라미터가 명확히 정해져 있지 않아 data에 대한 사전 지식이 전혀 없을 때 유용하게 사용될 수 있다.
1. Introduction
이 논문에서 저자는 이전 deep generative model들이 갖는 어려움을 피할 수 있는 새로운 generative model estimation 절차를 제안했다.
그게 바로 GAN (Generative Adversarial Nets)이다.
GAN 모델은 Generative net과 Discriminative net으로 이루어져 있는데 저자는 이 둘의 역할을 경찰과 화폐위조범으로 비유했다.
Generative model(화폐위조범 역할)은 fake data를 최대한 real data와 유사하게 생성하는 것이 목표이고 discrminative model (경찰 역할)은 real data와 fake data를 잘 구별하는 것이 목표이다!
그래서 결국 GAN의 핵심 내용은 각각의 역할을 가진 두모델을 통해 적대적(Adversarial)학습을 하면서 '진짜 같은 가짜'data를 생성하도록 학습시킨다는 것이다.
이때 이 adversarial framework는 많은 모델과 optimization 알고리즘에 대해 특정 학습 algorithm을 사용할 수 있는데 저자는 MLP (multi-layer perceptron)인 Generative model과 Discriminative model을 사용하였고 이 network를 adversarial nets라고 하였다.
MLP를 쓴 이유는 다른 복잡한 방식 필요없이 오직 back propagation, dropout만으로 두 모델을 학습시킬 수 있고 generative model로 부터 sampling할 때 forward propagation만 하면 된다는 장점이 있기 때문이다.
2. Related works (생략)
3. Adversarial nets
먼저 Generator에 대해 다음과 같이 논문에서 표현했다.
Data x에 대해 generator의 distribution Pg 를 배우기 위해 input noise variable Pz (z) 에 대한 prior를 정의한다. (이때 z 는 uniform하게 sample되었다.) 그런 다음 data space로의 mapping을 G(z; θg) 라고 표현한다. 이때 G는 parameters θg를 갖는 mulilayer perceptron으로 표현되는 미분 가능한 함수이다.
다음으로 discriminator을 MLP D(x; θd)로 표현하였다. discriminator은 한 개의 scalar 값을 출력한다. (1 = real data로 판단, 0 = false data로 판단)
즉 D(x)는 x가 generator로 생성된 distribution Pg 로부터 sample된 게 아니라 training data로 부터 sample되었을 확률은 의미한다.
이 Adversarial nets의 학습 목표는 다음과 같다.
1. D(discrimnator) 가 training examples(=real data), samples from G(=fake data)에 정확한 label을 assign하는 확률을 높이는 것.
2. 그리고 동시에 G가 log(1-D(G(Z)))를 최소화하도록 학습하는 것 (= D(G(Z))가 최대여야. 즉 samples from G(=fake generated image) 를 discriminator가 training data로부터 sample되었다고 판별하도록 하는 것. 즉, G가 진짜 같은 이미지를 만들도록 학습!)
이 두 학습목표를 다른 말로는 D, G과 아래의 value function V(G, D)를 갖는 2 player minimax game을 하는 것으로 볼 수 있다.
왜 이렇게 표현할 수 있는 지 좀 더 자세히 살펴보자
(1)번 식에서 첫번째 항: real data x를 discriminator 에 넣었을 때 나오는 결과를 log취했을 때 얻는 기댓값
(1)번 식에서 두번째 항: input noise variables z를 generator에 넣었을 때 나오는 결과(=generated fake data)를 discriminator에 넣었을 때 그 결과를 log(1-결과)했을 때 얻는 기댓값
이 식을 D와 G의 입장에서 각각 살펴보자.
먼저 D의 입장에서 위 Value function V(D,G)의 이상적인 결과는 D가 좋은 성능을 가져 판별을 잘 해낸다고 했을 때 실제 data distribution pdata에서 온 sample일 경우 D(x)는 1이 되기 때문에 첫 항은 0이 되고 G가 생성한 가짜 이미지를 잘 구별할 수 있으므로 D(G(z))는 0이 되어 두번재 항은 0이 된다. 따라서 전체 식 V(D,G) = 0이 된다.
따라서 D의 입장에서 얻을 수 있는 이상적인 결과는 '최댓값' 0 이다.
다음으로 G의 입장에서 위 Value function V(D,G)의 이상적인 결과는 G가 D가 구별 못할만큼 진짜와 같은 데이터를 잘 생성할 때로 첫번재 항은 D가 구별해내는 것에 대한 식으로 G의 성능과 무관하므로 패스하고 두번재 항을 보면 D가 G가 생성해낸 이미지를 가짜가 아닌 진짜라고 판단할 것이므로 D(G(z))가 1이되어 두번재 항은 마이너스 무한대가 된다.
따라서 G의 입장에서 얻을 수 있는 이상적인 결과는 ‘최솟값’ 마이너스 무한대이다.
정리하자면 D는 training data부터로의 sample과 G로 생성된 sample에 진짜인지 가짜인지 맞는 label을 assign할 확률을 최대화하기 위해 학습되고(value function이 최댓값을 갖도록) G는 log(1-D(G(z)))를 최소화하기 위해(value function이 최솟갑을 갖도록) 학습된다!
그렇기에 D는 V(D,G)를 최대화시키려고하고 G는 V(D,G)를 최소화시키려고 하므로 이 논문에서는 D,G를 V(D,G) 갖는 2 player minimax game으로 표현했다@.@
좀 더 이론적으로 학습이 되고 수렴하는 것에 대한 증명은 다음 section에서 다룰 것이고 그 전에 less formal하게 GAN의 학습과정을 보자.
figure 1을 보면 D가 학습됨에 따라 G가 나아지는 것을 볼 수 있다.
GAN은 discriminative distribution을 동시에 update 시켜 px로부터의 sample과 pg로부터의 sample을 구분하도록 학습된다.
이 그림에서 각 선이 의미하는 것은 다음과 같다.
(D) blue, dashed line : discriminative distribution. D outputs a single scalar value. (Data 구분 확률)
(Px) Black, dotted line: training data로 부터 만든 distribution. Data distribution (=real)
(G, Pg) green, solid line: generative distribution (=fake)
아래 수평선은 uniformly sampled된 z의 domain이고 위 수평선은 x의 domain 일부이다.
위로 향하는 화살표는 x = G(z) mapping이 transformed samples에 non-uniform distribution pg가 부과되는 것을 보여준다. (아래 수평선은 sampling간격 균일. 위는 균일하지 않음)
(a) 학습 초기로 black(real), green(generated) 분포가 다르고 D의 성능도 좋지 않음을 볼 수 있다.
(b) D가 학습이 되어 (a)처럼 들쑥날쑥하게 확률을 판단하지 않음을 알 수 있다.
(D ∗ (x) = pdata(x) pdata(x)+pg(x)로 수렴하여 samples from data를 구분하도록 D가 training inner loop에서 학습됨)
(c) 어느 정도 D가 학습이 되면 D의 gradient가 G가 실제 분포를 모사하여 D가 구별하기 힘들도록 G가 학습됨.
(d) 이 과정이 반복되어 학습하면 pg = pdata인 지점에 도달하고 D가 fake와 real을 구분하기 힘들어 D(x) = 1/2가 된다.
근데 이때 D를 training inner loop에서 완전하게 학습하는 것은 계산적으로도 금지되지만 overfitting이 될 수 있는 문제가 있다.
그래서 저자는 D를 optimize하는 k step과 G를 optimize하는 1 step을 번갈아가며 진행했다. 즉, D를 계속 optimizing하는 것이 아니라 D와 G를 번갈아가며 학습. 이렇게 하면 G가 천천히 변하는 한 D는 optimal solution에 머물게 된다고 한다. 이 과정은 Algorithm1에 나타나있다.
또, 학습 초기에는 G의 성능이 구려서 D가 너무 쉽게 가짜와 진짜를 구별할 수 있어 위의 (1) 방정식의 G가 학습하는데 충분한 gradient를 제공하지 못할 수 있다. 이 경우 log(1-D(G(z)))가 포화되기 때문에 이 경우 G가 log(1-D(G(Z)))를 최소화하도록 학습시키는 대신 G가 log(D(G(z)))를 최대화하도록 학습시킬 수 있다고 한다.
4. Theoretical Results
앞에서 말한 GAN의 minimax 방식이 제대로 working하는 것을 보이기 위해선 위의 minimax problem이 global optimum에서 유일해를 가지고 어떤 조건을 만족하면 그 solution으로 수렴함을 보여야한다.
본 논문의 section 4.1에서 이 minimax game이 pg = pdata일 때, global optimum을 가짐을 보여주고
section4.2에서 위의 algorithm1 방식으로 optimize해서 global optimum으로 수렴할 수 있음을 보여준다.
section 4.1 Global optimality of Pg = Pdata
Proposition(명제) 1. G가 고정된 경우, 최적 D는 다음과 같다.
따라서, G가 고정되어 있을 때 최적의 D는 위와 같다.
Theorem(정리) 1. Global minimum of G 는 Pg = Pdata일때 유일하고 그 값은 -log4이다.
optimal G를 구하자면,
minimax game이므로 V(G,D)값이 최소가 되게하는 G일때 D는 global maximum optimum을 가질 때이므로
아래와 같이 식을 정리하면 C(G)가 Global minimum optimum을 가질 때 pg = pdata 이고 그때 c(G)의 값은 -log4이다.
(KL: 쿨백-라이블러 발산. 두 확률분포의 차이를 구하는 것)
(JSD: Jensen-Shannon Divergence)
따라서, minimax 구조에서 pdata = pg일 때 global optimum을 갖는 다는 것을 알 수 있다.
section 4.2 Convergence of Algorithm 1
이 섹션에서는 모델이 위에서 구한 global optimum에 수렴하는지에 대한 내용이다.
위에서 pdata = pg일 때 global optimum을 가지므로 pg가 pdata로 수렴할 수 있으면 global optimum에 수렴한다고 볼 수 있다.
즉, generative model Pg 관점에서 global optimum으로 loss 함수가 수렴할 수 있는 지 따져야 하고
그러기 위해서 D를 고정시킨 후 loss함수가 convex한지 확인 하면 된다. (왜냐면 그래야 gradient descent같은 iterative optimization방식을 통해 해를 찾을 수 있기 때문이다.)
Proposition 2. G, D model이 충분한 capacity가 있고 algorithm1의 매 step에서 discriminator가 주어진 G에 대해 optimum에 도달할 수 있으며,
Pg가 아래 criterion을 개선하기 위해 update된다면 Pg는 Pdata로 수렴한다.
증명) 논문만으로 이해하기 어려워서,,,! 리뷰를 많이 참고했다.. 증명은 어렵지만 결론은 위의 식이 Pg에 대해 convex하므로 iterative optimization(ex. gradient descent. 여기서의 Algorithm1도 iterative optimization)방식으로 global optimum을 찾을 수 있고 그래서 Algorithm1으로 optimization해 수렴 된다는 것!
위의 criterion에서 D를 고정시켰으므로 V(G,D) = U(Pg, D)라고 Pg의 함수라고 생각해보자. D는 고정이므로 Pg관점에서 보면 D는 고정된 상수값이다.
그러므로 U(Pg,D)의 변수는 Pg이고 학습시(train loop에서) Pg가 변할텐데 변하는 Pg에 따라 U(Pg,D)가 Convex한지 증명해야한다.
따라서 U(Pg,D)는 Pg에 대해 convex함수이므로 algorithm1에 의해 Pg는 Pdata로 수렴한다.
따라서 4.1, 4.2를 통해, GAN은 global optimum을 갖고 algorithm1 방식으로 global optimum에 수렴할 수 있다는 것을 알 수 있다.
실제로는 Adversarial nets function G(z; θg)를 통해 제한된 Pg 분포군을 나타내고 저자는 pg자체 대신 θg 를 optimize한다. (즉, GAN을 통해 생성하는 데이터의 분포 pg를 추정해 최적화하는게 아니라 pg를 생성하는데 직접적인 영향을 미치는 parameter인 θg를 추정하고 최적화함. )
또, G를 정의하기 위해 MLP사용하는 것은 parameter space에서의 여러 critical points(비판점. Generative model을 MLP와 같은 것으로 설정하면 최종 loss function이 convex하지 않을 수 있음)가 있지만 실제로 MLP를 사용했을 때 성능이 좋기에 이론적으로 보장되지 않는 문제가 있음에도 쓸만하다고 한다.
5. Experiments
Train datasets
- MNIST[23], the Toronto Face Database (TFD) [28], and CIFAR-10 [21].
사용한 Activation function
- Generator nets: used a mixture of rectifier linear activations [19, 9] and sigmoid activations
- Discriminator net: used maxout [10] activations.
Dropout
- discriminator net 학습시에만 적용
Noise
- 이론적으로는 dropout이랑 noise를 generator의 intermediate layers에 사용가능하지만
저자는 generator network의 bottommost layer의 input으로만 noise사용.
6. Advantages and disadvantages
단점:
- pg(x)의 explicit representation이 없음
- 학습시 D는 G와 함께 synchronized되어야함. (G가 D 업데이트 없이 혼자 너무 학습되면 안됨.)
장점:
- Markov chain이 사용되지 않음
- gradients를 얻기위해 backpropagation만 사용되고 학습 시 inference 필요없음
- 다양한 function이 모델에 incoporated 될 수 있음
- 데이터 샘플로 부터 직접적으로 update되는 게 아니라 discriminator를 통해 gradients가 flowing됨으로써 generator network가 통계적 이점 가짐. 즉, input의 components가 generator의 parmeters로 direct하게 copy되지 않음
- 그리고 Markov chain based models와 달리 sharp, degenerate distribution도 represent가능.
7. Conclusions and Future work (생략)
This framework admits many straightforward extensions:
1. A conditional generative model p(x | c) can be obtained by adding c as input to both G and D.
2. Learned approximate inference can be performed by training an auxiliary network to predict z given x. This is similar to the inference net trained by the wake-sleep algorithm [15] but with the advantage that the inference net may be trained for a fixed generator net after the generator net has finished training.
3. One can approximately model all conditionals p(xS | x6S) where S is a subset of the indices of x by training a family of conditional models that share parameters. Essentially, one can use adversarial nets to implement a stochastic extension of the deterministic MP-DBM [11].
4. Semi-supervised learning: features from the discriminator or inference net could improve performance of classifiers when limited labeled data is available.
5. Efficiency improvements: training could be accelerated greatly by divising better methods for coordinating G and D or determining better distributions to sample z from during training.
This paper has demonstrated the viability of the adversarial modeling framework, suggesting that these research directions could prove useful.
8. References
1. https://tobigs.gitbook.io/tobigs/deep-learning/computer-vision/gan-generative-adversarial-network
3. https://process-mining.tistory.com/131
4. https://89douner.tistory.com/329
5. https://89douner.tistory.com/331
'논문 리뷰' 카테고리의 다른 글
DDPM: Denoising Diffusion Probabilistic Models (1) | 2023.07.25 |
---|---|
DeepLab v2 (0) | 2022.12.30 |
YOLO paper review (0) | 2022.12.02 |
Deep Residual Learning for Image Recognition (ResNet paper) 리뷰 (0) | 2022.11.18 |
FCN: Fully Convolutional Networks for Semantic Segmentation 정리 (0) | 2022.10.14 |