EM アルゴリズムを用いた多クラス分類
[Background]
EMアルゴリズムという手法があります。気が向いたら細かい説明を載せますが、簡単に説明すると、「観測できない隠れた変数」である潜在変数を含む確率モデルにおいて、最尤解を求めるアルゴリズムです。今回はこのEMアルゴリズムを実装しました。すべてC++で実装しました。
まず、使うデータですが、三つのガウス分布から100点ずつ点を生成します。
[Code 1]
#include <iostream> #include <fstream> #include <string> #define _USE_MATH_DEFINES #include <math.h> using namespace std; // 一様乱数 double Uniform() { double ret = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0); return ret; } // 中心mu、分散sigmaのガウス分布を生成 double rand_normal(double mu, double sigma) { double z = sqrt(-2.0 * log(Uniform())) * sin(2.0 * M_PI * Uniform()); return mu + sigma * z; } int main() { string filename = "data.txt"; ofstream ofs = ofstream(filename); for (int i = 0; i < 100; i++) { double x1 = rand_normal(0.1, 0.1); double y1 = rand_normal(0.1, 0.1); double x2 = rand_normal(0.4, 0.1); double y2 = rand_normal(0.6, 0.1); double x3 = rand_normal(0.6, 0.1); double y3 = rand_normal(0.1, 0.1); ofs << x1 << "\t" << y1 << endl; ofs << x2 << "\t" << y2 << endl; ofs << x3 << "\t" << y3 << endl; } return 0; }
[Result 1]
生成した点をプロットした結果が以下の図の通りです。
[Code 2]
では実際に分類を行いますが、まずはいくつのクラスに分類するのが最も良いのかを見てみます。2~10クラスに分類してみて、それぞれにおけるすべての点の対数尤度を合計したものを見てみます。
#define _CRT_SECURE_NO_WARNINGS 1 #include <iostream> #include <string> #include <fstream> #include <vector> #define _USE_MATH_DEFINES #include <math.h> using namespace std; int main() { // number of sample const int sample = 300; // data double x[sample]; double y[sample]; // data load FILE *fp = fopen("data.txt", "r"); if (fp == NULL) return -1; for (int i = 0; i < sample; i++) { fscanf(fp, "%lf", x+i); fscanf(fp, "%lf", y+i); } string outputfilename = "likelihood.txt"; ofstream ofs = ofstream(outputfilename); for (int ncomp = 2; ncomp < 11; ncomp++) { // def paramaters vector<double> pi(ncomp, 1.0 / ncomp); double mux[sample]; double muy[sample]; for (int z = 0; z < ncomp; z++) { mux[z] = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0); muy[z] = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0); } vector<double> sigma(ncomp, 1); double Qdef = 10; double beforeQ = -10000; // EM algorithm while (Qdef > 0.001) { // E step vector<vector<double> > w(ncomp, vector<double>(sample)); double Q = 0; for (int z = 0; z < ncomp; z++) { for (int h = 0; h < sample; h++) { double sum = 0; for (int zz = 0; zz < ncomp; zz++) { double dis = (x[h] - mux[zz]) * (x[h] - mux[zz]) + (y[h] - muy[zz]) * (y[h] - muy[zz]); sum += exp(-dis / (2 * sigma[zz])) / sigma[zz]; } double dis = (x[h] - mux[z]) * (x[h] - mux[z]) + (y[h] - muy[z]) * (y[h] - muy[z]); w[z][h] = exp(-dis / (2 * sigma[z])) / (sum * sigma[z]); Q += w[z][h] * ((-dis / (2 * sigma[z])) - log(2 * M_PI * sigma[z])) + w[z][h] * log(pi[z]); } } Qdef = Q - beforeQ; beforeQ = Q; cout << Q << endl; // M step for (int z = 0; z < ncomp; z++) { double wsum = 0; double dissum = 0; double xwsum0 = 0; double xwsum1 = 0; for (int h = 0; h < sample; h++) { wsum += w[z][h]; dissum += ((x[h] - mux[z]) * (x[h] - mux[z]) + (y[h] - muy[z]) * (y[h] - muy[z])) * w[z][h]; xwsum0 += x[h] * w[z][h]; xwsum1 += y[h] * w[z][h]; } mux[z] = xwsum0 / wsum; muy[z] = xwsum1 / wsum; sigma[z] = dissum / (wsum * 2); pi[z] = wsum / sample; } } // log likelihood double LLH = 0; for (int i = 0; i < sample; i++) { double f = 0; for (int z = 0; z < ncomp; z++) { double dis = (x[i] - mux[z]) * (x[i] - mux[z]) + (y[i]- muy[z]) * (y[i] - muy[z]); f += pi[z] * exp(-dis / (2 * sigma[z])) / (2 * M_PI * sigma[z]); } LLH += log(f); } cout << "Log Likelihood is ..." << LLH << endl; ofs << ncomp << "\t" << LLH << endl; } return 0; }
[Result 2]
横軸にクラス数、縦軸に対数尤度を取ったグラフが下のようになります。
見たところ3が最もよさそうです。(8などのほうが高そうですが、3以降は上昇が緩く飽和してそうなので、過学習を避けるためにも3に設定します。もっとも、元データを3クラスで作っていると知ってしまっているのもありますが……。)
[Code 3]
では、3クラスで分類し、学習が進んでいる様子をgifで出力するようにしてみます。gnuplot を用いて、各段階での学習した中心と分散とをpngファイルに出力するようにし、Giamを用いてgifにまとめました。(gnuplot からgif animation で出力するやり方がよくわからなかった。) 下のコードを実行するとgnuplot で実行すべきコマンドが"gnuplotcommand.txt" に出力されるので、gnuplotで load "gnuplotcommand.txt" としてやれば終わりです。
#define _CRT_SECURE_NO_WARNINGS 1 #include <iostream> #include <string> #include <fstream> #include <vector> #define _USE_MATH_DEFINES #include <math.h> using namespace std; int main() { // number of sample const int sample = 300; // data double x[sample]; double y[sample]; // data load FILE *fp = fopen("data.txt", "r"); if (fp == NULL) return -1; for (int i = 0; i < sample; i++) { fscanf(fp, "%lf", x+i); fscanf(fp, "%lf", y+i); } string outputfilename = "gnuplotcommand.txt"; ofstream gnuplot = ofstream(outputfilename); gnuplot << "set para" << endl; gnuplot << "set xrange[-0.4:1.0]" << endl; gnuplot << "set yrange[-0.4:1.0]" << endl; int ncomp = 3; // def paramaters vector<double> pi(ncomp, 1.0 / ncomp); double mux[sample]; double muy[sample]; for (int z = 0; z < ncomp; z++) { mux[z] = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0); muy[z] = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0); } vector<double> sigma(ncomp, 1); int count = 0; double Qdef = 10; double beforeQ = -10000; // EM algorithm while (Qdef > 0.001) { // E step vector<vector<double> > w(ncomp, vector<double>(sample)); double Q = 0; for (int z = 0; z < ncomp; z++) { for (int h = 0; h < sample; h++) { double sum = 0; for (int zz = 0; zz < ncomp; zz++) { double dis = (x[h] - mux[zz]) * (x[h] - mux[zz]) + (y[h] - muy[zz]) * (y[h] - muy[zz]); sum += exp(-dis / (2 * sigma[zz])) / sigma[zz]; } double dis = (x[h] - mux[z]) * (x[h] - mux[z]) + (y[h] - muy[z]) * (y[h] - muy[z]); w[z][h] = exp(-dis / (2 * sigma[z])) / (sum * sigma[z]); Q += w[z][h] * ((-dis / (2 * sigma[z])) - log(2 * M_PI * sigma[z])) + w[z][h] * log(pi[z]); } } Qdef = Q - beforeQ; beforeQ = Q; cout << Q << endl; gnuplot << "plot \"data.txt\"" << endl; for (int i = 0; i < ncomp; i++) { gnuplot << "set label " << i + 1 << " point at " << mux[i] << "," << muy[i] << endl; gnuplot << "replot [0:2*pi] " << mux[i] << "+" << sqrt(sigma[i]) << "*cos(t), " << muy[i] << "+" << sqrt(sigma[i]) << "*sin(t)" << endl; } gnuplot << "set terminal png" << endl; gnuplot << "set out \"" << count << ".png\"" << endl; gnuplot << "replot" << endl; // M step for (int z = 0; z < ncomp; z++) { double wsum = 0; double dissum = 0; double xwsum0 = 0; double xwsum1 = 0; for (int h = 0; h < sample; h++) { wsum += w[z][h]; dissum += ((x[h] - mux[z]) * (x[h] - mux[z]) + (y[h] - muy[z]) * (y[h] - muy[z])) * w[z][h]; xwsum0 += x[h] * w[z][h]; xwsum1 += y[h] * w[z][h]; } mux[z] = xwsum0 / wsum; muy[z] = xwsum1 / wsum; sigma[z] = dissum / (wsum * 2); pi[z] = wsum / sample; } count++; } return 0; }
[Result 3]
いい感じ。