图像数据增强方法整理(仿射变换, 颜色调节, 遮挡, 多图融合….) + Tensorflow 2 实现

  • Post author:
  • Post category:其他


图像方面的数据增强可以从下面几个角度来看.

  1. 仿射变换 (Random crop, Random flip, Random rotation, Random zoom, Random shear, Random translation…)
  2. 彩色失真 (Random gamma, Random brightness, Random hue, Random contrast, Gaussian Noise …)
  3. 信息丢弃(Gridmask, Cutout, Random Erasing, Hide-and-seek…)
  4. 多图融合(Mixup, Cutmix, Fmix…)
  5. 另类 (Augmix …)

这里我整理了现在流行的大部分数据增强, 并通过Tensorflow 2 实现. 使用Tensorflow API实现的数据增强在训练时会被GPU/TPU加速, 远比numpy实现(跑在CPU端)的要快.

另外, 有些数据增强(Fmix, Augmix)的Tensorflow版本, 我自己也还没有跑通, 等跑通了, 后再传上来.


加载数据集用于测试

import numpy as np
import tensorflow as tf
from tensorflow.keras import *
import tensorflow.keras.backend as B
import matplotlib.pyplot as plt
import math

(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = np.repeat(np.expand_dims(x_train / 255.0 / 2 , -1), 3, axis = -1).astype(np.float32) # expand to 3 channels
x_test = np.repeat(np.expand_dims(x_test / 255.0 / 2, -1), 3, axis = -1).astype(np.float32) # expand to 3 channels
y_train = np.eye(10)[np.reshape(y_train, -1)].astype(np.float32) # one-hot
y_test = np.eye(10)[np.reshape(y_test, - 1)].astype(np.float32) # one-hot

print(x_train.shape, y_train.shape) # (60000, 28, 28, 3) (60000, 10)
print(x_test.shape, y_test.shape) 	# (10000, 28, 28, 3) (10000, 10)

def xy_visualization(x, y):  
    plt.imshow(x)
    plt.axis('off')
    plt.show()
    print(y.numpy())


颜色方面(调节对比度, 调节亮度, 调节Hue, 添加饱和度, 添加高斯噪音)

@tf.function
def random_contrast(img_batch, lower = 0.9, upper=1.0):
    # (x - mean) * contrast_factor + mean.
    return tf.image.random_contrast(img_batch, lower, upper)

@tf.function
def random_brightness(img_batch, delta = 0.05):
    # x + delta
    return tf.image.random_brightness(img_batch, delta)

@tf.function
def random_hue(img_batch, delta = 0.01):
    # 将RGB图像转换为浮点表示形式,将其转换为HSV,向色调通道添加偏移量,再转换回RGB
    return tf.image.random_hue(img_batch, delta)

@tf.function
def random_saturation(img_batch, lower = 0.0, upper=0.01):
    # 将RGB图像转换为浮点表示形式,将其转换为HSV,向饱和通道添加偏移量,再转换回RGB,然后返回原始数据类型.
    return tf.image.random_saturation(img_batch, lower, upper)

@tf.function
def random_gaussian_noise(img_batch, noise_scale = 0.01):
    gaussian_noise = (noise_scale * tf.random.normal(tf.shape(img_batch), mean=0.0, stddev=1., dtype=tf.float32))
    return tf.clip_by_value(img_batch + gaussian_noise, 0., 1.)
    
@tf.function
def color_aug(img_batch, label_batch):
    # img_batch [N, image_h, img_w, img_channels]
    # label_batch [N, num_classes]
    img_batch = random_contrast(img_batch)
    img_batch = random_brightness(img_batch)
    img_batch = random_hue(img_batch)
    img_batch = random_saturation(img_batch)
    img_batch = random_gaussian_noise(img_batch)
    return img_batch, label_batch

batch_size = 32
batch = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(20).repeat(1).batch(batch_size)
batch = batch.map(lambda a, b : color_aug(a, b))
    
x, y = batch.__iter__().__next__()
for i in range(5):
    xy_visualization(x[i], y[i])

效果

在这里插入图片描述


仿射变换 (随机缩放, 随机剪裁, 随机翻转, 随机旋转, 随机错切, 随机偏移)


仿射变换原理请参考:

https://www.cnblogs.com/happystudyeveryday/p/10547316.html

@tf.function
def aug_affine(img_batch, label_batch, rotation=2, shear=0, zoom=0, shift=16, flip=3):
    batch_shape, batch_size, img_h, img_w = tf.shape(img_batch), tf.shape(img_batch)[0], img_batch.shape[1], img_batch.shape[2]
    # returns 3x3 transformmatrix which transforms indicies
    # CONVERT
    one = tf.ones([batch_size, 1], dtype='float32') # (Batch Size, 1)
    zero = tf.zeros([batch_size, 1], dtype='float32') # (Batch Size, 1)
    m_list = []
    # ROTATION MATRIX
    if rotation != 0:
        rotation = math.pi * tf.random.uniform(shape=[batch_size, 1], maxval=rotation) / 180. # (Batch Size, 1)
        c1 = tf.math.cos(rotation) # (Batch Size, 1)
        s1 = tf.math.sin(rotation) # (Batch Size, 1)
        rotation_matrix = tf.concat([c1, s1, zero, -s1, c1, zero, zero, zero, one], axis=-1) # (Batch Size, 9)
        rotation_matrix = tf.reshape(rotation_matrix, [-1, 3, 3]) # (Batch Size, 3, 3)
        #m = rotation_matrix if (m is None) else B.batch_dot(m, rotation_matrix)
        m_list.append(rotation_matrix)
    # SHEAR MATRIX
    if shear != 0:
        shear = math.pi * tf.random.uniform(shape=[batch_size, 1], maxval=shear) / 180.       # (Batch Size, 1)
        c2 = tf.math.cos(shear) # (Batch Size, 1)
        s2 = tf.math.sin(shear) # (Batch Size, 1)
        shear_matrix = tf.concat([one, s2, zero, zero, c2, zero, zero, zero, one], axis=-1) # (Batch Size, 9)
        shear_matrix = tf.reshape(shear_matrix, [-1, 3, 3]) # (Batch Size, 3, 3)
        #m = shear_matrix if (m is None) else B.batch_dot(m, shear_matrix)
        m_list.append(shear_matrix)
    # ZOOM MATRIX
    if zoom != 0:
        width_zoom = tf.random.uniform(shape=[batch_size, 1], minval=1-zoom, maxval=1+zoom) # (Batch Size, 1)
        height_zoom = tf.random.uniform(shape=[batch_size, 1], minval=1-zoom, maxval=1+zoom) # (Batch Size, 1)
        zoom_matrix = tf.concat([one / height_zoom, zero, zero, zero, one / width_zoom, zero, zero, zero, one], axis=-1) # (Batch Size, 9)
        zoom_matrix = tf.reshape(zoom_matrix, [-1, 3, 3]) # (Batch Size, 3, 3)
        #m = zoom_matrix if (m is None) else B.batch_dot(m, zoom_matrix)
        m_list.append(zoom_matrix)
    # SHIFT MATRIX
    if shift != 0:
        height_shift = tf.random.uniform(shape=[batch_size, 1], minval=-shift, maxval=shift) # (Batch Size, 1)
        width_shift = tf.random.uniform(shape=[batch_size, 1], minval=-shift, maxval=shift) # (Batch Size, 1)
        shift_matrix = tf.concat([one, zero, height_shift, zero, one, width_shift, zero, zero, one], axis=-1)
        shift_matrix = tf.reshape(shift_matrix, [-1, 3, 3]) # (Batch Size, 3, 3)
        #m = shift_matrix if (m is None) else B.batch_dot(m, shift_matrix)
        m_list.append(shift_matrix)
    # FLIP MATRIX
    if flip != 0:
        # 1: left_right 2: up_down 3: both
        flip_y = tf.where(tf.random.uniform(shape=[batch_size, 1]) >= (0.5 if (flip == 1 or flip == 3) else 0), 1., -1.)
        flip_x = tf.where(tf.random.uniform(shape=[batch_size, 1]) >= (0.5 if (flip == 2 or flip == 3) else 0), 1., -1.)
        flip_matrix = tf.concat([flip_x, zero, zero, zero, flip_y, zero, zero, zero, one], axis=-1)
        flip_matrix = tf.reshape(flip_matrix, [-1, 3, 3]) # (Batch Size, 3, 3)
        #m = flip_matrix if (m is None) else B.batch_dot(m, flip_matrix)
        m_list.append(flip_matrix)
    if len(m_list) > 0:
        # MERGE MATRIX
        m_list = tf.unstack(tf.random.shuffle(tf.stack(m_list, axis=0)), axis=0) # List of (Batch Size, 3, 3)
        m = reduce((lambda x, y: B.batch_dot(x, y)), m_list)
        # LIST DESTINATION PIXEL INDICES
        
        x, y = tf.meshgrid(tf.range(img_w//2, -img_w//2, -1), tf.range(-img_h//2, img_h//2, 1)) # (Img_h, Img_w)
        x, y = tf.reshape(x, [-1]), tf.reshape(y, [-1])

        z = tf.ones([img_h * img_w], tf.int32) # (IMG_H * IMG_W)
        idx = tf.stack([x, y, z])              # (3, IMG_H * IMG_W)
        idx = tf.repeat(tf.expand_dims(idx, axis=0), batch_size, axis=0)  # (Batch Size, 3, IMG_H * IMG_W)
        # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
        idx = tf.cast(B.batch_dot(m, tf.cast(idx, tf.float32)), tf.int32) # (Batch Size, 3, IMG_H * IMG_W)
        x_idx, y_idx = idx[:,0,:], idx[:,1,:]
        x_idx_valid = tf.math.logical_and(-img_w//2 + img_w%2 + 1 <= x_idx, x_idx <= img_w//2) # (Batch Size, IMG_H * IMG_W)
        y_idx_valid = tf.math.logical_and(-img_h//2 + img_h%2 + 1 <= y_idx, y_idx <= img_h//2) # (Batch Size, IMG_H * IMG_W)
        idx_valid = tf.math.logical_and(x_idx_valid, y_idx_valid) # # (Batch Size, IMG_H * IMG_W)
        x_idx = tf.where(x_idx_valid, x_idx, 0)
        y_idx = tf.where(y_idx_valid, y_idx, 0)
        # FIND ORIGIN PIXEL VALUES           
        idx = tf.stack([img_h//2-1+y_idx, img_w//2-x_idx],axis=-1) # (Batch Size, IMG_H * IMG_W, 2)
        img_batch = tf.gather_nd(img_batch, idx, batch_dims=1) # (Batch Size, IMG_H * IMG_W, 3)
        img_batch = tf.where(tf.expand_dims(idx_valid, axis=-1), img_batch, 0)
        img_batch = tf.reshape(img_batch, batch_shape)
    return img_batch, label_batch

batch_size = 32
batch = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(20).repeat(1).batch(batch_size)
batch = batch.map(lambda a, b : aug_affine(a, b))
    
x, y = batch.__iter__().__next__()
for i in range(8):
    xy_visualization(x[i], y[i])

效果

在这里插入图片描述


遮挡 (Grid Mask)

@tf.function
def aug_grid_mask(img_batch, label_batch, d1=4, d2=10, ratio=0.5):
    batch_size, img_h, img_w = tf.shape(img_batch)[0], img_batch.shape[1], img_batch.shape[2]
    d = tf.random.uniform(shape=[batch_size, 1], minval=d1, maxval=d2, dtype=tf.int32) # (Batch Size, 1)
    d_f32 = tf.cast(d, tf.float32) # (Batch Size, 1)
    y_range = tf.reshape(tf.range(img_h), (1, -1)) + tf.cast(tf.random.uniform(shape=[batch_size, 1], dtype=tf.float32) * d_f32 - 1, tf.int32)# (Batch Size, IMG_H)
    x_range = tf.reshape(tf.range(img_w), (1, -1)) + tf.cast(tf.random.uniform(shape=[batch_size, 1], dtype=tf.float32) * d_f32 - 1, tf.int32) # (Batch Size, IMG_W)
    y_range = tf.expand_dims(tf.cast(y_range % d, tf.float32) / d_f32 >= ratio, axis=2) # (Batch Size, IMG_H, 1)
    x_range = tf.expand_dims(tf.cast(x_range % d, tf.float32) / d_f32 >= ratio, axis=1) # (Batch Size, 1, IMG_W)
    mask = tf.expand_dims(tf.math.logical_and(y_range, x_range) == False, axis=-1) # (Batch Size, IMG_H, IMG_W, 1)
    img_batch =  tf.where(mask, img_batch, 0)
    return img_batch, label_batch


batch_size = 32
batch = tf.data.Dataset.from_tensor_slices((1.0 - x_train, y_train)).repeat(1).batch(batch_size)
batch = batch.map(lambda a, b : aug_grid_mask(a, b))

x, y = batch.__iter__().__next__()
for i in range(5):
    xy_visualization(x[i], y[i])

效果

在这里插入图片描述


多图融合 (Cutmix)

@tf.function
def cutmix(img_batch, label_batch):
    batch_size, img_h, img_w = tf.shape(img_batch)[0], img_batch.shape[1], img_batch.shape[2]
    # CHOOSE RANDOM LOCATION
    cut_xs = tf.cast(tf.random.uniform([batch_size], 0, tf.cast(img_w, tf.float32)), tf.int32) # (Batch Size)
    cut_ys = tf.cast(tf.random.uniform([batch_size], 0, tf.cast(img_h, tf.float32)), tf.int32) # (Batch Size)
    cut_ratios = tf.math.sqrt(1 - tf.random.uniform([batch_size], 0, 1)) # cut ratio
    cut_ws = tf.cast(tf.cast(img_w, tf.float32) * cut_ratios, tf.int32)
    cut_hs = tf.cast(tf.cast(img_h, tf.float32) * cut_ratios, tf.int32)
    yas = tf.math.maximum(0, cut_ys - cut_hs // 2)      # (Batch Size)
    ybs = tf.math.minimum(img_h, cut_ys + cut_hs // 2)  # (Batch Size)
    xas = tf.math.maximum(0, cut_xs - cut_ws // 2)      # (Batch Size)
    xbs = tf.math.minimum(img_w, cut_xs + cut_ws // 2)  # (Batch Size)
    # CHOOSE RANDOM IMAGE TO CUTMIX WITH
    index = tf.random.shuffle(tf.range(batch_size, dtype=tf.int32))
    x1, x2 = img_batch, tf.gather(img_batch, index)
    y1, y2 = label_batch, tf.gather(label_batch, index)
    
    X, Y = tf.meshgrid(tf.range(img_w), tf.range(img_h))
    X = tf.expand_dims(X, axis=0) # (1, img_h, img_w)
    Y = tf.expand_dims(Y, axis=0) # (1, img_h, img_w)
    img_weight = tf.math.logical_and(tf.math.logical_and(tf.reshape(xas, (-1, 1, 1)) <= X, X <= tf.reshape(xbs, (-1, 1, 1)) ), 
                                     tf.math.logical_and(tf.reshape(yas, (-1, 1, 1)) <= Y, Y <= tf.reshape(ybs, (-1, 1, 1)) ))
    img_weight = tf.expand_dims(img_weight, axis=-1) # (Batch Size, img_h, img_w, 1)
    img_batch = tf.where(img_weight, x2, x1)

    label_weight = tf.cast((ybs - yas) * (xbs - xas) / (img_h * img_w), tf.float32) # (Batch Size)
    label_weight = tf.expand_dims(label_weight, axis=-1) #(Batch Size, 1)
    label_batch = (label_weight) * y2 + (1 - label_weight) * y1
        
    return img_batch, label_batch

batch_size = 32
batch = tf.data.Dataset.from_tensor_slices((x_train, y_train)).repeat(1).batch(batch_size)
batch = batch.map(lambda a, b : cutmix(a, b))
    
x, y = batch.__iter__().__next__()
for i in range(5): 
    xy_visualization(x[i], y[i])

结果

在这里插入图片描述


多图融合 (Mixup)

@tf.function
def aug_mixup(img_batch, label_batch):
    # img_batch [N, image_h, img_w, img_channels]
    # label_batch [N, num_classes]
    batch_size = tf.shape(img_batch)[0]
    weight = tf.math.sqrt(1 - tf.random.uniform([batch_size])) #beta distribution
    x_weight = tf.reshape(weight, [batch_size, 1, 1, 1])
    y_weight = tf.reshape(weight, [batch_size, 1])
    index = tf.random.shuffle(tf.range(batch_size, dtype=tf.int32))
    x1, x2 = img_batch, tf.gather(img_batch, index)
    img_batch = x1 * x_weight + x2 * (1. - x_weight)
    y1, y2 = label_batch, tf.gather(label_batch, index)
    label_batch = y1 * y_weight + y2 * (1. - y_weight)
    return img_batch, label_batch

batch = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(20).repeat(1).batch(512)
batch = batch.map(lambda a, b : aug_mixup(a, b))
    
x, y = batch.__iter__().__next__()
xy_visualization(x[0], y[0])

在这里插入图片描述


多图融合 (FMix)



Todo


另类 (AugMix)



Todo



版权声明:本文为weixin_39576203原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。