脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|shell|

服务器之家 - 脚本之家 - Python - python神经网络tfrecords文件的写入读取及内容解析

python神经网络tfrecords文件的写入读取及内容解析

2022-12-13 12:05Bubbliiiing Python

这篇文章主要为大家介绍了python神经网络tfrecords文件的写入读取及内容解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

学习前言

前一段时间对SSD预测与训练的整体框架有了一定的了解,但是对其中很多细节还是把握的不清楚。今天我决定好好了解以下tfrecords文件的构造。

tfrecords格式是什么

tfrecords是一种二进制编码的文件格式,tensorflow专用。能将任意数据转换为tfrecords。更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。

之所以使用到tfrecords格式是因为当今数据爆炸的情况下,使用普通的数据格式不仅麻烦,而且速度慢,这种专门为tensorflow定制的数据格式可以大大增快数据的读取,而且将所有内容规整,在保证速度的情况下,使得数据更加简单明晰。

tfrecords的写入

这个例子将会讲述如何将MNIST数据集写入到tfrecords,本次用到的MNIST数据集会利用tensorflow原有的库进行导入。

?
1
2
3
from tensorflow.examples.tutorials.mnist import input_data
# 读取MNIST数据集
mnist = input_data.read_data_sets('./MNIST_data', dtype=tf.float32, one_hot=True)

对于MNIST数据集而言,其中的训练集是mnist.train,而它的数据可以分为images和labels,可通过如下方式获得。

?
1
2
3
4
5
6
# 获得image,shape为(55000,784)
images = mnist.train.images
# 获得label,shape为(55000,10)
labels = mnist.train.labels
# 获得一共具有多少张图片
num_examples = mnist.train.num_examples

接下来定义存储TFRecord文件的地址,同时创建一个writer来写TFRecord文件。

?
1
2
3
4
# 存储TFRecord文件的地址
filename = 'record/output.tfrecords'
# 创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)

此时便可以按照一定的格式写入了,此时需要对每一张图片进行循环并写入,在tf.train.Features中利用features字典定义了数据保存的方式。以image_raw为例,其经过函数_float_feature处理后,存储到tfrecords文件的’image/encoded’位置上。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 将每张图片都转为一个Example,并写入
for i in range(num_examples):
    image_raw = images[i]  # 读取每一幅图像
    image_string = images[i].tostring()
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/class/label': _int64_feature(np.argmax(labels[i])),
                'image/encoded': _float_feature(image_raw),
                'image/encoded_tostring': _bytes_feature(image_string)
            }
        )
    )
    print(i,"/",num_examples)
    writer.write(example.SerializeToString())  # 将Example写入TFRecord文件

在最终存入前,数据还需要经过处理,处理方式如下:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 生成整数的属性
def _int64_feature(value):
    if not isinstance(value,list) and not isinstance(value,np.ndarray):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
# 生成浮点数的属性
def _float_feature(value):
    if not isinstance(value,list) and not isinstance(value,np.ndarray):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
# 生成字符串型的属性
def _bytes_feature(value):
    if not isinstance(value,list) and not isinstance(value,np.ndarray):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

tfrecords的读取

tfrecords的读取首先要创建一个reader来读取TFRecord文件中的Example。

?
1
2
# 创建一个reader来读取TFRecord文件中的Example
reader = tf.TFRecordReader()

再创建一个队列来维护输入文件列表。

?
1
2
# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer(['record/output.tfrecords'])

利用reader读取输入文件列表队列,并用parse_single_example将读入的Example解析成tensor

?
1
2
3
4
5
6
7
8
9
10
11
# 从文件中读出一个Example
_, serialized_example = reader.read(filename_queue)
# 用parse_single_example将读入的Example解析成tensor
features = tf.parse_single_example(
    serialized_example,
    features={
        'image/class/label': tf.FixedLenFeature([], tf.int64),
        'image/encoded': tf.FixedLenFeature([784], tf.float32, default_value=tf.zeros([784], dtype=tf.float32)),
        'image/encoded_tostring': tf.FixedLenFeature([], tf.string)
    }
)

此时我们得到了一个features,实际上它是一个类似于字典的东西,我们额可以通过字典的方式读取它内部的内容,而字典的索引就是我们再写入tfrecord文件时所用的feature。

?
1
2
3
4
# 将字符串解析成图像对应的像素数组
labels = tf.cast(features['image/class/label'], tf.int32)
images = tf.cast(features['image/encoded'], tf.float32)
images_tostrings = tf.decode_raw(features['image/encoded_tostring'], tf.float32)

最后利用一个循环输出:

?
1
2
3
4
5
6
7
8
# 每次运行读取一个Example。当所有样例读取完之后,在此样例中程序会重头读取
for i in range(5):
    label, image = sess.run([labels, images])
    images_tostring = sess.run(images_tostrings)
    print(np.shape(image))
    print(np.shape(images_tostring))
    print(label)
    print("#########################")

测试代码

1、tfrecords文件的写入

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 生成整数的属性
def _int64_feature(value):
    if not isinstance(value,list) and not isinstance(value,np.ndarray):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
# 生成浮点数的属性
def _float_feature(value):
    if not isinstance(value,list) and not isinstance(value,np.ndarray):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
# 生成字符串型的属性
def _bytes_feature(value):
    if not isinstance(value,list) and not isinstance(value,np.ndarray):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
# 读取MNIST数据集
mnist = input_data.read_data_sets('./MNIST_data', dtype=tf.float32, one_hot=True)
# 获得image,shape为(55000,784)
images = mnist.train.images
# 获得label,shape为(55000,10)
labels = mnist.train.labels
# 获得一共具有多少张图片
num_examples = mnist.train.num_examples
# 存储TFRecord文件的地址
filename = 'record/Mnist_Out.tfrecords'
# 创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
# 将每张图片都转为一个Example,并写入
for i in range(num_examples):
    image_raw = images[i]  # 读取每一幅图像
    image_string = images[i].tostring()
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/class/label': _int64_feature(np.argmax(labels[i])),
                'image/encoded': _float_feature(image_raw),
                'image/encoded_tostring': _bytes_feature(image_string)
            }
        )
    )
    print(i,"/",num_examples)
    writer.write(example.SerializeToString())  # 将Example写入TFRecord文件
print('data processing success')
writer.close()

运行结果为:

……
54993 / 55000
54994 / 55000
54995 / 55000
54996 / 55000
54997 / 55000
54998 / 55000
54999 / 55000
data processing success

2、tfrecords文件的读取

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import tensorflow as tf
import numpy as np
# 创建一个reader来读取TFRecord文件中的Example
reader = tf.TFRecordReader()
# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer(['record/Mnist_Out.tfrecords'])
# 从文件中读出一个Example
_, serialized_example = reader.read(filename_queue)
# 用parse_single_example将读入的Example解析成tensor
features = tf.parse_single_example(
    serialized_example,
    features={
        'image/class/label': tf.FixedLenFeature([], tf.int64),
        'image/encoded': tf.FixedLenFeature([784], tf.float32, default_value=tf.zeros([784], dtype=tf.float32)),
        'image/encoded_tostring': tf.FixedLenFeature([], tf.string)
    }
)
# 将字符串解析成图像对应的像素数组
labels = tf.cast(features['image/class/label'], tf.int32)
images = tf.cast(features['image/encoded'], tf.float32)
images_tostrings = tf.decode_raw(features['image/encoded_tostring'], tf.float32)
sess = tf.Session()
# 启动多线程处理输入数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 每次运行读取一个Example。当所有样例读取完之后,在此样例中程序会重头读取
for i in range(5):
    label, image = sess.run([labels, images])
    images_tostring = sess.run(images_tostrings)
    print(np.shape(image))
    print(np.shape(images_tostring))
    print(label)
    print("#########################")

运行结果为:

#########################
(784,)
(784,)
7
#########################
#########################
(784,)
(784,)
4
#########################
#########################
(784,)
(784,)
1
#########################
#########################
(784,)
(784,)
1
#########################
#########################
(784,)
(784,)
9
#########################

以上就是python神经网络tfrecords文件的写入读取及内容解析的详细内容,更多关于python神经网络tfrecords写入读取的资料请关注服务器之家其它相关文章!

原文链接:https://blog.csdn.net/weixin_44791964/article/details/102566358

延伸 · 阅读

精彩推荐
  • Pythonpython 统计列表中不同元素的数量方法

    python 统计列表中不同元素的数量方法

    今天小编就为大家分享一篇python 统计列表中不同元素的数量方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    wenqiwenqi12318662021-03-11
  • PythonPython爬虫DNS解析缓存方法实例分析

    Python爬虫DNS解析缓存方法实例分析

    这篇文章主要介绍了Python爬虫DNS解析缓存方法,结合具体实例形式分析了Python使用socket模块解析DNS缓存的相关操作技巧与注意事项,需要的朋友可以参考下...

    九茶4932020-11-14
  • Pythonpandas.DataFrame选取/排除特定行的方法

    pandas.DataFrame选取/排除特定行的方法

    今天小编就为大家分享一篇pandas.DataFrame选取/排除特定行的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    软件大盗7282021-03-12
  • PythonPython使用matplotlib绘制余弦的散点图示例

    Python使用matplotlib绘制余弦的散点图示例

    这篇文章主要介绍了Python使用matplotlib绘制余弦的散点图,涉及Python操作matplotlib的基本技巧与散点的设置方法,需要的朋友可以参考下...

    chengqiuming8172021-01-22
  • PythonPython中glob库实现文件名的匹配

    Python中glob库实现文件名的匹配

    本文主要主要介绍了Python中glob库实现文件名的匹配,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...

    一天一篇Python库7152021-12-03
  • Pythonpytorch预测之解决多次预测结果不一致问题

    pytorch预测之解决多次预测结果不一致问题

    这篇文章主要介绍了pytorch多次预测结果不一致的解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教...

    confusingbird7442021-11-19
  • Python详解Python中类的定义与使用

    详解Python中类的定义与使用

    本篇文章主要介绍了详解Python中类的定义与使用,介绍了什么叫做类和如何使用,具有一定的参考价值,想要学习Python的同学可以了解一下。...

    chenxiaoyong13452020-09-29
  • PythonPython实现Windows上气泡提醒效果的方法

    Python实现Windows上气泡提醒效果的方法

    这篇文章主要介绍了Python实现Windows上气泡提醒效果的方法,涉及Python针对windows窗口操作的相关技巧,需要的朋友可以参考下 ...

    xm13313053692020-07-12