Skip to content

Commit 5e5b98a

Browse files
committed
add python code
1 parent e51e3d0 commit 5e5b98a

8 files changed

Lines changed: 1193 additions & 0 deletions

File tree

classification.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# -*- coding:utf-8 -*-
2+
# 用于模型的单张图像分类操作
3+
import os
4+
os.environ['GLOG_minloglevel'] = '2' # 将caffe的输出log信息不显示,必须放到import caffe前
5+
import caffe # caffe 模块
6+
from caffe.proto import caffe_pb2
7+
from google.protobuf import text_format
8+
import numpy as np
9+
import cv2
10+
import matplotlib.pyplot as plt
11+
import time
12+
13+
# 分类单张图像img
14+
def classification(img, net, transformer, synset_words):
15+
im = caffe.io.load_image(img)
16+
# 导入输入图像
17+
net.blobs['data'].data[...] = transformer.preprocess('data', im)
18+
19+
start = time.clock()
20+
# 执行测试
21+
net.forward()
22+
end = time.clock()
23+
print('classification time: %f s' % (end - start))
24+
25+
# 查看目标检测结果
26+
labels = np.loadtxt(synset_words, str, delimiter='\t')
27+
28+
category = net.blobs['prob'].data[0].argmax()
29+
30+
class_str = labels[int(category)].split(',')
31+
class_name = class_str[0]
32+
# text_font = cv2.cv.InitFont(cv2.cv.CV_FONT_HERSHEY_SCRIPT_SIMPLEX, 1, 1, 0, 3, 8)
33+
cv2.putText(im, class_name, (0, im.shape[0]), cv2.cv.CV_FONT_HERSHEY_SIMPLEX, 1, (55, 255, 155), 2)
34+
35+
# 显示结果
36+
plt.imshow(im, 'brg')
37+
plt.show()
38+
39+
#CPU或GPU模型转换
40+
caffe.set_mode_cpu()
41+
#caffe.set_device(0)
42+
#caffe.set_mode_gpu()
43+
44+
caffe_root = '../../'
45+
# 网络参数(权重)文件
46+
caffemodel = caffe_root + 'models/bvlc_alexnet/bvlc_alexnet.caffemodel'
47+
# 网络实施结构配置文件
48+
deploy = caffe_root + 'models/bvlc_alexnet/deploy.prototxt'
49+
50+
51+
img_root = caffe_root + 'data/VOCdevkit/VOC2007/JPEGImages/'
52+
synset_words = caffe_root + 'data/ilsvrc12/synset_words.txt'
53+
54+
# 网络实施分类
55+
net = caffe.Net(deploy, # 定义模型结构
56+
caffemodel, # 包含了模型的训练权值
57+
caffe.TEST) # 使用测试模式(不执行dropout)
58+
59+
# 加载ImageNet图像均值 (随着Caffe一起发布的)
60+
mu = np.load(caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy')
61+
mu = mu.mean(1).mean(1) # 对所有像素值取平均以此获取BGR的均值像素值
62+
63+
# 图像预处理
64+
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
65+
transformer.set_transpose('data', (2,0,1))
66+
transformer.set_mean('data', mu)
67+
transformer.set_raw_scale('data', 255)
68+
transformer.set_channel_swap('data', (2,1,0))
69+
70+
# 处理图像
71+
while 1:
72+
img_num = raw_input("Enter Img Number: ")
73+
if img_num == '': break
74+
img = img_root + '{:0>6}'.format(img_num) + '.jpg'
75+
classification(img,net,transformer,synset_words)
76+

detection.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# -*- coding:utf-8 -*-
2+
# 用于模型的单张图像分类操作
3+
import os
4+
os.environ['GLOG_minloglevel'] = '2' # 将caffe的输出log信息不显示,必须放到import caffe前
5+
import caffe # caffe 模块
6+
from caffe.proto import caffe_pb2
7+
from google.protobuf import text_format
8+
import numpy as np
9+
import cv2
10+
import matplotlib.pyplot as plt
11+
import time
12+
13+
# 分类单张图像img
14+
def detection(img, net, transformer, labels_file):
15+
im = caffe.io.load_image(img)
16+
net.blobs['data'].data[...] = transformer.preprocess('data', im)
17+
18+
start = time.clock()
19+
# 执行测试
20+
net.forward()
21+
end = time.clock()
22+
print('detection time: %f s' % (end - start))
23+
24+
# 查看目标检测结果
25+
file = open(labels_file, 'r')
26+
labelmap = caffe_pb2.LabelMap()
27+
text_format.Merge(str(file.read()), labelmap)
28+
29+
loc = net.blobs['detection_out'].data[0][0]
30+
confidence_threshold = 0.5
31+
for l in range(len(loc)):
32+
if loc[l][2] >= confidence_threshold:
33+
xmin = int(loc[l][3] * im.shape[1])
34+
ymin = int(loc[l][4] * im.shape[0])
35+
xmax = int(loc[l][5] * im.shape[1])
36+
ymax = int(loc[l][6] * im.shape[0])
37+
img = np.zeros((512, 512, 3), np.uint8) # 生成一个空彩色图像
38+
cv2.rectangle(im, (xmin, ymin), (xmax, ymax), (55 / 255.0, 255 / 255.0, 155 / 255.0), 2)
39+
40+
# 确定分类类别
41+
class_name = labelmap.item[int(loc[l][1])].display_name
42+
# text_font = cv2.cv.InitFont(cv2.cv.CV_FONT_HERSHEY_SCRIPT_SIMPLEX, 1, 1, 0, 3, 8)
43+
cv2.putText(im, class_name, (xmin, ymax), cv2.cv.CV_FONT_HERSHEY_SIMPLEX, 1, (55, 255, 155), 2)
44+
45+
# 显示结果
46+
plt.imshow(im, 'brg')
47+
plt.show()
48+
49+
#CPU或GPU模型转换
50+
caffe.set_mode_cpu()
51+
#caffe.set_device(0)
52+
#caffe.set_mode_gpu()
53+
54+
caffe_root = '../../'
55+
# 网络参数(权重)文件
56+
caffemodel = caffe_root + 'models/SSD_300x300/VGG_VOC0712_SSD_300x300_iter_60000.caffemodel'
57+
# 网络实施结构配置文件
58+
deploy = caffe_root + 'models/SSD_300x300/deploy.prototxt'
59+
60+
61+
img_root = caffe_root + 'data/VOCdevkit/VOC2007/JPEGImages/'
62+
labels_file = caffe_root + 'data/VOC0712/labelmap_voc.prototxt'
63+
64+
# 网络实施分类
65+
net = caffe.Net(deploy, # 定义模型结构
66+
caffemodel, # 包含了模型的训练权值
67+
caffe.TEST) # 使用测试模式(不执行dropout)
68+
69+
# 加载ImageNet图像均值 (随着Caffe一起发布的)
70+
mu = np.load(caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy')
71+
mu = mu.mean(1).mean(1) # 对所有像素值取平均以此获取BGR的均值像素值
72+
73+
# 图像预处理
74+
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
75+
transformer.set_transpose('data', (2,0,1))
76+
transformer.set_mean('data', mu)
77+
transformer.set_raw_scale('data', 255)
78+
transformer.set_channel_swap('data', (2,1,0))
79+
80+
# 处理图像
81+
while 1:
82+
img_num = raw_input("Enter Img Number: ")
83+
if img_num == '': break
84+
img = img_root + '{:0>6}'.format(img_num) + '.jpg'
85+
detection(img,net,transformer,labels_file)

generate_lmdb.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# -*- coding:utf-8 -*-
2+
# 将图像数据生成lmdb数据集
3+
# 1. 生成分类图像数据集
4+
# 2. 生成目标检测图像数据集
5+
import os
6+
import sys
7+
import numpy as np
8+
import random
9+
from caffe.proto import caffe_pb2
10+
from xml.dom.minidom import parse
11+
12+
# 生成分类标签文件
13+
def labelmap(labelmap_file, label_info):
14+
labelmap = caffe_pb2.LabelMap()
15+
for i in range(len(label_info)):
16+
labelmapitem = caffe_pb2.LabelMapItem()
17+
labelmapitem.name = label_info[i]['name']
18+
labelmapitem.label = label_info[i]['label']
19+
labelmapitem.display_name = label_info[i]['display_name']
20+
labelmap.item.add().MergeFrom(labelmapitem)
21+
with open(labelmap_file, 'w') as f:
22+
f.write(str(labelmap))
23+
24+
def rename_img(Img_dir):
25+
# 重新命名Img,这里假设图像名称表示为000011.jpg、003456.jpg、000000.jpg格式,最高6位,前补0
26+
# 列出图像,并将图像改为序号名称
27+
listfile=os.listdir(Img_dir) # 提取图像名称列表
28+
total_num = 0
29+
for line in listfile: #把目录下的文件都赋值给line这个参数
30+
if line[-4:] == '.jpg':
31+
newname = '{:0>6}'.format(total_num) +'.jpg'
32+
os.rename(os.path.join(Img_dir, line), os.path.join(Img_dir, newname))
33+
total_num+=1 #统计所有图像
34+
35+
def get_img_size():
36+
pass
37+
38+
def create_annoset(anno_args):
39+
if anno_args.anno_type == "detection":
40+
cmd = "E:\Code\windows-ssd/Build/x64/Release/convert_annoset.exe" \
41+
" --anno_type={}" \
42+
" --label_type={}" \
43+
" --label_map_file={}" \
44+
" --check_label={}" \
45+
" --min_dim={}" \
46+
" --max_dim={}" \
47+
" --resize_height={}" \
48+
" --resize_width={}" \
49+
" --backend={}" \
50+
" --shuffle={}" \
51+
" --check_size={}" \
52+
" --encode_type={}" \
53+
" --encoded={}" \
54+
" --gray={}" \
55+
" {} {} {}" \
56+
.format(anno_args.anno_type, anno_args.label_type, anno_args.label_map_file, anno_args.check_label,
57+
anno_args.min_dim, anno_args.max_dim, anno_args.resize_height, anno_args.resize_width, anno_args.backend, anno_args.shuffle,
58+
anno_args.check_size, anno_args.encode_type, anno_args.encoded, anno_args.gray, anno_args.root_dir, anno_args.list_file, anno_args.out_dir)
59+
elif anno_args.anno_type == "classification":
60+
cmd = "E:\Code\windows-ssd/Build/x64/Release/convert_annoset.exe" \
61+
" --anno_type={}" \
62+
" --min_dim={}" \
63+
" --max_dim={}" \
64+
" --resize_height={}" \
65+
" --resize_width={}" \
66+
" --backend={}" \
67+
" --shuffle={}" \
68+
" --check_size={}" \
69+
" --encode_type={}" \
70+
" --encoded={}" \
71+
" --gray={}" \
72+
" {} {} {}" \
73+
.format(anno_args.anno_type, anno_args.min_dim, anno_args.max_dim, anno_args.resize_height,
74+
anno_args.resize_width, anno_args.backend, anno_args.shuffle, anno_args.check_size, anno_args.encode_type, anno_args.encoded,
75+
anno_args.gray, anno_args.root_dir, anno_args.list_file, anno_args.out_dir)
76+
print cmd
77+
os.system(cmd)
78+
79+
def detection_list(Img_dir, Ano_dir, Data_dir, test_num):
80+
# 造成目标检测图像数据库
81+
# Img_dir表示图像文件夹
82+
# Ano_dir表示图像标记文件夹,用labelImg生成
83+
# Data_dir生成的数据库文件地址
84+
# test_num测试图像的数目
85+
# 列出图像
86+
listfile=os.listdir(Img_dir) # 提取图像名称列表
87+
88+
# 列出图像,并将图像改为序号名称
89+
total_num = 0
90+
for line in listfile: #把目录下的文件都赋值给line这个参数
91+
if line[-4:] == '.jpg':
92+
total_num+=1 #统计所有图像
93+
94+
trainval_num = total_num-test_num # 训练图像数目
95+
96+
# 生成训练图像及测试图像列表
97+
test_list_file=open(Data_dir+'/test.txt','w')
98+
train_list_file=open(Data_dir+'/trainval.txt','w')
99+
100+
test_list = np.random.randint(0,total_num-1, size=test_num)
101+
102+
train_list = range(total_num)
103+
for n in range(test_num):
104+
train_list.remove(test_list[n])
105+
random.shuffle(train_list)
106+
107+
# 测试图像排序,而训练图像不用排序
108+
test_list = np.sort(test_list)
109+
# train_list = np.sort(train_list)
110+
111+
for n in range(trainval_num):
112+
train_list_file.write(Img_dir + '{:0>6}'.format(train_list[n]) +'.jpg '+ Ano_dir + '{:0>6}'.format(train_list[n]) +'.xml\n')
113+
114+
for n in range(test_num):
115+
test_list_file.write(Img_dir + '{:0>6}'.format(test_list[n]) +'.jpg '+ Ano_dir + '{:0>6}'.format(test_list[n]) +'.xml\n')
116+
117+
118+
caffe_root = 'E:/Code/Github/windows_caffe/'
119+
data_root = caffe_root + 'data/mnist/'
120+
Img_dir = data_root + 'JPEGImages/'
121+
Ano_dir = data_root + 'Annotations/'
122+
anno_type = "detection"
123+
test_num = 100
124+
125+
# 第一步,预处理图像,重命名图像名,生成各图像标记信息
126+
# rename_img(Img_dir)
127+
# 然后通过labelImg(可以通过pip install labelImg安装,出现错误可以删除PyQt4的描述)来生成图像的标记
128+
129+
# 第二步,生成分类标签文件
130+
# 编辑label信息
131+
label_info = [
132+
dict(name='none', label=0, display_name='background'), # 背景
133+
dict(name="cat",label=1, display_name='cat'), # 背景
134+
dict(name="dog",label=2, display_name='dog'), # 背景
135+
]
136+
labelmap(data_root+'labelmap_voc.prototxt', label_info)
137+
138+
# 第三步,生成图像及标记的列表文件
139+
if anno_type == "detection":
140+
detection_list(Img_dir, Ano_dir, data_root, test_num)
141+
else:
142+
# 分类,生成
143+
pass
144+
145+
# 第四步,生成lmdb文件
146+
# 初始化信息
147+
anno_args = {}
148+
anno_args['anno_type'] = anno_type
149+
# 仅用于目标检测,lable文件的类型:{xml, json, txt}
150+
anno_args['label_type'] = "xml"
151+
# 仅用于目标检测,label文件地址
152+
anno_args['label_map_file'] = data_root+"labelmap_voc.prototxt"
153+
# 是否检测所有数据有相同的大小.默认False
154+
anno_args['check_size'] = False
155+
# 检测label是否相同的名称,默认False
156+
anno_args['check_label'] = False
157+
# 为0表示图像不用重新调整尺寸
158+
anno_args['min_dim'] = 0
159+
anno_args['max_dim'] = 0
160+
anno_args['resize_height'] = 0
161+
anno_args['resize_width'] = 0
162+
anno_args['backend'] = "lmdb" # 数据集格式(lmdb, leveldb)
163+
anno_args['shuffle'] = False # 是否随机打乱图像及对应标签
164+
anno_args['encode_type'] = "" # 图像编码格式('png','jpg',...)
165+
anno_args['encoded'] = False # 是否编码,默认False
166+
anno_args['gray'] = False # 是否视为灰度图,默认False
167+
anno_args['root_dir'] = data_root # 存放图像文件夹及标签文件夹的根目录
168+
anno_args['list_file'] = data_root + '' # listfile文件地址
169+
anno_args['out_dir'] = data_root # 最终lmdb的存在地址
170+
171+
# 生成训练数据集train_lmdb
172+
anno_args['list_file'] = data_root + 'trainval.txt'
173+
create_annoset(anno_args)
174+
175+
# 生成测试数据集train_lmdb
176+
anno_args['list_file'] = data_root + 'test.txt'
177+
create_annoset(anno_args)
178+
179+
180+
181+
182+

0 commit comments

Comments
 (0)