图像方面的数据增强可以从下面几个角度来看.
- 仿射变换 (Random crop, Random flip, Random rotation, Random zoom, Random shear, Random translation…)
- 彩色失真 (Random gamma, Random brightness, Random hue, Random contrast, Gaussian Noise …)
- 信息丢弃(Gridmask, Cutout, Random Erasing, Hide-and-seek…)
- 多图融合(Mixup, Cutmix, Fmix…)
- 另类 (Augmix …)
这里我整理了现在流行的大部分数据增强, 并通过Tensorflow 2 实现. 使用Tensorflow API实现的数据增强在训练时会被GPU/TPU加速, 远比numpy实现(跑在CPU端)的要快.
另外, 有些数据增强(Fmix, Augmix)的Tensorflow版本, 我自己也还没有跑通, 等跑通了, 后再传上来.
加载数据集用于测试
import numpy as np
import tensorflow as tf
from tensorflow.keras import *
import tensorflow.keras.backend as B
import matplotlib.pyplot as plt
import math
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = np.repeat(np.expand_dims(x_train / 255.0 / 2 , -1), 3, axis = -1).astype(np.float32) # expand to 3 channels
x_test = np.repeat(np.expand_dims(x_test / 255.0 / 2, -1), 3, axis = -1).astype(np.float32) # expand to 3 channels
y_train = np.eye(10)[np.reshape(y_train, -1)].astype(np.float32) # one-hot
y_test = np.eye(10)[np.reshape(y_test, - 1)].astype(np.float32) # one-hot
print(x_train.shape, y_train.shape) # (60000, 28, 28, 3) (60000, 10)
print(x_test.shape, y_test.shape) # (10000, 28, 28, 3) (10000, 10)
def xy_visualization(x, y):
plt.imshow(x)
plt.axis('off')
plt.show()
print(y.numpy())
颜色方面(调节对比度, 调节亮度, 调节Hue, 添加饱和度, 添加高斯噪音)
@tf.function
def random_contrast(img_batch, lower = 0.9, upper=1.0):
# (x - mean) * contrast_factor + mean.
return tf.image.random_contrast(img_batch, lower, upper)
@tf.function
def random_brightness(img_batch, delta = 0.05):
# x + delta
return tf.image.random_brightness(img_batch, delta)
@tf.function
def random_hue(img_batch, delta = 0.01):
# 将RGB图像转换为浮点表示形式,将其转换为HSV,向色调通道添加偏移量,再转换回RGB
return tf.image.random_hue(img_batch, delta)
@tf.function
def random_saturation(img_batch, lower = 0.0, upper=0.01):
# 将RGB图像转换为浮点表示形式,将其转换为HSV,向饱和通道添加偏移量,再转换回RGB,然后返回原始数据类型.
return tf.image.random_saturation(img_batch, lower, upper)
@tf.function
def random_gaussian_noise(img_batch, noise_scale = 0.01):
gaussian_noise = (noise_scale * tf.random.normal(tf.shape(img_batch), mean=0.0, stddev=1., dtype=tf.float32))
return tf.clip_by_value(img_batch + gaussian_noise, 0., 1.)
@tf.function
def color_aug(img_batch, label_batch):
# img_batch [N, image_h, img_w, img_channels]
# label_batch [N, num_classes]
img_batch = random_contrast(img_batch)
img_batch = random_brightness(img_batch)
img_batch = random_hue(img_batch)
img_batch = random_saturation(img_batch)
img_batch = random_gaussian_noise(img_batch)
return img_batch, label_batch
batch_size = 32
batch = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(20).repeat(1).batch(batch_size)
batch = batch.map(lambda a, b : color_aug(a, b))
x, y = batch.__iter__().__next__()
for i in range(5):
xy_visualization(x[i], y[i])
效果
仿射变换 (随机缩放, 随机剪裁, 随机翻转, 随机旋转, 随机错切, 随机偏移)
仿射变换原理请参考:
https://www.cnblogs.com/happystudyeveryday/p/10547316.html
@tf.function
def aug_affine(img_batch, label_batch, rotation=2, shear=0, zoom=0, shift=16, flip=3):
batch_shape, batch_size, img_h, img_w = tf.shape(img_batch), tf.shape(img_batch)[0], img_batch.shape[1], img_batch.shape[2]
# returns 3x3 transformmatrix which transforms indicies
# CONVERT
one = tf.ones([batch_size, 1], dtype='float32') # (Batch Size, 1)
zero = tf.zeros([batch_size, 1], dtype='float32') # (Batch Size, 1)
m_list = []
# ROTATION MATRIX
if rotation != 0:
rotation = math.pi * tf.random.uniform(shape=[batch_size, 1], maxval=rotation) / 180. # (Batch Size, 1)
c1 = tf.math.cos(rotation) # (Batch Size, 1)
s1 = tf.math.sin(rotation) # (Batch Size, 1)
rotation_matrix = tf.concat([c1, s1, zero, -s1, c1, zero, zero, zero, one], axis=-1) # (Batch Size, 9)
rotation_matrix = tf.reshape(rotation_matrix, [-1, 3, 3]) # (Batch Size, 3, 3)
#m = rotation_matrix if (m is None) else B.batch_dot(m, rotation_matrix)
m_list.append(rotation_matrix)
# SHEAR MATRIX
if shear != 0:
shear = math.pi * tf.random.uniform(shape=[batch_size, 1], maxval=shear) / 180. # (Batch Size, 1)
c2 = tf.math.cos(shear) # (Batch Size, 1)
s2 = tf.math.sin(shear) # (Batch Size, 1)
shear_matrix = tf.concat([one, s2, zero, zero, c2, zero, zero, zero, one], axis=-1) # (Batch Size, 9)
shear_matrix = tf.reshape(shear_matrix, [-1, 3, 3]) # (Batch Size, 3, 3)
#m = shear_matrix if (m is None) else B.batch_dot(m, shear_matrix)
m_list.append(shear_matrix)
# ZOOM MATRIX
if zoom != 0:
width_zoom = tf.random.uniform(shape=[batch_size, 1], minval=1-zoom, maxval=1+zoom) # (Batch Size, 1)
height_zoom = tf.random.uniform(shape=[batch_size, 1], minval=1-zoom, maxval=1+zoom) # (Batch Size, 1)
zoom_matrix = tf.concat([one / height_zoom, zero, zero, zero, one / width_zoom, zero, zero, zero, one], axis=-1) # (Batch Size, 9)
zoom_matrix = tf.reshape(zoom_matrix, [-1, 3, 3]) # (Batch Size, 3, 3)
#m = zoom_matrix if (m is None) else B.batch_dot(m, zoom_matrix)
m_list.append(zoom_matrix)
# SHIFT MATRIX
if shift != 0:
height_shift = tf.random.uniform(shape=[batch_size, 1], minval=-shift, maxval=shift) # (Batch Size, 1)
width_shift = tf.random.uniform(shape=[batch_size, 1], minval=-shift, maxval=shift) # (Batch Size, 1)
shift_matrix = tf.concat([one, zero, height_shift, zero, one, width_shift, zero, zero, one], axis=-1)
shift_matrix = tf.reshape(shift_matrix, [-1, 3, 3]) # (Batch Size, 3, 3)
#m = shift_matrix if (m is None) else B.batch_dot(m, shift_matrix)
m_list.append(shift_matrix)
# FLIP MATRIX
if flip != 0:
# 1: left_right 2: up_down 3: both
flip_y = tf.where(tf.random.uniform(shape=[batch_size, 1]) >= (0.5 if (flip == 1 or flip == 3) else 0), 1., -1.)
flip_x = tf.where(tf.random.uniform(shape=[batch_size, 1]) >= (0.5 if (flip == 2 or flip == 3) else 0), 1., -1.)
flip_matrix = tf.concat([flip_x, zero, zero, zero, flip_y, zero, zero, zero, one], axis=-1)
flip_matrix = tf.reshape(flip_matrix, [-1, 3, 3]) # (Batch Size, 3, 3)
#m = flip_matrix if (m is None) else B.batch_dot(m, flip_matrix)
m_list.append(flip_matrix)
if len(m_list) > 0:
# MERGE MATRIX
m_list = tf.unstack(tf.random.shuffle(tf.stack(m_list, axis=0)), axis=0) # List of (Batch Size, 3, 3)
m = reduce((lambda x, y: B.batch_dot(x, y)), m_list)
# LIST DESTINATION PIXEL INDICES
x, y = tf.meshgrid(tf.range(img_w//2, -img_w//2, -1), tf.range(-img_h//2, img_h//2, 1)) # (Img_h, Img_w)
x, y = tf.reshape(x, [-1]), tf.reshape(y, [-1])
z = tf.ones([img_h * img_w], tf.int32) # (IMG_H * IMG_W)
idx = tf.stack([x, y, z]) # (3, IMG_H * IMG_W)
idx = tf.repeat(tf.expand_dims(idx, axis=0), batch_size, axis=0) # (Batch Size, 3, IMG_H * IMG_W)
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx = tf.cast(B.batch_dot(m, tf.cast(idx, tf.float32)), tf.int32) # (Batch Size, 3, IMG_H * IMG_W)
x_idx, y_idx = idx[:,0,:], idx[:,1,:]
x_idx_valid = tf.math.logical_and(-img_w//2 + img_w%2 + 1 <= x_idx, x_idx <= img_w//2) # (Batch Size, IMG_H * IMG_W)
y_idx_valid = tf.math.logical_and(-img_h//2 + img_h%2 + 1 <= y_idx, y_idx <= img_h//2) # (Batch Size, IMG_H * IMG_W)
idx_valid = tf.math.logical_and(x_idx_valid, y_idx_valid) # # (Batch Size, IMG_H * IMG_W)
x_idx = tf.where(x_idx_valid, x_idx, 0)
y_idx = tf.where(y_idx_valid, y_idx, 0)
# FIND ORIGIN PIXEL VALUES
idx = tf.stack([img_h//2-1+y_idx, img_w//2-x_idx],axis=-1) # (Batch Size, IMG_H * IMG_W, 2)
img_batch = tf.gather_nd(img_batch, idx, batch_dims=1) # (Batch Size, IMG_H * IMG_W, 3)
img_batch = tf.where(tf.expand_dims(idx_valid, axis=-1), img_batch, 0)
img_batch = tf.reshape(img_batch, batch_shape)
return img_batch, label_batch
batch_size = 32
batch = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(20).repeat(1).batch(batch_size)
batch = batch.map(lambda a, b : aug_affine(a, b))
x, y = batch.__iter__().__next__()
for i in range(8):
xy_visualization(x[i], y[i])
效果
遮挡 (Grid Mask)
@tf.function
def aug_grid_mask(img_batch, label_batch, d1=4, d2=10, ratio=0.5):
batch_size, img_h, img_w = tf.shape(img_batch)[0], img_batch.shape[1], img_batch.shape[2]
d = tf.random.uniform(shape=[batch_size, 1], minval=d1, maxval=d2, dtype=tf.int32) # (Batch Size, 1)
d_f32 = tf.cast(d, tf.float32) # (Batch Size, 1)
y_range = tf.reshape(tf.range(img_h), (1, -1)) + tf.cast(tf.random.uniform(shape=[batch_size, 1], dtype=tf.float32) * d_f32 - 1, tf.int32)# (Batch Size, IMG_H)
x_range = tf.reshape(tf.range(img_w), (1, -1)) + tf.cast(tf.random.uniform(shape=[batch_size, 1], dtype=tf.float32) * d_f32 - 1, tf.int32) # (Batch Size, IMG_W)
y_range = tf.expand_dims(tf.cast(y_range % d, tf.float32) / d_f32 >= ratio, axis=2) # (Batch Size, IMG_H, 1)
x_range = tf.expand_dims(tf.cast(x_range % d, tf.float32) / d_f32 >= ratio, axis=1) # (Batch Size, 1, IMG_W)
mask = tf.expand_dims(tf.math.logical_and(y_range, x_range) == False, axis=-1) # (Batch Size, IMG_H, IMG_W, 1)
img_batch = tf.where(mask, img_batch, 0)
return img_batch, label_batch
batch_size = 32
batch = tf.data.Dataset.from_tensor_slices((1.0 - x_train, y_train)).repeat(1).batch(batch_size)
batch = batch.map(lambda a, b : aug_grid_mask(a, b))
x, y = batch.__iter__().__next__()
for i in range(5):
xy_visualization(x[i], y[i])
效果
多图融合 (Cutmix)
@tf.function
def cutmix(img_batch, label_batch):
batch_size, img_h, img_w = tf.shape(img_batch)[0], img_batch.shape[1], img_batch.shape[2]
# CHOOSE RANDOM LOCATION
cut_xs = tf.cast(tf.random.uniform([batch_size], 0, tf.cast(img_w, tf.float32)), tf.int32) # (Batch Size)
cut_ys = tf.cast(tf.random.uniform([batch_size], 0, tf.cast(img_h, tf.float32)), tf.int32) # (Batch Size)
cut_ratios = tf.math.sqrt(1 - tf.random.uniform([batch_size], 0, 1)) # cut ratio
cut_ws = tf.cast(tf.cast(img_w, tf.float32) * cut_ratios, tf.int32)
cut_hs = tf.cast(tf.cast(img_h, tf.float32) * cut_ratios, tf.int32)
yas = tf.math.maximum(0, cut_ys - cut_hs // 2) # (Batch Size)
ybs = tf.math.minimum(img_h, cut_ys + cut_hs // 2) # (Batch Size)
xas = tf.math.maximum(0, cut_xs - cut_ws // 2) # (Batch Size)
xbs = tf.math.minimum(img_w, cut_xs + cut_ws // 2) # (Batch Size)
# CHOOSE RANDOM IMAGE TO CUTMIX WITH
index = tf.random.shuffle(tf.range(batch_size, dtype=tf.int32))
x1, x2 = img_batch, tf.gather(img_batch, index)
y1, y2 = label_batch, tf.gather(label_batch, index)
X, Y = tf.meshgrid(tf.range(img_w), tf.range(img_h))
X = tf.expand_dims(X, axis=0) # (1, img_h, img_w)
Y = tf.expand_dims(Y, axis=0) # (1, img_h, img_w)
img_weight = tf.math.logical_and(tf.math.logical_and(tf.reshape(xas, (-1, 1, 1)) <= X, X <= tf.reshape(xbs, (-1, 1, 1)) ),
tf.math.logical_and(tf.reshape(yas, (-1, 1, 1)) <= Y, Y <= tf.reshape(ybs, (-1, 1, 1)) ))
img_weight = tf.expand_dims(img_weight, axis=-1) # (Batch Size, img_h, img_w, 1)
img_batch = tf.where(img_weight, x2, x1)
label_weight = tf.cast((ybs - yas) * (xbs - xas) / (img_h * img_w), tf.float32) # (Batch Size)
label_weight = tf.expand_dims(label_weight, axis=-1) #(Batch Size, 1)
label_batch = (label_weight) * y2 + (1 - label_weight) * y1
return img_batch, label_batch
batch_size = 32
batch = tf.data.Dataset.from_tensor_slices((x_train, y_train)).repeat(1).batch(batch_size)
batch = batch.map(lambda a, b : cutmix(a, b))
x, y = batch.__iter__().__next__()
for i in range(5):
xy_visualization(x[i], y[i])
结果
多图融合 (Mixup)
@tf.function
def aug_mixup(img_batch, label_batch):
# img_batch [N, image_h, img_w, img_channels]
# label_batch [N, num_classes]
batch_size = tf.shape(img_batch)[0]
weight = tf.math.sqrt(1 - tf.random.uniform([batch_size])) #beta distribution
x_weight = tf.reshape(weight, [batch_size, 1, 1, 1])
y_weight = tf.reshape(weight, [batch_size, 1])
index = tf.random.shuffle(tf.range(batch_size, dtype=tf.int32))
x1, x2 = img_batch, tf.gather(img_batch, index)
img_batch = x1 * x_weight + x2 * (1. - x_weight)
y1, y2 = label_batch, tf.gather(label_batch, index)
label_batch = y1 * y_weight + y2 * (1. - y_weight)
return img_batch, label_batch
batch = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(20).repeat(1).batch(512)
batch = batch.map(lambda a, b : aug_mixup(a, b))
x, y = batch.__iter__().__next__()
xy_visualization(x[0], y[0])
多图融合 (FMix)
Todo
另类 (AugMix)
Todo
版权声明:本文为weixin_39576203原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。