Java中的Fork Join矩阵乘法 [英] Fork Join Matrix Multiplication in Java

查看:160
本文介绍了Java中的Fork Join矩阵乘法的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在Java 7中对fork/join框架进行一些性能研究.为了提高测试结果,我想在测试过程中使用不同的递归算法.其中之一是乘法矩阵.

I’m doing some performance research on the fork/join framework in Java 7. To improve the test results I want to use different recursive algorithms during the tests. One of them is multiplying matrixes.

我从Doug Lea的网站()下载了以下示例:

I downloaded the following example from Doug Lea's website ():

public class MatrixMultiply {

  static final int DEFAULT_GRANULARITY = 16;

  /** The quadrant size at which to stop recursing down
   * and instead directly multiply the matrices.
   * Must be a power of two. Minimum value is 2.
   **/
  static int granularity = DEFAULT_GRANULARITY;

  public static void main(String[] args) {

    final String usage = "Usage: java MatrixMultiply <threads> <matrix size (must be a power of two)> [<granularity>] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16";

    try {
      int procs;
      int n;
      try {
        procs = Integer.parseInt(args[0]);
        n = Integer.parseInt(args[1]);
        if (args.length > 2) granularity = Integer.parseInt(args[2]);
      }

      catch (Exception e) {
        System.out.println(usage);
        return;
      }

      if ( ((n & (n - 1)) != 0) || 
           ((granularity & (granularity - 1)) != 0) ||
           granularity < 2) {
        System.out.println(usage);
        return;
      }

      float[][] a = new float[n][n];
      float[][] b = new float[n][n];
      float[][] c = new float[n][n];
      init(a, b, n);

      FJTaskRunnerGroup g = new FJTaskRunnerGroup(procs);
      g.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n));
      g.stats();

      // check(c, n);
    }
    catch (InterruptedException ex) {}
  }


  // To simplify checking, fill with all 1's. Answer should be all n's.
  static void init(float[][] a, float[][] b, int n) {
    for (int i = 0; i < n; ++i) {
      for (int j = 0; j < n; ++j) {
        a[i][j] = 1.0F;
        b[i][j] = 1.0F;
      }
    }
  }

  static void check(float[][] c, int n) {
    for (int i = 0; i < n; i++ ) {
      for (int j = 0; j < n; j++ ) {
        if (c[i][j] != n) {
          throw new Error("Check Failed at [" + i +"]["+j+"]: " + c[i][j]);
        }
      }
    }
  }

  /** 
   * Multiply matrices AxB by dividing into quadrants, using algorithm:
   * <pre>
   *      A      x      B                             
   *
   *  A11 | A12     B11 | B12     A11*B11 | A11*B12     A12*B21 | A12*B22 
   * |----+----| x |----+----| = |--------+--------| + |---------+-------|
   *  A21 | A22     B21 | B21     A21*B11 | A21*B21     A22*B21 | A22*B22 
   * </pre>
   */


  static class Multiplier extends FJTask {
    final float[][] A;   // Matrix A
    final int aRow;      // first row    of current quadrant of A
    final int aCol;      // first column of current quadrant of A

    final float[][] B;   // Similarly for B
    final int bRow;
    final int bCol;

    final float[][] C;   // Similarly for result matrix C
    final int cRow;
    final int cCol;

    final int size;      // number of elements in current quadrant

    Multiplier(float[][] A, int aRow, int aCol,
               float[][] B, int bRow, int bCol,
               float[][] C, int cRow, int cCol,
               int size) {
      this.A = A; this.aRow = aRow; this.aCol = aCol;
      this.B = B; this.bRow = bRow; this.bCol = bCol;
      this.C = C; this.cRow = cRow; this.cCol = cCol;
      this.size = size;
    }

    public void run() {

      if (size <= granularity) {
        multiplyStride2();
      }

      else {
        int h = size / 2;

        coInvoke(new FJTask[] {
          seq(new Multiplier(A, aRow,   aCol,    // A11
                             B, bRow,   bCol,    // B11
                             C, cRow,   cCol,    // C11
                             h),
              new Multiplier(A, aRow,   aCol+h,  // A12
                             B, bRow+h, bCol,    // B21
                             C, cRow,   cCol,    // C11
                             h)),

          seq(new Multiplier(A, aRow,   aCol,    // A11
                             B, bRow,   bCol+h,  // B12
                             C, cRow,   cCol+h,  // C12
                             h),
              new Multiplier(A, aRow,   aCol+h,  // A12
                             B, bRow+h, bCol+h,  // B22
                             C, cRow,   cCol+h,  // C12
                             h)),

          seq(new Multiplier(A, aRow+h, aCol,    // A21
                             B, bRow,   bCol,    // B11
                             C, cRow+h, cCol,    // C21
                             h),
              new Multiplier(A, aRow+h, aCol+h,  // A22
                             B, bRow+h, bCol,    // B21
                             C, cRow+h, cCol,    // C21
                             h)),

          seq(new Multiplier(A, aRow+h, aCol,    // A21
                             B, bRow,   bCol+h,  // B12
                             C, cRow+h, cCol+h,  // C22
                             h),
              new Multiplier(A, aRow+h, aCol+h,  // A22
                             B, bRow+h, bCol+h,  // B22
                             C, cRow+h, cCol+h,  // C22
                             h))
        });
      }
    }

    /** 
     * Version of matrix multiplication that steps 2 rows and columns
     * at a time. Adapted from Cilk demos.
     * Note that the results are added into C, not just set into C.
     * This works well here because Java array elements
     * are created with all zero values.
     **/

    void multiplyStride2() {
      for (int j = 0; j < size; j+=2) {
        for (int i = 0; i < size; i +=2) {

          float[] a0 = A[aRow+i];
          float[] a1 = A[aRow+i+1];

          float s00 = 0.0F; 
          float s01 = 0.0F; 
          float s10 = 0.0F; 
          float s11 = 0.0F; 

          for (int k = 0; k < size; k+=2) {

            float[] b0 = B[bRow+k];

            s00 += a0[aCol+k]   * b0[bCol+j];
            s10 += a1[aCol+k]   * b0[bCol+j];
            s01 += a0[aCol+k]   * b0[bCol+j+1];
            s11 += a1[aCol+k]   * b0[bCol+j+1];

            float[] b1 = B[bRow+k+1];

            s00 += a0[aCol+k+1] * b1[bCol+j];
            s10 += a1[aCol+k+1] * b1[bCol+j];
            s01 += a0[aCol+k+1] * b1[bCol+j+1];
            s11 += a1[aCol+k+1] * b1[bCol+j+1];
          }

          C[cRow+i]  [cCol+j]   += s00;
          C[cRow+i]  [cCol+j+1] += s01;
          C[cRow+i+1][cCol+j]   += s10;
          C[cRow+i+1][cCol+j+1] += s11;
        }
      }
    }

  }

}

此代码是为fork/join框架的较旧版本编写的.所以我必须重写它.我重写的代码实现了我自己的界面,如下所示:

This code is written for an older version of the fork/join framework. So I have to rewrite it. My rewritten code implements my own interface and looks like this:

public class Java7MatrixMultiply implements Algorithm { 
    private static final int SIZE = 32;
    private static final int THRESHOLD = 8;

    private float[][] a = new float[SIZE][SIZE];
    private float[][] b = new float[SIZE][SIZE];
    private float[][] c = new float[SIZE][SIZE];

    ForkJoinPool forkJoinPool;

    @Override
    public void initialize() {
        init(a, b, SIZE);
    }

    @Override
    public void execute() {
        MatrixMultiplyTask mainTask = new MatrixMultiplyTask(a, 0, 0, b, 0, 0, c, 0, 0, SIZE);
        forkJoinPool = new ForkJoinPool();
        forkJoinPool.invoke(mainTask);

        System.out.println("Terminated!");
    }

    @Override
    public void printResult() { 
        check(c, SIZE);

        for (int i = 0; i < SIZE; i++) {
            for (int j = 0; j < SIZE; j++) {
                System.out.print(c[i][j] + " ");
            }

            System.out.println();
        }
    }

    // To simplify checking, fill with all 1's. Answer should be all n's.
    static void init(float[][] a, float[][] b, int n) {
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                a[i][j] = 1.0F;
                b[i][j] = 1.0F;
            }
        }
    }

    static void check(float[][] c, int n) {
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if (c[i][j] != n) {
                    //throw new Error("Check Failed at [" + i + "][" + j + "]: " + c[i][j]);
                    System.out.println("Check Failed at [" + i + "][" + j + "]: " + c[i][j]);
                }
            }
        }
    }

    private class MatrixMultiplyTask extends RecursiveAction {
        private final float[][] A; // Matrix A
        private final int aRow; // first row of current quadrant of A
        private final int aCol; // first column of current quadrant of A

        private final float[][] B; // Similarly for B
        private final int bRow;
        private final int bCol;

        private final float[][] C; // Similarly for result matrix C
        private final int cRow;
        private final int cCol;

        private final int size;

        MatrixMultiplyTask(float[][] A, int aRow, int aCol, float[][] B,
                int bRow, int bCol, float[][] C, int cRow, int cCol, int size) {
            this.A = A;
            this.aRow = aRow;
            this.aCol = aCol;
            this.B = B;
            this.bRow = bRow;
            this.bCol = bCol;
            this.C = C;
            this.cRow = cRow;
            this.cCol = cCol;
            this.size = size;
        }

        @Override
        protected void compute() {      
            if (size <= THRESHOLD) {
                multiplyStride2();
            } else {

                int h = size / 2;               

                invokeAll(new MatrixMultiplyTask[] {
                        new MatrixMultiplyTask(A, aRow, aCol, // A11
                                B, bRow, bCol, // B11
                                C, cRow, cCol, // C11
                                h),

                        new MatrixMultiplyTask(A, aRow, aCol + h, // A12
                                B, bRow + h, bCol, // B21
                                C, cRow, cCol, // C11
                                h),

                        new MatrixMultiplyTask(A, aRow, aCol, // A11
                                B, bRow, bCol + h, // B12
                                C, cRow, cCol + h, // C12
                                h),

                        new MatrixMultiplyTask(A, aRow, aCol + h, // A12
                                B, bRow + h, bCol + h, // B22
                                C, cRow, cCol + h, // C12
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol, // A21
                                B, bRow, bCol, // B11
                                C, cRow + h, cCol, // C21
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22
                                B, bRow + h, bCol, // B21
                                C, cRow + h, cCol, // C21
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol, // A21
                                B, bRow, bCol + h, // B12
                                C, cRow + h, cCol + h, // C22
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22
                                B, bRow + h, bCol + h, // B22
                                C, cRow + h, cCol + h, // C22
                                h) });

            }
        }

        /**
         * Version of matrix multiplication that steps 2 rows and columns at a
         * time. Adapted from Cilk demos. Note that the results are added into
         * C, not just set into C. This works well here because Java array
         * elements are created with all zero values.
         **/

        void multiplyStride2() {
            for (int j = 0; j < size; j += 2) {
                for (int i = 0; i < size; i += 2) {

                    float[] a0 = A[aRow + i];
                    float[] a1 = A[aRow + i + 1];

                    float s00 = 0.0F;
                    float s01 = 0.0F;
                    float s10 = 0.0F;
                    float s11 = 0.0F;

                    for (int k = 0; k < size; k += 2) {

                        float[] b0 = B[bRow + k];

                        s00 += a0[aCol + k] * b0[bCol + j];
                        s10 += a1[aCol + k] * b0[bCol + j];
                        s01 += a0[aCol + k] * b0[bCol + j + 1];
                        s11 += a1[aCol + k] * b0[bCol + j + 1];

                        float[] b1 = B[bRow + k + 1];

                        s00 += a0[aCol + k + 1] * b1[bCol + j];
                        s10 += a1[aCol + k + 1] * b1[bCol + j];
                        s01 += a0[aCol + k + 1] * b1[bCol + j + 1];
                        s11 += a1[aCol + k + 1] * b1[bCol + j + 1];
                    }

                    C[cRow + i][cCol + j] += s00;
                    C[cRow + i][cCol + j + 1] += s01;
                    C[cRow + i + 1][cCol + j] += s10;
                    C[cRow + i + 1][cCol + j + 1] += s11;
                }
            }
        }
    }
}

有时我的计算未能通过检查.矩阵的某些字段具有不同的值,如预期的那样.这些不一致是随机的,并非总是会发生.我怀疑计算方法出了问题,因为我不得不重写使用Seq类的部分.与invokeAll()方法不同,Seq klass按顺序执行任务.该类在当前版本的fork/join框架中不再存在.我对矩阵乘法算法不是很熟悉,因此很难找出问题所在.有什么建议吗?

Sometimes my computation fails to pass the check. Some fields of the Matrix have a different value as expected. These inconsistencies are random, and don’t always occur. I suspect something goes wrong in the compute method, because I had to rewrite the parts where the Seq class is used. The Seq klass executes tasks in order, unlike the invokeAll() method. The class does not exist anymore in the current version of the fork/join framework. I am not very familiar with the matrix multiplication algorithm, so its very hard to see what goes wrong. Any suggestions?

推荐答案

您已经注意到,顺序执行属于同一象限的子任务对于此算法很重要.因此,您需要实现自己的seq()函数,例如,如下所示,并像在原始代码中那样使用它:

As you already noticed, sequential execution of subtasks that belong to the same quadrant is important for this algorithm. So, you need to implement your own seq() function, for example, as follows, and use it as in original code:

public ForkJoinTask<?> seq(final ForkJoinTask<?> a, final ForkJoinTask<?> b) {
    return adapt(new Runnable() {
        public void run() {
            a.invoke();
            b.invoke();
        }
    });
}

这篇关于Java中的Fork Join矩阵乘法的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆