Tensorflow 获取model中的变量列表

  • Post author:
  • Post category:其他


1、动态获取

(1)朴素获取法

1) 朴素获取可训练变量:t_vars = tf.trainable_variables()

2)朴素获取全部变量,包含声明training=False变量:all_vars = tf.global_variables()

(2)使用tensorflow.contrib.slim

1) 获取常规变量(是slim里面与model变量对应的一个类型):regular_variables = slim.get_variables()

2)直接获取:vars = slim.get_variables_to_restore()

3)slim用于筛选方法

a. 通过name筛选: variables = slim.get_variables_by_name(“d_”)

b. 通过name后缀筛选:variables = slim.get_variables_by_suffix(“_b”)

c. 通过namespace筛选:variables = slim.get_variables(scope=”layer1″)

d. 通过include和exclude筛选

d0. variables_to_restore = slim.get_variables_to_restore(include=[“d_”])

d1. variables_to_restore = slim.get_var



版权声明:本文为NOT_GUY原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。