Gorio Tech Blog search

Self-Supervised Learning(자기지도 학습 설명)

|

이 글에서는 Self-Supervised Learning(자기지도 학습)에 대해 알아본다. Self-Supervised Learning은 최근 Deep Learning 연구의 큰 트렌드 중 하나이다.

Self-Supervised Learning의 기본적인 개념과 여러 편의 논문을 간략히 소개하고자 한다.


Self-Supervised Learning

일반적으로 Supervised Learning(지도학습)이 높은 성능의 모델을 만드는 것이 유리하지만, 수많은 데이터에 label을 전부 달아야 한다는 점에서 데이터셋 모으기가 어려우며 따라서 활용하는 방법도 제한적일 수밖에 없다.
이와 같은 문제를 해결하고자 나온 방법이

  • 아주 일부분만 label이 존재하는 데이터셋을 사용하는 Semi-Supervisd Learning(준지도 학습)
  • label이 전혀 없는 데이터셋을 사용하는 Unsupervised Learning(비지도 학습)

이고, 최근 주목받는 연구 방법이 Self-Supervised Learning(자기지도 학습)이다. 보통 Self-Supervised Learning을 연구할 때, 다음과 같은 과정을 따른다:

  1. Pretext task(연구자가 직접 만든 task)를 정의한다.
  2. Label이 없는 데이터셋을 사용하여 1의 Pretext task를 목표로 모델을 학습시킨다.
    • 이때, 데이터 자체의 정보를 적당히 변형/사용하여 (label은 아니지만) 이를 supervision(지도)으로 삼는다.
  3. 2에서 학습시킨 모델을 Downstream task에 가져와 weight는 freeze시킨 채로 transfer learning을 수행한다(2에서 학습한 모델의 성능만을 보기 위해).
  4. 그래서 처음에는 label이 없는 상태에서 직접 supervision을 만들어 학습한 뒤, transfer learning 단계에서는 label이 있는 ImageNet 등에서 Supervised Learning을 수행하여 2에서 학습시킨 모델의 성능(feature를 얼마나 잘 뽑아냈는지 등)을 평가하는 방식이다.

여기서 Self-Supervised Learning의 이름답게 label 등의 직접적인 supervision이 없는 데이터셋에서 스스로 supervision을 만들어 학습하기 때문에, supervision이 전혀 없는 Unsupervised Learning의 분류로 보는 것은 잘못되었다는 시각이 있다.

I now call it “self-supervised learning”, because “unsupervised” is both a loaded and confusing term. In self-supervised learning, the system learns to predict part of its input from other parts of it input. In other words a portion of the input is used as a supervisory signal to a predictor fed with the remaining portion of the input. Self-supervised learning uses way more supervisory signals than supervised learning, and enormously more than reinforcement learning. That’s why calling it “unsupervised” is totally misleading. That’s also why more knowledge about the structure of the world can be learned through self-supervised learning than from the other two paradigms: the data is unlimited, and amount of feedback provided by each example is huge. Self-supervised learning has been enormously successful in natural language processing. For example, the BERT model and similar techniques produce excellent representations of text. BERT is a prototypical example of self-supervised learning: show it a sequence of words on input, mask out 15% of the words, and ask the system to predict the missing words (or a distribution of words). This an example of masked auto-encoder, itself a special case of denoising auto-encoder, itself an example of self-supervised learning based on reconstruction or prediction. But text is a discrete space in which probability distributions are easy to represent. So far, similar approaches haven’t worked quite as well for images or videos because of the difficulty of representing distributions over high-dimensional continuous spaces. Doing this properly and reliably is the greatest challenge in ML and AI of the next few years in my opinion. - Yann Lecun

게시물을 보면 예전에는 이러한 학습 방식을 “주어진” supervision이 없기 때문에 Unsupervised Learning이라 불렀지만, 사실은 모델이 “스스로” supervision을 만들어 가며 학습하기 때문에 Self-Supervised Learning이라 부르는 것이 맞다고 설명하는 내용이다.
그리고 (label은 없어도) 데이터는 무궁무진하며, 그로부터 얻을 수 있는 feedback 역시 엄청나기 때문에 “스스로” supervision을 만드는 Self-Supervised Learning 방식이 매우 중요하게 될 것이라고 언급한다.


Image Representation Learning

Discriminative Unsupervised Feature Learning with Exemplar Convolutional Neural Networks

NIPS 2014

Examples

이 논문에서는 Examplarpretext task로 한 Examplar-CNN이라고 하는 모델을 소개한다.

$N$개의 이미지에서 중요한 부분, 정확히는 considerable gradients를 가지고 있는 32 $\times$ 32 크기의 patch를 하나씩 추출한다(Gradient 크기의 제곱에 비례하는 확률로 patch를 선정함). 이렇게 추출한 Seed patch를 갖고 여러 가지 data augmentation을 진행한다. 위 그림의 오른쪽에서 볼 수 있듯이 translation, scaling, rotation, contrast, color 등을 이동하거나 조정하여 만든다.

분류기는 Data augmentation으로 얻어진 patch들은 하나의 class로 학습해야 한다. 여기서 Loss는 다음과 같이 정의된다.

[L(X) = \sum_{\text{x}i\in X} \sum{T \in T_i} l(i, T\text{x}_i)]

$l(i, T\text{x}_i)$은 Transformed sample과 surrogate true lable $i$ 사이의 loss를 의미한다. 즉 하나의 이미지가 하나의 class를 갖게 되는데, 이는 데이터셋이 커질수록 class의 수도 그만큼 늘어난다는 문제를 갖고 따라서 학습하기는 어렵다는 단점을 안고 있다.

아래는 실험 결과이다.

Examples

Unsupervised Visual Representation Learning by Context Prediction

ICCV 2015

이 논문에서는 context prediction 종류의 pretext task를 제안하는데, 간단하게 말하면 이미지에 $ 3 \times 3 = 9$개의 patch를 가져와, 중간 patch와 다른 1개의 patch를 보여 주고, 다른 1개의 patch가 중간 patch의 어느 뱡향(왼쪽 위, 위, 오른쪽 등 8방향)에 위치하는지를 찾는 문제이다. 아래 그림을 보면 이해가 빠를 것이다:

Examples
위 그림에서 무엇이 정답인지 보기

Answer key: Q1: Bottom right Q2: Top center

두 장의 이미지를 한 번에 입력으로 받아야 하기 때문에 아래와 같이 patch를 받는 부분을 두 부분으로 만들었다.

Examples

그리고 “자명한” 해를 쉽게 찾을 것을 방지하기 위해, 9개의 patch들이 딱 붙어 있지 않도록 하였으며 중심점의 위치도 랜덤하게 약간 이동하여 patch를 추출하여 자명한 경우를 최대한 피하고자 하였다(하지만 충분치는 않았다고 한다).

아래는 실험 결과이다.

Examples

Unsupervised Learning of Visual Representations using Videos

ICCV 2015

CNN에 의미적으로 레이블링된(semantically-labeled) 것이 꼭 필요한지 의문을 가지며 출발하는 이 논문은 시각적 표젼(visual representatinos)을 학습하기 위해 레이블이 없는 십만 개의 영상을 웹에서 가져와 사용하였다. 핵심 아이디어는 visual tracking이 supervision을 준다는 것이다. track으로 연결된 두 개의 패치는 같은 물체나 그 일부분을 포함할 것이기에, deep feature space에서 비슷한 시각적 표현을 갖고 있을 것이라 생각할 수 있다. 이러한 CNN 표현을 학습하기 위해 ranking loss function을 사용하는 Siamese-triplet 네트워크를 사용한다.
ImageNet에서 단일 이미지를 가져오는 대신 레이블이 없는 10만 개의 영상과 VOC 2012 dataset을 사용하여 비지도 학습을 수행, 52% mAP를 달성하였다. 이는 지도학습을 사용한 ImageNet 기존 결과의 54.4% mAP에 거의 근접하는 수치이다.
또한 이 논문은 surface-normal 추정 문제와 같은 것도 충분히 잘 수행함을 보여준다.

Examples
  • (a) 레이블이 없는 영상이 주어지면 그 안에서 비지도 tracking을 수행한다.
  • (b) 첫 frame에서 Query 패치, 마지막 frame에서 추적하는 patch, 다른 영상에서 가져온 무작위 패치로 구성된 총 3개의 패치가 siamese-triplet 네트워크에 입력으로 들어가 학습된다.
  • (c) 학습 목표: Query 패치-추적된 패치 사이의 거리가 Query 패치-무작위 패치 사이의 거리보다 더 작게 되도록 학습한다.

Joint Unsupervised Learning of Deep Representations and Image Clusters

CVPR 2016

JULE는 deep representations와 image cluster의 Joint Unsupervised LEarning을 위한 recurrent framework를 말한다. 이 framework에서, clustering 알고리즘에서 연속적인 동작이 recurrent process의 한 스텝으로 표현되어, CNN의 representation 출력의 위에 쌓아지는 형태를 갖는다. 학습하는 동안, image clusters와 representations는 공동으로 업데이트된다: image clustering은 forward pass로, representation 학습은 backward pass로 수행된다.
좋은 표현은 image clustering에 유익하며, 그 결과는 representation 학습에 지도를 해 줄 수 있을 것이라는 것이 핵심 부분이다. 두 과정을 하나의 모델에 통합하여 합쳐진 가중 triplet loss를 사용하여 end-to-end 방식으로 최적화했기에, 더 강력한 표현뿐만 아니라 더 정확한 image cluster도 얻을 수 있다.
여러 실험은 이 방법이 여러 데이터셋에서 image clustering을 할 때 기존 SOTA를 뛰어넘는다는 것을 보여준다.

Examples
Examples

Colorful Image Colorization

ECCV 2016

이 논문에서는 회색조 이미지가 입력으로 주어지면 적절히 색깔을 입혀 그럴 듯한 컬러 이미지로 바꾸는 작업을 수행한다. 이 문제는 명백히 제약조건이 부족하기 때문에, 기존 접근법은 사용자와 상호작용하거나 채도를 감소시킨 채색에 의존하였다.
여기서는 선명하고 사실적인 채색 결과를 도출하는 완전히 자동화된 접근법을 제시한다. 저자들은 이 문제에 내재하는 불확실성을 포용하여 이 문제를 분류 문제로 생각하고 채색 결과의 색깔 다양성을 증가시키기 위해 학습 시간에서 class별 rebalancing을 사용한다. 시스템은 CNN에서 feed-forward pass로 구성된다.

결과는 ‘채색 튜링 테스트(colorization Turing test)’으로 평가하는데, 사람 지원자에게 실제 이미지와 생성된 이미지 중 어떤 것이 실제 이미지인지 판별하도록 하였다. 결과는 32%가 답을 틀렸다는 점에서 이전 방법에 비해 매우 뛰어난 결과를 보여준다.

더욱이, 채색이라는 문제가 cross-channel encoder로서 기능하는 self-supervised feature 학습을 위한 강력한 pretext task로 기능할 수 있음을 보였다.

아래에서 채색(colorization) 결과를 확인할 수 있다. 꽤 뛰어난 결과를 보여준다.

Examples

전체 모델 구조는 아래와 같다.

Examples

Unsupervised Learning of Visual Representations by Solving Jigsaw Puzzles

ECCV 2016

이 논문은 직소 퍼즐처럼 이미지에 9개의 patch를 만들어 순서를 뒤섞은 뒤 원래 배치를 찾아내는 문제를 pretext task로 지정하였다. 대부분은 너무 비슷한 순열이기 때문에 딱 100개의 순열만을 지정하여 이를 분류하는 것을 목표로 한다.

Examples

다른 문제 간 호환성을 위해 context-free network(CFN, siamese-enread CNN)을 사용하였다. 이미지 패치를 입력으로 하고 CFN은 AlexNet에 비해 적은 parameter를 가지면서 동일한 의미 학습 능력을 유지한다고 한다.


Stacked Denoising Autoencoders: Learning Useful Representations in a Deep Network with a Local Denoising Criterion

입력 이미지에 무작위 noise를 추가한 뒤 원래 이미지를 복원하는 것을 pretext task로 지정하였다.

Examples

Self-supervised learning of visual features through embedding images into text topic spaces

CVPR 2017

대규모의 multimodal (텍스트와 이미지) 문서를 사용함으로써 visual features를 self-supervised learning을 수행한 논문이다. 특정 이미지가 그림으로 나타날 가능성이 더 높은 의미론적 문맥을 예측하는 CNN을 학습함으로써 구별되는 visual features이 학습될 수 있음을 보였다고 한다.

이를 위해 잘 알려진 주제 모델링 기술을 사용하여 텍스트 말뭉치에서 발견 된 숨겨진 의미 구조를 활용하였고, 최근의 자기지도 접근 방식과 비교하여 이미지 분류, object detection 및 multimodal 검색에서 SOTA 성능을 보여주었다.

Examples

Colorization as a Proxy Task for Visual Understanding

여기서는 Proxy Task라는 이름으로 self-supervised learning을 수행했다. 이 논문에서는 ImageNet의 학습 paradigm을 재검토하여 다음 질문에 대한 답을 찾는다:

  • 얼마나 많은 학습 데이터가 필요한가?
  • 얼마나 많은 레이블이 필요한가?
  • 미세조정할 시 얼마나 많은 features가 변화하는가?

그래서 Proxy task로 채색(colorization) 문제를 사용하여 강력한 지도를 해줄 수 있음을 보인다.

Examples

Learning Image Representations by Completing Damaged Jigsaw Puzzles

WACV 2018

이 논문은 직소 퍼즐을 사용하는데 위의 논문과는 달리 일부 손상이 가해진 직소 퍼즐을 pretext task로 삼아 손상된 직소 퍼즐을 맞추는 것을 목표로 하였다.

3 $\times$ 3개의 패치를 만들고 순서를 섞은 뒤, 9개 중 1개를 제거하고 회색조 이미지로 변환한다. 이를 원래 이미지로 복원하는 것이 목표이다.

Examples
Examples

Unsupervised Representation Learning by Predicting Image Rotations

ICLR 2018

Examples

이 논문은 이미지를 0°, 90°, 180°, 270° 회전시킨 후 얼마나 회전시켜야 원본 이미지가 나오는지를 맞히는 4-class 분류 문제를 pretext task로 사용하였다.

Examples

회전시킨 4종류의 이미지를 한 번에 넣는 것이 효율이 좋고, 2 또는 8방향 회전도 고려하였으나 4방향이 가장 좋은 성능을 보인다고 한다.


Cross-Domain Self-supervised Multi-task Feature Learning using Synthetic Imagery

CVPR 2018

보다 일반적인 고수준의 시각표현을 얻기 위해 하나의 task가 아닌 여러 개의 task를 학습시키는 방법을 제안하는 이 논문은 합성된 이미지를 학습한다. 실제 이미지와 합성된 이미지의 도메인 차이를 극복하기 위해 적대적 학습 방법에 기초한 비지도 domain adaptation 방법을 사용한다. 합성된 RGB 이미지가 입력으로 들어오면 네트워크는 그 표면 normal, depth, 등고선 등을 동시 추정하며 실제와 합성된 이미지 domain 간의 차이를 최소화하려 한다.

Examples

Depth prediction에서 기존 feature 학습 방법은 패치의 상대적 위치를 예측하능 등의 task를 pretext task로 한다. 이 논문에서, pretext task는 pixel prediction task를 채택하였다.

Examples

이 논문에서, Instance contour detection(객체 등고선 탐지), Depth Prediction(깊이 추정), Surface normal estimation(표면 수직벡터 추정)이 multi-task를 구성한다.


Self-Supervised Representation Learning by Rotation Feature Decoupling

ICML 2019

이 논문에서는 회전과 관련된 부분과 무관한 부분을 포함하는 split representation을 학습하는 모델을 제시한다. 이미지 회전과 각 instance를 동시에 식별하는데, 회전 식별과 instance 식별을 분리함으로써 회전의 noise 영향을 최소화하여 회전 각도 예측의 정확도를 높이고 이미지 회전과 관계없이 instance를 식별하는 것 역시 향상시키도록 한다.

Examples

방법은 다음으로 구성된다:

  • Rotation Feature Decoupling
    • 이미지 회전 예측
    • Noisy한 회전된 이미지 식별
    • Feature Decoupling
      • 회전 분류
      • 회전과 무관한 부분
      • 이미지 instance 분류
Examples

Unsupervised Deep Learning by Neighbourhood Discovery

ICML 2019


##


##


##


##


Examples
Comment  Read more

Python Selenium 사용법 [파이썬 셀레늄 사용법, 크롤링]

|

이 글에서는 Python 패키지인 Selenium에 대해 알아본다. 파이썬으로 크롤링할 때 Beautifulsoup4와 더불어 빼놓을 수 없는 훌륭한 라이브러리이다.

고급 기술(?)까지 전부 정리했기 때문에, 처음 사용하는 사람이라면 우선 필요한 기능부터 보면 된다.

예시 코드의 일부는 Selenium Documentation의 코드를 참조하였음을 밝혀둔다.

selenium 4 버전부터는 사용법이 바뀐 부분이 있다. 버전 4 이후 바뀐 점을 먼저 보면 좋다.
이전 버전을 사용해야 한다면 버전 4 이후 바뀐 점을 보고 적용하면 된다.

2023.06.09 updated


Install

일반 python 환경이라면 pip(pip3)을, conda 환경이라면 conda를 사용한다.

pip install selenium
conda install selenium

일반적인 파이썬 라이브러리와는 다르게, 하나 더 필요한 것이 있다.

브라우저별로 selenium webdriver를 다운로드해야 한다. 필자는 크롬을 추천한다:

Examples

버전이 여러 개가 있는데, 본인이 사용하는 Chrome의 버전에 맞는 webdriver를 다운받아야 한다.
크롬의 버전은 여기에서 확인하거나 오른쪽 위 점 3개 > 도움말 > Chrome 정보에서 확인할 수 있다.

다운받은 파일을 Python 파일과 같은 디렉토리에 둔다. 다른 곳에 두어도 상관없지만, driver 경로를 입력하기 아주 조금 더 귀찮아질 수 있다.


Import

import selenium
from selenium import webdriver
from selenium.webdriver import ActionChains

from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.by import By

from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import Select
from selenium.webdriver.support.ui import WebDriverWait

버전 4 이후 바뀐 점

  • driver를 만들 때 executable_path 인자를 주지 않는다. ```python

    selenium 4 이상에서는 executable_path를 인자로 주지 않는다.

    driver = webdriver.Chrome()

이전 버전에서는 아래와 같이 쓴다.

driver = webdriver.Chrome(executable_path=’chromedriver’)


- `find_element`류 함수들의 사용법이 함수명에 `by_xxx`를 쓰는 대신 `By.XXX` 인자를 주는 것으로 바뀌었다.

```python
# selenium 4 이상에서는 By.XXX를 인자로 준다.
search_box = driver.find_element(By.XPATH, '//*[@id="tsf"]/div[2]/div[1]/div[1]/div/div[2]/input')

# 이전 버전에서는 아래와 같이 쓴다.
search_box = driver.find_element_by_xpath('//*[@id="tsf"]/div[2]/div[1]/div[1]/div/div[2]/input')

불러오기(Driver & Web Load)

URL = 'https://www.miraeassetdaewoo.com/hki/hki3028/r01.do'

driver = webdriver.Chrome()

driver.get(url=URL)

먼저 webdriver.Chrome(executable_path) 함수를 사용하여 드라이버를 로드한다. 여기서는 driver라는 변수에 저장한다.
조금 전 다운로드한 파일 이름이 chromedriver.exe라면 경로는 같은 디렉토리인 경우 그냥 파일 이름(chromedriver)만 입력하면 된다. 확장자는 필요 없으며, 파일 경로가 다르다면 상대 경로나 절대 경로를 이용하자.

그리고 get(url) 함수를 사용하면, 해당 URL을 브라우저에서 띄운다.

현재 url 얻기

참고로, 현재 url은 다음 코드를 쓰면 된다.

print(driver.current_url)

현재 브라우저 title 얻기

print(driver.title)

브라우저 닫기

driver.close()

Wait till Load Webpage(로딩 대기)

브라우저에서 해당 웹 페이지의 요소들을 로드하는 데 시간이 좀 걸린다. 따라서 element가 존재하지 않는다는 error를 보고 싶지 않다면 해당 요소가 전부 준비가 될 때까지 대기해야 한다.

Implicit Waits(암묵적 대기)

driver.implicitly_wait(time_to_wait=5)

찾으려는 element가 로드될 때까지 지정한 시간만큼 대기할 수 있도록 설정한다. 이는 한 webdriver에 영구적으로 작용한다. 인자는 초 단위이며, Default 값은 0이다. 위의 예시는 5초까지 기다려 준다는 의미이다.

Explicit Waits(명시적 대기)

간단하게는 time.sleep(secs) 함수를 사용하여 무조건 몇 초간 대기하는 방법이 있다. 생각하기 귀찮다면 매우 편리하긴 하지만, 형편없는 효율을 자랑하므로 지양하자.

아래 코드를 살펴보자.

from selenium import webdriver

from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC

driver = webdriver.Chrome()
driver.get(url='https://www.google.com/')
try:
    element = WebDriverWait(driver, 5).until(
        EC.presence_of_element_located((By.CLASS_NAME, 'gLFyf'))
    )
    print(element)
finally:
    driver.quit()

위의 코드는 웹페이지에서 class가 gLFyf인 어떤 element를 찾을 수 있는지를 최대 5초 동안 매 0.5초마다 시도한다. expected_conditions(EC)는 만약 element를 찾을 수 있었으면 True를, 아니라면 False를 반환한다.

이 예시에서는 element가 존재하는지를 조건으로 사용했는데, 이것 말고도 아래와 같은 예시들이 있다:

제목이 어떤 문자열인지, 어떤 문자열을 포함하는지, 특정/모든 요소가 로드되었거나/볼 수 있거나/볼 수 없거나/클릭 가능하거나 등등의 여러 조건이 가능하다.

  • title_is
  • title_contains
  • presence_of_element_located
  • visibility_of_element_located
  • visibility_of
  • presence_of_all_elements_located
  • text_to_be_present_in_element
  • text_to_be_present_in_element_value
  • frame_to_be_available_and_switch_to_it
  • invisibility_of_element_located
  • element_to_be_clickable
  • staleness_of
  • element_to_be_selected
  • element_located_to_be_selected
  • element_selection_state_to_be
  • element_located_selection_state_to_be
  • alert_is_present

Custom으로 조건을 설정하는 것도 가능한데, __init__ 함수와 __call__ 함수를 구현한 class를 작성하면 된다. 여기를 참조한다.

참고로, until(method, message='') 함수는 method의 반환값이 False인 동안 계속 method를 실행한다.
반대로 until_not(method, message='') 함수는 True인 동안 실행한다.


시작하기 전에…

먼저 다음 코드를 실행해보자. 앞에서 설명한 대로 chromedriver.exe를 같은 위치에 놓았다면 문제 없이 실행될 것이다.
구글 홈페이지가 바뀌었다면 오류가 생길 수도 있다. 이건 글쎄..뭐..어쩔 수 없다.

from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys
from time import sleep

options = webdriver.ChromeOptions()
options.add_argument('window-size=1920,1080')

driver = webdriver.Chrome(options=options)
driver.implicitly_wait(5)

driver.get(url='https://www.google.com/')

search_box = driver.find_element(By.XPATH, '//*[@id="APjFqb"]')

search_box.send_keys('greeksharifa.github.io')
search_box.send_keys(Keys.RETURN)

elements = driver.find_elements(By.XPATH, '//*[@id="rso"]/div[*]')

for element in elements:
    print(element.text)
    print(element.text, file=open('gorio.txt', 'w', encoding='utf-8'))

sleep(3)
driver.close()

구글 검색의 결과가 출력되면서 동시에 gorio.txt라는 파일이 생성되며 같은 내용이 출력되어 있음을 확인할 수 있다.


요소 찾기(Locating Elements)

Selenium은 다양한 요소(element)를 찾는 방법을 지원한다. HTML을 조금 다뤄봤다면, class, id, parent-child 관계 등등을 알고 있을 것이다. 모른다면… 크롤링 하기 전에 그 공부부터(?)

먼저 어떤 요소를 찾을지부터 정해야 한다. 그냥 크롬을 켜서, Ctrl + Shift + C를 눌러 원하는 요소를 클릭한다.

Examples

그러면 해당 요소의 정보를 볼 수 있을 것이다. 원하는 요소를 아래쪽 Elements 창에서 우클릭하여 내용을 복사하거나(Copy element 혹은 Copy XPath 등), 아니면 그냥 직접 보고 입력해도 된다.

Examples
<input class="gLFyf gsfi" maxlength="2048" name="q" type="text" 
jsaction="paste:puy29d" aria-autocomplete="both" aria-haspopup="false" 
autocapitalize="off" autocomplete="off" autocorrect="off" autofocus="" 
role="combobox" spellcheck="false" title="검색" value="" aria-label="검색" 
data-ved="0ahUKEwjsxYnKytzsAhUVQd4KHXjpCvsQ39UDCAQ">

각 요소에는 class나, XPath나, id 등의 여러 속성이 존재한다. 이 속성이나 경로를 갖고 요소를 찾을 수 있다.
위에서 예시로 든 것은 구글의 검색창을 나타낸다. 이 검색창은 id가 없고 class가 특징적인 것 같으니, class로 찾아보자.

class로 찾는 방법은 다음과 같다.

search_box = driver.find_element(By.CLASS_NAME, 'gLFyf')
# 이전 버전
# search_box = driver.find_element_by_class_name('gLFyf')

# 아래는 키보드 입력을 해 주는 코드이다. 나중에 설명하겠지만 한번 해 보자.
search_box.send_keys('gorio')

그러면 search_box라는 변수에는 구글의 검색창 요소가 담겨 있는 것이다. 선택한 요소에 키보드 입력을 보내거나, 클릭하는 등의 행동을 할 수 있다. 이건 조금 있다가 살펴보자.

class 말고도 선택하는 방법은 많다. 위에서 class_name로 끝나는 함수를 쓰는 대신, id, name, xpath, css_selector, 등으로 끝나는 함수를 사용할 수 있다.

Examples

selenium 버전 4 이후 알아야 하는 함수는 find_elementfind_elements 뿐이다.

총 18개의 함수를 지원한다. 9개의 쌍이 있는데, find_element로 시작하는 함수는 조건에 맞는 요소를 하나만 반환하고, find_elements로 시작하는 함수는 해당 조건을 만족하는 모든 요소를 반복가능한(iterable) 형태로 반환한다. 버전 4 이후로는 기본 함수 2개 외에는 전부 deprecated되었다.

위에서 볼 수 있듯이 class나, css selector, id, name, tag_name, xpath, link_text, partial_link_text 등으로 선택 가능하다.

맨 위의 함수인 find_element 함수는 인자를 3개 받는다. self는 설명할 필요 없고, by는 조금 전 explicit_waits에서 보았던 그 selenium.webdriver.common.by이다. by도 CLASS_NAME 등으로 속성을 지정 가능하다.

  • ID = “id”
  • XPATH = “xpath”
  • LINK_TEXT = “link text”
  • PARTIAL_LINK_TEXT = “partial link text”
  • NAME = “name”
  • TAG_NAME = “tag name”
  • CLASS_NAME = “class name”
  • CSS_SELECTOR = “css selector”

link_text<a> tag의 링크를 대상으로 찾는다. 비슷하게 partial_link_text는 요소의 링크가 전달한 링크를 일부분으로 포함되어 있으면 해당 요소가 선택된다.

<html>
 <body>
  <p>Are you sure you want to do this?</p>
  <a href="continue.html">Continue</a>
  <a href="cancel.html">Cancel</a>
  <p class="content">Site content goes here.</p>
</body>
<html>

여기서 다음 코드로 찾을 수 있다.

continue_link = driver.find_element(By.LINK_TEXT, 'Continue')
continue_link = driver.find_element(By.PARTIAL_LINK_TEXT, 'Conti')

# 이전 버전
# continue_link = driver.find_element_by_link_text('Continue')
# continue_link = driver.find_element_by_partial_link_text('Conti')

CSS Selector는 다음 코드를 쓰면 된다.

content = driver.find_element(By.CSS_SELECTOR, 'p.content')
# 이전 버전
# content = driver.find_element_by_css_selector('p.content')

여기서 find_element_by_xpath는 매우 강력한 찾기 기능을 제공한다. 말하자면 웹페이지 상에서 해당 요소의 전체경로(혹은 상대경로)를 갖고 찾기 기능을 수행할 수 있는데, 원하는 요소에서 Copy XPath만 한 다음 그대로 갖다 쓰면 해당 요소를 정확히 찾아준다.

search_box = driver.find_element(By.XPATH, '//*[@id="tsf"]/div[2]/div[1]/div[1]/div/div[2]/input')
# 이전 버전
# search_box = driver.find_element_by_xpath('//*[@id="tsf"]/div[2]/div[1]/div[1]/div/div[2]/input')

XPath로 요소 찾기

표현식 설명
nodename nodename을 name으로 갖는 모든 요소 선택
/ root 요소에서 선택
// 현재 요소의 자손 요소를 선택
. 현재 요소를 선택
.. 현재 요소의 부모 요소를 선택
@ 속성(attibutes)를 선택
* 모든 요소에 매치됨
@* 모든 속성 요소에 매치됨
node() 모든 종류의 모든 요소에 매치됨
| OR 조건의 기능

예시는 다음과 같다.

표현식 설명
/div root 요소의 div 요소
./div 현재 요소의 자식 요소 중 div 요소
/* name에 상관없이 root 요소를 선택
./* 또는 * context 요소의 모든 자식 요소를 선택
//div 현재 웹페이지에서 모든 div 요소를 선택
.//div 현재 요소의 모든 자손 div 요소를 선택
//* 현재 웹페이지의 모든 요소를 선택
.//* 현재 요소의 모든 자손 요소를 선택
/div/p[0] root > div > p 요소 중 첫 번째 p 요소를 선택
/div/p[position()<3] root > div > p 요소 중 첫 두 p 요소를 선택
/div/p[last()] root > div > p 요소 중 마지막 p 요소를 선택
/bookstore/book[price>35.00] root > bookstore > book 요소 중 price 속성이 35.00 초과인 요소들을 선택
//*[@id="tsf"]/div[2]/ id가 tsf인 모든 요소의 자식 div 요소 중 3번째 요소를 선택
//title | //price title 또는 price 요소를 선택

텍스트 입력(키보드 입력)

어떤 요소를 find_element... 함수를 통해 선택했다고 하자.

search_box = driver.find_element(By.XPATH, '//*[@id="tsf"]/div[2]/div[1]/div[1]/div/div[2]/input')

선택한 요소에 키보드 입력을 명령으로 주어 텍스트 입력 등을 수행할 수 있다.
키보드 입력은 send_keys(*value) 함수를 통해 할 수 있다.

search_box.send_keys('greeksharifa.github.io')

기본적으로 send_keys(*value) 함수는 문자열을 받는다. 그런데, enter 같은 특수 키 입력의 경우도 문자열로 처리할 수 있지만(RETURN = '\ue006'), 다음과 같이 입력할 수 있다.

from selenium.webdriver.common.keys import Keys

search_box.send_keys(Keys.RETURN)

경우에 따라 조금씩 다르긴 하지만, 일반적으로 enter 키의 경우 Keys.ENTER 또는 Keys.RETURN으로 입력할 수 있다.

아래는 입력할 수 있는 Keys의 목록이다.

class Keys(object):
    """
    Set of special keys codes.
    """

    NULL = '\ue000'
    CANCEL = '\ue001'  # ^break
    HELP = '\ue002'
    BACKSPACE = '\ue003'
    BACK_SPACE = BACKSPACE
    TAB = '\ue004'
    CLEAR = '\ue005'
    RETURN = '\ue006'
    ENTER = '\ue007'
    SHIFT = '\ue008'
    LEFT_SHIFT = SHIFT
    CONTROL = '\ue009'
    LEFT_CONTROL = CONTROL
    ALT = '\ue00a'
    LEFT_ALT = ALT
    PAUSE = '\ue00b'
    ESCAPE = '\ue00c'
    SPACE = '\ue00d'
    PAGE_UP = '\ue00e'
    PAGE_DOWN = '\ue00f'
    END = '\ue010'
    HOME = '\ue011'
    LEFT = '\ue012'
    ARROW_LEFT = LEFT
    UP = '\ue013'
    ARROW_UP = UP
    RIGHT = '\ue014'
    ARROW_RIGHT = RIGHT
    DOWN = '\ue015'
    ARROW_DOWN = DOWN
    INSERT = '\ue016'
    DELETE = '\ue017'
    SEMICOLON = '\ue018'
    EQUALS = '\ue019'

    NUMPAD0 = '\ue01a'  # number pad keys
    NUMPAD1 = '\ue01b'
    NUMPAD2 = '\ue01c'
    NUMPAD3 = '\ue01d'
    NUMPAD4 = '\ue01e'
    NUMPAD5 = '\ue01f'
    NUMPAD6 = '\ue020'
    NUMPAD7 = '\ue021'
    NUMPAD8 = '\ue022'
    NUMPAD9 = '\ue023'
    MULTIPLY = '\ue024'
    ADD = '\ue025'
    SEPARATOR = '\ue026'
    SUBTRACT = '\ue027'
    DECIMAL = '\ue028'
    DIVIDE = '\ue029'

    F1 = '\ue031'  # function  keys
    F2 = '\ue032'
    F3 = '\ue033'
    F4 = '\ue034'
    F5 = '\ue035'
    F6 = '\ue036'
    F7 = '\ue037'
    F8 = '\ue038'
    F9 = '\ue039'
    F10 = '\ue03a'
    F11 = '\ue03b'
    F12 = '\ue03c'

    META = '\ue03d'
    COMMAND = '\ue03d'

텍스트 입력 지우기

위에 나와 있는 것처럼, 입력한 텍스트를 지우는 방법은 Keys.BACKSPACE 또는 Keys.BACK_SPACE를 사용하는 것이다.

만약 전체를 지우고 싶다면 Keys가 아니라, 선택한 요소에서 clear() 함수를 호출하면 된다.

search_box.clear()

파일 업로드

파일을 받는 <input>을 선택한 뒤, send_keys(file_path)를 호출하면 된다.

upload = driver.find_element(By.TAG_NAME, 'input')
# 이전 버전
# upload = driver.find_element_by_tag('input')
upload.send_keys(file_path)

상호작용

클릭하기(click)

클릭은 어렵지 않다. find_element 함수로 요소를 선택한 다음에, click() 함수를 호출하면 끝이다.

search_box.send_keys(Keys.RETURN)까지 입력했다면, 검색창은 다음과 같다.

Examples

이제 첫 번째 검색 결과를 클릭해 보자. Ctrl + Shift + C을 누른 뒤 제목 부분을 클릭하고, 위에서 설명한 Copy XPath를 이용하여 XPath를 얻자. XPath는 다음과 같다.

//*[@id="rso"]/div[1]/div/div/div[1]/div/div/div[1]/div/a/h3

다음 코드를 써 보자.

posting = driver.find_element(By.XPATH, '//*[@id="rso"]/div[1]/div/div/div[1]/div/div/div[1]/div/a/h3')
# 이전 버전
# posting = driver.find_element_by_xpath('//*[@id="rso"]/div[1]/div/div/div[1]/div/div/div[1]/div/a/h3')
posting.click()

실행하면 첫 번째 검색 결과가 클릭되어 다음과 같은 화면을 볼 수 있을 것이다.

Examples

옵션 선택 및 제출(submit)

XPath 등으로 select 요소를 선택한 다음에 각 옵션을 선택할 수 있지만, 그것보다 더 좋은 방법이 있다.

from selenium.webdriver.support.ui import Select

select = Select(driver.find_element(By.NAME, 'select_name'))
# select = Select(driver.find_element_by_name('select_name'))

select.select_by_index(index=2)
select.select_by_visible_text(text="option_text")
select.select_by_value(value='고리오')

selenium.webdriver.support.ui.Selectselect 요소를 선택하여 쉽게 다룰 수 있도록 한다.
위 코드에서 볼 수 있듯이 select 내에서 인덱스로 선택하거나, 옵션의 텍스트, 혹은 어떤 값을 통해 선택이 가능하다.

특정 선택을 해제하려면 다음 코드를 사용한다.


select.deselect_by_index(index=2)
select.deselect_by_visible_text(text="option_text")
select.deselect_by_value(value='고리오')

# 전부 해제
select.deselect_all()

선택된 옵션 리스트를 얻으려면 select.all_selected_options으로 얻을 수 있고, 첫 번째 선택된 옵션은 select.first_selected_option, 가능한 옵션을 모두 보려면 select.options를 사용하면 된다.

제출(submit)하려면 요소를 찾은 뒤 click()을 수행해도 되지만, 다음과 같이 써도 된다.

submit_btn.submit()

만약 선택한 요소가 form 형식이 아니라면 NoSuchElementException 오류를 볼 수 있을 것이다.

Drag and Drop

어떤 일련의 동작을 수행하기 위해서는 ActionChains를 사용하면 된다. source 요소에서 target 요소로 Drag & Drop을 수행한다고 하자.

from selenium.webdriver import ActionChains

action_chains = ActionChains(driver)
action_chains.drag_and_drop(source, target).perform()

Window / Frame 이동

최신 웹 페이지에서는 frame 같은 것을 잘 사용하지 않지만, 예전에 만들어진 사이트라면 frame을 사용한 경우가 있다.
이렇게 frame 안에 들어 있는 요소는 find_element 함수를 써도 그냥 찾아지지 않는다. find_element 함수는 frame 내에 있는 요소를 찾아주지 못한다.

그래서 특정 frame으로 이동해야 할 때가 있다.

driver.switch_to_frame("frameName")
driver.switch_to_window("windowName")

# frame 내 subframe으로도 접근이 가능하다. 점(.)을 쓰자.
driver.switch_to_frame("frameName.0.child")

windowName을 알고 싶다면 다음과 같은 링크가 있는지 살펴보자.

<a href="somewhere.html" target="windowName">Click here to open a new window</a>

혹은 webdriver는 window 목록에 접근할 수 있기 때문에, 다음과 같이 모든 window를 순회하는 것도 가능하다.

for handle in driver.window_handles:
    driver.switch_to_window(handle)

frame 밖으로 나가려면 다음과 같이 쓰면 기본 frame으로 되돌아간다.

driver.switch_to_default_content()

경고창으로 이동할 수도 있다.

alert = driver.switch_to.alert

경고창은 여기서 다룬다.


JavaScript 코드 실행

driver.execute_script() 함수를 실행할 수 있다.

아래는 Name이 search_box인 요소의 값을 query의 값으로 변경하는 코드이다.

driver.execute_script("document.getElementsByName('id')[0].value=\'"+query+"\'")

브라우저 창 다루기

뒤로가기, 앞으로 가기

브라우저는 뒤로가기(back)와 앞으로 가기(forward) 기능을 제공한다. 이를 selenium으로 구현이 가능하다.

driver.forward()
driver.back()

아주 간단하다.

화면 이동(맨 밑으로 내려가기 등)

크롤링을 하다 보면 화면의 끝으로 내려가야 내용이 동적으로 추가되는 경우를 자주 볼 수 있다.
이런 경우에는 웹페이지의 최하단으로 내려가는 코드를 실행할 필요가 있다.

driver.execute_script('window.scrollTo(0, document.body.scrollHeight);')

물론 전체를 내려가야 할 필요가 없다면 document.body.scrollHeight) 대신 지정된 값만큼 이동해도 된다.

특정 요소까지 계속 찾으려면 ActionChain을 써도 된다.

from selenium.webdriver import ActionChains

some_tag = driver.find_element(By.ID, 'gorio')
# some_tag = driver.find_element_by_id('gorio')

ActionChains(driver).move_to_element(some_tag).perform()

브라우저 최소화/최대화

driver.minimize_window()
driver.maximize_window()

Headless 설정

말하자면 브라우저 창을 띄우지 않고 수행하는 방법이다.

여기를 참조한다.

브라우저 크기 설정

여기를 참조한다.

스크린샷 저장

driver.save_screenshot('screenshot.png')

Option(ChromeOption)

여러 옵션을 설정할 수 있다. 브라우저의 창 크기, 해당 기기의 정보 등을 설정할 수 있다.

기본적인 사용법은 다음과 같다. 브라우저가 실행될 때 창 크기를 설정할 수 있다.

options = webdriver.ChromeOptions()
options.add_argument('window-size=1920,1080')

driver = webdriver.Chrome('chromedriver', options=options)

다른 기능들은 여기에 적어 두었다. 코드를 보면 역할을 짐작할 수 있을 것이다.

options.add_argument('headless')
options.add_argument('window-size=1920x1080')
options.add_argument('disable-gpu')

options.add_argument('start-maximized')
options.add_argument('disable-infobars')
options.add_argument('--disable-extensions')

options.add_experimental_option('excludeSwitches', ['enable-automation'])
options.add_experimental_option('useAutomationExtension', False)
options.add_argument('--disable-blink-features=AutomationControlled')

options.add_experimental_option('debuggerAddress', '127.0.0.1:9222')

ActionChains (마우스, 키보드 입력 등 연속 동작 실행)

from selenium.webdriver import ActionChains

menu = driver.find_element(By.CSS_SELECTOR, '.nav')
hidden_submenu = driver.find_element(By.CSS_SELECTOR, '.nav #submenu1')
# menu = driver.find_element_by_css_selector('.nav')
# hidden_submenu = driver.find_element_by_css_selector('.nav #submenu1')

ActionChains(driver).move_to_element(menu).click(hidden_submenu).perform()

# 위 한 줄은 아래와 같은 동작을 수행한다.
actions = ActionChains(driver)
actions.move_to_element(menu)
actions.click(hidden_submenu)
actions.perform()

마우스 클릭, Drag & Drop, 키보드 입력 등을 연속적으로 수행할 수 있다.

  • on_element 인자를 받는 함수는, 해당 인자가 주어지지 않으면 현재 마우스 위치를 기준으로 한다.
  • element 인자를 받는 함수는, 해당 인자가 주어지지 않으면 현재 선택이 되어 있는 요소를 기준으로 한다.
  • key_down, key_up 함수는 Ctrl 등의 키를 누를 때 쓰면 된다.
# Ctrl + C를 누른다.
ActionChains(driver).key_down(Keys.CONTROL).send_keys('c').key_up(Keys.CONTROL).perform()
동작 수행 함수 설명
click(on_element=None) 인자로 주어진 요소를 왼쪽 클릭한다.
click_and_hold(on_element=None) 인자로 주어진 요소를 왼쪽 클릭하고 누르고 있는다.
release(on_element=None) 마우스를 주어진 요소에서 뗀다.
context_click(on_element=None) 인자로 주어진 요소를 오른쪽 클릭한다.
double_click(on_element=None) 인자로 주어진 요소를 왼쪽 더블클릭한다.
drag_and_drop(source, target) source 요소에서 마우스 왼쪽 클릭하여 계속 누른 채로 target까지 이동한 뒤 마우스를 놓는다.
drag_and_drop_by_offset(source, xoffset, yoffset) 위와 비슷하지만 offset만큼 이동하여 마우스를 놓는다.
key_down(value, element=None) value로 주어진 키를 누르고 떼지 않는다.
key_up(value, element=None) value로 주어진 키를 뗀다.
move_to_element(to_element) 마우스를 해당 요소의 중간 위치로 이동한다.
move_to_element_with_offset(to_element, xoffset, yoffset) 마우스를 해당 요소에서 offset만큼 이동한 위치로 이동한다.
pause(seconds) 주어진 시간(초 단위)만큼 입력을 중지한다.
perform() 이미 쌓여 있는(stored) 모든 행동을 수행한다(chaining).
reset_actions() 이미 쌓여 있는(stored) 모든 행동을 제거한다.
send_keys(*keys_to_send) 키보드 입력을 현재 focused된 요소에 보낸다.
send_keys_to_element(element, *keys_to_send) 키보드 입력을 주어진 요소에 보낸다.

경고 창 다루기(alerts)

브라우저를 사용하다보면 상단에 경고창이 뜰 때가 있다. (확인/취소 등)
이 경고창을 무시하는 등의 처리를 할 수 있는 기능을 제공한다.

아래 코드는 경고창에서 수락/거절을 누르거나, 경고창의 내용을 출력, 혹은 경고창에 특정 키 입력을 보낼 수 있다.

from selenium.webdriver.common.alert import Alert

Alert(driver).accept()
Alert(driver).dismiss()

print(Alert(driver).text)
Alert(driver).send_keys(keysToSend=Keys.ESCAPE)

각각의 코드는 알아보기 어렵지 않을 것이다.


기타 기능

  • Touch Actions: 마우스/키보드 입력과 비슷하게 chaining이 가능하다. 터치와 관련한 여러 기능을 제공한다. selenium.webdriver.common.touch_actions.TouchActions
  • Proxy: Proxy 기능을 사용할 수 있다. selenium.webdriver.common.proxy.Proxy
  • 쿠키(Cookies): 쿠키를 추가하거나 가져올 수 있다.
# Go to the correct domain
driver.get('http://www.example.com')

# Now set the cookie. This one's valid for the entire domain
cookie = {name : foo, value : bar}
driver.add_cookie(cookie)

# And now output all the available cookies for the current URL
driver.get_cookies()

References

Comment  Read more

Adversarial AutoEncoder (AAE) 설명

|

본 글에서는 VAEGAN을 결합한 Adversarial Autoencoder (이하 AAE)에 대한 논문을 리뷰하면서 이론에 대해 설명하고 이를 Tensorflow로 구현하는 과정을 보여줄 것이다. 이 알고리즘을 이해하기 위해서는 앞서 언급한 2가지 알고리즘에 대해 숙지하고 있어야 하며, VAE에 대해 알고 싶다면 이 글을, GAN에 대해 알고 싶다면 이 글을 참조하길 바란다.


1. Adversarial Autoencoders Paper Review

1.1. Introduction

오디오, 이미지, 영상 등과 같은 rich distribution을 포착하기 위해 Scalable한 생성 모델을 구성하는 것은 머신러닝에서도 굉장히 중요하고도 어려운 문제로 여겨진다. RBM, DBNs, DBM 등은 MCMC 기반의 알고리즘으로 학습되었다. 이러한 방법론들은 학습이 진행될수록 부정확하게 Log Likelihood의 Gradient를 포착하는 경향을 보였다. 왜냐하면 Markov Chain에서 온 Sample들은 여러 Mode를 빠르게 혼합하지 못하기 때문이다.

최근에는 이 대신 Direct Back-propagation을 이용하여 이 같은 단점을 극복한 VAE, GAN과 같은 알고리즘이 제안되었다.

본 논문에서는 autoencoder를 생성 모델로 변환하는 AAE라는 알고리즘을 제안할 것이다. 우리의 모델에서 이 autoencoder는 autoencoder의 Latent Representation의 Aggregated PosteriorArbitrary Prior에 연결하는 2개의 목적함수 (Traditional Reconstruction Error Criterion, Adversarial Training Criterion)로 학습을 진행할 것이다. 본 논문에서는 이러한 학습 방법이 VAE의 학습과 강력한 연관성을 보여준다는 것을 보여줄 것이다. Encoder가 데이터 분포를 Prior 분포로 변환하는 방법에 대해 학습하고 Decoder는 Imposed Prior를 데이터 분포에 매핑하는 Deep 생성 모델을 학습하게 된다.

1.1.1. Generative Adversarial Networks

GAN은 생성 모델 G와 판별 모델 D라는 2개의 신경망 사이의 Min-Max 적대적 게임을 구축하는 프레임워크이다. 판별 모델 $D(x)$ 는 데이터 공간의 포인트 $\mathbf{x}$ 가 실제 데이터 분포 (Positive Samples) 로 부터 추출되었는지를 계산하는 신경망이다. 생성자는 $G(\mathbf{z})$ 라는 함수를 사용하게 되는데, 이 함수는 Prior $p(\mathbf{z})$ 로부터 추출된 Sample $\mathbf{z}$ 를 데이터 공간에 연결시키는 역할을 한다. $G(\mathbf{z})$ 는 최대한 판별 모델로 하여금 Sample이 실제 데이터 분포로부터 추출되었다고 믿게 만드는, 속이는 역할을 하게 된다. 이 생성자는 x에 대하여 $D(x)$ 의 Gradient를 레버리지하여 학습된다. 그리고 이를 활용하여 파라미터를 조정한다. 이 게임의 해는 아래와 같이 표현할 수 있다.

[\underset{G}{min} \underset{D}{max} E_{\mathbf{x} \sim p_{data}} [logD(\mathbf{x})] + E_{\mathbf{z} \sim p(\mathbf{z})} [log(1 - D(G(\mathbf{z})))]]

alternating SGD를 이용하여 2단계로 학습이 진행된다. 먼저 판별자가 생성자로부터 생성된 가짜 Sample로부터 진짜 Sample을 구별하는 방법을 학습하고, 생성자는 이후 생성된 Sample을 통해 판별자를 속이는 방법을 학습한다.


1.2. Adversarial Autoencoders

잠시 기호에 대해 살펴보자.

[p(\mathbf{z}), q(\mathbf{z} \mathbf{x}), p(\mathbf{x} \mathbf{z})]

위 기호는 차례대로 1) Code에 투사하고 싶은 Prior 분포, 2) Encoding 분포, 3) Decoding 분포를 의미한다.

[p_d (\mathbf{x}), p(\mathbf{x})]

위 기호는 차례대로 4) 실제 데이터 분포, 5) Model 분포를 의미한다. Encoding 함수는 autoencoder의 (잠재 표현) Hidden Code 벡터에 대한 Aggregated Posterior 분포를 아래와 같이 정의한다.

[q(\mathbf{z}) = \int_{\mathbf{x}} q(\mathbf{z} \mathbf{x}) p_d (\mathbf{x}) d\mathbf{x}]

Adversarial AutoencoderAggregated Posterior인 $q(\mathbf{z})$ 를 Arbitrary Prior인 $p(\mathbf{z})$ 와 매칭시킴으로써 regualarized 된다. 그렇게 하기 위해서 이 적대적 네트워크는 아래 그림과 같이 autoencoder의 Hidden Code 벡터의 상위에 추가된다.

autoencoder는 그동안 Reconstruction Error를 최소화한다. 적대적 네트워크의 생성자는 사실 autoencoder의 encoder이다.

[q(\mathbf{z} \mathbf{x})]

Encoder는 Hidden Code인 $q(\mathbf{z})$ 가 실제 Prior 분포 $p(\mathbf{z})$ 로부터 왔다고 판별 모델을 착각하게 만들어 Aggregated Posterior 분포가 판별 모델을 속이도록 만든다.

적대적 네트워크와 autoencoder 모두 2단계를 통해 SGD로 결합하여 학습된다. (Reconstruction 단계 Regularization 단계) Reconstruction 단계에서 autoencoder는 Encoder와 Decoder가 Input에 대한 Reconstruction Error를 최소화하도록 업데이트하게 된다. Regularization 단계에서 적대적 네트워크는 먼저 판별 모델이 진짜 Sample을 구별하도록 업데이트한 후, 생성자가 판별 네트워크를 속이도록 업데이트를 진행한다.

이러한 학습과정이 끝나면, autoencoder의 Decoder는 투사된 Prior인 $p(\mathbf{z})$ 를 실제 데이터 분포에 매핑하는 생성 모델을 정의하게 된다.

AAE의 Encoder를 정의하는 방법에는 다음과 같이 3가지가 존재한다.

[Encoder: q(\mathbf{z} \mathbf{x})]

1) Deterministic
Encoder가 $\mathbf{x}$ 의 deterministic 함수라고 가정해보자. 그렇다면 이 때의 Encoder는 가장 기본적인 형태의 autoencoder의 Encoder과 유사할 것이고 $q(\mathbf{z})$ 내부의 Stochasticity는 오직 실제 데이터 분포 $p_d (\mathbf{x})$ 에서만 찾을 수 있게 된다.

2) Gaussian Posterior
Encoder가 Encoder 네트워크에 의해 예측된 평균과 분산을 따르는 정규 분포라고 가정해보자.

[z_i \sim \mathcal{N} (\mu_i(\mathbf{x}), \sigma_i(\mathbf{x}))]

이 때 $q(\mathbf{z})$ 의 Stochasticity는 실제 데이터 분포와 Encoder의 결과의 정규 분포의 Randomness 모두에서 나온다. VAE에서 보았듯이 Encoder 네트워크의 Back-propagation은 Reparametrization Trick을 통해 이루어진다.

3) Universal Approximator Posterior
AAE의 Encoder네트워크가 정규 분포와 같은 고정된 분포에서 추출한 Random Noise $\eta$ 와 Input $\mathbf{x}$ 의 함수 $f(\mathbf{x}, \eta)$ 라고 해보자. 우리는 여러 $\eta$ 를 Sampling 한 뒤 이에 따른 $f(\mathbf{x}, \eta)$ 를 평가하여 아래와 같은 (Encoder) 임의의 사후 분포를 추출할 수 있다.

[q(\mathbf{z x})]

우리는 Aggregated Posterior인 $q(\mathbf{z})$ 를 아래와 같이 표현할 수 있을 것이다.

[q(\mathbf{z x}) = \int_{\eta} q(\mathbf{z x}, \eta) p_{\eta} (\eta) d \eta]
[\rightarrow q(\mathbf{z}) = \int_{\mathbf{x}} \int_{\eta} q(\mathbf{z x}, \eta) p_d(\mathbf{x}) p_{\eta} p_{\eta}(\eta) d\eta d\mathbf{x}]
[q(\mathbf{z x})]

이 때 위와 같은 Posterior는 더 이상 Gaussian일 필요가 없고, Encoder는 Input $\mathbf{x}$ 가 주어졌을 때 어떠한 임의의 사후 분포도 학습할 수 있다. Aggregated Posterior $q(\mathbf{z})$ 로부터 Sampling을 하는 효과적인 방법이 존재하기 때문에, 적대적 학습 과정은 Encoder 네트워크 $f(\mathbf{x}, \eta)$ 를 통한 직접적인 Back-propagation으로 $q(\mathbf{z})$ 를 $p(\mathbf{x})$ 에 연결할 수 있다.

지금까지 확인한 것처럼, 위 Posterior를 여러 다른 종류로 선택하게 되면 이는 또 다른 학습 Dymamics를 가진 다른 종류의 모델로 귀결된다. 예를 들어 1) Deterministic Case에서 네트워크는 데이터 분포로부터 Stochasticity를 뽑아내서 $q(\mathbf{z})$ 를 $p(\mathbf{x})$ 와 매칭시켜야 한다. 그러나 데이터의 경험적 분포가 학습 데이터 셋에서 고정되어 있기 때문에 Mapping이 Deterministic하면 부드럽지 못한 $q(\mathbf{z})$ 를 만들지도 모른다.

하지만 나머지 2개의 케이스에서는 네트워크가 Stochasticity의 추가적인 원천에 접근할 수 있기 때문에 이는 $q(\mathbf{z})$ 를 부드럽게 만들어 Adversarial Regularization 단계에서 개선이 이루어진다. 그럼에도 불구하고, 굉장히 광범위한 Hyper-parameter Search를 통해서 우리는 Posterior의 각 경우에 맞는 유사한 Test-Likelihood를 얻을 수 있었고 따라서 지금부터 본 논문에서는 Posterior의 Deterministic한 버전을 통해서 논의를 전개해 나가도록 할 것이다.

1.2.1. Relationship to Variational Autoencoders

VAE에서는 KL divergence 페널티를 사용하여 autoencoder의 Hidden Code 벡터Prior 분포를 투사하지만, Hidden Code의 Aggregated PosteriorPrior 분포에 매칭시키는 적대적 학습 과정을 사용할 것이다. VAE는 아래와 같이 $\mathbf{x}$ 의 Negative Log-Likelihood에 대한 상한선을 최소화한다.

[E_{\mathbf{x} \sim p_d(\mathbf{x})} [-logp(\mathbf{x})] < E_{\mathbf{x}} [-log(p(\mathbf{x z}))] + E_{\mathbf{x}} [KL(q(\mathbf{z x})   p(\mathbf{z}))]]
[= E_{\mathbf{x}} [-log(p(\mathbf{x z}))] - E_{\mathbf{x}} [H(q(\mathbf{z x}))] + E_{q(\mathbf{z})} [-logp(\mathbf{z})]]
[= E_{\mathbf{x}} [-log(p(\mathbf{x z}))] - E_{\mathbf{x}}[\Sigma_i log \sigma_i (\mathbf{x})] + E_{q(\mathbf{z})} [-logp(\mathbf{z})] + Const]

[= Reconstruction - Entropy + CrossEntropy(q(\mathbf{z}), p(\mathbf{z}))]

첫 번째 항은 Reconstruction 항으로, 나머지 항은 Regularization 항으로 생각할 수 있다. 만약 Regularization 항이 없다면 모델은 단순히 Input을 재현하는 기본적인 autoencoder의 형태를 취할 것이다. 두 번째 항이 사후 분포의 분산을 크게 만든다면, 세 번째 항은 Aggregated Posterior $q(\mathbf{z})$ 와 Prior $p(\mathbf{z})$ 사이의 Cross-Entropy를 최소화한다.

위 목적함수에서 KL divergence 또는 Cross-Entropy 항은 $q(\mathbf{z})$ 가 $p(\mathbf{z})$ 의 Modes를 고르도록 만든다. AAE에서 우리는 두 번째 두 항을 $q(\mathbf{z})$ 가 $p(\mathbf{z})$ 의 전체 분포와 매칭되게 하는 적대적 학습 과정으로 대체하였다.

본 섹션에서 우리는 구체적인 Prior 분포 $p(\mathbf{z})$ 를 Coding 분포에 투사하는 능력에 대해 AAEVAE를 비교해볼 것이다.

위 그림에서 E부분은 MNIST 숫자 데이터 셋을 학습한 AAE로부터 테스트 데이터의 Coding Space $\mathbf{z}$ 를 보여주며, 이 때 구형의 2차원 Gaussian Prior 분포가 Hidden Codes $\mathbf{z}$ 에 투사되었다.

A부분을 보면, 학습된 Manifold는 날카로운 변화를 보이는데 이는 Coding Space가 채워져 있고 ‘구멍’은 존재하지 않음을 의미한다. 실제로 Coding Space의 날카로운 변화는 $\mathbf{z}$ 내부에서 덧붙여져 생성된 이미지들이 데이터 Manifold 위에 있음을 의미한다.

반대로 C부분을 보면, VAE의 Coding Space는 AAE의 그것과 같은 구조를 보인다는 것을 확인할 수 있다. 이 경우 VAE가 대체로 2차원의 정규 분포의 형태와 일치한다는 것을 알 수 있는데, 그러나 어떠한 데이터 포인트도 Coding Space의 일부 Local Region에 매핑되지 않은 것을 볼 때 VAEAAE 만큼 데이터 Manifold를 잘 포착하지 못했다는 것을 알 수 있다.

B, D 부분을 보면 AAEVAE의 Coding Space의 투사된 분포가 10개의 2차원 Gaussian의 혼합임을 확인할 수 있다. AAE는 성공적으로 Aggregated PosteriorPrior 분포에 매칭시켰다. 반대로 VAE는 10개의 Gaussian 혼합과는 구조적으로 다른 결과를 보여주는데, 이는 VAE가 앞서 언급하였듯이 분포의 Modes를 매칭시키는 데 집중하기 때문인 것으로 파악된다.

VAE에서 Monte-Carlo Sampling로 KL divergence를 Back-propagate하기 위해서는 Prior 분포의 정확한 함수 형태를 알고 있어야 한다. 그러나 AAE에서는 $q{\mathbf{z}}$ 를 $p(\mathbf{z})$ 에 매칭시키기 위해 Prior 분포로부터 Sampling만 할 수 있으면 된다는 점이 VAEAAE의 큰 차이이다.

1.2.3 절에서 우리는 AAE가 분포의 정확한 함수 형태를 알지 못한다 하더라도 (Swiss Roll 분포와 같이) 복잡한 분포를 투사할 수 있음을 증명할 것이다.

1.2.2. Relationship to GANs and GMMNs

(생략)

1.2.3. Incorporating Label Information in the Adversarial Regularization

데이터에 Label이 존재하는 경우 우리는 학습 과정에서 이 정보를 활영하여 더 나은 형태의 Hidden Code의 분포를 얻을 수 있다. 본 섹션에서 우리는 autoencoder의 잠재 표현을 규제하기 위해 부분적 또는 완전한 Label 정보를 활용하는 방법에 대해 설명할 것이다. 이러한 구조를 살펴보기 위해 1.2.1 절에서 확인하였던 B그림을 참조해보자. 이 때 AAE는 10개의 2차원 Gaussian의 혼합에 적합한 것으로 보인다. 지금부터 우리는 이 정규분포의 혼합의 각 Mode가 MNIST의 1개의 Label을 표현한다는 것을 보이도록 할 것이다.

위 그림은 준지도 학습에 관한 학습 과정을 보여준다. 판별 네트워크의 Input에 One-Hot 벡터가 추가되어 분포의 Mode에 Label이 개입하도록 하였다. 이 One-Hot 벡터는 Class Label이 주어졌을 때 판별 네트워크의 적절한 결정 범위를 선택하는 스위치 역할을 하게 된다.

만약 Label이 존재하지 않을 경우 이 One-Hot 벡터는 다른 Class를 갖게 된다. 예를 들어, 10개의 2차원 Gaussian 혼합을 투사하는 경우, One-Hot 벡터는 11개의 Class를 갖게 될 것이다.

첫 10개의 Class의 각각은 적절한 개별 Mixture Component를 위한 결정 범위를 선택하게 된다. 모델에 Label이 없는 데이터 포인트가 주어졌을 때, Extra Class가 등장하여 정규 분포의 Full Mixture를 위한 결정 범위를 선택하게 된다.

AAE 학습의 Positive 단계에서 우리는 One-Hot 벡터를 통해 판별자에게 Mixture Component의 Label에 대한 정보를 제공한다. Label이 없는 경우를 위한 Positive Sample은 특정 Class가 아닌 Gaussian의 Full Mixture로부터 오게 된다. Negative 단계에서 우리는 One-Hot 벡터를 통해 판별자에게 학습 이미지의 Label을 제공한다.

위 그림의 A 부분을 보면, 10K labeled MNIST 예시와 40K unlabeled MNIST 예시로 학습되었고, 10개의 2차원 Gaussian의 혼합 Prior로 학습된 AAE의 잠재 표현을 보여주고 있다. 이경우 Prior의 i번째 Mixture Component는 준지도 방식으로 i번째 Class에 할당되어 있다.

B 부분을 보면, 첫 번째 3개의 Mixture Component의 Manifold를 확인할 수 있다. Class와 독립적으로 각 Mixture Component 안에 Style Representation이 일관적으로 표현되어 있음을 알 수 있다. 예를 들어 B 부분의 좌상단 부분을 보면 Upright 필기 스타일로, 우하단을 보면 기울어진 스타일로 표현되어 있음을 알 수 있다.

이 방법은 Parametric 형식 없이고 임의의 표현으로 확장될 수 있다. C 부분은 Coding Space $\mathbf{z}$ 를 묘사하고 있고, D 부분은 잠재 공간의 Swiss Roll 축을 따라 생성된 이미지를 강조하고 있다.


1.3. Likelihood Analysis of Adversarial Autoencoders

이전 섹션에서 소개되었던 실험은 오직 질적인 결과만을 보여주었다. 이번 Chapter에서는 MNIST와 Toronto Face 데이터셋에 기반하여 hold-out 이미지를 생하는 모델의 Likelihood를 비교하여 데이터 분포를 포착하는 생성 모델로서 AAE의 성능을 평가해볼 것이다.

우리는 MNIST와 TFD에 AAE를 학습시켰는데, 이 때 모델은 근본적인 Hidden Code에 고차원적 정규 분포를 투사한다. 아래 그림은 데이터셋의 일부를 가져온 것이다.

AAE의 성능을 평가하기 위해 hold-out 테스트셋을 대상으로 Log Likelihood를 계산하였다. 사실 Likelihood를 사용하여 평가를 하는 방식은 그리 직관적이지는 못한데, 왜냐하면 사실 이미지의 확률을 직접적으로 계산한다는 것은 불가능하기 때문이다. 따라서 우리는 논문1, 논문2, 논문3 에 제시되었던 것처럼 True Log Likelihood의 Lower Bound를 계산하였다.

우리는 모델이 생성한 10,000개의 Sample에 대해 Gaussian Parzen Window (Kernel Density Estimator)를 계산하였고, 이 분포 하에서 테스트 데이터의 Likelihood를 계산하였다. Parzen Window의 자유 파라미터인 $\sigma$ 는 CV를 통해 선택되었다. 다음은 테스트 결과이다.

(중략)

1.4. Supervised Adversarial Autoencoders

준지도 학습은 머신러닝에서 오래된 개념적인 문제이다. 최근에 생성 모델들은 준지도 학습에 있어서 유명한 방법론들이 외덨는데, 왜냐하면 이러한 모델들은 원칙에 의거한 방법으로 Variation의 많은 다른 잠재 요인들로부터 Class Label 정보를 구분해낼 수 있었기 때문이다.

이번 Chapter에서 우리는 먼저 완전한 지도학습 시나리오에 초점을 맞추고 이미지의 Style 정보로부터 Class Label 정보를 분리해내는 AAE의 구조에 대해 논할 것이다. 이후 우리는 이러한 구조를 1.5 절에서 준지도 학습 세팅으로 확장해볼 것이다.

우리는 Label 정보를 통합하는 측면에서, Label의 One-Hot 벡터 인코딩을 Decoder에 제공하기 위해 네트워크 구조를 변경하였다. Decoder는 Label을 식별하는 One-Hot 벡터와 이미지를 재구성하는 Hidden Code $\mathbf{z}$ 를 사용한다. 이러한 구조는 네트워크가 Hidden Code $\mathbf{z}$ 에 있는 Label과 독립적인 모든 정보를 유지하도록 만든다.

아래 그림은 Hidden Code가 15차원 Gaussian으로 설정된 MNIST 데이터셋에 학습된 네트워크의 결과를 보여준 것이다.

위 그림의 각 행은 Hidden Code $\mathbf{z}$ 는 특정한 값으로 고정되어 있지만 Label은 구조적으로 탐색되었을 때 재구성된 이미지를 나타낸다. 행 속에서 재구성된 이미지들은 상당히 Style 측면에서 일관적인 것을 알 수 있다.

이 실험에서 One-Hot 벡터는 이미지의 central 숫자와 관련있는 Label을 의미한다. 각 행의 Style 정보는 제일 왼쪽과 제일 오른쪽의 숫자의 Label에 대한 정보를 담고 있는데, 왜냐하면 양 끝의 숫자들은 One-Hot 인코딩 속에서 Label 정보가 주어지지 않았기 때문이다.

1.5. Semi-Supervised Adversarial Autoencoders

이전 Chapter에 이어 이제 AAE를 준지도학습에 사용해 볼 것이다. 이 때 준지도학습은 오직 labeled된 데이터를 사용하여 얻어진 분류 성능을 개선하기 위해 unlabeled된 데이터의 generative description을 이용하는 것을 말한다. 구체적으로, 우리는 범주형 분포를 갖는 잠재 class 변수 $\mathbf{y}$ 와 정규 분포를 갖는 연속형 잠재 변수 $\mathbf{z}$ 로부터 데이터가 생성되었다고 가정할 것이다.

[p(\mathbf{y}) = Cat(\mathbf{y}), p(\mathbf{z}) = \mathcal{N} (\mathbf{z 0, I})]
[Encoder: q(\mathbf{z, y x})]

네트워크 구조를 변경하여 AAE의 추론 네트워크가 위와 같은 Encoder를 사용하여 범주형 Class 변수 $\mathbf{y}$ 와 연속형 잠재 변수 $\mathbf{z}$ 를 모두 예측하도록 해보자. 구조는 아래 그림과 같다.

Decoder는 Class Label을 One-Hot 벡터로 활용하고 연속형 Hidden Code $\mathbf{z}$ 를 활용하여 이미지를 재구성한다. 첫 번째 적대적 네트워크는 범주형 분포를 Label Representation에 투사한다.

이 적대적 네트워크는 잠재 Class 변수 $\mathbf{y}$ 가 어떠한 Style 정보도 갖고 있지 않고 $\mathbf{y}$ 의 Aggregated Posterior 분포가 범주형 변수와 일치한다는 것을 보장한다. 두 번째 적대적 네트워크는 정규 분포를 Style Representation에 투사하여 잠재 변수 $\mathbf{z}$ 가 연속형 정규 변수라는 것을 보장한다.

적대적 네트워크와 autoencoder 모두 3가지 단계를 통해 학습된다. 3단계는 Reconstruction, Regularization, Semi-supervised Classification 단계를 의미한다.

Reconstruction 단계에서, autoencoder는 unlabeled 미니배치에 대해 Input의 Reconstruction Error를 최소화하기 위해 Encoder와 Decoder를 업데이트한다.

Regularization 단계에서 각 적대적 네트워크는 먼저 범주형 및 정규 Prior를 이용하여 생성된 진짜 Sample과 autoencoder에 의해 계산된 Hidden Code인 생성된 가짜 Sample를 구별하기 위해 판별 네트워크를 업데이트한다. 이후 적대적 네트워크는 판별 네트워크를 혼란시키기 위해 생성자를 업데이트한다.

[q(\mathbf{y x})]

마지막으로 Semi-supervised Classification 단계에서 autoencoder는 Labeled 미니배치에 대한 Cross-Entropy Cost를 최소화하기 위해 위 함수를 업데이트한다.

아래 표는 테스트 결과이다.

1.6. Unsupervised Clustering with Adversarial Autoencoders

어떠한 지도 없이 unlabeled 데이터로부터 강력한 표현을 학습할 수 있는 방법은 없을까? 이번 Chapter에서는 AAE가 오직 비지도 방법으로 연속형의 잠재 Style 변수로부터 이산적인 Class 변수를 구분해낼 수 있다는 것을 보여줄 것이다.

네트워크 구조는 사실 이전 Chapter에서 보았던 것과 유사하다. 다만, 이 때 준지도 분류 단계를 제거하였기 때문에 labeled 미니배치에 대해 네트워크를 학습하는 과정은 존재하지 않는다.

[q(\mathbf{y x})]

또 다른 점은, 위 추론 네트워크가 데이터가 속하는 클러스터를 의미하는 카테고리의 수만큼의 차원을 갖는 One-Hot 벡터를 예측하는 데 사용된다는 것이다.

위 그림은 클러스터의 수가 16일 때 MNIST 데이터셋에 대한 AAE의 클러스터링 성능을 보여준다. 각 행의 첫 번째 이미지는 클러스터 Head로, Style 변수를 0으로 고정하고 Label 변수를 1-of-16 One-Hot 벡터로 설정하여 생성된 숫자이다. 나머지 이미지들은 추론 네트워크에 기반하여 상응하는 카테고리로 배정된 테스트 데이터셋 이미지이다. 우리는 AAE가 몇몇 이산적 Style을 Class Label로서 포착한 것을 알 수 있다. 예를 들어 클러스터 11, 16을 보면 6과 1이 기울어져서 클러스터 10, 15와 다른 클러스터로 분류하였다. 또 클러스터 4, 6을 보면 같은 숫자 2라도 Style에 따라 다른 클러스터로 분류되어 있음을 알 수 있다.

AAE의 비지도 클러스터링 성능을 평가하기 위해 실험을 진행하였다. 평가 방법은 다음과 같다. 학습이 끝나면 각 클러스터 i에 대해 아래와 같은 식을 최대화하는 Validation Example $x_n$ 을 찾는다.

[q(y_i x_n)]

이후 $x_n$의 Label을 클러스터 i에 존재하는 모든 포인트에 할당한다. 그리고 나서 각 클러스터에 배정된 Class Label에 기반하여 테스트 에러를 계산한다. 그 결과는 다음과 같다. Cluster의 개수가 증가할 수록 에러는 감소하는 것을 알 수 있다.

1.7. Dimentionality Reduction with Adversarial Autoencoders

고차원 데이터의 시각화는 많은 분야에서 굉장히 중요한 문제이다. 왜냐하면 데이터의 생성 과정에 대한 이해를 촉진하고 분석가로 하여금 데이터에 대한 유용한 정보를 추출할 수 있도록 만들어주기 때문이다. 데이터 시각화의 유명한 방법은 유사한 개체에 상응하는 가까운 데이터 포인트에 존재하는 저차원의 Embedding을 학습하는 것이다. 지난 10년간 t-SNE와 같은 많은 비모수적 차원 축소 기법이 제안되었다. 이러한 방법의 큰 단점은 새로운 데이터 포인트를 위한 Embedding을 찾기 위해 사용될 수 있는 모수적 Encoder가 존재하지 않는다는 것이다. autoencoder는 이러한 Embeddding에 필요한 비선형적 Mapping을 제공해줄 수 있는 흥미로운 대안이지만, non-regularized autoencoder는 manifold를 많은 domain으로 분리해버려 유사한 이미지에 대한 다른 Code를 만들어 버리는 문제점을 갖고 있다.

이번 Chapter에서는 우리는 AAE를 차원 축소와 데이터 시각화를 위한 기법으로 소개할 것이다. 이 알고리즘에서 Adversarial Regularization은 유사한 이미지들의 Hidden Code를 서로 붙여주고 autoencoder에 의해 학습된 Embedding에서 전형적으로 맞닥드리는 문제였던 manifold를 분리하는 현상을 방지한다.

우리는 $m$ 개의 Class Label이 존재하는 데이터셋을 갖고 있고, 우리는 이를 $n$ 차원으로 줄이고 싶다고 하자. 이 때 $n$ 은 시각화를 위해 2~3이 적당할 것이다. 모델의 구조를 아래와 같이 바꿔보자.

이 때 최종 Representation은 클러스터의 $n$ 차원의 분포 Representation을 $n$ 차원의 Style Representation에 추가하는 것으로 얻을 수 있다. 클러스터 Head Representation은 $m$ 차원의 One-Hot Class Label 벡터를 $m * n$ 크기의 행렬 $W_c$ 와 곱함으로써 얻을 수 있다. 이 때 $W_c$ 의 행은 SGD에 의해 학습된 $m$ 개의 클러스터 Head Representation을 나타낸다. 모든 2개의 클러스터 Head 사이의 유클리디안 거리를 규제하는 추가적인 Cost 함수를 소개할 것이다. 구체적으로 만약 유클리디안 거리가 어떤 threshold $\eta$ 보다 크면, Cost 함수는 0이 될 것이고, 그 반대의 경우라면 이 Cost 함수가 거리를 선형적으로 규제하게 될 것이다.

위 그림 중 a, b 부분을 보면 $m=10, n=2$ 일 때, 1000개/100개의 Label을 가진 MNIST 데이터셋에 대해 준지도학습 차원 축소의 결과를 보여준다. 각각 4.2, 6.08%의 준지도 분류 Error를 보여주면서 상당히 깨끗하게 숫자 클러스터를 분류한 것을 알 수 있다. 2차원 제약조건 때문에 분류 Error는 고차원 케이스만큼 좋지는 못하다. 그리고 각 클러스터의 Style 분포는 딱히 정규분포라고 보기 어렵다.

c 부분은 $n=2$ 차원, $m=20$ 개의 클러스터가 존재한다고 가정한 비지도 차원축소 결과를 나타낸다. 우리는 네트워크가 숫자 클러스터와 sub-클러스터의 깨끗한 구분을 보여준다는 것을 알 수 있다. 예를 들어 네트워크는 2개의 다른 클러스터를 숫자 1(초록 클러스터)에 할당하였는데, 이는 그 숫자가 곧거나 기울어져있는 것에 따라 결정되었다. 이는 빨간 클러스터인 숫자 2의 경우에도 동일하게 적용된다.

앞서 보았던 AAE의 구조는 또한 $n>2$ 일 때 더 큰 차원으로 이미지를 임베딩하는 데에 사용될 수 있다. 예를 들어 d 부분을 보면, 100개의 라벨이 존재할 때 $n=10$ 차원의 준지도 차원 축소의 결과를 보여준다. 이 경우 우리는 $W_c = 10 \mathbf{I}$ 로 고정했기 때문에 클러스터 Head는 10차원 simplex의 코너를 의미한다. Style Representation은 표준편차1의 10차원 정규 분포로 학습되었고, 최종 Representation을 구성하기 위해 클러스터 Head에 직접 추가되었다.

네트워크가 학습되면 10차원의 학습된 Representation을 시각화하기 위해, 10차원 Representation을 클러스터 Head가 2차원 원에 균일하게 위치하는 데이터 포인트에 mapping되는 2차원 공간에 mapping되도록 선형 변환을 사용하였다. 우리는 고차원 케이스에서 Style Representation이 정말 정규 분포를 갖는다는 것을 확인할 수 있다. 총 100개의 Label에서 이 모델은 3.9%의 분류 에러를 보여주었는데, 이는 1.5 Chapter에서 보여주었던 1.9%의 에러보다 좋지 못한 결과이다.

1.8. Conclusion

본 논문에서 우리는 확률론적 autoencoder에 있는 이산형/연속형 잠재 변수들을 위한 변분 추론 알고리즘으로서 GAN 프레임워크를 사용하였다. 이 방법은 AAE라고 명명하며, MNIST와 Toronto Face 데이터셋에 대해 경쟁력 있는 테스트 Likelihood를 얻은 생성 autoencoder로 정의할 수 있다. 우리는 이 방법이 어떻게 준지도 학습 시나리오로 확장될 수 있는지 보여주었으며 이 방법이 MNIST, SVHN 데이터셋에 대해 경쟁력있는 준지도 분류 성능을 갖는다는 것을 보여주었다. 최종적으로 우리는 이 AAE 알고리즘이 이미지의 Style과 Content를 구별하고 비지도 클러스터링이나 차원축소, 데이터 시각화 등에도 사용될 수 있다는 것을 증명하였다.


2. 핵심 내용

코드로 넘어가기 전에, AAE의 핵심에 대해서 한 번만 더 짚고 넘어가도록 하겠다.

[\mathcal{L} (\theta, \phi; \mathbf{x}^{(i)}) = -D_{KL} (q_{\phi} (\mathbf{z} \mathbf{x}^{(i)})   p_{\theta} (\mathbf{z}) ) + E_{q_{\phi} (\mathbf{z} \mathbf{x}^{(i)})} [log p_{\theta} (\mathbf{x}^{(i)} \mathbf{z}) ]]

위 식은 이 글에서 상세히 설명하였듯이, VAE의 목적함수이다. 이 때 KL-divergence는 일종의 규제 역할을 하게 되는데, 의미론적으로 보면 아래 두 분포를 유사하게 만드는 역할을 하게 된다.

[q_{\phi} (\mathbf{z} \mathbf{x}^{(i)}), p_{\theta} (\mathbf{z})]

Encoder와 Prior를 유사하게 만들어야 하는데, VAE에서는 이 두 분포에서 Sampling이 가능하고, numerical하게 계산이 가능해야 한다는 전제가 필수적이다. 그렇기 때문에 이들 분포로 정규 분포가 널리 사용되는 것이다.

AAE는 두 전제 중 2번째 전제를 불필요하게 만드는 알고리즘이다. 두 분포를 유사하게 만들기 위해 위와 같은 KL-divergence를 계산하는 것이 아니라 GAN으로 하여금 이 역할을 수행하게 만든 것이다. 즉, GAN이 KL-divergence를 대체한 것이다.

따라서 전체 Loss는 1) Autoencoder의 Reconstruction Loss, 2) Discriminator Loss, 3) Generator Loss 이렇게 3가지로 구분할 수 있다.

이제 Tensorflow로 구현을 시작해보자.


3. Tensorflow로 구현

논문을 자세히 읽어보았다면 이 알고리즘을 구현하는 방법은 꽤 다양하다는 것을 알 수 있을 것이다. 본 글에서는 MNIST 데이터셋을 이용하여 Supervised Adversarial Autoencoder를 구현해보겠다. 본 Chapter에서의 구현 방법은 굉장히 간단하고 분포를 변경하거나 Layer의 구조를 선택하는 등 변형에 있어 다소 불편하게 되어 있는데, 추후 글에서 불균형 학습을 해결하기 위해 AAE를 활용하는 방법에 대해 다루면서 개선된 코드를 제공하도록 하겠다.

autoencoder는 Deterministic하게 구성하였으며, Prior로는 정규 분포를 사용하였다. (논문에서는 Gaussian Mixture와 Swiss Roll 분포를 예로 들었다.) 모델 구조를 살펴보자.

class AAE(tf.keras.Model):
    def __init__(self, latent_dim):
        super(AAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(28*28 + 1)),
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(units=512, activation='relu'),
                tf.keras.layers.Dropout(rate=0.2),
                tf.keras.layers.Dense(units=256, activation='relu'),
                tf.keras.layers.Dropout(rate=0.2),
                tf.keras.layers.Dense(units=64, activation='relu'),
                tf.keras.layers.Dense(units=latent_dim),
            ])

        self.decoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(latent_dim+1)),
                tf.keras.layers.Dense(units=64, activation='relu'),
                tf.keras.layers.Dropout(rate=0.2),
                tf.keras.layers.Dense(units=256, activation='relu'),
                tf.keras.layers.Dropout(rate=0.2),
                tf.keras.layers.Dense(units=512, activation='relu'),
                tf.keras.layers.Dense(units=784),
            ])

        self.discriminator = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(latent_dim+1)),
                tf.keras.layers.Dense(units=64, activation='relu'),
                tf.keras.layers.Dropout(rate=0.2),
                tf.keras.layers.Dense(units=128, activation='relu'),
                tf.keras.layers.Dropout(rate=0.2),
                tf.keras.layers.Dense(units=64, activation='relu'),
                tf.keras.layers.Dropout(rate=0.2),
                tf.keras.layers.Dense(units=1),
            ])

    @tf.function
    def encode(self, x, y):
        inputs = tf.concat([x, y], 1)
        z = self.encoder(inputs)
        return z
    
    @tf.function
    def discriminate(self, z, y):
        inputs = tf.concat([z, y], 1)
        output = self.discriminator(inputs)
        return output

    @tf.function
    def decode(self, z, y, apply_sigmoid=False):
        inputs = tf.concat([z, y], 1)
        logits = self.decoder(inputs)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits

Discriminator의 마지막 Layer는 True와 Fake를 구별하기 위해 1개의 Node로 마무리하였다. Encoder의 결과물인 $\mathbf{z}$ 가 Fake z이며, Prior에서 나온 결과물이 True z에 해당한다.

AAE는 학습이 3단계로 이루어지기 때문에 Loss도 개별적으로 계산하는 것이 좋다. 아래 코드를 보면 이해가 될 것이다.

def compute_reconstruction_loss(x, x_logit):
    # Reconstruction Loss
    marginal_likelihood = tf.reduce_sum(x * tf.math.log(x_logit) + (1 - x) * tf.math.log(1 - x_logit), axis=[1])
    loglikelihood = tf.reduce_mean(marginal_likelihood)
    reconstruction_loss = -loglikelihood
    return reconstruction_loss

def compute_discriminator_loss(fake_output, true_output):
    # Discriminator Loss
    d_loss_true = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=true_output,
                                                                         labels=tf.ones_like(true_output)))
    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_output,
                                                                         labels=tf.zeros_like(fake_output)))
    discriminator_loss = d_loss_true + d_loss_fake
    return discriminator_loss

def compute_generator_loss(fake_output):
    # Generator Loss
    generator_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_output,
                                                                            labels=tf.ones_like(fake_output)))
    return generator_loss

이 때 x는 실제 데이터 Input을, x_logit은 autoencoder를 통과한 결과물이다. fake_output은 Encoder의 결과물인 z가 Discriminator를 통과한 후의 결과물이며, true_output은 Prior true_z가 Discriminator를 통과한 후의 결과물이다.

학습 함수는 아래와 같다. 이 때 기본적인 GAN의 특성상 Generator의 학습 속도가 느린 점을 반영하여 Generator는 2번 학습하도록 하였다.

@tf.function
def train_step(model, x, y, r_optimizer, d_optimizer, g_optimizer):
    # Results
    x = tf.reshape(x, [-1, 784])
    y = tf.reshape(y, [-1, 1])

    # Propagation
    with tf.GradientTape() as tape:
        z = model.encode(x, y)
        x_logit = model.decode(z, y, True)
        x_logit = tf.clip_by_value(x_logit, 1e-8, 1 - 1e-8)
        reconstruction_loss = compute_reconstruction_loss(x, x_logit)
    r_gradients = tape.gradient(reconstruction_loss, model.trainable_variables)
    r_optimizer.apply_gradients(zip(r_gradients, model.trainable_variables))

    with tf.GradientTape() as tape:
        z = model.encode(x, y)
        true_z = tf.random.normal(shape=(z.shape))
        fake_output = model.discriminate(z, y)
        true_output = model.discriminate(true_z, y)
        discriminator_loss = compute_discriminator_loss(fake_output, true_output)
    d_gradients = tape.gradient(discriminator_loss, model.trainable_variables)
    d_optimizer.apply_gradients(zip(d_gradients, model.trainable_variables))

    for _ in range(2):
        with tf.GradientTape() as tape:
            z = model.encode(x, y)
            fake_output = model.discriminate(z, y)
            generator_loss = compute_generator_loss(fake_output)
        g_gradients = tape.gradient(generator_loss, model.trainable_variables)
        g_optimizer.apply_gradients(zip(g_gradients, model.trainable_variables))

    total_loss = reconstruction_loss + discriminator_loss + generator_loss

    return total_loss

Epoch30, 잠재 차원은 4로하고, Discriminator의 학습률은 다소 낮게 설정하여 학습을 진행하여 보자. (Reference 2 강의 참조)

epochs = 30
latent_dim = 4
model = AAE(latent_dim)

r_optimizer = tf.keras.optimizers.Adam(1e-4)
d_optimizer = tf.keras.optimizers.Adam(1e-4/5)
g_optimizer = tf.keras.optimizers.Adam(1e-4)

# Train
for epoch in range(1, epochs + 1):
    train_losses = []
    for x, y in train_dataset:
        total_loss = train_step(model, x, y, r_optimizer, d_optimizer, g_optimizer)
        train_losses.append(total_loss)

    print('Epoch: {}, Loss: {:.2f}'.format(epoch, np.mean(train_losses)))

Epoch 30 이후의 Loss는 126.8까지 감소하였다. 이전 글에서 보았던 Test 함수를 통해 결과물을 확인해보자. 위쪽 그림이 원본, 아래쪽 그림이 AAE의 결과물이다.

CVAE의 결과보다 더욱 선명한 결과를 보여준 점이 만족스럽다. 튜닝에 더욱 신경을 쓴다면 더욱 좋은 결과가 있지 않을까 생각해본다.


Reference

1) 논문 원본
2) 오토인코더의 모든 것

Comment  Read more

OpenAI GPT-3 - Language Models are Few-Shot Learners(GPT3 논문 설명)

|

이 글에서는 2020년 5월 Tom B. Brown 등이 발표한 OpenAI GPT-3: Language Models are Few-Shot Learners를 살펴보도록 한다.

GPT-3을 이용한 API가 공개되어 있다.

이 논문은 총 페이지수가 75페이지 정도는 된다.. 하지만 일부분을 제외하고 대부분 여기에 적었다.


OpenAI GPT-3 - Language Models are Few-Shot Learners

논문 링크: OpenAI GPT-3 - Language Models are Few-Shot Learners

API: GPT-3을 이용한 API

Github for Data: GPT-3

초록(Abstract)

최근 연구들은, 방대한 텍스트 말뭉치(corpus)로 사전학습(pre-training)한 후 특정 task에 맞춰 미세조정(find-tuning)하는 방법을 통해, 많은 NLP task에서 상당한 발전을 이루었다. (그러나, ) 이는 그 모델의 구조에 있어서는 task 종류에 민감하지 않지만(task-agnostic), (학습) 방법은 여전히 수천, 수만의 예시 데이터를 통해 어느 task에 특화된(task-specific) 미세조정 단계를 요구한다. 이와는 대조적으로, 사람은 일반적으로 단지 몇 개의 예시, 혹은 간단한 지시사항만으로도 (현 NLP 시스템에게는 여전히 많이 어려운) 새로운 언어 task를 수행할 수 있다.
이 논문에서 우리는 언어모델의 크기를 키우는 것이 task에 대한 일반성과(task-agnostic), few-shot 성능을 높이고, 미세조정 접근법을 사용한 이전의 state-of-the-art와도 비등한 성능을 확보할 수 있음을 보인다. 구체적으로, 이전의 그 어떤 비-희박 언어모델보다 10배는 많은 1750억 개의 인자를 가지는 자기회귀(auto-regressive) 언어모델인 GPT-3 를 학습시켜, few-shot 세팅에서 성능을 측정하였다. 모든 종류의 task에 대해, GPT-3는 어떤 gradient의 update나 미세조정을 거치지 않고 오직 few-shot 설명(문제 설명 등)을 취하였다. GPT-3는 번역, 질답(QA), cloze task 등의 많은 NLP 데이터셋과, 문장에서 새로운 단어를 쓰는 단어 해석, 3자리 연산, domain adaptation 등등 수많은 task에서 강력한 수행능력을 얻었다. 그와 동시에, GPT-3가 few-shot 학습을 할 때 여전히 어려워하는 몇몇 데이터셋과 더불어 거대한 웹 데이터셋으로 학습할 때 GPT-3가 방법론적인 문제를 맞이하는 데이터셋(이 무엇이 있는지)을 확인하였다. 마지막으로, GPT-3는 어떤 기사를 사람 또는 기계가 썼는지 판별하는 문제에서 사람이 봐도 어려움을 느낄 정도의 기사를 써낼 수 있음을 보인다. 그리고, GPT-3가 넓게 보아 어떤 사회적 영향력을 가질 수 있는지를 논의한다.


1. 서론(Introduction)

최근 몇 년간 NLP 시스템에서 사전학습된 언어 표현(representations)을 사용하는 추세가 있었고, downstream transfer을 위한 task-agnostic한 방향으로 적용되었다. 먼저, 단어 벡터를 사용하는 단일 레이어 표현을 학습시켜 task-specific한 모델 구조에 입력되고, 더 강력한 표현을 얻기 위해 다중 레이어 표현을 활용하는 RNN과 문맥적 state가 사용되었다(여전히 task-specific한 모델 구조를 가졌음). 그리고 더 최근에는 사전학습된 재귀, 또는 transformer 언어 모델이 task-specific한 모델 구조의 필요성을 제거하고 직접 미세조정하는 방식이 사용되었다.

이러한 최근의 패러다임은 독해, 질답, 원문함의 등등 수많은 어려운 NLP task들에서 상당한 발전을 이루어 냈으며, 새로운 모델구조와 알고리즘에 기반하여 더 많은 진전을 이루었다. 하지만, 이 방법의 큰 한계는 모델구조가 task-agnostic하더라도, 여전히 task-specific한 데이터셋과 task-specific한 미세조정 단계를 필요로 한다는 것이다: 특정 task에서 더 강력한 성능을 위해서는 일반적으로 해당 task에 초점이 맞춰진 수천~수만개의 데이터에 대해서 미세조정을 진행해야 한다. 이러한 한계를 없애는 것은 여러 이유에서 가치가 있다.

  1. 현실적인 관점에서, 새 task마다 레이블링이 전부 되어 있는 큰 데이터셋을 필요로 하는 것은 언어모델의 활용성을 제한한다. 문법교정에서 파생되는 어떤 것이든, 추상 개념의 예시를 생성하는 것과, 짧은 이야기를 비평하는 것 등을 포함하여 광범위한 분야에서 유용한 언어task들이 있다. 이러한 많은 task들에 대해 각각 그에 맞는 큰 규모의 감독학습용 데이터셋을 구하는 것은 어렵다(그 과정이 모든 새로운 task마다 반복되어야 하는 경우에는 특히 더).
  2. 학습 데이터에 존재하는 거짓 상관관계를 활용할 수 있는 가능성이 모델의 표현력과 학습 분포의 협소함에 따라 크게 증가한다. 이는 사전학습과 미세조정 패러다임에 문제를 야기하는데, 모델이 사전학습 동안에 정보를 습득할 수 있도록 큰 크기를 갖게 설계되었지만, 아주 좁은 task 분포에 미세조정(국한)된다. 예를 들어 Pretrained transformers improve out of distribution robustness는 더 큰 모델은 분포 외 데이터를 반드시 더 잘 일반화하지는 않는다. 모델은 학습 (데이터) 분포에 너무 맞춰져 있고 그 밖의 것은 잘 일반화하지 못하기 때문에 사전학습-미세조정 패러다임 하에서는 일반화가 잘 이루어질 수 없다는 증거가 존재한다. 따라서 특정 벤치마크에서 미세조정된 모델의 성능은 해당 부분에서는 인간 수준일지 몰라도 보다 근본적인 task에서는 실제 성능이 과장되었을 수 있다.
  3. 인간은 대부분의 언어 task를 배우기 위해 대규모 감독학습용 데이터셋이 필요하지 않다 - 자연어로 된 간단한 지시문 혹은 아주 적은 수의 예시만 있어도 어떤 사람이 새로운 task를 충분히 능숙하게 수행하도록 만들 수 있다(예: 이 문장이 기쁜 혹은 슬픈 무언가를 말하는지 선택하라. 또는, 여기 용감한 행동을 하는 사람 예시 2개가 있다. 용감한 행동의 세 번째 예시를 들라). 현재 NLP 기술에서 이런 개념적인 한계를 제쳐놓더라도, 이러한 적응 능력은 현실적으로 이점이 있다 - 이는 사람을 균일하게 여러 task와 기술들을 섞거나 전환하게 할 수 있다(긴 대화문에 무언가 더 추가하는 것 등). 더 넓은 곳에서 유용하게 쓰이려면, NLP 시스템을 (사람만큼) 유동적이고 일반성을 갖도록 할 것 이다.

이러한 문제들을 다루는 가능성 있는 방법은 언어모델의 문맥에서, 모델이 학습하는 동안 여러 기술과 패턴인식 능력을 키우고, 추론 시간에는 이를 원하는 task에 빠르게 적용시키거나 인식시키는 방법인 meta-learning이다.

Examples

무감독 사전학습 동안, 언어모델은 여러 기술들과 패턴인식 능력을 키워 이를 추론 시간에 사용한다. 각 sequence에 대해 forward-pass 안에서 일어나는 내부 반복 과정을 문맥 내 학습 이라고 부른다. 이 다이어그램에서 문장들은 사전하습 동안 모델이 데이터 표현을 볼 수 있게 하지는 않지만, 모델은 어떤 하위 작업들이 한 개의 sequence 내에서 일어난다는 사실은 알 수 있다.

우리가 문맥 내 학습이라고 부르는 이것을 통해 시도하려는 최근 연구는 사전학습된 언어모델의 텍스트 입력을 task specification의 형태로 사용한다: 이 모델은 자연어 지시문과 task 설명이라는 조건 속에서 ‘다음에 무엇이 올 것인지를 예측’한다.

이 방법은 처음에는 가능성을 보였지만, 여전히 미세조정에 비하면 갈 길이 멀다 - 예를 들어 Language models are unsupervised multitask learners 는 Natural Questions에서 4%만을, 55 F1 CoQa를 달성하였는데 이는 최신 결과보다 35점이나 뒤떨어진 결과이다. Meta-learning은 명백히 언어 task를 푸는 현실적인 방법으로서 실행 가능하기 위한 상당한 개선을 요구한다.

언어모델링의 다른 최신 경향은 앞으로 갈 길을 제시할 수 있다. 최근 몇 년간 Transformer 언어 모델의 크기(parameter의 수)는 1억 개부터 3, 15, 80, 110, 170억 개까지 증가하였다. 크기가 증가할 때마다 텍스트 합성과 downstream NLP task에서 상당한 성능 개선을 보여주었고, 이러한 log loss는 많은 downstream task와 관련하여 scale에 따라 개선되는 경향이 뚜렷하다. 문맥 내 학습이 모델의 parameter 안에서 많은 기술과 task를 습득하기 때문에, 문맥 내 학습 능력은 scale에 따라 그 능력이 더 증가한다고 보는 것은 설득력이 있다.

이 논문에서, 우리는 1750억 개의 parameter를 가지는 자기회귀 언어모델(GPT-3)을 학습함으로서 이 가설을 테스트하고, 그 문맥 내 학습능력을 측정한다. 구체적으로, GPT-3을 학습셋에 직접 포함되어 있지 않은 task에 대해 빠르게 적응할 수 있는지를 테스트하도록 고안된 여러 최신 task를 포함하여 20개 이상의 NLP 데이터셋에 대해 평가를 진행한다. 각 task에 대해, GPT-3를 3가지 조건에서 평가한다:

  • few-shot learning, 혹은 모델의 10/100개의 문맥창(context window)에 맞는 설명 또는 예시(demonstration)을 허용하는 문맥 내 학습 조건,
  • one-shot learning, 딱 한 개의 예시만을 허용하는 조건,
  • zero-shot learning, 어떤 예시도 허용되지 않고, 모델에 주어지는 것은 오직 자연어로 된 지시문인 조건.

GPT-3은 전통적인 미세조정 조건에서 평가할 수도 있지만, 이는 추후 연구로 남겨둔다.

Examples

위 그림은 우리가 연구한 조건에서, 모델이 단어에서 관련 없는 기호를 제거하도록 하는 task에서 few-shot learning 결과를 보여준다. 모델 성능은 자연어 지시문이 포함되면, 모델 문맥에 주어지는 예시의 수($K$)가 증가하면 높아진다. Few-shot learning 성능은 모델 크기에 따라서도 크게 증가한다. 모델의 크기와 문맥 내 예시의 수와 관련한 일반적인 경향은 우리가 연구하는 대부분의 task에 대해서도 성립한다. 우리는 이러한 “학습” 곡선은 어떤 가중치 업데이트나 미세조정을 거치지 않았으며, 단지 조건으로 주어지는 예시의 수를 늘렸을 뿐이다.

대략, NLP task들에서 GPT-3은 zero-shot과 one-shot 조건에서 훌륭한 결과를, few-shot 조건에서도 SOTA와 비슷하거나 경우에 따라서는 넘어서는 결과를 보여주었다. 예로, GPT-3은 CoQA, zero-shot에서 81.5 F1을(심지어 기존 SOTA는 미세조정 모델이다), few-shot에서는 85.0 F1을 달성했다. 비슷하게, TriviaQA에서는 zero-shot에서는 64.3%, one-shot에서는 68.0%, few-shot에서는 71.2%로, few-shot의 경우는 같은 closed-book 세팅에서 SOTA를 달성한 미세조정 모델의 것과 같다.

GPT-3은 또한 단어해독(순서 맞추기), 연산 수행, 정의된 것을 단 한 번만 보고서 문장에서 새로운 단어를 사용하는 등 즉석에서 추론하는 task와 빠른 적응력을 측정하는 task들에서 one-shot과 few-shot에서 숙련된 결과를 내놓음을 보여주었다. 또한 few-shot 세팅에서, GPT-3은 사람이 보기에도 주어진 기사가 인간 혹은 기계가 썼는지 분간하기 어려운 기사를 생성해낼 수 있다.

그와 동시에, GPT-3이 few-shot에서 어려움을 겪는 몇몇 task를 확인하였다. 여기에는 자연어 추론문제인 ANLI, 독해 데이터셋인 RACE와 QuAC 등이 포함된다. 이러한 한계를 포함하여 GPT-3의 장단점을 보여줌으로써, 우리는 언어모델에서 few-shot learning의 연구를 촉진하고 어떤 개선이 가장 필요한지 관심을 모을 수 있을 것이다.

전체 결과의 느낌은 아래 그림에서 볼 수 있다. zero-shot 성능은 모델 사이즈에 따라 천천히 증가하는 것에 비해, few-shot은 더 가파르게 증가하며, 더 큰 모델일수록 문맥 내 학습에서 월등함을 보여준다. 논문에서 그림 3.8을 보면 SuperGLUE에서 더 자세한 분석을 볼 수 있다.

Examples

또, 데이터 오염(학습 데이터셋과 테스트 데이터셋이 겹치는 문제)에 대해서도 체계적으로 연구했다 - Common Crawl 등을 통해 얻은 데이터셋에서 거대한 모델을 학습시킬 때 생기는 문제로, 웹에서 모은 데이터가 있기 때문에 가질 수 있는 문제이다(즉, train/test set이 우연히 겹치는 부분이 적지 않을 수 있다). 이 논문에서 데이터 오염과 그 왜곡 효과를 측정하는 체계적 도구를 개발했다. GPT-3의 성능은 대부분의 데이터셋에서 데이터 오염에 미미한 영향만을 받았지만, 우리는 약간의 데이터셋에서 오염이 충분히 큰 영향을 가질 수 있음을 보이고, 또 그 심각도에 따라 그러한 데이터셋에는 별표(*)를 하여 결과에 포함하지 않았다.

위의 모든 것에 더하여, 우리는 zero, one, few-shot 세팅에서 GPT-3의 성능을 비교하기 위해 1.25억~130억 개의 parameter를 가지는 작은(?) 모델을 학습시켰다. 폭넓게, 대부분의 데이터셋에서 모든 3가지 조건에서 상대적 smooth scaling을 찾았다: 주목할 만한 패턴은 zero, one, few-shot 성능은 종종 모델 크기에 따라 증가하며, 이는 더 큰 모델은 더 능숙한 mera-learner임을 시사한다.

마지막으로, GPT-3에 의해 보여진 광범위한 역량에서, 우리는 편향성, 공정성, 나아가 사화적 영향력과, 에 점에서 GPT-3의 특징에 대한 예비 분석을 시도하고 논의할 것이다.

이 논문의 남은 부분은 다음과 같이 구성된다. Section 2에서는 GPT-3을 학습시키고 평가하는 접근법과 방법을 소개한다. Section 3에서는 zero, one, few-shot 세팅에서 전체 범위의 task에 대한 결과를 보여준다. Section 4에서는 데이터 오염에 대한 문제를, Section 5에서는 GPT-3의 한계에 대해 논한다. Section 6에서는 GPT-3의 영향력을, Section 7에서는 관련 연구를 보고 Section 8에서는 결론을 다룬다.


2. 접근법(Approach)

모델, 데이터, 학습 등 기본적인 접근법은 Language models are unsupervised multitask learners와 비슷하지만, 모델의 크기를 키웠고, 데이터셋의 크기와 다양성, 학습량을 전부 늘렸다. 문맥 내 학습도 위 논문(GPT-2)와 비슷하지만, 이 논문에서는 문맥 내 학습을 위해 세팅을 다르게 하는 체계적인 방법을 보인다. 그래서, GPT-3을 평가하거나 원칙적으로 평가 가능할 수 있게 하는 여러 세팅들을 정의하고 대조하는 것으로 이 section을 시작한다. 이러한 세팅은 task-specific한 데이터에 얼마나 의존하려는 경향이 있는지를 보는 것이라 할 수 있다. 구체적으로, 4가지로 나누어 볼 수 있다:

  1. 미세조정(Fine-Tuning, FT)은 최근에 가장 일반적인 접근법으로, 사전학습된 모델을 원하는 task에 맞도록 감독학습 데이터셋으로 학습시키는 과정을 포함한다. 보통 수천~수만 개의 레이블링된 예시를 필요로 한다.
    • 이러한 미세조정(fine-tuning)의 주된 장점은 많은 벤치마크에서 강력한 성능을 가지는 것이다.
    • 주된 단점은 모든 task마다 큰 데이터셋을 새로이 필요로 하며, 분포 외의 데이터에 대해서는 일반화를 잘 못하며, 학습 데이터에 거짓/비논리적인 특성이 있는 경우 이를 흡수할 수도, 사람에 비해 불공정한 비교로 이어질 수도 있다.
    • 이 논문에서는 task-agnostic한 성능을 가지는 것이 목적이기 때문에 GPT-3은 미세조정을 진행하지 않는다. 단, 나중에는 추후 연구로 괜찮은 방향이기에 미세조정을 사용할 수도 있다.
  2. Few-Shot(FS)은 모델이 추론 시간에서 단 몇 개의 예시만을 볼 수 있되 가중치 업데이트는 허용되지 않는 조건이다. 그림 2.1에서 보듯이, 일반적인 데이터셋에서 예시는 문맥과 원하는 답이 있고(예로는 영어-독일어 번역), few-shot은 단 $K$개의 문맥과 답이 주어진다. 이후 마지막으로 단 한 개의 문맥이 주어지면, 모델은 (정확한) 답을 생성해 내야 한다.
    • 우리는 보통 $K$는 10~100 정도로 설정했고, 이는 모델의 문맥창($n_{ctx} = 2048$)에 잘 맞을만한 개수이다.
    • few-shot의 주된 장점은 task-specific한 데이터에 대한 필요를 크게 줄여주며(즉, 몇 개 없어도 됨) 지나치게 크고 좁은 분포를 갖는 미세조정용 데이터셋을 학습할 가능성을 줄일 수 있다.
    • 주된 단점은 이 방법은 미세조정 모델의 SOTA에는 한참 뒤떨어지는 성능을 갖는다는 점이다. 또한, 적은 수라 해도 여전히 task-specific한 데이터를 필요로 한다.
    • 이름에서 알 수 있듯이, 언어 모델에서 few-shot learning은 기계학습에서 다른 문맥에서 사용된 few-shot learning과 연관이 있다 - 둘 다 넓은 분포를 갖는 task에 기반한 학습 방법이며(이 경우에는 사전학습 데이터에서) 새로운 task에 빠르게 적응하는 방법이다.
  3. One-Shot(IS)은 few-shot과 비슷하나 단 한 개의 예시와, task에 대한 자연어 지시문이 제공된다는 점이 다르다. one-shot이 few나 zero-shot과 다른 점은 이 방법이 사람이 소통하는 방법과 가장 흡사한 방법이기 때문이다.
    • 예를 들어, Mechanical Turk와 같이, 사람에게 데이터셋을 만들어내라는 요청을 할 경우, 보통 task에 대해 하나의 예시를 주게 된다(물론 지시문과 함께). 이와 대조적으로, 예시가 아예 없다면 task의 내용이나 형식에 대해 소통하는 것이 어려울 수 있다.
  4. Zero-Shot(0S)은 one-shot과 비슷하지만 단 하나의 예시도 없으며, 모델은 단지 task에 대한 지시문만을 받는다.
    • 이 방법은 최대의 편의성을 갖는데, robustness나, 거짓 상관관계 등을 걱정할 필요가 없다.
    • 단, 가장 어려운 조건이다.
    • 어떤 경우에는 사람조차도 예시가 없으면 task에 대해 제대로 이해하지 못할 수도 있고, 따라서 이 조건은 “불공정할 정도로 어렵다”고 할 수 있다.
    • 예를 들어, 누군가 ‘200m 달리기를 위한 세계기록 표를 만들라’고 한다면, 이 요청은 상당히 모호할 수 있는데, 표가 어떤 형식을 가져야 하고, 어떤 내용이 들어가야 하는지에 대한 명확한 설명이 없기 때문이다.
    • 그럼에도 불구하고, 적어도 zero-shot의 조건은 사람이 task를 수행하는 것과 가장 가까운 방식이다 - 예로, 아래 그림에서 사람은 단지 텍스트 지시문만을 보고도 무엇을 해야 할지 알 수 있을 것이다.
Examples

위 그림은 영어-독일어 번역 예시에 대한 4가지 방법을 보여준다. 이 논문에서는 zero, one, few-shot에 집중하는데, 경쟁 상대로 보는 것이 아닌 비교 상대로 보기 위함이며, 다른 문제 세팅에서 특정 벤치마크에서의 성능과 sample의 효율성 사이에서 균형을 찾는다. 특히 few-shot 결과는 미세조정 모델보다 아주 약간 못한 결과를 보임을 강조한다. 궁극적으로, 하지만, 사람이 하는 것과 거의 비슷한 one-shot이나 zero-shot에서는 추후 연구의 중요한 목표로 둔다.

Section 2.1~2.3에서는 모델, 학습 데이터, 학습 과정을 자세히 설명한다. Section 2.4에서는 어떻게 few, one, zero-shot 평가를 진행했는지를 자세히 말한다.

2.1 Model and Architectures

모델 초기화, 사전 정규화, 내부 서술된 되돌릴 수 있는 토큰화 과정을 포함하여 GPT-2와 동일한 구조를 갖지만, Sparse Transformer와 비슷하게, Transformer 레이어 내에서 밀집/희박한 국소 집중 패턴을 번갈아 사용하였다. 모델 크기에 따른 기계학습 성능의 의존도를 살펴보기 위해, 1.25억 개부터 1750억 개의 parameter를 가지는 8가지 다른 크기의 모델을 학습시켰고, 가장 큰 마지막 것은 GPT-3라 부르는 모델이다. 이전 연구에서 충분한 학습 데이터를 갖고 있으면 validation loss는 크기에 대한 함수로 부드러운 멱법칙을 따를 것이라 하였다; 여러 다른 크기의 학습 모델은 validation loss와 downstream 언어 과제들에 대한 가설을 모두 검증할 수 있게 해 준다.

Examples

위 표는 8가지 모델의 크기와 구조를 보여준다. $n_{params}$는 학습가능한 parameter의 전체 개수, $n_{layers}$는 레이어 수, $d_{model}$은 각 bottleneck 레이어 안에 있는 unit의 수(이 논문에서, 항상 $d_{ff} = 4 \times d_{model}$이다), $d_{head}$는 각 attention head의 차원이다. 모든 모델은 $n_{ctx} = 2048$ 토큰을 가진다.
각 GPU 노드 등 데이터 이동을 줄이기 위해 깊이와 너비에 따라 여러 GPU에 모델을 나누었다. 또한 각 모델의 정확한 parameter 수는 계산효율성과 GPU의 부하 균형에 맞게 정했다. 이전 연구는 validation loss는 합리적인 범위 내에서는 parameter의 작은 차이에 크게 민감하지 않음을 시사한다.

2.2 Training Dataset

언어모델을 위한 데이터셋은 빠르게 확장되어 거의 1조 개의 단어로 구성된 Common Crawl 데이터셋에서 정점을 찍고 있다. 데이터셋의 이 크기는 같은 데이터를 두 번 쓰지 않아도 이 논문의 가장 큰 모델을 학습시키에도 충분하다. 하지만, 필터링을 전혀 또는 거의 거치지 않은 Common Crawl 데이터는 조정된 데이터셋에 비해 낮은 품질을 갖는 경향이 있다. 따라서, 데이터셋의 품질을 높이기 위한 3가지 방법이 사용되었다:

  1. 고품질 출처와 연관성이 있는 것만을 받아 정제하고
  2. 과적합을 정확히 측정하기 위한 온전성을 남겨두고 중복을 피하기 위해 문서 수준에서 중복 제거 작업을 수행하였고,
  3. 다양성을 증가시키기 위해 고품질 출처로 알려진 말뭉치를 추가하고 섞어서 사용하였다.

자세한 것은 부록 A를 참조하라. 추가 데이터셋으로는 WebText, Books1와 Books2, 영어 위키피디아가 있다.

아래 표는 혼합된 데이터셋 구성을 보여준다. CommonCrawl의 경우 45TB의 데이터셋을 정제하여 570GB로 만들었다(4천억 개의 byte pair encoded 토큰으로 구성됨). 학습에서, 데이터의 사용은 데이터셋의 크기에 비례하지 않고, 고품질일수록 많이 선택되었다. 이는 고품질과 과적합 사이의 trade-off가 있는 것이다.

Examples

인터넷에서 가져온 데이터로 사전학습한 언어모델에서 가장 큰 방법론적 문제는, 특히 굉장한 양의 내용을 기억하려는 큰 모델에서, 사전학습 동안 무심코 본 정보를 test나 dev set에서 다시금 마주하게 되는 (데이터) 오염 문제이다. 이러한 오염을 줄이기 위해서, 논문에서 살펴보는 모든 벤치마크의 dev/test set와 겹치는 어떤 부분이든 제거하려는 노력을 하였다.
안타깝게도, 일부 겹치는 부분을 무시하는 버그기 필터링 과정에서 있었고, 학습의 비용 문제로 인해 다시 모델을 학습하는 것은 비현실적이었다. 그래서 이 영향을 Section 4에서 살펴보고, 데이터 오염을 더욱 공격적으로 제거하는 추후 연구를 할 것이다.

2.3 Training Process

일반적으로 큰 모델일수록 더 큰 batch를 쓰지만, learning rate는 더 작게 해야 한다. 학습하는 동안 gradient noise scale을 측정하고 이를 batch size를 선택하는 가이드로 사용하였다. Table 2.1에서 사용한 parameter setting을 보여 준다. 메모리 부족 없이 더 큰 모델을 학습시키기 위해, 각 행렬곱 내에서 모델 병렬화와 네트워크의 레이어에서 모델 병렬화를 섞어 사용하였다. 모든 모델은 고대역 클러스터의 일부분으로 V100 GPU에서 학습되었다. 자세한 학습 과정은 부록 B에 있다.

2.4 Evaluation

few-shot learning에서, evaluation set의 각 예제에 대해 training set에서 조건으로 $K$개의 샘플을 뽑아 평가하였다(task에 따라 1~2개의 개행문자를 구분자로 사용함). LAMBADA와 StoryCloze에는 감독학습 셋이 없으므로 train-eval set 대신 dev-test set을 사용했다. Original Winograd는 set이 하나뿐이므로 그냥 같은 곳에서 뽑았다.

$K$는 모델의 context window이 허용하는 범위에 따라 $0 \sim \infty$가 될 수 있는데, 모든 모델에 대해 $n_{ctx}=2048$이고 보통 $10 \sim 100$개의 샘플에 맞는다. 큰 $K$가 항상 좋지는 않기 때문에 분리된 dev/test set이 있다면 적은 $K$를 dev set에 사용하여 최고의 값을 test set에서 사용했다. 일부(부록 G)에 대해서는 demonstration에 더해 자연어 지시(prompt)를 사용했다.

각 모델 크기와 학습 세팅(zero, one, few-shot)에 따라 test set에서의 최종 결과가 나와 있다. test set이 비공개인 경우에는, 테스트를 위해 모델을 올리는 것은 모델이 너무 커서 불가능하기에 dev set으로 결과를 얻었다. 적은 수의 데이터셋에 대해서는 test 서버에 제출했고 단 200B개의 few-shot 결과만 제출하였으며, 다른 모든 것에 대해 dev set 결과를 얻었다.


3. 결과(Results)

아래 그림은 10만 개 정도의 parameter를 갖는 작은 모델을 포함한 결과를 보여 준다.

Examples

언어모델링 성능은 학습 계산량의 효율에 따라 지수적으로 증가한다. 누군가는 cross-entropy loss에서 이러한 발전이 단지 겉만 그럴싸한 학습 말뭉치의 작은 차이로 인한 것이라 생각할 수 있지만, 그런 것이 아닌 일관된 개선을 의미함을 보일 수 있다.

광범위한 데이터셋에 대해 Section 2에서 언급된 8개의 모델(GPT-3과 작은 모델들)을 테스트하였고, 비슷한 데이터셋끼리 묶어 9개로 나누었다.

  • Section 3.1에서는 전통적인 언어모델링 task와 비슷한 것들(Cloze 등), 문장/문단 완성 task에 대해 평가하였다.
  • Section 3.2에서는 “closed book” QA task를,
  • Section 3.3에서는 언어 간 번역능력을,
  • Section 3.4에서는 Winograd Schema와 같은 task를,
  • Section 3.5에서는 상식추론/질답을,
  • Section 3.6에서는 독해를,
  • Section 3.7에서는 SuperGLUE를,
  • Section 3.8에서는 자연어추론(NLI)를,
  • Section 3.9에서는 즉석 추론, 적응 기술, open-ended 텍스트 합성 등 문맥 내 학습능력을 탐색하는 추가적인 task를 개발하였다.

3.1. Language Modeling, Cloze, and Completion Tasks

전통적인 언어모델링 task에 대해 GPT-3의 성능을 측정하였다. 흥미있는(of interest) 한 개의 단어를 예측하거나, 문장/문단을 완성하거나, 텍스트를 완성시키는 가능한 것들 중 하나를 선택하는 task이다.

3.1.1 Language Modeling

Penn Tree Bank(PTB)에 대해 zero-shot perplexity를 계산했으나, Wikipedia와 연관된 4개의 task는 학습데이터에 포함된 부분이 있기 때문에 결과에서 생략하였다.
이전 SOTA보다 15 point 앞서는 20.50 Perplexity를 기록하였다. 여기서는 데이터셋의 명확한 구분이 없기 때문에 zero-shot만 테스트했다.

Penn Tree Bank
Penn Tree Bank는 말뭉치 주석(corpus annotation) 중 구문 주석(syntactic annotation) 말뭉치의 일종으로, 기존의 구조 분석보다 정교한 tree structure의 집합이다. 330만 어절 이상의 월스트리트 저널(Wall Street Journal (WSJ))의 문장들로 이루어져 있으며 공개되어 있는 데이터셋이다. Treebank-3은 1999년에 나왔으며 2499개의 story부터 만들어진 98732개의 syntactic annotation story를 포함한다.

3.1.2 LAMBADA

LAMBADA dataset은 텍스트에서 장거리 의존성 모델링을 테스트한다 - 모델은 문맥 문단을 읽어 문장의 마지막 단어를 예측해야 한다. 최근 언어모델의 크기를 키우는 것은 이 어려운 벤치마크에 대한 성능을 경감시킨다는 것이 제안되어 왔다. 모델을 두 배 키워도 1.5% 정도의 향상만이 있었다.
그러나 큰 모델은 여전히 유효한 연구 방향이다. GPT-3은 이전 SOTA보다 8% 향상된 76%의 정확도를 보였다.

LAMBADA는 이 데이터셋에서 흔히 나타나는 문제를 다루는 방법으로 few-show learning의 유연성을 설명해줄 수도 있다. LAMBADA 문장의 완성은 언제나 문장의 마지막 단어이지만, 표준 언어모델은 이 부분을 알 수 없다. 이 문제는 과거에 stop-word로 다루어져 왔다. Few-shot learning 세팅은 대신 이 task를 cloze-test처럼 “frame”화 하여 언어모델이 딱 하나의 단어가 필요함을 알도록 할 수 있다. 여기서 ‘빈칸 채우기’ 형식을 활용했다.

Alice was friends with Bob. Alice went to visit her friend _____. -> Bob
George bought some baseball equipment, a ball, a glove, and a _____. -> 

이렇게 했을 때 GPT-3은 86.4%의 정확도를 보여 이전보다 18% 향상된 결과를 얻었다. 여기서 few-show 성능은 모델의 크기에 따라 크게 향상될 수 있다는 것을 알 수 있다.

LAMBADA dataset
LAMBADA dataset(LAnguage Modeling Broadened to Account for Discourse Aspects)은 단어 예측 task로서 계산모델이 텍스트를 이해했는지를 판별할 수 있는 dataset이다. 이 데이터셋은 전체 문맥이 주어졌을 때 마지막 단어가 무엇일지를 사람이 맞추어 생성된 서술형 구절들로 이루어져 있다. 계산모델은 여기서 단지 지역적인 문맥뿐 아니라 더 넓은 범위의 담화에서 정보들을 얻어 사용할 수 있어야 한다.

예시:

Context: “Yes, I thought I was going to lose the baby.” “I was scared too,” he stated, sincerity flooding his eyes. “You were ?” “Yes, of course. Why do you even ask?” “This baby wasn’t exactly planned for.”
Target sentence: “Do you honestly think that I would want you to have a ?”
Target word: miscarriage

3.1.3 HellaSwag

HellaSwag dataset은 어떤 이야기나 지시문 집합의 끝맺음 문장으로 어느 것이 가장 좋을지를 선택하는 문제를 다룬다. 사람에게도 살짝 어려운 문제이지만(95.6% 정확도), GPT-3은 78.1%(one-shot), 79.3%(few-shot)을 달성하며 종전의 미세조정된 15억 개의 parameter를 가진 모델(75.4%)를 뛰어넘었다. 그러나 여전히 미세조정된 multi-task 모델인 ALUM(85.6%)에 비하면 낮은 점수이다.

HellaSwag
HellaSwag dataset은 task를 다루는 데 있어 상식(commonsense)가 필요하다. video caption인 ActivityNet Captions dataset에서의 데이터만 사용한다(original SWAG dataset은 LSMDC의 caption 데이터도 포함한다). 시간정보를 포함하는 서술(temporal description)과 각 caption에 대한 activity label을 포함한다.

예시:

Pick the best ending to the context. How to catch dragonflies. Use a long-handled aerial net with a wide opening. Select an aerial net that is 18 inches (46 cm) in diameter or larger. Look for one with a nice long handle.
a) Loop 1 piece of ribbon over the handle. Place the hose or hose on your net and tie the string securely.
b) Reach up into the net with your feet. Move your body and head forward when you lift up your feet.
c) If possible, choose a dark-colored net over a light one. Darker nets are more difficult for dragonflies to see, making the net more difficult to avoid.
d) If it’s not strong enough for you to handle, use a hand held net with one end shorter than the other. The net should have holes in the bottom of the net.

3.1.4 StoryCloze

StoryCloze 2016 dataset에서는 few-shot에서 종전 기록보다 4.1% 낮은 87.7%을 기록하였으나, zero-shot에서는 거의 10%가량 향상되었다(83.2%).

StoryCloze 2016 dataset
StoryCloze 2016 dataset은 5문장의 긴 story에서 가장 적절한 끝맺음 문장을 선택하는 문제로, 3744개의 test set을 보유하고 있다.
Context: Karen was assigned a roommate her first year of college. Her roommate asked her to go to a nearby city for a concert. Karen agreed happily. The show was absolutely exhilarating.
Right Ending: Karen became good friends with her roommate.
Wrong Ending: Karen hated her roommate.

3.2. Closed Book Question Answering

이 Section에서는 광범위한 사실적 지식(broad factual knowledge)에 대한 QA 능력을 측정한다. 가능한 질의가 방대하기 때문에, 이 task는 보통 연관된 텍스트를 찾는 정보 검색 시스템에 더해, 질문과 검색한 텍스트가 주어지면 답변을 생성하는 모델을 함께 사용하는 접근법을 사용해 왔다. 이러한 세팅은 “open-book” 과 같은 방식으로 쓸 수 있다. 최근에는 보조적인 정보라는 조건 없이도 질문에 대한 답변을 잘 생성하는 충분히 큰 모델이 제안되었다. 이러한 방식은 “closed book”으로 불린다.
여기서는 GPT-3을 Natural Questions, WebQuestions, TriviaQA 3개에 대해 측정하였다. 기존의 closed-book 세팅에 더해 few, one, zero-shot 평가를 진행(더 엄격한 조건)하였다.

아래에 결과가 있다.

  • TriviaQA: zero-shot에서 64.3%, one-shot에서 68.0%, few-shot에서 71.2%를 달성하였다. zero-shot 결과는 T5-11B를 14.2% 차이로 능가하는 성능을 보여 주었다. one-shot에서도 3.7%의 차이로 SOTA를 제치는 등의 결과를 얻었다.
  • WebQuestions(WebQs): fine-tune 모델에 비하면 조금 낮지만 비슷한 수준의 성능을 보여준다.
  • Natural Questions(NQs): WebQs에서와 비슷하게 zero~few-shot에서의 큰 발전은 분포의 이동을 제안할 수 있으며, TriviaQA나 WebQS에 비해 더 낮은 경쟁력을 보여주는 것을 설명할 수 있다. 특히, NQs의 질문은 Wikipedia에서 아주 fine-grained한 지식을 물어보기에 특히 이는 GPT-3의 용량과 광범위 사전학습 분포의 한계를 테스트해볼 수 있다.
Examples

TriviaQA
TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension은 650k개의 질문-답변-증거(인용구) triple를 포함하는 독해를 위한 데이터셋이다. 95k개의 질문-답변 쌍을 포함한다. 답변을 하기 위해서는 문장 여럿을 살펴봐야 한다.
Question: The Dodecanese Campaign of WWII that was an attempt by the Allied forces to capture islands in the Aegean Sea was the inspiration for which acclaimed 1961 commando film?
Answer: The Guns of Navarone
Excerpt: The Dodecanese Campaign of World War II was an attempt by Allied forces to capture the Italianheld Dodecanese islands in the Aegean Sea following the surrender of Italy in September 1943, and use them as bases against the German-controlled Balkans. The failed campaign, and in particular the Battle of Leros, inspired the 1957 novel The Guns of Navarone and the successful 1961 movie of the same name.

WebQuestions(WebQs)
Freebase를 사용하여 대답할 수 있는 6642개의 질문-답변 데이터셋이다. 2013년까지 웹에서 자주 질문이 이루어진 것들로 이루어져 있다.
targetValue: Jazmyn Bieber / Jaxon Bieber
utterance: What is the name of justin bieber brother?

Natural Questions(NQs)
구글 검색엔진에서 모아진 익명으로 이루어진 질의들로, 5개의 상위 결과에서 Wikipedia와 연관한 질문, 일반적으로 문단 수준인 긴 답변, 1개 또는 그 이상의 객체로 이루어진 짧은 답변으로 구성되며, 없을 경우 None으로 표시된다. 307k/7.8k/7.8k개의 train/dev/test 데이터가 있다.
Question: where is blood pumped after it leaves the right ventricle?
Short Answer: None
Long Answer: From the right ventricle , blood is pumped through the semilunar pulmonary valve into the left and right main pulmonary arteries ( one for each lung ) , which branch into smaller pulmonary arteries that spread throughout the lungs.

3.3. Translation

GPT-2에서는 용량 문제 때문에 영어만 존재하는 데이터셋을 만들기 위한 필터를 사용하였다. 그럼에도 다언어 역량을 가질 수 있음을 보였는데, GPT-3에서는 훨씬 더 커진 크기 덕분에 실제로 여러 언어에 대한 표현(representation)을 얻을 수 있게 되었다(물론, 연구로 더 많은 향상이 있을 수 있다). GPT-3 학습을 위해 사용된 Common Crawl 데이터는 사실 93%가 영어이지만 7%는 다른 언어를 포함한다(자세한 것은 여기서 볼 수 있음). 더 나은 번역 능력을 얻기 위해 분석을 영어 이외에도 널리 연구된 독일어와 루마니아어(Romanian)에 대해 수행했다.

기존에는 두 언어를 연결하기 위해 단일 언어 데이터셋의 쌍을 back-translation으로 사전학습을 결합시켜 사용했다. 이에 반해 GPT-3은 여러 언어들이 단어, 문장, 문서 단위로 자연스레 섞인 데이터를 그냥 학습하였다. 또한 어느 특별한 문제만을 위해 특별 제작되지 않았다.

결과는 기존 NMT 모델에 비해 좋지 않지만, 단 한 개의 예시만이 주어지는 task에서는 7 BLEU score만큼 향상되었으며 이전 연구와 거의 근접한 성능을 보여준다. 아래 그림들에서 자세한 결과를 볼 수 있다.

Examples

GPT-3은 다른 무감독 NMT에 비해 다른 언어 → 영어로의 번역은 아주 잘 하지만 그 반대는 상당히 성능이 낮다. 이는 GPT-2의 byte-level BPE tokenizer가 거의 영어에 맞춰져 있기 때문으로 보인다.

그리고, 여전히 모델이 커질수록 성능이 증가함에는 변함이 없다.

3.4. Winograd-Style Tasks

Winograd Schemas Challenge는 자연어 문장에서 특정 대명사가 어떤 대상을 지칭하는지를 판별하는 고전적인 자연어처리 문제로, 문법적으로는 (답이) 모호하지만 의미적으로는 (사람에게는) 명확한 문제이다. 최근에는 미세조정된 언어모델이 거의 사람 수준의 성능을 보였지만 더 어려운 버전인 Winogrande dataset에서는 여전히 사람에 비해 크게 뒤떨어지는 결과를 보였다.

GPT-3은 273개의 Winograd Schemas의 원래 세트에 대해 테스트를 진행하였으며, GPT-2에서 사용된 “partial evaluation” 방법을 사용했다. 이 세팅은 SuperGLUE benchmark와는 조금 다른데, 이진 분류로 표현되며 객체 추출이 필요하다.
Winograd에서, GPT-3은 zero/one/few-shot에서 각각 88.5%, 89.7%, 88.6%의 성능을 보였으며, 문맥 내 학습에서 크게 명확하진 않지만 모든 경우에서 SOTA 및 사람보다 조금의 차이밖에 나지 않는 결과를 얻었다.
여기서 약간의 데이터 오염 문제가 있었지만, 결과에 미친 영향은 미미하다.

더 어려운 버전인 Winogrande dataset에서는, zero/one/few-shot에서 각각 70.2%, 73.2%, 77.7%의 정확도를 보였다. 비교를 하자면, 미세조정된 RoBERTa가 79%를, SOTA가 84.6%(T5), 사람이 94.0%이다.

Examples

Winograd Schemas Challenge
Winograd Schemas Challenge은 튜링 테스트의 약점을 보완하고자 나온 데이터셋으로 어떤 대상이 어떠한지를 물을 때 적절한 대상을 찾는 문제이다. The trophy doesn’t fit in the brown suitcase because it’s too big. What is too big?
Answer 0: the trophy
Answer 1: the suitcase
Joan made sure to thank Susan for all the help she had given. Who had given the help?
Answer 0: Joan
Answer 1: Susan

Winogrande dataset
Winogrande dataset은 44k개의 문제를 포함하여 기존 WSC보다 더 어렵고 규모가 큰 데이터셋이다. 언어적으로 편향되어 있기 때문에(어떤 단어는 특정 단어들과 같이 나올 확률이 높은 등) 언어모델이 쉽게 판별할 수 있는 질문들은 제거되었다.

Examples

3.5. Common Sense Reasoning

이제 문장 완성, 독해, 광범위한 지식 질답과는 다른 물리학적/과학적 추론을 다루고자 한다. PhysicalQA(PIQA)는 어떻게 물리학적으로 세계가 움젝이고 세상에 대한 현실 이해를 관찰하는 것을 목적으로 한 상식 질문을 묻는 데이터셋이다.
GPT-3은 zero/one/few-shot에서 81.0%, 80.5%, 82.8%의 성능을 보였으며, 이는 이전의 SOTA였던 미세조정된 RoBERTa의 79.4%보다 더 높은 수치이다.
PIQA는 상대적으로 모델 크기가 커져도 성능이 많이 향상되지 않으며 또한 사람에 비하면 10% 이상 낮지만, GPT-3의 zero-shot 결과는 현재 SOTA를 뛰어넘는다. 참고로 PhysicalQA에서도 데이터 오염 문제가 있었다.

ARC(AI2 Reasoning Challenge)는 3~9학년(초등~중등)의 과학 시험에서 모은 7787개의 다지선다형 문제 데이터셋이다. 이 데이터셋의 “Challenge” 버전은 단순 통계적 혹은 정보 검색만으로는 맞게 답할 수 없는 것들만 필터링한 것으로서 GPT-3은 zero/one/few-shot에서 51.4%, 53.2%, 51.5%의 성능을 보였다. 이는 UnifiedQA에서 미세조정된 RoBERTa가 보인 55.9%와 견줄 수 있는 수준이다.
이 데이터셋의 “easy” 버전에서는 GPT-3은 zero/one/few-shot에서 68.8%, 71.2%, 70.1%의 성능을 보여 RoBERTa를 살짝 앞질렀다. 그러나 이는 UnifiedQA에 비하면 27%/22%만큼이나 낮은 수치이다.

초등 수준의 과학적 사실을 다루는 질문으로 구성된 OpenBookQA에서는, zero~few-shot에서 상당한 발전을 보였으나 SOTA에 비하면 20점 이상 낮은 수치이다. few-shot 성능은 미세조정된 BERT Large와 비슷하다.

Examples

전체적으로, PIQA, ARC에서는 큰 향상이 없었으나, OpenBookQA에서는 꽤 진전이 있었다.

PhysicalQA(PIQA)
PhysicalQA(PIQA): Reasoning about Physical Commonsense in Natural Language은 어떤 (실생활) 목표가 자연어로 주어지면, 모델은 적절한 해답을 선택해야 한다.
Goal: To separate egg whites from the yolk using a water bottle, you should…
a. Squeeze the water bottle and press it against the yolk. Release, which creates suction and lifts the yolk.
b. Place the water bottle and press it against the yolk. Keep pushing, which creates suction and lifts the yolk.

ARC(AI2 Reasoning Challenge)
ARC(AI2 Reasoning Challenge): Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge은 Challenge set과 Easy set으로 구분되어, Challenge set은 정보기반 알고리즘과 단어 co-occurence 알고리즘으로 제대로 답변할 수 없는 질문들로만 구성되어 있다.
What is a worldwide increase in temperature called?
(A) greenhouse effect
(B) global warming
(C) ozone depletion
(D) solar heating

OpenBookQA
OpenBookQA : Can a Suit of Armor Conduct Electricity A New Dataset for Open Book Question Answering은 1326개의 초등 수준 과학적 사실에 기반하였으며 질문의 수는 6k 정도이다.
Question: Which of these would let the most heat travel through?
A) a new pair of jeans.
B) a steel spoon in a cafeteria.
C) a cotton candy at a store.
D) a calvin klein cotton hat.
Science Fact: Metal is a thermal conductor.
Common Knowledge: Steel is made of metal. Heat travels through a thermal conductor.

3.6. Reading Comprehension

추상적 / 다지선다 등 5개의 데이터셋에 대해 독해력을 측정한다. 여러 다른 답변 형식에서도 데이터셋간 장벽을 뛰어넘는 GPT-3의 범용성을 확인하였다.

  • 자유 형식 대화 데이터셋인 CoQA에서 최고의 성능(사람보다 3 point 낮음)과,
  • 구조화된 대화와 교사-학생 상호작용의 답변 선택 모델링을 요구하는 QuAC에서는 ELMo baseline보다 13 F1 score가 낮은 나쁜 성능을 보여주었다.
  • 독해 문맥에서 이산적 추론과 산술능력을 평가하는 데이터셋인 DROP에서는 GPT-3이 few-shot에서 미세조정된 BERT baseline을 뛰어넘었지만 SOTA나 사람에 비하면 아주 못 미치는 성적을 거두었다.
  • SQuAD 2.0에서는 zero-shot에서 10 F1(69.8) score를 향상시켰고 이는 원 논문의 가장 좋은 미세조정 모델보다 약간 더 높은 점수이다.
  • 중/고등 다지선다형 영어시험 문제를 모은 RACE에서는 상당히 약한 모습을 보였다(SOTA에 비해 45%나 낮음).
Examples

CoQA: A Conversational Question Answering Challenge
Conversational Question Answering systems을 만드는 데 필요한 대규모 데이터셋으로 coca(코카)라 읽는다. 텍스트 구절을 이해하고 대화에서 나타나는 상호 연결된 질문들에 대답하는 능력을 측정한다. 127k개의 질문과 8k개의 대화문을 포함한다.
Jessica went to sit in her rocking chair. Today was her birthday and she was turning 80. Her granddaughter Annie was coming over in the afternoon and Jessica was very excited to see her. Her daughter Melanie and Melanie’s husband Josh were coming as well. Jessica had . . . Q1: Who had a birthday?
A1: Jessica
R1: Jessica went to sit in her rocking chair. Today was her birthday and she was turning 80. Q2: How old would she be?
A2: 80
R2: she was turning 80

Q5: Who?
A5: Annie, Melanie and Josh
R5: Her granddaughter Annie was coming over in the afternoon and Jessica was very excited to see her. Her daughter Melanie and Melanie’s husband Josh were coming as well.

QuAC : Question Answering in Context
14k개의 정보를 찾는(information-seeking) QA 대화로 총 100k개의 질문을 포함하는 Question Answering in Context 데이터셋이다.
Section: Daffy Duck, Origin & History
STUDENT: What is the origin of Daffy Duck?
TEACHER: → first appeared in Porky’s Duck Hunt
STUDENT: What was he like in that episode?
TEACHER: → assertive, unrestrained, combative
STUDENT: Was he the star?
TEACHER: → No, barely more than an unnamed bit player in this short
STUDENT: Who was the star?
TEACHER: → No answer

DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs
Discrete Reasoning Over the content of Paragraphs의 약자로 문단 내용의 이산적 추론능력을 필요로 하는 독해력 측정용 benchmark dataset이다. Subtraction, Comparison, Selection, Addition, Count and Sort, Conference Resolution, Other Arithmetic, Set of spans, Other 등의 카테고리로 나뉘어 총 96k개의 질문을 포함한다.
Reasoning: Subtraction (28.8%)
Passage(some parts shortened): That year, his Untitled (1981), a painting of a haloed, black-headed man with a bright red skeletal body, depicted amid the artists signature scrawls, was sold by Robert Lehrman for $16.3 million, well above its $12 million high estimate.
Question: How many more dollars was the Untitled (1981) painting sold for than the 12 million dollar estimation?
Answer: 4300000
BiDAF: $16.3 million

Know What You Don’t Know: Unanswerable Questions for SQuAD:
SQuAD 2.0(Stanford Question Answering Dataset)은 기존 독해력 측정 데이터셋의 대답 가능한 질문이나 쉽게 판별할 수 있는 대답 불가능한 질문이라는 약점을 보완하여 사람이 생성한 질문으로 구성된다. 모델은 정답을 맞추는 것만 아니라 답변 중에 정답이 있는지도 판별해야 한다.
Article: Endangered Species Act
Paragraph: “…Other legislation followed, including the Migratory Bird Conservation Act of 1929, a 1937 treaty prohibiting the hunting of right and gray whales, and the Bald Eagle Protection Act of 1940. These later laws had a low cost to society—the species were relatively rare—and little opposition was raised.”
Question 1: “Which laws faced significant opposition?”
Plausible Answer: later laws
Question 2: “What was the name of the 1937 treaty?”
Plausible Answer: Bald Eagle Protection Act

RACE: Large-scale ReAding Comprehension Dataset From Examinations:
중국의 중/고등(12~18세) 영어시험으로부터 모은 28k개의 구절(passages)와 100k 정도의 (영어교사가 만든) 질문들로 구성된 데이터셋으로 이해력과 추론 능력을 평가하기 위한 다양한 주제들로 구성되어 있다.

Examples

자유 형식 대화 데이터셋인 CoQA에서 최고의 성능(사람보다 3 point 낮음)과, 구조화된 대화와 교사-학생 상호작용의 답변 선택 모델링을 요구하는 QuAC에서는 ELMo 기준보다 13 F1 score가 낮은 나쁜 성능을 보여주었다.

3.7. SuperGLUE

조금 더 많은 NLP task에 대한 결과를 모으고 BERT와 RoBERTa와 더 체계적으로 비교하기 위해 SuperBLUE benchmark 테스트도 진행하였다. few-shot에서는 모든 task에 대해 32개의 예제를 사용하였고 이는 training set에서 무작위로 선택하였다. WSC와 MultiRC를 제외하고는 각 문제당 문맥에서 새로운 예제를 선택했다. 위 두 데이터셋에서는 training set에서 무작위로 선택한 것과 같은 세트를 사용했다.

Examples

GPT-3으 몇몇 데이터셋에서는 SOTA에 근접한 결과를 얻었으나 그렇지 못한 데이터셋도 여럿 있는 것을 볼 수 있다. 특히 두 문장에서 사용된 동일한 철자의 단어가 같은 의미로 사용되었는지를 보는 WiC dataset에서는 49.4%로 찍는 거랑 다를 바가 없었다.

Examples

그리고 모델의 크기가 커질수록, few-shot에 사용되는 예제의 수가 많을수록 성능이 증가함을 볼 수 있다.

3.8. NLI

Natural Language Inference (NLI)은 두 문장 간의 관계를 이해하는 것을 측정한다. 실제로는, 보통 2~3개의 분류 문제로 모델은 두 번째 문장이 첫 번째 문장과 같은 논리를 따르는지, 모순되는지, 혹은 그럴지도 모를지(중립적)를 판별한다. SuperGLUE는 NLI dataset으로 이진 분류 RTE를 포함한다. RTE에서는 가장 큰 모델인 GPT-3만 찍는 것보다 나은 56%를 기록하였으나 few-shot에서는 single-task 미세조정 BERT Large와 비슷하다.

그리고 ANLI 데이터셋에 대해서도 테스트를 진행하였는데, few-shot에서조차 GPT-3보다 작은 모델은 전부 형편없는 모습(~33%)을 보여준다. 전체 결과는 아래 그림과 Appendix H에서 볼 수 있다.

Examples

Adversarial NLI: A New Benchmark for Natural Language Understanding:
적대적으로 만들어진 자연어 추론 문제를 3 round(R1, R2, R3)에 걸쳐 진행하는 데이터셋으로 Context, Hypothesis, Reason, Round, Labels(orig, pred, valid), Annotations으로 구성된다.

Examples

3.9. Synthetic and Qualitative Tasks

GPT-3의 few-shot(혹은 zero나 one) 능력의 범위를 보려면 즉석 계산적 추론이나, 새로운 패턴을 찾아내거나, 새 task에 빠르게 적응(적용)하는지를 측정해보면 된다.
그래서 세 가지를 측정한다: 산술능력, 단어 재조합, SAT 유추.

3.9.1 Arithmetic

10가지 경우를 테스트하였는데, 각 숫자는 각 범위 내에서 동일한 확률로 선택되었다.

  • 2자리 덧셈: “Q: What is 48 plus 76? A: 124.”
  • 2자리 뺄셈: “Q: What is 34 minus 53? A: -19”
  • 3, 4, 5자리 덧셈/뺄셈
  • 2자리 곱셈: “Q: What is 24 times 42? A: 1008”
  • 괄호를 포함한 1자리 복합연산: “Q: What is 6+(4*8)? A: 38”
Examples

2/3자리 계산은 거의 100%에 가까운 정확도를 보여주지만, 자리수가 많아질수록 성능은 떨어졌다. 또한, 2자리 곱셈에서는 29.2%, 1자리 복합연산은 21.3%의 정확도를 얻었다(few-shot).
종합적으로, GPT-3은 보통 수준의 복잡한 산술연산에서 합리적인 수준의 성능을 보였다.

Examples

3.9.2 Word Scrambling and Manipulation Tasks

적은 수의 예로부터 새로운 symbolic manipulation을 학습하는 능력을 측정하기 위해 다음 5가지 “문자조합” task를 설정했다.

  • 단어 내 철자를 회전시켜 원래 단어를 만들기(Cycle letters in word (CL)). lyinevitab = inevitably
  • 처음과 마지막을 제외한 철자가 뒤섞여 있을 때 원래 단어 만들기(Anagrams of all but first and last characters (A1)). criroptuon = corruption
  • A1과 비슷하지만 처음/마지막 각 2글자가 섞이지 않음(Anagrams of all but first and last 2 characters (A2)). opoepnnt → opponent
  • 구두점들과 빈칸이 각 철자 사이에 올 때 원래 단어 만들기(Random insertion in word (RI)). s.u!c/c!e.s s i/o/n = succession
  • 거꾸로 된 단어에서 원래 단어 만들기(Reversed words (RW)). stcejbo → objects

각 1만 개의 예제를 만들기 위해 4~15글자 사이의 가장 빈도가 높은 단어들을 선정하였다.
few-shot 결과는 아래와 같다. 모델의 크기가 커질수록 성능도 조금씩 증가한다.

Examples

어떤 모델도 단어를 뒤집는 RW task에 성공하지 못했다. 또한 one/zero-shot에서는 성능이 매우 떨어진다.

Examples

여기서 “문맥 내 정보”는 큰 모델일수록 더 잘 활용한다는 것을 보였다.

또한, 언어모델의 이러한 성공은 단지 BPE token을 잘 쓰는 것 뿐 아니라 그 하부구조를 잘 이해하고 분석하였음을 알 수 있다. 그리고 CL, A1, A2 task는 전단사상(bijective)이 아니기 때문에(즉, 하나로 결정된 것이 아니기에) 자명하지 않은 패턴매칭과 계산적인 능력에서 연관이 있다고 할 수 있다.

3.9.3 SAT Analogies

텍스트의 전형적인 분포와 연관된 무언가 흔치 않은 다른 task에 테스트하기 위해, 374개의 “SAT analogy” 문제를 모았다. 예시는 다음과 같다:

Audacious is to boldness as
(a) sanctimonious is to hypocrisy,
(b) anonymous is to identity,
(c) remorseful is to misdeed,
(d) deleterious is to result,
(e) impressionable is to temptation”

5개의 단어 쌍 중 주어진 단어 쌍과 같은 관계를 갖는 정답지를 골라야 한다(정답은 a이다).
여기서 GPT-3은 zero/one/few-shot에서 각각 53.7%, 59.1%, 65.2%의 성능을 보였고, 대학 졸업자 평균이 57%(찍으면 20%)이다.
아래 그림에서 보듯 1750억 개 짜리 모델은 130억 개의 모델보다 10% 가량 더 성능이 높다.

Examples

3.9.4 News Article Generation

생성적 언어모델을 질적으로 테스트하기 위한 방법으로 뉴스기사로 첫 문장을 준 뒤 이후 문장을 생성하는 식으로 측정해 왔다. GPT-3을 학습한 데이터셋은 뉴스기사에 별 비중을 두지 않읐으므로 무조건으로 뉴스기사를 생성하는 것은 비효율적이다. 그래서 GPT-3에서는 3개의 이전 뉴스기사를 모델의 문맥에 포함시켜 few-shot 학습능력을 측정했다. 제목과 부제목이 주어지면, 모델은 “뉴스” 장르에서 짧은 기사를 생성할 수 있다.

GPT-3이 생성한 기사가 사람이 쓴 것과 구별되는지를 사람이 얼마나 잘 구별하는지를 측정하였다. 이를 통해 GPT-3의 뉴스기사 생성능력을 볼 수 있다.

사람이 판별할 때는 ‘사람이 쓴 기사’와 ‘기계가 생성한 기사’ 두 개를 보고, 다음 5가지 중에 고른다: 1) 확실히 사람이 썼다. 2) 사람이 쓴 것 같다. 3) 잘 모르겠다. 4) 기계가 쓴 것 같다. 5) 확실히 기계가 썼다.

Examples

정답률이 50%에 근접하면 기계가 썼는지 사람이 썼는지 분간이 안 된다는 뜻이다(기계가 사람만큼 잘 썼다/또는 사람처럼 비슷하게 썼다).

물론 모델이 커질수록 구별하기는 점점 더 어려워진다.

Examples

아래는 GPT-3이 생성한 기사이다. 구분하기 가장 어려운 것과 쉬운 것을 보여주고 있다.

Examples

3.9.5 Learning and Using Novel Words

발달적 언어학에서 연구되는 한 task는 새로운 단어를 학습하고 사용하는 능력을 측정한다(예를 들면 단 한 번 정의된 단어를 보고 나서 사용하거나, 한 번의 용례로 단어의 의미를 유추하는 것).
테스트는 다음과 같이 했다. “Gigamuru”와 같이 실제로 없는 단어를 정의하고, 이를 문장에서 사용해 보게 하였다. 1~5개의 없는 단어를 정의하고 문장에서 사용하여, 이 task는 one/few-shot 세팅으로 구성된다. 모든 정의는 사람이 직접 하였고, 첫 번째 답은 사람이 정의하였고, 나머지는 GPT-3이 한 것이다:

Examples

모든 경우에서 정확하거나 합리적인 수준으로 사용하였음을 볼 수 있다. 적어도, GPT-3은 새로운 단어를 사용하는 능력은 꽤 수준이 있는 것 같다.

3.9.6 Correcting English Grammar

few-shot에 적합한 또 다른 task는 영문법을 교정하는 것이다. 다음과 같이 few-shot 예시를 주었다:

"Poor English Input: <sentence>\n Good English Output: <sentence>".

결과는 아래 그림에서 볼 수 있다.

Examples

4. 벤치마크를 외웠는지 측정하고 예방하기(Measuring and Preventing Memorization Of Benchmarks)


5. 한계(Limitations)


6. 광범위한 영향(Broader Impacts)


7. 관련 연구(Related Work)

이 논문


8. 결론(Conclusion)

큰 크기

Acknowledgements

언제나 있는 감사의 인사


Appendix A: Details of Common Crawl Filtering


Appendix B: Details of Model Training


Appendix C: tails of Test Set Contamination Studies


Appendix D: Total Compute Used to Train Language Models


Appendix E: Human Quality Assessment of Synthetic News Articles


Appendix F: Additional Samples from GPT-3


Appendix G: Details of Task Phrasing and Specifications


Appendix H: Results on All Tasks for All Model Sizes


Refenrences

논문 참조. 많은 레퍼런스가 있다.


Citation

@misc{brown2020language,
      title={Language Models are Few-Shot Learners}, 
      author={Tom B. Brown and Benjamin Mann and Nick Ryder and Melanie Subbiah and Jared Kaplan and Prafulla Dhariwal and Arvind Neelakantan and Pranav Shyam and Girish Sastry and Amanda Askell and Sandhini Agarwal and Ariel Herbert-Voss and Gretchen Krueger and Tom Henighan and Rewon Child and Aditya Ramesh and Daniel M. Ziegler and Jeffrey Wu and Clemens Winter and Christopher Hesse and Mark Chen and Eric Sigler and Mateusz Litwin and Scott Gray and Benjamin Chess and Jack Clark and Christopher Berner and Sam McCandlish and Alec Radford and Ilya Sutskever and Dario Amodei},
      year={2020},
      eprint={2005.14165},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Comment  Read more

Conditional Variational AutoEncoder (CVAE) 설명

|

본 글에서는 Variational AutoEncoder를 개선한 Conditional Variational AutoEncoder (이하 CVAE)에 대해 설명하도록 할 것이다. 먼저 논문을 리뷰하면서 이론적인 배경에 대해 탐구하고, Tensorflow 코드(이번 글에서는 정확히 구현하지는 않았다.)로 살펴보는 시간을 갖도록 하겠다. VAE에 대해 알고 싶다면 이 글을 참조하길 바란다.


1. Learning Structured Output Representation using Deep Conditional Generative Models

1.1. Introduction

구조화된 Output 예측에서, 모델이 확률적인 추론을 하고 다양한 예측을 수행하는 것은 매우 중요하다. 왜냐하면 우리는 단지 분류 문제에서처럼 many-to-one 함수를 모델링하는 것이 아니라 하나의 Input에서 많은 가능한 Output을 연결짓는 모델이 필요하기 때문이다. CNN이 이러한 문제에 효과적인 모습을 보여주었지만, 사실 CNN은 복수의 mode를 갖는 분포를 모델링하기에는 적합하지 않다.

이 문제를 다루기 위해서는 Output Representation Learning과 구조화된 예측을 위한 새로운 Deep Conditional Generative Model이 필요하다. 즉, 고차원의 Output Space를 Input 관측값에 조건화되어 있는 생성 모델로 모델링해야하는 것이다. 변분 추론Directed Graphical Model의 최근 발전에 기반하여 본 논문은 CVAE를 새로운 모델로서 제안한다. 이 모델은 Directed Graphical Model로서 Input 관측값이 Output을 생성하는 Gaussian 잠재 변수에 대한 Prior를 조절한다. 모델은 조건부 Log Likelihood를 최대화하도록 학습하게 되며, 우리는 이 과정을 SGVB: Stochastic Gradient Variational Bayes의 프레임워크 안에서 설명할 것이다. SGVB에 대해 미리 알고 싶다면 이 글을 참조하도록 하라. 또한 더욱 Robust한 예측 모델을 만들기 위해 우리는 Input Noise Injection이나 Multi-scale Prediction Training Method 등을 소개할 것이다.

실험에서 본 모델의 효과성을 보이도록 할 것인데, 특히 데이터가 일부만 주어졌을 때 구조화된 Output을 모델링하는 데에 있어 확률적 뉴런의 중요성을 보여줄 것이다. 데이터셋은 Caltech-UCSD Birds 200과 LFW를 사용하였다.

(중략)

1.3. Preliminary: Variational Auto-Encoder

이 Chapter 역시 대부분 생략하도록 하겠다. 자세한 설명은 글 서두에 있는 링크를 클릭하여 살펴보도록 하자. 최종적으로 VAE의 목적함수만 정리하고 넘어가겠다.

[\tilde{\mathcal{L}}{VAE} (\theta, \phi; \mathbf{x}^{(i)}) = -KL (q{\phi} (\mathbf{z} \mathbf{x}^{(i)})   p_{\theta} (\mathbf{z}) ) + \frac{1}{L} \Sigma_{l=1}^L logp_{\theta} (\mathbf{x z^{(l)}})]

1.4. Deep Conditional Generative Models for Structured output Prediction

변수에는 3가지 종류가 있다. Input 변수 x, Output 변수 y, 잠재 변수 z가 바로 그것이다. $\mathbf{x}$ 가 주어졌을 때, $\mathbf{z}$ 는 아래와 같은 사전 확률로 부터 추출된다.

[p_{\theta}(\mathbf{z x})]

그리고 output $\mathbf{y}$ 는 아래 분포로 부터 생성된다.

[p_{\theta}(\mathbf{y x, z})]

baseline CNN과 비교하여 잠재 변수 $\mathbf{z}$ 는 Input이 주어졌을 때 Output 변수에 대한 조건부 분포에서 복수의 mode를 모델링하는 것을 허용하기 때문에 제안된 CGM (조건부 생성 모델) 을 one-to-many mapping 모델링에 적합하게 만든다. 위 식에 따르면 잠재 변수의 사전 확률은 Input 변수에 의해 조절되는 것처럼 보이지만, 이러한 제한은 잠재 변수를 Input 변수에 독립적으로 만들어 해소할 수 있다.

[p_{\theta} (\mathbf{z x}) = p_{\theta} (\mathbf{z})]

Deep CGM은 조건부 Conditional Log Likelihood를 최대화하면서 학습된다. 이 목적 함수는 종종 intractable 하기 때무에 우리는 SGVB 프레임워크를 적용할 것이다. ELBO는 아래와 같다.

[log p_{\theta} (\mathbf{y x}) \geq \tilde{\mathcal{L}}{VAE} (\mathbf{x, y} ; \theta, \phi) = -KL( q{\theta} (\mathbf{z x, y})   p_{\theta} (\mathbf{z x}) ) + E_{q_{\phi} (\mathbf{z} \mathbf{x, y})} { logp_{\theta} (\mathbf{y} \mathbf{x, z}) }]

물론 위 식의 우변의 두 번째 항은 Monte-Carlo Estimation을 통해 경험적으로 값을 얻을 수 있다. 이 기법에 대해 알고 싶다면 이 글을 참조하라. 다시 포현하면 아래와 같다.

[\tilde{\mathcal{L}}{VAE} (\mathbf{x, y} ; \theta, \phi) = -KL( q{\theta} (\mathbf{z x, y})   p_{\theta} (\mathbf{z x}) ) + \frac{1}{L} \Sigma_{l=1}^L logp_{\theta} (\mathbf{x z^{(l)}})]

$L$ 은 Sample의 개수이며 이 때,

[\mathbf{z}^{(l)} = g_{\phi} (\mathbf{x, y} , {\epsilon}^{(l)} ), {\epsilon}^{(l)} \sim \mathcal{N} (\mathbf{0}, \mathbf{I})]

본 논문은 이 모델을 CVAE라고 부를 것이다. 이 모델은 복수의 MLP로 구성되는 데 크게 3가지의 요소를 갖고 있다.

1) Recognition Network

[q_{\phi} (\mathbf{z x, y})]

2) Prior Network

[p_{\theta} (\mathbf{z x})]

3) Generation Network

[p_{\theta} (\mathbf{y x, z})]

이 네트워크 구조를 디자인할 때, baseline CNN 위에 CVAE의 구성요소를 올릴 것이다. 아래 그림에서 (d)를 확인해보자.

직접적인 Input $\mathbf{x}$ 뿐만 아니라 CNN으로부터 만들어진 최초의 예측 값 $\hat{\mathbf{y}}$ 는 Prior Network로 투입된다. 이러한 순환 연결은 구조화된 Output 예측 문제에서 효과적으로 합성곱 네트워크를 깊게 만들면서 이전의 추측을 수정하여 예측값을 연속적으로 업데이트하는 과정에 적용된 바 있다. 우리는 또한 이러한 순환 연결이, 설사 단 한 번의 반복에 그치더라도, 굉장한 성능 향상을 이끌어낸다는 사실을 발견했다. 네트워크 구조에 대한 자세한 사항은 이후에 설명할 것이다.

1.4.1. Output Inference and Estimation of the Conditional Likelihood

모델 파라미터가 학습되면 CGM의 생성 과정을 따라 Input으로부터 Output에 대한 예측을 수행하게 된다. 모델을 평가하기 위해서 우리는 $\mathbf{z}$ 에 대한 Sampling 없이 Deterministic Inference를 수행할 수 있다.

[\mathbf{y^*} = \underset{y}{argmax} p_{\theta} (\mathbf{y x, z^}), \mathbf{z^} = E[\mathbf{z x}]]

또는 사전 확률로부터 복수의 $\mathbf{z}$ 를 추출한 뒤 사후 확률의 평균을 사용하여 예측을 수행할 수도 있다.

[\mathbf{y^*} = \underset{y}{argmax} \frac{1}{L} \Sigma_{l=1}^L p_{\theta} (\mathbf{y x, z^{(l)}}), \mathbf{z}^{(l)} \sim p_{\theta} (\mathbf{z x})]

CGM을 평가하는 또 다른 방법은 테스트 데이터의 조건부 Likelihood를 비교하는 것이다. 아주 직관적인 방법은 사전 확률 네트워크로부터 $\mathbf{z}$ 를 추출하고 Likelihood의 평균을 취하는 것이다. 물론 이 방법은 Monte-Carlo Sampling이다.

[p_{\theta} (\mathbf{y x}) \approx \frac{1}{S} \Sigma_{s=1}^S p_{\theta} (\mathbf{y x, z^{(s)}}), \mathbf{z}^{(s)} \sim p_{\theta} (\mathbf{z x})]

사실 이 몬테카를로 샘플링은 굉장히 많은 샘플을 필요로 한다. 이것이 어려울 경우 Importance Sampling을 통해 조건부 Likelihood를 추정할 수 있다.

[p_{\theta} (\mathbf{y x}) \approx \frac{1}{S} \Sigma_{s=1}^S \frac{p_{\theta} (\mathbf{y x, z^{(s)}}) p_{\theta} (\mathbf{z^{(s)} x}) } { q_{\phi} (\mathbf{ z^{(s)} x, y }) } , \mathbf{z}^{(s)} \sim q_{\phi} (\mathbf{ z x, y })]

1.4.2. Learning to predict structured output

SGVB가 Deep Generative Model을 학습하는 데에 있어 효과적 것이 증명되긴 하였지만, 학습 과정에서 형성된 Output 변수들에 대한 Conditional Auto-Encoding은 테스트 과정에서 예측을 할 때 최적화되지 않았을 수도 있다.

즉, CVAE가 학습을 할 때 아래와 같은 인식 네트워크를 사용할 것인데,

[q_{\phi} (\mathbf{z x, y})]

테스트 과정에서는 아래와 같은 Prior 네트워크로부터 sample $\mathbf{z}$ 를 추출하여 예측을 수행한다는 것이다.

[p_{\theta} (\mathbf{z x})]

인식 네트워크에서 $\mathbf{y}$ 는 Input으로 주어지기 때문에, 학습의 목표는 $\mathbf{y}$ 의 Reconstruction인데, 이는 사실 예측보다 쉬운 작업이다.

[\tilde{\mathcal{L}}{VAE} (\mathbf{x, y} ; \theta, \phi) = -KL( q{\theta} (\mathbf{z x, y})   p_{\theta} (\mathbf{z x}) ) + \frac{1}{L} \Sigma_{l=1}^L logp_{\theta} (\mathbf{x z^{(l)}})]

위 식에서 Negative 쿨백-라이블리 발산 항은 2개의 파이프라인의 차이를 줄이려고 한다. 따라서 이러한 특성을 활용하여, 학습 및 테스트 과정에서 잠재 변수의 Encoding의 차이를 줄이기 위한 방법이 있다. 바로 목적 함수의 Negative 쿨백-라이블리 발산 항에 더욱 큰 가중치를 할당하는 것이다. 예를 들어 다음과 같은 형상을 생각해볼 수 있겠다.

[- (1 + \beta) KL( q_{\theta} (\mathbf{z x, y})   p_{\theta} (\mathbf{z x}) ), \beta \ge 0]

그러나 본 논문에서의 실험에서 이와 같은 조치는 큰 효력을 발휘하지 못하였다.

대신, 학습과 테스트 과정 상의 예측 파이프라인을 일치(consistent) 시키는 방향으로 네트워크를 학습시키는 것을 제안한다. 이는 Prior 네트워크와 인식 네트워크를 동일하게 만드는 방식으로 적용할 수 있는데, 그렇게 하면 아래와 같은 목적함수를 얻게 된다.

[\tilde{\mathcal{L}}{GSNN} (\mathbf{x, y} ; \theta, \phi) = \frac{1}{L} \Sigma{l=1}^L logp_{\theta} (\mathbf{x z^{(l)}})]

[\mathbf{z}^{(l)} = g_{\phi} (\mathbf{x, y} , {\epsilon}^{(l)} ), {\epsilon}^{(l)} \sim \mathcal{N} (\mathbf{0}, \mathbf{I})]

우리는 이 모델을 GSNN: Gaussian Stochastic Neural Network라고 부를 것이다. GSNNCVAE에서의 인식 네트워크와 Prior 네트워크를 동일하게 만듦으로써 만들 수 있다. 따라서 CVAE에서 사용하였던 Reparameterization Trick과 같은 학습 트릭은 GSNN에서도 사용할 수 있다. 비슷하게 테스트 과정에서의 추론과 Conditional Likelihood 추정 또한 CVAE의 그것과 같다. 마지막으로, 우리는 두 모델의 목적 함수를 결합하여 다음과 같은 Hybrid 목적 함수를 얻을 수 있다.

[\tilde{\mathcal{L}}{hybrid} = \alpha \tilde{\mathcal{L}}{CVAE} + (1-\alpha) \tilde{\mathcal{L}}_{GSNN}]

이 때 $\alpha$는 두 목적 함수 사이의 균형을 맞춰준다. 만약 $\alpha=1$ 이면, 그냥 CVAE의 목적 함수와 동일함을 알 수 있다. 만약 반대로 $\alpha = 0$ 이면, 우리는 그냥 인식 네트워크 없이 GSNN을 학습시키는 것이라고 생각할 수 있다.

1.4.3. CVAE for Image Segmentation and Labelling

Semantic Segmentation은 중요한 구조화된 Output 예측 과제이다. 이 Chapter에서는 이러한 문제를 해결하기 위한 Robust한 예측 모델을 학습시키는 전략을 제시할 것이다. 특히 관측되지 않은 데이터에 대해 잘 일반화될 수 있는 high-capacity 신경망을 학습시키기 위해 우리는 1) Multi-scale 예측 목적 함수와 2) 구조화된 Input Noise와 함께 신경망을 학습시킬 것을 제안한다.

1.4.3.1. Training with multi-scale prediction objective

이미지 크기가 커질 수록, 정교하게 픽셀 레벨의 예측을 하는 것은 굉장히 어려워진다. Multi-scale 접근 방법은 Input에 대해 Multi-scale 이미지 피라미드를 형성하는 관점에서 사용되어 왔지만 Multi-scale Output 예측을 위해서는 잘 사용되지 않았다.

본 논문에서 우리는 다른 scale로 Output을 예측하는 네트워크를 학습시킬 것을 제안한다. 그렇게 함으로써, global-to-loca, coarse-to-fine-grained한 픽셀 레벨의 semantic label에 대한 예측을 수행할 수 있다. 위 그림은 3가지 scale로 학습을 진행하는 모습에 대한 예시이다.

1.4.3.2. Training with Input Omission Noise

깊은 신경망의 뉴런에 Noise를 추가하는 것은 대표적인 규제 방법 중 하나이다. 우리는 Semantic Segmentation에 대해서 간단한 규제 테크닉을 제안한다. Input 데이터 $\mathbf{x}$ 를 Noise Process에 따라 오염시켜 $\tilde{\mathbf{x}}$ 로 만들고 목적함수 $\tilde{\mathcal{L} (\mathbf{\tilde{x}, y})}$ 로 네트워크를 최적화하는 것이다.

Noise Process는 임의로 정할 수 있는데, 본 문제에서는 Random Block Omission Noise를 제안한다. 특히 우리는 이미지의 40% 이하의 면적에 대해 사각형의 마스크를 랜덤하게 생성하고, 그 부분의 픽셀 값을 0으로 만드는 방법을 사용하였다. 이는 Block 폐쇄 혹은 결측값을 시뮬레이션한 것으로, 예측 문제를 더욱 어렵게 만드는 요인으로 파악할 수 있다.

이렇게 제안된 전략은 또한 Denoising 학습 방법과 연관되어 있다고도 볼 수 있는데, 우리는 Input 데이터에만 Noise를 투사하고 Missing Input을 재구성하지는 않는다는 점이 다르다.

1.5. Experiments

(논문 원본 참조)

1.6. Conclusion

구조화된 Output 변수에 대해 복수의 Mode를 갖는 분포를 모델링하는 것은 구조화된 예측 문제에 대해 좋은 성과를 내는 데에 있어 중요한 이슈이다. 본 연구에서 우리는 가우시안 잠재 변수를 이용하여 Conditional Deep Generative Model에 근거한 확률적 신경망을 제안하였다.

제안된 모델은 scalable하며 추론과 학습에 있어 효율적이다. 우리는 Output 공간이 복수의 Mode를 갖는 분포에 대해 확률적인 추론을 하는 것의 중요성을 역설하였고, Segmentation 정확도, 조건부 Log Likelihood 추정, 생성된 Sample의 시각화 측면에서 모두 뛰어난 성과를 냈다는 것을 보여주었다.


2. Tensorflow로 확인

VAE를 다루었던 이전 글에서 크게 바뀐 부분은 없다.
본래 이 논문에 나와있는 내용에 충실히 따라서 구현을 해야겠지만… 이 논문 이후에 나온 다른 논문들에 더 집중하기 위해 본 글에서는 간단히 $y$ 를 Input으로 추가했을 때 어떤 효과가 나오는지 정도만 확인을 하도록 하겠다.

Convolutional 형태를 취했던 이전 모델과 달리 $y$ 를 Input으로 넣기 위해 모두 Flatten한 상태로 네트워크를 구성하였다. 이번에는 Label 데이터도 같이 불러온다.

train_dataset = (tf.data.Dataset.from_tensor_slices(
    (tf.cast(train_images, tf.float32), tf.cast(train_labels, tf.float32)))
                 .shuffle(train_size).batch(batch_size))
test_dataset = (tf.data.Dataset.from_tensor_slices(
    (tf.cast(test_images, tf.float32), tf.cast(test_labels, tf.float32)))
                .shuffle(test_size).batch(batch_size))

모델은 아래와 같다. encode, decode 단계에서 $y$ 가 Input으로 추가되어 있는 모습을 확인할 수 있다.

class CVAE(tf.keras.Model):
    def __init__(self, latent_dim):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(28*28 + 1)),
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(units=512, activation='relu'),
                tf.keras.layers.Dropout(rate=0.2),
                tf.keras.layers.Dense(units=256, activation='relu'),
                tf.keras.layers.Dropout(rate=0.2),
                tf.keras.layers.Dense(units=64, activation='relu'),
                tf.keras.layers.Dense(units=latent_dim + latent_dim),
            ])

        self.decoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(latent_dim+1)),
                tf.keras.layers.Dense(units=64, activation='relu'),
                tf.keras.layers.Dropout(rate=0.2),
                tf.keras.layers.Dense(units=256, activation='relu'),
                tf.keras.layers.Dropout(rate=0.2),
                tf.keras.layers.Dense(units=512, activation='relu'),
                tf.keras.layers.Dense(units=784),
            ])

    @tf.function
    def encode(self, x, y):
        inputs = tf.concat([x, y], 1)
        mean, logvar = tf.split(self.encoder(inputs), num_or_size_splits=2, axis=1)
        stddev = 1e-8 + tf.nn.softplus(logvar)
        return mean, stddev

    def reparameterize(self, mean, stddev):
        eps = tf.random.normal(shape=mean.shape)
        z = mean + eps * stddev
        return z

    @tf.function
    def decode(self, z, y, apply_sigmoid=False):
        inputs = tf.concat([z, y], 1)
        logits = self.decoder(inputs)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits

학습 및 테스트 코드는 아래와 같다.

optimizer = tf.keras.optimizers.Adam(1e-4)

def compute_loss(model, x, y):
    x = tf.reshape(x, [-1, 784])
    y = tf.reshape(y, [-1, 1])
    mean, stddev = model.encode(x, y)
    z = model.reparameterize(mean, stddev)
    x_logit = model.decode(z, y, True)
    x_logit = tf.clip_by_value(x_logit, 1e-8, 1-1e-8)

    # Loss
    marginal_likelihood = tf.reduce_sum(x * tf.math.log(x_logit) + (1 - x) * tf.math.log(1 - x_logit), axis=[1])
    loglikelihood = tf.reduce_mean(marginal_likelihood)

    kl_divergence = -0.5 * tf.reduce_sum(1 + tf.math.log(1e-8 + tf.square(stddev)) - tf.square(mean) - tf.square(stddev),
                                         axis=[1])
    kl_divergence = tf.reduce_mean(kl_divergence)

    ELBO = loglikelihood - kl_divergence
    loss = -ELBO

    return loss


@tf.function
def train_step(model, x, y, optimizer):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x, y)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    return loss


epochs = 30
latent_dim = 2
model = CVAE(latent_dim)

# Train
for epoch in range(1, epochs + 1):
    train_losses = []
    for x, y in train_dataset:
        loss = train_step(model, x, y, optimizer)
        train_losses.append(loss)

    print('Epoch: {}, Loss: {:.2f}'.format(epoch, np.mean(train_losses)))


# Test
def generate_images(model, test_x, test_y):
    test_x = tf.reshape(test_x, [-1, 784])
    test_y = tf.reshape(test_y, [-1, 1])
    mean, stddev = model.encode(test_x, test_y)
    z = model.reparameterize(mean, stddev)

    predictions = model.decode(z, test_y, True)
    predictions = tf.clip_by_value(predictions, 1e-8, 1 - 1e-8)
    predictions = tf.reshape(predictions, [-1, 28, 28, 1])

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, :, :, 0], cmap='gray')
        plt.axis('off')

    plt.show()


num_examples_to_generate = 16
random_vector_for_generation = tf.random.normal(shape=[num_examples_to_generate, latent_dim])
test_x, test_y = next(iter(test_dataset))
test_x, test_y = test_x[0:num_examples_to_generate, :, :, :], test_y[0:num_examples_to_generate, ]

for i in range(test_x.shape[0]):
    plt.subplot(4, 4, i + 1)
    plt.imshow(test_x[i, :, :, 0], cmap='gray')
    plt.axis('off')

plt.show()

generate_images(model, test_x, test_y)

VAE와 동일하게 Epoch 30이후의 결과를 확인하면 다음과 같다. 기존 이미지와 상당히 유사하게 새로운 이미지를 생성한 것을 확인할 수 있다. (Loss도 127까지 줄어들었다. 다만 좀 흐릿하긴 하다.)


Reference

논문 원본

Comment  Read more