MCMC と EM アルゴリズムを比べてみた
このエントリについて
前回のエントリで PyStan の MCMC によって GMM (混合正規分布)を学習してみました。 一方、GMM の学習と言えば一般的には EM アルゴリズムが使われることが多いかと思います。
参考:
EM アルゴリズムは山登り法の一種であり、局所最適解(local minimum)に陥る可能性があるのが問題ですが、MCMC は局所最適解にとどまらないという利点を持っています。
この利点が活かされて本当に MCMC の方が EM アルゴリズムよりも良くなるのかを確認したいという欲求で実験してみました、というのがこのエントリです。 たぶん論文探せばあると思いますが自分でやってみたかったんですよ。
やること
やりたいことは以下のとおりです。 アルゴリズムの善し悪しは評価データに対する尤度で見ることにします。
- 多次元ベクトルデータを学習用/評価用に分ける
- 学習用データを使い、EM アルゴリズムで GMM を学習する
- 学習用データを使い、MCMC で GMM を学習する
- 評価用データに対するサンプル平均尤度を 2, 3 それぞれについて計算する
- 1〜4 をデータの区切りを変えて複数回実施する
使ったデータ
実験用にUCI Machine Learning Repositoryから適当なデータを探してきました。 今回は Wine Quality Data Set を使います。
このデータは酸性度や残存糖度、密度などの測定可能な11のパラメータおよび人の手による品質評価値により数々のワインを表現したものです。 その中でも白ワインのデータを使いました。(サンプル数: 4898) 品質評価値は除いて11次元のデータとして扱います。*1
プログラム
Stan および呼び出し側の Python のソースは以下です。
Stan
Stan code to train multi dimensional GMM (Gaussian ...
フル共分散だと学習対象のモデルパラメータが多くなるため、対角共分散を想定したモデルとなっています。
尤度計算の部分でフル共分散だと multi_normal_log()
が使えたのですが対角成分だけでは使えないらしく、次元数で for ループを回していちいち加算しています。
Python
To compare EM algorithm and MCMC on GMM training.
長いですね…。以下解説です。
StanModel の永続化
Stan コードのコンパイルして StanModel
インスタンスを作るには数十秒かかり、何回かスクリプトを回して試すときは結構なストレスになります。
これを避けるために StanModel
インスタンスを Python の pickle
モジュールにより永続化しています。
(77行目付近)
pickle
はシリアライゼーション、すなわちインスタンスのバイナリストリーム化を担うモジュールであり、かつ StanModel
は pickable です。
コンパイル後に永続化しておくことで、次回以降はコンパイル処理は行わず、永続化されたインスタンスをロードするだけの短い処理で済ますことができます。
Cross Validation
4898サンプルの白ワインデータは学習データと評価データに分けて使いますが、1回の学習/評価だけでは心許ないので cross validation (交差検証)を行います。
Python の機械学習ライブラリ scikit-learn
の cross_validation
モジュールは、データを学習用/評価用に分けるためのいくつかの手法を提供しています。
ここでは Random permutations cross-validation a.k.a. Shuffle & Split という手法を使います。ざっくり説明すると、指定の割合でのランダムなデータ分割結果を指定回数分作って List ライクな結果として返却します。
92行目でその結果を ss
に入れています。その後、この ss
で for
文を回してそれぞれの分割結果を取得します。
この実験では学習/評価データ分割を500回実施し、そのそれぞれで EM アルゴリズムと MCMC を比較します。
EM アルゴリズム
EM アルゴリズムも scikit-learn で提供されています。misture.GMM
インスタンスを生成し、学習データを与えて fit()
を実行すれば OK。とても簡単。(107行目)
MCMC
MCMC によるモデルパラメータ学習は StanModel
インスタンスの optimizing()
メソッドで実施します。(114行目)
この実験ではデフォルトの2,000回のイタレーションでは学習が収束しないケースが多かったので、かなり長めの20,000回を指定しました。
尤度計算
各アルゴリズムによる学習結果の GMM は、評価データに対する尤度という形で評価されます。
尤度は mixture.GMM.score()
メソッドの結果をサンプルあたりの平均にするという形で計算しました。(109, 117行目)
つまりサンプル平均尤度です。
上記のメソッドを利用するため、MCMC の結果は一度 mixture.GMM
の形式に変換しています。
グラフ化
実験結果は最終的に matplotlib
でグラフ化します。(127行目)
X軸に EM アルゴリズムの尤度、Y軸に MCMC の尤度を取り、1つの実験結果のプロットが x = y
のときの直線の上にくるか下にくるかでどちらが優勢かを判断します。
500回のデータ分割による cross-validation を実施しているので、500点プロットされます。
実験結果
尤度グラフ
ということでプロットされた結果は以下のようになりました。
なんだこれは…???
EM アルゴリズムの尤度はだいたい -7.0 〜 -6.5 に収まっています。 一方で MCMC の尤度は2つに分かれており、-5.8 付近のグループ(上のかたまり)と-6.9付近のグループ(下のかたまり)に分かれています。 上のかたまりでは EM アルゴリズムに比べて MCMC の尤度がかなり良くなっていますが、下のかたまりでは MCMC が若干負けています。(対数尤度なので 0 に近い方が良い)
上記のスクリプトとは別で同一の学習/評価データで複数の、つまり乱数の異なる optimizing()
を実行したところ、上のかたまり相当の尤度になる場合と下のかたまり相当の尤度になる場合があることが分かりました。
つまり Stan の MCMC も常に最適なところまでたどり着けるわけではなく、最適でないところで収束してしまうことがあるようです。(local minimum と言っていいのか?)*2
実行速度
MCMC の実行時間は EM アルゴリズムに比べ数百〜千倍のオーダーとなってしまいます。 予想はしていましたが、やはり遅い。
まとめ
- EM アルゴリズム(by scikit-learn)と MCMC (by Stan)の比較を GMM で行った
- MCMC だと EM アルゴリズムでは到達できない精度のモデリングが可能
- しかし MCMC も local minimum っぽいところで収束してしまうことがある
- MCMC はとても遅い
その他、今回は Python の勉強が捗りました。scikit-learn はやはり面白い。