将tensorflow 训练模型pb文件转为Relay_IR

  • Post author:
  • Post category:其他



举例:将推荐场景的DCN网络转为Relay_IR

1 预训练模型保存为pb

self.saver = tf.train.Saver()
init = tf.global_variables_initializer()
self.sess = self._init_session()
self.sess.run(init)
self.saver.save(self.sess, './checkpoint_dir/')
constant_graph = tf.graph_util.convert_variables_to_constants(self.sess, self.sess.graph_def, ['Sigmoid'])
graph_def = tf.compat.v1.graph_util.extract_sub_graph(constant_graph,["Sigmoid"])
with tf.gfile.FastGFile('./dcn_ckpt.pb', mode='wb') as f:
    f.write(graph_def.SerializeToString())

2 初始化设定size(部分网络没有初始值没法推断中间的shape size)

self.feature_index = tf.placeholder(tf.int32, shape=[1024, 39], name="feature_index")
self.feature_value = tf.placeholder(tf.float32, shape=[1024, 39], name="feature_value")
self.dropout_keep_deep = tf.placeholder(tf.float32, shape=[3], name="dropout_keep_deep")

3 转relay

import tvm
import tx
from tvm.relay.transform import *
from tx.relay.graph_schedule import *
from tx.utils.utils import is_valid, stcprof_list_to_dict
import tx.relay.backend
import tx.runtime.graph_executor
import tvm.relay.testing.tf as tf_testing
import sys
import numpy as np
import tensorflow as tf
from tvm.target import Target
 
 
# Tensorflow imports
import tensorflow as tf
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops.gen_random_ops import *
from tensorflow.core.framework import attr_value_pb2
 
# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing
import pandas as pd
 
 
dtype = "float16"
target = "stc_tc"
model_path = "./DCN-master/example/dcn_ckpt.pb"
#model_path = "./freeze_graph.pb"
 
def get_pb(output_name):
    # Import from TensorFlow .pb model
    tf.reset_default_graph()
    with tf.gfile.FastGFile(model_path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        #graph_def = tf.compat.v1.graph_util.extract_sub_graph(graph_def,["Sigmoid"])
         
    new_model = tf.GraphDef()
    with tf.Graph().as_default() as graph:
        for node in graph_def.node:
            if node.op == "RefSwitch":
                node.op = "Switch"
                nn = new_model.node.add()
                nn.CopyFrom(node)
            elif node.op == "AssignSub":
                node.op = "Sub"
                if "use_locking" in node.attr:
                    del node.attr["use_locking"]
                nn = new_model.node.add()
                nn.CopyFrom(node)
            elif node.op == "RandomUniform":
                nn = new_model.node.add()
                nn.op = "Const"
                nn.name = node.name
                if "T" in node.attr:
                    del node.attr["T"]
                if "seed" in node.attr:
                    del node.attr["seed"]
                if "seed2" in node.attr:
                    del node.attr["seed2"]
                nn.attr["value"].CopyFrom(attr_value_pb2.AttrValue(tensor=tf.make_tensor_proto(1.0)))
                nn.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(type=node.attr["dtype"].type))
                if "use_locking" in node.attr:
                    del node.attr["use_locking"]
                if 'validate_shape' in node.attr:
                    del node.attr['validate_shape']
            else:
                nn = new_model.node.add()
                nn.CopyFrom(node)
     
    feature_index = tf.placeholder(tf.int32,shape=[None,None],  name="feature_index")
    feature_value = tf.placeholder(tf.float32,shape=[None,None],  name="feature_value")
    tf.import_graph_def(
                new_model,
                name="",
                input_map={
                "feature_index": feature_index,
                "feature_value": feature_value,
            },
            )
    graph_def = tf_testing.ProcessGraphDefParam(new_model)
    with tf.Session() as sess:
        graph_def = tf.graph_util.convert_variables_to_constants(sess,graph_def,["Sigmoid"])
        graph_def = tf_testing.AddShapesToGraphDef(sess, output_name)
        with tf.gfile.GFile("./DCN-master/example/dcn_relay_in.pb", "wb") as f:
            f.write(graph_def.SerializeToString())
    return graph_def
 
def get_graph(
    graph_def=None,
    target=None,
    relay_path=None,
    shape_dict=None,
):
     
    mod, params = relay.frontend.from_tensorflow(graph_def, layout="NHWC", shape=shape_dict)
    # save module
    if relay_path:
        with open(relay_path, "w") as f:
            f.write(mod.astext())
        print("dump relay to {}".format(relay_path))
    return mod, params
 
 
if __name__ == "__main__":
    batch = 1024
    feature_shape = [batch,39]
    output_node = "Sigmoid"
    shape_dict = {
            "feature_index": feature_shape,
            "feature_value": feature_shape,
        }
    graph_def = get_pb(output_node)
     
    relay_mod, params = get_graph(
        graph_def=graph_def,
        target=target,
        relay_path="./dcn.relay",
        shape_dict=shape_dict,
    )



版权声明:本文为pxiongw原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。