
  • Post author:
  • Post category:其他

使用已经在COCO Train 2017数据集的子集上进行训练的FCN,该子集对应于PASCALVOC数据集。模型共支持20个类别。


from torchvision import models             #加载模型
fcn = models.segmentation.fcn_resnet101(pretrained=True).eval()    #基于Resnet101的预先训练的FCN模型。




from PIL import Image
import matplotlib.pyplot as plt
import torch
img = Image.open('C:/Users/ting/Desktop./qcs.jpg')
plt.imshow(img); plt.show()





  • Resize the image to (256 x 256)
  • CenterCrop it to (224 x 224)
  • Convert it to Tensor – all the values in the image will be scaled so they lie between instead of the original, range. [0,1] [0, 255]
  • Normalize it with the Imagenet specific values where mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]
  • And lastly, we unsqueeze the image dimensions so that it becomes from . This is required since we need a batch while passing it through the network . [1 x C x H x W] [C x H x W]
# Apply the transformations needed
import torchvision.transforms as T
trf = T.Compose([T.Resize(256),          #将图像尺寸调整为256×256
                 T.CenterCrop(224),         #中心裁剪,大小为224x224
                 T.ToTensor(),              #将图像转换为张量,并将值缩放到[0,1]范围
                 T.Normalize(mean = [0.485, 0.456, 0.406], 
                             std = [0.229, 0.224, 0.225])])  #用给定的均值和标准差对图像进行正则化。
inp = trf(img).unsqueeze(0)
tensor([[[[ 1.4098,  1.3584,  1.2899,  ...,  0.5364,  0.5364,  0.5364],
          [ 1.4612,  1.4098,  1.3413,  ...,  0.5536,  0.5536,  0.5536],
          [ 1.4954,  1.4440,  1.3755,  ...,  0.5536,  0.5536,  0.5536],
          [-1.6555, -1.6555, -1.6898,  ..., -2.0494, -2.0152, -2.0152],
          [-1.5357, -1.5699, -1.6042,  ..., -2.0494, -2.0152, -2.0152],
          [-1.5185, -1.5357, -1.5699,  ..., -2.0494, -2.0152, -2.0152]],

         [[ 1.4132,  1.3606,  1.2906,  ...,  0.7129,  0.7129,  0.7129],
          [ 1.4657,  1.4132,  1.3431,  ...,  0.7304,  0.7304,  0.7304],
          [ 1.5007,  1.4482,  1.3782,  ...,  0.7304,  0.7304,  0.7304],
          [-1.5455, -1.5455, -1.5805,  ..., -1.9657, -1.9657, -1.9657],
          [-1.4230, -1.4580, -1.4930,  ..., -1.9657, -1.9657, -1.9657],
          [-1.4055, -1.4230, -1.4580,  ..., -1.9657, -1.9657, -1.9657]],

         [[ 1.9254,  1.8731,  1.8034,  ...,  1.1934,  1.1934,  1.1934],
          [ 1.9777,  1.9254,  1.8557,  ...,  1.2108,  1.2108,  1.2108],
          [ 2.0125,  1.9603,  1.8905,  ...,  1.2108,  1.2108,  1.2108],
          [-1.2293, -1.2293, -1.2641,  ..., -1.6999, -1.6824, -1.6824],
          [-1.1073, -1.1421, -1.1770,  ..., -1.6999, -1.6824, -1.6824],
          [-1.0898, -1.1073, -1.1421,  ..., -1.6999, -1.6824, -1.6824]]]])

4.Forward pass through the network

# Pass the input through the net
out = fcn(inp)['out']  #out是模型的最终输出。
print (out.shape)
torch.Size([1, 21, 224, 224])

out是模型的最终输出, [1 x 21 x H x W] 。

我们需要将这21个通道输出到一个2D图像或一个1通道图像,其中该图像的每个像素对应于一个类! 因此,2D图像(形状[HxW])的每个像素将与相应的类标签对应,对于该2D图像中的每个(x,y)像素将对应于表示类的0-20之间的数字。 How?我们



import numpy as np
om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
print (om.shape)
(224, 224)

out.squeeze()去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用。此处对out进行降维,变为[21 x H x W] 。

torch.argmax() 返回指定维度最大值的序号。dim给定的定义是:the demention to reduce.也就是把dim这个维度的,变成这个维度的最大值的index。dim=0说明对第一维(21)操作,故结果为[H x W] 。

降维torch.squeeze(input, dim=None, out=None)

torch.argmax() 函数详解

print (np.unique(om))#np.unique() 函数 去除其中重复的元素 ,并按元素 由小到大 返回一个新的无元素重复的元组或者列表。
[ 0 15]





现在,我们必须从我们拥有的2D图像中创建一个RGB图像。因此,我们所做的是为所有3个通道创建空的2D矩阵。 因此,r、g和b是构成最终图像的RGB通道的列表,这些列表中的每一个的形状都是[HxW](这与2D图像的形状相同)。

现在,我们循环存储在label_colors中的每个颜色,并在存在特定类标签的2D图像中获取索引。然后,对于每个通道,我们将其相应的颜色放置到存在该类标签的像素上。 最后,我们将3个独立的通道叠加起来,形成RGB图像。 好吧!现在,让我们使用这个函数来查看最终的输出!

# Define the helper function
def decode_segmap(image, nc=21):
  label_colors = np.array([(0, 0, 0),  # 0=background
               # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
               (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
               # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
               (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
               # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
               (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
               # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
               (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])
  r = np.zeros_like(image).astype(np.uint8)
  g = np.zeros_like(image).astype(np.uint8)
  b = np.zeros_like(image).astype(np.uint8)
  for l in range(0, nc):
    idx = image == l
    r[idx] = label_colors[l, 0]
    g[idx] = label_colors[l, 1]
    b[idx] = label_colors[l, 2]
  rgb = np.stack([r, g, b], axis=2)
  return rgb

rgb = decode_segmap(om)
plt.imshow(rgb); plt.show()


6.Final Result


def segment(net, path):
  img = Image.open(path)
  plt.imshow(img); plt.axis('off'); plt.show()
  # Comment the Resize and CenterCrop for better inference results
  trf = T.Compose([T.Resize(256), 
                   T.Normalize(mean = [0.485, 0.456, 0.406], 
                               std = [0.229, 0.224, 0.225])])
  inp = trf(img).unsqueeze(0)
  out = net(inp)['out']
  om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
  rgb = decode_segmap(om)
  plt.imshow(rgb); plt.axis('off'); plt.show()

print (segment(fcn ,'C:/Users/ting/Desktop./qcs.jpg'))


版权声明:本文为m0_70813473原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。