2013-04-13 31 views
6

Tôi có một ma trận thưa thớt lớn và tôi muốn nhận được giá trị tối đa cho mỗi hàng. Trong numpy, tôi có thể gọi numpy.max (mat, axis = 1), nhưng tôi không thể tìm thấy chức năng tương tự cho ma trận thưa thớt scipy. Có cách nào hiệu quả để có được tối đa của mỗi hàng cho một ma trận thưa thớt lớn?cách hiệu quả để lấy tối đa mỗi hàng cho ma trận thưa thớt lớn

Trả lời

4

Nếu ma trận của bạn, hãy gọi nó là a, được lưu trữ ở định dạng CSR, sau đó a.data có tất cả các mục khác 0 theo thứ tự hàng và a.indptr có chỉ mục của phần tử đầu tiên của mỗi hàng. Bạn có thể sử dụng tính năng này để tính toán số tiền sau:

def sparse_max_row(csr_mat): 
    ret = np.maximum.reduceat(csr_mat.data, csr_mat.indptr[:-1]) 
    ret[np.diff(csr_mat.indptr) == 0] = 0 
    return ret 
2

Tôi vừa gặp vấn đề tương tự này. Giải pháp của Jaime bị hỏng nếu bất kỳ hàng nào trong ma trận hoàn toàn trống. Dưới đây là một cách giải quyết:

def sparse_max_row(csr_mat): 
    ret = np.zeros(csr_mat.shape[0]) 
    ret[np.diff(csr_mat.indptr) != 0] = np.maximum.reduceat(csr_mat.data,csr_mat.indptr[:-1][np.diff(csr_mat.indptr)>0]) 
    return ret 
+0

này thất bại khi không ai trong số các mục dữ liệu lớn hơn 0: https://gist.github.com/jni/6120922#file-example-py – Juan