Transformer-XL?
Transformer-XL은 기존 Transformer가 갖고 있는 문제를 해결하기 위해 나온 모델이다. XL은 extra Long의 의미로, fixed-length의 문제를 너머 개선된 Transformer라는 의미다.
Transfomer의 핵심은 Self-Attention으로 Query에 대한 단일 segment(ex. 문장) 내의 관계가 함축되어 있는 embedding을 추출한다는 것이다.
즉, Transformer의 embedding vector(encoded)는 Represention Layer로 생각할 수도 있으며, 이를 가장 잘 활용한 모델이 BERT(s)다
BERT는 Transformer를 block으로 활용하여 각 Represention layer(Transformer의 결과들)을 bi-directional하게 조합하여 엄청난 결과를 만들어냈다.
그러면 Transformer는 어떤 문제가 있기에 Transformer-XL이 나온 것일까?
Why Transformer-XL?
Transformer는 Query에 대한 단일 segment 내의 관계를 함축하는 모델(모듈)이다.
위는 두 가지 문제점을 갖게 된다.
-
Fixed-length(고정된 길이)를 너머의 longer-term dependency를 포착할 수 없다.
-
다른 segment에 대한 관계를 포착할 수 없다는 문제를 갖게 된다. (lack of contextual information),
위에 대한 문제를 Context Fragmentation이라고 부른다.
이를 해결하기 위해 나온 모델이 Transformer-XL이며,
Transformer-XL은 Transformer에 두 가지 구성을 추가하게 된다.
- Segment-level Recurrence Mechanism
- Positional Encoding Scheme
1. Longer-term dependency problem
위 문제는 초기 Language Model들이 갖고 있던 문제다.
Transformer가 RNN 모델과 차별점을 갖게 되는 문제이며, 이는 결국 gradient 문제다.
아래 링크에서 gredient 문제를 잘 설명하고 있으니 자세한 부분은 생략.
RNN gredient vanishing
중요한 것은 Transformer는 위 문제를 Attention 계산법에서 해결했다는 것이다.
수식에서 표현된 것과 같이 각 Query의 row(token/input)들은 독립적으로 K와 V의 계산에 걸려있으므로,
gradient 계산에는 이전 row는 다음 row에 영향을 미치지 않게 된다.
그러나, 중요한 것은 Transformer는 각 segment에 대해 attention value를 구하기 때문에,
pre-defined context length 너머의 longer-term dependency를 포착할 수 없는 한계를 갖게 된다.
그림에서 본 것과 같이 다음 문장(Query)은 Attention Value 계산에 들어가지 않기 때문이다.
2. Context Fragmentation
위 또한 고정된 길이의 Query에 대해 Attention Value를 계산하기 때문에 문제가 발생한다.
당연히 각 Query에 대한 관계를 포착할 수 있겠지만, Transformer는 contextual information이 부족할 수밖에 없고, 이는 contextual 한 task들은 성능에서 제약이 발생할 수밖에 없다.
위 논문에서는 이에 대한 문제를 이해하기 쉽게 도식화하였다. 다음 segment에 대해 information flow가 이어지지 않고, 단일 segment에서 정보가 끊기는 것을 볼 수 있다.
Transformer-XL’s features
1. Segment-level Recurrence Mechanism
그러면 Transformer-XL은 어떻게 위 문제들을 해결할 수 있을까?
첫 번째는 Recurrence Mechanism을 추가한 것이다.
Recurrence Mechanism은
위 식으로 표현될 수 있는데, 여기서 중요한 부분은 이전 segment의 hidden states는 cached 된다는 것이다.
Cache?
cached된 hidden states는 이후 hidden states 계산에 사용되는데, recurrent한 식을 보면 알 수 있듯이 단순히 연속된 segments에 대한 관계뿐만 아니라 전체 segments에 대한 관계/의존관계를 파악할 수 있다.
또한, cache는 evaluation phase에서 계산될 때 재사용되기 때문에 계산은 더 빠를 수밖에 없다.
추가적으로 결국 transformer의 embedding결과는 이전 hidden states를 계산에서 사용하게 되는데, 이는 Recurrence가 같은 layer뿐만 아니라 layer들끼리에서도 이뤄지게 된다는 것을 알 수 있다.
SG는 stop-gradient로 이전 segment는 이후 segment gredient 계산에 영향을 미치지 않게 된다. (단순히 이전 segment를 참조한다는 것으로 이해할 수 있다)
2. Relative Positional Encodings
positonal Encoding을 할 수 있는 방법 중 단순한 방법은 absolute하게 처리하는 것이다.
그러나, 이는 각 segment에 대한 positional한 차이를 구분할 수 없게 되는 문제점을 안게 된다.
이에 Transformer-XL은 다음과 같은 relative positional encoding을 도입하게 된다.
위 식을 self-attention with relative positional encoding에서 처리한 positional encoding처럼 생각할 수 있지만,
transformer-XL에서의 positional encoding은 sinusoid한 positional encoding을 도입하므로써 일반화가능한 term을 구하게 된다. 그리고 해당 relative counterpart부분은 trainable 하지 않다는 점이다. 단순히 position에 대한 정보만 제공하고 있고, trainable한 부분은 임베딩, weights, v term으로 구성된다. 또한 relative한 term은 recurrent 계산법에 따라 absolute postion을 구할 수 있게 된다.
위 두 가지 feature가 결합된 결과의 식은 위 모형이다.
위 식에서 masked 부분은 transformer-XL에서도 optional하게 처리되는 부분이다 (higgingface source code 참고)
또한 많은 사람들이 착각하는 부분은 masked softmax와 transformer decoder가 차이가 있다 생각하지만,
위 둘 다 sublayer에서 처리하는 것으로 같은 것이다.
relative positional encoding은 계산비용이 크지 않나?
결론적으로는 아니다. 이에 대한 부분은 나중 post를 통해서 소개할 수도 있지만, 결론적으론 matrix decomposition과 matrix approximation을 이용하여 left-shifted matrix를 구하게 되기 때문에 계산 복잡도는 O(n)이다
repo="ghk829" issue-term="pathname" theme="github-dark" crossorigin="anonymous" async>