Weka开发[21]——IBk(KNN)源代码分析
如果你没有看上一篇IB1,请先看一下,因为重复的内容我在这里不会介绍了。
直接看buildClassifier,这里只列出在IB1中也没有出现的代码:
try {
m_NumClasses = instances.numClasses();
m_ClassType = instances.classAttribute().type(); } catch (Exception ex) {
throw new Error("This should never be reached"); }
// Throw away initial instances until within the specified window size
if ((m_WindowSize > 0) &&
(instances.numInstances() > m_WindowSize)) {
m_Train = new Instances(m_Train,
m_Train.numInstances()
- m_WindowSize, m_WindowSize);
}
// Compute the number of attributes that contribute
// to each prediction
m_NumAttributesUsed = 0.0;
for (int i = 0; i < m_Train.numAttributes(); i ) {
if ((i != m_Train.classIndex())
&& (m_Train.attribute(i).isNominal() || m_Train
.attribute(i).isNumeric())) {
m_NumAttributesUsed = 1.0;
}
}
// Invalidate any currently cross-validation selected k
m_kNNValid = false;
IB1中不关心m_NumClasses是因为它就找一个邻居,当然就一个值了。m_WindowSize是指用多少样本用于分类,这里不是随机选择而是直接选前m_WindowSize个。这里下面是看有多少属性参与预测。
KNN也是一个可以增量学习的分器量,下面看一下它的updateClassifier代码:
public void updateClassifier(Instance instance) throws Exception {
if (m_Train.equalHeaders(instance.dataset()) == false) {
throw new Exception("Incompatible instance types");
}
if (instance.classIsMissing()) {
return;
}
if (!m_DontNormalize) {
updateMinMax(instance);
}
m_Train.add(instance);
m_kNNValid = false;
if ((m_WindowSize > 0) &&
(m_Train.numInstances() > m_WindowSize)){
while (m_Train.numInstances() >
m_WindowSize) {
m_Train.delete(0);
}
}
}
同样很简单,updateMinMax,如果超出窗口大小,
循环删除超过窗口大小的第一个样本。
这里注意IBk没有实现classifyInstance,它只实现了distributionForInstances:
public double[] distributionForInstance(Instance instance) throws Exception {
if (m_Train.numInstances() == 0) {
throw new Exception("No training instances!");
}
if ((m_WindowSize > 0) &&
(m_Train.numInstances() > m_WindowSize)){
m_kNNValid = false;
boolean deletedInstance = false;
while (m_Train.numInstances() >
m_WindowSize) {
m_Train.delete(0);
}
//rebuild datastructure KDTree currently can't delete
if (deletedInstance == true)
m_NNSearch.setInstances(m_Train);
}
// Select k by cross validation
if (!m_kNNValid && (m_CrossValidate)
&& (m_kNNUpper >= 1)) {
crossValidate();
}
m_NNSearch.addInstanceInfo(instance);
Instances neighbours =
m_NNSearch.kNearestNeighbours(instance,
m_kNN);
double[] distances = m_NNSearch.getDistances();
double[] distribution = makeDistribution(neighbours, distances);
return distribution;
}
前面两个判断不讲了,crossValidate()马上讲,寻找K个邻居在我第[18]篇里已经讲过了,现在我们看一下makeDistribution函数。
protected double[] makeDistribution(Instances neighbours, double[] distances)throws Exception {
double total = 0, weight;
double[] distribution = new double[m_NumClasses];
// Set up a correction to the estimator
if (m_ClassType == Attribute.NOMINAL) {
for (int i = 0; i < m_NumClasses; i ) {
distribution[i] = 1.0 / Math.max(1,
m_Train.numInstances());
}
total = (double) m_NumClasses / Math.max(1,
m_Train.numInstances());
}
for (int i = 0; i < neighbours.numInstances(); i ) { // Collect class counts
Instance current = neighbours.instance(i);
distances[i] = distances[i] * distances[i];
distances[i] = Math.sqrt(distances[i] /
m_NumAttributesUsed);
switch (m_DistanceWeighting) {
case WEIGHT_INVERSE:
weight = 1.0 / (distances[i] 0.001); // to avoid div by zero
break;
case WEIGHT_SIMILARITY:
weight = 1.0 - distances[i];
break;
default: // WEIGHT_NONE:
weight = 1.0;
break;
}
weight *= current.weight();
try {
switch (m_ClassType) {
case Attribute.NOMINAL:
distribution[(int) current.classValue()] = weight;
break;
case Attribute.NUMERIC:
distribution[0] = current.classValue() * weight;
break;
}
} catch (Exception ex) {
throw new Error("Data has no class
attribute!");
}
total = weight;
}
// Normalise distribution
if (total > 0) {
Utils.normalize(distribution, total);
}
return distribution;
}
第一行注释Set up a correction,我感觉没什么必要,又不是Bayes还有除0错误,没什么可修正的。这里可以看见它实现了三种距离权重计算方法,倒数,与1的差,另外就是固定权重1。然后如果类别是离散值把对应的类值加上权重,如果是连续值,就加上当前类别值剩权重。
crossValidate简单地说就是用蛮力找在到底用多少个邻居好,它对m_Train中的样本进行循环,对每个样本找邻居,然后统计看寻找多少个邻居时最好。
protected void crossValidate() {
double[] performanceStats = new
double[m_kNNUpper];
double[] performanceStatsSq = new
double[m_kNNUpper];
for (int i = 0; i < m_kNNUpper; i ) {
performanceStats[i] = 0;
performanceStatsSq[i] = 0;
}
m_kNN = m_kNNUpper;
Instance instance;
Instances neighbours;
double[] origDistances, convertedDistances;
for (int i = 0; i < m_Train.numInstances(); i ) { instance = m_Train.instance(i);
neighbours =
m_NNSearch.kNearestNeighbours(instance, m_kNN);
origDistances = m_NNSearch.getDistances();
for (int j = m_kNNUpper - 1; j >= 0; j--) {
// Update the performance stats
convertedDistances = new
double[origDistances.length];
System.arraycopy(origDistances, 0, convertedDistances, 0,
origDistances.length);
double[] distribution =
makeDistribution(neighbours,
convertedDistances);
double thisPrediction =
Utils.maxIndex(distribution);
if (m_Train.classAttribute().isNumeric()) {
thisPrediction = distribution[0];
double err = thisPrediction -
instance.classValue();
performanceStatsSq[j] = err * err; // Squared error
performanceStats[j] = Math.abs(err); // Absolute error
} else {
if (thisPrediction != instance.classValue()) {
performanceStats[j] ; // Classification error
}
}
if (j >= 1) {
neighbours = pruneToK(neighbours, convertedDistances, j);
}
}
}
// Check through the performance stats and select the best
// k value (or the lowest k if more than one best)
double[] searchStats = performanceStats;
if (m_Train.classAttribute().isNumeric() && m_MeanSquared) {
searchStats = performanceStatsSq;
}
double bestPerformance = Double.NaN;
int bestK = 1;
for (int i = 0; i < m_kNNUpper; i ) {
if (Double.isNaN(bestPerformance)
|| (bestPerformance > searchStats[i])) {
bestPerformance = searchStats[i];
bestK = i 1;
}
}
m_kNN = bestK;
m_kNNValid = true;
}
m_kNNUpper是另一个设置最多有多少样本的参数,枚举每一个样本(instance),找它的邻居(neighbors),和距离(origDistances)。接下来就是把从0到m_kNNUpper个邻居的得到的方差(performanceStatsSq)和标准差(performanceStats)与以前得到的值累加。pruneToK就是得到j个样本(如果j 1的距离不等于第j个),后面就比较好理
解了,m_MeanSquared对连续类别是选择用方差还是标准差进行选择,然后最出m_kNNUpper看在多少邻居的时候,分类误差最小,就认为是最好的邻居数。
本文档为【weka开发[21]ibk(knn)源代码分析】,请使用软件OFFICE或WPS软件打开。作品中的文字与图均可以修改和编辑,
图片更改请在作品中右键图片并更换,文字修改请直接点击文字进行修改,也可以新增和删除文档中的内容。
该文档来自用户分享,如有侵权行为请发邮件ishare@vip.sina.com联系网站客服,我们会及时删除。
[版权声明] 本站所有资料为用户分享产生,若发现您的权利被侵害,请联系客服邮件isharekefu@iask.cn,我们尽快处理。
本作品所展示的图片、画像、字体、音乐的版权可能需版权方额外授权,请谨慎使用。
网站提供的党政主题相关内容(国旗、国徽、党徽..)目的在于配合国家政策宣传,仅限个人学习分享使用,禁止用于任何广告和商用目的。