如何写出比 MATLAB 更快的矩阵运算程序

如何写出比 MATLAB 更快的矩阵运算程序,第1张

矩阵乘法是一个相对成熟的问题,根据矩阵的稀疏程度有不同的优化算法。

不使用GPU加速的MATLAB版本采用的是BLAS中的General Matrix Multiplication[1]。学术界有各种矩阵乘法算法将其复杂度降低到O(n^2.x),例如Strassen和Winograd算法,在BLAS中应该已经使用了Strassen算法。

如果你的MATLAB是安装了Parallel Computing Toolbox的话,那么很可能它已经在使用GPU进行计算了。这种情况下采用的是MAGMA[2]。我没有使用过MAGMA,但我猜测它应该使用了cuBLAS来计算矩阵乘法。

宏观角度上对矩阵乘法的优化包括对局部内存使用的优化(Blocked/Tiled)以及对中间运算步骤的优化(Strassen/Winograd),实现细节上的优化就非常繁多了。比如loop unrolling,多级的tiling,指令级并行等等。其中会牵扯到一些编译器和体系结构的知识,似乎对仅仅希望使用矩阵乘法函数的用户来讲没有什么太大必要去探究。

//我的算法分析与设计的作业就做的这个,你参考下吧.是用C写的

/////////////////////////////////////////

程序功能:用分而治之算法计算两个n维矩阵相乘的结果

其中n必须是2的正整数次幂。

运行过程:首先,根据提示输入矩阵的维数n

其次,根据提示分别输入矩阵A和B

最后,显示矩阵A和矩阵B以及其相乘结果矩阵C

****************************************/

#include "stdio.h"

#define mytype int//矩阵元素的数据类型

#define myinputmode "%d"//矩阵元素的输入格式

#define myprintmode "%4d"//矩阵元素的输出格式

/*以上参数的设置可根据所计算矩阵的元素的数值类型进行相应改变

如更改为浮点型数据则可以使用下面的设置

#define mytype float

#define myinputmode "%f"

#define myprintmode "%6.2f"

*/

/////////////////////////////////////////

/****************************************

函数名:is2

参数:m为长整型整数

功能:检测m是否是2的正整数次幂

返回值:返回布尔型变量

true则表示m为2的正整数次幂

false则表示m不是2的正整数次幂

****************************************/

bool is2(long m)

{

if(m<0)return false

if(m>=2)

{

if((m%2)==0) return is2(m/2)

else return false

}

else

{

if(m==1)return true

else return false

}

return false

}

/////////////////////////////////////////

/****************************************

函数名:inputmatrix

参数:M为指向数组指针,用来存储输入的矩阵

m长整型,是数组M所存矩阵的维数

name字符型数组,是需要进行数据输入的矩阵的名字

功能:矩阵数据输入的函数,通过输入矩阵的每个元素将

矩阵存入数组

返回值:无

****************************************/

void inputmatrix(mytype * M,long m,char *name)

{

long i,j

for(i=0i<mi++)

for(j=0j<mj++)

{

printf("Please input the %s(%d,%d):",name,i+1,j+1)

getchar()

scanf(myinputmode,&M[i*m+j])

}

}

/////////////////////////////////////////

/****************************************

函数名:printmatrix

参数:M为指向数组的指针,数组中存储着矩阵

m长整型,是数组M所存矩阵的维数

name字符型数组,是需要进行数据输入的矩阵的名字

功能:矩阵数据输出显示的函数,将矩阵元素一一显示一在屏幕上

返回值:无

****************************************/

void printmatrix(mytype * M,long m,char *name)

{

long i,j

printf("\nMatrix %s:\n",name)

for(i=0i<mi++)

{

for(j=0j<mj++)

{

printf(myprintmode,M[i*m+j])

}

printf("\n")

}

}

/////////////////////////////////////////

/****************************************

函数名:Matrix_add_sub

参数:A,B为指向数组的指针,数组中存储着矩阵

C为指向数组的指针,用来存储运算结果

m长整型,是数组A、B、C所存矩阵的维数

add为布尔型变量,为true则C=A+B,为false则C=A-B

功能:根据add值对A、B进行加减运算并将结果存入C

返回值:无

****************************************/

void Matrix_add_sub(mytype * A,mytype * B,mytype * C,long m,bool add)

{

long i

for(i=0i<m*mi++)

{

if(add)

C[i]=A[i]+B[i]

else

C[i]=A[i]-B[i]

}

}

/////////////////////////////////////////

/****************************************

函数名:GetHalfValue

参数:B为指向数组的指针,数组中存储着矩阵。其中B是指向m维矩阵中的一个元素。

A为指向数组的指针,用来接收B中的四分之一数据

m长整型,是数组B所指矩阵的维数

功能:从B所在位置向左和向右取矩阵的m/2维的子矩阵(子矩阵中包括B所指元素)并存入A

返回值:无

****************************************/

void GetHalfValue(mytype * A,mytype * B,long m)

{

long i,j

for(i=0i<m/2i++)

{

for(j=0j<m/2j++)

{

A[i*m/2+j]=B[i*m+j]

}

}

}

/////////////////////////////////////////

/****************************************

函数名:UpdateHalfValue

参数:B为指向数组的指针,数组中存储着矩阵。其中B是指向m维矩阵中的一个元素。

A为指向数组的指针,存储着一个m/2维矩阵

m长整型,是数组B所指矩阵的维数

功能:把A矩阵所有元素存入从B所在位置向左和向右的m/2维的子矩阵(子矩阵中包括B所指元素)

返回值:无

****************************************/

void UpdateHalfValue(mytype * A,mytype * B,long m)

{

long i,j

for(i=0i<m/2i++)

{

for(j=0j<m/2j++)

{

B[i*m+j]=A[i*m/2+j]

}

}

}

/////////////////////////////////////////

/****************************************

函数名:Matrix_multiplication

参数:A,B为指向数组的指针,数组中存储着矩阵。

C为指向数组的指针,用来存储计算结果

m长整型,是指针A、B所指矩阵的维数

功能:用分而治之算法和Strassen方法计算A与B的乘积并存入C

返回值:无

****************************************/

void Matrix_multiplication(mytype * A,mytype * B,mytype * C,long m)

{

if(m>2)//当矩阵维数大于2时

{

//将矩阵A、B分为四个小矩阵,分别为A1、A2、A3、A4、B1、B2、B3、B4

mytype *A1=new mytype[m*m/4],*A2=new mytype[m*m/4],*A3=new mytype[m*m/4],*A4=new mytype[m*m/4],*B1=new mytype[m*m/4],*B2=new mytype[m*m/4],*B3=new mytype[m*m/4],*B4=new mytype[m*m/4],*C1=new mytype[m*m/4],*C2=new mytype[m*m/4],*C3=new mytype[m*m/4],*C4=new mytype[m*m/4]

GetHalfValue(A1,&A[0],m)

GetHalfValue(A2,&A[m/2],m)

GetHalfValue(A3,&A[m*m/2],m)

GetHalfValue(A4,&A[m*m/2+m/2],m)

GetHalfValue(B1,&B[0],m)

GetHalfValue(B2,&B[m/2],m)

GetHalfValue(B3,&B[m*m/2],m)

GetHalfValue(B4,&B[m*m/2+m/2],m)

//利用Strassen方法计算D、E、F、G、H、I、J

mytype *D=new mytype[m*m/4],*E=new mytype[m*m/4],*F=new mytype[m*m/4],*G=new mytype[m*m/4],*H=new mytype[m*m/4],*I=new mytype[m*m/4],*J=new mytype[m*m/4]

mytype *temp1=new mytype[m*m/4],*temp2=new mytype[m*m/4]

//D=A1(B2-B4)

Matrix_add_sub(B2,B4,temp1,m/2,false)

Matrix_multiplication(A1,temp1,D,m/2)

//E=A4(B3-B1)

Matrix_add_sub(B3,B1,temp1,m/2,false)

Matrix_multiplication(A4,temp1,E,m/2)

//F=(A3+A4)B1

Matrix_add_sub(A3,A4,temp1,m/2,true)

Matrix_multiplication(temp1,B1,F,m/2)

//G=(A1+A2)B4

Matrix_add_sub(A1,A2,temp1,m/2,true)

Matrix_multiplication(temp1,B4,G,m/2)

//H=(A3-A1)(B1+B2)

Matrix_add_sub(A3,A1,temp1,m/2,false)

Matrix_add_sub(B1,B2,temp2,m/2,true)

Matrix_multiplication(temp1,temp2,H,m/2)

//I=(A2-A4)(B3+B4)

Matrix_add_sub(A2,A4,temp1,m/2,false)

Matrix_add_sub(B3,B4,temp2,m/2,true)

Matrix_multiplication(temp1,temp2,I,m/2)

//J=(A1+A4)(B1+B4)

Matrix_add_sub(A1,A4,temp1,m/2,true)

Matrix_add_sub(B1,B4,temp2,m/2,true)

Matrix_multiplication(temp1,temp2,J,m/2)

//利用Strassen方法计算C1、C2、C3、C4

//C1=E+I+J-G

Matrix_add_sub(E,I,temp1,m/2,true)

Matrix_add_sub(J,G,temp2,m/2,false)

Matrix_add_sub(temp1,temp2,C1,m/2,true)

//C2=D+G

Matrix_add_sub(D,G,C2,m/2,true)

//C3=E+F

Matrix_add_sub(E,F,C3,m/2,true)

//C4=D+H+J-F

Matrix_add_sub(D,H,temp1,m/2,true)

Matrix_add_sub(J,F,temp2,m/2,false)

Matrix_add_sub(temp1,temp2,C4,m/2,true)

//将计算结果存入数组C

UpdateHalfValue(C1,&C[0],m)

UpdateHalfValue(C2,&C[m/2],m)

UpdateHalfValue(C3,&C[m*m/2],m)

UpdateHalfValue(C4,&C[m*m/2+m/2],m)

//释放内存

delete[] A1,A2,A3,A4,B1,B2,B3,B4,C1,C2,C3,C4,D,E,F,G,H,I,J,temp1,temp2

}

else

{

//当矩阵维数小于2时用Strassen方法计算矩阵乘积

mytype D,E,F,G,H,I,J

//D=A1(B2-B4)

D=A[0]*(B[1]-B[3])

//E=A4(B3-B1)

E=A[3]*(B[2]-B[0])

//F=(A3+A4)B1

F=(A[2]+A[3])*B[0]

//G=(A1+A2)B4

G=(A[0]+A[1])*B[3]

//H=(A3-A1)(B1+B2)

H=(A[2]-A[0])*(B[0]+B[1])

//I=(A2-A4)(B3+B4)

I=(A[1]-A[3])*(B[2]+B[3])

//J=(A1+A4)(B1+B4)

J=(A[0]+A[3])*(B[0]+B[3])

//C1=E+I+J-G

C[0]=E+I+J-G

//C2=D+G

C[1]=D+G

//C3=E+F

C[2]=E+F

//C4=D+H+J-F

C[3]=D+H+J-F

}

}

/////////////////////////////////////////

int main()

{

long n

//提示输入n维矩阵的维数

printf("Please input the dimension of the Matrix.(n):")

//获得用户输入的n维矩阵维数

scanf("%d",&n)

while(!is2(n))//检查维数是否是2的幂,不是则要求重新输入

{

printf("Please reinput the dimension of the Matrix.(n):")

scanf("%d",&n)

}

//开辟空间存储用来存储n维矩阵元素

mytype *A=new mytype[n*n]

mytype *B=new mytype[n*n]

mytype *C=new mytype[n*n]

//输入矩阵A、B

inputmatrix(A,n,"A")

inputmatrix(B,n,"B")

if(n>1)//矩阵维数大于1则用分而治之算法计算

Matrix_multiplication(A,B,C,n)

else//矩阵维数为1则直接计算

*C=(*A)*(*B)

//输出矩阵A、B、C

printmatrix(A,n,"A")

printmatrix(B,n,"B")

printmatrix(C,n,"C")

//释放内存

delete[] A,B,C

getchar()getchar()

return 1

}

#include <iostream>

using namespace std

const int N=8//设置矩阵的大小,仅为测试用

//const int N=4//Matrix_Siz/2

template<typename T>//矩阵加法

void Matrix_Add(int n,T X[][N],T Y[][N],T Z[][N])

template<typename T>//矩阵减法

void Matrix_Sub(int n,T X[][N],T Y[][N],T Z[][N])

template<typename T>//矩阵数据输入

void input(int n,T p[][N])

template<typename T>

void output(int n,T C[][N])

template<typename T>//矩阵乘法

void Strassen_Matrix(int n,T A[][N],T B[][N],T C[][N])

int main()

{

/*double**X

double**Y

double**Z

X=new double*[N]

for(int i=0i<Ni++)

X[i]=new double[N]

Y=new double*[N]

for(int i=0i<Ni++)

Y[i]=new doubleN]

Y=new double*[N]

for(int i=0i<Ni++)

Y[i]=new double[N]*/

int X[N][N]={0},Y[N][N]={0},Z[N][N]={0}

cout<<"请输入第一个矩阵的值:"<<endl

input(N,X)

cout<<"请输入第二个矩阵的值:"<<endl

input(N,Y)

Strassen_Matrix(N,X,Y,Z)

output(N,Z)

system("pause")

}

template<typename T>//矩阵加法

void Matrix_Add(int n,T X[][N],T Y[][N],T Z[][N])

{

for(int i=0i<ni++)

for(int j=0j<nj++)

Z[i][j]=X[i][j]+Y[i][j]

}

template<typename T>//矩阵减法

void Matrix_Sub(int n,T X[][N],T Y[][N],T Z[][N])

{

for(int i=0i<ni++)

for(int j=0j<nj++)

Z[i][j]=X[i][j]-Y[i][j]

}

template<typename T>//矩阵数据输入

void input(int n,T p[][N])

{

for(int i=0i<ni++)

{

cout<<"请输入矩阵"<<i+1<<"行的"<<n<<"个数"<<endl

for(int j=0j<nj++)

{

cin>>p[i][j]

}

}

}

template<typename T>

void output(int n,T C[][N])

{

cout<<"输出矩阵是:"<<endl

for(int i=0i<ni++)

{

for(int j=0j<nj++)

{

cout<<C[i][j]<<" "

}

cout<<endl

}

}

template<typename T>

void Strassen_Matrix(int n,T A[][N],T B[][N],T C[][N])

{

if(n==2)//当为阶方阵时

for(int i=0i<2i++)

{

for(int j=0j<2j++)

{

C[i][j]=0

for(int t=0t<2t++)

{

C[i][j]=C[i][j]+A[i][t]*B[t][j]

}

}

}

else

{

/*int A11[n/2][n/2],A12[n/2][n/2],A21[n/2][n/2],A22[n/2][n/2],

B11[n/2][n/2],B12[n/2][n/2],B21[n/2][n/2],B22[n/2][n/2],

M1[n/2][n/2],M2[n/2][n/2],M3[n/2][n/2],M4[n/2][n/2],M5[n/2][n/2],

M6[n/2][n/2],M7[n/2][n/2],TMP[n/2][n/2],TMP1[n/2][n/2]//注:TMP,TMP1只用于存放中间变量*/

int A11[N][N],A12[N][N],A21[N][N],A22[N][N],

B11[N][N],B12[N][N],B21[N][N],B22[N][N],

M1[N][N],M2[N][N],M3[N][N],M4[N][N],M5[N][N],

M6[N][N],M7[N][N],TMP[N][N],TMP1[N][N],

C11[N][N],C12[N][N],C21[N][N],C22[N][N]

for(int i=0i<n/2i++)

for(int j=0j<n/2j++)

{

A11[i][j]=A[i][j]

A12[i][j]=A[i][n/2+j]

A21[i][j]=A[n/2+i][j]

A22[i][j]=A[n/2+i][n/2+j]

B11[i][j]=B[i][j]

B12[i][j]=B[i][n/2+j]

B21[i][j]=B[n/2+i][j]

B22[i][j]=B[n/2+i][n/2+j]

}

//计算M1

Matrix_Sub(n/2,B12,B22,TMP)

Strassen_Matrix(n/2,A11,TMP,M1)

//计算M2

Matrix_Add(n/2,A11,A12,TMP)

Strassen_Matrix(n/2,TMP,B22,M2)

//计算M3

Matrix_Add(n/2,A21,A22,TMP)

Strassen_Matrix(n/2,TMP,B11,M3)

//计算M4

Matrix_Sub(n/2,B21,B11,TMP)

Strassen_Matrix(n/2,A22,TMP,M4)

//计算M5;

Matrix_Add(n/2,A11,A22,TMP)

Matrix_Add(n/2,B11,B22,TMP1)

Strassen_Matrix(n/2,TMP,TMP1,M5)

//计算M6

Matrix_Sub(n/2,A12,A22,TMP)

Matrix_Add(n/2,B21,B22,TMP1)

Strassen_Matrix(n/2,TMP,TMP1,M6)

//计算M7

Matrix_Sub(n/2,A11,A21,TMP)

Matrix_Add(n/2,B11,B12,TMP1)

Strassen_Matrix(n/2,TMP,TMP1,M7)

//计算C11,

Matrix_Add(n/2,M5,M4,TMP)

Matrix_Sub(n/2,TMP,M2,TMP1)

Matrix_Add(n/2,TMP1,M6,C11)

//就算C22,

Matrix_Add(n/2,M1,M2,C12)

//计算C21

Matrix_Add(n/2,M3,M4,C21)

//计算C22

Matrix_Add(n/2,M5,M1,TMP)

Matrix_Sub(n/2,TMP,M3,TMP1)

Matrix_Sub(n/2,TMP1,M7,C22)

for(int i=0i<n/2i++)

{

for(int j=0j<n/2j++)

{

C[i][j]=C11[i][j]

C[i][j+n/2]=C12[i][j]

C[i+n/2][j]=C21[i][j]

C[i+n/2][j+n/2]=C22[i][j]

}

}

}

}

网上有一些strassen是错误的 会运行是负数 我毕业设计也差点被那个害死 郁闷


欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/yw/12080856.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2023-05-20
下一篇 2023-05-20

发表评论

登录后才能评论

评论列表(0条)

保存