TensorFlow系列——写tfrecord数据

  • Post author:
  • Post category:其他


涉及的概念:

  • Example
  • Tensor
  • SequenceExample
  • Feature

涉及的写入方式

  • python
  • spark scala
  • spark dataframe

写入的数据类型

  • int64
  • float32
  • string

写入的特征类型

  • VarlenFeature
  • SparseFeature
  • FixedLenFeature
feature_schema = {
    # featureA: 一维字符串特征
    "featureA": tf.io.FixedLenFeature(shape=(1,), dtype=tf.string, default_value="null"),
    # featureB: 一维数值特征
    "featureB": tf.io.FixedLenFeature(shape=(1,), dtype=tf.float32, default_value=0.0),
    # featureC: 三维字符串特征
    "featureC": tf.io.FixedLenFeature(shape=(3,), dtype=tf.string, default_value=["null", "null", "null"]),
    # featureD: 二维数值特征
    "featureD": tf.io.FixedLenFeature(shape=(2,), dtype=tf.int64, default_value=[0, 0]),
    # featureE: 不固定维度字符串特征
    "featureE": tf.io.VarLenFeature(dtype=tf.string),
    # featureF: 不固定维度数值特征
    "featureF": tf.io.VarLenFeature(dtype=tf.float32),
    "featureEwhight":tf.io.VarLenFeature(dtype=tf.float32),
    # featureG: 二维字符串序列特征
    "featureG": tf.io.FixedLenSequenceFeature(shape=(2,), dtype=tf.string, allow_missing=True, default_value=None),
    # featureH: 三维数值序列特征
    "featureH": tf.io.FixedLenSequenceFeature(shape=(3,), dtype=tf.int64, allow_missing=True, default_value=None),
    # featureI: 21 * 4 * 10 维字符串稀疏特征
    "featureI": tf.io.SparseFeature(index_key=["featureI_Index0", "featureI_Index1", "featureI_Index2"],
                                    value_key="featureI_value", dtype=tf.string, size=[21, 4, 10], already_sorted=False)
}

一、python方式写tfrecord

    # TensorFlow2.x
    writer = tf.io.TFRecordWriter("./tfrecord")

    example_1 = tf.train.Example(features=tf.train.Features(feature={
        # 数据维度必须为 1
        "featureA": tf.train.Feature(bytes_list=tf.train.BytesList(value=[u"valueA1".encode("utf-8")])),
        # 数据维度必须为 1
        "featureB": tf.train.Feature(float_list=tf.train.FloatList(value=[2.3])),
        # 数据维度必须为 3
        "featureC": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueC3C3C3", b"valueC2", b"valueC8"])),
        # 数据维度必须为 2
        "featureD": tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4])),
        # 数据维度不固定
        "featureE": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueE3", b"valueE2", b"valueE4", b"valueE9"])),
        # 数据维度不固定(把它作为featureE的对应权重)
        "featureEwhight": tf.train.Feature(float_list=tf.train.FloatList(value=[3.0,2.0,4.0,9.0])),
        # 数据维度不固定
        "featureF": tf.train.Feature(float_list=tf.train.FloatList(value=[4.5, 1.2, 2.1]))
    }))
    example_2 = tf.train.Example(features=tf.train.Features(feature={
        "featureA": tf.train.Feature(bytes_list=tf.train.BytesList(value=[u"valueA1".encode("utf-8")])),
        "featureB": tf.train.Feature(float_list=tf.train.FloatList(value=[2.3])),
        "featureC": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueC3C3C3", b"valueC2", b"valueC8"])),
        "featureD": tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4])),
        "featureE": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueE3", b"valueE5"])),
        "featureEwhight": tf.train.Feature(float_list=tf.train.FloatList(value=[3.0,5.0])),
        "featureF": tf.train.Feature(float_list=tf.train.FloatList(value=[5.5]))
    }))
    example_3 = tf.train.Example(features=tf.train.Features(feature={
        "featureA": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"value5"])),
        "featureB": tf.train.Feature(float_list=tf.train.FloatList(value=[2.3])),
        "featureC": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueC3C3C3", b"valueC2", b"valueC8"])),
        "featureD": tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4])),
        "featureE": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueE1", b"valueE2", b"valueE2"])),
        "featureEwhight": tf.train.Feature(float_list=tf.train.FloatList(value=[1.0,2.0,2.0])),
        "featureF": tf.train.Feature(float_list=tf.train.FloatList(value=[1.5, 2.2]))
    }))
    example_4 = tf.train.Example(features=tf.train.Features(feature={
        "featureA": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueA4"])),
        "featureB": tf.train.Feature(float_list=tf.train.FloatList(value=[2.3])),
        "featureC": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueC3C3C3", b"valueC2", b"valueC8"])),
        "featureD": tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4])),
        "featureE": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueE3", b"valueE3", b"valueE1", b"valueE4"])),
        "featureEwhight": tf.train.Feature(float_list=tf.train.FloatList(value=[3.0,3.0,1.0,4.0])),
        "featureF": tf.train.Feature(float_list=tf.train.FloatList(value=[7.5, 1.4, 2.3]))
    }))

    sequence_example = tf.train.SequenceExample(
        context=tf.train.Features(feature={
            "featureA": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueA2"])),
            "featureB": tf.train.Feature(float_list=tf.train.FloatList(value=[4.1])),
            "featureC": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueC1", b"valueC2", b"valueC3"])),
            "featureE": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueE6", b"valueE1"])),
            "featureF": tf.train.Feature(float_list=tf.train.FloatList(value=[9.4, 6.6, 8.3, 7.2, 9.1])),
            "featureI_Index0": tf.train.Feature(int64_list=tf.train.Int64List(value=[5, 10, 2, 2, 6])),
            "featureI_Index1": tf.train.Feature(int64_list=tf.train.Int64List(value=[4, 2, 7, 6, 3])),
            "featureI_Index2": tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 1, 8, 4, 2])),
            "featureI_value": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueI3", b"valueI5", b"valueI2", b"valueI7", b"valueI5"]))
        }),
        feature_lists=tf.train.FeatureLists(feature_list={
            "featureG": tf.train.FeatureList(feature=[
                tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueG2", b"valueG1"])),
                tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"valueG6", u"valueG6".encode("utf-8")]))
            ]),
            "featureH": tf.train.FeatureList(feature=[
                tf.train.Feature(int64_list=tf.train.Int64List(value=[1, 1, 3])),
                tf.train.Feature(int64_list=tf.train.Int64List(value=[4, 2, 9])),
                tf.train.Feature(int64_list=tf.train.Int64List(value=[8, 4, 2]))
            ])
        })
    )

    writer.write(example_1.SerializeToString())
    writer.write(example_2.SerializeToString())
    writer.write(example_3.SerializeToString())
    writer.write(example_4.SerializeToString())
    #writer.write(sequence_example.SerializeToString())

    writer.close()

二、spark scala方式写tfrecord

// 1、将原始特征数据序列化为Feature
# int型特征序列化为Feature
# values 类型为 Iterable<Long> 
Int64List int64List = Int64List.newBuilder().addAllValue(values).build();
return Feature.newBuilder().setInt64List(int64List).build();

# float(或double)型数据序列化为Feature
# values 类型为 Iterable<Float>
FloatList floatList = FloatList.newBuilder().addAllValue(values).build();
return Feature.newBuilder().setFloatList(floatList).build();

# string 型数据先转化为ByteString,再序列化为 Feature
# value 类型为 string
List<ByteString> bytesStringList = new ArrayList<>();
bytesStringList.add(ByteString.copyFromUtf8(value));
BytesList bytesList = BytesList.newBuilder().addAllValue(bytesStringList).build();
return Feature.newBuilder().setBytesList(bytesList).build();

// 2、将多个Feature序列化为FeatureList(若不需要则不必使用)
# 将以上序列化得到的Feature序列化为FeatureList
return FeatureList.newBuilder().addAllFeature(features).build();

// 3、创建字典(hashMap),为各个Feature、Features命名
private Map<String, Feature> features;
private Map<String, FeatureList> featureLists;

// 4、将3中命名的Feature、FeatureList序列化为Features和FeatureList
return Features.newBuilder().putAllFeature(features).build();
return FeatureLists.newBuilder().putAllFeatureList(featureLists).build();

// 5.1、对于非sequence特征数据:将4生成的Features序列化为Example
Example.Builder exampleBuilder = Example.newBuilder();
return exampleBuilder.setFeatures(Features).build();
或:
return exampleBuilder.mergeFeatures(feature).build();

// 5.2、对于sequence特征数据:将4生成的Features及FeatureList序列化为Example
SequenceExample.Builder sequenceExampleBuilder = SequenceExample.newBuilder();
sequenceExampleBuilder.setContext(Features);
return sequenceExampleBuilder.setFeatureLists(FeatureLists).newBuilder();
或:
sequenceExampleBuilder.mergeContext(Features);
return sequenceExampleBuilder.mergeFeatureLists(FeatureLists).newBuilder();

// 6、将Example写入tfrecord
RDD
.map(example => (new BytesWritable(example.toByteArray), NullWritable.get()))
.saveAsNewAPIHadoopFile[TFRecordFileOutputFormat]("path")

三、spark dataframe方式写tfrecord

// 定义数据列
val schema = StructType(List(
    StructField("IntegerCol",IntegerType),
    StructField("LongCol",LongType),
    StructField("FloatCol",FloatType),
    StructField("DoubleCol",DoubleType),
    StructField("VectorCol",ArrayType(DoubleType,True)),
    StructField("StringCol",StringType),
    StructField("VectorStrCol",ArrayType(StringType,True))
))
// 依据定义的数据列,保存为tfrecord
RDD
.toDF("col1","col2","col3")
.coalesce(1)
.write
.format("tfrecords")
.option("recordType","Example")
.save("path")



版权声明:本文为h_jlwg6688原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。