安装对应的库文件
pip install ptflops
代码(get_model_complexity_info)
import torchvision.models as models
from ptflops import get_model_complexity_info
net = models.vgg16() #可以为自己搭建的模型
flops, params = get_model_complexity_info(model, (3,512,512), as_strings=True, print_per_layer_stat=True) #(3,512,512)输入图片的尺寸
print("Flops: {}".format(flops))
print("Params: " + params)
版权声明:本文为lijiahao1212原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。