PyStan で多次元混合正規分布を学習する
このエントリについて
PyStan の実行環境を用意したので、モデルパラメータ推定に使ってみました。 個人的に慣れのある多次元の混合正規分布(GMM: Gaussian Mixture Model)のパラメータを学習してみます。
GMM
複数の正規分布の重なりによって表される確率モデルです。 詳しくはググれ。
かつて音声認識の仕事をしていたときによく触っていたという慣れがあり、このモデルを選びました。 通常は GMM といえば EM アルゴリズムによる学習が一般的なのかなと思います。 でも今回は MCMC。
学習データ
多次元の正規分布にもとづく学習用データを、Python の機械学習用ライブラリである scikit-learn の make_classification() メソッドで用意しました。
例えば以下は2次元かつ4混合の GMM から作られたデータのプロットです。 次元相関のある4つの2次元正規分布を見て取れるかと。(重なりあり)
今回初めて使ってみたのですが、適当な実験データを作るのに sklearn.dataset
のメソッドはなかなか便利そうです。
PyStan によるモデルパラメータ学習
では実際に PyStan による多次元 GMM のモデルパラメータ学習をコードを交えて見ていきます。 今回は Python のコードと Stan のコードを別ファイルにしていますが、Python の中に Stan コードを文字列で埋め込むこともできます。
Stan
まず Stan のモデル定義を multi_dimensional_gmm.stan として保存します。
Stan code for multi dimension GMM with full covari ...
ちょっとググって見たのですが、多次元かつフルの共分散行列を学習できる Stan コードの例が見つからなかったので、
を参考にして多次元かつ分散を学習できるように手を加えて作りました。*1 *2
Python
上記の Stan のコードを使って実際に学習を回すための Python コードが以下になります。
make_classification()
による 学習データ生成も含んでいます。
Python code to train GMM by PyStan.
ポイントとなる部分を見ていきます。
- 33行目: 2次元かつ4混合の正規分布にもとづく学習データを1000サンプル作成します。
- 36行目: ファイル multi_dimensional_gmm.stan のモデル定義を読み込んで C++ でコンパイルします。
- 40行目: Stan の
data
セクションで定義されたデータを辞書形式で渡し、学習させます。学習自体はこの1行だけで書けてしまいます。 - 42行目: 辞書ライクな形式で return された学習結果を表示させます。
コンパイル部分に結構時間がかかり、私の買ったばかりの性能とデザインを研ぎすました iMac で25秒程度要しました。 コンパイルの頻度が高いとシステム化するときに問題になりそう。
学習結果
前述のプロットの学習データのときの学習結果の出力は次のようになりました。 (適当に改行を追加しています)
OrderedDict([ ('weights', array([ 0.35766754, 0.21940132, 0.30998937, 0.11294177])), ('mu', array([ [ 0.9772859 , -0.90527383], [-0.93987368, -0.90560861], [ 1.18169255, 1.19923733], [-1.00850481, 1.20498538]])), ('sigma', array([ [[ 0.0873947 , -0.11510449], [-0.11510449, 0.87560445]], [[ 0.66621811, 0.84946125], [ 0.84946125, 1.29226115]], [[ 1.78483369, 1.18724188], [ 1.18724188, 0.99583823]], [[ 0.00758462, -0.04079781], [-0.04079781, 0.70870785]]]))])
mu
つまり平均の座標がだいたいグラフプロットと合ってるのが分かるかと思います。*3
学習できているようです。
まとめ
- PyStan の MCMC で多次元 GMM のパラメータを学習できた
- Stan のコードを C++ でコンパイルするのに時間がかかる
- scikit-learn の datasets のメソッドは適当な実験データ作るのに便利
じゃあ EM アルゴリズムと比べてどうなのか?というのが疑問ですね。 こちらはまた別のエントリで実験しようかな。
2014-10-06 追記
怪しい点があったのでデータ等を修正しました。