首页 weka开发[21]ibk(knn)源代码分析

weka开发[21]ibk(knn)源代码分析

举报
开通vip

weka开发[21]ibk(knn)源代码分析Weka开发[21]——IBk(KNN)源代码分析 如果你没有看上一篇IB1,请先看一下,因为重复的内容我在这里不会介绍了。 直接看buildClassifier,这里只列出在IB1中也没有出现的代码: try { m_NumClasses = instances.numClasses(); m_ClassType = instances.classAttribute().type(); } catch (Exception ex) { throw new Error("This should nev...

weka开发[21]ibk(knn)源代码分析
Weka开发[21]——IBk(KNN)源代码分析 如果你没有看上一篇IB1,请先看一下,因为重复的内容我在这里不会介绍了。 直接看buildClassifier,这里只列出在IB1中也没有出现的代码: try { m_NumClasses = instances.numClasses(); m_ClassType = instances.classAttribute().type(); } catch (Exception ex) { throw new Error("This should never be reached"); } // Throw away initial instances until within the specified window size if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) { m_Train = new Instances(m_Train, m_Train.numInstances() - m_WindowSize, m_WindowSize); } // Compute the number of attributes that contribute // to each prediction m_NumAttributesUsed = 0.0; for (int i = 0; i < m_Train.numAttributes(); i ) { if ((i != m_Train.classIndex()) && (m_Train.attribute(i).isNominal() || m_Train .attribute(i).isNumeric())) { m_NumAttributesUsed = 1.0; } } // Invalidate any currently cross-validation selected k m_kNNValid = false; IB1中不关心m_NumClasses是因为它就找一个邻居,当然就一个值了。m_WindowSize是指用多少样本用于分类,这里不是随机选择而是直接选前m_WindowSize个。这里下面是看有多少属性参与预测。 KNN也是一个可以增量学习的分器量,下面看一下它的updateClassifier代码: public void updateClassifier(Instance instance) throws Exception { if (m_Train.equalHeaders(instance.dataset()) == false) { throw new Exception("Incompatible instance types"); } if (instance.classIsMissing()) { return; } if (!m_DontNormalize) { updateMinMax(instance); } m_Train.add(instance); m_kNNValid = false; if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)){ while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); } } } 同样很简单,updateMinMax,如果超出窗口大小, 循环删除超过窗口大小的第一个样本。 这里注意IBk没有实现classifyInstance,它只实现了distributionForInstances: public double[] distributionForInstance(Instance instance) throws Exception { if (m_Train.numInstances() == 0) { throw new Exception("No training instances!"); } if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)){ m_kNNValid = false; boolean deletedInstance = false; while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); } //rebuild datastructure KDTree currently can't delete if (deletedInstance == true) m_NNSearch.setInstances(m_Train); } // Select k by cross validation if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) { crossValidate(); } m_NNSearch.addInstanceInfo(instance); Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN); double[] distances = m_NNSearch.getDistances(); double[] distribution = makeDistribution(neighbours, distances); return distribution; } 前面两个判断不讲了,crossValidate()马上讲,寻找K个邻居在我第[18]篇里已经讲过了,现在我们看一下makeDistribution函数。 protected double[] makeDistribution(Instances neighbours, double[] distances)throws Exception { double total = 0, weight; double[] distribution = new double[m_NumClasses]; // Set up a correction to the estimator if (m_ClassType == Attribute.NOMINAL) { for (int i = 0; i < m_NumClasses; i ) { distribution[i] = 1.0 / Math.max(1, m_Train.numInstances()); } total = (double) m_NumClasses / Math.max(1, m_Train.numInstances()); } for (int i = 0; i < neighbours.numInstances(); i ) { // Collect class counts Instance current = neighbours.instance(i); distances[i] = distances[i] * distances[i]; distances[i] = Math.sqrt(distances[i] / m_NumAttributesUsed); switch (m_DistanceWeighting) { case WEIGHT_INVERSE: weight = 1.0 / (distances[i] 0.001); // to avoid div by zero break; case WEIGHT_SIMILARITY: weight = 1.0 - distances[i]; break; default: // WEIGHT_NONE: weight = 1.0; break; } weight *= current.weight(); try { switch (m_ClassType) { case Attribute.NOMINAL: distribution[(int) current.classValue()] = weight; break; case Attribute.NUMERIC: distribution[0] = current.classValue() * weight; break; } } catch (Exception ex) { throw new Error("Data has no class attribute!"); } total = weight; } // Normalise distribution if (total > 0) { Utils.normalize(distribution, total); } return distribution; } 第一行注释Set up a correction,我感觉没什么必要,又不是Bayes还有除0错误,没什么可修正的。这里可以看见它实现了三种距离权重计算方法,倒数,与1的差,另外就是固定权重1。然后如果类别是离散值把对应的类值加上权重,如果是连续值,就加上当前类别值剩权重。 crossValidate简单地说就是用蛮力找在到底用多少个邻居好,它对m_Train中的样本进行循环,对每个样本找邻居,然后统计看寻找多少个邻居时最好。 protected void crossValidate() { double[] performanceStats = new double[m_kNNUpper]; double[] performanceStatsSq = new double[m_kNNUpper]; for (int i = 0; i < m_kNNUpper; i ) { performanceStats[i] = 0; performanceStatsSq[i] = 0; } m_kNN = m_kNNUpper; Instance instance; Instances neighbours; double[] origDistances, convertedDistances; for (int i = 0; i < m_Train.numInstances(); i ) { instance = m_Train.instance(i); neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN); origDistances = m_NNSearch.getDistances(); for (int j = m_kNNUpper - 1; j >= 0; j--) { // Update the performance stats convertedDistances = new double[origDistances.length]; System.arraycopy(origDistances, 0, convertedDistances, 0, origDistances.length); double[] distribution = makeDistribution(neighbours, convertedDistances); double thisPrediction = Utils.maxIndex(distribution); if (m_Train.classAttribute().isNumeric()) { thisPrediction = distribution[0]; double err = thisPrediction - instance.classValue(); performanceStatsSq[j] = err * err; // Squared error performanceStats[j] = Math.abs(err); // Absolute error } else { if (thisPrediction != instance.classValue()) { performanceStats[j] ; // Classification error } } if (j >= 1) { neighbours = pruneToK(neighbours, convertedDistances, j); } } } // Check through the performance stats and select the best // k value (or the lowest k if more than one best) double[] searchStats = performanceStats; if (m_Train.classAttribute().isNumeric() && m_MeanSquared) { searchStats = performanceStatsSq; } double bestPerformance = Double.NaN; int bestK = 1; for (int i = 0; i < m_kNNUpper; i ) { if (Double.isNaN(bestPerformance) || (bestPerformance > searchStats[i])) { bestPerformance = searchStats[i]; bestK = i 1; } } m_kNN = bestK; m_kNNValid = true; } m_kNNUpper是另一个设置最多有多少样本的参数,枚举每一个样本(instance),找它的邻居(neighbors),和距离(origDistances)。接下来就是把从0到m_kNNUpper个邻居的得到的方差(performanceStatsSq)和标准差(performanceStats)与以前得到的值累加。pruneToK就是得到j个样本(如果j 1的距离不等于第j个),后面就比较好理 解了,m_MeanSquared对连续类别是选择用方差还是标准差进行选择,然后最出m_kNNUpper看在多少邻居的时候,分类误差最小,就认为是最好的邻居数。
本文档为【weka开发[21]ibk(knn)源代码分析】,请使用软件OFFICE或WPS软件打开。作品中的文字与图均可以修改和编辑, 图片更改请在作品中右键图片并更换,文字修改请直接点击文字进行修改,也可以新增和删除文档中的内容。
该文档来自用户分享,如有侵权行为请发邮件ishare@vip.sina.com联系网站客服,我们会及时删除。
[版权声明] 本站所有资料为用户分享产生,若发现您的权利被侵害,请联系客服邮件isharekefu@iask.cn,我们尽快处理。
本作品所展示的图片、画像、字体、音乐的版权可能需版权方额外授权,请谨慎使用。
网站提供的党政主题相关内容(国旗、国徽、党徽..)目的在于配合国家政策宣传,仅限个人学习分享使用,禁止用于任何广告和商用目的。
下载需要: 免费 已有0 人下载
最新资料
资料动态
专题动态
is_633808
暂无简介~
格式:doc
大小:32KB
软件:Word
页数:0
分类:互联网
上传时间:2019-07-17
浏览量:0