# -*- coding: utf-8 -*-
"""
@Author : zhwzhong
@License : (C) Copyright 2013-2018, hit
@Contact : zhwzhong.hit@gmail.com
@Software: PyCharm
@File : metrics.py
@Time : 2018/10/8 09:40
@Desc :
"""
import numpy as np
from skimage.measure import compare_psnr, compare_ssim
def compare_ergas(x_true, x_pred, ratio):
"""
Calculate ERGAS, ERGAS offers a global indication of the quality of fused image.The ideal value is 0.
:param x_true:
:param x_pred:
:param ratio: 上采样系数
:return:
"""
x_true, x_pred = img_2d_mat(x_true=x_true, x_pred=x_pred)
sum_ergas = 0
for i in range(x_true.shape[0]):
vec_x = x_true[i]
vec_y = x_pred[i]
r_mse = np.linalg.norm(vec_x - vec_y) / np.sqrt(x_true.shape[1])
tmp = np.square(r_mse / vec_x.mean())
sum_ergas += tmp
return (100 / ratio) * np.sqrt(sum_ergas / x_true.shape[0])
def compare_sam(x_true, x_pred):
"""
:param x_true: 高光谱图像:格式:(H, W, C)
:param x_pred: 高光谱图像:格式:(H, W, C)
:return: 计算原始高光谱数据与重构高光谱数据的光谱角相似度
"""
num = 0
sum_sam = 0
x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
for x in range(x_true.shape[0]):
for y in range(x_true.shape[1]):
tmp_pred = x_pred[x, y].ravel()
tmp_true = x_true[x, y].ravel()
if np.linalg.norm(tmp_true) != 0 and np.linalg.norm(tmp_pred) != 0:
sum_sam += np.arccos(
np.inner(tmp_pred, tmp_true) / (np.linalg.norm(tmp_true) * np.linalg.norm(tmp_pred)))
num += 1
sam_deg = (sum_sam / num) * 180 / np.pi
return sam_deg
def compare_corr(x_true, x_pred):
"""
Calculate the cross correlation between x_pred and x_true.
求对应波段的相关系数,然后取均值
CC is a spatial measure.
"""
x_true, x_pred = img_2d_mat(x_true=x_true, x_pred=x_pred)
x_true = x_true - np.mean(x_true, axis=1).reshape(-1, 1)
x_pred = x_pred - np.mean(x_pred, axis=1).reshape(-1, 1)
numerator = np.sum(x_true * x_pred, axis=1).reshape(-1, 1)
denominator = np.sqrt(np.sum(x_true * x_true, axis=1) * np.sum(x_pred * x_pred, axis=1)).reshape(-1, 1)
return (numerator / denominator).mean()
def img_2d_mat(x_true, x_pred):
"""
# 将三维的多光谱图像转为2位矩阵
:param x_true: (H, W, C)
:param x_pred: (H, W, C)
:return: a matrix which shape is (C, H * W)
"""
h, w, c = x_true.shape
x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
x_mat = np.zeros((c, h * w), dtype=np.float32)
y_mat = np.zeros((c, h * w), dtype=np.float32)
for i in range(c):
x_mat[i] = x_true[:, :, i].reshape((1, -1))
y_mat[i] = x_pred[:, :, i].reshape((1, -1))
return x_mat, y_mat
def compare_rmse(x_true, x_pred):
"""
Calculate Root mean squared error
:param x_true:
:param x_pred:
:return:
"""
x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
return np.linalg.norm(x_true - x_pred) / (np.sqrt(x_true.shape[0] * x_true.shape[1] * x_true.shape[2]))
def compare_mpsnr(x_true, x_pred, data_range):
"""
:param x_true: Input image must have three dimension (H, W, C)
:param x_pred:
:return:
"""
x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
channels = x_true.shape[2]
total_psnr = [compare_psnr(im_true=x_true[:, :, k], im_test=x_pred[:, :, k], data_range=data_range)
for k in range(channels)]
return np.mean(total_psnr)
def compare_mssim(x_true, x_pred, data_range, multidimension):
"""
:param x_true:
:param x_pred:
:param data_range:
:param multidimension:
:return:
"""
return compare_ssim(X=x_true, Y=x_pred, data_range=data_range, multidimension=multidimension)
def compare_sid(x_true, x_pred):
"""
SID is an information theoretic measure for spectral similarity and discriminability.
:param x_true:
:param x_pred:
:return:
"""
x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
N = x_true.shape[2]
err = np.zeros(N)
for i in range(N):
err[i] = abs(np.sum(x_pred[:, :, i] * np.log10((x_pred[:, :, i] + 1e-3) / (x_true[:, :, i] + 1e-3))) +
np.sum(x_true[:, :, i] * np.log10((x_true[:, :, i] + 1e-3) / (x_pred[:, :, i] + 1e-3))))
return np.mean(err / (x_true.shape[1] * x_true.shape[0]))
def compare_appsa(x_true, x_pred):
"""
:param x_true:
:param x_pred:
:return:
"""
x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
nom = np.sum(x_true * x_pred, axis=2)
denom = np.linalg.norm(x_true, axis=2) * np.linalg.norm(x_pred, axis=2)
cos = np.where((nom / (denom + 1e-3)) > 1, 1, (nom / (denom + 1e-3)))
appsa = np.arccos(cos)
return np.sum(appsa) / (x_true.shape[1] * x_true.shape[0])
def compare_mare(x_true, x_pred):
"""
:param x_true:
:param x_pred:
:return:
"""
x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
diff = x_true - x_pred
abs_diff = np.abs(diff)
relative_abs_diff = np.divide(abs_diff, x_true + 1) # added epsilon to avoid division by zero.
return np.mean(relative_abs_diff)
def quality_assessment(x_true, x_pred, data_range, ratio, multi_dimension):
"""
:param multi_dimension:
:param ratio:
:param data_range:
:param x_true:
:param x_pred:
:return:
"""
result = {'MPSNR': compare_mpsnr(x_true=x_true, x_pred=x_pred, data_range=data_range),
'MSSIM': compare_mssim(x_true=x_true, x_pred=x_pred, data_range=data_range,
multidimension=multi_dimension),
'ERGAS': compare_ergas(x_true=x_true, x_pred=x_pred, ratio=ratio),
'SAM': compare_sam(x_true=x_true, x_pred=x_pred),
'SID': compare_sid(x_true=x_true, x_pred=x_pred),
'CrossCorrelation': compare_corr(x_true=x_true, x_pred=x_pred),
'RMSE': compare_rmse(x_true=x_true, x_pred=x_pred),
'APPSA': compare_appsa(x_true=x_true, x_pred=x_pred),
'MARE': compare_mare(x_true=x_true, x_pred=x_pred)
}
return result
版权声明:本文为qq_34535410原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。