线性回归原理:
线性回归公式:y = b + w*x,w表示权重b表示偏置。
在实际实现中可以将公式写作:y = w[0] * x[0] + w[1] * x[1],x[0]=1,这样就可以很方便的进行参数求解,同样稍作修改将公式写成:y = w[0] * x[0] + w[1] * x[1] + … + w[n]*x[n],就变成了多元回归。
采用梯度下降和多次迭代不断优化参数,梯度下降计算参数的梯度,计算流程分为以下几步:
1、根据当前参数和训练计算数据预测值
preY = sum(w[n] + x[n])
2、计算梯度
wright_gradient[n] = sum(2 * (preY – y) * x[n] / N),N为训练数据总行数
3、更新参数:
wright[n] = wright[n] – a * wright_gradient[n],a为学习率,学习率取值范围[0,1],根据训练数据和训练情况来定。
4、迭代
每迭代一次就多整个训练数据计算一次梯度和更新一次参数,通过迭代使函数不断逼近最小误差。
线性回归的实现(java实现,一元回归和多元回归通用):
1、读取数据,以csv格式存储,前面几列为x,最后一列为y。
public List<double[]> readTrainFile(String filepath) {
File trainFile = new File(filepath);
List<double[]> resultList = new ArrayList<double[]>();
if (trainFile.exists()) {
try {
BufferedReader reader = new BufferedReader(new FileReader(trainFile));
String line;
while ((line = reader.readLine()) != null) {
String[] strs = line.split(",");
double[] lines = new double[strs.length];
for (int i = 0; i < strs.length; i++) {
lines[i] = Double.parseDouble(strs[i]);
}
resultList.add(lines);
}
reader.close();
} catch (Exception e) {
e.printStackTrace();
}
}
return resultList;
}
2、训练,需要设置学习率和迭代次数,返回参数数组。
public double[] train(String filepath, double learningRate, int iterationNum) {
List<double[]> trainData = readTrainFile(filepath);
double[] weights = new double[trainData.get(0).length];
for(int i = 0; i < weights.length; i++) {
weights[i] = 0;
}
weights = updateWeights(trainData, weights, learningRate, iterationNum);
return weights;
}
3、计算权重参数,对数据集每迭代一次,使用梯度下降计算梯度,通过学习率*梯度更新权重。
public double[] updateWeights(List<double[]> trainData, double[] weights, double learningRate, int iterationNum) {
for (int i = 0; i < iterationNum; i++) {
double[] weights_gradient = new double[weights.length];
for (int j = 0; j < trainData.size(); j++) {
double[] line = trainData.get(j);
double[] x = new double[line.length];
x[0] = 1;
double y = line[line.length - 1];
for(int n = 1; n < x.length; n++) {
x[n] = line[n - 1];
}
//根据当前参数和数据预测preY
double preY = 0.0;
for(int n = 0; n < weights.length; n++) {
preY += x[n] * weights[n];
}
for(int n = 0; n < weights.length; n++) {
weights_gradient[n]+=2 * (preY - y) * x[n] / (double)trainData.size();
}
}
//更新参数
for(int j = 0; j < weights.length; j++) {
weights[j] = weights[j] - learningRate * weights_gradient[j];
}
//每迭代1次,输出loss
if (i % 100 == 0) {
double loss = computeError(trainData, weights);
System.out.println(loss);
}
}
return weights;
}
4、计算error
public double computeError(List<double[]> trainData, double[] weights) {
double error= 0.0;
for(int i = 0; i < trainData.size(); i++) {
double[] line = trainData.get(i);
double preY = 0.0;
double[] x = new double[line.length];
x[0] = 1;
double y = line[line.length - 1];
for(int n = 1; n < x.length; n++) {
x[n] = line[n - 1];
}
for(int j = 0; j < line.length; j++) {
preY += weights[j] * x[j];
}
error += (y - preY) * (y - preY);
}
return error / (double)trainData.size();
}
5、测试程序,输出计算后参数
public static void main(String[] args) {
MultiLineRegression lineRegression = new MultiLineRegression();
double[] weights = lineRegression.train("E:/index/traindata.csv", 0.001, 1000);
for(double w : weights) {
System.out.print(w + ",");
}
}
版权声明:本文为jameschen9051原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。