记录tensorflow、caffe模型文件的获取方式

  • Post author:
  • Post category:其他


在学习tensorflow模型时,经常需要用到xx.pb网络模型文件。由于自己是刚学tf,因此经常遇到找不到pb文件的情况,下面就将自己找到的一些方法共享给大家。


1、华为云modelzoo共享的pb文件


https://www.huaweicloud.com/ascend/resources/modelzoo?ticket=ST-2203190-w5VbNKUD4ZleGiBvoQk3ygqA-sso&locale=zh-cn

在华为云共享的这些modelzoo资源中,选择ATC_XX开始的这些资源,就能直接下载到pb文件。

缺点:提供的pb文件比较少,只有常见的resnet50,VGG16,VGG19等20个左右网络


2、github共享的tf models


https://github.com/tensorflow/models/tree/master/research/slim

通过下载xx.tar.gz可以得到对应模型的ckpt文件。因此需要自己写代码才能将ckpt文件转换成Pb文件。

ckpt转换pb代码的python脚本,我这边是参考了大神博客:

《.ckpt、.pb、.pbtxt模型相互转换》

,进行了简单修改。

示例代码如下:

# -*- coding: utf-8 -*-
import os
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.platform import gfile
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import graph_util


# print ckpt_node_name
def ckpt_node_name(filename):
    checkpoint_path = os.path.join(filename)
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        print('tensor_name: ', key)


# convert .ckpt to .pb to freeze a trained model
def convert_ckpt_to_pb(filename1, filename2):
    # filename1 is a .meta file
    saver = tf.train.import_meta_graph(filename1, clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    with tf.Session() as sess:
        saver.restore(sess, filename1)
        # you need to change the output node name ['embeddings'] to your model's real name.
        output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def, ['output_node_name'])
        with tf.gfile.GFile(filename2, "wb") as f:
            f.write(output_graph_def.SerializeToString())


# print pb_node_name
def pb_node_name(filename):
    def create_graph():
        with tf.gfile.FastGFile(filename, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')

    create_graph()
    tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
    for tensor_name in tensor_name_list:
        print(tensor_name, '\n')


def convert_pb_to_pbtxt(filename):
    with gfile.FastGFile(filename, 'rb') as f:
        graph_def = tf.GraphDef()

        graph_def.ParseFromString(f.read())

        tf.import_graph_def(graph_def, name='')

        # tf.train.write_graph(graph_def, './', 'protobuf.pbtxt', as_text=True)
        tf.train.write_graph(graph_def, './tmp', 'LSTM111.pbtxt', as_text=True)
    return


def convert_pbtxt_to_pb(filename):
    """Returns a `tf.GraphDef` proto representing the data in the given pbtxt file.
    Args:
      filename: The name of a file containing a GraphDef pbtxt (text-formatted
        `tf.GraphDef` protocol buffer data).
    """
    with tf.gfile.FastGFile(filename, 'r') as f:
        graph_def = tf.GraphDef()

        file_content = f.read()

        # Merges the human-readable string in `file_content` into `graph_def`.
        text_format.Merge(file_content, graph_def)
        tf.train.write_graph(graph_def, './tmp/train', 'lstm.pb', as_text=False)
    return


if __name__ == '__main__':
    model_path = 'D:\\modelzoo\\inception_v1_2016_08_28\\'
    ckpt_path = model_path + 'inception_v1.ckpt'
    # 输出pb模型的路径
    out_pb_path = model_path + 'inception_v1.pb'
    convert_ckpt_to_pb(ckpt_path, out_pb_path)
    print('Convert .ckpt to .pb has finished')


问题1:

但是在使用这段代码时,遇到如下报错:AttributeError: module ‘tensorflow._api.v2.train’ has no attribute ‘import_meta_graph’

解决方法是将头文件引入修改成如下:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()


问题2:

解决完以上问题后,又报错:ttributeError: ‘NoneType’ object has no attribute ‘restore’

关于saver.restore问题,一直找不到解决方法。因此我只有卸载tf v2.0,重装成tf v1.15。

由于直接通过命令:pip install tensorflow==1.15 安装比较慢,因此我是直接安装离线包。对应包路径:

https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/tensorflow/

当然也可以加上清华镜像来安装TF:如:

pip install tensorflow==1.15 -i https://pypi.tuna.tsinghua.edu.cn/simple/

安装成功后可以通过pip show tensorflow来确认版本。

但是发现,即使安装成v1.15版本后,仍然报错。截止目前,我仍未找到解决方法。另外要注意的是,ckpt文件只是checkpoint文件,还需要要对应.meta,但是没找到。可能是我没注意吧。


3、caffe网络的模型下载方式


https://github.com/BVLC/caffe/wiki/Model-Zoo

当然还有这个网络也能进行caffemodel下载(但是没有prototxt文件):

http://dl.caffe.berkeleyvision.org/


3、github 共享的tensorflow网络模型下载


https://github.com/IntelAI/models/tree/master/benchmarks

点击后面的FP32可以看到对应模型下载的方式,推荐使用wget xxx



版权声明:本文为yuzipeng原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。