Gorio Tech Blog search

Contrastive Learning, SimCLR 논문 설명(SimCLRv1, SimCLRv2)

|

이 글에서는 Contrastive Learning을 간략하게 정리한다.


Contrastive Learning

어떤 item들의 “차이”를 학습해서 그 rich representation을 학습하는 것을 말한다. 이 “차이”라는 것은 어떤 기준에 의해 정해진다.

Contrastive Learning은 Positive pair와 Negative pair로 구성된다. 단, Metric Learning과는 다르게 한 번에 3개가 아닌 2개의 point를 사용한다.

한 가지 예시는,

  • 같은 image에 서로 다른 augmentation을 가한 다음
  • 두 positive pair의 feature representation은 거리가 가까워지도록 학습을 하고
  • 다른 image에 서로 다른 augmentation을 가한 뒤
  • 두 negative pair의 feature representation은 거리가 멀어지도록 학습을 시키는

방법이 있다. 아래에서 간략히 소개할 SimCLR도 비슷한 방식이다.

Pair-wise Loss function을 사용하는데, 어떤 입력 쌍이 들어오면, ground truth distance $Y$는 두 입력이 비슷(similar)하면 0, 그렇지 않으면(dissmilar) 1의 값을 갖는다.

Loss function은 일반적으로 다음과 같이 나타낼 수 있다.

[\mathcal{L}(W) = \sum^P_{i=1} L(W, (Y, \vec{X_1}, \vec{X_2})^i)]

[L(W, (Y, \vec{X_1}, \vec{X_2})^i) = (1-Y)L_S(D^i_W) + YL_D(D^i_W)]

이때, 비슷한 경우와 그렇지 않은 경우 loss function을 다른 함수를 사용한다.

예를 들면,

[L(W, (Y, \vec{X_1}, \vec{X_2})^i) = (1-Y)\frac{1}{2}(D_W)^2 + (Y)\frac{1}{2}( \max(0, m-D_W) )^2]

즉 similar한 경우 멀어질수록 loss가 커지고, dissimilar한 경우 가까워질수록 loss가 커진다.


SimCLR(A Simple Framework for Contrastive Learning of Visual Representations)

논문 링크: A Simple Framework for Contrastive Learning of Visual Representations

Github: https://github.com/google-research/simclr

  • 2020년 2월(Arxiv), ICML
  • Google Research
  • Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton

Mini-batch에 $N$개의 image가 있다고 하면, 각각 다른 종류의 augmentation을 적용하여 $2N$개의 image를 생성한다. 이때 각 이미지에 대해, 나머지 $2N-1$개의 이미지 중 1개만 positive고 나머지 $2N-2$개의 image는 negative가 된다. 이렇게 하면 anchor에 대해 positive와 negative를 어렵지 않게 생성할 수 있고 따라서 contrastive learning을 수행할 수 있다.

위의 그림을 보면.

  1. 이미지 $x$에
  2. 서로 다른 2개의 augmentation을 적용하여 $\tilde{x}_i, \tilde{x}_j$을 생성
  3. 이는 CNN 기반 network $f(\cdot)$를 통과하여 visual representation $h_i, $h_j$로 변환됨
  4. 이 표현을 projection head, 즉 MLP 기반 network인 $g(\cdot)$을 통과하여 $z_i, z_j$를 얻으면
  5. 이 $z_i, z_j$로 contrastive loss를 계산한다.
  6. 위에서 설명한 대로 mini-batch 안의 $N$개의 이미지에 대해 positive와 negative를 정해서 계산한다.

Contrastive loss는 다음과 같이 쓸 수 있다. (NT-Xent(Normalized Temperature-scaled Cross Entropy))

[\ell_{(i, j)} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum^{2N}{k=1} \mathbb{1}{[k \ne i]} \exp(\text{sim}(z_i, z_j) / \tau)}]

참고로 이는 self-supervised learning이다(사전에 얻은 label이 필요 없음을 알 수 있다).

Miscellaneous

  • Projection head는 2개의 linear layer로 구성되어 있고, 그 사이에는 ReLU activation을 적용한다.
  • Batch size가 클수록 많은 negative pair를 확보할 수 있으므로 클수록 좋다. SimCLR에서는 $N=4096$을 사용하였다.
  • SGD나 Momemtum 등을 사용하지 않고 대규모 batch size를 사용할 때 좋다고 알려진 LARS optimizer를 사용하였다.
  • Multi-device로 분산학습을 했는데, Batch Norm을 적용할 때는 device별로 따로 계산하지 않고 전체를 통합하여 평균/분산을 계산했다. 이러면 전체 device 간의 분포를 정규화하므로 정보 손실을 줄일 수 있다.
  • Data Augmentation은
    • Cropping/Resizing/Rotating/Cutout 등 이미지의 구도나 구조를 바꾸는 연산과
    • Color Jittering, Color Droppinog, Gaussian Blurring, Soble filtering 등 이미지의 색깔을 변형하는 2가지 방식을 제안하였다.
    • Augmentation 방법을 1개만 하는 것보다는 여러 개 하는 경우가 prediction task의 난이도를 높여 더 좋은 representation을 얻을 수 있다.
    • 7가지의 data augmentation 방법 중 Random crop + Random Color Distortion 방식을 적용하면 가장 좋은 성능을 보인다고 한다.

Experiments

ImageNet에서 같은 모델 크기 대비 훨씬 좋은 성능을 보인다.

3가지 방법으로 평가한 결과는 아래 표에서 볼 수 있다.

  1. 학습된 모델을 고정하고 linear classifier를 추가한 linear evaluation
  2. 학습된 모델과 linear classifier를 모두 학습시킨 fine-tuning
  3. 학습된 모델을 다른 dataset에서 평가하는 transfer learning

ImageNet 말고 다른 dataset에서 평가한 결과는 아래와 같다. Supervised 방식과 비등하거나 더 좋은 결과도 보여준다.


SimCLR v2(Big Self-Supervised Models are Strong Semi-Supervised Learners)

논문 링크: A Simple Framework for Contrastive Learning of Visual Representations

Github: https://github.com/google-research/simclr

  • 2020년 6월(Arxiv), NIPS
  • Google Research
  • Ting Chen, Simon Kornblith, Kevin Swersky, Mohammad Norouzi, Geoffrey Hinton

SimCLR를 여러 방법으로 개선시킨 논문이며, computer vision에서 unsupervised learning 연구에 큰 기여를 했다.

모델은 다음과 같은 과정을 따른다.

  1. Unsupervised(Self-supervised) Pre-training
  2. Supervised Fine-tuning
  3. Distillation using unlabeled data

이때 Unsupervsed 과정에서는 최종 task와는 무관한 데이터를 사용하였기에 task-agnostic이라는 용어를 사용한다.

Unsupervised(Self-supervised) Pre-training

대량의 unlabeled dataset으로 CNN model을 학습시켜 general representation을 모델이 학습하게 된다. SimCLR와 비슷하지만 다른 점은 SimCLRv1은 projection head를 버리고 downstream task를 수행하지만 v2는 1번째 head까지 포함시켜 fine-tuning이 시작된다.
또한 Projection head의 linear layer 개수도 2개에서 3개로 늘었다.

그 이유는 label fraction(label이 되어 있는 비율)이 낮을수록 projection head의 layer가 더 많을수록 성능이 높아지기 때문이라고 한다.

Supervised Fine-tuning

전술했듯 Projection head의 1번째 layer까지 포함하여 fine-tuning을 진행한다.

Distillation via unlaeled dataset

다음 과정을 통해 distillation을 수행한다.

  • 학습시킬 모델은 student model이다.
  • fine-tuning까지 학습된 teacher model을 준비한다. 이때 student model의 크기는 teacher보다 작다.
  • Unlabeled data를 teacher model과 student model에 집어넣고 teacher model의 output distribution을 얻는다. 여기서 가장 높은 값의 pseudo-label을 얻고 이를 student model의 output distribution과 비교하여 loss를 minimize한다.

이 과정을 통해 teacher model이 갖고 있는 지식을 student model이 학습할 수 있게 된다. 그러면서 크기는 더 작기 때문에 효율적인 모델을 만들 수 있는 것이다.

Ground-truth label과 조합하여 가중합 loss를 계산할 수도 있다.

Experiments

더 큰 모델이 더 좋은 성능을 내는 건 어쩔 수 없는 것 같다..

한 가지 눈여겨볼 것은 큰 모델일수록 label fraction이 낮은 dataset에 대해서 더 좋은 성능을 보인다는 것이다.

  • 또 Projection head를 더 깊게 쌓거나 크기가 클수록 Representation을 학습하는 데 더 도움이 된다.
  • Unlabeled data로 distillation을 수행하면 semi-supervised learning을 향상시킬 수 있다. 이때 label이 있는 경우 같이 사용해주면 좋은데, label이 있는 것와 없는 것을 따로 학습시키기보다는 distillation loss로 위에서 언급한 것처럼 가중합시킨 loss를 사용하면 성능이 가장 좋은 것을 확인할 수 있다.

References

Comment  Read more

Metric Learning 설명

|

이 글에서는 Metric Learning을 간략하게 정리한다.


Metric Learning

한 문장으로 요약하면,

  • Object간에 어떤 거리 함수를 학습하는 task

이다.

예를 들어 아래 이미지들을 보자.

뭔가 “이미지 간 거리”를 생각해보면, 1번째 이미지와 2번째 이미지는 거리가 가까울 것 같다. 이와는 대조적으로, 3번째 이미지는 다른 두 개의 이미지보다 거리가 멀 것 같다.

이런 관계를 학습하는 방식이 Metric Learning이다.

위의 경우는 조금 fine-grained한 경우이고, 좀 더 coarse한 경우는,

1번과 2번 간 거리보다는 3번 간 거리가 훨씬 멀 것 같다.


Training Dataset

그러면, 이러한 관계를 어떻게 데이터셋으로 만들 수 있는가?

  • “1번 이미지가 3번 이미지보다는 2번 이미지와 더 가깝다.”

혹은

  • 연관도 순으로 1 » 2 » 3 » 4

와 같이 쓸 수도 있고, 어떤 식으로든 관계를 설정해서 데이터셋으로 쓸 수 있다.

그렇다면, 왜 이렇게 복잡해 보이는(?) 방식으로 데이터를 구성하고 학습을 시키는가?

당연히, 이러한 데이터셋은 hard하게 labeling하는 것보다 훨씬 쉽게 대량의 데이터를 구성할 수 있다.

그리고, 위와 같이 어쨌든 supervision이 있기 때문에(좀 약하긴 하지만) metric learning은 지도학습의 일종이다.


Problem Formulation

크게 3가지로 생각할 수 있다. 이 중 일반적으로 2번째가 많이 쓰인다.

  1. Point-wise Problem: 하나의 학습 샘플은 Query-Item 쌍이 있고 어떤 numerical/ordinal score와 연관된다.
    • 그냥 classification 문제와 비슷하다. 하나의 query-item 쌍이 주어지면, 모델은 그 score를 예측할 수 있어야 한다.
  2. Pair-wise Problem: 하나의 학습 샘플은 순서가 있는 2개의 item으로 구성된다.
    • 모델은 주어진 query에 대해 각 item에 점수(혹은 순위)를 예측한다. 순서는 보존된다.
    • 그 순위(순서)를 최대한 바르게 맞추는 것이 목표가 된다.
  3. List-wise Problem: 하나의 학습 샘플이 2개보다 많은 item의 순서가 있는 list로 구성된다.
    • 계산량이 많고 어려워서 많이 쓰이지는 않는다.
    • NDCG 등으로 평가한다.

NDCG(Normalized Discounted Cumulative Gain)

널리 쓰이는 ranking metric(순위 평가 척도)이다.

먼저 CG를 정의한다.

Cumulative Gain: result list에서 모든 result의 관련도 점수의 합. $rel_i$는 관계가 제일 높으면 1, 그 반대면 0이라 할 수 있다. 물론 0~1 범위가 아니라 실제 점수를 쓸 수도 있다.

[\text{CG}p = \sum^p{i=1}rel_i]

이제 DCG를 정의하자. Discounted CG는 상위에 있는 item을 더 중요하게 친다.

[\text{DCG}p = \sum^p{i=1} \frac{rel_i}{\log_2(i+1)} = rel_1 + \sum^p_{i=2} \frac{rel_i}{\log_2(i+1)}]

Alternative formulation of DCG를 생각할 수도 있다. 이는 관련도상 상위 item을 더 강조한다.

[\text{DCG}p = \sum^p{i=1} \frac{2^{rel_i}-1}{\log_2(i+1)}]

그런데 DCG는 개수가 많거나 관련도 점수가 높으면 한없이 커질 수 있다. 그래서 정규화가 필요하다.

이제 Normalized DCG를 정의하자. 그러러면 먼저 Ideal DCG를 생각해야 한다. IDCG는 위의 식으로 최상의 결과를 출력했을 때 얻을 수 있는 점수라 생각하면 된다.

[\text{DCG}p = \sum^{REL_p}{i=1} \frac{rel_i}{\log_2(i+1)}]

이제 NDCG는 다음과 같다.

[\text{NDCG}_p = \frac{DCG_p}{IDCG_p}]


Triplet Loss

3개의 point 간 거리를 갖고 loss를 구하는 방식이다.

  • 기준점을 Anchor,
  • anchor와 관련이 높은 point를 Positive,
  • 관련이 없거나 먼 point를 Negative

라 하면,

loss function은 유클리드 거리 함수로 생각할 수 있다.

[\mathcal{L}(A, P, N) = max(\Vert \text{f}(A) - \text{f}(P) \Vert^2 - \Vert \text{f}(A) - \text{f}(N) \Vert^2 + \alpha, 0)]

$N$개의 anchor point로 확장하면 다음과 같이 쓸 수 있다. (사실 같은 식이라 다름없다)

[\mathcal{L} = \sum^N_i [ max(\Vert \text{f}(x_i^a) - \text{f}(x_i^p) \Vert^2 - \Vert \text{f}(x_i^a) - \text{f}(x_i^n) \Vert^2 + \alpha, 0)]]

$\alpha$는 margin을 나타낸다.

Training(Dataset)

학습할 때 anchor과 positive, negative를 잘 설정해야 한다.

랜덤으로 point를 뽑으면 잘 되지 않는 경우가 많다.

여기서 사용할 수 있는 방법으로 Online Negative Mining이 있다.

이는 batch를 크게 잡아서, 현재 sample인 anchor-positive-negative에서 다른 sample의 anchor/positive/negative 중 anchor와 positive 관계인 것을 제외하고 가장 가까운 것부터 선택하여 negative를 대체하여 계산할 수 있다. 이러면 학습이 조금 더 되지만, batch size가 커야 되고, 계산량이 매우 많아지는 단점이 존재한다.

학습을 더 잘 하기 위해서 생각해야 할 방법은 semi-hard negative mining이다.

A와 P는 고정이라 할 때, 아래 3개의 Negative 중 어느 것을 선택해야 학습이 잘 될까라는 문제이다.

  • Hard Negative: $d(a, n_1) < d(a, p)$
    • 언뜻 보면 학습이 잘 될 것 같지만, loss를 줄이는 방법은 그냥 모든 $x$에 대해 $f(x)=0$으로 만들어 버리는 것이라(collapsing) 학습이 잘 안 된다.
  • Semi-hard Negative: $d(a, p) < d(a, n_2) < d(a, p) + \alpha$
    • 이 경우에는 위의 collapsing problem이 발생하지 않으면서 N을 P 밖으로 밀어내면서 학습이 제일 잘 된다.
  • Easy Negative: $d(a, p) + \alpha < d(a, n_3)$
    • N3는 이미 $d(a, p) + \alpha$보다 멀리 있기 때문에 negative로 잡아도 학습할 수 있는 것이 없다.

따라서 semi-hard negative로 sample을 잡아 학습하는 것이 제일 좋다고 한다.


FaceNet: A Unified Embedding for Face Recognition and Clustering

논문 링크: FaceNet: A Unified Embedding for Face Recognition and Clustering

Metric Learning을 사용해서 여러 사진이 있을 때 같은 사람의 사진끼리 모으는 task를 수행했다.

Comment  Read more

Attention based Video Models

|

이 글에서는 Attention 기반 Video (Classification) Model을 간략히 소개한다.


Multi-LSTM

논문 링크: Every Moment Counts: Dense Detailed Labeling of Actions in Complex Videos

LRCN과 비슷하다.

다른 점은,

  • Multiple Input: LSTM에 입력이 1개의 frame이 아니라 N개의 최근 frame에 대해 attention을 적용한다.
    • Query: LSTM의 이전 hidden state $h_{i-1}$
    • Key=value: $N$개의 input frame features
    • Attention value: $N$개의 frame freature의 가중합
  • Multiple Output: 각 LSTM cell은 $N$개의 최근 frame에 대한 예측결과를 출력한다.

Action Recognition using Visual Attention

논문 링크: Action Recognition using Visual Attention

  • LSTM의 이전 hidden state(=query)와 입력 이미지의 region feature(7 x 7 x 1024)를 49개의 candidate로 보고 spatial attention을 수행한다. 이를 통해 attention value(1024차원)를 얻는다.
    • Query: LSTM의 직전 hidden state
    • Key=Value: 입력 이미지 $X_t$의 $K \times K$의 region feature
    • Attention Value: region feature의 가중합. Weight: $h_{t-1}$

이 모델의 장점은 Interpretability가 좋다(spatial attention에 의함). 정답을 맞췄을 때 어떤 부분을 보고 맞추었는지, 혹은 반대로 틀렸을때 어디를 보고 틀렸는지를 볼 수 있다(spatial attention의 의미를 생각해보면 알 수 있다).

Comment  Read more

RNN based Video Models

|

이 글에서는 RNN 기반 Video Classification Model을 간략히 소개한다.

RNN 기반 모델은 single embedding으로 전체 seq를 인코딩하므로 정보의 손실이 입력이 길어질수록 커진다. 또 처음에 들어온 입력에 대한 정보는 점점 잊혀지는 단점이 있다.

이는 Attention Is All You Need 이후 발표되는 Attention 기반 모델에서 개선된다.


LRCN: Long-term Recurrent Convolutional Network

논문 링크: Long-term Recurrent Convolutional Networks for Visual Recognition and Description

비디오의 일부 frame을 뽑아 CNN을 통과시킨 뒤 LSTM에 넣고 그 결과를 평균하여 최종 출력을 얻는 상당히 straight-forward 한 방법론을 제시한다.

모든 input frame마다 LSTM cell이 activity class를 예측한다.


Beyond Short Snippets

논문 링크: Beyond Short Snippets: Deep Networks for Video Classification

이전에는 짧은 video snippet(16 frame 등)만을 사용했는데 이 논문에서 (거의?) 처음으로 긴 영상(300 frame)을 다룰 수 있게 되었다.

Pooling 방법을 여러 개 시도하였는데,

  • Conv Pooling: 사실상 max와 갈다.
  • Late Pooling: max 전에 FC layer를 추가했다.
  • Slow Pooling: FC를 사이에 추가하면서 max를 계층적으로 취한다.
  • Local Pooling: max를 지역적으로만 사용하고 softmax 전에 이어 붙인다.
  • Time-domain conv Pooling: 1x1를 max pooling 전에 사용한다.

근데 결과는 Conv Pooling(max pool)이 가장 좋았다고 한다..

LSTM 갖고도 실험을 해보았는데, Multi-layer LSTM을 사용하였다.

frame-level prediction을 aggregate하기 위해 여러 방식을 사용하였다.

  • Max, average, weighted, …

물론 결과는 비슷하다고 한다.


FC-LSTM

논문 링크: Generating Sequences With Recurrent Neural Networks

LSTM cell에서, 각 gate 계산 시 short term에 더해 long term 부분을 집어넣어 계산하였다.


Conv-LSTM

논문 링크: Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting

데이터의 Spatio-Temporal dynamics를 학습하기 위해 FC-LSTM을 확장한 모델이다.

변경점은,

  1. Input $x$와 hidden $(h, c)$가 vector에서 matrix(Tensor)로 확장되었다.
  2. Weights $(W, U)$를 $(x, h)$와 연산할 때 FC 대신 Conv 연산으로 바꾼다.

이로써 장점이 생기는데,

  • Multi-layer LSTM과 비슷하게 ConvLSTM layer를 여러 층 쌓을 수 있다.
  • seq2seq 구조에 적용하면 decoder는 이후 출력을 예측할 수 있다.

Conv-GRU

논문 링크: Delving Deeper into Convolutional Networks for Learning Video Representations

GRU에다가 같은 아이디어를 적용한 논문이다.


Comment  Read more

BST(Behavior Sequence Transformer for E-commerce Recommendation in Alibaba) 설명

|

이번 글에서는 Alibaba에서 발표한 추천시스템 알고리즘인 BST에 대해 다뤄보도록 하겠습니다. 논문 원본은 이 곳에서 확인할 수 있습니다. 본 글에서는 핵심적인 부분에 대해서만 살펴보겠습니다.


Behavior Sequence Transformer for E-commerce Recommendation in Alibaba 설명

1. Background

Alibaba는 Taobao라고 하는 거대한 쇼핑몰을 갖고 있는데, 본 논문은 이 Taobao에서 발생한 로그의 일부를 효과적으로 활용하여 고객들에게 더 나은 구매경험을 제공하기 위한 방법에 대해 설명하고 있습니다.

예를 들어 어떤 고객이 순서를 갖고 여러 아이템을 클릭했다고 하면 분명 노이즈는 존재하긴 하겠습니다만 이러한 고객의 행동 시퀀스에는 구매행동에 관한 시그널이 존재할 가능성이 높습니다.

Wide & Deep이나 Deep Interest Network에서도 이러한 문제를 풀기 위해 방법을 제시하였지만 여러 한계점이 존재합니다.

본 논문에서는 BST라고 하는 알고리즘을 제시하고 있고, 이 방법론은 기존의 한계를 극복하기 위해 고객의 행동 시퀀스 속에 있는 시그널을 효과적으로 포착/통합하기 위해 설계되었습니다.


2. Architecture

Bert4Rec이나 Transformers4Rec에 비해 복잡한 편은 아니지만, 구조를 잘 들여다보면 생각해볼 부분이 꽤 있습니다. 고객의 행동 시퀀스를 $S(u) = { v_1, v_2, …, v_n }$ 이라고 표기해보겠습니다. 이 때 고객 $u$ 는 총 $n$ 개의 아이템에 대해 클릭을 행했다고 가정하는 것입니다.

고객의 특성, 고객이 만든 feedback(click) sequence, 타겟 아이템의 특성 등을 모델링에서 활용할 수 있을 것입니다. 이러한 특성을 활용하는 데에는 여러가지 방법을 취할 수 있습니다. 이제 상세히 설명하겠지만 BST에서는 각 아이템 사이에서 일어나는 상호작용을 Transformer Layer를 통해 포착하고, 이 결과물을 user feature과 함께 연결하여 최종적으로 downstream task를 푸는 방식으로 추천/예측 알고리즘을 전개해나가고 있습니다.

2.1. Embedding Layer

Embedding Layer는 2개로 구분되어 있습니다.

1번째: other features의 embedding layer
2번째: sequence item features의 embedding layer

이 때 sequence item featuers는 고객이 반응한 아이템의 시퀀스를 의미하고, 논문에서는 오직 아이템의 ID와 카테고리만을 사용하고 있습니다. 만약 본인이 해결해야 할 문제에서 아이템의 정보가 부족하고 오직 ID 정도만을 사용할 수 있더라도 이와 같이 접근할 수 있다고 생각하면 될 것 같습니다.

other features로는 user profile features, target item features, context features 등이 있고 쉽게 말해서 위 sequence item features를 제외한 모든 것이라고 보면 됩니다.

항목 표기 Shape
sequence item feature embedding matrix $\mathbf{W}_v$ $(\vert V \vert, d_v)$
other feature embedding matrix $\mathbf{W}_o$ $(\vert D \vert, d_o)$

이 때 $\vert V \vert$ 는 아이템의 수를, $d_o, d_v$는 각각 세팅한 임베딩 dimension을 의미합니다.
$\vert D \vert$ 는 other features의 feature 수 입니다. 그런데 논문에 나온 것처럼 임베딩을 하기 위해서는 사실 모든 other features는 categorical variable이어야 합니다. 따라서 원-핫 인코딩 및 구간화를 통해 모든 변수를 categorical하게 변환하는 과정을 거치거나 아니면 이 대신 간단하게 projection layer를 사용할 수도 있을 것입니다. task의 성격이나 성능 등을 고려하여 설계를 해야될 것으로 보입니다.

sequence item features는 positional embedding 과정을 거치게 됩니다. BST에서는 vanilla transformer에 등장하였던 sin, cos 함수가 아닌 아래와 같은 식으로 시간 순서를 반영합니다.

[pos(v_i) = t(v_t) - t(v_i)]

논문에서는 위와 같이 추천 시간에서 해당 아이템이 클릭된 시간을 뺀 것으로 정의되는데, 이는 Alibaba의 연구 상에서는 이 방법이 더 뛰어났기 때문이라고 합니다. 이 부분은 본인의 task에 맞게 선택적으로 수용하면 될 것 같습니다.

2.2. Transformer Layer

Transformer Layer는 고객의 행동 시퀀스를 통해 여러 아이템 사이의 관계를 포착하는 역할을 수행합니다.

1번째: self-attention layer
2번째: point-wise feed-forward network

먼저 self-attention layerAttention is all you need를 비롯하여 수 많은 논문에서 등장하는 그 형태 그대로입니다.

[Attention(\mathbf{Q, K, V}) = softmax(\frac{\mathbf{QK}^T}{\sqrt{d}}\mathbf{V})]

이 연산은 아이템 임베딩을 input으로 취하고 이를 linear projection을 통해 3개의 행렬로 변환한 뒤 attention layer로 투입하는 역할을 수행합니다.

[\mathbf{S} = MH(\mathbf{E}) = concat(head_1, head_2, …, head_h) \mathbf{W}^H]

[head_i = Attention(\mathbf{EW}^Q, \mathbf{EW}^K, \mathbf{EW}^V)]

이 때 $\mathbf{W}$ 들은 모두 $(d, d)$ 의 shape을 취하고 있습니다.

point-wise feed-forward network는 이후에 비선형성을 강화하는 역할을 수행합니다.

[\grave{\mathbf{S}} = LayerNorm(\mathbf{S} + Dropout(MH(\mathbf{S})))]

[\mathbf{F} = LayerNorm(\grave{\mathbf{S}} + Dropout(LeakyReLU(\grave{\mathbf{S}} \mathbf{W^{1} + b^{1}})\mathbf{W}^{2} + b^2 ))]

여기까지가 이후에 나올 MLP layer 이전의 형태입니다.

2.3. MLP layers and Loss Function

이후 과정은 간단합니다. other features의 결과물과 transformer layer의 결과물을 모두 concat한 뒤 몇 개의 fully connected layer를 거치면 최종 output을 반환하게 됩니다. 손실 함수로는 cross-entropy를 사용하였다고 밝히고 있습니다.

여기까지가 BST의 구조인데, 본인의 task에 맞게 세부 구성을 수정할 수 있습니다. 예를 들어 MLP layer의 input을 결정할 때 논문에서처럼 모두 concat하는 방식을 취할 수도 있지만, SRGNN에서처럼 soft attention을 이용하여 session embedding vector를 만들 수도 있을 것입니다.

[\alpha_i = \mathbf{q}^T \sigma ( \mathbf{W}_s \mathbf{v}_i + \mathbf{c} )]

[\mathbf{s}g = \Sigma{i=1}^n \alpha_i \mathbf{v}_i]

결국은 아이템 간의 상호작용을 어떻게 포착할 것인가, 그리고 그 과정이 끝난 아이템 벡터들을 어떻게 통합할 것인가는 연구자/분석가가 편의에 맞게 설정할 수 있는 것입니다.


3. Experiments and Conclusions

논문에서의 실험은 Taobao 앱 데이터를 통해 이루어 졌습니다. 7일치를 학습데이터로 사용하고 마지막 1일을 테스트 데이터로 사용하였다고 합니다. AUC, CTR, Response Time 등을 측정하였고 Adagrad를 통해 gradient descent를 수행하였습니다.

BST가 비교 대상인 WDL, DIN을 outperform하였는데, 이는 Transformer Layer를 통해 고객의 행동 시퀀스에 내재되어 있는 sequential signal을 더욱 잘 포착했기 때문이라고 논문은 서술하고 있습니다.

BST의 경우 transformer를 이용한 다른 알고리즘과 마찬가지로 시퀀스에 유의미한 정보가 포함되어 있을 경우 그 효과를 발휘할 수 있을 것으로 기대됩니다. 또한 user feature와 같은 보조적인 정보도 충분히 활용할 수 있다는 장점을 갖습니다.

Comment  Read more