froglog

プログラミングや統計の話など

PyStan で多次元混合正規分布を学習する

このエントリについて

PyStan の実行環境を用意したので、モデルパラメータ推定に使ってみました。 個人的に慣れのある多次元の混合正規分布GMM: Gaussian Mixture Model)のパラメータを学習してみます。

GMM

複数正規分布の重なりによって表される確率モデルです。 詳しくはググれ

かつて音声認識の仕事をしていたときによく触っていたという慣れがあり、このモデルを選びました。 通常は GMM といえば EM アルゴリズムによる学習が一般的なのかなと思います。 でも今回は MCMC

学習データ

多次元の正規分布にもとづく学習用データを、Python機械学習用ライブラリである scikit-learnmake_classification() メソッドで用意しました。

例えば以下は2次元かつ4混合の GMM から作られたデータのプロットです。 次元相関のある4つの2次元正規分布を見て取れるかと。(重なりあり)

f:id:soonraah:20141006010011p:plain

今回初めて使ってみたのですが、適当な実験データを作るのに sklearn.datasetメソッドはなかなか便利そうです。

PyStan によるモデルパラメータ学習

では実際に PyStan による多次元 GMM のモデルパラメータ学習をコードを交えて見ていきます。 今回は Python のコードと Stan のコードを別ファイルにしていますが、Python の中に Stan コードを文字列で埋め込むこともできます。

Stan

まず Stan のモデル定義を multi_dimensional_gmm.stan として保存します。

Stan code for multi dimension GMM with full covari ...

ちょっとググって見たのですが、多次元かつフルの共分散行列を学習できる Stan コードの例が見つからなかったので、

を参考にして多次元かつ分散を学習できるように手を加えて作りました。*1 *2

Python

上記の Stan のコードを使って実際に学習を回すための Python コードが以下になります。 make_classification() による 学習データ生成も含んでいます。

Python code to train GMM by PyStan.

ポイントとなる部分を見ていきます。

  • 33行目: 2次元かつ4混合の正規分布にもとづく学習データを1000サンプル作成します。
  • 36行目: ファイル multi_dimensional_gmm.stan のモデル定義を読み込んで C++コンパイルします。
  • 40行目: Stan の data セクションで定義されたデータを辞書形式で渡し、学習させます。学習自体はこの1行だけで書けてしまいます。
  • 42行目: 辞書ライクな形式で return された学習結果を表示させます。

コンパイル部分に結構時間がかかり、私の買ったばかりの性能とデザインを研ぎすました iMac で25秒程度要しました。 コンパイルの頻度が高いとシステム化するときに問題になりそう。

学習結果

前述のプロットの学習データのときの学習結果の出力は次のようになりました。 (適当に改行を追加しています)

OrderedDict([
    ('weights', array([ 0.35766754,  0.21940132,  0.30998937,  0.11294177])),
    ('mu', array([
        [ 0.9772859 , -0.90527383],
        [-0.93987368, -0.90560861],
        [ 1.18169255,  1.19923733],
        [-1.00850481,  1.20498538]])), 
    ('sigma', array([
       [[ 0.0873947 , -0.11510449],
        [-0.11510449,  0.87560445]],

       [[ 0.66621811,  0.84946125],
        [ 0.84946125,  1.29226115]],

       [[ 1.78483369,  1.18724188],
        [ 1.18724188,  0.99583823]],

       [[ 0.00758462, -0.04079781],
        [-0.04079781,  0.70870785]]]))])

mu つまり平均の座標がだいたいグラフプロットと合ってるのが分かるかと思います。*3

学習できているようです。

まとめ

  • PyStan の MCMC で多次元 GMM のパラメータを学習できた
  • Stan のコードを C++コンパイルするのに時間がかかる
  • scikit-learn の datasets のメソッドは適当な実験データ作るのに便利

じゃあ EM アルゴリズムと比べてどうなのか?というのが疑問ですね。 こちらはまた別のエントリで実験しようかな。

2014-10-06 追記

怪しい点があったのでデータ等を修正しました。

*1:共分散行列のサンプリングの書き方はあれでいいのか?等、怪しい部分もあるので「これは変だ」と気付いた方がいらっしゃったらご指摘いただけると嬉しいです。

*2:2014-10-06 勘違いしていたところがあったので修正。「~」によるサンプリングの部分は不要でした。

*3:実際、make_classification() はデフォルトで n 次元立方体の頂点、この場合だと (1, 1), (1, -1), (-1, 1), (-1, -1) に分布の平均を置くので概ね近い値になっています。