基于多元线性回归去除图片水印(Java版)

  • Post author:
  • Post category:java




前提

采集的图片有淡淡的水印,为了避免不必要的麻烦,需要淡化或去除水印。图片如下所示:

在这里插入图片描述

Java自带的工具可以对图片指定位置(x,y)的颜色(r,g,b)进行替换。如果图片上下左右颜色一致,则可进行颜色区间的简单替换。

如上图所示,图片是手机拍摄的,颜色不一致。这就需要算法来识别图片中的水印部分。



方案

  1. 取图片四周的点位建立二元线性回归方程RGB(r,g,b)=f(x,y)
  2. 对各个点位值偏差在方程值阈值范围内的进行统一替换



代码


多元线性回归

public class LinearRegressionTest {

    public static final Logger logger = LoggerFactory.getLogger(LinearRegressionTest.class);

    /**
     * 多元线性回归分析
     *
     * @param x  每一列存放m个自变量的观察值
     * @param y  存放随即变量y的n个观察值
     * @param m  自变量的个数
     * @param n  观察数据的组数
     * @param a  返回回归系数a0,...,am
     * @param dt dt[0]偏差平方和q,dt[1]平均标准偏差s,dt[2]复相关系数r,dt[3]回归平方和u
     * @param v  返回m个自变量的偏相关系数
     */
    public static void sqt2(double[][] x, double[] y, int m, int n, double[] a,
                            double[] dt, double[] v) {
        int i, j, k, mm;
        double q, e, u, p, yy, s, r, pp;
        double[] b = new double[(m + 1) * (m + 1)];
        mm = m + 1;
        b[mm * mm - 1] = n;
        for (j = 0; j <= m - 1; j++) {
            p = 0.0;
            for (i = 0; i <= n - 1; i++)
                p = p + x[j][i];
            b[m * mm + j] = p;
            b[j * mm + m] = p;
        }
        for (i = 0; i <= m - 1; i++)
            for (j = i; j <= m - 1; j++) {
                p = 0.0;
                for (k = 0; k <= n - 1; k++)
                    p = p + x[i][k] * x[j][k];
                b[j * mm + i] = p;
                b[i * mm + j] = p;
            }
        a[m] = 0.0;
        for (i = 0; i <= n - 1; i++)
            a[m] = a[m] + y[i];
        for (i = 0; i <= m - 1; i++) {
            a[i] = 0.0;
            for (j = 0; j <= n - 1; j++)
                a[i] = a[i] + x[i][j] * y[j];
        }
        chlk(b, mm, 1, a);
        yy = 0.0;
        for (i = 0; i <= n - 1; i++)
            yy = yy + y[i] / n;
        q = 0.0;
        e = 0.0;
        u = 0.0;
        for (i = 0; i <= n - 1; i++) {
            p = a[m];
            for (j = 0; j <= m - 1; j++)
                p = p + a[j] * x[j][i];
            q = q + (y[i] - p) * (y[i] - p);
            e = e + (y[i] - yy) * (y[i] - yy);
            u = u + (yy - p) * (yy - p);
        }
        s = Math.sqrt(q / n);
        r = Math.sqrt(1.0 - q / e);
        for (j = 0; j <= m - 1; j++) {
            p = 0.0;
            for (i = 0; i <= n - 1; i++) {
                pp = a[m];
                for (k = 0; k <= m - 1; k++)
                    if (k != j)
                        pp = pp + a[k] * x[k][i];
                p = p + (y[i] - pp) * (y[i] - pp);
            }
            v[j] = Math.sqrt(1.0 - q / p);
        }
        dt[0] = q;
        dt[1] = s;
        dt[2] = r;
        dt[3] = u;
    }

    private static int chlk(double[] a, int n, int m, double[] d) {
        int i, j, k, u, v;
        if ((a[0] + 1.0 == 1.0) || (a[0] < 0.0)) {
            logger.warn("Fail!");
            return (-2);
        }
        a[0] = Math.sqrt(a[0]);
        for (j = 1; j <= n - 1; j++)
            a[j] = a[j] / a[0];
        for (i = 1; i <= n - 1; i++) {
            u = i * n + i;
            for (j = 1; j <= i; j++) {
                v = (j - 1) * n + i;
                a[u] = a[u] - a[v] * a[v];
            }
            if ((a[u] + 1.0 == 1.0) || (a[u] < 0.0)) {
                logger.warn("Fail!");
                return (-2);
            }
            a[u] = Math.sqrt(a[u]);
            if (i != (n - 1)) {
                for (j = i + 1; j <= n - 1; j++) {
                    v = i * n + j;
                    for (k = 1; k <= i; k++)
                        a[v] = a[v] - a[(k - 1) * n + i] * a[(k - 1) * n + j];
                    a[v] = a[v] / a[u];
                }
            }
        }
        for (j = 0; j <= m - 1; j++) {
            d[j] = d[j] / a[0];
            for (i = 1; i <= n - 1; i++) {
                u = i * n + i;
                v = i * m + j;
                for (k = 1; k <= i; k++)
                    d[v] = d[v] - a[(k - 1) * n + i] * d[(k - 1) * m + j];
                d[v] = d[v] / a[u];
            }
        }
        for (j = 0; j <= m - 1; j++) {
            u = (n - 1) * m + j;
            d[u] = d[u] / a[n * n - 1];
            for (k = n - 1; k >= 1; k--) {
                u = (k - 1) * m + j;
                for (i = k; i <= n - 1; i++) {
                    v = (k - 1) * n + i;
                    d[u] = d[u] - a[v] * d[i * m + j];
                }
                v = (k - 1) * n + k - 1;
                d[u] = d[u] / a[v];
            }
        }
        return (2);
    }

    public static double[] getRegressionCoefficient(double[][] x, double[] y) {
        int m = x.length; // 自变量的个数
        int n = x[0].length; // 观察数据的组数
        double[] a = new double[m + 1];
        double[] v = new double[m];
        double[] dt = new double[4];

        sqt2(x, y, m, n, a, dt, v);

        return a;
    }

    public static void main(String[] args) {
        double[][] x = {{1.1, 1.0, 1.2, 1.1, 0.9},
                {2.0, 2.0, 1.8, 1.9, 2.1}};
        double[] y = {10.1, 10.2, 10.0, 10.1, 10.0};

        int m = x.length; // 自变量的个数
        int n = x[0].length; // 观察数据的组数
        double[] a = new double[m + 1];
        double[] v = new double[m];
        double[] dt = new double[4];

        sqt2(x, y, m, n, a, dt, v);

        logger.info("回归系数a0,...,am");
        for (int i = 0; i <= m; i++) {
            logger.info("a(" + i + ")=" + a[i]);
        }

        logger.info("自变量的偏相关系数");
        for (int i = 0; i <= m - 1; i++) {
            logger.info("v(" + i + ")=" + v[i]);
        }

        logger.info("dt[0]偏差平方和q,dt[1]平均标准偏差s,dt[2]复相关系数r,dt[3]回归平方和u");
        logger.info("q=" + dt[0] + "  s=" + dt[1] + "  r=" + dt[2] + "  u=" + dt[3]);
    }

}


图片去水印

public class WatermarkRemoveTest {

    public static final Logger logger = LoggerFactory.getLogger(WatermarkRemoveTest.class);

    private static List<File> fileList = new ArrayList<>();

    public static void main(String[] args) {
        convertAllImages("D:\\fan.zhou\\可行性研究\\正时\\Tmp-图片去水印\\timing_pic\\",
                "D:\\fan.zhou\\可行性研究\\正时\\Tmp-图片去水印\\timing_pic_handle\\"); // 支持批量去除图片水印
    }

    private static void convertAllImages(String dir, String saveDir) {
        File dirFile = new File(dir);
        dir = dirFile.getAbsolutePath();
        File saveDirFile = new File(saveDir);
        saveDir = saveDirFile.getAbsolutePath();
        loadImages(new File(dir));

        for (File file : fileList) {
            String filePath = file.getAbsolutePath();
            String dstPath = saveDir + filePath.substring(filePath.indexOf(dir) + dir.length(), filePath.length());
            logger.info("Converting: {}", filePath);
            replaceColor(file.getAbsolutePath(), dstPath);
        }
    }

    public static void loadImages(File f) {
        if (f != null) {
            if (f.isDirectory()) {
                File[] fileArray = f.listFiles();
                if (fileArray != null) {
                    for (int i = 0; i < fileArray.length; i++) {
                        loadImages(fileArray[i]);
                    }
                }
            } else {
                String name = f.getName();
                if (name.endsWith("png") || name.endsWith("jpg")) {
                    fileList.add(f);
                }
            }
        }
    }

    private static void replaceColor(String srcFile, String dstFile) {
        try {
            replaceImageColor(srcFile, dstFile);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static void replaceImageColor(String file, String dstFile) throws IOException {
        URL http;
        if (file.trim().startsWith("https")) {
            http = new URL(file);
            HttpsURLConnection conn = (HttpsURLConnection) http.openConnection();
            conn.setRequestMethod("GET");
        } else if (file.trim().startsWith("http")) {
            http = new URL(file);
            HttpURLConnection conn = (HttpURLConnection) http.openConnection();
            conn.setRequestMethod("GET");
        } else {
            http = new File(file).toURI().toURL();
        }
        BufferedImage bi = ImageIO.read(http.openStream());
        if (bi == null) {
            return;
        }

        int pointNum = 30;
        double[][] xy = new double[2][4 * pointNum];
        double[][] rgb = new double[3][4 * pointNum];
        int widthBorder = 30;
        int width = bi.getWidth();
        double widthStep = (double) (width - widthBorder * 2) / (pointNum - 1);
        int heightBorder = 30;
        int height = bi.getHeight();
        double heightStep = (double) (height - heightBorder * 2) / (pointNum - 1);
        for (int i = 0; i < pointNum; i++) {
            xy[0][i] = widthBorder;
            xy[1][i] = (int) (heightBorder + heightStep * i);
            Color oriColor = new Color(bi.getRGB((int) xy[0][i], (int) xy[1][i]));
            rgb[0][i] = oriColor.getRed();
            rgb[1][i] = oriColor.getGreen();
            rgb[2][i] = oriColor.getBlue();

            xy[0][pointNum + i] = width - widthBorder;
            xy[1][pointNum + i] = (int) (heightBorder + heightStep * i);
            oriColor = new Color(bi.getRGB((int) xy[0][pointNum + i], (int) xy[1][pointNum + i]));
            rgb[0][pointNum + i] = oriColor.getRed();
            rgb[1][pointNum + i] = oriColor.getGreen();
            rgb[2][pointNum + i] = oriColor.getBlue();
        }
        for (int i = 0; i < pointNum; i++) {
            xy[0][2 * pointNum + i] = (int) (widthBorder + widthStep * i);
            xy[1][2 * pointNum + i] = heightBorder;
            Color oriColor = new Color(bi.getRGB((int) xy[0][2 * pointNum + i], (int) xy[1][2 * pointNum + i]));
            rgb[0][2 * pointNum + i] = oriColor.getRed();
            rgb[1][2 * pointNum + i] = oriColor.getGreen();
            rgb[2][2 * pointNum + i] = oriColor.getBlue();

            xy[0][3 * pointNum + i] = (int) (widthBorder + widthStep * i);
            xy[1][3 * pointNum + i] = height - heightBorder;
            oriColor = new Color(bi.getRGB((int) xy[0][3 * pointNum + i], (int) xy[1][3 * pointNum + i]));
            rgb[0][3 * pointNum + i] = oriColor.getRed();
            rgb[1][3 * pointNum + i] = oriColor.getGreen();
            rgb[2][3 * pointNum + i] = oriColor.getBlue();
        }
        double[] ar = LinearRegressionTest.getRegressionCoefficient(xy, rgb[0]);
        double[] ag = LinearRegressionTest.getRegressionCoefficient(xy, rgb[1]);
        double[] ab = LinearRegressionTest.getRegressionCoefficient(xy, rgb[2]);

        int colorDiffInit = 20;
        for (int i = 0; i < bi.getWidth(); i++) {
            for (int j = 0; j < bi.getHeight(); j++) {
                int redStd = (int) (i * ar[0] + j * ar[1] + ar[2]);
                redStd = redStd <= 255 ? redStd : 255;
                int greedStd = (int) (i * ag[0] + j * ag[1] + ag[2]);
                greedStd = greedStd <= 255 ? greedStd : 255;
                int blueStd = (int) (i * ab[0] + j * ab[1] + ab[2]);
                blueStd = blueStd <= 255 ? blueStd : 255;

                Color oriColor = new Color(bi.getRGB(i, j));
                int red = oriColor.getRed();
                int greed = oriColor.getGreen();
                int blue = oriColor.getBlue();

                int colorDiffRed = redStd > red ? redStd - red : red - redStd;
                int colorDiffGreed = greedStd > greed ? greedStd - greed : greed - greedStd;
                int colorDiffBlue = blueStd > blue ? blueStd - blue : blue - blueStd;
                if (colorDiffRed < colorDiffInit
                        && colorDiffGreed < colorDiffInit
                        && colorDiffBlue < colorDiffInit) {
                    Color wColor;
                    try {
                        wColor = new Color(redStd, greedStd, blueStd);
                    } catch (Exception ex) {
                        logger.error("", ex);
                        continue;
                    }
                    bi.setRGB(i, j, wColor.getRGB());
                }
            }
        }

        String type = file.substring(file.lastIndexOf(".") + 1, file.length());
        Iterator<ImageWriter> it = ImageIO.getImageWritersByFormatName(type);
        ImageWriter writer = it.next();
        File f = new File(dstFile);
        f.getParentFile().mkdirs();
        ImageOutputStream ios = ImageIO.createImageOutputStream(f);
        writer.setOutput(ios);
        writer.write(bi);
        bi.flush();
        ios.flush();
        ios.close();
    }

}



效果


处理前

在这里插入图片描述


处理后

在这里插入图片描述



总结

图片去水印主要的难点在于识别图片中的水印,本文基于多元线性回归识别;实际应用时,可根据需要调整识别算法。



版权声明:本文为sinat_32501475原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。