froglog

プログラミングや統計の話など

MCMC と EM アルゴリズムを比べてみた

このエントリについて

前回のエントリで PyStan の MCMC によって GMM (混合正規分布)を学習してみました。 一方、GMM の学習と言えば一般的には EM アルゴリズムが使われることが多いかと思います。

参考:

EM アルゴリズムは山登り法の一種であり、局所最適解(local minimum)に陥る可能性があるのが問題ですが、MCMC は局所最適解にとどまらないという利点を持っています。

この利点が活かされて本当に MCMC の方が EM アルゴリズムよりも良くなるのかを確認したいという欲求で実験してみました、というのがこのエントリです。 たぶん論文探せばあると思いますが自分でやってみたかったんですよ。

やること

やりたいことは以下のとおりです。 アルゴリズムの善し悪しは評価データに対する尤度で見ることにします。

  1. 多次元ベクトルデータを学習用/評価用に分ける
  2. 学習用データを使い、EM アルゴリズムで GMM を学習する
  3. 学習用データを使い、MCMC で GMM を学習する
  4. 評価用データに対するサンプル平均尤度を 2, 3 それぞれについて計算する
  5. 1〜4 をデータの区切りを変えて複数回実施する

使ったデータ

実験用にUCI Machine Learning Repositoryから適当なデータを探してきました。 今回は Wine Quality Data Set を使います。

このデータは酸性度や残存糖度、密度などの測定可能な11のパラメータおよび人の手による品質評価値により数々のワインを表現したものです。 その中でも白ワインのデータを使いました。(サンプル数: 4898) 品質評価値は除いて11次元のデータとして扱います。*1

プログラム

Stan および呼び出し側の Python のソースは以下です。

Stan

Stan code to train multi dimensional GMM (Gaussian ...

フル共分散だと学習対象のモデルパラメータが多くなるため、対角共分散を想定したモデルとなっています。

尤度計算の部分でフル共分散だと multi_normal_log() が使えたのですが対角成分だけでは使えないらしく、次元数で for ループを回していちいち加算しています。

Python

To compare EM algorithm and MCMC on GMM training.

長いですね…。以下解説です。

StanModel の永続化

Stan コードのコンパイルして StanModel インスタンスを作るには数十秒かかり、何回かスクリプトを回して試すときは結構なストレスになります。 これを避けるために StanModel インスタンスPythonpickle モジュールにより永続化しています。 (77行目付近)

pickle はシリアライゼーション、すなわちインスタンスのバイナリストリーム化を担うモジュールであり、かつ StanModel は pickable です。 コンパイル後に永続化しておくことで、次回以降はコンパイル処理は行わず、永続化されたインスタンスをロードするだけの短い処理で済ますことができます。

Cross Validation

4898サンプルの白ワインデータは学習データと評価データに分けて使いますが、1回の学習/評価だけでは心許ないので cross validation (交差検証)を行います。 Python機械学習ライブラリ scikit-learncross_validation モジュールは、データを学習用/評価用に分けるためのいくつかの手法を提供しています。

ここでは Random permutations cross-validation a.k.a. Shuffle & Split という手法を使います。ざっくり説明すると、指定の割合でのランダムなデータ分割結果を指定回数分作って List ライクな結果として返却します。 92行目でその結果を ss に入れています。その後、この ssfor 文を回してそれぞれの分割結果を取得します。

この実験では学習/評価データ分割を500回実施し、そのそれぞれで EM アルゴリズムMCMC を比較します。

EM アルゴリズム

EM アルゴリズムも scikit-learn で提供されています。misture.GMM インスタンスを生成し、学習データを与えて fit() を実行すれば OK。とても簡単。(107行目)

MCMC

MCMC によるモデルパラメータ学習は StanModel インスタンスoptimizing() メソッドで実施します。(114行目)

この実験ではデフォルトの2,000回のイタレーションでは学習が収束しないケースが多かったので、かなり長めの20,000回を指定しました。

尤度計算

アルゴリズムによる学習結果の GMM は、評価データに対する尤度という形で評価されます。 尤度は mixture.GMM.score() メソッドの結果をサンプルあたりの平均にするという形で計算しました。(109, 117行目) つまりサンプル平均尤度です。

上記のメソッドを利用するため、MCMC の結果は一度 mixture.GMM の形式に変換しています。

グラフ化

実験結果は最終的に matplotlib でグラフ化します。(127行目) X軸に EM アルゴリズムの尤度、Y軸に MCMC の尤度を取り、1つの実験結果のプロットが x = y のときの直線の上にくるか下にくるかでどちらが優勢かを判断します。 500回のデータ分割による cross-validation を実施しているので、500点プロットされます。

実験結果

尤度グラフ

ということでプロットされた結果は以下のようになりました。

f:id:soonraah:20141006041715p:plain

なんだこれは…???

EM アルゴリズムの尤度はだいたい -7.0 〜 -6.5 に収まっています。 一方で MCMC の尤度は2つに分かれており、-5.8 付近のグループ(上のかたまり)と-6.9付近のグループ(下のかたまり)に分かれています。 上のかたまりでは EM アルゴリズムに比べて MCMC の尤度がかなり良くなっていますが、下のかたまりでは MCMC が若干負けています。(対数尤度なので 0 に近い方が良い)

上記のスクリプトとは別で同一の学習/評価データで複数の、つまり乱数の異なる optimizing() を実行したところ、上のかたまり相当の尤度になる場合と下のかたまり相当の尤度になる場合があることが分かりました。 つまり Stan の MCMC も常に最適なところまでたどり着けるわけではなく、最適でないところで収束してしまうことがあるようです。(local minimum と言っていいのか?)*2

実行速度

MCMC の実行時間は EM アルゴリズムに比べ数百〜千倍のオーダーとなってしまいます。 予想はしていましたが、やはり遅い。

まとめ

その他、今回は Python の勉強が捗りました。scikit-learn はやはり面白い。

*1:このデータを表現するのに混合正規分布を使うのが本当に適切なのか?という点は考えません。同じフィールドで競わせたときに EM アルゴリズムMCMC どちらが良くなるのか、というのが今の興味だからです。

*2:しかし、①なぜこうもくっきり分かれるのか? ②なぜ下のグループは EM アルゴリズムにちょっぴり負けてしまうのか? ③MCMC の local minimum を避けるにはどうすればいいか? は未解決。分かる方いたら教えてください…