本代码解决问题:
将通过VGG网络或其他网络得到的特征图feats进行可视化输出,其中feats.shape=[N, C, H, W]
N为一组图片数,
C为通道数,
H为图片的高度,
W为图片的宽度
代码说明:
代码命名为featsVisual.py
使用时 import featsVisual as FV
如下:其中conv2_2为通过网络得到的feats
在目录下的featsVisual文件夹自动生成以conv2_2命名的文件夹,将可视化的特征图保存到此文件夹中。
如下图所示:
FV.show_feature_map(conv2_2, FV.retrieve_name(conv2_2))
import torch
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import inspect
import shutil
def retrieve_name(var):
'''
utils:
get back the name of variables
'''
callers_local_vars = inspect.currentframe().f_back.f_locals.items()
return [var_name for var_name, var_val in callers_local_vars if var_val is var]
def show_feature_map(feature_map, name):
N, C, H, W = feature_map.shape
feature_map = feature_map.cpu().detach().numpy() # 进行卷积运算后转化为numpy格式
path_temp = './featsVisual'
featsVisual_dir_list = os.listdir(path_temp)
path_temp_1 = path_temp + '/' + name[0]
if name[0] in featsVisual_dir_list:
shutil.rmtree(path_temp_1)
os.mkdir(path_temp_1)
path_temp_2 = path_temp + '/' + name[0]
# shape = [N, 1, H, W]
for itemN in range(0, N):
for itemC in range(0, C):
feature_map_1 = feature_map[itemN, itemC, :, :]
# print(feature_map_1)
# plt.figure()
#plt.imshow(feature_map_1)
# plt.imshow(feature_map_1, cmap='gray')
# cv2.imwrite(path_temp_2 + "/{}-{}.jpg".format(itemN, itemC), feature_map_1 * 255)
# plt.show()
matplotlib.image.imsave(path_temp_2 + "/{}-{}.png".format(itemN, itemC), feature_map_1 * 255)
print("featsVisual over ...")
os.system("pause")
版权声明:本文为qq_37909691原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。