pytorch VGG网络特征可视化

  • Post author:
  • Post category:其他


本代码解决问题:

将通过VGG网络或其他网络得到的特征图feats进行可视化输出,其中feats.shape=[N, C, H, W]

N为一组图片数,

C为通道数,

H为图片的高度,

W为图片的宽度

代码说明:

代码命名为featsVisual.py

使用时 import featsVisual as FV

如下:其中conv2_2为通过网络得到的feats

在目录下的featsVisual文件夹自动生成以conv2_2命名的文件夹,将可视化的特征图保存到此文件夹中。

如下图所示:

在这里插入图片描述

FV.show_feature_map(conv2_2, FV.retrieve_name(conv2_2))
import torch
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import cv2

import os

import inspect
import shutil


def retrieve_name(var):
    '''
    utils:
    get back the name of variables
    '''
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    return [var_name for var_name, var_val in callers_local_vars if var_val is var]

def show_feature_map(feature_map, name):
    N, C, H, W = feature_map.shape
    feature_map = feature_map.cpu().detach().numpy()  # 进行卷积运算后转化为numpy格式
    path_temp = './featsVisual'
    featsVisual_dir_list = os.listdir(path_temp)
    path_temp_1 = path_temp + '/' + name[0]
    if name[0] in featsVisual_dir_list:
        shutil.rmtree(path_temp_1)
    os.mkdir(path_temp_1)
    path_temp_2 = path_temp + '/' + name[0]
    # shape = [N, 1, H, W]
    for itemN in range(0, N):
        for itemC in range(0, C):
            feature_map_1 = feature_map[itemN, itemC, :, :]
            # print(feature_map_1)
            # plt.figure()
            #plt.imshow(feature_map_1)
            # plt.imshow(feature_map_1, cmap='gray')
            # cv2.imwrite(path_temp_2 + "/{}-{}.jpg".format(itemN, itemC), feature_map_1 * 255)
            # plt.show()
            matplotlib.image.imsave(path_temp_2 + "/{}-{}.png".format(itemN, itemC), feature_map_1 * 255)
    print("featsVisual over ...")
    os.system("pause")






版权声明:本文为qq_37909691原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。