完整版KNN算法实验报告Word格式文档下载.docx
《完整版KNN算法实验报告Word格式文档下载.docx》由会员分享,可在线阅读,更多相关《完整版KNN算法实验报告Word格式文档下载.docx(11页珍藏版)》请在冰豆网上搜索。
数学表达:
目标函数值可以是离散值(分类问题),也可以是连续值(回归问题).函数形势为f:
n维空间R—〉一维空间R。
第一步:
将数据集分为训练集(DTrn)和测试集(DTES)。
第二步:
在测试集给定一个实例Xq;
在训练集(DTrn)中找到与这个实例Xq的K-最近邻子集{X1、、、、XK},即:
DKNN。
第三步:
计算这K-最近邻子集得目标值,经过加权平均:
^f(Xq)=(f(X1)+...+f(XK))/k作为f(Xq)的近似估计。
改进的地方:
对kNN算法的一个明显的改进是对k个最近邻的贡献加权,将较大的权值赋给较近的近邻,相应的算法称为距离加权kNN回归算法,则公式1则修改为:
^f(Xq)=(w1*f(X1)+...+wk*f(XK))/(w1+...wk)一般地距离权值wi和距离成反比关系,例如,wi近似=1/d(xq;
xi).K值的选择:
需要消除K值过低,预测目标容易产生变动性,同时高k值时,预测目标有过平滑现象。
推定k值的有益途径是通过有效参数的数目这个概念。
有效参数的数目是和k值相关的,大致等于n/k,其中,n是这个训练数据集中实例的数目。
缺点:
(1)在大训练集寻找最近邻的时间是难以忍受的。
(2)在训练数据集中要求的观测值的数目,随着维数p的增长以指数方式增长。
这是因为和最近邻的期望距离随着维数p的增多而急剧上升,除非训练数据集的大小随着p以指数方式增长。
这种现象被称为“维数灾难”。
解决办法有下面几个:
(1)通过降维技术来减少维数,如主成分分析,因子分析,变量选择(因子选择)从而减少计算距离的时间;
(2)用复杂的数据结构,如搜索树去加速最近邻的确定。
这个方法经常通过公式2公式1设定“几乎是最近邻”的目标去提高搜索速度;
(3)编辑训练数据去减少在训练集中的冗余和几乎是冗余的点,从而加速搜索最近邻。
在个别例子中去掉在训练数据集中的一些观察点,对分类效果没有影响,原因是这些点被包围属于同类的观测点中。
三注意事项
KNN算法的实现要注意:
1.用TreeMap<
String,TreeMap<
String,Double>
>
保存测试集和训练集。
2.注意要以"
类目_文件名"
作为每个文件的key,才能避免同名不同内容的文件出现。
3.注意设置JM参数,否则会出现JAVAheap溢出错误。
4.本程序用向量夹角余弦计算相似度。
四代码
//KNN.java
packagecqu.KNN;
importjava.util.ArrayList;
importjava.util.Comparator;
importjava.util.HashMap;
importjava.util.List;
importjava.util.Map;
importjava.util.PriorityQueue;
//KNN算法主体类
publicclassKNN
{
/***设置优先级队列的比较函数,距离越大,优先级越高*/
privateComparator<
KNNNode>
comparator=newComparator<
()
{
publicintcompare(KNNNodeo1,KNNNodeo2)
if(o1.getDistance()>
=o2.getDistance())
return-1;
}
else
return1;
}
};
/***获取K个不同的随机数*@paramk随机数的个数*@parammax随机数最大的范围*@return生成的随机数数组*/
publicList<
Integer>
getRandKNum(intk,intmax)
List<
rand=newArrayList<
(k);
for(inti=0;
i<
k;
i++)
inttemp=(int)(Math.random()*max);
if(!
rand.contains(temp))
rand.add(temp);
i--;
returnrand;
/***计算测试元组与训练元组之前的距离*@paramd1测试元组*@paramd2训练元组*@return距离值*/
publicdoublecalDistance(List<
Double>
d1,List<
d2)
doubledistance=0.00;
d1.size();
distance+=(d1.get(i)-d2.get(i))*(d1.get(i)-d2.get(i));
returndistance;
/***执行KNN算法,获取测试元组的类别*@paramdatas训练数据集*@paramtestData测试元组*@paramk设定的K值*@return测试元组的类别*/
publicStringknn(List<
List<
datas,List<
testData,intk)
PriorityQueue<
pq=newPriorityQueue<
(k,comparator);
randNum=getRandKNum(k,datas.size());
intindex=randNum.get(i);
currData=datas.get(index);
Stringc=currData.get(currData.size()-1).toString();
KNNNodenode=newKNNNode(index,calDistance(testData,currData),c);
pq.add(node);
datas.size();
t=datas.get(i);
doubledistance=calDistance(testData,t);
KNNNodetop=pq.peek();
if(top.getDistance()>
distance)
pq.remove();
pq.add(newKNNNode(i,distance,t.get(t.size()-1).toString()));
returngetMostClass(pq);
/***获取所得到的k个最近邻元组的多数类*@parampq存储k个最近近邻元组的优先级队列*@return多数类的名称*/
privateStringgetMostClass(PriorityQueue<
pq)
{
Map<
String,Integer>
classCount=newHashMap<
();
intpqsize=pq.size();
pqsize;
KNNNodenode=pq.remove();
Stringc=node.getC();
if(classCount.containsKey(c))
classCount.put(c,classCount.get(c)+1);
classCount.put(c,1);
intmaxIndex=-1;
intmaxCount=0;
Object[]classes=classCount.keySet().toArray();
classes.length;
i++)
if(classCount.get(classes[i])>
maxCount)
maxIndex=i;
maxCount=classCount.get(classes[i]);
returnclasses[maxIndex].toString();
}
//KNNNode.java
publicclassKNNNode
privateintindex;
//元组标号
privatedoubledistance;
//与测试元组的距离
privateStringc;
//所属类别
publicKNNNode(intindex,doubledistance,Stringc)
super();
this.index=index;
this.distance=distance;
this.c=c;
publicintgetIndex()
returnindex;
publicvoidsetIndex(intindex)
publicdoublegetDistance()
publicvoidsetDistance(doubledistance)
publicStringgetC()
returnc;
publicvoidsetC(Stringc)
//TestKNN.java
importjava.io.BufferedReader;
importjava.io.File;
importjava.io.FileReader;
//KNN算法测试类
publicclassTestKNN
/***从数据文件中读取数据*@paramdatas存储数据的集合对象*@parampath数据文件的路径*/
publicvoidread(List<
datas,Stringpath)
try{
BufferedReaderbr=newBufferedReader(newFileReader(newFile(path)));
Stringreader=br.readLine();
while(reader!
=null)
Stringt[]=reader.split("
"
);
ArrayList<
list=newArrayList<
t.length;
list.add(Double.parseDouble(t[i]));
datas.add(list);
reader=br.readLine();
catch(Exceptione)
e.printStackTrace();
/***程序执行入口*@paramargs*/
publicstaticvoidmain(String[]args)
TestKNNt=newTestKNN();
Stringdatafile=newFile("
"
).getAbsolutePath()+File.separator+"
cqudata\\datafile.txt"
;
Stringtestfile=newFile("
cqudata\\testfile.txt"
datas=newArrayList<
testDatas=newArrayList<
t.read(datas,datafile);
t.read(testDatas,testfile);
KNNknn=newKNN();
testDatas.size();
test=testDatas.get(i);
System.out.print("
测试元组:
for(intj=0;
j<
test.size();
j++)
System.out.print(test.get(j)+"
类别为:
System.out.println(Math.round(Float.parseFloat((knn.knn(datas,test,3)))));
五运行测试
训练数据:
1.01.11.22.10.32.31.40.51
1.71.21.42.00.22.51.20.81
1.21.81.62.50.12.21.80.21
1.92.16.21.10.93.32.45.50
1.00.81.62.10.22.31.60.51
1.62.15.21.10.83.62.44.50
实验数据:
1.01.11.22.10.32.31.40.5
1.71.21.42.00.22.51.20.8
1.21.81.62.50.12.21.80.2
1.92.16.21.10.93.32.45.5
1.00.81.62.10.22.31.60.5
1.62.15.21.10.83.62.44.5
程序运行结果:
1.01.11.22.10.32.31.40.5类别为:
1
1.71.21.42.00.22.51.20.8类别为:
1.21.81.62.50.12.21.80.2类别为:
1.92.16.21.10.93.32.45.5类别为:
0
1.00.81.62.10.22.31.60.5类别为:
1.62.15.21.10.83.62.44.5类别为:
0