涉及的概念:
- Example
- Tensor
- SequenceExample
- Feature
涉及的写入方式
- python
- spark scala
- spark dataframe
写入的数据类型
- int64
- float32
- string
写入的特征类型
- VarlenFeature
- SparseFeature
- FixedLenFeature
feature_schema = {
# featureA: 一维字符串特征
"featureA": tf.io.FixedLenFeature(shape=(1,), dtype=tf.string, default_value="null"),
# featureB: 一维数值特征
"featureB": tf.io.FixedLenFeature(shape=(1,), dtype=tf.float32, default_value=0.0),
# featureC: 三维字符串特征
"featureC": tf.io.FixedLenFeature(shape=(3,), dtype=tf.string, default_value=["null", "null", "null"]),
# featureD: 二维数值特征
"featureD": tf.io.FixedLenFeature(shape=(2,), dtype=tf.int64, default_value=[0, 0]),
# featureE: 不固定维度字符串特征
"featureE": tf.io.VarLenFeature(dtype=tf.string),
# featureF: 不固定维度数值特征
"featureF": tf.io.VarLenFeature(dtype=tf.float32),
"featureEwhight":tf.io.VarLenFeature(dtype=tf.float32),
# featureG: 二维字符串序列特征
"featureG": tf.io.FixedLenSequenceFeature(shape=(2,), dtype=tf.string, allow_missing=True, default_value=None),
# featureH: 三维数值序列特征
"featureH": tf.io.FixedLenSequenceFeature(shape=(3,), dtype=tf.int64, allow_missing=True, default_value=None),
# featureI: 21 * 4 * 10 维字符串稀疏特征
"featureI": tf.io.SparseFeature(index_key=["featureI_Index0", "featureI_Index1", "featureI_Index2"],
value_key="featureI_value", dtype=tf.string, size=[21, 4, 10], already_sorted=False)
}
一、python方式写tfrecord
# TensorFlow2.x
writer = tf.io.TFRecordWriter("./tfrecord")
example_1 = tf.train.Example(features=tf.train.Features(feature={
# 数据维度必须为 1
"featureA": tf.train.Feature(bytes_list=tf.train.BytesList(value=[u"valueA1".encode("utf-8")])),
# 数据维度必须为 1
"featureB": tf.train.Feature(float_list=tf.train.FloatList(value=[2.3])),
# 数据维度必须为 3
"featureC": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueC3C3C3", b"valueC2", b"valueC8"])),
# 数据维度必须为 2
"featureD": tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4])),
# 数据维度不固定
"featureE": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueE3", b"valueE2", b"valueE4", b"valueE9"])),
# 数据维度不固定(把它作为featureE的对应权重)
"featureEwhight": tf.train.Feature(float_list=tf.train.FloatList(value=[3.0,2.0,4.0,9.0])),
# 数据维度不固定
"featureF": tf.train.Feature(float_list=tf.train.FloatList(value=[4.5, 1.2, 2.1]))
}))
example_2 = tf.train.Example(features=tf.train.Features(feature={
"featureA": tf.train.Feature(bytes_list=tf.train.BytesList(value=[u"valueA1".encode("utf-8")])),
"featureB": tf.train.Feature(float_list=tf.train.FloatList(value=[2.3])),
"featureC": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueC3C3C3", b"valueC2", b"valueC8"])),
"featureD": tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4])),
"featureE": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueE3", b"valueE5"])),
"featureEwhight": tf.train.Feature(float_list=tf.train.FloatList(value=[3.0,5.0])),
"featureF": tf.train.Feature(float_list=tf.train.FloatList(value=[5.5]))
}))
example_3 = tf.train.Example(features=tf.train.Features(feature={
"featureA": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"value5"])),
"featureB": tf.train.Feature(float_list=tf.train.FloatList(value=[2.3])),
"featureC": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueC3C3C3", b"valueC2", b"valueC8"])),
"featureD": tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4])),
"featureE": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueE1", b"valueE2", b"valueE2"])),
"featureEwhight": tf.train.Feature(float_list=tf.train.FloatList(value=[1.0,2.0,2.0])),
"featureF": tf.train.Feature(float_list=tf.train.FloatList(value=[1.5, 2.2]))
}))
example_4 = tf.train.Example(features=tf.train.Features(feature={
"featureA": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueA4"])),
"featureB": tf.train.Feature(float_list=tf.train.FloatList(value=[2.3])),
"featureC": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueC3C3C3", b"valueC2", b"valueC8"])),
"featureD": tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4])),
"featureE": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueE3", b"valueE3", b"valueE1", b"valueE4"])),
"featureEwhight": tf.train.Feature(float_list=tf.train.FloatList(value=[3.0,3.0,1.0,4.0])),
"featureF": tf.train.Feature(float_list=tf.train.FloatList(value=[7.5, 1.4, 2.3]))
}))
sequence_example = tf.train.SequenceExample(
context=tf.train.Features(feature={
"featureA": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueA2"])),
"featureB": tf.train.Feature(float_list=tf.train.FloatList(value=[4.1])),
"featureC": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueC1", b"valueC2", b"valueC3"])),
"featureE": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueE6", b"valueE1"])),
"featureF": tf.train.Feature(float_list=tf.train.FloatList(value=[9.4, 6.6, 8.3, 7.2, 9.1])),
"featureI_Index0": tf.train.Feature(int64_list=tf.train.Int64List(value=[5, 10, 2, 2, 6])),
"featureI_Index1": tf.train.Feature(int64_list=tf.train.Int64List(value=[4, 2, 7, 6, 3])),
"featureI_Index2": tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 1, 8, 4, 2])),
"featureI_value": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueI3", b"valueI5", b"valueI2", b"valueI7", b"valueI5"]))
}),
feature_lists=tf.train.FeatureLists(feature_list={
"featureG": tf.train.FeatureList(feature=[
tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueG2", b"valueG1"])),
tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueG6", u"valueG6".encode("utf-8")]))
]),
"featureH": tf.train.FeatureList(feature=[
tf.train.Feature(int64_list=tf.train.Int64List(value=[1, 1, 3])),
tf.train.Feature(int64_list=tf.train.Int64List(value=[4, 2, 9])),
tf.train.Feature(int64_list=tf.train.Int64List(value=[8, 4, 2]))
])
})
)
writer.write(example_1.SerializeToString())
writer.write(example_2.SerializeToString())
writer.write(example_3.SerializeToString())
writer.write(example_4.SerializeToString())
#writer.write(sequence_example.SerializeToString())
writer.close()
二、spark scala方式写tfrecord
// 1、将原始特征数据序列化为Feature
# int型特征序列化为Feature
# values 类型为 Iterable<Long>
Int64List int64List = Int64List.newBuilder().addAllValue(values).build();
return Feature.newBuilder().setInt64List(int64List).build();
# float(或double)型数据序列化为Feature
# values 类型为 Iterable<Float>
FloatList floatList = FloatList.newBuilder().addAllValue(values).build();
return Feature.newBuilder().setFloatList(floatList).build();
# string 型数据先转化为ByteString,再序列化为 Feature
# value 类型为 string
List<ByteString> bytesStringList = new ArrayList<>();
bytesStringList.add(ByteString.copyFromUtf8(value));
BytesList bytesList = BytesList.newBuilder().addAllValue(bytesStringList).build();
return Feature.newBuilder().setBytesList(bytesList).build();
// 2、将多个Feature序列化为FeatureList(若不需要则不必使用)
# 将以上序列化得到的Feature序列化为FeatureList
return FeatureList.newBuilder().addAllFeature(features).build();
// 3、创建字典(hashMap),为各个Feature、Features命名
private Map<String, Feature> features;
private Map<String, FeatureList> featureLists;
// 4、将3中命名的Feature、FeatureList序列化为Features和FeatureList
return Features.newBuilder().putAllFeature(features).build();
return FeatureLists.newBuilder().putAllFeatureList(featureLists).build();
// 5.1、对于非sequence特征数据:将4生成的Features序列化为Example
Example.Builder exampleBuilder = Example.newBuilder();
return exampleBuilder.setFeatures(Features).build();
或:
return exampleBuilder.mergeFeatures(feature).build();
// 5.2、对于sequence特征数据:将4生成的Features及FeatureList序列化为Example
SequenceExample.Builder sequenceExampleBuilder = SequenceExample.newBuilder();
sequenceExampleBuilder.setContext(Features);
return sequenceExampleBuilder.setFeatureLists(FeatureLists).newBuilder();
或:
sequenceExampleBuilder.mergeContext(Features);
return sequenceExampleBuilder.mergeFeatureLists(FeatureLists).newBuilder();
// 6、将Example写入tfrecord
RDD
.map(example => (new BytesWritable(example.toByteArray), NullWritable.get()))
.saveAsNewAPIHadoopFile[TFRecordFileOutputFormat]("path")
三、spark dataframe方式写tfrecord
// 定义数据列
val schema = StructType(List(
StructField("IntegerCol",IntegerType),
StructField("LongCol",LongType),
StructField("FloatCol",FloatType),
StructField("DoubleCol",DoubleType),
StructField("VectorCol",ArrayType(DoubleType,True)),
StructField("StringCol",StringType),
StructField("VectorStrCol",ArrayType(StringType,True))
))
// 依据定义的数据列,保存为tfrecord
RDD
.toDF("col1","col2","col3")
.coalesce(1)
.write
.format("tfrecords")
.option("recordType","Example")
.save("path")
版权声明:本文为h_jlwg6688原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。