脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服务器之家 - 脚本之家 - Python - TensorFLow 不同大小图片的TFrecords存取实例

TensorFLow 不同大小图片的TFrecords存取实例

2020-04-09 12:19Wayne2019 Python

今天小编就为大家分享一篇TensorFLow 不同大小图片的TFrecords存取实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

全部存入一个TFrecords文件,然后读取并显示第一张。

不多写了,直接贴代码。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
 
 
IMAGE_PATH = 'test/'
tfrecord_file = IMAGE_PATH + 'test.tfrecord'
writer = tf.python_io.TFRecordWriter(tfrecord_file)
 
 
def _int64_feature(value):
 return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 
def _bytes_feature(value):
 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
 
def get_image_binary(filename):
  """ You can read in the image using tensorflow too, but it's a drag
    since you have to create graphs. It's much easier using Pillow and NumPy
  """
  image = Image.open(filename)
  image = np.asarray(image, np.uint8)
  shape = np.array(image.shape, np.int32)
  return shape, image.tobytes() # convert image to raw data bytes in the array.
 
def write_to_tfrecord(label, shape, binary_image, tfrecord_file):
  """ This example is to write a sample to TFRecord file. If you want to write
  more samples, just use a loop.
  """
  # write label, shape, and image content to the TFRecord file
  example = tf.train.Example(features=tf.train.Features(feature={
        'label': _int64_feature(label),
        'h': _int64_feature(shape[0]),
        'w': _int64_feature(shape[1]),
        'c': _int64_feature(shape[2]),
        'image': _bytes_feature(binary_image)
        }))
  writer.write(example.SerializeToString())
 
 
def write_tfrecord(label, image_file, tfrecord_file):
  shape, binary_image = get_image_binary(image_file)
  write_to_tfrecord(label, shape, binary_image, tfrecord_file)
  # print(shape)
 
 
 
def main():
  # assume the image has the label Chihuahua, which corresponds to class number 1
  label = [1,2]
  image_files = [IMAGE_PATH + 'a.jpg', IMAGE_PATH + 'b.jpg']
 
  for i in range(2):
    write_tfrecord(label[i], image_files[i], tfrecord_file)
  writer.close()
 
  batch_size = 2
 
  filename_queue = tf.train.string_input_producer([tfrecord_file])
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)
 
  img_features = tf.parse_single_example(
                    serialized_example,
                    features={
                        'label': tf.FixedLenFeature([], tf.int64),
                        'h': tf.FixedLenFeature([], tf.int64),
                        'w': tf.FixedLenFeature([], tf.int64),
                        'c': tf.FixedLenFeature([], tf.int64),
                        'image': tf.FixedLenFeature([], tf.string),
                        })
 
  h = tf.cast(img_features['h'], tf.int32)
  w = tf.cast(img_features['w'], tf.int32)
  c = tf.cast(img_features['c'], tf.int32)
 
  image = tf.decode_raw(img_features['image'], tf.uint8)
  image = tf.reshape(image, [h, w, c])
 
  label = tf.cast(img_features['label'],tf.int32)
  label = tf.reshape(label, [1])
 
 # image = tf.image.resize_images(image, (500,500))
  #image, label = tf.train.batch([image, label], batch_size= batch_size)
 
 
  with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    image, label=sess.run([image, label])
    coord.request_stop()
    coord.join(threads)
 
    print(label)
 
    plt.figure()
    plt.imshow(image)
    plt.show()
 
 
if __name__ == '__main__':
  main()

全部存入一个TFrecords文件,然后按照batch_size读取,注意需要将图片变成一样大才能按照batch_size读取。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
 
 
IMAGE_PATH = 'test/'
tfrecord_file = IMAGE_PATH + 'test.tfrecord'
writer = tf.python_io.TFRecordWriter(tfrecord_file)
 
 
def _int64_feature(value):
 return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 
def _bytes_feature(value):
 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
 
def get_image_binary(filename):
  """ You can read in the image using tensorflow too, but it's a drag
    since you have to create graphs. It's much easier using Pillow and NumPy
  """
  image = Image.open(filename)
  image = np.asarray(image, np.uint8)
  shape = np.array(image.shape, np.int32)
  return shape, image.tobytes() # convert image to raw data bytes in the array.
 
def write_to_tfrecord(label, shape, binary_image, tfrecord_file):
  """ This example is to write a sample to TFRecord file. If you want to write
  more samples, just use a loop.
  """
  # write label, shape, and image content to the TFRecord file
  example = tf.train.Example(features=tf.train.Features(feature={
        'label': _int64_feature(label),
        'h': _int64_feature(shape[0]),
        'w': _int64_feature(shape[1]),
        'c': _int64_feature(shape[2]),
        'image': _bytes_feature(binary_image)
        }))
  writer.write(example.SerializeToString())
 
 
def write_tfrecord(label, image_file, tfrecord_file):
  shape, binary_image = get_image_binary(image_file)
  write_to_tfrecord(label, shape, binary_image, tfrecord_file)
  # print(shape)
 
 
 
def main():
  # assume the image has the label Chihuahua, which corresponds to class number 1
  label = [1,2]
  image_files = [IMAGE_PATH + 'a.jpg', IMAGE_PATH + 'b.jpg']
 
  for i in range(2):
    write_tfrecord(label[i], image_files[i], tfrecord_file)
  writer.close()
 
  batch_size = 2
 
  filename_queue = tf.train.string_input_producer([tfrecord_file])
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)
 
  img_features = tf.parse_single_example(
                    serialized_example,
                    features={
                        'label': tf.FixedLenFeature([], tf.int64),
                        'h': tf.FixedLenFeature([], tf.int64),
                        'w': tf.FixedLenFeature([], tf.int64),
                        'c': tf.FixedLenFeature([], tf.int64),
                        'image': tf.FixedLenFeature([], tf.string),
                        })
 
  h = tf.cast(img_features['h'], tf.int32)
  w = tf.cast(img_features['w'], tf.int32)
  c = tf.cast(img_features['c'], tf.int32)
 
  image = tf.decode_raw(img_features['image'], tf.uint8)
  image = tf.reshape(image, [h, w, c])
 
  label = tf.cast(img_features['label'],tf.int32)
  label = tf.reshape(label, [1])
 
  image = tf.image.resize_images(image, (224,224))
  image = tf.reshape(image, [224, 224, 3])
  image, label = tf.train.batch([image, label], batch_size= batch_size)
 
 
  with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    image, label=sess.run([image, label])
    coord.request_stop()
    coord.join(threads)
 
    print(image.shape)
    print(label)
 
    plt.figure()
    plt.imshow(image[0,:,:,0])
    plt.show()
 
    plt.figure()
    plt.imshow(image[0,:,:,1])
    plt.show()
 
    image1 = image[0,:,:,:]
    print(image1.shape)
    print(image1.dtype)
    im = Image.fromarray(np.uint8(image1)) #参考numpy和图片的互转:http://blog.csdn.net/zywvvd/article/details/72810360
    im.show()
 
if __name__ == '__main__':
  main()

输出是

?
1
2
3
4
5
(2, 224, 224, 3)
[[1]
 [2]]
 
第一张图片的三种显示(略)

封装成函数:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 8 14:38:15 2017
 
@author: wayne
 
 
"""
 
 
'''
本文参考了以下代码,在多个不同大小图片存取方面做了重新开发:
https://github.com/chiphuyen/stanford-tensorflow-tutorials/blob/master/examples/09_tfrecord_example.py
http://blog.csdn.net/hjxu2016/article/details/76165559
https://stackoverflow.com/questions/41921746/tensorflow-varlenfeature-vs-fixedlenfeature
https://github.com/tensorflow/tensorflow/issues/10492
 
后续:
-存入多个TFrecords文件的例子见
http://blog.csdn.net/xierhacker/article/details/72357651
-如何作shuffle和数据增强
string_input_producer (需要理解tf的数据流,标签队列的工作方式等等)
http://blog.csdn.net/liuchonge/article/details/73649251
'''
 
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
 
 
IMAGE_PATH = 'test/'
tfrecord_file = IMAGE_PATH + 'test.tfrecord'
writer = tf.python_io.TFRecordWriter(tfrecord_file)
 
 
def _int64_feature(value):
 return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 
def _bytes_feature(value):
 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
 
def get_image_binary(filename):
  """ You can read in the image using tensorflow too, but it's a drag
    since you have to create graphs. It's much easier using Pillow and NumPy
  """
  image = Image.open(filename)
  image = np.asarray(image, np.uint8)
  shape = np.array(image.shape, np.int32)
  return shape, image.tobytes() # convert image to raw data bytes in the array.
 
def write_to_tfrecord(label, shape, binary_image, tfrecord_file):
  """ This example is to write a sample to TFRecord file. If you want to write
  more samples, just use a loop.
  """
  # write label, shape, and image content to the TFRecord file
  example = tf.train.Example(features=tf.train.Features(feature={
        'label': _int64_feature(label),
        'h': _int64_feature(shape[0]),
        'w': _int64_feature(shape[1]),
        'c': _int64_feature(shape[2]),
        'image': _bytes_feature(binary_image)
        }))
  writer.write(example.SerializeToString())
 
 
def write_tfrecord(label, image_file, tfrecord_file):
  shape, binary_image = get_image_binary(image_file)
  write_to_tfrecord(label, shape, binary_image, tfrecord_file)
 
 
def read_and_decode(tfrecords_file, batch_size):
  '''''read and decode tfrecord file, generate (image, label) batches
  Args:
    tfrecords_file: the directory of tfrecord file
    batch_size: number of images in each batch
  Returns:
    image: 4D tensor - [batch_size, width, height, channel]
    label: 1D tensor - [batch_size]
  '''
  # make an input queue from the tfrecord file
 
  filename_queue = tf.train.string_input_producer([tfrecord_file])
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)
 
  img_features = tf.parse_single_example(
                    serialized_example,
                    features={
                        'label': tf.FixedLenFeature([], tf.int64),
                        'h': tf.FixedLenFeature([], tf.int64),
                        'w': tf.FixedLenFeature([], tf.int64),
                        'c': tf.FixedLenFeature([], tf.int64),
                        'image': tf.FixedLenFeature([], tf.string),
                        })
 
  h = tf.cast(img_features['h'], tf.int32)
  w = tf.cast(img_features['w'], tf.int32)
  c = tf.cast(img_features['c'], tf.int32)
 
  image = tf.decode_raw(img_features['image'], tf.uint8)
  image = tf.reshape(image, [h, w, c])
 
  label = tf.cast(img_features['label'],tf.int32)
  label = tf.reshape(label, [1])
 
  ##########################################################
  # you can put data augmentation here 
#  distorted_image = tf.random_crop(images, [530, 530, img_channel])
#  distorted_image = tf.image.random_flip_left_right(distorted_image)
#  distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
#  distorted_image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8)
#  distorted_image = tf.image.resize_images(distorted_image, (imagesize,imagesize))
#  float_image = tf.image.per_image_standardization(distorted_image)
 
  image = tf.image.resize_images(image, (224,224))
  image = tf.reshape(image, [224, 224, 3])
  #image, label = tf.train.batch([image, label], batch_size= batch_size)
 
  image_batch, label_batch = tf.train.batch([image, label],
                        batch_size= batch_size,
                        num_threads= 64
                        capacity = 2000)
  return image_batch, tf.reshape(label_batch, [batch_size])
 
def read_tfrecord2(tfrecord_file, batch_size):
  train_batch, train_label_batch = read_and_decode(tfrecord_file, batch_size)
 
  with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    train_batch, train_label_batch = sess.run([train_batch, train_label_batch])
    coord.request_stop()
    coord.join(threads)
  return train_batch, train_label_batch
 
 
def main():
  # assume the image has the label Chihuahua, which corresponds to class number 1
  label = [1,2]
  image_files = [IMAGE_PATH + 'a.jpg', IMAGE_PATH + 'b.jpg']
 
  for i in range(2):
    write_tfrecord(label[i], image_files[i], tfrecord_file)
  writer.close()
 
  batch_size = 2
  # read_tfrecord(tfrecord_file) # 读取一个图
  train_batch, train_label_batch = read_tfrecord2(tfrecord_file, batch_size)
 
  print(train_batch.shape)
  print(train_label_batch)
 
  plt.figure()
  plt.imshow(train_batch[0,:,:,0])
  plt.show()
 
  plt.figure()
  plt.imshow(train_batch[0,:,:,1])
  plt.show()
 
  train_batch1 = train_batch[0,:,:,:]
  print(train_batch.shape)
  print(train_batch1.dtype)
  im = Image.fromarray(np.uint8(train_batch1)) #参考numpy和图片的互转:http://blog.csdn.net/zywvvd/article/details/72810360
  im.show()
 
if __name__ == '__main__':
  main()

以上这篇TensorFLow 不同大小图片的TFrecords存取实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。

原文链接:https://blog.csdn.net/Wayne2019/article/details/77894794

延伸 · 阅读

精彩推荐