Tensorflow笔记之自定义损失函数

  • Post author:
  • Post category:其他


Tensorflow不仅支持经典损失函数,还可以优化任意的自定义损失函数。下面以预测产品销量为例。

在预测产品销量时,如果预测多了,商家损失的是生产商品的成本,如果预测少了,损失的是商品的利润。前面所讲到的均方误差损失函数不能很好的最大化销售利润。下面的公式给出了一个当预测多于真实值和预测少于真实值时有不同损失系数的损失函数:

其中yi为一个batch中第i个数据的正确答案,yi‘为神经网络得到的预测值,a和b为常量,通过以下代码实现这个损失函数:

loss=tf.reduce_sum(tf.where(tf.greater(v1,v2),(v1-v2)*a,(v2-v1)*b))

上面代码用到了tf.greater和tf.where来实现选择操作。tf.greater的输入是两个张量,此函数会比较这两个输入张量中每一个元素的大小,并返回比较结果。当输入的张量维度不一样时,Tensorflow会进行类似NumPy广播操作的处理。tf.where函数有三个参数。第一个为选择条件根据,当选择条件为True时,tf.where函数会选择第二个参数的值,否则会使用第三个参数中的值。tf.where函数判断和选择都是在元素级别进行,以下代码展示了tf.where函数和tf.greater函数的用法。

import tensorflow as tf
v1=tf.constant([1.0,2.0,3.0,4.0])
v2=tf.constant([4.0,3.0,2.0,1.0])
sess=tf.InteractiveSession()
print tf.greater(v1,v2).eval()
#输出为[False False True True]
print tf.where(tf.greater(v1,v2),v1,v2).eval()
#输出[4. 3. 3. 4.]
sess.close()

下面的代码通过一个简单的神经网络程序来看损失函数对训练结果的影响。代码如下:

import tensorflow as tf
from numpy.random import RandomState
batch_size=8
#两个输入点
x=tf.placeholder(tf.float32,shape=(None,2),name='x-input')
#回归问题一般只有一个输出节点。
y_=tf.placeholder(tf.float32,shape=(None,1),name='y-input')
#定义一个单层的神经网络前向传播过程
w1=tf.Variable(tf.random_normal([2,1],stddev=1,seed=1))
y=tf.matmul(x,w1)
#定义预测多了和预测少了的成本
loss_less=10
loss_more=1
loss=tf.reduce_sum(tf.where(tf.greater(y,y_),(y-y_)*loss_more,(y_-y)*loss_less))
train_step=tf.train.AdamOptimizer(0.001).minimize(loss)
#通过随机数生成一个模拟数据集
rdm=RandomState(1)
dataset_size=128
X=rdm.rand(dataset_size,2)
#设置回归的正确值为两个输入的和加上一个随机变量。之所以要加上一个随机变量
#是为了不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数都会在能
#完全预测正确的时候最低。一般来说噪音为一个均值为0的小量,所以这里的噪音设置为
#-0.05-0.05的随机数
Y=[[x1+x2+rdm.rand()/10.0-0.05]for (x1,x2)in X]
#训练神经网络
with tf.Session() as sess:
    init_op=tf.global_variables_initializer()
    sess.run(init_op)
    STEPS=5000
    for i in range(STEPS):
        start=(i*batch_size)%dataset_size
        end=min(start+batch_size,dataset_size)
        sess.run(train_step,
                 feed_dict={x:X[start,end],y_:Y[start,end]})
        print(sess.run(w1))



版权声明:本文为wht18720080085原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。