def set_row_csr(A, row_idx, new_row): ''' Replace a row in a CSR sparse matrix A. Parameters ---------- A: csr_matrix Matrix to change row_idx: int index of the row to be changed new_row: np.array list of new values for the row of A Returns ------- None (the matrix A is changed in place) Prerequisites ------------- The row index shall be smaller than the number of rows in A The number of elements in new row must be equal to the number of columns in matrix A ''' assert sparse.isspmatrix_csr(A), 'A shall be a csr_matrix' assert row_idx < A.shape[0], 'The row index ({0}) shall be smaller than the number of rows in A ({1})' .format(row_idx, A.shape[0]) try: N_elements_new_row = len(new_row) except TypeError: msg = 'Argument new_row shall be a list or numpy array, is now a {0}' .format(type(new_row)) raise AssertionError(msg) N_cols = A.shape[1] assert N_cols == N_elements_new_row, 'The number of elements in new row ({0}) must be equal to ' 'the number of columns in matrix A ({1})' .format(N_elements_new_row, N_cols) idx_start_row = A.indptr[row_idx] idx_end_row = A.indptr[row_idx + 1] additional_nnz = N_cols - (idx_end_row - idx_start_row) A.data = np.r_[A.data[:idx_start_row], new_row, A.data[idx_end_row:]] A.indices = np.r_[A.indices[:idx_start_row], np.arange(N_cols), A.indices[idx_end_row:]] A.indptr = np.r_[A.indptr[:row_idx + 1], A.indptr[(row_idx + 1):] + additional_nnz]
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)