设置csr_matrix的行 [英] Set row of csr_matrix

查看:184
本文介绍了设置csr_matrix的行的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个稀疏的csr_matrix,我想将单行的值更改为不同的值.但是,我找不到一种简单有效的实现方式.这是它要做的:

I have a sparse csr_matrix, and I want to change the values of a single row to different values. I can't find an easy and efficient implementation however. This is what it has to do:

A = csr_matrix([[0, 1, 0],
                [1, 0, 1],
                [0, 1, 0]])
new_row = np.array([-1, -1, -1])
print(set_row_csr(A, 2, new_row).todense())

>>> [[ 0,  1, 0],
     [ 1,  0, 1],
     [-1, -1, -1]]

这是我目前对set_row_csr的实现:

def set_row_csr(A, row_idx, new_row):
    A[row_idx, :] = new_row
    return A

但这给了我SparseEfficiencyWarning.有没有一种方法可以在不进行手动索引操作的情况下完成此操作,或者这是我唯一的出路吗?

But this gives me a SparseEfficiencyWarning. Is there a way of getting this done without manual index juggling, or is this my only way out?

推荐答案

最后,我设法通过索引变戏法来完成此操作.

In the end, I managed to get this done with index juggling.

def set_row_csr(A, row_idx, new_row):
    '''
    Replace a row in a CSR sparse matrix A.

    Parameters
    ----------
    A: csr_matrix
        Matrix to change
    row_idx: int
        index of the row to be changed
    new_row: np.array
        list of new values for the row of A

    Returns
    -------
    None (the matrix A is changed in place)

    Prerequisites
    -------------
    The row index shall be smaller than the number of rows in A
    The number of elements in new row must be equal to the number of columns in matrix A
    '''
    assert sparse.isspmatrix_csr(A), 'A shall be a csr_matrix'
    assert row_idx < A.shape[0], \
            'The row index ({0}) shall be smaller than the number of rows in A ({1})' \
            .format(row_idx, A.shape[0])
    try:
        N_elements_new_row = len(new_row)
    except TypeError:
        msg = 'Argument new_row shall be a list or numpy array, is now a {0}'\
        .format(type(new_row))
        raise AssertionError(msg)
    N_cols = A.shape[1]
    assert N_cols == N_elements_new_row, \
            'The number of elements in new row ({0}) must be equal to ' \
            'the number of columns in matrix A ({1})' \
            .format(N_elements_new_row, N_cols)

    idx_start_row = A.indptr[row_idx]
    idx_end_row = A.indptr[row_idx + 1]
    additional_nnz = N_cols - (idx_end_row - idx_start_row)

    A.data = np.r_[A.data[:idx_start_row], new_row, A.data[idx_end_row:]]
    A.indices = np.r_[A.indices[:idx_start_row], np.arange(N_cols), A.indices[idx_end_row:]]
    A.indptr = np.r_[A.indptr[:row_idx + 1], A.indptr[(row_idx + 1):] + additional_nnz]

这篇关于设置csr_matrix的行的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆