- 积分
- 26294
- 贡献
-
- 精华
- 在线时间
- 小时
- 注册时间
- 2012-6-1
- 最后登录
- 1970-1-1
|
登录后查看更多精彩内容~
您需要 登录 才可以下载或查看,没有帐号?立即注册
x
今天看了贝叶斯的相关理论,然后就想实现一个贝叶斯分类器,花了一个下午,结合自己在网上找的一些资料,实现了分类,下面介绍一下,一些理论方面的东西我就不读说了,说实现。
主要是分两个步骤:
1、产生数据,这里的数据是自己产生的,可以用自己需要分类的数据去替换
2、分类
产生数据的程序如下:
clear
clc
train1=[];
train2=[];
train3=[];
test1=[];
test2=[];
test3=[];
%% train1
heng=rand(1,100)*3+1;
zong=rand(1,100)*2+2;
train1(1,:)=heng;train1(2,:)=zong;
%% train2
heng=rand(1,100)*3+6;
zong=rand(1,100)*2;
train2(1,:)=heng;train2(2,:)=zong;
%% train3
heng=rand(1,100)*2+3;
zong=-abs(rand(1,100)*2);
train3(1,:)=heng;train3(2,:)=zong;
%% test1
heng=rand(1,100)*3+1;
zong=rand(1,100)*2+2;
test1(1,:)=heng;test1(2,:)=zong;
%% test2
heng=rand(1,100)*3+6;
zong=rand(1,100)*2;
test2(1,:)=heng;test2(2,:)=zong;
%% test3
heng=rand(1,100)*2+3;
zong=-abs(rand(1,100)*2);
test3(1,:)=heng;test3(2,:)=zong;
%% save as mat
save train1;
save train2;
save train3;
save test1;
save test2;
save test3;
实现分类的代码如下:
outPut(1:3,1:3)=0; %判别矩阵的初始化
class1=[];
class2=[];
class3=[];
%生成二维正态分布的样本2 X N 维的矩阵 样本程序
%训练样本
load train1;
load train2;
load train3;
load test1;
load test2;
load test3;
% train1=mvnrnd([1 1],[4 0;0 5],100)'; %2 X N
% train2=mvnrnd([7 2],[7 0;0 4],100)';
% train3=mvnrnd([2 7],[2 0;0 4],100)';
% %测试样本
% test1=mvnrnd([1 1],[4 0;0 5],100)'; %2 X N
% test2=mvnrnd([7 2],[7 0;0 4],100)';
% test3=mvnrnd([2 7],[2 0;0 4],100)';
%---------------------------------------------------%
%先验概率
P(1)=length(train1)/(length(train1)+length(train2)+length(train3));
P(2)=length(train2)/(length(train1)+length(train2)+length(train3));
P(3)=length(train3)/(length(train1)+length(train2)+length(train3));
%计算相关量 cov(X):协方差矩阵 Ave:均值
%--------------------------------------------------------%
W1=-1/2*inv(cov(train1'));
W2=-1/2*inv(cov(train2'));
W3=-1/2*inv(cov(train3'));%
Ave1=(sum(train1')/length(train1))';%计算平均值(2维列向量,2*1)
Ave2=(sum(train2')/length(train2))';
Ave3=(sum(train3')/length(train3))';
w1=inv(cov(train1'))*Ave1;
w2=inv(cov(train2'))*Ave2;
w3=inv(cov(train3'))*Ave3;%2
w10=-1/2*Ave1'*inv(cov(train1'))*Ave1-1/2*log(det(cov(train1')))+log(P(1));
w20=-1/2*Ave2'*inv(cov(train2'))*Ave2-1/2*log(det(cov(train2')))+log(P(2));
w30=-1/2*Ave3'*inv(cov(train3'))*Ave3-1/2*log(det(cov(train3')))+log(P(3));
%-----------------------------------------------------------%
for i=1:3
for j=1:100
if i==1
g1=test1(:,j)'*W1*test1(:,j)+w1'*test1(:,j)+w10;
g2=test1(:,j)'*W2*test1(:,j)+w2'*test1(:,j)+w20;
g3=test1(:,j)'*W3*test1(:,j)+w3'*test1(:,j)+w30;
if g1>=g2&g1>=g3
outPut(1,1)=outPut(1,1)+1;
class1=[class1,test1(:,j)];
elseif g2>=g1&g2>=g3
outPut(1,2)=outPut(1,2)+1;%记录误判情况
class2=[class2,test1(:,j)];
else
outPut(1,3)=outPut(1,3)+1;%记录误判情况
class3=[class3,test1(:,j)];
end
elseif i==2
g1=test2(:,j)'*W1*test2(:,j)+w1'*test2(:,j)+w10;
g2=test2(:,j)'*W2*test2(:,j)+w2'*test2(:,j)+w20;
g3=test2(:,j)'*W3*test2(:,j)+w3'*test2(:,j)+w30;
if g2>=g1&g2>=g3
outPut(2,2)=outPut(2,2)+1;
class2=[class2,test2(:,j)];
elseif g1>=g2&g1>=g3
outPut(2,1)=outPut(2,1)+1;
class1=[class1,test2(:,j)];
else
outPut(2,3)=outPut(2,3)+1;
class3=[class3,test2(:,j)];
end
else
g1=test3(:,j)'*W1*test3(:,j)+w1'*test3(:,j)+w10;
g2=test3(:,j)'*W2*test3(:,j)+w2'*test3(:,j)+w20;
g3=test3(:,j)'*W3*test3(:,j)+w3'*test3(:,j)+w30;
if g3>=g1&g3>=g2
outPut(3,3)=outPut(3,3)+1;
class3=[class3,test3(:,j)];
elseif g2>=g1&g2>=g3
outPut(3,2)=outPut(3,2)+1;
class2=[class2,test3(:,j)];
else
outPut(3,1)=outPut(3,1)+1;
class1=[class1,test3(:,j)];
end
end
end
end
outPut
%---------------------------------------------------%
%画出各样本的分布情况
subplot(3,1,1)
plot(train1(1,:),train1(2,:),'go','LineWidth',2),hold on
plot(train2(1,:),train2(2,:),'b+','LineWidth',2),hold on
plot(train3(1,:),train3(2,:),'r.','LineWidth',2),hold on
title('训练样本分布情况')
legend('训练样本1','训练样本2','训练样本3')
subplot(3,1,2)
plot(test1(1,:),test1(2,:),'go','LineWidth',2),hold on
plot(test2(1,:),test2(2,:),'b+','LineWidth',2),hold on
plot(test3(1,:),test3(2,:),'r.','LineWidth',2),hold on
title('测试样本分布情况')
legend('测试样本1','测试样本2','测试样本3')
subplot(3,1,3)
plot(class1(1,:),class1(2,:),'go','LineWidth',2),hold on
plot(class2(1,:),class2(2,:),'b+','LineWidth',2),hold on
plot(class3(1,:),class3(2,:),'r.','LineWidth',2),hold on
title('测试样本分类后分布情况')
legend('测试样本1','测试样本2','测试样本3')
注意,首先要运行第一个程序,产生的数据供第二个程序调用,下面是结果,贴出来看一下:
|
评分
-
查看全部评分
|