Strassen矩阵乘法

  • Post author:
  • Post category:其他





Strassen矩阵乘法

  • Strassen采用了类似于在大整数乘法中用过的分治技术,将计算2个n阶矩阵乘积所需的计算时间由O(n^ 3 )改进到O(n^log7)=0(n ^2.81)。其基本思想还是使用分治法。



– 问题描述

假设n是2的幂。两个大小为 2 * 2 的矩阵相乘,一般需要进行 8 次乘法。而Strassen矩阵乘法可以减少一次乘法,只需要 7 次,看似很少,但当数据量很大时,效率就会有显著提升。不过使用 Strassen矩阵乘法需要满足矩阵边长为 2 的幂次方。因为该算法会用到分治,如果分治后矩阵两边边长不等,结果会出错。

在这里插入图片描述

由此可得:

在这里插入图片描述

  • Strassen提出了一种新的算法来计算2个2阶方阵的乘积。该算法只用了7次乘法运算,但增加了加减法的运算次数。这7次乘法是:
  • 在这里插入图片描述

    做了着7次乘法后,再做若干次加减法,得到:

    在这里插入图片描述

    从而得到矩阵乘法的结果。



– 代码实现

public class test06 {
    public static void main(String[] args) {
        int[] a = {
                1, 1, 1, 1,
                2, 2, 2, 2,
                3, 3, 3, 3,
                4, 4, 4, 4
        };
        int[] b = {
                1, 2, 3, 4,
                1, 2, 3, 4,
                1, 2, 3, 4,
                1, 2, 3, 4
        };
        int length=4;
        int[] res = fun(a, b,length);
        for (int i = 0; i < res.length; i++) {
            System.out.print(res[i] + "\t");
            if ((i + 1) % length == 0) //换行
                System.out.println();
        }

    }

    public static int[] fun(int[] a, int[] b, int length) {
        if (length == 2)
            return getResult(a, b);
        int halfLength = length / 2;
        //把a数组分为四部分,进行分治递归`
        int[] aa = new int[halfLength * halfLength];
        int[] ab = new int[halfLength * halfLength];
        int[] ac = new int[halfLength * halfLength];
        int[] ad = new int[halfLength * halfLength];
        //把b数组分为四部分,进行分治递归
        int[] ba = new int[halfLength * halfLength];
        int[] bb = new int[halfLength * halfLength];
        int[] bc = new int[halfLength * halfLength];
        int[] bd = new int[halfLength * halfLength];
        /*划分子矩阵
         * 例子:将 4 * 4 的矩阵,变为 2 * 2 的矩阵,
         * 那么原矩阵左上、右上、左下、右下的四个元素分别归为新矩阵*/
        for (int i = 0; i < length; i++) {
            for (int j = 0; j < length; j++) {
                if (i < halfLength) {
                    if (j < halfLength) {
                        aa[i * halfLength + j] = a[i * length + j];
                        ba[i * halfLength + j] = b[i * length + j];
                    } else {
                        ab[i * halfLength + (j - halfLength)] = a[i * length + j];
                        bb[i * halfLength + (j - halfLength)] = b[i * length + j];
                    }
                } else {
                    if (j < halfLength) {
                        //i 大于 halfLength 时,需要减去 halfLength,j同理
                        //因为 b,c,d三个子矩阵有对应了父矩阵的后半部分
                        ac[(i - halfLength) * halfLength + j] = a[i * length + j];
                        bc[(i - halfLength) * halfLength + j] = b[i * length + j];
                    } else {
                        ad[(i - halfLength) * halfLength + (j - halfLength)] = a[i * length + j];
                        bd[(i - halfLength) * halfLength + (j - halfLength)] = b[i * length + j];
                    }
                }
            }
        }
        //分治递归
        int[] result = new int[length * length];
        //temp结果集的4个临时矩阵
        int[] t1 = add(fun(aa, ba, halfLength), fun(ab, bc, halfLength));
        int[] t2 = add(fun(aa, bb, halfLength), fun(ab, bd, halfLength));
        int[] t3 = add(fun(ac, ba, halfLength), fun(ad, bc, halfLength));
        int[] t4 = add(fun(ac, bb, halfLength), fun(ad, bd, halfLength));
        //归并结果
        for (int i = 0; i < length; i++) {
            for (int j = 0; j < length; j++) {
                if (i < halfLength) {
                    if (j < halfLength)
                        result[i * length + j] = t1[i * halfLength + j];
                    else
                        result[i * length + j] = t2[i * halfLength + (j - halfLength)];
                } else {
                    if (j < halfLength)
                        result[i * length + j] = t3[(i - halfLength) * halfLength + j];
                    else
                        result[i * length + j] = t4[(i - halfLength) * halfLength + (j - halfLength)];
                }
            }
        }
        return result;
    }
    public static int[] add(int[] a, int[] b) {
        int[] c = new int[a.length];
        for (int i = 0; i < a.length; i++)
            c[i] = a[i] + b[i];
        return c;
    }
    public static int[] getResult(int[] a, int[] b) {
        int p1 = a[0] * (b[1] - b[3]);
        int p2 = (a[0] + a[1]) * b[3];
        int p3 = (a[2] + a[3]) * b[0];
        int p4 = a[3] * (b[2] - b[0]);
        int p5 = (a[0] + a[3]) * (b[0] + b[3]);
        int p6 = (a[1] - a[3]) * (b[2] + b[3]);
        int p7 = (a[0] - a[2]) * (b[0] + b[1]);
        int c00 = p5 + p4 - p2 + p6;
        int c01 = p1 + p2;
        int c10 = p3 + p4;
        int c11 = p5 + p1 - p3 - p7;
        return new int[]{c00, c01, c10, c11};
    }
}



版权声明:本文为qq_51251599原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。