2018년 9월 1일 토요일

Spectral Normalization - part.3

이 정리는 Spectral Normalization 논문에 관한 정리입니다. 이번 포스팅에서는 Lipschitz constant의 정의와 의미 그리고 Spectral Normalization이 무엇이고 이것이 어떻게 Lipschitz constant를 control하게 되는지에 대해 알아보겠습니다.

* 아래 포스팅에 등장하는 수식들은 번호 혹은 영어로 태그가 붙어 있습니다. 번호로 태그 된 수식은 논문에서 해당 번호로 똑같이 태그 되어 있는 수식들이고 영어로 태그된 수식은 논문에는 등장은 하지만 따로 태그가 붙어 있지 않은 수식들입니다.

Gradient analysis of the spectrally normalized weights

이번에는 gradient $\partial V(D, G) \over \partial W$를 따져봄으로써 spectral normalization을 통해 $W$가 어떻게 학습이 되는지를 자세하게 알아보도록 하겠습니다.

우선 $D$의 임의의 hidden layer가 있고 이 layer의 input $h$가 있다고 가정하겠습니다. 이 $h$는 weight $\bar{W}_{SN}(W)$와 곱해져 output으로 $\bar{W}_{SN}(W) h$가 나올 것입니다. 그리고 $\bar{W}_{SN}(W) = \frac{W}{\sigma (W)}$입니다. $\bar{W}_{SN}(W)$는 이제 간단하게 $\bar{W}_{SN}$으로 표기하도록 하겠습니다.

이제 궁극적으로 밝히고자 하는 것에 대해 먼저 명확하게 하고 넘어가겠습니다. 분석하고자 하는 것은 $\partial V(D, G) \over \partial W$ 입니다. 왜냐하면 이 값으로 $D$를 학습시키기 때문입니다. 중간에 있는 hidden layer $h$를 고려하여 이 식을 풀어 써보도록 하겠습니다. (행렬을 곱할 때 행렬의 크기를 맞춰야 하는 문제는 고려하지 않고 그냥 Scalar처럼 쓴 식이라 transpose나 곱의 순서 등이 정확하지 않을 수 있습니다 ㅜㅜ)

$\begin{align}
{{\partial {V(D, G)}} \over {\partial W}}
&= {{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}} {{\partial {\bar{W}_{SN} h}} \over {\partial {\bar{W}_{SN}}}} {{\partial{\bar{W}_{SN}}} \over {\partial{W}}}\tag{a} \\
&= {{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}} h {{\partial {\bar{W}_{SN}}} \over {\partial {W}}} \tag{b} \\
\end{align}$

식 (a)는 $h$를 포함해 chain rule로 나타낸 식이고 식 (b)는 $h$는 $W$와 관련이 없으니 미분 밖으로 빠져 나온 것입니다. 우선 ${{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}}$는 $V(D, G)$로부터 $h$까지 타고 내려온 gradient입니다. 이 부분은 Spectral normalization과 상관 없이 back propagation을 사용하면 무조건 생기는 upstream gradient term이므로 간단히 $\delta = {{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}}$라고 하겠습니다. 이렇게 하여 식 (b)를 다시 쓰면 아래와 같이 쓸 수 있습니다.

$\begin{align}
{{\partial {V(D, G)}} \over {\partial W}} = {\delta} h {{\partial {\bar{W}_{SN}}} \over {\partial {W}}} \tag{c}
\end{align}$

이제 여기서 $\delta$와 $h$는 크게 의미가 없습니다. 중요한 부분은 ${{\partial {\bar{W}_{SN}}} \over {\partial {W}}}$ 입니다. 이 term이 논문의 식 (9)와 식 (10)에 설명 되어 있습니다.

$\begin{align}
{{\partial{\bar{W}_{SN}(W)}} \over {\partial{W_{ij}}}}
&= \frac{1}{\sigma(W)} E_{ij} - \frac{1}{\sigma(W)^2} {{\partial{\sigma (W)}} \over {\partial{W_{ij}}}} W \tag{d} \\
&= \frac{1}{\sigma(W)} E_{ij} - \frac{[{u_1} {v_1^T}]_{ij}}{\sigma(W)^2} W \tag{9} \\
&= \frac{1}{\sigma(W)} E_{ij} - \frac{1}{\sigma(W)} [{u_1} {v_1^T}]_{ij} \frac{W}{\sigma(W)} \tag{e} \\
&= \frac{1}{\sigma(W)} (E_{ij} - [{u_1} {v_1^T}]_{ij} \bar{W}_{SN}) \tag{10} \\
\end{align}$

식 (e)는 논문에는 없지만 이해에 도움이 될 것 같은 식을 한 줄 추가한 것입니다. 먼저 식 (d)는 단순히 $\frac{W}{\sigma(W)}$를 곱의 미분법을 사용하여 $W$로 미분한 식입니다. $E_{ij}$는 $(i, j)$만 1이고 나머지는 전부 0을 갖는 행렬을 의미합니다. 이러한 $E_{ij}$가 나온것은 행렬의 미분 방법 때문입니다. 아래 식은 행렬 $W$를 scalar $W_{ij}$로 미분하는 식입니다.

$\frac{\partial \mathbf{W}}{\partial W_{ij}} =
\begin{bmatrix}
\frac{\partial W_{11}}{\partial W_{ij}} & \frac{\partial W_{12}}{\partial W_{ij}} & \cdots & \frac{\partial W_{1n}}{\partial W_{ij}}\\
\frac{\partial W_{21}}{\partial W_{ij}} & \frac{\partial W_{22}}{\partial W_{ij}} & \cdots & \frac{\partial W_{2n}}{\partial W_{ij}}\\
\vdots & \vdots & \ddots & \vdots\\
\frac{\partial W_{m1}}{\partial W_{ij}} & \frac{\partial W_{m2}}{\partial W_{ij}} & \cdots & \frac{\partial W_{mn}}{\partial W_{ij}} \tag{f} \\
\end{bmatrix}$

식 (f)의 분자에서 $W_{ij}$만이 변수로 취급되고 나머지는 모두 상수로 취급 되므로 미분한 결과가 $(i, j)$만 1이고 나머지는 0인 행렬 $E_{ij}$가 나오는 것입니다.

그리고 식 (9)에서 ${\partial{\sigma(W)} \over \partial{W_{ij}}} = [{u_1} {v_1^T}]_{ij}$ 임을 이용합니다. $u_1$과 $v_1$은 $W$를 Singular value decomposition을 한 $W = U \Sigma V^T$의 first left singular vector와 first right singular vector를 의미합니다. 이 부분에 대한 수학적인 증명이 있는데 그건 생략하도록 하고 그냥 Matrix의 spectral norm의 derivative에 대한 정의처럼 받아들이는 것이 좋을 것 같습니다. $\sigma(W)$가 $W$의 first singular value 임을 고려하면 직관적으로 와 닿기도 합니다.

이제 식 (10)을 식 (c)에 대입하여 ${\partial {V(D, G)}} \over {\partial W}$를 구할 수 있습니다. 그런데 이 부분을 그냥 바로 식으로 계산하는 것은 너무 복잡합니다. 행렬의 곱들로 이루어져 있고 행렬의 미분에 의해 생기는 행렬의 크기 변화 등의 문제 때문에 필요한 것들을 하나하나 맞추려면 해야할 것이 너무 많습니다.

이럴 때 저는 그냥 전부 행렬의 곱이 아니라 scalar의 곱으로 가정을 하고 computational graph를 그려서 확인합니다. 물론 이렇게 하면 완전히 정확한 식이 나오지는 않지만 gradient가 어디서 어떻게 만들어져서 어떻게 흘러가는지 등의 필요한 것들은 전부 얻을 수 있습니다. 아래 그림이 computational graph입니다.


< ${\partial V(D, G)} \over {\partial W}$ 의 computational graph >

위의 gradient는 행렬의 곱이 아니라 scalar의 곱으로 가정하고 chain rule만 적용한 gradient이므로 transpose나 곱셈의 순서등은 정확하지 않습니다.

그리고 위의 그림은 data가 1개일 때를 가정한 그림입니다. 이제 논문에서처럼 mini batch를 사용한다고 가정하겠습니다. 이는 간단하게 mini batch의 data의 모든 gradient들의 평균을 사용하면 됩니다. 이제 ${\partial V(D, G)} \over {\partial W}$은 아래와 같이 쓸 수 있습니다.

$\begin{align}
{{\partial {V(D, G)}} \over {\partial W}}
&= \hat{E}[{{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}} h {\frac{1}{\sigma (W)}}] - \hat{E}[{{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}} hW {\frac{1}{{\sigma (W)}^2}} u_1 v_1^T] \tag{g} \\
&= {\frac{1}{\sigma (W)}} \hat{E}[{{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}} h] - {\frac{1}{\sigma (W)}} \hat{E}[{{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}} h \frac{W}{\sigma (W)}] u_1 v_1^T \tag{h} \\
&= {\frac{1}{\sigma (W)}} \hat{E}[{{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}} h] - {\frac{1}{\sigma (W)}} \hat{E}[\delta h \bar{W}_{SN}] u_1 v_1^T \tag{i} \\
&= {\frac{1}{\sigma (W)}} \hat{E}[\delta h] - {\frac{1}{\sigma (W)}} \lambda u_1 v_1^T \tag{j} \\
& where \, \delta = {{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}}, \, \lambda = \hat{E}[\delta h \bar{W}_{SN}] \nonumber \\
\end{align}$

식 (j)와 논문에 나온 식 (12)를 비교했을 때 곱셈 순서나 transpose만 제외하면 정확히 일치함을 볼 수 있습니다. Spectral normalization의 gradient는 위의 computational graph처럼 만들어지는 것입니다. 순서와 transpose를 맞춘 논문에서 주어진 식 (12)를 보도록 하겠습니다.

${{\partial V(D, G)} \over {\partial W}} = \frac{1}{\sigma (W)} (\hat{E}[\delta h^T] - \lambda u_1 v_1^T) \, \, where \, \delta = ({{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}})^T, \, \lambda = \hat{E}[\delta^T (\bar{W}_{SN} h)] \tag{12}$

여기서 첫 번 째 term인 $\frac{1}{\sigma(W)} \hat{E}[\delta h^T]$는 위의 computational graph에서 $W$로 흘러들어가는 gradient 중 위 쪽 gradient에 해당합니다. 이는 $\sigma(W)$로 normalize 하지 않았어도 만들어지는 gradient입니다.

그리고 두 번 째 term인 $-\frac{1}{\sigma(W)} \lambda u_1 v_1^T$가 $\sigma (W)$로 $W$를 normalize 했기 때문에 추가적으로 생긴 gradient입니다. 즉, 첫 번 째 term은 spectral normalization과 상관 없이 원래 생기는 gradient이고 Spectral norm으로 인해 생긴 추가적으로 얻은 효과가 바로 두 번 째 term에 해당하는 것입니다.

이 두 번 째 term은 첫 번 째 term에 대한 regularization으로 해석할 수 있습니다. $\lambda$가 positive가 될 때 첫 번 째 term에 대해 regularization이 걸린다고 볼 수 있는데요. 이 $\lambda$가 positive가 될 가능성이 가장 높은 경우는 $\delta^T$와 $\bar{W}_{SN}$이 비슷한 방향을 가리키고 있을 때입니다.  예를 들어 $\delta^T$와 $\bar{W}_{SN}$의 각 요소들의 positive, negative가 전부 똑같다면 $\lambda$는 100% positive가 될 것입니다.

$\delta = {{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}}$는 $\bar{W}_{SN} h$가 학습되는 방향을 의미합니다. GAN에서 $D$는 $V(D, G)$를 maximize 하는 방향으로 움직이므로 gradient descent가 아니라 gradient ascent로 학습이 되기 때문에 $-{{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}}$가 아니라 ${{\partial {V(D,G)}} \over {\partial {\bar{W}_{SN} h}}}$로 학습이 됩니다.

이렇게 보면 $\lambda$가 positive가 되는 상황은 $\bar{W}_{SN} h$가 현재 가리키는 방향과 $\bar{W}_{SN} h$이 학습 되는 방향이 같을 때이고 이 때 regularization을 걸게 되는 것입니다. 이렇게 해서 spectral normalization을 통해 $D$의 각 layer가 한 쪽 방향으로 너무 sensitive해지는 것을 막아서 각 layer로 하여금 좀 더 여러 방향을 볼 수 있도록 하는 효과를 갖게 된다고 할 수 있습니다.

여기까지 Spectral normalization의 gradient analysis에 관한 포스팅이었습니다. 굉장히 험난한 과정이었네요ㅎㅎ 이번 포스팅은 여기서 마치도록 하겠습니다.

댓글 3개:

  1. 이해하기 쉽게 잘 해설해 놓으셔서 많은 도움이 되었습니다. 감사합니다. 꾸벅

    답글삭제
  2. 논문 리뷰하는 데에 너무 큰 도움이 되었습니다 감사합니다.

    답글삭제