1. dataset.make_one_shot_iterator()
import tensorflow as tf
dataset = tf.data.Dataset.range(10)
# 也可以是.repeat(-1)
dataset = dataset.map(lambda x:x+2).repeat(3).batch(4)
iterator = dataset.make_one_shot_iterator()
elem = iterator.get_next()
with tf.Session() as sess:
for _ in range(2):
print(sess.run(elem))
print('**************')
for _ in range(3):
print(sess.run(elem))
print('**************')
for _ in range(4):
try:
print(sess.run(elem))
except Exception:
print('end')
break
output:
[2 3 4 5]
[6 7 8 9]
**************
[10 11 2 3]
[4 5 6 7]
[ 8 9 10 11]
**************
[2 3 4 5]
[6 7 8 9]
[10 11]
end
需要注意
1.通常用 try-catch 配合使用,当 Dataset 中的数据被读取完毕的时候,程序会抛出异常,获取这个异常就可以从容结束本次数据的迭代。
2. 它不支持参数化。它需要 Dataset 在程序运行之前就确认自己的大小。
下面的程序报错:
import tensorflow as tf
number = tf.placeholder(dtype=tf.int64)
dataset = tf.data.Dataset.range(number)
dataset = dataset.map(lambda x:x+2).repeat(3).batch(4)
iterator = dataset.make_one_shot_iterator()
elem = iterator.get_next()
with tf.Session() as sess:
for _ in range(2):
print(sess.run(elem, feed_dict={number:10}))
print('**************')
for _ in range(3):
print(sess.run(elem))
print('**************')
for _ in range(4):
try:
print(sess.run(elem))
except Exception:
print('end')
break
报错:
ValueError: Cannot capture a placeholder (name:Placeholder, type:Placeholder) by value.
2. dataset.make_initializable_iterator()
import tensorflow as tf
number = tf.placeholder(dtype=tf.int64)
dataset = tf.data.Dataset.range(number)
dataset = dataset.map(lambda x:x+1).repeat(5).batch(4)
iterator = dataset.make_initializable_iterator()
elem = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer, feed_dict={number:10})
for _ in range(4):
print(sess.run(elem))
print("******************")
for _ in range(5):
print(sess.run(elem))
print("******************")
for _ in range(6):
try:
print(sess.run(elem))
except Exception:
print('end')
break
print("******************")
output:
[1 2 3 4]
[5 6 7 8]
[ 9 10 1 2]
[3 4 5 6]
******************
[ 7 8 9 10]
[1 2 3 4]
[5 6 7 8]
[ 9 10 1 2]
[3 4 5 6]
******************
[ 7 8 9 10]
[1 2 3 4]
[5 6 7 8]
[ 9 10]
end
******************
注:可初始化的迭代器(感觉叫可迭代的迭代器更靠谱), 可以不用预先定义dataset的大小,作为参数feed进去。
3. 可接不同数据集的迭代器
import tensorflow as tf
train_num = tf.placeholder(dtype=tf.int64)
test_num = tf.placeholder(dtype=tf.int64)
train_dataset = tf.data.Dataset.range(train_num)
train_dataset = train_dataset.map(lambda x:x+1).repeat(5).batch(8)
test_dataset = tf.data.Dataset.range(test_num)
test_dataset = test_dataset.map(lambda x: x+2).repeat(2).batch(1)
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
train_op = iterator.make_initializer(train_dataset)
test_op = iterator.make_initializer(test_dataset)
elem = iterator.get_next()
with tf.Session() as sess:
sess.run(train_op, feed_dict={train_num:10})
for _ in range(3):
print(sess.run(elem))
print('*********')
sess.run(test_op, feed_dict={test_num:5})
for _ in range(4):
print(sess.run(elem))
print('*********')
for _ in range(5):
print(sess.run(elem))
print('*********')
sess.run(train_op, feed_dict={train_num:10})
for _ in range(3):
print(sess.run(elem))
print('*********')
output:
[1 2 3 4 5 6 7 8]
[ 9 10 1 2 3 4 5 6]
[ 7 8 9 10 1 2 3 4]
*********
[2]
[3]
[4]
[5]
*********
[6]
[2]
[3]
[4]
[5]
*********
[1 2 3 4 5 6 7 8]
[ 9 10 1 2 3 4 5 6]
[ 7 8 9 10 1 2 3 4]
*********
tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
的接口说明:
“””Creates a new, uninitialized `Iterator` with the given structure.
This iterator-constructing method can be used to create an iterator that
is reusable with many different datasets.
The returned iterator is not bound to a particular dataset, and it has
no `initializer`. To initialize the iterator, run the operation returned by
`Iterator.make_initializer(dataset)`”””
注:一般我们在训练模型时,都是在训练集上跑一段时间,然后在测试集上看看模型的性能,紧接着训练集 继续训练。但是这中迭代器 跑完测试集时, 必须重新初始化训练集迭代器,也即重新喂训练数据 开始跑模型。下一个迭代器解决此问题。
4. 可馈赠的迭代器(听起来挺别扭的,但是其他地方都这么叫)
import tensorflow as tf
train_num = tf.placeholder(tf.int64)
test_num = tf.placeholder(tf.int64)
train_dataset = tf.data.Dataset.range(train_num).map(lambda x:x+1).repeat(5).batch(8)
test_dataset = tf.data.Dataset.range(test_num).map(lambda x:x*2).repeat(5).batch(1)
handle = tf.placeholder(tf.string)
iterator = tf.data.Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes)
elem = iterator.get_next()
train_op = train_dataset.make_initializable_iterator()
test_op = test_dataset.make_initializable_iterator()
with tf.Session() as sess:
train_handle = sess.run(train_op.string_handle())
test_handle = sess.run(test_op.string_handle())
sess.run(train_op.initializer, feed_dict={train_num:10})
for _ in range(3):
print(sess.run(elem, feed_dict={handle:train_handle}))
print('***********')
sess.run(test_op.initializer, feed_dict={test_num:5})
for _ in range(2):
print(sess.run(elem, feed_dict={handle: test_handle}))
print('***********')
for _ in range(4):
print(sess.run(elem, feed_dict={handle:train_handle}))
print('***********')
for _ in range(3):
print(sess.run(elem, feed_dict={handle: test_handle}))
output:
[1 2 3 4 5 6 7 8]
[ 9 10 1 2 3 4 5 6]
[ 7 8 9 10 1 2 3 4]
***********
[0]
[2]
***********
[ 5 6 7 8 9 10 1 2]
[ 3 4 5 6 7 8 9 10]
[1 2 3 4 5 6 7 8]
[ 9 10]
***********
[4]
[6]
[8]