Spark中的逻辑回归算法中有两个参数,regParam正则化参数和elasticNetParam弹性网参数,通过设置这两个参数,可以间接确定算法中的L1和L2参数。
-
反推公式
公式比较简单,可以从源码中看到,如下val regParamL1 = $(elasticNetParam) * $(regParam) val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
反推公式如下:
val RegParam = L1Param + L2Param
val ElasticNetParam = L1Param / (L1Param + L2Param)
-
参数调优
当涉及到参数调优时,情况会复杂一些。因为L1参数的取值向量和L2参数的取值向量映射回RegParam,ElasticNetParam 时,向量中的元素个数会产生变化。
目标调优参数
val L1Param = Array(0.1,0.2)
val L2Param = Array(0.2,0.3)
因为在这里,参数调优直接调用ml包中的ParamMap类
以逻辑回归为例子,首先定义逻辑回归lr
val lr = new LogisticRegression()
.setFitIntercept(true)
.setStandardization(true)
.setFeaturesCol("scaledFeatures")
.setLabelCol(labelColumnName)
其次,定义调优参数
val paramGrid: Array[ParamMap] = new ParamGridBuilder()
.addGrid(lr.regParam,regParam)
.addGrid(lr.elasticNetParam,elasticNetParam)
.build()
代码如下
val regParam: Array[Double] = L1Param.flatMap{ x => L2Param.map(y=>(x+y).formatted("%.4f").toDouble)}.distinct
val elasticNetParam: Array[Double] = L1Param.flatMap{ x=> L2Param.map(y=>(x/(x+y)).formatted("%.4f").toDouble)}.distinct
regParam = Array(0.3, 0.4, 0.5)
elasticNetParam= Array(0.3333, 0.25, 0.5, 0.4)
容易发现,原来的目标是遍历4组调优参数,现在确需要遍历12组,其中有8组都是不需要去训练模型的。
备注:因为addGrid方法是目前我只会添加单一变量向量的方式定义,看源码的时候,感觉是可以定义多变量向量的方式添加,这样这篇文章的方法也就可以不用了,但是不会这么做,所以才“曲线救国”有了这篇文章。
12组参数中有8组参数是多余的,那么问题也就明朗了,我们只需要筛选中其中需要的四组参数。
那么需要做两件事情,第一件事先要有选择对象,第二件事要有筛选对象。
选择对象定义:
val paramGrid: Array[ParamMap] = new ParamGridBuilder()
.addGrid(lr.regParam,regParam)
.addGrid(lr.elasticNetParam,elasticNetParam)
.build()
筛选对象定义
val paramFilter: Array[Array[Double]] = L1Param.flatMap{ x => L2Param.map(y=>Array((x+y).formatted("%.4f").toDouble,(x/(x+y)).formatted("%.4f").toDouble))}
进行筛选
val paramG = paramGrid.filter(y=>
paramFilter.map(x=>x(0)==y.get(lr.regParam).get && x(1)==y.get(lr.elasticNetParam).get).contains(true)
)
这样我们就筛选出了我们的调优参数组合了,从L1,L2相关的数组转化为regParam和elasticNetParam相关数组。
最后一步筛选有个小插曲,使得原本简单的代码稍微复杂了一些,最后一步代码理论上可以写成下面的形式,
val paramG =paramGrid.filter(x=>paramFilter.contains(Array(x.get(lr.regParam).get,x.get(lr.elasticNetParam).get)))
但是scala中相同的Array也是不等的,例如Array(1,2)==Array(1,2)的结果是False,所以上面的代码筛选出的是参数组合是空的。