A Recurrent Latent Variable Model for Sequential Data

Chungらがvariational autoencoderを時系列データに適用したvariational reccurent neural networkに関する以下の論文のノートを公開します.
J. Chung, K. Kastner, L. Dinh, K. Goel, A.C. Courville, and Y. Bengio, “A Recurrent Latent Variable Model for Sequential Data,” NIPS 2015.

variational autoencoderについて説明した以前の投稿をまだ読んでいない人は先にそちらを読んでください.

系列データに対するモデル

事前確率・尤度・事後確率の近似の考え方

通常のvariational autoencoderとvariational recurrent neural networkの違いは,対象とするデータが系列データになることです.系列データに対してどのようなモデルを用いればよいかは考え方によって変わりますので,様々なモデルが考えられます.ここでは,論文に書かれている考え方について説明します.

$D$次元のデータ$\bm{x}_t^{(i)} \in\mathbb{R}^D$からなる系列データ$\bm{x}^{(i)} = \left\{\bm{x}_t^{(i)} \right\}_{t=1}^{T^{(i)}}$があるとし,系列データ$\bm{x}^{(i)}$からなる系列データセット${\rm X} = \left\{ \bm{x}^{(i)} \right\}_{i=1}^{N}$があるとします.

通常のvariational autoencoderでは,事前確率$p(\bm{z})$と尤度$p(\bm{x}|\bm{z})$からなる2段階の確率過程として生成モデルを考えました.データ$\bm{x}$と潜在変数$\bm{z}$の同時確率$p(\bm{x}, \bm{z})$が条件付き確率の関係を使うと
\begin{equation}
p(\bm{x}, \bm{z}) = p(\bm{x}|\bm{z}) p(\bm{z})
\end{equation}
と表せることから,variational autoencoderにおける生成モデルでは同時確率$p(\bm{x}, \bm{z})$を事前確率$p(\bm{z})$と尤度$p(\bm{x}|\bm{z})$からなる2段階の確率過程に分解していたと考えられます.

では,系列データに対してデータと潜在変数の同時確率$p(\bm{x}_1, \cdots \bm{x}_T , \bm{z}_1, \cdots \bm{z}_T)$を考えましょう.これを系列データの事前確率$p(\bm{z}_1, \cdots \bm{z}_T)$と系列データの尤度$p(\bm{x}_1, \cdots \bm{x}_T |\bm{z}_1, \cdots \bm{z}_T)$に分解することは可能ですが,論文では,連鎖律を使用して各時刻$t$の確率過程に分解してから各時刻における事前確率と尤度に分解しています.連鎖律を使用すると,同時確率は

\begin{equation}
p(\bm{x}_1, \cdots \bm{x}_T , \bm{z}_1, \cdots \bm{z}_T) =
p(\bm{x}_T,\bm{z}_T | \bm{x}_1, \cdots \bm{x}_{T-1}, \bm{z}_1, \cdots \bm{z}_{T-1})
\cdots
p(\bm{x}_2,\bm{z}_2 | \bm{x}_1, \bm{z}_1)
p(\bm{x}_1, \bm{z}_1)
\end{equation}

と各時刻における確率過程に分解できます.

今,この関係を以下のように書くことにします.

\begin{equation}
p(\bm{x}_{\leq T}, \bm{z}_{\leq T}) =
\prod_{t=1}^{T}
p(\bm{x}_t,\bm{z}_t | \bm{x}_{<t}, \bm{z}_{<t})
\end{equation}

ただし,$p(\bm{x}_0, \bm{z}_0) = 1$とします.

次に,各時刻の確率に対して条件付き確率の関係を使用すると

\begin{equation}
p(\bm{x}_{\leq T}, \bm{z}_{\leq T}) =
\prod_{t=1}^{T}
p(\bm{x}_t | \bm{x}_{<t}, \bm{z}_{\leq t})
p(\bm{z}_t | \bm{x}_{<t}, \bm{z}_{<t})
\end{equation}

となります.論文では,$p(\bm{z}_t | \bm{x}_{<t}, \bm{z}_{<t})$を時刻$t$における事前確率,$p(\bm{x}_t | \bm{x}_{<t}, \bm{z}_{\leq t})$を時刻$t$における尤度と考えます.

また,事後確率の近似についても様々なモデルが考えられますが,論文ではある時刻$t$の事後確率は過去の潜在変数$\bm{z}_{<t}$と過去と現在のデータ$\bm{x}_{\leq t}$に依存すると考え,$q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})$としています.

$t=1$のときのモデル

まず,時刻$t=1$のときにモデルがどのように表現されるかについて考えましょう.

$t=1$のとき,事前確率は$p(\bm{z}_1)$,尤度は$p(\bm{x}_1 |\bm{z}_1)$,事後確率の近似は$q(\bm{z}_1 |\bm{x}_1)$となります.いずれも通常のvariational autoencoderの表現と変わりませんので,モデルも同じになります.(尤度で使用する確率分布は問題によって変わりますが,図では正規分布としています.)

$t>1$のときのモデル

次に,時刻$t>1$のときにモデルがどのように表現されるかについて考えましょう.

$t>1$のとき,事前確率は$p(\bm{z}_t | \bm{x}_{<t}, \bm{z}_{<t})$,尤度は$p(\bm{x}_t | \bm{x}_{<t}, \bm{z}_{\leq t})$,事後確率の近似は$q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})$であり,いずれも過去の時刻のデータ$\bm{x}_{<t}$と過去の時刻の潜在変数$\bm{z}_{<t}$に依存します.

論文では,再帰ニューラルネットワーク(recurrent neural network)を使用して過去のデータや潜在変数に対する依存関係を表現しています.再帰ネットワーク$f_{\theta}$では,時刻$t$におけるデータ$\bm{x}_t$と潜在変数$\bm{z}_t$を入力とし,図のように再帰的に隠れ層の状態$\bm{h}_t$を求めます.ただし,$\bm{h}_0=\bm{0}$とします.このとき,$\bm{h}_{t-1}$は過去のデータ$\bm{x}_{<t}$と潜在変数$\bm{z}_{<t}$の依存関係を内部状態として保存していると考えます.

この隠れ層の状態$\bm{h}_{t-1}$を事前確率,尤度,事後確率の近似を計算する際に使用すれば,過去の時刻のデータや潜在変数の依存関係を表現できると考えます.また,時刻$t=1$のとき,事前確率は単純だと考え標準正規分布として表現しましたが,$t>1$のときは複雑な確率過程になると考えニューラルネットワーク$\varphi_{\tau}^{\mathrm{prior}}$と正規分布を使用して表現することになります.

特徴抽出器の表現

まとめると,variational recurrent neural networkの構造は図のように表せます.このうち,尤度と事後確率はそれぞれニューラルネットワークと正規分布を使用して表現されています.この2種類のニューラルネットワークが多層であるとすると,前段のネットワークは特徴抽出器の役割をしていると考えられます.decoder networkの前段は潜在変数$\bm{z}_t$の特徴抽出器の役割を,encoder networkの前段はデータ$\bm{x}_t$の特徴抽出器の役割を果たしていると見なせます.

一方,variational recurrent neural networkでは,隠れ層の状態が潜在変数$\bm{z}_t$とデータ$\bm{x}_t$を入力とする再帰型ニューラルネットワークによって変化するように表現されています.この入力部にも特徴抽出器を配置するほうがよいと考えられますので,論文で提案されているモデルでは,decoder networkとencoder networkの前段の特徴抽出器と同じものを再帰型ニューラルネットワークの入力部にも配置しています.

decoder networkの前段の特徴抽出器を$\varphi_{\tau}^{\bm{z}}$,後段のネットワークを$\varphi_{\tau}^{\mathrm{dec}}$,encoder networkの前段の特徴抽出器を$\varphi_{\tau}^{\bm{x}}$,後段のネットワークを$\varphi_{\tau}^{\mathrm{enc}}$とすると,論文で提案されているモデルは図のように表せます.

誤差関数

誤差関数は,各時刻$t$における誤差関数を定義し,それを時間方向に加算したものとすればよいでしょう.また,各時刻$t$における誤差関数は通常のvariational autoencoderの誤差関数と同様に求めればよいです.

時刻$t$における誤差関数

時刻$t$における誤差関数としては,通常のvariational autoencoderと同様に事後確率$p(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})$に対する事後確率の近似$q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})$のKL Divergenceを考えます.

\begin{eqnarray}
&&D_{KL}\left(q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t}) ||
p(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})\right) \nonumber \\
&=& \mathbb{E}_{\bm{z}_t \sim q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})}
\left[ \log \frac{q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})}{p(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})} \right] \nonumber \\
&=& \mathbb{E}_{\bm{z}_t \sim q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})}
\left[ \log \frac{q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})
p(\bm{x}_t | \bm{x}_{< t}, \bm{z}_{<t})}
{p(\bm{x}_t | \bm{x}_{< t}, \bm{z}_{\leq t})
p(\bm{z}_t | \bm{x}_{< t}, \bm{z}_{<t})} \right] \nonumber \\
&=& \mathbb{E}_{\bm{z}_t \sim q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})}
\left[ \log \frac{1}{p(\bm{x}_t | \bm{x}_{< t}, \bm{z}_{\leq t})}
\frac{q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})}{p(\bm{z}_t | \bm{x}_{< t}, \bm{z}_{<t})}
p(\bm{x}_t | \bm{x}_{< t}, \bm{z}_{<t})
\right] \nonumber \\
&=& \mathbb{E}_{\bm{z}_t \sim q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})}
\left[ -\log p(\bm{x}_t | \bm{x}_{< t}, \bm{z}_{\leq t})
+ \log \frac{q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})}{p(\bm{z}_t | \bm{x}_{< t}, \bm{z}_{<t})}
+ \log p(\bm{x}_t | \bm{x}_{< t}, \bm{z}_{<t})
\right] \nonumber \\
&=& -\mathbb{E}_{\bm{z}_t \sim q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})}
\left[\log p(\bm{x}_t | \bm{x}_{< t}, \bm{z}_{\leq t}) \right]
+ \mathbb{E}_{\bm{z}_t \sim q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})}
\left[\log \frac{q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})}{p(\bm{z}_t | \bm{x}_{< t}, \bm{z}_{<t})} \right]
+ \mathbb{E}_{\bm{z}_t \sim q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})}
\left[\log p(\bm{x}_t | \bm{x}_{< t}, \bm{z}_{<t}) \right] \nonumber \\
&=& -\mathbb{E}_{\bm{z}_t \sim q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})}
\left[\log p(\bm{x}_t | \bm{x}_{< t}, \bm{z}_{\leq t}) \right]
+ D_{KL}\left(q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t}) ||p(\bm{z}_t | \bm{x}_{< t}, \bm{z}_{<t})\right)
+\log p(\bm{x}_t | \bm{x}_{< t}, \bm{z}_{<t}) \label{kl}
\end{eqnarray}

右辺第3項はintractableですが,右辺第1項と第2項はtractableですので,右辺第1項と第2項の和を時刻$t$における誤差関数$E_t$とします.

\begin{equation}
E_t = -\mathbb{E}_{\bm{z}_t \sim q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t})}
\left[\log p(\bm{x}_t | \bm{x}_{< t}, \bm{z}_{\leq t}) \right]
+ D_{KL}\left(q(\bm{z}_t | \bm{x}_{\leq t}, \bm{z}_{<t}) ||p(\bm{z}_t | \bm{x}_{< t}, \bm{z}_{<t})\right)
\end{equation}

時刻$t$における誤差関数の計算

$E_t$の第1項である再現誤差の求め方は通常のvariational autoencoderにおける再現誤差の求め方と同じです.

$E_t$の第2項である正則化項の求め方は,$t=1$のときは通常のvariational autoencoderにおける正則化項の求め方と同じですが,$t>1$のときは異なります.通常のvariational autoencoderでは事前確率は単純だと考えて標準正規分布としましたが,variational recurrent neural networkでは事前確率は複雑だと考えてニューラルネットワークと正規分布を組み合わせて表現したためです.

平均ベクトルが$\bm{\mu_0}$,分散共分散行列が対角行列$\mathrm{diag}(\bm{\sigma}_0 \odot \bm{\sigma}_0)$の正規分布$\mathcal{N}(\bm{z}; \bm{\mu_0},\mathrm{diag}(\bm{\sigma}_0 \odot \bm{\sigma}_0))$と平均ベクトルが$\bm{\mu_1}$,分散共分散行列が対角行列$\mathrm{diag}(\bm{\sigma}_1 \odot \bm{\sigma}_1)$の正規分布$\mathcal{N}(\bm{z}; \bm{\mu_1},\mathrm{diag}(\bm{\sigma}_1 \odot \bm{\sigma}_1))$があるとします.このとき$\mathcal{N}(\bm{z}; \bm{\mu_1},\mathrm{diag}(\bm{\sigma}_1 \odot \bm{\sigma}_1))$に対する$\mathcal{N}(\bm{z}; \bm{\mu_0},\mathrm{diag}(\bm{\sigma}_0 \odot \bm{\sigma}_0))$のKL Divergenceは
\begin{eqnarray}
&& D_{KL}\left(\mathcal{N}(\bm{z}; \bm{\mu_0},\mathrm{diag}(\bm{\sigma}_0 \odot \bm{\sigma}_0)) \| \mathcal{N}(\bm{z}; \bm{\mu_1},\mathrm{diag}(\bm{\sigma}_1 \odot \bm{\sigma}_1))\right) \nonumber \\
&=& \frac{1}{2} \sum_{j=1}^{D} \left(
\frac{\sigma_{0,j}^2}{\sigma_{1,j}^2}
+ \frac{(\mu_{1,j}-\mu_{0,j})^2}{\sigma_{1,j}^2}
– 1
+ \log \sigma_{1,j}^2
– \log \sigma_{0,j}^2
\right)
\end{eqnarray}

と求められます.ここで,$D$は次元数です.

系列データに対する誤差関数

系列データに対する誤差関数$E$は,各時刻$t$における誤差関数$E_t$を時間方向に加算すればよいです.

\begin{equation}
E = \sum_{t=1}^{T} E_t
\end{equation}