OpenCV3.x实现KNN算法(K近邻算法),并保存训练模型

     【尊重原创,转载请注明出处 <https://blog.csdn.net/guyuealian/article/details/80241473>
】https://blog.csdn.net/guyuealian/article/details/80241473

   OpenCV 3.x中cv::ml::Knearest类可以实现K-最近邻(KNN)算法,其详细用法可以参考官方说明文档:
https://docs.opencv.org/3.2.0/dd/de1/classcv_1_1ml_1_1KNearest.html

(1)、cv::ml::Knearest类:继承自cv::ml::StateModel,而cv::ml::StateModel又继承自cv::Algorithm;
(2)、create函数:为static,new一个KNearestImpl用来创建一个KNearest对象;
(3)、setDefaultK/getDefaultK函数:在预测时,设置/获取的K值;
(4)、setIsClassifier/getIsClassifier函数:设置/获取应用KNN是进行分类还是回归;
(5)、setEmax/getEmax函数:在使用KDTree算法时,设置/获取Emax参数值;
(6)、setAlgorithmType/getAlgorithmType函数:设置/获取KNN算法类型,目前支持两种:brute_force和KDTree;
(7)、findNearest函数:根据输入预测分类/回归结果。

关于KNN算法思路可以参考: http://blog.csdn.net/fengbingchun/article/details/78464169  

     
下面是本人使用OpenCV3.2实现的KNN算法,其中利用save()方法把训练数据保存下来,测试时重新加载load()训练模型,这样可以实现单独的测试,而不需要重新训练数据:
#include "stdafx.h" #include <opencv2/core/core.hpp> #include
<opencv2/highgui/highgui.hpp> #include <opencv2/ml/ml.hpp> #include
<opencv2/ml.hpp> using namespace cv; using namespace cv::ml; using namespace
std; int main() { float labels[10] = { 0.0, 1.0, 1.0, 2.0,2.0,0.0, 1.0,1.0,
2.0,2.0 }; Mat labelsMat(10, 1, CV_32FC1, labels); // Set up training data
float trainArray[10][3] = { { 510, 510,10 },{ 405, 10,510 },{ 501, 45,420 },{
10,20, 510 },{ 35,45,515 },{ 540,420,40 },{ 380,30,300 },{ 400,70,500 },{
30,60,410 },{ 54,23,543 } }; Mat trainDataMat(10, 3, CV_32FC1, trainArray);
/*******************************************训练过程******************************************/
//保存训练模型(在KNN中实质上是保存训练样本的原始数据) string knnPath = "D:/SmartAlbum/image1/knn.xml";
Ptr<KNearest> kclassifier = KNearest::create(); Ptr<TrainData> trainData;
trainData = TrainData::create(trainDataMat, SampleTypes::ROW_SAMPLE,
labelsMat); kclassifier->setIsClassifier(true);
kclassifier->setAlgorithmType(KNearest::Types::BRUTE_FORCE);
kclassifier->setDefaultK(1); kclassifier->train(trainData);
kclassifier->save(knnPath);//会把trainDataMat的原始数据全部保存为*.xml文件
/*******************************************测试过程******************************************/
//加载训练模型(在KNN中,实质上就是加载训练样本的原始数据) const int K = 4;//testModel->getDefaultK()
Ptr<KNearest> testModel = StatModel::load<KNearest>(knnPath); Mat sampleMat =
(Mat_<float>(1, 3) << 310, 5, 339);//测试样本 Mat matResults(0, 0, CV_32F);//保存测试结果
testModel->findNearest(sampleMat, K, matResults);//knn分类预测 cout <<
"matResults=" << matResults << endl; system("pause"); waitKey(); }

友情链接
KaDraw流程图
API参考文档
OK工具箱
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:ixiaoyang8@qq.com
QQ群:637538335
关注微信