tensorflow的4中迭代器

  • Post author:
  • Post category:其他



1. dataset.make_one_shot_iterator()

import tensorflow as tf
dataset = tf.data.Dataset.range(10)
# 也可以是.repeat(-1)
dataset = dataset.map(lambda x:x+2).repeat(3).batch(4)
iterator = dataset.make_one_shot_iterator()
elem = iterator.get_next()
with tf.Session() as sess:
  for _ in range(2):
    print(sess.run(elem))
  print('**************')
  for _ in range(3):
    print(sess.run(elem))
  print('**************')
  for _ in range(4):
    try:
      print(sess.run(elem))
    except Exception:
      print('end')
      break

output:
[2 3 4 5]
[6 7 8 9]
**************
[10 11  2  3]
[4 5 6 7]
[ 8  9 10 11]
**************
[2 3 4 5]
[6 7 8 9]
[10 11]
end

需要注意

1.通常用 try-catch 配合使用,当 Dataset 中的数据被读取完毕的时候,程序会抛出异常,获取这个异常就可以从容结束本次数据的迭代。

2. 它不支持参数化。它需要 Dataset 在程序运行之前就确认自己的大小。

下面的程序报错:

import tensorflow as tf
number = tf.placeholder(dtype=tf.int64)
dataset = tf.data.Dataset.range(number)
dataset = dataset.map(lambda x:x+2).repeat(3).batch(4)
iterator = dataset.make_one_shot_iterator()
elem = iterator.get_next()
with tf.Session() as sess:
  for _ in range(2):
    print(sess.run(elem, feed_dict={number:10}))
  print('**************')
  for _ in range(3):
    print(sess.run(elem))
  print('**************')
  for _ in range(4):
    try:
      print(sess.run(elem))
    except Exception:
      print('end')
      break

报错:
ValueError: Cannot capture a placeholder (name:Placeholder, type:Placeholder) by value.


2. dataset.make_initializable_iterator()

import tensorflow as tf
number = tf.placeholder(dtype=tf.int64)
dataset = tf.data.Dataset.range(number)
dataset = dataset.map(lambda x:x+1).repeat(5).batch(4)
iterator = dataset.make_initializable_iterator()

elem = iterator.get_next()

with tf.Session() as sess:
  sess.run(iterator.initializer, feed_dict={number:10})
  for _ in range(4):
    print(sess.run(elem))
  print("******************")
  for _ in range(5):
    print(sess.run(elem))
  print("******************")

  for _ in range(6):
    try:
      print(sess.run(elem))
    except Exception:
      print('end')
      break
  print("******************")

output:
[1 2 3 4]
[5 6 7 8]
[ 9 10  1  2]
[3 4 5 6]
******************
[ 7  8  9 10]
[1 2 3 4]
[5 6 7 8]
[ 9 10  1  2]
[3 4 5 6]
******************
[ 7  8  9 10]
[1 2 3 4]
[5 6 7 8]
[ 9 10]
end
******************

注:可初始化的迭代器(感觉叫可迭代的迭代器更靠谱), 可以不用预先定义dataset的大小,作为参数feed进去。


3. 可接不同数据集的迭代器

import tensorflow as tf

train_num = tf.placeholder(dtype=tf.int64)
test_num = tf.placeholder(dtype=tf.int64)

train_dataset = tf.data.Dataset.range(train_num)
train_dataset = train_dataset.map(lambda x:x+1).repeat(5).batch(8)

test_dataset = tf.data.Dataset.range(test_num)
test_dataset = test_dataset.map(lambda x: x+2).repeat(2).batch(1)

iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)

train_op = iterator.make_initializer(train_dataset)
test_op = iterator.make_initializer(test_dataset)

elem = iterator.get_next()

with tf.Session() as sess:
  sess.run(train_op, feed_dict={train_num:10})
  for _ in range(3):
    print(sess.run(elem))
  print('*********')


  sess.run(test_op, feed_dict={test_num:5})
  for _ in range(4):
    print(sess.run(elem))
  print('*********')

  for _ in range(5):
    print(sess.run(elem))
  print('*********')

  sess.run(train_op, feed_dict={train_num:10})
  for _ in range(3):
    print(sess.run(elem))
  print('*********')

output:
[1 2 3 4 5 6 7 8]
[ 9 10  1  2  3  4  5  6]
[ 7  8  9 10  1  2  3  4]
*********
[2]
[3]
[4]
[5]
*********
[6]
[2]
[3]
[4]
[5]
*********
[1 2 3 4 5 6 7 8]
[ 9 10  1  2  3  4  5  6]
[ 7  8  9 10  1  2  3  4]
*********

tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)

的接口说明:

“””Creates a new, uninitialized `Iterator` with the given structure.

This iterator-constructing method can be used to create an iterator that

is reusable with many different datasets.

The returned iterator is not bound to a particular dataset, and it has

no `initializer`. To initialize the iterator, run the operation returned by

`Iterator.make_initializer(dataset)`”””

注:一般我们在训练模型时,都是在训练集上跑一段时间,然后在测试集上看看模型的性能,紧接着训练集 继续训练。但是这中迭代器 跑完测试集时, 必须重新初始化训练集迭代器,也即重新喂训练数据 开始跑模型。下一个迭代器解决此问题。

4. 可馈赠的迭代器(听起来挺别扭的,但是其他地方都这么叫)

import tensorflow as tf
train_num = tf.placeholder(tf.int64)
test_num = tf.placeholder(tf.int64)

train_dataset = tf.data.Dataset.range(train_num).map(lambda x:x+1).repeat(5).batch(8)
test_dataset = tf.data.Dataset.range(test_num).map(lambda x:x*2).repeat(5).batch(1)

handle = tf.placeholder(tf.string)

iterator = tf.data.Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes)
elem = iterator.get_next()

train_op = train_dataset.make_initializable_iterator()
test_op = test_dataset.make_initializable_iterator()

with tf.Session() as sess:
  train_handle = sess.run(train_op.string_handle())
  test_handle = sess.run(test_op.string_handle())

  sess.run(train_op.initializer, feed_dict={train_num:10})
  for _ in range(3):
    print(sess.run(elem, feed_dict={handle:train_handle}))
  print('***********')

  sess.run(test_op.initializer, feed_dict={test_num:5})
  for _ in range(2):
    print(sess.run(elem, feed_dict={handle: test_handle}))
  print('***********')


  for _ in range(4):
    print(sess.run(elem, feed_dict={handle:train_handle}))
  print('***********')

  for _ in range(3):
    print(sess.run(elem, feed_dict={handle: test_handle}))




output:
[1 2 3 4 5 6 7 8]
[ 9 10  1  2  3  4  5  6]
[ 7  8  9 10  1  2  3  4]
***********
[0]
[2]
***********
[ 5  6  7  8  9 10  1  2]
[ 3  4  5  6  7  8  9 10]
[1 2 3 4 5 6 7 8]
[ 9 10]
***********
[4]
[6]
[8]



版权声明:本文为biubiubiu888原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。