学习前言
前一段时间对SSD预测与训练的整体框架有了一定的了解,但是对其中很多细节还是把握的不清楚。今天我决定好好了解以下tfrecords文件的构造。
tfrecords格式是什么
tfrecords是一种二进制编码的文件格式,tensorflow专用。能将任意数据转换为tfrecords。更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。
之所以使用到tfrecords格式是因为当今数据爆炸的情况下,使用普通的数据格式不仅麻烦,而且速度慢,这种专门为tensorflow定制的数据格式可以大大增快数据的读取,而且将所有内容规整,在保证速度的情况下,使得数据更加简单明晰。
tfrecords的写入
这个例子将会讲述如何将MNIST数据集写入到tfrecords,本次用到的MNIST数据集会利用tensorflow原有的库进行导入。
1
2
3
|
from tensorflow.examples.tutorials.mnist import input_data # 读取MNIST数据集 mnist = input_data.read_data_sets( './MNIST_data' , dtype = tf.float32, one_hot = True ) |
对于MNIST数据集而言,其中的训练集是mnist.train,而它的数据可以分为images和labels,可通过如下方式获得。
1
2
3
4
5
6
|
# 获得image,shape为(55000,784) images = mnist.train.images # 获得label,shape为(55000,10) labels = mnist.train.labels # 获得一共具有多少张图片 num_examples = mnist.train.num_examples |
接下来定义存储TFRecord文件的地址,同时创建一个writer来写TFRecord文件。
1
2
3
4
|
# 存储TFRecord文件的地址 filename = 'record/output.tfrecords' # 创建一个writer来写TFRecord文件 writer = tf.python_io.TFRecordWriter(filename) |
此时便可以按照一定的格式写入了,此时需要对每一张图片进行循环并写入,在tf.train.Features中利用features字典定义了数据保存的方式。以image_raw为例,其经过函数_float_feature处理后,存储到tfrecords文件的’image/encoded’位置上。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
# 将每张图片都转为一个Example,并写入 for i in range (num_examples): image_raw = images[i] # 读取每一幅图像 image_string = images[i].tostring() example = tf.train.Example( features = tf.train.Features( feature = { 'image/class/label' : _int64_feature(np.argmax(labels[i])), 'image/encoded' : _float_feature(image_raw), 'image/encoded_tostring' : _bytes_feature(image_string) } ) ) print (i, "/" ,num_examples) writer.write(example.SerializeToString()) # 将Example写入TFRecord文件 |
在最终存入前,数据还需要经过处理,处理方式如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
# 生成整数的属性 def _int64_feature(value): if not isinstance (value, list ) and not isinstance (value,np.ndarray): value = [value] return tf.train.Feature(int64_list = tf.train.Int64List(value = value)) # 生成浮点数的属性 def _float_feature(value): if not isinstance (value, list ) and not isinstance (value,np.ndarray): value = [value] return tf.train.Feature(float_list = tf.train.FloatList(value = value)) # 生成字符串型的属性 def _bytes_feature(value): if not isinstance (value, list ) and not isinstance (value,np.ndarray): value = [value] return tf.train.Feature(bytes_list = tf.train.BytesList(value = value)) |
tfrecords的读取
tfrecords的读取首先要创建一个reader来读取TFRecord文件中的Example。
1
2
|
# 创建一个reader来读取TFRecord文件中的Example reader = tf.TFRecordReader() |
再创建一个队列来维护输入文件列表。
1
2
|
# 创建一个队列来维护输入文件列表 filename_queue = tf.train.string_input_producer([ 'record/output.tfrecords' ]) |
利用reader读取输入文件列表队列,并用parse_single_example将读入的Example解析成tensor
1
2
3
4
5
6
7
8
9
10
11
|
# 从文件中读出一个Example _, serialized_example = reader.read(filename_queue) # 用parse_single_example将读入的Example解析成tensor features = tf.parse_single_example( serialized_example, features = { 'image/class/label' : tf.FixedLenFeature([], tf.int64), 'image/encoded' : tf.FixedLenFeature([ 784 ], tf.float32, default_value = tf.zeros([ 784 ], dtype = tf.float32)), 'image/encoded_tostring' : tf.FixedLenFeature([], tf.string) } ) |
此时我们得到了一个features,实际上它是一个类似于字典的东西,我们额可以通过字典的方式读取它内部的内容,而字典的索引就是我们再写入tfrecord文件时所用的feature。
1
2
3
4
|
# 将字符串解析成图像对应的像素数组 labels = tf.cast(features[ 'image/class/label' ], tf.int32) images = tf.cast(features[ 'image/encoded' ], tf.float32) images_tostrings = tf.decode_raw(features[ 'image/encoded_tostring' ], tf.float32) |
最后利用一个循环输出:
1
2
3
4
5
6
7
8
|
# 每次运行读取一个Example。当所有样例读取完之后,在此样例中程序会重头读取 for i in range ( 5 ): label, image = sess.run([labels, images]) images_tostring = sess.run(images_tostrings) print (np.shape(image)) print (np.shape(images_tostring)) print (label) print ( "#########################" ) |
测试代码
1、tfrecords文件的写入
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
|
import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 生成整数的属性 def _int64_feature(value): if not isinstance (value, list ) and not isinstance (value,np.ndarray): value = [value] return tf.train.Feature(int64_list = tf.train.Int64List(value = value)) # 生成浮点数的属性 def _float_feature(value): if not isinstance (value, list ) and not isinstance (value,np.ndarray): value = [value] return tf.train.Feature(float_list = tf.train.FloatList(value = value)) # 生成字符串型的属性 def _bytes_feature(value): if not isinstance (value, list ) and not isinstance (value,np.ndarray): value = [value] return tf.train.Feature(bytes_list = tf.train.BytesList(value = value)) # 读取MNIST数据集 mnist = input_data.read_data_sets( './MNIST_data' , dtype = tf.float32, one_hot = True ) # 获得image,shape为(55000,784) images = mnist.train.images # 获得label,shape为(55000,10) labels = mnist.train.labels # 获得一共具有多少张图片 num_examples = mnist.train.num_examples # 存储TFRecord文件的地址 filename = 'record/Mnist_Out.tfrecords' # 创建一个writer来写TFRecord文件 writer = tf.python_io.TFRecordWriter(filename) # 将每张图片都转为一个Example,并写入 for i in range (num_examples): image_raw = images[i] # 读取每一幅图像 image_string = images[i].tostring() example = tf.train.Example( features = tf.train.Features( feature = { 'image/class/label' : _int64_feature(np.argmax(labels[i])), 'image/encoded' : _float_feature(image_raw), 'image/encoded_tostring' : _bytes_feature(image_string) } ) ) print (i, "/" ,num_examples) writer.write(example.SerializeToString()) # 将Example写入TFRecord文件 print ( 'data processing success' ) writer.close() |
运行结果为:
……
54993 / 55000
54994 / 55000
54995 / 55000
54996 / 55000
54997 / 55000
54998 / 55000
54999 / 55000
data processing success
2、tfrecords文件的读取
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
import tensorflow as tf import numpy as np # 创建一个reader来读取TFRecord文件中的Example reader = tf.TFRecordReader() # 创建一个队列来维护输入文件列表 filename_queue = tf.train.string_input_producer([ 'record/Mnist_Out.tfrecords' ]) # 从文件中读出一个Example _, serialized_example = reader.read(filename_queue) # 用parse_single_example将读入的Example解析成tensor features = tf.parse_single_example( serialized_example, features = { 'image/class/label' : tf.FixedLenFeature([], tf.int64), 'image/encoded' : tf.FixedLenFeature([ 784 ], tf.float32, default_value = tf.zeros([ 784 ], dtype = tf.float32)), 'image/encoded_tostring' : tf.FixedLenFeature([], tf.string) } ) # 将字符串解析成图像对应的像素数组 labels = tf.cast(features[ 'image/class/label' ], tf.int32) images = tf.cast(features[ 'image/encoded' ], tf.float32) images_tostrings = tf.decode_raw(features[ 'image/encoded_tostring' ], tf.float32) sess = tf.Session() # 启动多线程处理输入数据 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess = sess, coord = coord) # 每次运行读取一个Example。当所有样例读取完之后,在此样例中程序会重头读取 for i in range ( 5 ): label, image = sess.run([labels, images]) images_tostring = sess.run(images_tostrings) print (np.shape(image)) print (np.shape(images_tostring)) print (label) print ( "#########################" ) |
运行结果为:
#########################
(784,)
(784,)
7
#########################
#########################
(784,)
(784,)
4
#########################
#########################
(784,)
(784,)
1
#########################
#########################
(784,)
(784,)
1
#########################
#########################
(784,)
(784,)
9
#########################
以上就是python神经网络tfrecords文件的写入读取及内容解析的详细内容,更多关于python神经网络tfrecords写入读取的资料请关注服务器之家其它相关文章!
原文链接:https://blog.csdn.net/weixin_44791964/article/details/102566358