Auto-Encoding Variational Bayes

Kingmaらがvariational autoencoderを提案した以下の論文のノートを公開します.
D.P. Kingma, M. Welling, “Auto-Encoding Variational Bayes,” ICLR 2014.

生成モデルと推論モデル

生成モデル(generative model)

ここでは,生成モデル(generative model)の考え方について説明します.

$D$次元のデータ$\bm{x}^{(i)} \in\mathbb{R}^D$からなるデータセット${\rm X} = \left\{ \bm{x}^{(i)} \right\}_{i=1}^{N}$があるとし,各データ$\bm{x}^{(i)}$は確率過程(random process)により生成されると考えます.ここでは,以下のような2段階の生成モデルを考えます.

まず,各データ$\bm{x}^{(i)}$は裏に隠れている単純な確率分布から生成されると考えます.この直接観測できない確率分布を事前確率(prior)$p(\bm{z})$と呼び,確率変数$\bm{z}$を潜在変数(latent variable)と呼びます.一般的に生成されるデータは高次元ですが,潜在変数は低次元だと考えます.また,事前確率は単純であると考えますので,ここでは標準正規分布(standard Gaussian)$\mathcal{N}(\bm{z} ; \bm{0}, I)$とします.まず,図に示すように単純な事前確率$p(\bm{z})$からある潜在変数$\bm{z}^{(i)}$がサンプリングされます.

次に,各データ$\bm{x}^{(i)}$は尤度(likelihood)と呼ばれる確率分布$p_{\theta}(\bm{x}|\bm{z})$に従って生成されるとします.この過程は非線形な写像になると考えますので,尤度は複雑な分布になります.詳しくは後述しますが,尤度は非線形写像であるニューラルネットワークと問題に応じた確率分布を組み合わせて表現します.図に示すように,事前確率からサンプリングされたある潜在変数$\bm{z}^{(i)}$から複雑で非線形な確率過程を経て,最終的にデータ$\bm{x}^{(i)}$がサンプリングされます.

推論モデル(inference model)

今,潜在変数$\bm{z}$からデータ$\bm{x}$が生成される過程(generation)を考えましたが,この逆の過程を考えたいと思います.この過程は観測されたデータ$\bm{x}$から裏に隠れている潜在変数$\bm{z}$を推定することに相当しますので,推論(inference)あるいは認識(recognition)と呼ばれます.この確率分布を事後確率(posterior)$p_{\theta}(\bm{z}|\bm{x})$と呼び,事後確率はベイズ則・周辺確率・条件付き確率の関係を順に使うと

\begin{eqnarray}
p_{\theta}(\bm{z}|\bm{x}) &=& \frac{p_{\theta}(\bm{x}|\bm{z}) p(\bm{z})}{p_{\theta}(\bm{x})} \\
&=& \frac{p_{\theta}(\bm{x}|\bm{z}) p(\bm{z})}{\int p_{\theta}(\bm{x, z}) d\bm{z}} \\
&=& \frac{p_{\theta}(\bm{x}|\bm{z}) p(\bm{z})}{\int p_{\theta}(\bm{x|z})p(\bm{z}) d\bm{z}} \label{posterior}
\end{eqnarray}

のように事前確率$p(\bm{z})$と尤度$p_{\theta}(\bm{x}|\bm{z})$によって表すことができます.推論・認識の過程も非線形な写像になると考えられますので,事後確率も複雑な分布になります.図に示すように,あるデータ$\bm{x}^{(i)}$から複雑で非線形な確率過程を経て,ある潜在変数$\bm{z}^{(i)}$が得られます.

tractability

生成モデルと推論モデルは事前確率$p(\bm{z})$,尤度$p_{\theta}(\bm{x}|\bm{z})$,事後確率$p_{\theta}(\bm{z}|\bm{x})$の3種類の確率分布によって表現されました.ここではこれらのtractabilityについて考えます.

tractabilityとは現実的な時間での計算可能性です.それぞれの確率が現実的な時間で計算できる場合にtractable,計算できない場合にintractableと言います.

事前確率$p(\bm{z})$のtractability

まず,事前確率$p(\bm{z})$は標準正規分布でしたので,現実的な時間で計算可能です.したがって事前確率はtractableになります.

尤度$p_{\theta}(\bm{x}|\bm{z})$のtractability

尤度$p_{\theta}(\bm{x}|\bm{z})$は複雑な確率分布ですが,非線形写像であるニューラルネットワークと問題に応じた確率分布によって表現されます.これらも現実的な時間で計算可能ですので,tractableになります.

事後確率$p_{\theta}(\bm{z}|\bm{x})$のtractability

事後確率$p_{\theta}(\bm{z}|\bm{x})$は式(\ref{posterior})のように事前確率$p(\bm{z})$と尤度$p_{\theta}(\bm{x}|\bm{z})$を使用して表せました.事前確率と尤度はtractableですが,分母の全ての$\bm{z}$に関して積分する部分が現実的な時間で計算できませんので,intractableになります.

tractableな事後確率の近似$q_{\phi}(\bm{z}|\bm{x})$の導入

図に示すように,生成モデルの事前確率$p(\bm{z})$,尤度$p_{\theta}(\bm{x}|\bm{z})$はtractableですが,推論・認識モデルの事後確率$p_{\theta}(\bm{z}|\bm{x})$はintractableになり現実的な時間で計算できません.

そこで,事後確率の近似であるtractableな確率分布$q_{\phi}(\bm{z}|\bm{x})$を導入します.詳しくは後述しますが,推論・認識の過程も複雑で非線形な確率過程と考えられますので,ニューラルネットワークと例えば正規分布を組み合わせて表現します(これらは現実的な時間で計算できるのでtractableです).

複雑で非線形な確率過程の表現

推論モデルにおける$q_{\phi}(\bm{z}|\bm{x})$の表現

観測されたデータ$\bm{x}$から潜在変数$\bm{z}$を求める推論の過程は複雑で非線形な確率過程と考えられるため,図のように非線形写像であるニューラルネットワークと例えば正規分布$\mathcal{N}(\bm{z} ; \bm{\mu_{\phi}}, \mathrm{diag}(\bm{\sigma_{\phi}} \odot \bm{\sigma_{\phi}}))$を組み合わせて表現します.

この過程は観測されたデータから潜在的な意味のようなものを取り出す過程と見なせます.データをエンコードするという意味で,推論モデルで使用するニューラルネットワークをエンコーダ(encoder)と呼びます.

また,潜在空間の各軸は互いに独立した意味を持っていると考えますので,共分散が$0$である対角行列$\mathrm{diag}(\bm{\sigma_{\phi}} \odot \bm{\sigma_{\phi}})$を共分散行列とする正規分布を使用します.ここで,$\odot$は2つのベクトルを要素ごとに掛け算する演算子を表します.

生成モデルにおける$p_{\theta}(\bm{x}|\bm{z})$の表現

低次元の潜在変数$\bm{z}$から高次元のデータ$\bm{x}$を生成する過程も複雑で非線形な確率過程と考えられるため,非線形写像であるニューラルネットワークと問題に応じた確率分布を組み合わせて表現します.

この過程はエンコードされたデータをデコードして元に戻す過程であると見なせますので,生成モデルで使用するニューラルネットワークをデコーダ(decoder)と呼びます.

正規分布による表現

確率分布としては,推論モデルと同様に正規分布$\mathcal{N}(\bm{x} ; \bm{\mu_{\theta}}, \mathrm{diag}(\bm{\sigma_{\theta}} \odot \bm{\sigma_{\theta}}))$を用います.

観測されるデータの空間の各軸は独立ではないと考えられますが,高次元の共分散を考慮すると推定するパラメータの数が多くなってしまいますので,ここでは共分散が$0$である対角行列$\mathrm{diag}(\bm{\sigma_{\theta}} \odot \bm{\sigma_{\theta}})$を共分散行列とする正規分布を使用します.

また,尤度$p_{\theta}(\bm{x}|\bm{z})$からデータ$\bm{x}$を生成する際には,サンプリングは行わず,尤度が最も大きい$\bm{x}$を出力すればよいでしょう.正規分布の場合は$\bm{\mu_{\theta}}$が出力されることになります.

ベルヌーイ分布による表現

対象とするデータ$\bm{x}$が2値データであるとき,ベルヌーイ分布(Bernoulli distribution)$f(\bm{x}; \bm{r})$を使用します.

尤度$p_{\theta}(\bm{x}|\bm{z})$からデータ$\bm{x}$を生成する際には,サンプリングは行わず,尤度が大きいほうの$\bm{x}$を出力すればよいでしょう.

誤差関数

誤差関数(error function)

生成モデルのパラメータ$\theta$と推論・認識モデルのパラメータ$\phi$はわかりませんので,観測できるデータセット${\rm X}$から学習により推定することになります.

その際に学習がうまくいっているかどうかを表す基準を決める必要があります.ここでは基準として,事後確率$p_{\theta}(\bm{z}|\bm{x})$とその近似として導入した$q_{\phi}(\bm{z}|\bm{x})$との差異を使用し,この差異がなるべく小さくなる$\theta$と$\phi$を求めたいと思います.確率分布どうしの差異を測る指標としてKL Divergenceを用います.$p_{\theta}(\bm{z}|\bm{x})$に対する$q_{\phi}(\bm{z}|\bm{x})$のKL Divergenceは
\begin{equation}
D_{KL}\left(q_{\phi}(\bm{z}|\bm{x}) || p_{\theta}(\bm{z}|\bm{x})\right) = \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log \frac{q_{\phi}(\bm{z}|\bm{x})}{p_{\theta}(\bm{z}|\bm{x})} \right]
\end{equation}
と表せます.ベイズ則を適用し,整理すると
\begin{eqnarray}
&&D_{KL}\left(q_{\phi}(\bm{z}|\bm{x}) || p_{\theta}(\bm{z}|\bm{x})\right) \nonumber \\
&=& \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log \frac{q_{\phi}(\bm{z}|\bm{x})p_{\theta}(\bm{x})}{p_{\theta}(\bm{x}|\bm{z})p(\bm{z})} \right] \nonumber \\
&=& \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log \frac{1}{p_{\theta}(\bm{x}|\bm{z})} \frac{q_{\phi}(\bm{z}|\bm{x})}{p(\bm{z})} p_{\theta}(\bm{x}) \right] \nonumber \\
&=& \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ – \log p_{\theta}(\bm{x}|\bm{z}) + \log \frac{q_{\phi}(\bm{z}|\bm{x})}{p(\bm{z})} + \log p_{\theta}(\bm{x}) \right] \nonumber \\
&=& -\mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log p_{\theta}(\bm{x}|\bm{z}) \right]
+ \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log \frac{q_{\phi}(\bm{z}|\bm{x})}{p(\bm{z})} \right]
+ \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log p_{\theta}(\bm{x}) \right] \nonumber \\
&=& -\mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log p_{\theta}(\bm{x}|\bm{z}) \right]
+ D_{KL}\left(q_{\phi}(\bm{z}|\bm{x}) || p(\bm{z})\right)
+ \log p_{\theta}(\bm{x}) \label{kl}
\end{eqnarray}
となります.右辺第3項はintractableですが,右辺第1項と第2項はtractableですので,右辺第1項と第2項の和を誤差関数$E(\theta, \phi; \bm{x})$(error function)とし,誤差ががなるべく小さくなる$\theta$と$\phi$を求めることにしたいと思います.

\begin{equation}
E(\theta, \phi; \bm{x}) =
-\mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log p_{\theta}(\bm{x}|\bm{z}) \right]
+ D_{KL}\left(q_{\phi}(\bm{z}|\bm{x}) || p(\bm{z})\right) \label{error}
\end{equation}

データの対数尤度のLower Bound

誤差関数の別の解釈を考えたいと思います.式(\ref{kl})を書き換えると
\begin{equation}
\log p_{\theta}(\bm{x}) = \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log p_{\theta}(\bm{x}|\bm{z}) \right]
– D_{KL}\left(q_{\phi}(\bm{z}|\bm{x}) || p(\bm{z})\right)
+D_{KL}\left(q_{\phi}(\bm{z}|\bm{x}) || p_{\theta}(\bm{z}|\bm{x})\right) \label{log_likelihood}
\end{equation}
となります.右辺第3項はintractableですが,KL Divergenceは$0$以上になりますので,

\begin{equation}
\log p_{\theta}(\bm{x}) \geq \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log p_{\theta}(\bm{x}|\bm{z}) \right]
– D_{KL}\left(q_{\phi}(\bm{z}|\bm{x}) || p(\bm{z})\right)
\end{equation}

と表せます.これは,データの対数尤度$\log p_{\theta}(\bm{x})$のlower bound $L(\theta, \phi; \bm{x})$が右辺であることを表しています.

\begin{equation}
L(\theta, \phi; \bm{x}) =
\mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log p_{\theta}(\bm{x}|\bm{z}) \right]
– D_{KL}\left(q_{\phi}(\bm{z}|\bm{x}) || p(\bm{z})\right) \label{lower}
\end{equation}

したがって,このlower boundがなるべく大きくなる$\theta$と$\phi$を求めればよいことになります.式(\ref{error})と式(\ref{lower})を比較すると,lower boundは誤差関数の符号が逆転したものですので,誤差関数がなるべく小さくなる$\theta$と$\phi$を求めることと等価になります.

復元誤差(reconstruction error)と正則化(regularization)

誤差関数$-\mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log p_{\theta}(\bm{x}|\bm{z}) \right] + D_{KL}\left(q_{\phi}(\bm{z}|\bm{x}) || p(\bm{z})\right)$の第1項は,あるデータ$\bm{x}$をエンコードし,潜在変数$\bm{z}$を求め,それをデコードしたものが元のデータと同じになる場合に小さな値となることから,復元誤差(reconstruction error)と考えることができます.

また,第2項は事後確率の分布$q_{\phi}(\bm{z}|\bm{x})$が事前確率$p(\bm{z})$の分布と近くなった場合に小さくなります.事後確率の分布には自由度がありますので,これがなるべく単純な分布となるようにする効果があります.したがって第2項は正則化(regularization)の役割を果たしていると考えられます.

復元誤差は実際のデータとエンコード・デコードしたデータとの差ですので,データの空間における誤差になります.一方,正則化項は事後確率と事前確率との差ですので,潜在変数の空間における誤差になります.このように,データ空間と潜在空間の両方で最適化が行われることがわかります.

復元誤差の計算

復元誤差はデータセットの各データ$\bm{x}^{(i)}$に対して求め,平均すればよいです.あるデータ$\bm{x}^{(i)}$に対する復元誤差は,潜在変数$\bm{z}$を$L$個サンプリングし,期待値を求めるとすると
\begin{equation}
-\mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x}^{(i)})} \left[ \log p_{\theta}(\bm{x}^{(i)}|\bm{z}) \right]
= -\frac{1}{L} \sum_{l=1}^{L} \log p_{\theta}\left(\bm{x}^{(i)}|\bm{z}^{(i,l)}\right)
\end{equation}
となります.バッチサイズが大きいバッチ学習を行う場合には$L$は小さな値で十分です.論文には,バッチサイズが100程度であれば,$L$は1としてもよいと書かれています.

まとめると,バッチサイズを$M$とし$M$が大きい値である場合,復元誤差は
\begin{equation}
-\frac{1}{M} \sum_{i=1}^{M} \log p_{\theta}\left(\bm{x}^{(i)}|\bm{z}^{(i)}\right)
\end{equation}
と求められます.

ここで,$\bm{z}^{(i)}$は$q_{\phi}(\bm{z}|\bm{x}^{(i)})$からサンプリングされた潜在変数ですが,サンプリングの操作が入ると$\bm{z}$に関する勾配計算ができなくなるので,勾配計算ができるように変更します.

今,$\bm{z}$が平均$\bm{\mu_{\phi}}$分散$\bm{\sigma_{\phi}} \odot\bm{\sigma_{\phi}}$の正規分布からサンプリングされるとすると
\begin{equation}
\bm{z} \sim q_{\phi}(\bm{z}|\bm{x}^{(i)}) =  \mathcal{N}(\bm{z} ; \bm{\mu_{\phi}}, \mathrm{diag}(\bm{\sigma_{\phi}} \odot \bm{\sigma_{\phi}}))
\end{equation}
となりますが,一旦標準正規分布から確率変数$\bm{\epsilon}$をサンプリングし,
\begin{equation}
\bm{\epsilon} \sim \mathcal{N}(\bm{\epsilon} ; \bm{0}, I)
\end{equation}
$\bm{\epsilon}$を平均$\bm{\mu_{\phi}}$だけ平行移動させ,標準偏差$\bm{\sigma_{\phi}}$で拡大縮小すれば,同じような潜在変数$\bm{z}$が得られると思います.
\begin{equation}
\bm{z} =\bm{\mu_{\phi}} + \bm{\sigma_{\phi}} \odot \bm{\epsilon}
\end{equation}
このようにすることで,$\bm{z}$に関する勾配計算ができるようになります.

正則化項の計算

正則化項は$p(\bm{z})$に対する$q_{\phi}(\bm{z}|\bm{x})$のKL Divergenceですが,想定している分布が正規分布である場合には解析的に求められます.

平均ベクトルが$\bm{\mu_0}$,分散共分散行列が$\Sigma_0$の正規分布$\mathcal{N}(\bm{z}; \bm{\mu_0},\Sigma_0)$と平均ベクトルが$\bm{\mu_1}$,分散共分散行列が$\Sigma_1$の正規分布$\mathcal{N}(\bm{z}; \bm{\mu_1},\Sigma_1)$があるとします.このとき$\mathcal{N}(\bm{z}; \bm{\mu_1},\Sigma_1)$に対する$\mathcal{N}(\bm{z}; \bm{\mu_0},\Sigma_0)$のKL Divergenceは
\begin{eqnarray}
&& D_{KL}\left(\mathcal{N}(\bm{z}; \bm{\mu_0},\Sigma_0) \| \mathcal{N}(\bm{z}; \bm{\mu_1},\Sigma_1)\right) \nonumber \\
&=& \frac{1}{2} \left\{ \mathrm{tr} \left( \Sigma_1^{-1}\Sigma_0 \right)
+ (\bm{\mu}_1-\bm{\mu}_0)^T \Sigma_1^{-1} (\bm{\mu}_1-\bm{\mu}_0)
– D
+ \log \frac{\mathrm{det}(\Sigma_1)}{\mathrm{det}(\Sigma_0)}
\right\}
\end{eqnarray}

と求められます.ここで,$D$は次元数です.

今,$p(\bm{z})$は標準正規分布$\mathcal{N}(\bm{z}; \bm{0},I)$ですので,$\mathcal{N}(\bm{z}; \bm{0},I)$に対する$\mathcal{N}(\bm{z}; \bm{\mu},\Sigma)$のKL Divergenceは
\begin{equation}
D_{KL}\left(\mathcal{N}(\bm{z}; \bm{\mu},\Sigma) \| \mathcal{N}(\bm{z}; \bm{0},I)\right)
= \frac{1}{2} \left\{ \mathrm{tr} \left( \Sigma \right)
+ \bm{\mu}^T \bm{\mu}
– D
– \log \mathrm{det}(\Sigma)
\right\}
\end{equation}
となります.

また,分散共分散行列$\Sigma$が対角行列$\mathrm{diag}(\bm{\sigma} \odot \bm{\sigma})$のとき,$\mathcal{N}(\bm{z}; \bm{0},I)$に対する$\mathcal{N}(\bm{z}; \bm{\mu},\mathrm{diag}(\bm{\sigma} \odot \bm{\sigma}))$のKL Divergenceは
\begin{equation}
D_{KL}\left(\mathcal{N}(\bm{z}; \bm{\mu},\mathrm{diag}(\bm{\sigma} \odot \bm{\sigma})) \| \mathcal{N}(\bm{z}; \bm{0},I)\right)
= \frac{1}{2} \sum_{j=1}^D \left( \sigma_j^2 + \mu_j^2 – 1 – \log \sigma_j^2 \right)
\end{equation}
と求められます.

Variational Autoencoder

Variational Autoencoderは,尤度に正規分布を使用した場合には

のように,尤度にベルヌーイ分布を使用した場合には

のようになり,符号化器(encoder)・復号化器(decoder)・生成器(generator)の3つの要素で構成されます.

学習時には,encoderとdecoderにより学習データセットを処理し,誤差関数が最小となるようにencoder networkとdecoder networkのパラメータを推定します.

うまく学習でき,最適なパラメータが推定できたとすると,generatorにより様々なデータが生成できるようになります.

encoder networkとdecoder networkの構造は問題に応じて決めることになりますが,正規分布の分散を推定する際には対数をとることが多いので,図もそのようにしてあります.(図中の$\log_{\odot}$は要素ごとに対数をとる演算子です.)

付録

データの対数尤度

データの対数尤度は式(\ref{log_likelihood})と表わせますが,上では誤差関数からデータの対数尤度を求めました.誤差関数を介さずにデータの対数尤度を求める方法を書いておきます.

\begin{eqnarray}
&& \log p_{\theta}(\bm{x}) \\
&=& \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[\log p_{\theta}(\bm{x}) \right] \\
&=& \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log \frac{p_{\theta}(\bm{x}|\bm{z}) p(\bm{z})}{p_{\theta}(\bm{z}|\bm{x})} \right] \\
&=& \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log \frac{p_{\theta}(\bm{x}|\bm{z}) p(\bm{z})}{p_{\theta}(\bm{z}|\bm{x})}
\frac{q_{\phi}(\bm{z}|\bm{x})}{q_{\phi}(\bm{z}|\bm{x})} \right] \\
&=& \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log p_{\theta}(\bm{x}|\bm{z}) \frac{p(\bm{z})}{q_{\phi}(\bm{z}|\bm{x})}
\frac{q_{\phi}(\bm{z}|\bm{x})}{p_{\theta}(\bm{z}|\bm{x})} \right] \\
&=& \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log p_{\theta}(\bm{x}|\bm{z}) \right]
+ \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \frac{p(\bm{z})}{q_{\phi}(\bm{z}|\bm{x})} \right]
+ \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \frac{q_{\phi}(\bm{z}|\bm{x})}{p_{\theta}(\bm{z}|\bm{x})} \right] \\
&=& \mathbb{E}_{\bm{z} \sim q_{\phi}(\bm{z}|\bm{x})} \left[ \log p_{\theta}(\bm{x}|\bm{z}) \right]
– D_{KL}\left(q_{\phi}(\bm{z}|\bm{x}) || p(\bm{z})\right)
+ D_{KL}\left(q_{\phi}(\bm{z}|\bm{x}) || p_{\theta}(\bm{z}|\bm{x})\right)
\end{eqnarray}

式(\ref{log_likelihood})と同じになります.