网上从文件中读取样本和标签的资料很多,但大多讲的不全面,或只讲原理,或只有变为.tfrecords部分,或没有调用的栗子。寄几and男票一起捣鼓了两天,终于有了目前这个完整版的代码,希望对看到的朋友有所帮助。

1. 准备样本和标签

样本图示如图1,标签文件train_y.csv如图2,这是个2分类问题。



图1



图2

2.生成记录样本的记录文件

我们的图片存储路径如图3红框所示,标签文件train_y.csv存储路径如图3绿框所示。


我们用ray14_train.py进行train,这个.py文件和train_y.csv不在同一目录下。所以,在标签文件train_y.csv中,我们需要将图片名称这一列变为相对路径,如图4所示,这个新csv我们存为y_train.csv,测试集也这么处理。



图3

      

图4
import numpy as np import pandas as pd import cv2 import csv from os import
path as osp import osbase_path = os.path.join('images','images224')
train_y_path = os.path.join(base_path,'train_y.csv') train_y =
np.loadtxt(train_y_path, delimiter=",", skiprows=0, usecols=(0,1), dtype=str)
train_y_pd = pd.DataFrame(train_y) for i in range(train_y.shape[0]):
train_y_pd.iloc[i,0] = os.path.join(base_path,train_y[i,0])
train_y_pd.to_csv(os.path.join(base_path, 'y_train.csv'),header=None,index=None)
先将2运行,得到y_train.csv和y_test.csv,从3开始要正式读取了。

3.读取csv存于数组中,将图片路径和标签存于数组中
def load_file(example_list_file): lines =
np.genfromtxt(example_list_file,delimiter=",",dtype=[('col1', 'S120'), ('col2',
'i8')]) examples = [] labels = [] for example,label in lines:
examples.append(example) labels.append(label) #convert to numpy array return
np.asarray(examples),np.asarray(labels),len(lines)
4.使用cv2读取图片
def extract_image(filename,height,width): # print(filename) image =
cv2.imread(filename) # image = cv2.resize(image,(height,width)) b,g,r =
cv2.split(image) rgb_image = cv2.merge([r,g,b]) return rgb_image
5.将图片和标签转化为tfrecords文件


def trans2tfRecord(train_file,name,output_dir,height,width): if not
os.path.exists(output_dir) or os.path.isfile(output_dir):
os.makedirs(output_dir) _examples,_labels,examples_num = load_file(train_file)
filename = name + '.tfrecords' writer = tf.python_io.TFRecordWriter(filename)
for i,[example,label] in enumerate(zip(_examples,_labels)): #
print("NO{}".format(i)) #need to convert the example(bytes) to utf-8 example =
example.decode("UTF-8") image = extract_image(example,height,width) image_raw =
image.tostring() example =
tf.train.Example(features=tf.train.Features(feature={
'image_raw':_bytes_feature(image_raw), 'height':_int64_feature(image.shape[0]),
'width': _int64_feature(32), 'depth': _int64_feature(32), 'label':
_int64_feature(label) })) writer.write(example.SerializeToString())
writer.close() return filenamedef _int64_feature(value): return
tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def
_bytes_feature(value): return
tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


6.从tfrecords文件中读取训练数据
def read_tfRecord(file_tfRecord,shuffle=False): #
这个函数需要传入一个文件名,系统会自动将它转为一个文件名队列,这个队列存的是训练或测试过程用到的数据 #
tf.train.string_input_producer有两个重要的参数,一个是num_epochs,这个设成默认none就行,none表示无限次 #
它表示将全部样本入队次数,一般程序迭代几次就入队几次。程序运行开始,数据就开始出队,为了保证队列一直不空, #
我们设为none,使全部样本入队无数次(无限循环)。 #
另外一个就是shuffle,shuffle是指在一个epoch内文件的顺序是否被打乱(但是我测试时发现无论是True还是False,其实都打乱了)。
queue = tf.train.string_input_producer([file_tfRecord], shuffle=shuffle) reader
= tf.TFRecordReader() _,serialized_example = reader.read(queue) features =
tf.parse_single_example( serialized_example, features={ 'image_raw':
tf.FixedLenFeature([], tf.string), 'height': tf.FixedLenFeature([], tf.int64),
'width':tf.FixedLenFeature([], tf.int64), 'depth': tf.FixedLenFeature([],
tf.int64), 'label': tf.FixedLenFeature([], tf.int64) } ) image =
tf.decode_raw(features['image_raw'],tf.uint8) #height =
tf.cast(features['height'], tf.int64) #width = tf.cast(features['width'],
tf.int64) image = tf.reshape(image,[224,224,3]) image = tf.cast(image,
tf.float32) image = tf.image.per_image_standardization(image) label =
tf.cast(features['label'], tf.int64) print(image,label) return image,label
7.调用3-6,开始训练
with tf.Session() as sess: # 训练过程 base_path =
os.path.join('images','images224') data_train_path =
os.path.join(base_path,'y_train.csv') data_test_path =
os.path.join(base_path,'y_test.csv') #
首次执行程序需要运行一旦生成之后就可以注释掉了:利用csv生成y_train.tfrecords和y_test.tfrecords文件,这俩文件是训练集和测试集的样本与标签,
filename = trans2tfRecord(data_train_path, 'y_train', base_path, 224, 224)
filename2 = trans2tfRecord(data_train_path, 'y_test', base_path, 224, 224)
img_batch, path_batch = read_tfRecord(filename, shuffle=True) img_batch2,
path_batch2 = read_tfRecord(filename2, shuffle=False) image_batches,
label_batches = tf.train.batch([img_batch, path_batch], batch_size=batch,
capacity=4096) image_batches2, label_batches2 = tf.train.batch([img_batch2,
path_batch2], batch_size=batch, capacity=4096)
tf.local_variables_initializer().run() coord = tf.train.Coordinator() threads =
tf.train.start_queue_runners(sess=sess,coord=coord) # 定义一个模型
model=ATDA(sess=sess) model.create_model() #
训练模型:(image_batches,label_batches)是训练集,(image_batches2,label_batches2)是测试集,
model.fit_ATDA(source_train=image_batches, y_train=label_batches,
target_val=image_batches2, y_val=label_batches2, #
n是训练集总数,my_number是测试集总数,my_catelogy是标签种类,batch是迭代次数 nb_epoch=epochs, n = 86524,
my_number = 25596, my_catelogy = 2,batch = 16) coord.request_stop() # 请求线程结束
coord.join() # 等待线程结束
8.model.fit_ATDA(),这部分是训练模型。
def fit_ATDA(source_train, y_train, target_val, y_val, nb_epoch=30, n = 86524,
my_number = 25596, my_catelogy = 2, batch = 4): for e in range(nb_epoch):
n_batch = 0 for my_batch_train in range(int(n/batch)): Xu_batch, Yu_batch =
self.sess.run([source_train, y_train]) Xu_batch =
transform_batch_images(Xu_batch) Yu_batch = np_utils.to_categorical(Yu_batch,
2) # print('train label',Yu_batch) feed_dict = { self.x: Xu_batch, self.y_:
Yu_batch ,self.istrain:True} cost, Ft_loss = self.sess.run([cost, Ft_loss],
feed_dict=feed_dict) n_batch += 1 #every 1000 minibatch print loss if n_batch %
1000==0: print("Epoch %d total_loss %f Ft_loss %f" % (e + 1, cost,Ft_loss))
其中,从文件读取部分代码是:
Xu_batch, Yu_batch = self.sess.run([source_train, y_train])9.测试的代码就不写了,类似8。




参考资料:

1.https://zhuanlan.zhihu.com/p/27238630

2.https://www.cnblogs.com/wktwj/p/7257526.html


友情链接
KaDraw流程图
API参考文档
OK工具箱
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:ixiaoyang8@qq.com
QQ群:637538335
关注微信