読者です 読者をやめる 読者になる 読者になる

消極的自殺の記録

暁月分明 (tube_worm) が人生という消極的な自殺をしていくにあたっての記録です。

EM アルゴリズムを用いた多クラス分類

[Background]

 EMアルゴリズムという手法があります。気が向いたら細かい説明を載せますが、簡単に説明すると、「観測できない隠れた変数」である潜在変数を含む確率モデルにおいて、最尤解を求めるアルゴリズムです。今回はこのEMアルゴリズムを実装しました。すべてC++で実装しました。
 まず、使うデータですが、三つのガウス分布から100点ずつ点を生成します。

[Code 1]

#include <iostream>
#include <fstream>
#include <string>
#define _USE_MATH_DEFINES
#include <math.h>
using namespace std;

// 一様乱数
double Uniform()
{
	double ret = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0);
	return ret;
}

// 中心mu、分散sigmaのガウス分布を生成
double rand_normal(double mu, double sigma)
{
	double z = sqrt(-2.0 * log(Uniform())) * sin(2.0 * M_PI * Uniform());
	return mu + sigma * z;
}

int main()
{
	string filename = "data.txt";
	ofstream ofs = ofstream(filename);

	for (int i = 0; i < 100; i++)
	{
		double x1 = rand_normal(0.1, 0.1);
		double y1 = rand_normal(0.1, 0.1);
		double x2 = rand_normal(0.4, 0.1);
		double y2 = rand_normal(0.6, 0.1);
		double x3 = rand_normal(0.6, 0.1);
		double y3 = rand_normal(0.1, 0.1);

		ofs << x1 << "\t" << y1 << endl;
		ofs << x2 << "\t" << y2 << endl;
		ofs << x3 << "\t" << y3 << endl;
	}
	
	return 0;
}

[Result 1]

生成した点をプロットした結果が以下の図の通りです。
f:id:tube_worm:20160329061628p:plain

[Code 2]

 では実際に分類を行いますが、まずはいくつのクラスに分類するのが最も良いのかを見てみます。2~10クラスに分類してみて、それぞれにおけるすべての点の対数尤度を合計したものを見てみます。

#define _CRT_SECURE_NO_WARNINGS 1
#include <iostream>
#include <string>
#include <fstream>
#include <vector>
#define _USE_MATH_DEFINES
#include <math.h>

using namespace std;

int main()
{
	// number of sample
	const int sample = 300;

	// data
	double x[sample];
	double y[sample];

	// data load
	FILE *fp = fopen("data.txt", "r");
	if (fp == NULL) return -1;

	for (int i = 0; i < sample; i++)
	{
		fscanf(fp, "%lf", x+i);
		fscanf(fp, "%lf", y+i);
	}
	
	string outputfilename = "likelihood.txt";
	ofstream ofs = ofstream(outputfilename);

	for (int ncomp = 2; ncomp < 11; ncomp++)
	{
		// def paramaters
		vector<double> pi(ncomp, 1.0 / ncomp);
		double mux[sample];
		double muy[sample];
		for (int z = 0; z < ncomp; z++)
		{
			mux[z] = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0);
			muy[z] = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0);
		}

		vector<double> sigma(ncomp, 1);

		double Qdef = 10;
		double beforeQ = -10000;

		// EM algorithm
		while (Qdef > 0.001)
		{
			// E step
			vector<vector<double> > w(ncomp, vector<double>(sample));

			double Q = 0;
			for (int z = 0; z < ncomp; z++)
			{
				for (int h = 0; h < sample; h++)
				{
					double sum = 0;
					for (int zz = 0; zz < ncomp; zz++)
					{
						double dis = (x[h] - mux[zz]) * (x[h] - mux[zz]) + (y[h] - muy[zz]) * (y[h] - muy[zz]);
						sum += exp(-dis / (2 * sigma[zz])) / sigma[zz];
					}
					double dis = (x[h] - mux[z]) * (x[h] - mux[z]) + (y[h] - muy[z]) * (y[h] - muy[z]);
					w[z][h] = exp(-dis / (2 * sigma[z])) / (sum * sigma[z]);

					Q += w[z][h] * ((-dis / (2 * sigma[z])) - log(2 * M_PI * sigma[z])) + w[z][h] * log(pi[z]);
				}
			}
			Qdef = Q - beforeQ;
			beforeQ = Q;

			cout << Q << endl;
			
			// M step
			for (int z = 0; z < ncomp; z++)
			{
				double wsum = 0;
				double dissum = 0;
				double xwsum0 = 0;
				double xwsum1 = 0;
				for (int h = 0; h < sample; h++)
				{
					wsum += w[z][h];
					dissum += ((x[h] - mux[z]) * (x[h] - mux[z]) + (y[h] - muy[z]) * (y[h] - muy[z])) * w[z][h];
					xwsum0 += x[h] * w[z][h];
					xwsum1 += y[h] * w[z][h];
				}
				mux[z] = xwsum0 / wsum;
				muy[z] = xwsum1 / wsum;
				sigma[z] = dissum / (wsum * 2);
				pi[z] = wsum / sample;
			}
		}

		// log likelihood
		double LLH = 0;
		for (int i = 0; i < sample; i++)
		{
			double f = 0;
			for (int z = 0; z < ncomp; z++)
			{
				double dis = (x[i] - mux[z]) * (x[i] - mux[z]) + (y[i]- muy[z]) * (y[i] - muy[z]);
				f += pi[z] * exp(-dis / (2 * sigma[z])) / (2 * M_PI * sigma[z]);
			}
			LLH += log(f);
		}
		cout << "Log Likelihood is ..." << LLH << endl;
		ofs << ncomp << "\t" << LLH << endl;
	}


	return 0;
}

[Result 2]

 横軸にクラス数、縦軸に対数尤度を取ったグラフが下のようになります。
f:id:tube_worm:20160329062111p:plain
 見たところ3が最もよさそうです。(8などのほうが高そうですが、3以降は上昇が緩く飽和してそうなので、過学習を避けるためにも3に設定します。もっとも、元データを3クラスで作っていると知ってしまっているのもありますが……。)

[Code 3]

 では、3クラスで分類し、学習が進んでいる様子をgifで出力するようにしてみます。gnuplot を用いて、各段階での学習した中心と分散とをpngファイルに出力するようにし、Giamを用いてgifにまとめました。(gnuplot からgif animation で出力するやり方がよくわからなかった。) 下のコードを実行するとgnuplot で実行すべきコマンドが"gnuplotcommand.txt" に出力されるので、gnuplotで load "gnuplotcommand.txt" としてやれば終わりです。

#define _CRT_SECURE_NO_WARNINGS 1
#include <iostream>
#include <string>
#include <fstream>
#include <vector>
#define _USE_MATH_DEFINES
#include <math.h>

using namespace std;

int main()
{
	// number of sample
	const int sample = 300;

	// data
	double x[sample];
	double y[sample];

	// data load
	FILE *fp = fopen("data.txt", "r");
	if (fp == NULL) return -1;

	for (int i = 0; i < sample; i++)
	{
		fscanf(fp, "%lf", x+i);
		fscanf(fp, "%lf", y+i);
	}
	
	string outputfilename = "gnuplotcommand.txt";
	ofstream gnuplot = ofstream(outputfilename);

	gnuplot << "set para" << endl;
	gnuplot << "set xrange[-0.4:1.0]" << endl;
	gnuplot << "set yrange[-0.4:1.0]" << endl;

	int ncomp = 3;

	// def paramaters
	vector<double> pi(ncomp, 1.0 / ncomp);
	double mux[sample];
	double muy[sample];
	for (int z = 0; z < ncomp; z++)
	{
		mux[z] = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0);
		muy[z] = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0);
	}

	vector<double> sigma(ncomp, 1);

	int count = 0;

	double Qdef = 10;
	double beforeQ = -10000;

	// EM algorithm
	while (Qdef > 0.001)
	{
		// E step
		vector<vector<double> > w(ncomp, vector<double>(sample));

		double Q = 0;
		for (int z = 0; z < ncomp; z++)
		{
			for (int h = 0; h < sample; h++)
			{
				double sum = 0;
				for (int zz = 0; zz < ncomp; zz++)
				{
					double dis = (x[h] - mux[zz]) * (x[h] - mux[zz]) + (y[h] - muy[zz]) * (y[h] - muy[zz]);
					sum += exp(-dis / (2 * sigma[zz])) / sigma[zz];
				}
				double dis = (x[h] - mux[z]) * (x[h] - mux[z]) + (y[h] - muy[z]) * (y[h] - muy[z]);
				w[z][h] = exp(-dis / (2 * sigma[z])) / (sum * sigma[z]);

				Q += w[z][h] * ((-dis / (2 * sigma[z])) - log(2 * M_PI * sigma[z])) + w[z][h] * log(pi[z]);
			}
		}
		Qdef = Q - beforeQ;
		beforeQ = Q;

		cout << Q << endl;
		gnuplot << "plot \"data.txt\"" << endl;
		for (int i = 0; i < ncomp; i++)
		{
			gnuplot << "set label " << i + 1 << " point at " << mux[i] << "," << muy[i] << endl;
			gnuplot << "replot [0:2*pi] " << mux[i] << "+" << sqrt(sigma[i]) << "*cos(t), " << muy[i] << "+" << sqrt(sigma[i]) << "*sin(t)" << endl;
		}
		gnuplot << "set terminal png" << endl;
		gnuplot << "set out \"" << count << ".png\"" << endl;
		gnuplot << "replot" << endl;

		// M step
		for (int z = 0; z < ncomp; z++)
		{
			double wsum = 0;
			double dissum = 0;
			double xwsum0 = 0;
			double xwsum1 = 0;
			for (int h = 0; h < sample; h++)
			{
				wsum += w[z][h];
				dissum += ((x[h] - mux[z]) * (x[h] - mux[z]) + (y[h] - muy[z]) * (y[h] - muy[z])) * w[z][h];
				xwsum0 += x[h] * w[z][h];
				xwsum1 += y[h] * w[z][h];
			}
			mux[z] = xwsum0 / wsum;
			muy[z] = xwsum1 / wsum;
			sigma[z] = dissum / (wsum * 2);
			pi[z] = wsum / sample;
		}
		count++;
	}


	return 0;
}

[Result 3]

f:id:tube_worm:20160329064454g:plain
いい感じ。