froglog

プログラミングや統計の話など

線形分類器のオンライン学習を実装してみた

このエントリについて

最近読んだオンライン学習の本が分かりやすくて面白かったので、紹介されているアルゴリズムを実装して遊んでみました。

書籍紹介

オンライン機械学習 (機械学習プロフェッショナルシリーズ)

オンライン機械学習 (機械学習プロフェッショナルシリーズ)

タイトルのとおりオンライン学習について書かれた書籍です。 かなり分かりやすく書かれていて、数学*1が苦手な私でも比較的楽に読むことができました。 機械学習の入門書としても良いかもしれません。 主に2値分類の線形分類器について書かれていますが、SVM や深層学習の話も出てきます。

このエントリでは本書のセクション "4.1 高度なオンライン学習" で紹介されている線形分類器のオンライン学習アルゴリズムを実装して動かしてみます。

ソースコードはこちら、Scala です。

アルゴリズム

セクション "4.1 高度なオンライン学習" では以下のオンライン学習アルゴリズムが紹介されています。

  • Perceptron*2
  • PA (Passive-Aggressive)*3
  • PA-I
  • PA-II
  • CW (Confidence Weighted Learning)*4
  • AROW (Adaptive Regularization of Weight Vectors)*5
  • SCW-I (Soft Confidence-Weighted Learning)*6
  • SCW-II

上から古い順になっており、新しいものほど性能が良いとされているとのこと。

これらのアルゴリズムの学習はすべて以下の形のアルゴリズムで表すことができます。

A base class of online learning for binary linear ...

学習はメソッド train() で行います。 w は学習しようとしている線形分類器の重みベクトル、またはそれを生成する正規分布の平均ベクトルで、sigma はその共分散行列です。 これらを更新していくのがこのオンライン学習の目的です。 *7

アルゴリズムごとに異なるのはパラメータ更新の条件判定に用いる e 、および wsigma の更新に使われる alphabeta になります。 各アルゴリズムの具体的な説明は省きますが、こちらにそれぞれの実装があるのでご興味のある方はどうぞ。

実験

実際にデータを突っ込んで性能を見てみます。

データ

Breast Cancer Wisconsin (Diagnostic) Data Set を使いました。 30次元で数値化された乳房の組織の画像の特徴およびその組織が良性か悪性かというラベルを含むデータです。 実験はこの良性/悪性の2値分類をするタスクとなります。

手順

こんな感じ

  • 以下繰り返し
    1. サンプルを1個読む
    2. 更新前の線形分類器で分類して結果を記録
    3. そのサンプルと正解ラベルを使って分類器を学習・更新

学習サンプル数に対するエラーの数をカウントして、それをアルゴリズム間で比較します。

ハイパーパラメータは全データを使って軽く手でチューニングしています。 (本当はデータ分けてチューニングした方がいいんだけど) なのでアルゴリズムごとにこのデータで出せるベストの性能を目指している形になります。

実験結果

f:id:soonraah:20160606050927p:plain

AROW が一番良いという結果になりました。 SCW が AROW に負けるのはデータのせいなのか、チューニングが足りないのか、実装に問題があるのか… SCW はハイパーパラメータが2つあることに加え、ハイパーパラメータに対するエラー数の動きが予測しにくく、チューニングが難しそうでした。

共分散行列を使わない*8 Perceptron と PA 系はやはり他と比べてかなり性能が落ちます。 直前に学習したサンプルに振り回されてる感がありました。

まとめ

以上の結果から実際に業務等で使うとしたら AROW が使い勝手良さそうです。 実装もすごくシンプル。 SCW もポテンシャルはあるはずなんですが…

今回オンライン学習を初めて実装してみましたが、関数型プログラミングと相性が良いと感じました。 実際 var メンバーを使わずに実装することができました。 並列化とか考えだすと難しくなりそうな気はしますけどね。

ここで挙げたアルゴリズムの元論文の導出とか見ると数式だらけでウゲーってなりますが、実装は比較的シンプルになります。 オンライン学習ってそういうもんなんですかね。