递归矩阵乘法 [英] Recursive matrix multiplication

查看:414
本文介绍了递归矩阵乘法的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我读算法导论由CLRS。本书展示了伪code简单的分而治之的矩阵乘法:

  N = A.rows
令C是一个新的N×n矩阵
如果n == 1
    C11 = A11 * B11
其他分区A,B和C
    C11 = SquareMatrixMultiplyRecursive(A11,B11)
        + SquareMatrixMultiplyRecursive(A12,B21)
    // ...
回复C
 

,其中,例如,A11是A尺寸N / 2 XN / 2的子矩阵。 作者还暗示我应该使用,而不是创造新的矩阵重新present子矩阵指数计算,所以我这样做:

 的#include<的iostream>
#包括<载体>

模板<类T>
结构矩阵
{
    矩阵(为size_t R,为size_t三)
    {
        Data.resize(C,性病::矢量< T>(R,0));
    }

    无效SetSubMatrix(const int的R,const int的C,const int的N,常量矩阵< T>&安培; A,常量矩阵< T>和B)
    {
        对于(INT _c = C; _c n种++ _三)
        {
            对于(INT _r = R; _r n种++ _ R)
            {
                数据[_c] [_ R] = A.Data [_c] [_ R] + B.Data [_c] [_ R];
            }
        }
    }

    静态矩阵< T> SquareMultiplyRecursive(矩阵< T>&安培; A,矩阵< T>和B,INT AR,诠释交流,诠释BR,公元前INT,INT N)
    {
        矩阵< T> C(N,N);

        如果(正== 1)
        {
            C.Data [0] [0] = A.Data [AC] [AR] * B.Data [BC] [BR];
        }
        其他
        {
            C.SetSubMatrix(0,0,π/ 2,
                           SquareMultiplyRecursive(A,B,AR,交流,BR,公元前,N / 2),
                           SquareMultiplyRecursive(A,B,AR,AC +(N / 2),BR +(N / 2),BC,N / 2));

            C.SetSubMatrix(0,N / 2,N / 2,
                           SquareMultiplyRecursive(A,B,AR,交流,BR,公元前+(N / 2),N / 2),
                           SquareMultiplyRecursive(A,B,AR,AC +(n / 2个峰),br +(N / 2),BC +(N / 2)中,n / 2));

            C.SetSubMatrix(n / 2个,0,π/ 2,
                           SquareMultiplyRecursive(A,B,AR +(N / 2),交流,BR,公元前,N / 2),
                           SquareMultiplyRecursive(A,B,AR +(N / 2),交流+(N / 2),BR +(N / 2),BC,N / 2));

            C.SetSubMatrix(N / 2,N / 2,N / 2,
                           SquareMultiplyRecursive(A,B,AR +(N / 2),交流,BR,公元前+(N / 2),N / 2),
                           SquareMultiplyRecursive(A,B,芳+(N / 2),AC +(n / 2个峰),br +(N / 2),BC +(N / 2)中,n / 2));
        }

        返回℃;
    }

    无效打印()
    {
        对于(INT C = 0;℃下Data.size(); ++ C)
        {
            对(INT R = 0; R<数据[0] .size(); ++ r)的
            {
                性病::法院<<数据[C] [R]<< ;
            }
            性病::法院<< \ N的;
        }
        性病::法院<< \ N的;
    }

    的std ::矢量<的std ::矢量< T> >数据;
};

诠释的main()
{
    矩阵< INT> A(2,2);
    矩阵< INT> B(2,2);
    A.Data [0] [0] = 2;
    A.Data [0] [1] = 1;
    A.Data [1] [0] = 1;
    A.Data [1] [1] = 2;

    B.Data [0] [0] = 2;
    B.Data [0] [1] = 1;
    B.Data [1] [0] = 1;
    B.Data [1] [1] = 2;

    A.Print();
    B.Print();

    矩阵< INT> C(矩阵< INT> :: SquareMultiplyRecursive(A,B,0,0,0,0,2));

    C.Print();
}
 

这给了我不正确的结果,寿我不知道我在做什么错了?

解决方案

  C语言//递归天真的矩阵乘法,而不是STRASSEN。
// 2013-FEB-15周五12点28分,在/ Gmail的moshahmed /

#包括< ASSERT.H>
#包括< stdio.h中>
#包括< stdlib.h中>
#包括< time.h中>

#定义M 2
#定义N(1<< M)

的typedef INT垫[N] [N]; //垫[2 **男,2 ** M]的分而治之MULT。
typedef结构{INT RA,RB,CA,CB; }角落; //用于跟踪行和列。

//设置A [A] = K
空集(垫A,角落,诠释K){
  INT I,J;
  对于(i = a.ra; I< a.rb;我++)
    为(J = a.ca; J< a.cb; J ++)
      A [1] [J] = K;
}

//设置A [A] = [随机(l..h)。
无效randk(垫A,角落,诠释L,INT高){
  INT I,J;
  对于(i = a.ra; I< a.rb;我++)
    为(J = a.ca; J< a.cb; J ++)
      A [1] [J] = L +兰特()%(H-1);
}

//打印A [一]
无效打印(垫A,角落,字符*名称){
  INT I,J;
  的printf(%S = {\ N,名);
  对于(i = a.ra; I< a.rb;我++){
    为(J = a.ca; J< a.cb; J ++)
      的printf(%4d中,A [1] [J]);
    的printf(\ N);
  }
  的printf(} \ N);
}

//返回矩阵的1/4:上/下,左/右。
无效find_corners(角落,诠释我,诠释J,角* B){
  INT RM = a.ra +(a.rb  -  a.ra)/ 2;
  INT厘米= a.ca +(a.cb  -  a.ca)/ 2;
  * B = A;
  如果(我== 0)B-> RB = RM; //前行
  否则B-> RA = RM; // BOT行
  如果(j == 0)B-> CB =厘米; //左COLS
  否则B-> CA =厘米; //正确COLS
}

//天真乘法:A [A] * B [B] => C [C],递归。
无效MUL(垫A,板坯B,垫C,角落,角落B,角C){
  角AII [2] [2],BII [2] [2],CII [2] [2];
  INT I,J,M,N,P;

  //检查:A [M N] * B [N P] = C [M P]
  M = a.rb  -  a.ra;断言(M ==(c.rb-c.ra));
  N = a.cb  -  a.ca;断言(N = =(b.rb-b.ra));
  P = b.cb  -  b.ca;断言(P = =(c.cb-c.ca));
  断言(米大于0);

  如果(正== 1){
    C [c.ra] [c.ca] + = A [a.ra] [a.ca] * B [b.ra] [b.ca]
    返回;
  }

  //创建更小的矩阵:
  对于(I = 0; I&2;我++){
  为(J = 0; J&2; J ++){
        find_corners(A,I,J,和放大器; AII [I] [J]);
        find_corners(B,I,J,和放大器; BII [I] [J]);
        find_corners(C,I,J,和放大器; CII [I] [J]);
      }
  }

  //现在做8次矩阵乘法。
  // C00 = A00 * B00 + A01 * B10
  // C01 = A00 * B01 + A01 * B11
  // C10 = A10 * B00 + A11 * B10
  // C11 = A10 * B01 + A11 * B11

  穆尔(甲,乙,丙,AII [0] [0],BII [0] [0],CII [0] [0]);
  穆尔(甲,乙,丙,AII [0] [1],BII [1] [0],CII [0] [0]);

  穆尔(甲,乙,丙,AII [0] [0],BII [0] [1],CII [0] [1]);
  穆尔(甲,乙,丙,AII [0] [1],BII [1] [1],CII [0] [1]);

  穆尔(甲,乙,丙,AII [1] [0],BII [0] [0],CII [1] [0]);
  穆尔(甲,乙,丙,AII [1] [1],BII [1] [0],CII [1] [0]);

  穆尔(甲,乙,丙,AII [1] [0],BII [0] [1],CII [1] [1]);
  穆尔(甲,乙,丙,AII [1] [1],BII [1] [1],CII [1] [1]);

}

诠释的main(){
  垫A,B,C;
  角AI = {0,N,O,N};
  角双向= {0,N,O,N};
  角落CI = {0,N,O,N};
  //设置(A,AI,2);
  //设置(B,BI,2);
  函数srand(时间(0));
  randk(A,AI,0,2);
  randk(乙,双,0,2);
  组(C,CI,0); //设置为零多重峰之前。
  打印(A,AI,A);
  印刷(B,BI,B);
  MUL(A,B,C,AI,BI,CI);
  打印(C,CI,C);
  返回0;
}
 

I am reading Introduction to Algorithms by CLRS. Book shows pseudocode for simple divide and conquer matrix multiplication:

n = A.rows
let c be a new n x n matrix
if n == 1
    c11 = a11 * b11
else partition A, B, and C
    C11 = SquareMatrixMultiplyRecursive(A11, B11)
        + SquareMatrixMultiplyRecursive(A12, B21)
    //...
return C

Where for example, A11 is submatrix of A of size n/2 x n/2. Author also hints that I should use index calculations instead of creating new matrices to represent submatrices, so I did this:

#include <iostream>
#include <vector>

template<class T>
struct Matrix
{
    Matrix(size_t r, size_t c)
    {
        Data.resize(c, std::vector<T>(r, 0));
    }    

    void SetSubMatrix(const int r, const int c, const int n, const Matrix<T>& A, const Matrix<T>& B)
    {
        for(int _c=c; _c<n; ++_c)
        {
            for(int _r=r; _r<n; ++_r)
            {
                Data[_c][_r] = A.Data[_c][_r] + B.Data[_c][_r];
            }
        }
    }

    static Matrix<T> SquareMultiplyRecursive(Matrix<T>& A, Matrix<T>& B, int ar, int ac, int br, int bc, int n)
    {
        Matrix<T> C(n, n);

        if(n == 1)
        {
            C.Data[0][0] = A.Data[ac][ar] * B.Data[bc][br];
        }
        else
        {
            C.SetSubMatrix(0, 0, n / 2,
                           SquareMultiplyRecursive(A, B, ar, ac, br, bc, n / 2),
                           SquareMultiplyRecursive(A, B, ar, ac + (n / 2), br + (n / 2), bc, n / 2));

            C.SetSubMatrix(0, n / 2, n / 2,
                           SquareMultiplyRecursive(A, B, ar, ac, br, bc + (n / 2), n / 2),
                           SquareMultiplyRecursive(A, B, ar, ac + (n / 2), br + (n / 2), bc + (n / 2), n / 2));

            C.SetSubMatrix(n / 2, 0, n / 2,
                           SquareMultiplyRecursive(A, B, ar + (n / 2), ac, br, bc, n / 2),
                           SquareMultiplyRecursive(A, B, ar + (n / 2), ac + (n / 2), br + (n / 2), bc, n / 2));

            C.SetSubMatrix(n / 2, n / 2, n / 2,
                           SquareMultiplyRecursive(A, B, ar + (n / 2), ac, br, bc + (n / 2), n / 2),
                           SquareMultiplyRecursive(A, B, ar + (n / 2), ac + (n / 2), br + (n / 2), bc + (n / 2), n / 2));
        }

        return C;
    }

    void Print()
    {
        for(int c=0; c<Data.size(); ++c)
        {
            for(int r=0; r<Data[0].size(); ++r)
            {
                std::cout << Data[c][r] << " ";
            }
            std::cout << "\n";
        }
        std::cout << "\n";
    }

    std::vector<std::vector<T> > Data;
};

int main()
{
    Matrix<int> A(2, 2);
    Matrix<int> B(2, 2);
    A.Data[0][0] = 2;
    A.Data[0][1] = 1;
    A.Data[1][0] = 1;
    A.Data[1][1] = 2;

    B.Data[0][0] = 2;
    B.Data[0][1] = 1;
    B.Data[1][0] = 1;
    B.Data[1][1] = 2;

    A.Print();
    B.Print();

    Matrix<int> C(Matrix<int>::SquareMultiplyRecursive(A, B, 0, 0, 0, 0, 2));

    C.Print();
}

It gives me incorrect results, tho I am not sure what I'm doing wrong...

解决方案

// Recursive naive matrix multiplication in C, not strassen.
// 2013-Feb-15 Fri 12:28 moshahmed/at/gmail

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>

#define M 2
#define N (1<<M)

typedef int mat[N][N]; // mat[2**M,2**M]  for divide and conquer mult.
typedef struct { int ra, rb, ca, cb; } corners; // for tracking rows and columns.

// set A[a] = k
void set(mat A, corners a, int k){
  int i,j;
  for(i=a.ra;i<a.rb;i++)
    for(j=a.ca;j<a.cb;j++)
      A[i][j] = k;
}

// set A[a] = [random(l..h)].
void randk(mat A, corners a, int l, int h){
  int i,j;
  for(i=a.ra;i<a.rb;i++)
    for(j=a.ca;j<a.cb;j++)
      A[i][j] = l + rand()% (h-l);
}

// Print A[a]
void print(mat A, corners a, char *name) {
  int i,j;
  printf("%s = {\n",name);
  for(i=a.ra;i<a.rb;i++){
    for(j=a.ca;j<a.cb;j++)
      printf("%4d, ", A[i][j]);
    printf("\n");
  }
  printf("}\n");
}

// Return 1/4 of the matrix: top/bottom , left/right.
void find_corners(corners a, int i, int j, corners *b) {
  int rm = a.ra + (a.rb - a.ra)/2 ;
  int cm = a.ca + (a.cb - a.ca)/2 ;
  *b = a;
  if (i==0)  b->rb = rm;     // top rows
  else       b->ra = rm;     // bot rows
  if (j==0)  b->cb = cm;     // left cols
  else       b->ca = cm;     // right cols
}

// Naive Multiply: A[a] * B[b] => C[c], recursively.
void mul(mat A, mat B, mat C, corners a, corners b, corners c) {
  corners aii[2][2], bii[2][2], cii[2][2];
  int i, j, m, n, p;

  // Check: A[m n] * B[n p] = C[m p]
  m = a.rb - a.ra; assert(m==(c.rb-c.ra));
  n = a.cb - a.ca; assert(n==(b.rb-b.ra));
  p = b.cb - b.ca; assert(p==(c.cb-c.ca));
  assert(m>0);

  if (n==1) {
    C[c.ra][c.ca] += A[a.ra][a.ca] * B[b.ra][b.ca];
    return;
  }

  // Create the smaller matrices:
  for(i=0;i<2;i++) {
  for(j=0;j<2;j++) {
        find_corners(a, i, j, &aii[i][j]);
        find_corners(b, i, j, &bii[i][j]);
        find_corners(c, i, j, &cii[i][j]);
      }
  }

  // Now do the 8 sub matrix multiplications.
  // C00 = A00*B00 + A01*B10
  // C01 = A00*B01 + A01*B11
  // C10 = A10*B00 + A11*B10
  // C11 = A10*B01 + A11*B11

  mul( A, B, C, aii[0][0], bii[0][0], cii[0][0] );
  mul( A, B, C, aii[0][1], bii[1][0], cii[0][0] );

  mul( A, B, C, aii[0][0], bii[0][1], cii[0][1] );
  mul( A, B, C, aii[0][1], bii[1][1], cii[0][1] );

  mul( A, B, C, aii[1][0], bii[0][0], cii[1][0] );
  mul( A, B, C, aii[1][1], bii[1][0], cii[1][0] );

  mul( A, B, C, aii[1][0], bii[0][1], cii[1][1] );
  mul( A, B, C, aii[1][1], bii[1][1], cii[1][1] );

}

int main() {
  mat A, B, C;
  corners ai = {0,N,0,N};
  corners bi = {0,N,0,N};
  corners ci = {0,N,0,N};
  //set(A,ai,2);
  //set(B,bi,2);
  srand(time(0));
  randk(A,ai, 0, 2);
  randk(B,bi, 0, 2);
  set(C,ci,0); // set to zero before mult.
  print(A, ai, "A");
  print(B, bi, "B");
  mul(A,B,C, ai, bi, ci);
  print(C, ci, "C");
  return 0;
}  

这篇关于递归矩阵乘法的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆