矩阵乘法之strassen算法.docx
《矩阵乘法之strassen算法.docx》由会员分享,可在线阅读,更多相关《矩阵乘法之strassen算法.docx(8页珍藏版)》请在冰豆网上搜索。
![矩阵乘法之strassen算法.docx](https://file1.bdocx.com/fileroot1/2023-7/10/e8e20d3f-c1a7-41e0-ab9e-13e005a017c3/e8e20d3f-c1a7-41e0-ab9e-13e005a017c31.gif)
矩阵乘法之strassen算法
矩阵乘法之strassen算法
一般情况下矩阵乘法需要三个for循环,时间复杂度为O(n^3),现在我们将矩阵分块如图:
(来自MIT算法导论)
一般算法需要八次乘法
r=a*e+b*g;
s=a*f+b*h;
t=c*e+d*g;
u=c*f+d*h;
strassen将其变成7次乘法,因为大家都知道乘法比加减法消耗更多,所有时间复杂更高!
strassen的处理是:
令:
p1=a*(f-h)
p2=(a+b)*h
p3=(c+d)*e
p4=d*(g-e)
p5=(a+d)*(e+h)
p6=(b-d)*(g+h)
p7=(a-c)*(e+f)
那么我们可以知道:
r=p5+p4+p6-p2
s=p1+p2
t=p3+p4
u=p5+p1-p3-p7
我们可以看到上面只有7次乘法和多次加减法,最终达到降低复杂度为O(n^lg7)~=O(n^2.81);
代码实现如下:
[cpp]viewplaincopyprint?
//strassen算法:
将矩阵相乘的复杂度降到O(n^lg7)~=O(n^2.81)
//原理是将8次乘法减少到7次的处理
//现在理论上的最好的算法是O(n^2,367),仅仅是理论上的而已
//
//
//下面的代码仅仅是简单的实例而已,不必较真哦,呵呵~
//下面的空间可以优化的,此处就不麻烦了~
#include
#defineN10
//matrix+matrix
voidplus(intt[N/2][N/2],intr[N/2][N/2],ints[N/2][N/2]){
inti,j;
for(i=0;i{
for(j=0;j{
t[i][j]=r[i][j]+s[i][j];
}
}
}
//matrix-matrix
voidminus(intt[N/2][N/2],intr[N/2][N/2],ints[N/2][N/2]){
inti,j;
for(i=0;i{
for(j=0;j{
t[i][j]=r[i][j]-s[i][j];
}
}
}
//matrix*matrix
voidmul(intt[N/2][N/2],intr[N/2][N/2],ints[N/2][N/2]){
inti,j,k;
for(i=0;i{
for(j=0;j{
t[i][j]=0;
for(k=0;k{
t[i][j]+=r[i][k]*s[k][j];
}
}
}
}
intmain()
{
inti,j,k;
intmat[N][N];
intm1[N][N];
intm2[N][N];
inta[N/2][N/2],b[N/2][N/2],c[N/2][N/2],d[N/2][N/2];
inte[N/2][N/2],f[N/2][N/2],g[N/2][N/2],h[N/2][N/2];
intp1[N/2][N/2],p2[N/2][N/2],p3[N/2][N/2],p4[N/2][N/2];
intp5[N/2][N/2],p6[N/2][N/2],p7[N/2][N/2];
intr[N/2][N/2],s[N/2][N/2],t[N/2][N/2],u[N/2][N/2],t1[N/2][N/2],t2[N/2][N/2];
printf("\nInputthefirstmatrix...:
\n");
for(i=0;i{
for(j=0;j{
scanf("%d",&m1[i][j]);
}
}
printf("\nInputthesecondmatrix...:
\n");
for(i=0;i{
for(j=0;j{
scanf("%d",&m2[i][j]);
}
}
//abcdefgh
for(i=0;i{
for(j=0;j{
a[i][j]=m1[i][j];
b[i][j]=m1[i][j+N/2];
c[i][j]=m1[i+N/2][j];
d[i][j]=m1[i+N/2][j+N/2];
e[i][j]=m2[i][j];
f[i][j]=m2[i][j+N/2];
g[i][j]=m2[i+N/2][j];
h[i][j]=m2[i+N/2][j+N/2];
}
}
//p1
minus(r,f,h);
mul(p1,a,r);
//p2
plus(r,a,b);
mul(p2,r,h);
//p3
plus(r,c,d);
mul(p3,r,e);
//p4
minus(r,g,e);
mul(p4,d,r);
//p5
plus(r,a,d);
plus(s,e,f);
mul(p5,r,s);
//p6
minus(r,b,d);
plus(s,g,h);
mul(p6,r,s);
//p7
minus(r,a,c);
plus(s,e,f);
mul(p7,r,s);
//r=p5+p4-p2+p6
plus(t1,p5,p4);
minus(t2,t1,p2);
plus(r,t2,p6);
//s=p1+p2
plus(s,p1,p2);
//t=p3+p4
plus(t,p3,p4);
//u=p5+p1-p3-p7=p5+p1-(p3+p7)
plus(t1,p5,p1);
plus(t2,p3,p7);
minus(u,t1,t2);
for(i=0;i{
for(j=0;j{
mat[i][j]=r[i][j];
mat[i][j+N/2]=s[i][j];
mat[i+N/2][j]=t[i][j];
mat[i+N/2][j+N/2]=u[i][j];
}
}
printf("\n下面是strassen算法处理结果:
\n");for(i=0;i{
for(j=0;j{
printf("%d",mat[i][j]);
}
printf("\n");
}
//下面是朴素算法处理
printf("\n下面是朴素算法处理结果:
\n");
for(i=0;i{
for(j=0;j{
mat[i][j]=0;
for(k=0;k{
mat[i][j]+=m1[i][j]*m2[i][j];
}
}
}
for(i=0;i{
for(j=0;j{
printf("%d",mat[i][j]);