图像分割经典方法2–Graph Based Segmentation之谱聚类(附代码)

  • Post author:
  • Post category:其他




引言

在之前的一篇文章中,我介绍了基于聚类的图像分割,在这一篇文章中,我会介绍另一种图像分割的方法–基于图的图像分割。具体用到的方法是谱聚类。OK, 我们先来简单了解一下谱聚类。



谱聚类(spectral clustering)

首先我们需要明确一点,谱聚类虽然是一种聚类的模型,但是事实上,它的设计初衷确是解决一个关于切割图的问题,因此它的算法也是从图论中演化而来的。具体来说,它的主要思想就是将所有的数据点看成空间中的一个个点(node), 这些点之间可以有边(edge)来进行连接。然后对于两个距离较远的点,连接它们的边权重(可以理解为相似度)会较低,而距离较近的两个点,连接它们的边的权重会比较高。现在我们要解决的问题就是,我们该如何去‘切图‘,让切图后不同的子图之间的权重和尽可能小,而子图内的的点之间的边权重之和尽可能大,通过实现这个目的,我们就间接实现了聚类的目的。

关于具体如何实现这个最优切割,大家可以参考刘建平(Pinard)老师的博客,其中有一篇关于谱聚类的原理总结

博客地址

,讲解的非常详细。在这里我就只对算法流程做一个梳理。

1.根据输入的数据以及定义好的数据点之间的相似度度量构建相似矩阵S以及聚类数k

2.根据得到的相似矩阵S,构建邻接矩阵W。(在构建W时,我们通常采用全连接法,也就是所有点之间的权重都大于0,我们可以定义不同的核函数来实现,最常用的就是高斯核)

3.构建度矩阵D(degree matrix), 假设样本数为n,那么D就是一个nxn的对角矩阵,Dii的值就是第i个样本与其他节点的边权重之和。

4.构建拉普拉斯矩阵L=D-W

5.计算得到L(或者标准化后)的最小的k个特征值对应的k个特征向量。注意这里每个个特征向量的维度都为n(样本个数),将其组合形成nxk的特征矩阵F。这里的每一行都代表一个数据点。(这种方法被称为spectral embedding)。

6.在得到F的基础上,再使用k-means等聚类方法对n个样本进行聚类,得到最后的聚类结果。

观察上面的算法流程,不难发现,谱聚类并不是直接使用数据点,而是借助数据之间的相似度矩阵进行后面的流程。那么这对处理稀松数据的聚类问题是非常有效的,非常巧妙地避免了数据稀松的问题。这一点上,传统的聚类算法比如k-means是很难企及的。


总结一下,个人认为谱聚类非常像是一种模型的连接,首先通过spectral embedding的方法,将图中的每个节点(数据点)转化为向量,再利用经典的聚类算法完成聚类。



基于谱聚类的图像分割

话不多说,直接上代码,说明的话见代码中的注释。

import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from sklearn.feature_extraction import image
from sklearn.cluster import spectral_clustering


class img_seg():

	def load_img(self,path):
		sample_img=cv2.imread(path)[:,:,2] #对于彩色图片,作为例子我们只选取1个channel 
		sample_img=sample_img/255.
		
		#用img_to_graph将img转化为graph,每个位置计算的是相邻像素点这之间的差(梯度)
		graph=image.img_to_graph(sample_img)
		
		#转化为邻接矩阵(从 distance 转化为 similarity),这里做了归一化来保证更好的效果
		gamma=20
		graph.data=np.exp(-gamma*graph.data/graph.data.std())	
		return sample_img,graph

	def visualize(self,img,labels):#对结果可视化

		x1=np.arange(0,img.shape[1])
		x2=np.arange(0,img.shape[0])
		X1,X2=np.meshgrid(x1,x2)

		fig=plt.figure(figsize=(6,6))
		ax=fig.add_subplot(1,1,1)
		ax.imshow(img,cmap=plt.cm.gray,alpha=0.8) #展示原图
		ax.contour(X1,X2,labels,linewidths=1,linestyles='dashdot')#切割线
		ax.get_xaxis().set_visible(False)#去掉坐标轴
		ax.get_yaxis().set_visible(False)
		plt.show()

	def spectral_clustering_seg(self,path):
		img,graph=self.load_img(path)
		#运行谱聚类,这里 assign_labels也可以选择kmeans, 如果eigen_solver想使用'amg',需要先安装 pyamg包
		labels=spectral_clustering(graph,n_clusters=7,assign_labels='discretize',eigen_solver='amg')
		labels=labels.reshape(img.shape)
		self.visualize(img,labels)


_=img_seg()
_.spectral_clustering_seg('paris.jpg')

然后下面是原图以及切割后的效果

在这里插入图片描述

在这里插入图片描述



版权声明:本文为weixin_44607838原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。