diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py index 7995da545258..06eb5b667827 100644 --- a/python/mxnet/ndarray/sparse.py +++ b/python/mxnet/ndarray/sparse.py @@ -310,7 +310,7 @@ def __getitem__(self, key): Parameters ---------- - key : slice + key : int or slice Indexing key. Examples @@ -320,14 +320,22 @@ def __getitem__(self, key): >>> data = np.array([1, 2, 3, 4, 5, 6]) >>> a = mx.nd.sparse.csr_matrix(data, indptr, indices, (3, 3)) >>> a.asnumpy() - array([[1, 0, 2], - [0, 0, 3], - [4, 5, 6]]) + array([[ 1., 0., 2.], + [ 0., 0., 3.], + [ 4., 5., 6.]], dtype=float32) >>> a[1:2].asnumpy() - array([[0, 0, 3]], dtype=float32) + array([[ 0., 0., 3.]], dtype=float32) + >>> a[1].asnumpy() + array([[ 0., 0., 3.]], dtype=float32) + >>> a[-1].asnumpy() + array([[ 4., 5., 6.]], dtype=float32) """ if isinstance(key, int): - raise ValueError("__getitem__ with int key is not implemented for CSRNDArray") + if key == -1: + begin = self.shape[0] - 1 + else: + begin = key + return op.slice(self, begin=begin, end=begin+1) if isinstance(key, py_slice): if key.step is not None: raise ValueError('CSRNDArray only supports continuous slicing on axis 0') diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index ce4719bf5a3c..08c87ad0f079 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -559,8 +559,11 @@ void SliceCsrImpl(const SliceParam ¶m, const OpContext& ctx, if (req == kNullOp) return; CHECK_NE(req, kAddTo) << "kAddTo for Slice on CSR input is not supported"; CHECK_NE(req, kWriteInplace) << "kWriteInplace for Slice on CSR input is not supported"; + const TShape ishape = in.shape(); int begin = *param.begin[0]; + if (begin < 0) begin += ishape[0]; int end = *param.end[0]; + if (end < 0) end += ishape[0]; int indptr_len = end - begin + 1; out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len)); if (!in.storage_initialized()) { diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index cebc275cab17..7a313edc0a6e 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -111,8 +111,11 @@ def check_sparse_nd_csr_slice(shape): start = rnd.randint(0, shape[0] - 1) end = rnd.randint(start + 1, shape[0]) assert same(A[start:end].asnumpy(), A2[start:end]) + assert same(A[start - shape[0]:end].asnumpy(), A2[start:end]) assert same(A[start:].asnumpy(), A2[start:]) assert same(A[:end].asnumpy(), A2[:end]) + ind = rnd.randint(-shape[0], shape[0] - 1) + assert same(A[ind].asnumpy(), A2[ind][np.newaxis, :]) def check_slice_nd_csr_fallback(shape): stype = 'csr'