在学习tensorflow模型时,经常需要用到xx.pb网络模型文件。由于自己是刚学tf,因此经常遇到找不到pb文件的情况,下面就将自己找到的一些方法共享给大家。
1、华为云modelzoo共享的pb文件
在华为云共享的这些modelzoo资源中,选择ATC_XX开始的这些资源,就能直接下载到pb文件。
缺点:提供的pb文件比较少,只有常见的resnet50,VGG16,VGG19等20个左右网络
2、github共享的tf models
https://github.com/tensorflow/models/tree/master/research/slim
通过下载xx.tar.gz可以得到对应模型的ckpt文件。因此需要自己写代码才能将ckpt文件转换成Pb文件。
ckpt转换pb代码的python脚本,我这边是参考了大神博客:
《.ckpt、.pb、.pbtxt模型相互转换》
,进行了简单修改。
示例代码如下:
# -*- coding: utf-8 -*-
import os
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.platform import gfile
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import graph_util
# print ckpt_node_name
def ckpt_node_name(filename):
checkpoint_path = os.path.join(filename)
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print('tensor_name: ', key)
# convert .ckpt to .pb to freeze a trained model
def convert_ckpt_to_pb(filename1, filename2):
# filename1 is a .meta file
saver = tf.train.import_meta_graph(filename1, clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
with tf.Session() as sess:
saver.restore(sess, filename1)
# you need to change the output node name ['embeddings'] to your model's real name.
output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def, ['output_node_name'])
with tf.gfile.GFile(filename2, "wb") as f:
f.write(output_graph_def.SerializeToString())
# print pb_node_name
def pb_node_name(filename):
def create_graph():
with tf.gfile.FastGFile(filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
print(tensor_name, '\n')
def convert_pb_to_pbtxt(filename):
with gfile.FastGFile(filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
# tf.train.write_graph(graph_def, './', 'protobuf.pbtxt', as_text=True)
tf.train.write_graph(graph_def, './tmp', 'LSTM111.pbtxt', as_text=True)
return
def convert_pbtxt_to_pb(filename):
"""Returns a `tf.GraphDef` proto representing the data in the given pbtxt file.
Args:
filename: The name of a file containing a GraphDef pbtxt (text-formatted
`tf.GraphDef` protocol buffer data).
"""
with tf.gfile.FastGFile(filename, 'r') as f:
graph_def = tf.GraphDef()
file_content = f.read()
# Merges the human-readable string in `file_content` into `graph_def`.
text_format.Merge(file_content, graph_def)
tf.train.write_graph(graph_def, './tmp/train', 'lstm.pb', as_text=False)
return
if __name__ == '__main__':
model_path = 'D:\\modelzoo\\inception_v1_2016_08_28\\'
ckpt_path = model_path + 'inception_v1.ckpt'
# 输出pb模型的路径
out_pb_path = model_path + 'inception_v1.pb'
convert_ckpt_to_pb(ckpt_path, out_pb_path)
print('Convert .ckpt to .pb has finished')
问题1:
但是在使用这段代码时,遇到如下报错:AttributeError: module ‘tensorflow._api.v2.train’ has no attribute ‘import_meta_graph’
解决方法是将头文件引入修改成如下:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
问题2:
解决完以上问题后,又报错:ttributeError: ‘NoneType’ object has no attribute ‘restore’
关于saver.restore问题,一直找不到解决方法。因此我只有卸载tf v2.0,重装成tf v1.15。
由于直接通过命令:pip install tensorflow==1.15 安装比较慢,因此我是直接安装离线包。对应包路径:
https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/tensorflow/
当然也可以加上清华镜像来安装TF:如:
pip install tensorflow==1.15 -i https://pypi.tuna.tsinghua.edu.cn/simple/
安装成功后可以通过pip show tensorflow来确认版本。
但是发现,即使安装成v1.15版本后,仍然报错。截止目前,我仍未找到解决方法。另外要注意的是,ckpt文件只是checkpoint文件,还需要要对应.meta,但是没找到。可能是我没注意吧。
3、caffe网络的模型下载方式
https://github.com/BVLC/caffe/wiki/Model-Zoo
当然还有这个网络也能进行caffemodel下载(但是没有prototxt文件):
http://dl.caffe.berkeleyvision.org/
3、github 共享的tensorflow网络模型下载
https://github.com/IntelAI/models/tree/master/benchmarks
点击后面的FP32可以看到对应模型下载的方式,推荐使用wget xxx