使用分治法的矩阵乘法 [英] Matrix Multiplication using divide and conquer approach

查看:121
本文介绍了使用分治法的矩阵乘法的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我是编程的初学者,刚刚学习了新概念,并开始编写矩阵乘法的代码,但是我对指针和其他内容感到困惑,因此我将代码上传到这里以寻求指导.

I am a beginner in programming and just learned new concepts and started writing code for matrix multiplication but I got confused in pointers and others so I am uploading my code here in seek of guidelines.

#include <stdio.h>
#include <stdlib.h>

int **matrixMultiply(int A[][8], int B[][8], int row);

int main() {
    int **A = allocate_matrix(A, 8, 8);
    int **B = allocate_matrix(B, 8, 8);

    int i, j;
    for (i = 0; i < 8; i++) {
        for (j = 0; j < 8; j++) {
            A[i][j] = i + j;
            A[i][j] = i + j;
        }
    }

    int **C = allocate_matrix(C, 8, 8);
    C = matrixMultiply(A, B, 8);

    return 0;
}

int **matrixMultiply(int A[][8], int B[][8], int row) {
    int **C = allocate_matrix(C, row, row);
    if (row == 1) {
        C[1][1] = A[1][1] * B[1][1];
    } else {
        int a11[row/2][row/2], a12[row/2][row/2], a21[row/2][row/2], a22[row/2][row/2];
        int b11[row/2][row/2], b12[row/2][row/2], b21[row/2][row/2], b22[row/2][row/2];
        int **c11 = allocate_matrix(c11, row/2, row/2);
        int **c12 = allocate_matrix(c12, row/2, row/2);
        int **c21 = allocate_matrix(c21, row/2, row/2);
        int **c22 = allocate_matrix(c22, row/2, row/2);

        int i, j;
        for (i = 0; i < row/2; i++) {
            for (j = 0; j < row/2; j++) {
                a11[i][j] = A[i][j];
                a12[i][j] = A[i][j + (row/2)];
                a21[i][j] = A[i + (row/2)][j];
                a22[i][j] = A[i + (row/2)][j + (row/2)];
                b11[i][j] = B[i][j];
                b12[i][j] = B[i][j + (row/2)];
                b21[i][j] = B[i + (row/2)][j];
                b22[i][j] = B[i + (row/2)][j + (row/2)];
                c11[i][j] = C[i][j];
                c12[i][j] = C[i][j + (row/2)];
                c21[i][j] = C[i + (row/2)][j];
                c22[i][j] = C[i + (row/2)][j + (row/2)];
            }
        }

        c11 = addmatrix(matrixMultiply(a11, b11, row/2),
                        matrixMultiply(a12, b21, row/2), c11, row/2);
        c12 = addmatrix(matrixMultiply(a11, b12, row/2),
                        matrixMultiply(a22, b22, row/2), c12, row/2);
        c21 = addmatrix(matrixMultiply(a21, b11, row/2),
                        matrixMultiply(a22, b21, row/2), c21, row/2);
        c22 = addmatrix(matrixMultiply(a21, b12, row/2),
                        matrixMultiply(a22, b22, row/2), c22, row/2);

        // missing code???
        return C;
    }
}

int **allocate_matrix(int **matrix, int row, int column) {
    matrix = (int **)malloc(row * sizeof(int*));
    int i;
    for (i = 0; i < row; i++) {
        matrix[row] = (int *)malloc(row * sizeof(int));
    }
    return matrix;
}

void deallocate_matrix(int **matrix, int row) {
    int i;
    for (i = 0; i < row; i++) {
        free(matrix[row]);
    }
    free(matrix);
}

int **addMatrix(int **a, int **b, int **c, int row) {
    int i, j;
    for (i = 0; i < row; i++) {
        for (j = 0; j < row; j++) {
            c[i][j] = a[i][j] + b[i][j];
        }
    }
    return c;
}

推荐答案

我重新格式化了您的代码,以便可以对其进行分析.一致地缩进4个空格,在,;分隔符之后以及关键字和(之间插入二进制运算符,这大大提高了可读性.

I reformatted your code so I could analyze it. Indent consistently with 4 spaces, insert spaces around binary operators, after , and ; separators and between keywords and (, this improves readability a lot.

matrixMultiply函数中似乎缺少代码:您分配结果矩阵C,但将其用作初始化中间矩阵c11c21c21,除了琐碎的1x1情况外,切勿将任何东西实际存储到C中.

There seems to be missing code in the matrixMultiply function: you allocate the resulting matrix C but you use it as an input to initialize the intermediary matrices c11, c21, c21 and c22, and never actually store anything into C except for the trivial 1x1 case.

矩阵乘法代码似乎超出了此范围,该函数接受2个类型为int A[][8], int B[][8]的参数,但是您使用定义为int a11[row/2][row/2]的局部数组a11b22递归调用它.这些类型不同,我不知道代码如何编译.

The matrix multiplication code seems broken beyond this, the function takes 2 arguments of type int A[][8], int B[][8], but you recursively call it with local arrays a11 to b22 defined as int a11[row/2][row/2]. These types are different, I do not know how the code even compiles.

在矩阵分配代码中,分配的行大小不正确,而不是row而不是column.您应该为此使用calloc,以便将矩阵初始化为0,而且完全不要传递初始参数:

In the matrix allocation code, you allocate rows with in incorrect size row instead of column. You should use calloc for this so the matrix is initialized to 0, plus you should not pass the initial argument at all:

int **allocate_matrix(int row, int column) {
    int **matrix = malloc(row * sizeof(*matrix));
    for (int i = 0; i < row; i++) {
        matrix[i] = calloc(column, sizeof(*matrix[row]));
    }
    return matrix;
}

第二个子矩阵乘法也有一个错误,应该是

There is also a mistake for the second submatrix multiplication, it should be

    c12 = addmatrix(matrixMultiply(a11, b12, row/2),
                    matrixMultiply(a12, b22, row/2), c12, row/2);

此外,您永远不会释放用于中间结果的临时矩阵.与Java不同,C没有垃圾回收器,您有责任在不再需要内存块时才释放它们,直到它们变得不可访问为止.

Furthermore, you never free the temporary matrices used for intermediary results. Unlike java, C does not have a garbage collector, you are responsible for releasing blocks of memory when you no longer need them, before they become inaccessible.

这里是经过纠正的版本,具有其他功能来打印矩阵数据并验证矩阵乘法的正确性.我添加了时间:递归方法比直接方法要慢得多,这主要是因为中间结果的所有额外分配/取消分配.

Here is a corrected version, with extra functions to print the matrix data and verify the matrix multiplication correctness. I added timings: the recursive method is much slower than the direct method, mostly because of all the extra allocation/deallocation for the intermediary results.

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

int **matrix_allocate(int row, int column) {
    int **matrix = malloc(row * sizeof(*matrix));
    for (int i = 0; i < row; i++) {
        matrix[i] = calloc(column, sizeof(*matrix[i]));
    }
    return matrix;
}

void matrix_free(int **matrix, int row) {
    for (int i = 0; i < row; i++) {
        free(matrix[i]);
    }
    free(matrix);
}

void matrix_print(const char *str, int **a, int row) {
    int min, max, w = 0, n1, n2, nw;
    min = max = a[0][0];
    for (int i = 0; i < row; i++) {
        for (int j = 0; j < row; j++) {
            if (min > a[i][j])
                min = a[i][j];
            if (max < a[i][j])
                max = a[i][j];
        }
    }
    n1 = snprintf(NULL, 0, "%d", min);
    n2 = snprintf(NULL, 0, "%d", max);
    nw = n1 > n2 ? n1 : n2;

    for (int i = 0; i < row; i++) {
        if (i == 0)
            w = printf("%s = ", str);
        else
            printf("%*s", w, "");

        for (int j = 0; j < row; j++) {
            printf(" %*d", nw, a[i][j]);
        }
        printf("\n");
    }
    fflush(stdout);
}

int **matrix_add(int **a, int **b, int row, int deallocate) {
    int **c = matrix_allocate(row, row);
    for (int i = 0; i < row; i++) {
        for (int j = 0; j < row; j++) {
            c[i][j] = a[i][j] + b[i][j];
        }
    }
    if (deallocate & 1) matrix_free(a, row);
    if (deallocate & 2) matrix_free(b, row);

    return c;
}

int **matrix_multiply(int **A, int **B, int row, int deallocate) {
    int **C = matrix_allocate(row, row);
    if (row == 1) {
        C[0][0] = A[0][0] * B[0][0];
    } else {
        int row2 = row / 2;
        int **a11 = matrix_allocate(row2, row2);
        int **a12 = matrix_allocate(row2, row2);
        int **a21 = matrix_allocate(row2, row2);
        int **a22 = matrix_allocate(row2, row2);
        int **b11 = matrix_allocate(row2, row2);
        int **b12 = matrix_allocate(row2, row2);
        int **b21 = matrix_allocate(row2, row2);
        int **b22 = matrix_allocate(row2, row2);

        for (int i = 0; i < row2; i++) {
            for (int j = 0; j < row2; j++) {
                a11[i][j] = A[i][j];
                a12[i][j] = A[i][j + row2];
                a21[i][j] = A[i + row2][j];
                a22[i][j] = A[i + row2][j + row2];
                b11[i][j] = B[i][j];
                b12[i][j] = B[i][j + row2];
                b21[i][j] = B[i + row2][j];
                b22[i][j] = B[i + row2][j + row2];
            }
        }

        int **c11 = matrix_add(matrix_multiply(a11, b11, row2, 0),
                               matrix_multiply(a12, b21, row2, 0), row2, 1+2);
        int **c12 = matrix_add(matrix_multiply(a11, b12, row2, 1),
                               matrix_multiply(a12, b22, row2, 1), row2, 1+2);
        int **c21 = matrix_add(matrix_multiply(a21, b11, row2, 2),
                               matrix_multiply(a22, b21, row2, 2), row2, 1+2);
        int **c22 = matrix_add(matrix_multiply(a21, b12, row2, 1+2),
                               matrix_multiply(a22, b22, row2, 1+2), row2, 1+2);

        for (int i = 0; i < row2; i++) {
            for (int j = 0; j < row2; j++) {
                C[i][j] = c11[i][j];
                C[i][j + row2] = c12[i][j];
                C[i + row2][j] = c21[i][j];
                C[i + row2][j + row2] = c22[i][j];
            }
        }
        matrix_free(c11, row2);
        matrix_free(c12, row2);
        matrix_free(c21, row2);
        matrix_free(c22, row2);
    }
    if (deallocate & 1) matrix_free(A, row);
    if (deallocate & 2) matrix_free(B, row);

    return C;
}

int **matrix_multiply_direct(int **A, int **B, int row, int deallocate) {
    int **C = matrix_allocate(row, row);
    for (int i = 0; i < row; i++) {
        for (int j = 0; j < row; j++) {
            int x = 0;
            for (int k = 0; k < row; k++) {
                x += A[i][k] * B[k][j];
            }
            C[i][j] = x;
        }
    }
    if (deallocate & 1) matrix_free(A, row);
    if (deallocate & 2) matrix_free(B, row);

    return C;
}

int main(int argc, char **argv) {
    int n = argc < 2 ? 8 : atoi(argv[1]);
    int **A = matrix_allocate(n, n);
    int **B = matrix_allocate(n, n);

    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            A[i][j] = i + j;
            B[i][j] = i + j;
        }
    }

    matrix_print("A", A, n);
    matrix_print("B", B, n);

    if ((n & (n - 1)) == 0) {
        /* recursive method can be applied only to powers of 2 */
        clock_t ticks = -clock();
        int **C = matrix_multiply(A, B, n, 0);
        ticks += clock();
        matrix_print("C = A * B", C, n);
        printf("%d ticks\n", ticks);
        matrix_free(C, n);
    }

    clock_t ticks = -clock();
    int **D = matrix_multiply_direct(A, B, n, 1+2);
    ticks += clock();

    matrix_print("D = A * B", D, n);
    printf("%d ticks\n", ticks);
    matrix_free(D, n);

    return 0;
}

这篇关于使用分治法的矩阵乘法的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆