博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
使用机器学习的方法识别手写数字0和1
阅读量:4107 次
发布时间:2019-05-25

本文共 8661 字,大约阅读时间需要 28 分钟。

最近在学习OpenCV中机器学习的相关部分,想了解机器学习中从模型的训练到用模型去预测的具体实现过程,所以做了一个识别别手写数字0和1的简单项目(此文只识别手写数字0和1,如果想识别数字0到9,可以根据示例自己扩展),下面进行详细讲解。

这里用两种方式实现,一种是将模型训练和数字识别分别写在两个项目里,逻辑清晰,便于理解训练模型的过程和查看训练的结果;一种是将所有的功能写在一个项目里,可以直接查看最后的识别结果。

用于实验的数据:

我是将data文件夹存放在D盘,所以后面使用数据的时候,用到的绝对的路径为"D:\\data\\train_image\\1"。可以根据自己下载后存放的路径做相应的修改。用于训练的数字0和1各有400张图片,用于测试的数字0和1各有100张图片。

分两个项目实现

1.模型训练

#include 
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;using namespace cv;void getFiles(string path, vector
& files);void get_1(Mat& trainingImages, vector
& trainingLabels);void get_0(Mat& trainingImages, vector
& trainingLabels);int main(){ //获取训练数据 Mat classes; Mat trainingData; Mat trainingImages; vector
trainingLabels; //数字1先存入trainingImages,接着数字0存入trainingImages get_1(trainingImages, trainingLabels);//数字1贴标签 get_0(trainingImages, trainingLabels);//数字0贴标签 Mat(trainingImages).copyTo(trainingData);//将写好的包含特征的矩阵拷贝给trainingData trainingData.convertTo(trainingData, CV_32FC1); Mat(trainingLabels).copyTo(classes);//将包含标签的vector容器进行类型转换后拷贝到classes里 //配置SVM训练器参数 CvSVMParams SVM_params; SVM_params.svm_type = CvSVM::C_SVC; SVM_params.kernel_type = CvSVM::LINEAR; SVM_params.degree = 0; SVM_params.gamma = 1; SVM_params.coef0 = 0; SVM_params.C = 1; SVM_params.nu = 0; SVM_params.p = 0; SVM_params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, 0.01); //训练 CvSVM svm; svm.train(trainingData, classes, Mat(), Mat(), SVM_params); //保存模型 svm.save("svm.xml"); cout << "训练好了!!!" << endl; getchar(); return 0;}//通过文件路径遍历文件夹,实现图像批处理void getFiles(string path, vector
& files){ long hFile = 0; struct _finddata_t fileinfo; string p; if ((hFile = _findfirst(p.assign(path).append("\\*").c_str(), &fileinfo)) != -1) { do { if ((fileinfo.attrib & _A_SUBDIR)) { if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0) getFiles(p.assign(path).append("\\").append(fileinfo.name), files); } else { files.push_back(p.assign(path).append("\\").append(fileinfo.name)); } } while (_findnext(hFile, &fileinfo) == 0); _findclose(hFile); }}//将数字1的图像贴标签void get_1(Mat& trainingImages, vector
& trainingLabels){ /*这里使用绝对路径,也可以使用相对路径,先得将data文件夹放到.cpp所在的文件夹 char * filePath = "data\\train_image\\1"; */ char * filePath = "D:\\data\\train_image\\1"; vector
files; getFiles(filePath, files); int number = files.size(); for (int i = 0; i < number; i++) { //利用循环遍历文件夹里的每一张图像 Mat SrcImage = imread(files[i].c_str()); /*特征提取 函数原型Mat reshape(int cn, int rows=0) const; cn为新的通道数,如果cn = 0,表示通道数不会改变 rows为新的行数,如果rows = 0,表示行数不会改变 reshape(1, 1)的结果就是原图像对应的矩阵将被拉伸成一个一行的向量,作为特征向量 */ SrcImage = SrcImage.reshape(1, 1); //获取一张图片后会将图片(特征)写入到容器中,紧接着会将标签写入另一个容器中,这样就保证了特征和标签是一一对应的关系 trainingImages.push_back(SrcImage);//将图片特征写入容器 trainingLabels.push_back(1);//将标签写入容器 }}//将数字0的图像贴标签void get_0(Mat& trainingImages, vector
& trainingLabels){ char * filePath = "D:\\data\\train_image\\0"; vector
files; getFiles(filePath, files);//将路径对应的文件夹下的图片都存放在vector容器中 int number = files.size(); for (int i = 0; i < number; i++)//遍历文件夹 { Mat SrcImage = imread(files[i].c_str()); SrcImage = SrcImage.reshape(1, 1); trainingImages.push_back(SrcImage); trainingLabels.push_back(0); }}

训练好的模型保存在“svm.xml”文件中,使用的时候加载就好。运行结果和生成的xml如下:

运行过程中如果遇到下图所示错误,可以参考博客:

2.数字识别

#include 
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;using namespace cv;void getFiles(string path, vector
& files);int main(){ int result = 0; //也可以使用相对路径char * filePath = "data\\test_image\\0",data文件夹需先与.cpp放在同一个文件夹下 char * filePath = "D:\\Projects\\visual studio 2013\\SVMTest1\\data\\test_image\\0";//识别数字0 vector
files; getFiles(filePath, files); int number = files.size(); cout << number << endl; CvSVM svm; svm.clear(); /*当.xml文件与.cpp文件在同一个文件夹下时可以使用相对路径string modelpath ="svm.xml" 在不同文件夹下时需使用绝对路径 */ string modelpath = "D:\\Projects\\visual studio 2013\\SVMTest2\\SVMTest2\\svm.xml"; FileStorage svm_fs(modelpath, FileStorage::READ); if (svm_fs.isOpened()) { svm.load(modelpath.c_str());//加载模型 } for (int i = 0; i < number; i++) { Mat inMat = imread(files[i].c_str()); Mat p = inMat.reshape(1, 1); p.convertTo(p, CV_32FC1); int response = (int)svm.predict(p); if (response == 0) { result++;//用result记录识别正确的个数 } } cout << result << endl; getchar(); return 0;}void getFiles(string path, vector
& files){ long hFile = 0; struct _finddata_t fileinfo; string p; if ((hFile = _findfirst(p.assign(path).append("\\*").c_str(), &fileinfo)) != -1) { do { if ((fileinfo.attrib & _A_SUBDIR)) { if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0) getFiles(p.assign(path).append("\\").append(fileinfo.name), files); } else { files.push_back(p.assign(path).append("\\").append(fileinfo.name)); } } while (_findnext(hFile, &fileinfo) == 0); _findclose(hFile); }}

使用相对路径加载模型时,需要将生成好的.xml模型复制到.cpp文件所在的文件夹。

最终的 运行结果为:

主函数中测试的是数字0的100张测试图片,最后识别正确的图片也是100张。

一个项目实现

准备工作:将data文件夹复制到main.cpp文件所在的文件夹

#include 
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;using namespace cv;void getFiles(string path, vector
& files);void get_1(Mat& trainingImages, vector
& trainingLabels);void get_0(Mat& trainingImages, vector
& trainingLabels);void train();int main(){ int result = 0; //char * filePath = "D:\\data\\test_image\\0"; char * filePath = "data\\test_image\\0"; vector
files; getFiles(filePath, files); int number = files.size(); cout << number << endl; train(); CvSVM svm; svm.clear(); //string modelpath = "D:\\svm.xml"; string modelpath = "svm.xml"; FileStorage svm_fs(modelpath, FileStorage::READ); if (svm_fs.isOpened()) { svm.load(modelpath.c_str()); } for (int i = 0; i < number; i++) { Mat inMat = imread(files[i].c_str()); Mat p = inMat.reshape(1, 1); p.convertTo(p, CV_32FC1); int response = (int)svm.predict(p); if (response == 0) { result++; } } cout << result << endl; getchar(); return 0;}//通过文件路径遍历文件夹,实现图像批处理void getFiles(string path, vector
& files){ long hFile = 0; struct _finddata_t fileinfo; string p; if ((hFile = _findfirst(p.assign(path).append("\\*").c_str(), &fileinfo)) != -1) { do { if ((fileinfo.attrib & _A_SUBDIR)) { if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0) getFiles(p.assign(path).append("\\").append(fileinfo.name), files); } else { files.push_back(p.assign(path).append("\\").append(fileinfo.name)); } } while (_findnext(hFile, &fileinfo) == 0); _findclose(hFile); }}//将数字1的图像贴标签void get_1(Mat& trainingImages, vector
& trainingLabels){ //char * filePath = "D:\\data\\train_image\\1"; char * filePath = "data\\train_image\\1"; vector
files; getFiles(filePath, files); int number = files.size(); for (int i = 0; i < number; i++) { //利用循环遍历文件夹里的每一张图像 Mat SrcImage = imread(files[i].c_str()); /*特征提取 函数原型Mat reshape(int cn, int rows=0) const; cn为新的通道数,如果cn = 0,表示通道数不会改变 rows为新的行数,如果rows = 0,表示行数不会改变 reshape(1, 1)的结果就是原图像对应的矩阵将被拉伸成一个一行的向量,作为特征向量 */ SrcImage = SrcImage.reshape(1, 1); //获取一张图片后会将图片(特征)写入到容器中,紧接着会将标签写入另一个容器中,这样就保证了特征和标签是一一对应的关系 trainingImages.push_back(SrcImage);//将图片特征写入容器 trainingLabels.push_back(1);//将标签写入容器 }}//将数字0的图像贴标签void get_0(Mat& trainingImages, vector
& trainingLabels){ //char * filePath = "D:\\Projects\\visual studio 2013\\SVMTest1\\data\\train_image\\0"; char * filePath = "data\\train_image\\0"; vector
files; getFiles(filePath, files);//将路径对应的文件夹下的图片都存放在vector容器中 int number = files.size(); for (int i = 0; i < number; i++)//遍历文件夹 { Mat SrcImage = imread(files[i].c_str()); SrcImage = SrcImage.reshape(1, 1); trainingImages.push_back(SrcImage); trainingLabels.push_back(0); }}//训练模型void train(){ Mat classes; Mat trainingData; Mat trainingImages; vector
trainingLabels; //数字1先存入trainingImages,接着数字0存入trainingImages get_1(trainingImages, trainingLabels);//数字1贴标签 get_0(trainingImages, trainingLabels);//数字0贴标签 Mat(trainingImages).copyTo(trainingData);//将写好的包含特征的矩阵拷贝给trainingData trainingData.convertTo(trainingData, CV_32FC1); Mat(trainingLabels).copyTo(classes);//将包含标签的vector容器进行类型转换后拷贝到classes里 //配置SVM训练器参数 CvSVMParams SVM_params; SVM_params.svm_type = CvSVM::C_SVC; SVM_params.kernel_type = CvSVM::LINEAR; SVM_params.degree = 0; SVM_params.gamma = 1; SVM_params.coef0 = 0; SVM_params.C = 1; SVM_params.nu = 0; SVM_params.p = 0; SVM_params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, 0.01); //训练 CvSVM svm; svm.train(trainingData, classes, Mat(), Mat(), SVM_params); //保存模型 svm.save("svm.xml"); cout << "训练好了!!!" << endl;}

运行结果如下:

参考博客:

你可能感兴趣的文章
有return的情况下try catch finally的执行顺序(最有说服力的总结)
查看>>
String s1 = new String("abc"); String s2 = ("abc");
查看>>
JAVA数据类型
查看>>
Xshell 4 入门
查看>>
SoapUI-入门
查看>>
Oracle -常用命令
查看>>
JAVA技术简称
查看>>
ORACLE模糊查询优化浅谈
查看>>
2016——个人年度总结
查看>>
2017——新的开始,加油!
查看>>
【Python】学习笔记——-6.2、使用第三方模块
查看>>
【Python】学习笔记——-7.0、面向对象编程
查看>>
【Python】学习笔记——-7.1、类和实例
查看>>
【Python】学习笔记——-7.2、访问限制
查看>>
【Python】学习笔记——-7.3、继承和多态
查看>>
【Python】学习笔记——-7.4、获取对象信息
查看>>
【Python】学习笔记——-7.5、实例属性和类属性
查看>>
Linux设备模型(总线、设备、驱动程序和类)之四:class_register
查看>>
git中文安装教程
查看>>
虚拟机 CentOS7/RedHat7/OracleLinux7 配置静态IP地址 Ping 物理机和互联网
查看>>