2021. 3. 2. 19:59ㆍ논문 리뷰
* 개인적으로 이해한 내용이니 혹 틀린부분이 있다면 댓글로 수정요청 부탁드립니다. 감사합니다.
Introduction
GAN은 그동안 많은 발전을 이루어 왔다. 하지만 gradient vanishing 문제가 항상 대두되었고 이를 해결하고자 spectral normailization을 활용한 discriminator를 사용하는 것들이 많아지기 시작했다.
특히 SAGAN 같은 경우에는 spectral normalization을 generator와 discriminator에 둘다 적용하여 GAN의 불완전성을 해소하는데 도움을 주었다.
BigGAN 같은 경우에는 parameter와 batch_size를 늘려 생성되는 이미지의 queality를 향상시켰다.
이 과정들을 통해 generator와 discriminator에 대한 conditioning class 정보가 image 생성에 영향을 준다는 것이 밝혀졌다.
ASGAN의 경우 discriminator에 softmax classifier를 더해 학습하는 방법으로 validation을 해주었고, ProjGAN의 경우 projection discriminator를 probabilistic model assumption을 이용해서 다른 논문들의 귀감이 되었다.
* projection discriminator란?
A Projection Discriminator is a type of discriminator for generative adversarial networks. It is motivated by a probabilistic model in which the distribution of the conditional variable y given x is discrete or uni-modal continuous distributions.
x가 주어진 조건부 변수 y의 분포가 이산 또는 단봉 연속 분포 인 확률이라고 합니다...
이 논문에서는 ACGAN과 ProjGAN이 data-to-class relation을 활용했다는 것에 영감을 받아 embedding of an image 와 embedding of the corresponding label 사이의 관계를 활용해서 conditional contrastive loss(이하 2C loss)를 제안하고 이는 data-to-class 뿐만 아니라 sample들 사이에서의 data-to-data 관계까지 파악할 수 있다.
이를 통해 Tiny ImageNet과 ImageNet datasets 에서 SOTA를 달성하고 discriminator의 overfitting문제를 해결하는데 도움을 준다고 한다.
Background
Generative Adversarial Networks
기본적으로 GAN은 generator와 discriminator로 구성되어있고, 서로 속고 속이고를 반복하는 과정을 통해 실제와 같은 이미지들을 만들어 낸다.
Conditional GANs
가장 일반적인 방법은 label information을 generator와 discriminator에 붙여서 사용하는 방법이다.
ACGAN의 경우 classifier를 추가하여 discriminator로 하여금 이미지들의 class까지 분류할수 있도록 하였다.
ProjGAN의 경우 ACGAN이 만들기 쉬운, 만들었을 때 확률이 가장 적절한 이미지들만 계속해서 만들어낸다는 점을 지적하며 이를 해결하기 위해 projection discriminator를 사용하는 것을 제안했다. 하지만 이는 역시 data-to-data의 관계를 고려하지 않은 방법이다. 게다가 projection discriminator를 활용한 BigGAN에서 discrminator의 overfitting 문제가 해결되지 않은 것이 발견되었다.
아래 그림을 보면 더 이해하기 편하다
Method
ACGAN과 ProjGAN들을 보면 위에서도 언급했듯, data-to-class 관계에만 집중한다는 것을 볼 수 있다. 이 논문에서는 data-to-data 관계까지 모두 고려한 Contrastive Generative Adversarial Networks를 제안한다.
Conditional GANs and Data-to-Class Relations
ACGAN의 discriminator은 주어진 이미지의 classification과 sample들의 확실성을 평가하는데 목적을 둔다.
주어진 data가 어떤 class에 속하는지에 대한 정보(data-to-class)를 이용하여 generator는 주어진 target labels로 classify 될 수 있도록 fake 이미지들을 만든다.
ProjGAN은 fake 이미지가 들어왔을 때 embeddings of real images와 corresponding target embeddings 사이의 inner-product value를 최소화 한다.
ACGAN와 같이 data-to-class 관계를 고려한 loss function을 활용하며 Proxy라는 learnable class embedding을 활용해서 ACGAN 보다는 더 유연하게 관계파악이 가능하다.
Conditional Contrastive Loss
data-to-data 관계에 더 집중하기 위해서 self-supervised learning 혹은 metric learning에 사용되는 loss function을 선택했다.
이는 metric learning이나 self-supervised learning objective를 discriminator와 generator에 더해서 주어진 label에 따라 embedded image features들간의 거리 조절을 더 잘 하게 하는 방법이다.
많이 사용되는 metric learning losses로는 contrastive loss, triplet loss, quadruplet loss등등이 있다.
하지만, triplet이나 quadruplet 같은 경우는 training 할 때 높은 complexity를 야기할 수 있고, 데이터 pair를 잘 못 묶으면 오히려 training하는데 시간이 더 걸리수 있으니 주의 하도록 하자.
2C loss를 소개 하기 전에 이해를 돕기 위해 먼저 NT-Xent loss를 살펴보면
라고 표현할 수 있으며 여기서 t는 push와 pull의 정도를 control하는 상수이다.
위의 식에서는 data augmentation을 요구로 하며 training example들에서의 data-to-class의 관계를 고려할 수 없다. 이 문제를 해결하기 위해서 data augmentations 대신에 embeddings of class label을 사용하는 방법을 생각했다. 이를 적용하면, 아래와 같이 표현이 가능하다.
이 식은 class embedding에서 가까운 sample x_i를 잡아 당기며 다른 것들은 밀어낸다.
하지만 이는 같은 label을 갖는 negative sample들을 밀어낼 수 있으므로 이 논문에서는 cosine similarities를 적용하여 negative sample들에 예외를 주었다.
최종적으로 제안한 2C loss를 최소화 하면 같은 label을 갖는 embedding된 image들 간의 거리를 감소 시키게 된다.
이를 통해 data-to-data relation과 data-to-class relation을 data augmentation과 mining of the training dataset 없이 가능하게 된다.
Contrastive Generative Adversarial Networks
전형적인 GAN들과 같이 ContraGAN 역시 discriminator training step과 generator training step이 존재하며 추가적으로 real or fake images들을 사용하여 2C loss도 계산하게 된다.
위의 그림이 제안된 ContraGAN의 과정을 잘 나타내 준다.
이 과정들을 통해 discriminator는 스스로 같은 class내의 real image embeddings간 거리를 좁혀주고 다른 것들은 거리를 벌려주게 된다.
Differences between 2C Loss and NT-Xent Loss
NT-Xent loss는 unsupervised learning을 의도한 것이다. data-to-data relation을 파악하기 위해 augmented image들을 positive sample로 넣어주고 그 관계를 파악하게 된다.
반면 2C loss는 weak supervision을 위한 것이며 그러하기 때문에 2C loss에 비해 NT-Xent 방법이 같은 class 내에서 image embeddings을 모으기가 어렵다. 학습시간도 더 오래 걸린다.
Experiments
다른 image generation 실험들과 비교하기 위해 CIFAR10, Tiny ImageNet, ImageNet dataset을 사용해서 실험을 진행했으며 Frechet Inception Distance(FID)를 사용해서 scoring을 하였다.
Conclusion
이 논문에서는 이전의 방식에서는 사용하지 않은 2C loss 를 제시했고 이는 data-to-class 뿐만 아니라 data-to-data 의 관계도 학습한다.
실험 결과를 통해 SOTA임을 증명했고 discriminator가 overfitting되지 않도록 도와준다는 것을 증명해 냈다.
'논문 리뷰' 카테고리의 다른 글
ELMo : Embeddings from Language Models (0) | 2021.04.12 |
---|---|
Bert : Pre-training of Deep Bidirectional Transformers for Language Understanding (0) | 2021.04.09 |
Swin Transformer: Hierarchical Vision Transformer using shifted Windows (0) | 2021.04.01 |
Video Transformer Network (0) | 2021.03.25 |
Exploring Simple siamese Representation Learning (0) | 2021.03.10 |