TensorFlow 2.X获取Flops和参数量(parameters)的方法(2022年)

  • Post author:
  • Post category:其他


0.少走弯路

TensorFlow很多地方使用不如pytorch方便,比如说获取模型的Flops和parameters这种基本信息都需要查找半天。平时大家在分析模型优势的时候除了在准确率或者精度方面比较,还有一个就是运行效率了。但是每个人的电脑配置不一样,光靠运行时间比较也不好做,一般见得比较多的就是比较Flops和parameters。我之所以在标题上标明年份,是不想让大家浪费时间,很多帖子讲的都是1.X版本的做法,实在是太难用了,好多试了也不行,都是三四年前的帖子了,以下我的方法最近自己刚刚测试完,是可用的。

首先是使用的模型,就选个比较常见的吧

def Alexnet32():
    inputs1 = Input(shape=(32, 32, 1))
    conv1 = Conv2D(filters=16, kernel_size=3)(inputs1)
    BN1 = BatchNormalization()(conv1)
    act1 = Activation('relu')(BN1)
    pool1 = MaxPooling2D(pool_size=3, strides=1)(act1)
    conv4 = Conv2D(filters=32, kernel_size=3, padding='same')(pool1)
    BN2 = BatchNormalization()(conv4)
    act2 = Activation('relu')(BN2)
    pool2 = MaxPooling2D(pool_size=3, strides=1)(act2)
    conv5 = Conv2D(filters=128, kernel_size=3, padding='same',
                   activation='relu')(pool2)
    conv6 = Conv2D(filters=128, kernel_size=3, padding='same',
                   activation='relu')(conv5)
    conv7 = Conv2D(filters=128, kernel_size=3, strides=2,
                   activation='relu')(conv6)
    BN3 = BatchNormalization()(conv7)
    act3 = Activation('relu')(BN3)
    pool3 = MaxPooling2D(pool_size=3, strides=1)(act3)
    flat1 = Flatten()(pool3)
    dense1 = Dense(300)(flat1)
    BN4 = BatchNormalization()(dense1)
    drop1 = Dropout(0.2)(BN4)
    outputs = Dense(10, activation='softmax')(drop1)
    model = Model(inputs=inputs1, outputs=outputs)
    # model.summary()  # 打印模型结构
    return model

1.查看Flops的2种方法

第一个比较简单

导入数据包,直接调用,通过pip install keras-flops安装就可以

from keras_flops import get_flops

flops = get_flops(Alexnet32(), batch_size=1)
print(f"FLOPS: {flops / 10 ** 6:.03} M")

结果比较全面,有各层的数据

==================Model Analysis Report======================

Doc:
scope: The nodes in the model graph are organized by their names, which is hierarchical like filesystem.
flops: Number of float operations. Note: Please read the implementation for the math behind it.

Profile:
node name | # float_ops
_TFProfRoot (--/307.61m flops)
  model/conv2d_3/Conv2D (199.36m/199.36m flops)
  model/conv2d_2/Conv2D (49.84m/49.84m flops)
  model/conv2d_4/Conv2D (42.47m/42.47m flops)
  model/dense/MatMul (7.68m/7.68m flops)
  model/conv2d_1/Conv2D (7.23m/7.23m flops)
  model/conv2d/Conv2D (259.20k/259.20k flops)
  model/max_pooling2d_1/MaxPool (194.69k/194.69k flops)
  model/max_pooling2d_2/MaxPool (115.20k/115.20k flops)
  model/max_pooling2d/MaxPool (112.90k/112.90k flops)
  model/conv2d_2/BiasAdd (86.53k/86.53k flops)
  model/conv2d_3/BiasAdd (86.53k/86.53k flops)
  model/batch_normalization_1/FusedBatchNormV3 (50.37k/50.37k flops)
  model/batch_normalization_2/FusedBatchNormV3 (37.63k/37.63k flops)
  model/batch_normalization/FusedBatchNormV3 (28.90k/28.90k flops)
  model/conv2d_1/BiasAdd (25.09k/25.09k flops)
  model/conv2d_4/BiasAdd (18.43k/18.43k flops)
  model/conv2d/BiasAdd (14.40k/14.40k flops)
  model/dense_1/MatMul (6.00k/6.00k flops)
  model/batch_normalization_3/batchnorm/Rsqrt (600/600 flops)
  model/batch_normalization_3/batchnorm/add (300/300 flops)
  model/batch_normalization_3/batchnorm/add_1 (300/300 flops)
  model/batch_normalization_3/batchnorm/mul (300/300 flops)
  model/batch_normalization_3/batchnorm/mul_1 (300/300 flops)
  model/batch_normalization_3/batchnorm/mul_2 (300/300 flops)
  model/batch_normalization_3/batchnorm/sub (300/300 flops)
  model/dense/BiasAdd (300/300 flops)
  model/dense_1/Softmax (50/50 flops)
  model/dense_1/BiasAdd (10/10 flops)

======================End of Report==========================
FLOPS: 3.08e+02 M

第二个是通过调用函数实现的

据说是将高版本函数直接应用过来的,不得不佩服

from typing import Any, Callable, Dict, List, Optional, Union
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph


def try_count_flops(model: Union[tf.Module, tf.keras.Model],
                    inputs_kwargs: Optional[Dict[str, Any]] = None,
                    output_path: Optional[str] = None):
    """Counts and returns model FLOPs.
  Args:
    model: A model instance.
    inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
      shape specifications to getting corresponding concrete function.
    output_path: A file path to write the profiling results to.
  Returns:
    The model's FLOPs.
  """
    if hasattr(model, 'inputs'):
        try:
            # Get input shape and set batch size to 1.
            if model.inputs:
                inputs = [
                    tf.TensorSpec([1] + input.shape[1:], input.dtype)
                    for input in model.inputs
                ]
                concrete_func = tf.function(model).get_concrete_function(inputs)
            # If model.inputs is invalid, try to use the input to get concrete
            # function for model.call (subclass model).
            else:
                concrete_func = tf.function(model.call).get_concrete_function(
                    **inputs_kwargs)
            frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)

            # Calculate FLOPs.
            run_meta = tf.compat.v1.RunMetadata()
            opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
            if output_path is not None:
                opts['output'] = f'file:outfile={output_path}'
            else:
                opts['output'] = 'none'
            flops = tf.compat.v1.profiler.profile(
                graph=frozen_func.graph, run_meta=run_meta, options=opts)
            return flops.total_float_ops
        except Exception as e:  # pylint: disable=broad-except
            logging.info(
                'Failed to count model FLOPs with error %s, because the build() '
                'methods in keras layers were not called. This is probably because '
                'the model was not feed any input, e.g., the max train step already '
                'reached before this run.', e)
            return None
    return None

flops = try_count_flops(Alexnet32())
print(flops/1000000,"M Flops")

结果只有flops

307.611928 M Flops

2.查看参数量(parameters)的1种方法

这个查了不少,发现都是1.X版本的,需要想办法去兼容,但是这样一来搞不好又会有其他问题,而且我试了几个,一下子也没搞定,也不值得花太多时间。本来就有现成的方法,还是直接用model.summary()看就好了,细节也比较多。这个直接写在模型里运行Alexnet32()就能看到,也可以直接Alexnet32().summary()查看。

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 32, 32, 1)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 30, 30, 16)        160       
_________________________________________________________________
batch_normalization (BatchNo (None, 30, 30, 16)        64        
_________________________________________________________________
activation (Activation)      (None, 30, 30, 16)        0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 28, 28, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 32)        4640      
_________________________________________________________________
batch_normalization_1 (Batch (None, 28, 28, 32)        128       
_________________________________________________________________
activation_1 (Activation)    (None, 28, 28, 32)        0         
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 26, 26, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 26, 26, 128)       36992     
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 26, 26, 128)       147584    
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 12, 12, 128)       147584    
_________________________________________________________________
batch_normalization_2 (Batch (None, 12, 12, 128)       512       
_________________________________________________________________
activation_2 (Activation)    (None, 12, 12, 128)       0         
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 10, 10, 128)       0         
_________________________________________________________________
flatten (Flatten)            (None, 12800)             0         
_________________________________________________________________
dense (Dense)                (None, 300)               3840300   
_________________________________________________________________
batch_normalization_3 (Batch (None, 300)               1200      
_________________________________________________________________
dropout (Dropout)            (None, 300)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                3010      
=================================================================
Total params: 4,182,174
Trainable params: 4,181,222
Non-trainable params: 952

总参数量就是4182174,方法还是很简单直观的。



版权声明:本文为weixin_50642818原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。