쉬어가기: Back-translation을 재해석하기

이번 절에서는 앞서 이야기했던 back-translation을 듀얼리티 관점에서 수식으로 해석해보 겠습니다. 기존의 back-translation이 추상적 관점에서 왜 잘 동작하는지 이야기하고 넘어갔다면, 이번에는 수식의 해석을 통해 그 이유를 파악할 수 있습니다. 이와 관련해 논문[62]에서는 기존의 back-translation에 대한 해석을 듀얼리티 관점에서 수식으로 풀어냈습니다. 한번 살 펴보겠습니다.

먼저, 다음과 같이 N개의 소스 문장 $x$ , 타깃 문장 $y$ 으로 이루어진 양방향 병렬 코퍼스 $\mathcal{B}$ 와 S개의 타깃 문장 $y$ 로만 이루어진 단일 언어 코퍼스 $\mathcal{M}$ 이 있다고 가정합니다.

B={(xn,yn)}n=1NM={ys}s=1S\begin{aligned} \mathcal{B}&=\{(x^n, y^n)\}^N_{n=1} \\ \mathcal{M}&=\{y^s\}^S_{s=1} \end{aligned}

앞서 다룬 DUL과 같이, 우리가 최종적으로 최소화하려는 손실 함수는 다음과 같이 표현할 수 있습니다.

L(θ)=n=1NlogP(ynxn;θ)s=1SlogP(ys)\mathcal{L}(\theta)=-\sum_{n=1}^N{\log{P(y^{n}|x^{n};\theta)}}-\sum_{s=1}^S{\log{P(y^s)}}

DUL과 마찬가지로 $P(y)$ 는 주변 분포의 속성을 활용해 표현할 수 있을 겁니다. 다만 여기서는 좌변이 $P(y)$ 가 아닌 $\log{P(y)}$ 임을 주목해주세요.

logP(y)=logxXP(yx)P(x)=logxXP(xy)P(yx)P(x)P(xy)xXP(xy)logP(yx)P(x)P(xy)=ExP(xy)[logP(yx)+logP(x)P(xy)]=ExP(xy)[logP(yx)]+ExP(xy)[logP(x)P(xy)]=ExP(xy)[logP(yx)]KL(P(xy)P(x))\begin{aligned} \log{P(y)}&=\log{\sum_{x\in\mathcal{X}}{P(y|x)P(x)}} \\ &=\log{\sum_{x\in\mathcal{X}}{P(x|y)\frac{P(y|x)P(x)}{P(x|y)}}} \\ &\ge\sum_{x\in\mathcal{X}}{P(x|y)\log{\frac{P(y|x)P(x)}{P(x|y)}}} \\ &=\mathbb{E}_{\text{x}\sim{P(\text{x}|y)}}\Big[\log{P(y|\text{x})}+\log{\frac{P(\text{x})}{P(\text{x}|y)}}\Big] \\ &=\mathbb{E}_{\text{x}\sim{P(\text{x}|y)}}\Big[\log{P(y|\text{x})}\Big]+\mathbb{E}_{\text{x}\sim{P(\text{x}|y)}}\Big[\log{\frac{P(\text{x})}{P(\text{x}|y)}}\Big] \\ &=\mathbb{E}_{\text{x}\sim{P(\text{x}|y)}}\Big[\log{P(y|\text{x})}\Big]-\text{KL}\big(P(\text{x}|y)||P(\text{x})\big) \end{aligned}

우리는 젠센 부등식Jensen’s inequality 정리를 사용해 항상 $\log{P(y)}$ 보다 작거나 같은 수식으로 정리할 수 있습니다. 젠센 부등식에 대해 더 이야기해보겠습니다. 로그의 함수 곡선은 다음과 같이 생겼습니다. 이때 두 점 $x_1,x_2$ 에 대한 평균을 $x_m=1/2\times(x_1+x_2)$ 라고 하겠습니다. 그럼 $\log{x_m}\ge1/2\times(\log{x_1}+\log{x_2})$ 은 항상 성립하는 것을 알 수 있습니다.

젠슨스 부등식의 예

이 성질을 이용하여 우리는 $\log{P(y)}$ 보다 항상 작거나 같음을 알 수 있습니다. 여기에 음의 부호 를 붙여주면 부등호의 방향은 바뀔 것입니다.

logP(y)ExP(xy)[logP(yx)]+KL(P(xy)P(x))-\log{P(y)}\le-\mathbb{E}_{\text{x}\sim{P(\text{x}|y)}}\Big[\log{P(y|\text{x})}\Big]+\text{KL}\big(P(\text{x}|y)||P(\text{x})\big)

우리는 부등호가 항상 성립함을 확인했으므로, $-\log{P(y)}$ 보다 항상 큰 수식을 최소화하는 것은 마찬가지로 $-\log{P(y)}$ 도 최소화하는 것임을 알 수 있습니다. 그럼 조금 전 최소화하려 했던 $\mathcal{L}(\theta)$ 에 이 수식을 대입해보면 다음과 같이 정리할 수 있습니다.

L(θ)n=1NlogP(ynxn;θ)s=1S(ExP(xys)[logP(ysx;θ)]KL(P(xys)P(x)))n=1NlogP(ynxn;θ)1Ks=1Si=1KlogP(ysxi;θ)+s=1SKL(P(xys)P(x))=L~(θ)\begin{aligned} \mathcal{L}(\theta)&\le-\sum_{n=1}^N{\log{P(y^{n}|x^{n};\theta)}}-\sum_{s=1}^S{\Big(\mathbb{E}_{\text{x}\sim{P(\text{x}|y^s)}}\big[\log{P(y^s|\text{x};\theta)}\big]-\text{KL}\big(P(\text{x}|y^s)||P(\text{x})\big)\Big)} \\ &\approx-\sum_{n=1}^N{\log{P(y^{n}|x^{n};\theta)}}-\frac{1}{K}\sum_{s=1}^S{\sum_{i=1}^K{\log{P(y^s|x_i;\theta)}}}+\sum_{s=1}^S{\text{KL}\big(P(\text{x}|y^s)||P(\text{x})\big)} \\ &=\tilde{\mathcal{L}}(\theta) \end{aligned}

우리는 결국 $\tilde{\mathcal{L}}(\theta)$ 를 다시 정의했습니다. 앞의 $-\log{P(y)}$ 에 대한 부등호화 마찬가지로 $\tilde{\mathcal{L}}(\theta)$ 를 최소화하는 것은 $\mathcal{L}(\theta)$ 를 최소화하는 효과를 만들 수 있습니다. 따라서 $\tilde{\mathcal{L}}(\theta)$ 를 최소화하기 위해 경사하강법을 수행하여 최적화를 수행해야 합니다.

이 새로운 손실 함수 $\theta$ 에 대해 미분하면 다음과 같이 될 것입니다. 쿨백-라이블러 발산(KLD) 부분은 $\theta$ 에 대해 상수이므로 생략될 것입니다.

θL~(θ)=n=1NθlogP(ynxn;θ)1Ks=1Si=1KθlogP(ysxi;θ)\nabla_\theta\tilde{\mathcal{L}}(\theta)=-\sum_{n=1}^N{\nabla_\theta\log{P(y^n|x^n;\theta)}}-\frac{1}{K}\sum_{s=1}^S{\sum_{i=1}^K{\nabla_\theta\log{P(y^s|x_i;\theta)}}}

최종적으로 얻게 된 손실 함수의 각 부분의 의미를 살펴보겠습니다. 첫 번째 항 $\sum{n=1}^N{\nabla\theta\log{P(y^n|x^n;\theta)}}$ 은 당연히 $x^n$ 이 주어졌을 때, $y^n$ 의 확률을 최대로 하는 $\theta$ 를 찾는 것을 가리킵니다. 두 번째 항 $\frac{1}{K}\sum{s=1}^S{\sum{i=1}^K{\nabla_\theta\log{P(y^s|x_i;\theta)}}}$ 의 경우에는 샘플링된 문장 $x^i$ 이 주어졌을 때, 단일 언어 코퍼스의 문장 $y^s$ 가 나올 평균 확률을 최대로 하는 $\theta$ 를 찾는 것이 됩니다. 결국 back-translation을 통해 수행하던 것이 $\tilde{\mathcal{L}}(\theta)$ 를 최소화하는 것임을 알 수 있습니다.