TensorFlow杂记 – 生成和读取TFRecord(二)

  • Post author:
  • Post category:其他



与上一篇不同的FRecord生成和读取方法,抽取自SSD-TensorFlow,并做一定的修改。


使用Slim生成和读取TFRecord。


软件平台:Pycarn 2019.2 + Tensorflow 1.13.1 + cuda 10.0 + cudnn 7.6.5


GPU: 1080Ti


上代码。



dataset_common.py

"""Provides data for the Pascal VOC Dataset (images + annotations).
"""
import os

import tensorflow as tf
import dataset_utils

slim = tf.contrib.slim

VOC_LABELS = {
    'B01': (0, 'B01'),
    'D01': (1, 'D01'),
    'G01': (2, 'G01'),
    'W01': (3, 'W01'),
    'W02': (4, 'W02'),
    'W03': (5, 'W03'),
    'W04': (6, 'W04'),
    'T01': (7, 'T01'),
    'R01': (8, 'R01'),
    'F01': (9, 'F01'),
    'S01': (10, 'S01'),
    'G02': (11, 'G02'),
    'G03': (12, 'G03'),
    'W05': (13, 'W05'),
    'I18': (14, 'I18'),
    'I01': (15, 'I01'),
    'I02': (16, 'I02'),
    'I03': (17, 'I03'),
    'I04': (18, 'I04'),
    'I99': (19, 'I99'),
}


def get_split(split_name, dataset_dir, file_pattern, reader,
              split_to_sizes, items_to_descriptions, num_classes):
    """Gets a dataset tuple with instructions for reading Pascal VOC dataset.

    Args:
      split_name: A train/test split name.
      dataset_dir: The base directory of the dataset sources.
      file_pattern: The file pattern to use when matching the dataset sources.
        It is assumed that the pattern contains a '%s' string so that the split
        name can be inserted.
      reader: The TensorFlow reader type.

    Returns:
      A `Dataset` namedtuple.

    Raises:
        ValueError: if `split_name` is not a valid train/test split.
    """
    if split_name not in split_to_sizes:
        raise ValueError('split name %s was not recognized.' % split_name)
    file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

    # Allowing None in the signature so that dataset_factory can use the default.
    if reader is None:
        reader = tf.TFRecordReader
    # Features in Pascal VOC TFRecords.
    keys_to_features = { # 解析TFR文件方式
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height': tf.FixedLenFeature([1], tf.int64),
        'image/width': tf.FixedLenFeature([1], tf.int64),
        'image/channels': tf.FixedLenFeature([1], tf.int64),
        'image/shape': tf.FixedLenFeature([3], tf.int64),
        'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
    }
    items_to_handlers = { # 解码二进制数据条目
        'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
        'shape': slim.tfexample_decoder.Tensor('image/shape'),
        'object/bbox': slim.tfexample_decoder.BoundingBox(
                ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
        'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
        'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
        'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
    }
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)

    labels_to_names = None
    if dataset_utils.has_labels(dataset_dir): #查看指定路径下,是否有labels.txt文件
        labels_to_names = dataset_utils.read_label_file(dataset_dir)
    # else:
    #     labels_to_names = create_readable_names_for_imagenet_labels()
    #     dataset_utils.write_label_file(labels_to_names, dataset_dir)

    return slim.dataset.Dataset(
            data_sources=file_pattern,
            reader=reader,
            decoder=decoder,
            num_samples=split_to_sizes[split_name],
            items_to_descriptions=items_to_descriptions,
            num_classes=num_classes,
            labels_to_names=labels_to_names)



dataset_utils.py

"""Contains utilities for downloading and converting datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import tarfile # 管理tar压缩包

from six.moves import urllib # six是用来兼容python 2 和 3的,我猜名字就是用的2和3的最小公倍数。
                             # six.moves 是用来处理那些在2 和 3里面函数的位置有变化的,直接用six.moves就可以屏蔽掉这些变化
import tensorflow as tf

LABELS_FILENAME = 'labels.txt'


def int64_feature(value):
    """Wrapper for inserting int64 features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def float_feature(value):
    """Wrapper for inserting float features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def image_to_tfexample(image_data, image_format, height, width, class_id):
    return tf.train.Example(features=tf.train.Features(feature={
      'image/encoded': bytes_feature(image_data),
      'image/format': bytes_feature(image_format),
      'image/class/label': int64_feature(class_id),
      'image/height': int64_feature(height),
      'image/width': int64_feature(width),
    }))


def download_and_uncompress_tarball(tarball_url, dataset_dir):
    """Downloads the `tarball_url` and uncompresses it locally.

    Args:
    tarball_url: The URL of a tarball file.
    dataset_dir: The directory where the temporary files are stored.
    """
    filename = tarball_url.split('/')[-1]
    filepath = os.path.join(dataset_dir, filename)

    def _progress(count, block_size, total_size):
        sys.stdout.write('\r>> Downloading %s %.1f%%' % (
            filename, float(count * block_size) / float(total_size) * 100.0))
        sys.stdout.flush()
    filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
    print()
    statinfo = os.stat(filepath)
    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
    tarfile.open(filepath, 'r:gz').extractall(dataset_dir)


def write_label_file(labels_to_class_names, dataset_dir,
                     filename=LABELS_FILENAME):
    """Writes a file with the list of class names.

    Args:
    labels_to_class_names: A map of (integer) labels to class names.
    dataset_dir: The directory in which the labels file should be written.
    filename: The filename where the class names are written.
    """
    labels_filename = os.path.join(dataset_dir, filename)
    with tf.gfile.Open(labels_filename, 'w') as f:
        for label in labels_to_class_names:
            class_name = labels_to_class_names[label]
            f.write('%d:%s\n' % (label, class_name))


def has_labels(dataset_dir, filename=LABELS_FILENAME):
    """Specifies whether or not the dataset directory contains a label map file.

    Args:
    dataset_dir: The directory in which the labels file is found.
    filename: The filename where the class names are written.

    Returns:
    `True` if the labels file exists and `False` otherwise.
    """
    return tf.gfile.Exists(os.path.join(dataset_dir, filename))


def read_label_file(dataset_dir, filename=LABELS_FILENAME):
    """Reads the labels file and returns a mapping from ID to class name.

    Args:
    dataset_dir: The directory in which the labels file is found.
    filename: The filename where the class names are written.

    Returns:
    A map from a label (integer) to class name.
    """
    labels_filename = os.path.join(dataset_dir, filename)
    with tf.gfile.Open(labels_filename, 'rb') as f:
        lines = f.read()
    lines = lines.split(b'\n')
    lines = filter(None, lines)

    labels_to_class_names = {}
    for line in lines:
        index = line.index(b':')
        labels_to_class_names[int(line[:index])] = line[index+1:]
    return labels_to_class_names



transfer_to_tfrecords.py

"""Converts data to TFRecords file format with Example protos.

The raw Pascal VOC data set is expected to reside in JPEG files located in the
directory 'JPEGImages'. Similarly, bounding box annotations are supposed to be
stored in the 'Annotation directory'

This TensorFlow script converts the training and evaluation data into
a sharded data set consisting of 1024 and 128 TFRecord files, respectively.

Each validation TFRecord file contains ~500 records. Each training TFREcord
file contains ~1000 records. Each record within the TFRecord file is a
serialized Example proto. The Example proto contains the following fields:

    image/encoded: string containing JPEG encoded image in RGB colorspace
    image/height: integer, image height in pixels
    image/width: integer, image width in pixels
    image/channels: integer, specifying the number of channels, always 3
    image/format: string, specifying the format, always'JPEG'


    image/object/bbox/xmin: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/xmax: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/ymin: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/ymax: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/label: list of integer specifying the classification index.
    image/object/bbox/label_text: list of string descriptions.

"""

# 1. 按照VOC格式存储图片数据和标注数据;
# 2. 在run函数中指定输入、输出路径和输出tfrecord数据库的名字;


import os
import sys
import random

import numpy as np
import tensorflow as tf

import xml.etree.ElementTree as ET

from PIL import Image  #注意Image,后面会用到

from dataset_utils import int64_feature, float_feature, bytes_feature
from dataset_common import VOC_LABELS

# Original dataset organisation.
DIRECTORY_ANNOTATIONS = 'Annotations\\'
DIRECTORY_IMAGES = 'JPEGImages\\'

# TFRecords convertion parameters.
RANDOM_SEED = 4242
SAMPLES_PER_FILES = 100


"""
获取一张图片数据及其对应的尺寸,标注信息等
"""
def _process_image(directory, name):
    """Process a image and annotation file.

    Args:
      filename: string, path to an image file e.g., '/path/to/example.JPG'.
      coder: instance of ImageCoder to provide TensorFlow image coding utils.
    Returns:
      image_buffer: string, JPEG encoding of RGB image.
      height: integer, image height in pixels.
      width: integer, image width in pixels.
    """

    filename = directory + DIRECTORY_IMAGES + name + '.jpg'  # 完整的路径,带后缀

    """
        使用tf.gfile.FastGFile和tf.gfile.GFile
    """
    # # Read the image file.
    # image_data = tf.gfile.FastGFile(filename, 'rb').read()
    image_data = tf.gfile.GFile(filename, 'rb').read()

    # Read the XML annotation file.
    xmlfilename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml')
    print(xmlfilename)

    tree = ET.parse(xmlfilename)
    root = tree.getroot()

    # Image shape.
    size = root.find('size')
    shape = [int(size.find('height').text), # shape是个list,shape[0]: height, shape[1]: width, shape[2]: depth
             int(size.find('width').text),
             int(size.find('depth').text)]
    # Find annotations.
    bboxes = []
    labels = []
    labels_text = []
    difficult = []
    truncated = []
    for obj in root.findall('object'):
        label = obj.find('name').text
        labels.append(int(VOC_LABELS[label][0]))
        labels_text.append(label.encode('ascii'))

        if obj.find('difficult'):
            difficult.append(int(obj.find('difficult').text))
        else:
            difficult.append(0)
        if obj.find('truncated'):
            truncated.append(int(obj.find('truncated').text))
        else:
            truncated.append(0)

        bbox = obj.find('bndbox')
        bboxes.append((float(bbox.find('ymin').text) / shape[0],
                       float(bbox.find('xmin').text) / shape[1],
                       float(bbox.find('ymax').text) / shape[0],
                       float(bbox.find('xmax').text) / shape[1]
                       ))
    return image_data, shape, bboxes, labels, labels_text, difficult, truncated


"""
    将一张图片的相关信息转换为example
"""
def _convert_to_example(image_data, labels, labels_text, bboxes, shape,
                        difficult, truncated):
    """将一张图片转换成example

    Args:
      image_data: Jpeg图片数据;
      labels: ground truth list;
      labels_text: list of strings, human-readable labels;
      bboxes: list of bounding boxes; each box is a list of integers;
          specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong
          to the same label as the image label.
      shape: 3 integers, image shapes in pixels.
    Returns:
      Example proto
    """
    xmin = []
    ymin = []
    xmax = []
    ymax = []
    for b in bboxes:
        assert len(b) == 4
        # pylint: disable=expression-not-assigned
        [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
        # pylint: enable=expression-not-assigned

    # image_format = b'JPEG'
    image_format = b'jpeg'
    example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': int64_feature(shape[0]), # 在dataset_utils.py中定义
            'image/width': int64_feature(shape[1]),
            'image/channels': int64_feature(shape[2]),
            'image/shape': int64_feature(shape),
            'image/object/bbox/xmin': float_feature(xmin),
            'image/object/bbox/xmax': float_feature(xmax),
            'image/object/bbox/ymin': float_feature(ymin),
            'image/object/bbox/ymax': float_feature(ymax),
            'image/object/bbox/label': int64_feature(labels),
            'image/object/bbox/label_text': bytes_feature(labels_text),
            'image/object/bbox/difficult': int64_feature(difficult),
            'image/object/bbox/truncated': int64_feature(truncated),
            'image/format': bytes_feature(image_format),
            'image/encoded': bytes_feature(image_data)}))
    return example


"""
    将一张图片对应的example保存到tfrecord
    一个example是
"""
def _add_to_tfrecord(dataset_dir, name, tfrecord_writer):
    """加载一张图片及其标注信息至example
       并将当前example写入tfrecord

    Args:
      dataset_dir: 数据集路径;
      name: 图片名;
      tfrecord_writer: TFRecord writer.
    """
    # 一张图片的数据
    image_data, shape, bboxes, labels, labels_text, difficult, truncated = \
        _process_image(dataset_dir, name)
    example = _convert_to_example(image_data, labels, labels_text,
                                  bboxes, shape, difficult, truncated)
    tfrecord_writer.write(example.SerializeToString())


def _get_output_filename(output_dir, name, idx):
    return '%s%s_%03d.tfrecord' % (output_dir, name, idx)


def run(dataset_dir, output_dir, name='train', shuffling=False):
    """Runs the conversion operation.

    Args:
      dataset_dir: 数据集路径.
      output_dir: TFRecord输出路径.
    """
    if not tf.gfile.Exists(dataset_dir):
        tf.gfile.MakeDirs(dataset_dir)

    # Dataset filenames, and shuffling.
    path = os.path.join(dataset_dir, DIRECTORY_ANNOTATIONS) # xml文件路径
    print("xml path: ", path)
    filenames = sorted(os.listdir(path))  # 返回一个list,包含了当前文件夹下所有的文件名
    if shuffling:
        random.seed(RANDOM_SEED)
        random.shuffle(filenames)

    # print(filenames)

    # Process dataset files.
    i = 0
    fidx = 0 # 标识TFR文件索引
    while i < len(filenames):  # 一张图片一张图片的处理
        # Open new TFRecord file.
        tf_filename = _get_output_filename(output_dir, name, fidx)  #TFR文件保存路径
        print("tf_filename: ", tf_filename)
        with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
            j = 0
            while i < len(filenames) and j < SAMPLES_PER_FILES: # 每个TFR保存100张图片的sample数据
                sys.stdout.write('\r>> Converting image %d/%d \n' % (i+1, len(filenames)))
                sys.stdout.flush()

                filename = filenames[i]
                img_name = filename[:-4] # 文件名,不带后缀
                _add_to_tfrecord(dataset_dir, img_name, tfrecord_writer) # 对一张图片执行TFR相应操作
                i += 1
                j += 1
            fidx += 1

    # Finally, write the labels file:
    # labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
    # dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
    print('\nFinished converting the Pascal VOC dataset!')



if __name__ == "__main__":
    # execute only if run as a script
    run("H:\\11_DataSet\\QD\\", "H:\\11_DataSet\\QD\\")



read_from_tfrecords.py

# coding=utf-8
import tensorflow as tf
from PIL import Image

slim = tf.contrib.slim

# 指定TFRecord路径
tfrecords_filename = 'H:\\11_DataSet\\QD\\train_000.tfrecord'


"""
    使用Slim的方法从TFrecord文件中读取
"""
def read_record_file():



    keys_to_features = { # 解析TFR文件方式
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height': tf.FixedLenFeature([], tf.int64),
        'image/width': tf.FixedLenFeature([], tf.int64),
        'image/channels': tf.FixedLenFeature([], tf.int64),
        'image/shape': tf.FixedLenFeature([3], tf.int64),
        'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
    }
    items_to_handlers = { # 解码二进制数据条目
        'image': slim.tfexample_decoder.Image('image/encoded', 'image/format', channels=3),
        'height': slim.tfexample_decoder.Tensor('image/height'),
        'width': slim.tfexample_decoder.Tensor('image/width'),
        'shape': slim.tfexample_decoder.Tensor('image/shape'),
        'object/bbox': slim.tfexample_decoder.BoundingBox(
                ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
        'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
        'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
        'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
    }



    # 定义解码器,进行解码
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
    # 定义dataset,该对象定义了数据集的文件位置,解码方式等元信息
    dataset = slim.dataset.Dataset(
        data_sources=tfrecords_filename,
        reader=tf.TFRecordReader,
        decoder=decoder,
        num_samples=100,  # 训练数据的总数
        items_to_descriptions=None,
        num_classes=20,
    )
    # 使用provider对象根据dataset信息读取数据
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=1,
        common_queue_capacity=20,
        common_queue_min=10)

    # 获取数据
    [image, shape, h, w] = provider.get(['image', 'shape', 'height', 'width'])
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(10):
            img, h_, w_, shape_ = sess.run([image, h, w, shape])
            img = tf.reshape(img, [h_, w_, 3])
            print(img.shape)
            print(img)
            print("h = %d" % h_)
            print("w = %d" % w_)

            img=Image.fromarray(img.eval(), 'RGB')       # 这里将narray转为Image类,Image转narray:a=np.array(img)
            img.save('./'+str(i)+'.jpg')                 # 保存图片

        coord.request_stop()
        coord.join(threads)


if __name__ == '__main__':
    # create_record_file()
    read_record_file()



版权声明:本文为tecsai原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。