第十章分治算法(矩阵相乘)

  • Post author:
  • Post category:其他


矩阵乘法问题:

strassen算法



Strassen算法的基本思想是把每一个矩阵都分为4块


在这里插入图片描述

在求C = AB,设7个矩阵变量。

M1 = (A12-A22)(B21+B22)

M2 = (A11+A22)(B11+B22)

M3 = (A11-A21)(B11+B12)

M4 = (A11+A12)B22

M5 = A11(B12-B11)

M6 = A22(B21-B11)

M7 = (A21+A22)B11。

则 C可以通过这7个变量算出。

C11 = M1+M2-M4+M6。

C12 = M4+N5。

C21 = M6+M7。

C22 = M2-M3+M5-M7。

就可以求出C。


主函数

int main() {
	int **A, **B, **C;
	int n;
	cout<<"请输入矩阵的规模,将自动产生矩阵:";
	cin >> n;
	A = initMatrix(n);			//初始化A,申请空间
	randomMatrix(A, n);			//A矩阵内容随机产生
	B = initMatrix(n);
	randomMatrix(B, n);
	C = initMatrix(n);			//矩阵C申请空间
	printfMatrix(A, n);			//打印输出矩阵
	printfMatrix(B, n);
	StrassenMatrix(A, B,C, n);	//求C,C=AB,n为矩阵的规模
	printfMatrix(C, n);			//打印C
	return 0;
}


功能函数

int** initMatrix(int n) {
	int **Matrix = new int *[n];
	for (int i = 0; i < n; i++) {
		Matrix[i] = new int[n];
	}
	return Matrix;
}
void randomMatrix(int** Matrix,int n) {
	srand(time(NULL));
	Sleep(1000);
	for (int i = 0; i < n; i++) {
		for (int j = 0; j < n; j++) {
			Matrix[i][j] = rand() % 10;
		}
	}
}
void printfMatrix(int** Matrix,int n) {
	for (int i = 0; i < n;i++) {
		for (int j = 0; j < n; j++) {
			cout << Matrix[i][j] << " ";
		}
		cout << endl;
	}
}
//求两个矩阵相加结果
void AddMatrix(int** m1, int** m2, int** result, int n) {
	for (int i = 0; i < n; i++) {
		for (int j = 0; j < n; j++) {
			result[i][j] = m1[i][j] + m2[i][j];
		}
	}
}
//求两个矩阵相减结果
void SubMatrix(int** m1, int** m2, int ** result, int n) {
	//矩阵m1-m2
	for (int i = 0; i < n; i++) {
		for (int j = 0; j < n; j++) {
			result[i][j] = m1[i][j] - m2[i][j];
		}
	}
}


Strassen算法

void StrassenMatrix(int** A,int** B,int** C,int n) {
	//C=AB; n为矩阵的规模
	if (n==1) {
		C[0][0] = A[0][0] * B[0][0];
	}
	else {
		int m = n / 2;				//缩小问题规模
		//将ABC矩阵分块
		int **A_11 = new int *[m];
		int **A_12 = new int *[m];
		int **A_21 = new int *[m];
		int **A_22 = new int *[m];
		int **B_11 = new int *[m];
		int **B_12 = new int *[m];
		int **B_21 = new int *[m];
		int **B_22 = new int *[m];
		int **C_11 = new int *[m];
		int **C_12 = new int *[m];
		int **C_21 = new int *[m];
		int **C_22 = new int *[m];
		//定义7个变量
		int **M1 = new int *[m];
		int **M2 = new int *[m];
		int **M3 = new int *[m];
		int **M4 = new int *[m];
		int **M5 = new int *[m];
		int **M6 = new int *[m];
		int **M7 = new int *[m];
		int **t1 = new int *[m];
		int **t2 = new int *[m];
		//分配存储空间
		for (int i = 0; i < m; i++)
		{
			A_11[i] = new int[m];
			A_12[i] = new int[m];
			A_21[i] = new int[m];
			A_22[i] = new int[m];
			B_11[i] = new int[m];
			B_12[i] = new int[m];
			B_21[i] = new int[m];
			B_22[i] = new int[m];
			C_11[i] = new int[m];
			C_12[i] = new int[m];
			C_21[i] = new int[m];
			C_22[i] = new int[m];
			M1[i] = new int[m];
			M2[i] = new int[m];
			M3[i] = new int[m];
			M4[i] = new int[m];
			M5[i] = new int[m];
			M6[i] = new int[m];
			M7[i] = new int[m];
			t1[i] = new int[m];
			t2[i] = new int[m];
		}
		//将A,B分块
		for (int i = 0; i < m; i++)
		{
			for (int j = 0; j < m; j++)
			{
				A_11[i][j] = A[i][j];
				A_12[i][j] = A[i][j + m];
				A_21[i][j] = A[i + m][j];
				A_22[i][j] = A[i + m][j + m];
				B_11[i][j] = B[i][j];
				B_12[i][j] = B[i][j + m];
				B_21[i][j] = B[i + m][j];
				B_22[i][j] = B[i + m][j + m];
			}
		}
		//M1 = (A12 - A22)(B21 + B22)
		SubMatrix(A_12, A_22, t1, m);
		AddMatrix(B_21, B_22, t2, m);
		StrassenMatrix(t1, t2, M1, m);
		//M2 = (A11 + A22)(B11 + B22)
		AddMatrix(A_11, A_22, t1, m);
		AddMatrix(B_11, B_22, t2, m);
		StrassenMatrix(t1, t2, M2, m);
		//M3 = (A11 - A21)(B11 + B12)
		SubMatrix(A_11, A_21, t1, m);
		AddMatrix(B_11, B_12, t2, m);
		StrassenMatrix(t1, t2, M3, m);
		//M4 = (A11 + A12)B22
		AddMatrix(A_11, A_12, t1, m);
		StrassenMatrix(t1, B_22, M4, m);
		//M5 = A11(B12 - B22)
		SubMatrix(B_12, B_22, t1, m);
		StrassenMatrix(t1, A_11, M5, m);
		//M6 = A22(B21 - B11),
		SubMatrix(B_21, B_11, t1, m);
		StrassenMatrix(A_22, t1, M6, m);
		//M7 = (A21 + A22)B11。
		AddMatrix(A_21, A_22, t1, m);
		StrassenMatrix(t1, B_11, M7, m);
		//根据M1到M7,求C_11,C_12,C_21,C_22
		//C11 = M1 + M2 - M4 + M6。
		AddMatrix(M1, M2, t1, m);
		AddMatrix(t1, M6, t2, m);
		SubMatrix(t2, M4, C_11, m);
		//C12 = M4 + M5。
		AddMatrix(M4, M5, C_12, m);
		//C21 = M6 + M7。
		AddMatrix(M6, M7, C_21, m);
		//C22 = M2 - M3 + M5 - M7。
		SubMatrix(M2, M3, t1, m);
		SubMatrix(M5, M7, t2, m);
		AddMatrix(t1, t2, C_22, m);
		//求出C11,C12,C21,C22后拼接回C;
		for (int i = 0; i < m;i++) {
			for (int j = 0; j < m;j++) {
				C[i][j] = C_11[i][j];
				C[i][j + m] = C_12[i][j];
				C[i + m][j] = C_21[i][j];
				C[i + m][j + m] = C_22[i][j];
			}
		}
		//释放所有申请的空间
		for (int i = 0; i < m; i++)
		{
			delete[] A_11[i];
			delete[] A_12[i];
			delete[] A_21[i];
			delete[] A_22[i];
			delete[] B_11[i];
			delete[] B_12[i];
			delete[] B_21[i];
			delete[] B_22[i];
			delete[] C_11[i];
			delete[] C_12[i];
			delete[] C_21[i];
			delete[] C_22[i];
			delete[] M1[i];
			delete[] M2[i];
			delete[] M3[i];
			delete[] M4[i];
			delete[] M5[i];
			delete[] M6[i];
			delete[] M7[i];
			delete[] t1[i];
			delete[] t2[i];
		}
		delete[] A_11;
		delete[] A_12;
		delete[] A_21;
		delete[] A_22;
		delete[] B_11;
		delete[] B_12;
		delete[] B_21;
		delete[] B_22;
		delete[] C_11;
		delete[] C_12;
		delete[] C_21;
		delete[] C_22;
		delete[] M1;
		delete[] M2;
		delete[] M3;
		delete[] M4;
		delete[] M5;
		delete[] M6;
		delete[] M7;
		delete[] t1;
		delete[] t2;
	}	
}



版权声明:本文为weixin_40540957原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。