BOW--创建和训练目标检测器
BOW--创建和训练目标检测器
词袋模型(Bag of Words,BOW)是一种常用的目标检测方法,通过将图像特征转化为固定长度的向量表示,实现对目标的检测和分类。本文将详细介绍BOW方法的实现步骤,并通过一个汽车检测的具体案例进行演示。
BOW--创建和训练目标检测器
1、bow方法的实现步骤
- 去一个样本数据集
- 对样本集中的每幅图像提取描述符(采用SIFT,SURF等方法)
- 将每一个描述符都加入到BOW训练器中
- 将描述符聚类到K簇中(聚类的中心就是视觉单词)
接下来要测试一下分类器,并尝试进行检测,这个过程和前面介绍的非常相似:给定测试的图像,提取特征,然后基于这些特征到最近簇心的距离来实现向量化,已形成直方图。
下图是bow过程的可视化表示:
2、K-means聚类
对于给定的数据集,k表示要分割的数据集的簇数,术语“means”指的是数学中的均值,从可视化的角度来讲,簇的均值是这个簇中所有点的几何中心。
BagOfWordsKMeansTrainer是一种执行目标检测的类(class),opencv官方文档给出的定义如下:
"kmeans() - based class to train a visal vocabulary using the bag-of-words approch"
k-means聚类操作结果如下图所示:(图形帮助理解)
3、汽车检测
采用uiuc的数据集,下载链接:
链接:https://pan.baidu.com/s/1Bx4uGysRpFhGvqqb43yUBA&shfl=sharepset
提取码:xjre
1、定义函数path返回训练数据的路径:(./表示当前文件夹,../表示上一级的文件夹)
datapath = "../trains/"
def path(cls,i):
return "%s%s%d.pgm" % (datapath,cls,i+1)
pos, neg = "pos-", "neg-"
2、创建两个sift实例,一个提取关键点,一个提取描述符。
detect = cv2.xfeatures2d.SIFT_create()
extract = cv2.xfeatures2d.SIFT_create()
3、创建flann的匹配器实例
flann_params = dict(algorithm=1, trees=5)
matcher = cv2.FlannBasedMatcher(flann_params, {})
4、创建BOW训练器
bow_kmeans_trainer = cv2.BOWKMeansTrainer(40)
5、为bow训练器制定的簇数为40,接下来初始化bow提取器(bow_extractor),视觉词汇作为bow的输入,在测试图像中会检测这些视觉词汇:
extract_bow = cv2.BOWImgDescriptorExtractor(extract, matcher)
6、每个类读入8个图像的sift特征,sift函数返回图像的描述符:
def extract_sift(fn):
im = cv2.imread(fn, 0)
return extract.compute(im, detect.detect(im))[1]
for i in range(8):
bow_kmeans_trainer.add(extract_sift(path(pos, i)))
bow_kmeans_trainer.add(extract_sift(path(neg, i)))
7、运用训练器的cluster函数来创建视觉词汇,降为bow_extractor指定返回的词汇,以便他能够从测试图像中提取描述符
voc = bow_kmeans_trainer.cluster()
extract_bow.setVocabulary(voc)
8、定义一个函数,该函数返回基于bow的描述符提取器计算得到的描述符,并用数组来存储20个样本的描述符和label
def bow_features(fn):
im = cv2.imread(fn, 0)
return extract_bow.compute(im, detect.detect(im))
traindata, trainlabels = [], []
for i in range(20):
traindata.extend(bow_features(path(pos, i)));
trainlabels.append(1)
traindata.extend(bow_features(path(neg, i)));
trainlabels.append(-1)
9、创建一个svm实例,并对其进行训练
svm = cv2.ml.SVM_create()
svm.train(np.array(traindata), cv2.ml.ROW_SAMPLE, np.array(trainlabels))
10、定义predict函数,并返回预测
def predict(fn):
f = bow_features(fn);
p = svm.predict(f)
print
fn, "\t", p[1][0][0]
return p
11、最后读取要检测的图像,并通过predict进行判断
car = "../images/car.jpg"
car_img = cv2.imread(car)
car_predict = predict(car)
font = cv2.FONT_HERSHEY_SIMPLEX
# 预测为正样本(检测到车)
if (car_predict[1][0][0] == 1.0):
cv2.putText(car_img, 'Car Detected', (10, 30), font, 1, (0, 255, 0), 2, cv2.LINE_AA)
# 预测为负样本(没有检测到车)
else:
cv2.putText(car_img, 'Car Not Detected', (10, 30), font, 1, (0, 0, 255), 2, cv2.LINE_AA)
cv2.imshow('BOW + SVM Success', car_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
12、完整代码如下:
import cv2
import numpy as np
from os.path import join
datapath = "../trains/"
def path(cls,i):
return "%s%s%d.pgm" % (datapath,cls,i+1)
pos, neg = "pos-", "neg-"
detect = cv2.xfeatures2d.SIFT_create()
extract = cv2.xfeatures2d.SIFT_create()
flann_params = dict(algorithm=1, trees=5)
matcher = cv2.FlannBasedMatcher(flann_params, {})
bow_kmeans_trainer = cv2.BOWKMeansTrainer(40)
extract_bow = cv2.BOWImgDescriptorExtractor(extract, matcher)
def extract_sift(fn):
im = cv2.imread(fn, 0)
return extract.compute(im, detect.detect(im))[1]
for i in range(8):
bow_kmeans_trainer.add(extract_sift(path(pos, i)))
bow_kmeans_trainer.add(extract_sift(path(neg, i)))
voc = bow_kmeans_trainer.cluster()
extract_bow.setVocabulary(voc)
def bow_features(fn):
im = cv2.imread(fn, 0)
return extract_bow.compute(im, detect.detect(im))
traindata, trainlabels = [], []
for i in range(20):
traindata.extend(bow_features(path(pos, i)));
trainlabels.append(1)
traindata.extend(bow_features(path(neg, i)));
trainlabels.append(-1)
svm = cv2.ml.SVM_create()
svm.train(np.array(traindata), cv2.ml.ROW_SAMPLE, np.array(trainlabels))
def predict(fn):
f = bow_features(fn);
p = svm.predict(f)
print
fn, "\t", p[1][0][0]
return p
car = "../images/car.jpg"
car_img = cv2.imread(car)
car_predict = predict(car)
font = cv2.FONT_HERSHEY_SIMPLEX
if (car_predict[1][0][0] == 1.0):
cv2.putText(car_img, 'Car Detected', (10, 30), font, 1, (0, 255, 0), 2, cv2.LINE_AA)
else:
cv2.putText(car_img, 'Car Not Detected', (10, 30), font, 1, (0, 0, 255), 2, cv2.LINE_AA)
cv2.imshow('BOW + SVM Success', car_img)
cv2.waitKey(0)
cv2.destroyAllWindows()